@stellarapp/tfjs-stellar 1.0.3 → 1.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +3 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
|
@@ -1,175 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { MultiHeadAttention } from "@/layers/multihead_attention";
|
|
3
|
-
/**
|
|
4
|
-
* This class implements the transformer encoder architecture from the 2017 paper
|
|
5
|
-
* Attention Is All You Need.
|
|
6
|
-
*
|
|
7
|
-
* This layer accepts exactly one tensor input with the shape
|
|
8
|
-
* `[ batch, sequences, embedding dims ]`.
|
|
9
|
-
*
|
|
10
|
-
* @param numHeads number of attention heads to use
|
|
11
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
12
|
-
* @param causal use causal masking, default `false` for encoders
|
|
13
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
14
|
-
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
15
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
16
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
17
|
-
*/
|
|
18
|
-
export class TransformerEncoder extends tf.layers.Layer {
|
|
19
|
-
static className = "TransformerEncoder";
|
|
20
|
-
selfAttention;
|
|
21
|
-
selfAttentionDropout;
|
|
22
|
-
selfAttentionNorm;
|
|
23
|
-
reluLayer;
|
|
24
|
-
linearLayer;
|
|
25
|
-
feedForwardDropout;
|
|
26
|
-
feedFowardNorm;
|
|
27
|
-
numHeads;
|
|
28
|
-
embedDim;
|
|
29
|
-
causal;
|
|
30
|
-
useBias;
|
|
31
|
-
dropout;
|
|
32
|
-
activation;
|
|
33
|
-
dimsFeedForward;
|
|
34
|
-
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }) {
|
|
35
|
-
super(args);
|
|
36
|
-
this.numHeads = numHeads;
|
|
37
|
-
this.embedDim = embedDim;
|
|
38
|
-
this.causal = causal ?? false;
|
|
39
|
-
this.useBias = useBias ?? true;
|
|
40
|
-
this.dropout = dropout ?? 0.1;
|
|
41
|
-
this.activation = activation ?? "relu";
|
|
42
|
-
this.dimsFeedForward = dimsFeedForward ?? 2048;
|
|
43
|
-
if (this.dropout >= 1) {
|
|
44
|
-
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
45
|
-
}
|
|
46
|
-
// self attention sub-block
|
|
47
|
-
this.selfAttention = new MultiHeadAttention({
|
|
48
|
-
numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
|
|
49
|
-
dropout: this.dropout, causal: this.causal
|
|
50
|
-
});
|
|
51
|
-
this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
|
|
52
|
-
this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
53
|
-
// feed forward sub-block
|
|
54
|
-
this.reluLayer = tf.layers.dense({
|
|
55
|
-
units: this.dimsFeedForward, activation: this.activation,
|
|
56
|
-
useBias: this.useBias
|
|
57
|
-
});
|
|
58
|
-
this.linearLayer = tf.layers.dense({
|
|
59
|
-
units: this.embedDim, activation: "linear",
|
|
60
|
-
useBias: this.useBias
|
|
61
|
-
});
|
|
62
|
-
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
63
|
-
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
64
|
-
}
|
|
65
|
-
/**
|
|
66
|
-
* Forward propagation
|
|
67
|
-
*/
|
|
68
|
-
call(inputs, kwargs) {
|
|
69
|
-
// validate the input tensors
|
|
70
|
-
let input;
|
|
71
|
-
if (Array.isArray(inputs)) {
|
|
72
|
-
if (inputs.length != 1) {
|
|
73
|
-
throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
|
|
74
|
-
` input, got ${inputs.length} inputs instead.`);
|
|
75
|
-
}
|
|
76
|
-
input = inputs[0];
|
|
77
|
-
}
|
|
78
|
-
else {
|
|
79
|
-
input = inputs;
|
|
80
|
-
}
|
|
81
|
-
// perform forward propagation
|
|
82
|
-
return tf.tidy(() => {
|
|
83
|
-
const attention = this.selfAttentionBlock(input, kwargs);
|
|
84
|
-
const feedforward = this.feedForwardBlock(attention, kwargs);
|
|
85
|
-
return feedforward;
|
|
86
|
-
});
|
|
87
|
-
}
|
|
88
|
-
selfAttentionBlock(x, kwargs) {
|
|
89
|
-
return tf.tidy(() => {
|
|
90
|
-
const residual = x;
|
|
91
|
-
let attention = this.selfAttention.apply(x, kwargs);
|
|
92
|
-
attention = this.selfAttentionDropout.apply(attention, kwargs);
|
|
93
|
-
attention = tf.add(attention, residual);
|
|
94
|
-
attention = this.selfAttentionNorm.apply(attention);
|
|
95
|
-
return attention;
|
|
96
|
-
});
|
|
97
|
-
}
|
|
98
|
-
feedForwardBlock(x, kwargs) {
|
|
99
|
-
return tf.tidy(() => {
|
|
100
|
-
const residual = x;
|
|
101
|
-
let feedForward = this.reluLayer.apply(x);
|
|
102
|
-
feedForward = this.linearLayer.apply(feedForward);
|
|
103
|
-
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
104
|
-
feedForward = tf.add(feedForward, residual);
|
|
105
|
-
feedForward = this.feedFowardNorm.apply(feedForward);
|
|
106
|
-
return feedForward;
|
|
107
|
-
});
|
|
108
|
-
}
|
|
109
|
-
/**
|
|
110
|
-
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
111
|
-
*/
|
|
112
|
-
build(inputShape) {
|
|
113
|
-
let input_shapes = [];
|
|
114
|
-
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
115
|
-
// input is an array of shapes
|
|
116
|
-
input_shapes = inputShape;
|
|
117
|
-
}
|
|
118
|
-
else if (inputShape.length != 0) {
|
|
119
|
-
// input is a single shape
|
|
120
|
-
input_shapes = [inputShape];
|
|
121
|
-
}
|
|
122
|
-
// expects only 1 rank 3 tensor input
|
|
123
|
-
if (input_shapes.length != 1 || input_shapes[0].length != 3) {
|
|
124
|
-
throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
125
|
-
}
|
|
126
|
-
// initialize self attention sub-block's weights
|
|
127
|
-
this.selfAttention.build(inputShape);
|
|
128
|
-
this.selfAttentionNorm.build(inputShape);
|
|
129
|
-
// inintialize feedforward sub-block's weights
|
|
130
|
-
const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
|
|
131
|
-
const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
|
|
132
|
-
this.reluLayer.build(inputShape);
|
|
133
|
-
this.linearLayer.build(reluLayerOutputShape);
|
|
134
|
-
this.feedFowardNorm.build(linearLayerOutputShape);
|
|
135
|
-
// track sublayers' weights
|
|
136
|
-
this.trainableWeights = [
|
|
137
|
-
...this.selfAttention.trainableWeights,
|
|
138
|
-
...this.selfAttentionDropout.trainableWeights,
|
|
139
|
-
...this.selfAttentionNorm.trainableWeights,
|
|
140
|
-
...this.reluLayer.trainableWeights,
|
|
141
|
-
...this.linearLayer.trainableWeights,
|
|
142
|
-
...this.feedForwardDropout.trainableWeights,
|
|
143
|
-
...this.feedFowardNorm.trainableWeights
|
|
144
|
-
];
|
|
145
|
-
// rename the weights otherwise they'll take on the default naming and overlap
|
|
146
|
-
// each other which breaks model loading due to duplicate weight names
|
|
147
|
-
let indexing = 0;
|
|
148
|
-
for (const weight of this.trainableWeights) {
|
|
149
|
-
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
150
|
-
weight.name += unique_name;
|
|
151
|
-
weight.originalName += unique_name;
|
|
152
|
-
indexing++;
|
|
153
|
-
}
|
|
154
|
-
super.build(inputShape);
|
|
155
|
-
}
|
|
156
|
-
/**
|
|
157
|
-
* Save the layer's hyperparameters for serialization
|
|
158
|
-
*/
|
|
159
|
-
getConfig() {
|
|
160
|
-
const base_config = super.getConfig();
|
|
161
|
-
const config = {
|
|
162
|
-
numHeads: this.numHeads,
|
|
163
|
-
embedDim: this.embedDim,
|
|
164
|
-
causal: this.causal,
|
|
165
|
-
useBias: this.useBias,
|
|
166
|
-
dropout: this.dropout,
|
|
167
|
-
activation: this.activation,
|
|
168
|
-
dimsFeedForward: this.dimsFeedForward
|
|
169
|
-
};
|
|
170
|
-
Object.assign(config, base_config);
|
|
171
|
-
return config;
|
|
172
|
-
}
|
|
173
|
-
}
|
|
174
|
-
tf.serialization.registerClass(TransformerEncoder);
|
|
175
|
-
//# sourceMappingURL=transformer_encoder.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_encoder.js","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAS/F;;;;;;;;;;;;;;GAcG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAEvB,aAAa,CAAkB;IAC/B,oBAAoB,CAAkB;IACtC,iBAAiB,CAAkB;IAEnC,SAAS,CAAkB;IAC3B,WAAW,CAAkB;IAC7B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,MAAM,CAAU;IAChB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAGzC,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QACtH,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,MAAM,GAAG,MAAM,IAAI,KAAK,CAAC;QAC9B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,IAAI,CAAC;QAE/C,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2BAA2B;QAC3B,IAAI,CAAC,aAAa,GAAG,IAAI,kBAAkB,CAAC;YACxC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YACvE,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,MAAM,EAAE,IAAI,CAAC,MAAM;SAC7C,CAAC,CAAC;QACH,IAAI,CAAC,oBAAoB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QACrE,IAAI,CAAC,iBAAiB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAEzE,yBAAyB;QACzB,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC7B,KAAK,EAAE,IAAI,CAAC,eAAe,EAAE,UAAU,EAAE,IAAI,CAAC,UAAU;YACxD,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC/B,KAAK,EAAE,IAAI,CAAC,QAAQ,EAAE,UAAU,EAAE,QAAQ;YAC1C,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAgB,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBACrB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,UAAU,IAAI,CAAC,IAAI,2BAA2B;oBAC1E,eAAe,MAAM,CAAC,MAAM,kBAAkB,CAAC,CAAC;YACxD,CAAC;YAED,KAAK,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,CAAC;aAAM,CAAC;YACJ,KAAK,GAAG,MAAM,CAAC;QACnB,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;YACzD,MAAM,WAAW,GAAG,IAAI,CAAC,gBAAgB,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC;YAE7D,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,kBAAkB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACjE,SAAS,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC5E,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,SAAS,CAAc,CAAC;YAEjE,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACjD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC1C,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC;YAClD,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,CAAc,CAAC;YAElE,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,qCAAqC;QACrC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC1D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,iEAAiE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACxJ,CAAC;QAED,gDAAgD;QAChD,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACrC,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QAEzC,8CAA8C;QAC9C,MAAM,oBAAoB,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC3E,MAAM,sBAAsB,GAAG,IAAI,CAAC,WAAW,CAAC,kBAAkB,CAAC,oBAAoB,CAAC,CAAC;QAEzF,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACjC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC;QAC7C,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,sBAAsB,CAAC,CAAC;QAElD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,oBAAoB,CAAC,gBAAgB;YAC7C,GAAG,IAAI,CAAC,iBAAiB,CAAC,gBAAgB;YAC1C,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,WAAW,CAAC,gBAAgB;YACpC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAC;QAEF,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_encoder.test.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { TransformerEncoder } from "@/layers/transformer_encoder";
|
|
3
|
-
// disables warning for using the faster node backend,
|
|
4
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
-
tf.env().set('IS_NODE', false);
|
|
6
|
-
describe("TransformerEncoder tests", () => {
|
|
7
|
-
it("should return an output with the same shape as the input", () => {
|
|
8
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
9
|
-
const decoder = new TransformerEncoder({
|
|
10
|
-
numHeads: 2, embedDim: input.shape.at(-1),
|
|
11
|
-
dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
|
|
12
|
-
});
|
|
13
|
-
const output = decoder.apply(input);
|
|
14
|
-
expect(output.shape.length).toBe(input.shape.length);
|
|
15
|
-
});
|
|
16
|
-
test("correct forward calls", () => {
|
|
17
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
18
|
-
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
19
|
-
expect(() => encoder.apply(input)).not.toThrow();
|
|
20
|
-
expect(() => encoder.apply([input])).not.toThrow();
|
|
21
|
-
const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
22
|
-
expect(() => causal.apply(input)).not.toThrow();
|
|
23
|
-
expect(() => causal.apply([input])).not.toThrow();
|
|
24
|
-
});
|
|
25
|
-
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
26
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
27
|
-
expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1) })).toThrow();
|
|
28
|
-
expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1) })).not.toThrow();
|
|
29
|
-
});
|
|
30
|
-
it("should not accept non-rank 3 tensor inputs", () => {
|
|
31
|
-
const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
|
|
32
|
-
const incorrect_input2 = tf.randomUniform([2, 3]);
|
|
33
|
-
const correct_input = tf.randomUniform([2, 3, 10]);
|
|
34
|
-
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1) });
|
|
35
|
-
expect(() => encoder.apply([correct_input, correct_input])).toThrow();
|
|
36
|
-
expect(() => encoder.apply(incorrect_input)).toThrow();
|
|
37
|
-
expect(() => encoder.apply(incorrect_input2)).toThrow();
|
|
38
|
-
expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
|
|
39
|
-
expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
|
|
40
|
-
expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
|
|
41
|
-
expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
|
|
42
|
-
});
|
|
43
|
-
it("should accept exactly one input", () => {
|
|
44
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
45
|
-
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
46
|
-
expect(() => encoder.apply(input)).not.toThrow();
|
|
47
|
-
expect(() => encoder.apply([input])).not.toThrow();
|
|
48
|
-
expect(() => encoder.apply([])).toThrow();
|
|
49
|
-
expect(() => encoder.apply([input, input])).toThrow();
|
|
50
|
-
expect(() => encoder.apply([input, input, input])).toThrow();
|
|
51
|
-
});
|
|
52
|
-
it("should return a non-empty config dict", () => {
|
|
53
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
54
|
-
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
55
|
-
expect(Object.keys(encoder.getConfig())).not.toBe(0);
|
|
56
|
-
});
|
|
57
|
-
});
|
|
58
|
-
//# sourceMappingURL=transformer_encoder.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_encoder.test.js","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,IAAI;SACxE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,uBAAuB,EAAE,GAAG,EAAE;QAC/B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACvG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,eAAe,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QACzD,MAAM,gBAAgB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,aAAa,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAGnD,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,eAAe,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEtE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,gBAAgB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,eAAe,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,gBAAgB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACzE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,gBAAgB,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC7E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAA;IAChE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
export declare function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
3
|
-
export declare function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
4
|
-
export declare function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
5
|
-
export declare function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
6
|
-
export declare function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
7
|
-
/**
|
|
8
|
-
* Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
|
|
9
|
-
* Both have equal weight.
|
|
10
|
-
*
|
|
11
|
-
* @param y_true the label tensor
|
|
12
|
-
* @param y_pred the prediction tensor (not sparse)
|
|
13
|
-
* @returns a tensor of shape `[ batch ]`
|
|
14
|
-
*/
|
|
15
|
-
export declare function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
16
|
-
/**
|
|
17
|
-
* Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
|
|
18
|
-
* Both have equal weight. Expects dense (non-sparse) label tensors.
|
|
19
|
-
*
|
|
20
|
-
* This does not support sparse tensors because TFJS's
|
|
21
|
-
* sparseCategoricalCrossentropy loss onehots the label
|
|
22
|
-
* and calls categoricalCrossentropy. See
|
|
23
|
-
* https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
|
|
24
|
-
*
|
|
25
|
-
* @param y_true the label
|
|
26
|
-
* @param y_pred the prediction tensor (not sparse)
|
|
27
|
-
* @returns a tensor of shape `[ batch ]`
|
|
28
|
-
*/
|
|
29
|
-
export declare function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
30
|
-
//# sourceMappingURL=dice.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"dice.d.ts","sourceRoot":"","sources":["../../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAWvC,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAclF;AAQD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAahF;AAOD,wBAAgB,uBAAuB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAUvF;AAOD,wBAAgB,0BAA0B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB1F;AAOD,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAWrF;AAOD;;;;;;;GAOG;AACH,wBAAgB,sBAAsB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAMtF;AAOD;;;;;;;;;;;;GAYG;AACH,wBAAgB,2BAA2B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAM3F"}
|
package/dist/src/losses/dice.js
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { categoricalCrossentropy, binaryCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
|
|
3
|
-
const epsilon = 1e-7;
|
|
4
|
-
const REDUCE_HW = [1, 2]; // reduce over width and height
|
|
5
|
-
const REDUCE_BHW = [0, 1, 2]; // reduce over batch, width, height
|
|
6
|
-
const REDUCE_BHWC = [0, 1, 2, 3]; // reduce all dimensions
|
|
7
|
-
// Standard (Sorensen) Dice Loss
|
|
8
|
-
export function diceBinaryStandard(y_true, y_pred) {
|
|
9
|
-
const y_true_flat = tf.reshape(y_true, [y_true.shape[0], -1]);
|
|
10
|
-
const y_pred_flat = tf.reshape(y_pred, [y_pred.shape[0], -1]);
|
|
11
|
-
const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat), 1);
|
|
12
|
-
const union = tf.add(tf.sum(y_true_flat, 1), tf.sum(y_pred_flat, 1));
|
|
13
|
-
const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
|
|
14
|
-
return tf.scalar(1).sub(dice);
|
|
15
|
-
}
|
|
16
|
-
// prevents minification of function name which TFJS relies on
|
|
17
|
-
Object.defineProperty(diceBinaryStandard, "name", { value: "diceBinaryStandard", configurable: false });
|
|
18
|
-
// https://github.com/keras-team/keras/blob/v3.3.3/keras/src/losses/losses.py#L1983-L2010
|
|
19
|
-
export function diceBinaryGlobal(y_true, y_pred) {
|
|
20
|
-
const y_true_flat = tf.reshape(y_true, [-1]);
|
|
21
|
-
const y_pred_flat = tf.reshape(y_pred, [-1]);
|
|
22
|
-
const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat));
|
|
23
|
-
const union = tf.add(tf.sum(y_true_flat), tf.sum(y_pred_flat));
|
|
24
|
-
const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
|
|
25
|
-
return tf.scalar(1).sub(dice);
|
|
26
|
-
}
|
|
27
|
-
// prevents minification of function name which TFJS relies on
|
|
28
|
-
Object.defineProperty(diceBinaryGlobal, "name", { value: "diceBinaryGlobal", configurable: false });
|
|
29
|
-
export function diceCategoricalStandard(y_true, y_pred) {
|
|
30
|
-
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_HW);
|
|
31
|
-
const union = tf.add(y_true, y_pred).sum(REDUCE_HW);
|
|
32
|
-
const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
|
|
33
|
-
return tf.scalar(1).sub(tf.mean(dice, -1));
|
|
34
|
-
}
|
|
35
|
-
// prevents minification of function name which TFJS relies on
|
|
36
|
-
Object.defineProperty(diceCategoricalStandard, "name", { value: "diceCategoricalStandard", configurable: false });
|
|
37
|
-
export function diceCategoricalGeneralized(y_true, y_pred) {
|
|
38
|
-
// this is done twice so we calculate it once
|
|
39
|
-
const y_true_sum = y_true.sum(REDUCE_BHW);
|
|
40
|
-
const weighting = tf.div(1, y_true_sum.square().add(epsilon));
|
|
41
|
-
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHW).mul(weighting).sum();
|
|
42
|
-
const union = tf.add(y_true_sum, y_pred.sum(REDUCE_BHW)).mul(weighting).sum();
|
|
43
|
-
const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
|
|
44
|
-
return tf.scalar(1).sub(dice);
|
|
45
|
-
}
|
|
46
|
-
// prevents minification of function name which TFJS relies on
|
|
47
|
-
Object.defineProperty(diceCategoricalGeneralized, "name", { value: "diceCategoricalGeneralized", configurable: false });
|
|
48
|
-
export function diceCategoricalGlobal(y_true, y_pred) {
|
|
49
|
-
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHWC);
|
|
50
|
-
const union = tf.add(tf.sum(y_true, REDUCE_BHWC), tf.sum(y_pred, REDUCE_BHWC));
|
|
51
|
-
const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
|
|
52
|
-
return tf.scalar(1).sub(dice);
|
|
53
|
-
}
|
|
54
|
-
// prevents minification of function name which TFJS relies on
|
|
55
|
-
Object.defineProperty(diceCategoricalGlobal, "name", { value: "diceCategoricalGlobal", configurable: false });
|
|
56
|
-
/**
|
|
57
|
-
* Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
|
|
58
|
-
* Both have equal weight.
|
|
59
|
-
*
|
|
60
|
-
* @param y_true the label tensor
|
|
61
|
-
* @param y_pred the prediction tensor (not sparse)
|
|
62
|
-
* @returns a tensor of shape `[ batch ]`
|
|
63
|
-
*/
|
|
64
|
-
export function diceBinaryCrossentropy(y_true, y_pred) {
|
|
65
|
-
// reduce cross entropy shape from [B, H, W] to [B] to match dice
|
|
66
|
-
const bce = binaryCrossentropy(y_true, y_pred).mean(REDUCE_HW);
|
|
67
|
-
const dice = diceBinaryStandard(y_true, y_pred);
|
|
68
|
-
return tf.add(bce.mul(0.5), dice.mul(0.5));
|
|
69
|
-
}
|
|
70
|
-
// prevents minification of function name which TFJS relies on
|
|
71
|
-
Object.defineProperty(diceBinaryCrossentropy, "name", { value: "diceBinaryCrossentropy", configurable: false });
|
|
72
|
-
/**
|
|
73
|
-
* Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
|
|
74
|
-
* Both have equal weight. Expects dense (non-sparse) label tensors.
|
|
75
|
-
*
|
|
76
|
-
* This does not support sparse tensors because TFJS's
|
|
77
|
-
* sparseCategoricalCrossentropy loss onehots the label
|
|
78
|
-
* and calls categoricalCrossentropy. See
|
|
79
|
-
* https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
|
|
80
|
-
*
|
|
81
|
-
* @param y_true the label
|
|
82
|
-
* @param y_pred the prediction tensor (not sparse)
|
|
83
|
-
* @returns a tensor of shape `[ batch ]`
|
|
84
|
-
*/
|
|
85
|
-
export function diceCategoricalCrossentropy(y_true, y_pred) {
|
|
86
|
-
// reduce cross entropy shape from [B, H, W] to [B] to match dice
|
|
87
|
-
const cce = categoricalCrossentropy(y_true, y_pred).mean(REDUCE_HW);
|
|
88
|
-
const dice = diceCategoricalStandard(y_true, y_pred);
|
|
89
|
-
return tf.add(cce.mul(0.5), dice.mul(0.5));
|
|
90
|
-
}
|
|
91
|
-
// prevents minification of function name which TFJS relies on
|
|
92
|
-
Object.defineProperty(diceCategoricalCrossentropy, "name", { value: "diceCategoricalCrossentropy", configurable: false });
|
|
93
|
-
//# sourceMappingURL=dice.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"dice.js","sourceRoot":"","sources":["../../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,uBAAuB,EAAE,kBAAkB,EAAE,MAAM,qCAAqC,CAAC;AAElG,MAAM,OAAO,GAAG,IAAI,CAAC;AAErB,MAAM,SAAS,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,+BAA+B;AACzD,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,mCAAmC;AACjE,MAAM,WAAW,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,wBAAwB;AAG1D,gCAAgC;AAChC,MAAM,UAAU,kBAAkB,CAAC,MAAiB,EAAE,MAAiB;IAEnE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,EAAE,CAAC,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC;IAErE,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,kBAAkB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,oBAAoB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxG,yFAAyF;AACzF,MAAM,UAAU,gBAAgB,CAAC,MAAiB,EAAE,MAAiB;IACjE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC7C,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE7C,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IAC9D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC,CAAC;IAE/D,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,gBAAgB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,kBAAkB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGpG,MAAM,UAAU,uBAAuB,CAAC,MAAiB,EAAE,MAAiB;IACxE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,SAAS,CAAC,CAAC;IAC/D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC;IAEpD,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,uBAAuB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,yBAAyB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGlH,MAAM,UAAU,0BAA0B,CAAC,MAAiB,EAAE,MAAiB;IAE3E,6CAA6C;IAC7C,MAAM,UAAU,GAAG,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;IAE1C,MAAM,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,UAAU,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,UAAU,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IACrF,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,EAAE,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IAE9E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,0BAA0B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,4BAA4B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxH,MAAM,UAAU,qBAAqB,CAAC,MAAiB,EAAE,MAAiB;IAEtE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,WAAW,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC,CAAC;IAE/E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,qBAAqB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,uBAAuB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAG9G;;;;;;;GAOG;AACH,MAAM,UAAU,sBAAsB,CAAC,MAAiB,EAAE,MAAiB;IACvE,iEAAiE;IACjE,MAAM,GAAG,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC/D,MAAM,IAAI,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAEhD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,sBAAsB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,wBAAwB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGhH;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,2BAA2B,CAAC,MAAiB,EAAE,MAAiB;IAC5E,iEAAiE;IACjE,MAAM,GAAG,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpE,MAAM,IAAI,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAErD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,2BAA2B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,6BAA6B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
|
package/dist/src/losses/index.js
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
|
package/dist/src/masks.d.ts
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* Generate a causal mask used in self-attention to prevent tokens from looking
|
|
4
|
-
* ahead. The values in the upper right portion of the mask matrix are set to
|
|
5
|
-
* -1e7 so that they have no impact during scaled dot product attention.
|
|
6
|
-
*/
|
|
7
|
-
export declare function causal(query_seq_length: number, key_seq_length: number): tf.Tensor<tf.Rank>;
|
|
8
|
-
/**
|
|
9
|
-
* Generate a self-attention mask that prevents packed sequences from cross document
|
|
10
|
-
* boundaries and attending to each other. The result is a tensor of diagonally
|
|
11
|
-
* positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
|
|
12
|
-
* The latter is scored zero during the scaled dot product attention's softmax operation.
|
|
13
|
-
*
|
|
14
|
-
* @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
|
|
15
|
-
*
|
|
16
|
-
* Example boundary of 3 samples that are packed into one:
|
|
17
|
-
* `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
|
|
18
|
-
*/
|
|
19
|
-
export declare function packing(boundaries: Int32Array): tf.Tensor<tf.Rank>;
|
|
20
|
-
//# sourceMappingURL=masks.d.ts.map
|
package/dist/src/masks.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"masks.d.ts","sourceRoot":"","sources":["../../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,wBAAgB,MAAM,CAAC,gBAAgB,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,sBAItE;AAGD;;;;;;;;;;GAUG;AACH,wBAAgB,OAAO,CAAC,UAAU,EAAE,UAAU,sBAc7C"}
|
package/dist/src/masks.js
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* Generate a causal mask used in self-attention to prevent tokens from looking
|
|
4
|
-
* ahead. The values in the upper right portion of the mask matrix are set to
|
|
5
|
-
* -1e7 so that they have no impact during scaled dot product attention.
|
|
6
|
-
*/
|
|
7
|
-
export function causal(query_seq_length, key_seq_length) {
|
|
8
|
-
return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
|
|
9
|
-
.sub(1)
|
|
10
|
-
.mul(1e7);
|
|
11
|
-
}
|
|
12
|
-
/**
|
|
13
|
-
* Generate a self-attention mask that prevents packed sequences from cross document
|
|
14
|
-
* boundaries and attending to each other. The result is a tensor of diagonally
|
|
15
|
-
* positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
|
|
16
|
-
* The latter is scored zero during the scaled dot product attention's softmax operation.
|
|
17
|
-
*
|
|
18
|
-
* @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
|
|
19
|
-
*
|
|
20
|
-
* Example boundary of 3 samples that are packed into one:
|
|
21
|
-
* `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
|
|
22
|
-
*/
|
|
23
|
-
export function packing(boundaries) {
|
|
24
|
-
// see images at
|
|
25
|
-
// https://reddit.com/r/LocalLLaMA/comments/197efaz/training_llama_mistral_and_mixtralmoe_faster_with/
|
|
26
|
-
return tf.tidy(() => {
|
|
27
|
-
// cumsum transforms the tensor such that each sequence in the pack gets its own id,
|
|
28
|
-
const partitions = tf.tensor1d(boundaries).cumsum();
|
|
29
|
-
return partitions.expandDims(1)
|
|
30
|
-
.equal(partitions.expandDims(0))
|
|
31
|
-
.sub(1)
|
|
32
|
-
.mul(1e7)
|
|
33
|
-
// introduce a head dimension so it can be broadcasted
|
|
34
|
-
.expandDims(0);
|
|
35
|
-
});
|
|
36
|
-
}
|
|
37
|
-
//# sourceMappingURL=masks.js.map
|
package/dist/src/masks.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"masks.js","sourceRoot":"","sources":["../../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAC,gBAAwB,EAAE,cAAsB;IACnE,OAAO,EAAE,CAAC,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,gBAAgB,EAAE,cAAc,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC;SACxE,GAAG,CAAC,CAAC,CAAC;SACN,GAAG,CAAC,GAAG,CAAC,CAAC;AAClB,CAAC;AAGD;;;;;;;;;;GAUG;AACH,MAAM,UAAU,OAAO,CAAC,UAAsB;IAC1C,gBAAgB;IAChB,sGAAsG;IACtG,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,oFAAoF;QACpF,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,CAAC;QAEpD,OAAO,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC;aAC1B,KAAK,CAAC,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;aAC/B,GAAG,CAAC,CAAC,CAAC;aACN,GAAG,CAAC,GAAG,CAAC;YACT,sDAAsD;aACrD,UAAU,CAAC,CAAC,CAAC,CAAC;IACvB,CAAC,CAAC,CAAA;AACN,CAAC"}
|
package/dist/src/metrics.d.ts
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import { Tensor } from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* Applies the recall metric with the prediction rounded based on a threshold
|
|
4
|
-
*
|
|
5
|
-
* @param y_true the label tensor
|
|
6
|
-
* @param y_pred the prediction tensor
|
|
7
|
-
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
8
|
-
* @returns
|
|
9
|
-
*/
|
|
10
|
-
export declare function recall(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
|
|
11
|
-
/**
|
|
12
|
-
* Applies the precision metric with the prediction rounded based on a threshold
|
|
13
|
-
*
|
|
14
|
-
* @param y_true the label tensor
|
|
15
|
-
* @param y_pred the prediction tensor
|
|
16
|
-
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
17
|
-
* @returns
|
|
18
|
-
*/
|
|
19
|
-
export declare function precision(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
|
|
20
|
-
//# sourceMappingURL=metrics.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"metrics.d.ts","sourceRoot":"","sources":["../../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAW,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,wBAAgB,MAAM,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAE7E;AAKD;;;;;;;GAOG;AACH,wBAAgB,SAAS,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAEhF"}
|
package/dist/src/metrics.js
DELETED
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
import { metrics } from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* Applies the recall metric with the prediction rounded based on a threshold
|
|
4
|
-
*
|
|
5
|
-
* @param y_true the label tensor
|
|
6
|
-
* @param y_pred the prediction tensor
|
|
7
|
-
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
8
|
-
* @returns
|
|
9
|
-
*/
|
|
10
|
-
export function recall(y_true, y_pred, threshold = 0.5) {
|
|
11
|
-
return metrics.recall(y_true, y_pred.greaterEqual(threshold));
|
|
12
|
-
}
|
|
13
|
-
// prevents minification of function name which TFJS relies on
|
|
14
|
-
Object.defineProperty(recall, "name", { value: "recall", configurable: false });
|
|
15
|
-
/**
|
|
16
|
-
* Applies the precision metric with the prediction rounded based on a threshold
|
|
17
|
-
*
|
|
18
|
-
* @param y_true the label tensor
|
|
19
|
-
* @param y_pred the prediction tensor
|
|
20
|
-
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
21
|
-
* @returns
|
|
22
|
-
*/
|
|
23
|
-
export function precision(y_true, y_pred, threshold = 0.5) {
|
|
24
|
-
return metrics.precision(y_true, y_pred.greaterEqual(threshold));
|
|
25
|
-
}
|
|
26
|
-
// prevents minification of function name which TFJS relies on
|
|
27
|
-
Object.defineProperty(precision, "name", { value: "precision", configurable: false });
|
|
28
|
-
//# sourceMappingURL=metrics.js.map
|
package/dist/src/metrics.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"metrics.js","sourceRoot":"","sources":["../../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,OAAO,EAAU,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,MAAM,UAAU,MAAM,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC1E,OAAO,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AAClE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,QAAQ,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAEhF;;;;;;;GAOG;AACH,MAAM,UAAU,SAAS,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC7E,OAAO,OAAO,CAAC,SAAS,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AACrE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,SAAS,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,WAAW,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
|
|
@@ -1,94 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type LossOrMetricFn } from "@/tfjs_types";
|
|
3
|
-
import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
|
|
4
|
-
import { KvCacheContainer } from "@/kv_cache";
|
|
5
|
-
import { type DisposeResult } from "@tensorflow/tfjs-layers/dist/engine/topology";
|
|
6
|
-
export interface GptModelArgs extends LlmModelArgs {
|
|
7
|
-
/**
|
|
8
|
-
* Number of heads per attention layer.
|
|
9
|
-
*/
|
|
10
|
-
numHeads: number;
|
|
11
|
-
/**
|
|
12
|
-
* Number of GPT decoder blocks.
|
|
13
|
-
*/
|
|
14
|
-
numLayers: number;
|
|
15
|
-
/**
|
|
16
|
-
* The embedding size of each token.
|
|
17
|
-
*/
|
|
18
|
-
embedDim: number;
|
|
19
|
-
/**
|
|
20
|
-
* The vocabulary size of the embedding layer and number of units of the output
|
|
21
|
-
* layer. This is also the tokenizer vocabulary size.
|
|
22
|
-
*/
|
|
23
|
-
vocabSize: number;
|
|
24
|
-
/**
|
|
25
|
-
* Pad the embeddings' vocab size and output layer's units to the next nearest
|
|
26
|
-
* multiple of 64 to optimize hardware efficiency. Defaults to `true`.
|
|
27
|
-
*
|
|
28
|
-
* For example: if a tokenizer has 50,257 tokens, the model uses 50,304 for the
|
|
29
|
-
* vocab size and output units count.
|
|
30
|
-
*/
|
|
31
|
-
padToMultipleOf64?: boolean;
|
|
32
|
-
}
|
|
33
|
-
/**
|
|
34
|
-
* This is a subclass of tf.Sequential that creating a GPT-like model and
|
|
35
|
-
* automatically handles padding (and masking) the vocab size for hardware
|
|
36
|
-
* efficiency.
|
|
37
|
-
*
|
|
38
|
-
* Example:
|
|
39
|
-
*
|
|
40
|
-
* ```javascript
|
|
41
|
-
*
|
|
42
|
-
* const model = new GptModel({ numLayers: 1, numHeads: 1, embedDim: 16, vocabSize: 64 });
|
|
43
|
-
* model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
|
|
44
|
-
*
|
|
45
|
-
* // use fitDataset() instead of fit for masking support
|
|
46
|
-
* model.fitDataset(your_batched_generator_dataset, { epochs: 1 });
|
|
47
|
-
*
|
|
48
|
-
* const kv_cache = new KvCacheContainer(your_preferred_max_sequence_length);
|
|
49
|
-
*
|
|
50
|
-
* // use generate() and predictNextToken() instead of predict() for masking and auto memory cleanup
|
|
51
|
-
* model.generate(tokenized_tensor1d_input, kv_cache, onPredict_callback)
|
|
52
|
-
*
|
|
53
|
-
*
|
|
54
|
-
* ```
|
|
55
|
-
*/
|
|
56
|
-
export declare class GptModel extends LlmModel {
|
|
57
|
-
static className: string;
|
|
58
|
-
protected readonly numHeads: number;
|
|
59
|
-
protected readonly numLayers: number;
|
|
60
|
-
protected readonly embedDim: number;
|
|
61
|
-
protected readonly vocabSize: number;
|
|
62
|
-
protected readonly padToMultipleOf64: boolean;
|
|
63
|
-
protected readonly vocabSizePadded: number;
|
|
64
|
-
protected vocab_padding_mask?: tf.Tensor1D;
|
|
65
|
-
/**
|
|
66
|
-
* DO NOT add layers in the constructor or it will break tf.loadLayersModel().
|
|
67
|
-
* It should be done in build() instead.
|
|
68
|
-
*/
|
|
69
|
-
constructor(args: GptModelArgs);
|
|
70
|
-
protected fitBatch(xs: tf.Tensor, ys: tf.Tensor, loss_mask: tf.Tensor | undefined, loss_function: LossOrMetricFn, other_masks?: {
|
|
71
|
-
[key: string]: tf.Tensor | undefined;
|
|
72
|
-
}): {
|
|
73
|
-
y_pred: tf.Tensor<tf.Rank>;
|
|
74
|
-
loss: tf.Scalar;
|
|
75
|
-
};
|
|
76
|
-
/**
|
|
77
|
-
* Overrides LlmModel.predictNextToken to add softmax before argMax because the final
|
|
78
|
-
* dense layer doesn't have an activation.
|
|
79
|
-
*
|
|
80
|
-
* TODO: implement temperature and multinomial sampling so that the model has varied outputs
|
|
81
|
-
*/
|
|
82
|
-
predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer): tf.Tensor2D;
|
|
83
|
-
build(inputShape?: tf.Shape | tf.Shape[]): void;
|
|
84
|
-
dispose(): DisposeResult;
|
|
85
|
-
getConfig(): {
|
|
86
|
-
numHeads: number;
|
|
87
|
-
numLayers: number;
|
|
88
|
-
embedDim: number;
|
|
89
|
-
vocabSize: number;
|
|
90
|
-
vocabSizePadded: number;
|
|
91
|
-
padToMultipleOf64: boolean;
|
|
92
|
-
};
|
|
93
|
-
}
|
|
94
|
-
//# sourceMappingURL=gpt_model.d.ts.map
|