@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,236 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
+
5
+ import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
+ import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
7
+
8
+
9
+ export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
10
+ activation?: "relu" | "gelu";
11
+ dimsFeedForward?: number;
12
+ causal?: boolean; // use causal mask for attention on inputs
13
+ }
14
+
15
+
16
+ /**
17
+ * This class implements the transformer decoder architecture from
18
+ * the 2017 paper "Attention Is All You Need".
19
+ *
20
+ * This decoder-only transformer layer accepts one tensor input.
21
+ * The input tensor should have the shape
22
+ * `[ batch, sequences, embedding dims ]`.
23
+ *
24
+ * Causal masking is enabled by default for the initial attention sub-layer.
25
+ *
26
+ * @param numHeads number of attention heads to use
27
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
28
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
29
+ * @param dropout use dropout during the attention calculations, default `0.1`
30
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
31
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
32
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
33
+ */
34
+ export class TransformerDecoder extends tf.layers.Layer {
35
+ static className = "TransformerDecoder";
36
+
37
+ protected readonly causalSelfAttention: tf.layers.Layer;
38
+ protected readonly causalSelfAttentionDropout: tf.layers.Layer;
39
+ protected readonly causalSelfAttentionNorm: tf.layers.Layer;
40
+
41
+ protected readonly feedforward1: tf.layers.Layer;
42
+ protected readonly feedforward2: tf.layers.Layer;
43
+ protected readonly feedForwardDropout: tf.layers.Layer;
44
+ protected readonly feedFowardNorm: tf.layers.Layer;
45
+
46
+ protected readonly numHeads: number;
47
+ protected readonly embedDim: number;
48
+ protected readonly useBias: boolean;
49
+ protected readonly dropout: number;
50
+ protected readonly activation: ActivationIdentifier;
51
+ protected readonly dimsFeedForward: number;
52
+
53
+ constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs) {
54
+ super(args);
55
+
56
+ this.numHeads = numHeads;
57
+ this.embedDim = embedDim;
58
+ this.useBias = useBias ?? true;
59
+ this.dropout = dropout ?? 0.1;
60
+ this.activation = activation ?? "relu";
61
+
62
+ if (this.dropout >= 1) {
63
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
64
+ }
65
+
66
+ // in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
67
+ this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
68
+
69
+ // self attention sub-block
70
+ this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
71
+ numHeads: this.numHeads, embedDim: this.embedDim,
72
+ useBias: this.useBias, dropout: this.dropout,
73
+ causal: true
74
+ });
75
+ this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
76
+ this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
77
+
78
+ // feed forward sub-block
79
+ this.feedforward1 = tf.layers.dense({
80
+ units: this.dimsFeedForward,
81
+ activation: this.activation,
82
+ useBias: this.useBias,
83
+ });
84
+ this.feedforward2 = tf.layers.dense({
85
+ units: this.embedDim,
86
+ activation: "linear",
87
+ useBias: this.useBias
88
+ });
89
+ this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
90
+ this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
91
+ }
92
+
93
+
94
+ /**
95
+ * Forward propagation
96
+ *
97
+ * @param inputs input tensor
98
+ * @return the output tensor
99
+ */
100
+ override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
101
+ // validate the input tensors
102
+ if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
103
+ throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
104
+ }
105
+
106
+ if (Array.isArray(inputs)) {
107
+ inputs = inputs[0] as tf.Tensor;
108
+ }
109
+
110
+ // perform forward propagation
111
+ return tf.tidy(() => {
112
+ let output = this.causalSelfAttentionBlock(inputs, kwargs);
113
+ output = this.feedForwardBlock(output, kwargs);
114
+
115
+ return output;
116
+ });
117
+ }
118
+
119
+
120
+ protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
121
+ return tf.tidy(() => {
122
+ const residual = x;
123
+
124
+ let attention = this.causalSelfAttention.apply(x, kwargs) as tf.Tensor;
125
+ attention = this.causalSelfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
126
+ attention = tf.add(attention, residual);
127
+ attention = this.causalSelfAttentionNorm.apply(attention, kwargs) as tf.Tensor;
128
+
129
+ return attention;
130
+ });
131
+ }
132
+
133
+
134
+ protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
135
+ return tf.tidy(() => {
136
+ const residual = x;
137
+
138
+ let feedForward = this.feedforward1.apply(x, kwargs);
139
+ feedForward = this.feedforward2.apply(feedForward, kwargs);
140
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
141
+ feedForward = tf.add(feedForward, residual);
142
+ feedForward = this.feedFowardNorm.apply(feedForward, kwargs) as tf.Tensor;
143
+
144
+ return feedForward;
145
+ });
146
+ }
147
+
148
+
149
+ /**
150
+ * Initialize the sublayers' weights and track them to enable serialization
151
+ */
152
+ override build(inputShape: tf.Shape | tf.Shape[]): void {
153
+ let input_shapes: tf.Shape[] = [];
154
+
155
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
156
+ // input is an array of shapes
157
+ input_shapes = inputShape as tf.Shape[];
158
+ } else if (inputShape.length != 0) {
159
+ // input is a single shape
160
+ input_shapes = [inputShape as tf.Shape];
161
+ }
162
+
163
+ if (input_shapes.length != 1 && input_shapes.length != 2) {
164
+ throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
165
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
166
+ }
167
+
168
+ const [decoderInputShape] = input_shapes;
169
+
170
+ if (decoderInputShape?.length != 3) {
171
+ throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
172
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
173
+ }
174
+
175
+ // initialize causal self attention sub-block's weights
176
+ this.causalSelfAttention.build(decoderInputShape);
177
+ this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
178
+
179
+ // initialize feedforward sub-block's weights
180
+ const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
181
+ const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
182
+
183
+ this.feedforward1.build(decoderInputShape);
184
+ this.feedforward2.build(feedforward1OutputShape);
185
+ this.feedFowardNorm.build(feedforward2OutputShape);
186
+
187
+ // track sublayers' weights
188
+ this.trainableWeights = [
189
+ ...this.causalSelfAttention.trainableWeights,
190
+ ...this.causalSelfAttentionDropout.trainableWeights,
191
+ ...this.causalSelfAttentionNorm.trainableWeights,
192
+ ...this.feedforward1.trainableWeights,
193
+ ...this.feedforward2.trainableWeights,
194
+ ...this.feedForwardDropout.trainableWeights,
195
+ ...this.feedFowardNorm.trainableWeights
196
+ ];
197
+
198
+ // rename the weights otherwise they'll take on the default naming and overlap
199
+ // each other which breaks model loading due to duplicate weight names
200
+ let indexing = 0;
201
+
202
+ for (const weight of this.trainableWeights) {
203
+ const unique_name = `${this.getClassName()}_${indexing}`;
204
+ (weight as any).name += unique_name;
205
+ (weight as any).originalName += unique_name;
206
+ indexing++;
207
+ }
208
+
209
+ super.build(inputShape);
210
+ }
211
+
212
+
213
+ /**
214
+ * Save the layer's hyperparameters for serialization
215
+ */
216
+ override getConfig() {
217
+ const base_config = super.getConfig();
218
+
219
+ const config = {
220
+ numHeads: this.numHeads,
221
+ embedDim: this.embedDim,
222
+ useBias: this.useBias,
223
+ dropout: this.dropout,
224
+ activation: this.activation,
225
+ dimsFeedForward: this.dimsFeedForward
226
+ }
227
+
228
+ Object.assign(config, base_config);
229
+
230
+ return config;
231
+ }
232
+
233
+ }
234
+
235
+
236
+ tf.serialization.registerClass(TransformerDecoder);
@@ -0,0 +1,85 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+
3
+ import { TransformerEncoder } from "@/layers/transformer_encoder";
4
+
5
+ // disables warning for using the faster node backend,
6
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
+ tf.env().set('IS_NODE', false);
8
+
9
+
10
+ describe("TransformerEncoder tests", () => {
11
+ it("should return an output with the same shape as the input", () => {
12
+ const input = tf.randomUniform([2, 3, 10]);
13
+
14
+ const decoder = new TransformerEncoder({
15
+ numHeads: 2, embedDim: input.shape.at(-1)!,
16
+ dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
17
+ });
18
+
19
+ const output = decoder.apply(input) as tf.Tensor;
20
+
21
+ expect(output.shape.length).toBe(input.shape.length);
22
+ })
23
+
24
+
25
+ test("correct forward calls", () => {
26
+ const input = tf.randomUniform([2, 3, 10]);
27
+
28
+ const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
29
+ expect(() => encoder.apply(input)).not.toThrow();
30
+ expect(() => encoder.apply([input])).not.toThrow();
31
+
32
+ const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
33
+ expect(() => causal.apply(input)).not.toThrow();
34
+ expect(() => causal.apply([input])).not.toThrow();
35
+ })
36
+
37
+
38
+ it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
39
+ const input = tf.randomUniform([2, 3, 10]);
40
+
41
+ expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1)! })).toThrow();
42
+ expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1)! })).not.toThrow();
43
+ })
44
+
45
+
46
+ it("should not accept non-rank 3 tensor inputs", () => {
47
+ const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
48
+ const incorrect_input2 = tf.randomUniform([2, 3]);
49
+ const correct_input = tf.randomUniform([2, 3, 10]);
50
+
51
+
52
+ const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1)! });
53
+ expect(() => encoder.apply([correct_input, correct_input])).toThrow();
54
+
55
+ expect(() => encoder.apply(incorrect_input)).toThrow();
56
+ expect(() => encoder.apply(incorrect_input2)).toThrow();
57
+
58
+ expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
59
+ expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
60
+
61
+ expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
62
+ expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
63
+ })
64
+
65
+
66
+ it("should accept exactly one input", () => {
67
+ const input = tf.randomUniform([2, 3, 10]);
68
+
69
+ const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
70
+ expect(() => encoder.apply(input)).not.toThrow();
71
+ expect(() => encoder.apply([input])).not.toThrow();
72
+
73
+ expect(() => encoder.apply([])).toThrow();
74
+ expect(() => encoder.apply([input, input])).toThrow();
75
+ expect(() => encoder.apply([input, input, input])).toThrow()
76
+ })
77
+
78
+
79
+ it("should return a non-empty config dict", () => {
80
+ const input = tf.randomUniform([2, 3, 10]);
81
+
82
+ const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
83
+ expect(Object.keys(encoder.getConfig())).not.toBe(0);
84
+ })
85
+ })
@@ -0,0 +1,224 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
+
5
+ import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
+
7
+
8
+ export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
9
+ activation?: "relu" | "gelu";
10
+ dimsFeedForward?: number;
11
+ }
12
+
13
+
14
+ /**
15
+ * This class implements the transformer encoder architecture from the 2017 paper
16
+ * Attention Is All You Need.
17
+ *
18
+ * This layer accepts exactly one tensor input with the shape
19
+ * `[ batch, sequences, embedding dims ]`.
20
+ *
21
+ * @param numHeads number of attention heads to use
22
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
23
+ * @param causal use causal masking, default `false` for encoders
24
+ * @param dropout use dropout during the attention calculations, default `0.1`
25
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
26
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
27
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
28
+ */
29
+ export class TransformerEncoder extends tf.layers.Layer {
30
+ static className = "TransformerEncoder";
31
+
32
+ private readonly selfAttention: tf.layers.Layer;
33
+ private readonly selfAttentionDropout: tf.layers.Layer;
34
+ private readonly selfAttentionNorm: tf.layers.Layer;
35
+
36
+ private readonly reluLayer: tf.layers.Layer;
37
+ private readonly linearLayer: tf.layers.Layer;
38
+ private readonly feedForwardDropout: tf.layers.Layer;
39
+ private readonly feedFowardNorm: tf.layers.Layer;
40
+
41
+ private readonly numHeads: number;
42
+ private readonly embedDim: number;
43
+ private readonly causal: boolean;
44
+ private readonly useBias: boolean;
45
+ private readonly dropout: number;
46
+ private readonly activation: ActivationIdentifier;
47
+ private readonly dimsFeedForward: number;
48
+
49
+
50
+ constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs) {
51
+ super(args);
52
+
53
+ this.numHeads = numHeads;
54
+ this.embedDim = embedDim;
55
+ this.causal = causal ?? false;
56
+ this.useBias = useBias ?? true;
57
+ this.dropout = dropout ?? 0.1;
58
+ this.activation = activation ?? "relu";
59
+ this.dimsFeedForward = dimsFeedForward ?? 2048;
60
+
61
+ if (this.dropout >= 1) {
62
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
63
+ }
64
+
65
+ // self attention sub-block
66
+ this.selfAttention = new MultiHeadAttention({
67
+ numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
68
+ dropout: this.dropout, causal: this.causal
69
+ });
70
+ this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
71
+ this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
72
+
73
+ // feed forward sub-block
74
+ this.reluLayer = tf.layers.dense({
75
+ units: this.dimsFeedForward, activation: this.activation,
76
+ useBias: this.useBias
77
+ });
78
+ this.linearLayer = tf.layers.dense({
79
+ units: this.embedDim, activation: "linear",
80
+ useBias: this.useBias
81
+ });
82
+ this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
83
+ this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
84
+ }
85
+
86
+
87
+ /**
88
+ * Forward propagation
89
+ */
90
+ override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
91
+ // validate the input tensors
92
+ let input: tf.Tensor;
93
+
94
+ if (Array.isArray(inputs)) {
95
+ if (inputs.length != 1) {
96
+ throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
97
+ ` input, got ${inputs.length} inputs instead.`);
98
+ }
99
+
100
+ input = inputs[0];
101
+ } else {
102
+ input = inputs;
103
+ }
104
+
105
+ // perform forward propagation
106
+ return tf.tidy(() => {
107
+ const attention = this.selfAttentionBlock(input, kwargs);
108
+ const feedforward = this.feedForwardBlock(attention, kwargs);
109
+
110
+ return feedforward;
111
+ });
112
+ }
113
+
114
+
115
+ private selfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
116
+ return tf.tidy(() => {
117
+ const residual = x;
118
+
119
+ let attention = this.selfAttention.apply(x, kwargs) as tf.Tensor;
120
+ attention = this.selfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
121
+ attention = tf.add(attention, residual);
122
+ attention = this.selfAttentionNorm.apply(attention) as tf.Tensor;
123
+
124
+ return attention;
125
+ });
126
+ }
127
+
128
+
129
+ private feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
130
+ return tf.tidy(() => {
131
+ const residual = x;
132
+
133
+ let feedForward = this.reluLayer.apply(x);
134
+ feedForward = this.linearLayer.apply(feedForward);
135
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
136
+ feedForward = tf.add(feedForward, residual);
137
+ feedForward = this.feedFowardNorm.apply(feedForward) as tf.Tensor;
138
+
139
+ return feedForward;
140
+ });
141
+ }
142
+
143
+
144
+ /**
145
+ * Initialize the sublayers' weights and track them to enable backpropagation.
146
+ */
147
+ override build(inputShape: tf.Shape | tf.Shape[]): void {
148
+ let input_shapes: tf.Shape[] = [];
149
+
150
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
151
+ // input is an array of shapes
152
+ input_shapes = inputShape as tf.Shape[];
153
+ } else if (inputShape.length != 0) {
154
+ // input is a single shape
155
+ input_shapes = [inputShape as tf.Shape];
156
+ }
157
+
158
+ // expects only 1 rank 3 tensor input
159
+ if (input_shapes.length != 1 || input_shapes[0].length != 3) {
160
+ throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
161
+ }
162
+
163
+ // initialize self attention sub-block's weights
164
+ this.selfAttention.build(inputShape);
165
+ this.selfAttentionNorm.build(inputShape);
166
+
167
+ // inintialize feedforward sub-block's weights
168
+ const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
169
+ const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
170
+
171
+ this.reluLayer.build(inputShape);
172
+ this.linearLayer.build(reluLayerOutputShape);
173
+ this.feedFowardNorm.build(linearLayerOutputShape);
174
+
175
+ // track sublayers' weights
176
+ this.trainableWeights = [
177
+ ...this.selfAttention.trainableWeights,
178
+ ...this.selfAttentionDropout.trainableWeights,
179
+ ...this.selfAttentionNorm.trainableWeights,
180
+ ...this.reluLayer.trainableWeights,
181
+ ...this.linearLayer.trainableWeights,
182
+ ...this.feedForwardDropout.trainableWeights,
183
+ ...this.feedFowardNorm.trainableWeights
184
+ ];
185
+
186
+ // rename the weights otherwise they'll take on the default naming and overlap
187
+ // each other which breaks model loading due to duplicate weight names
188
+ let indexing = 0;
189
+
190
+ for (const weight of this.trainableWeights) {
191
+ const unique_name = `${this.getClassName()}_${indexing}`;
192
+ (weight as any).name += unique_name;
193
+ (weight as any).originalName += unique_name;
194
+ indexing++;
195
+ }
196
+
197
+ super.build(inputShape);
198
+ }
199
+
200
+
201
+ /**
202
+ * Save the layer's hyperparameters for serialization
203
+ */
204
+ override getConfig(): tf.serialization.ConfigDict {
205
+ const base_config = super.getConfig();
206
+
207
+ const config = {
208
+ numHeads: this.numHeads,
209
+ embedDim: this.embedDim,
210
+ causal: this.causal,
211
+ useBias: this.useBias,
212
+ dropout: this.dropout,
213
+ activation: this.activation,
214
+ dimsFeedForward: this.dimsFeedForward
215
+ };
216
+
217
+ Object.assign(config, base_config);
218
+
219
+ return config;
220
+ }
221
+ }
222
+
223
+
224
+ tf.serialization.registerClass(TransformerEncoder);
@@ -0,0 +1,156 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { categoricalCrossentropy, binaryCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
3
+
4
+ const epsilon = 1e-7;
5
+
6
+ const REDUCE_HW = [1, 2]; // reduce over width and height
7
+ const REDUCE_BHW = [0, 1, 2]; // reduce over batch, width, height
8
+ const REDUCE_BHWC = [0, 1, 2, 3]; // reduce all dimensions
9
+
10
+
11
+ // Standard (Sorensen) Dice Loss
12
+ export function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
13
+
14
+ const y_true_flat = tf.reshape(y_true, [y_true.shape[0], -1]);
15
+ const y_pred_flat = tf.reshape(y_pred, [y_pred.shape[0], -1]);
16
+
17
+ const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat), 1);
18
+ const union = tf.add(tf.sum(y_true_flat, 1), tf.sum(y_pred_flat, 1));
19
+
20
+ const dice = tf.div(
21
+ intersection.mul(2).add(epsilon),
22
+ union.add(epsilon)
23
+ );
24
+
25
+ return tf.scalar(1).sub(dice);
26
+ }
27
+
28
+
29
+ // prevents minification of function name which TFJS relies on
30
+ Object.defineProperty(diceBinaryStandard, "name", { value: "diceBinaryStandard", configurable: false });
31
+
32
+
33
+ // https://github.com/keras-team/keras/blob/v3.3.3/keras/src/losses/losses.py#L1983-L2010
34
+ export function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
35
+ const y_true_flat = tf.reshape(y_true, [-1]);
36
+ const y_pred_flat = tf.reshape(y_pred, [-1]);
37
+
38
+ const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat));
39
+ const union = tf.add(tf.sum(y_true_flat), tf.sum(y_pred_flat));
40
+
41
+ const dice = tf.div(
42
+ intersection.mul(2).add(epsilon),
43
+ union.add(epsilon)
44
+ );
45
+
46
+ return tf.scalar(1).sub(dice);
47
+ }
48
+
49
+
50
+ // prevents minification of function name which TFJS relies on
51
+ Object.defineProperty(diceBinaryGlobal, "name", { value: "diceBinaryGlobal", configurable: false });
52
+
53
+
54
+ export function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
55
+ const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_HW);
56
+ const union = tf.add(y_true, y_pred).sum(REDUCE_HW);
57
+
58
+ const dice = tf.div(
59
+ intersection.mul(2).add(epsilon),
60
+ union.add(epsilon)
61
+ );
62
+
63
+ return tf.scalar(1).sub(tf.mean(dice, -1));
64
+ }
65
+
66
+
67
+ // prevents minification of function name which TFJS relies on
68
+ Object.defineProperty(diceCategoricalStandard, "name", { value: "diceCategoricalStandard", configurable: false });
69
+
70
+
71
+ export function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
72
+
73
+ // this is done twice so we calculate it once
74
+ const y_true_sum = y_true.sum(REDUCE_BHW);
75
+
76
+ const weighting = tf.div(1, y_true_sum.square().add(epsilon));
77
+
78
+ const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHW).mul(weighting).sum();
79
+ const union = tf.add(y_true_sum, y_pred.sum(REDUCE_BHW)).mul(weighting).sum();
80
+
81
+ const dice = tf.div(
82
+ intersection.mul(2).add(epsilon),
83
+ union.add(epsilon)
84
+ );
85
+
86
+ return tf.scalar(1).sub(dice);
87
+ }
88
+
89
+
90
+ // prevents minification of function name which TFJS relies on
91
+ Object.defineProperty(diceCategoricalGeneralized, "name", { value: "diceCategoricalGeneralized", configurable: false });
92
+
93
+
94
+ export function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
95
+
96
+ const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHWC);
97
+ const union = tf.add(tf.sum(y_true, REDUCE_BHWC), tf.sum(y_pred, REDUCE_BHWC));
98
+
99
+ const dice = tf.div(
100
+ intersection.mul(2).add(epsilon),
101
+ union.add(epsilon)
102
+ );
103
+
104
+ return tf.scalar(1).sub(dice);
105
+ }
106
+
107
+
108
+ // prevents minification of function name which TFJS relies on
109
+ Object.defineProperty(diceCategoricalGlobal, "name", { value: "diceCategoricalGlobal", configurable: false });
110
+
111
+
112
+ /**
113
+ * Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
114
+ * Both have equal weight.
115
+ *
116
+ * @param y_true the label tensor
117
+ * @param y_pred the prediction tensor (not sparse)
118
+ * @returns a tensor of shape `[ batch ]`
119
+ */
120
+ export function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
121
+ // reduce cross entropy shape from [B, H, W] to [B] to match dice
122
+ const bce = binaryCrossentropy(y_true, y_pred).mean(REDUCE_HW);
123
+ const dice = diceBinaryStandard(y_true, y_pred);
124
+
125
+ return tf.add(bce.mul(0.5), dice.mul(0.5));
126
+ }
127
+
128
+
129
+ // prevents minification of function name which TFJS relies on
130
+ Object.defineProperty(diceBinaryCrossentropy, "name", { value: "diceBinaryCrossentropy", configurable: false });
131
+
132
+
133
+ /**
134
+ * Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
135
+ * Both have equal weight. Expects dense (non-sparse) label tensors.
136
+ *
137
+ * This does not support sparse tensors because TFJS's
138
+ * sparseCategoricalCrossentropy loss onehots the label
139
+ * and calls categoricalCrossentropy. See
140
+ * https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
141
+ *
142
+ * @param y_true the label
143
+ * @param y_pred the prediction tensor (not sparse)
144
+ * @returns a tensor of shape `[ batch ]`
145
+ */
146
+ export function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
147
+ // reduce cross entropy shape from [B, H, W] to [B] to match dice
148
+ const cce = categoricalCrossentropy(y_true, y_pred).mean(REDUCE_HW);
149
+ const dice = diceCategoricalStandard(y_true, y_pred);
150
+
151
+ return tf.add(cce.mul(0.5), dice.mul(0.5));
152
+ }
153
+
154
+
155
+ // prevents minification of function name which TFJS relies on
156
+ Object.defineProperty(diceCategoricalCrossentropy, "name", { value: "diceCategoricalCrossentropy", configurable: false });