@stellarapp/tfjs-stellar 1.0.4 → 1.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +2 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
package/dist/src/models/u_net.js
DELETED
|
@@ -1,151 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
export class UNetModel extends tf.Sequential {
|
|
3
|
-
constructor(args) {
|
|
4
|
-
const { filters, units, activation = units == 1 ? "sigmoid" : "softmax", depth, residual = false, batchNorm = false, inputShape = [null, null, 3], ...sequentialArgs } = args;
|
|
5
|
-
sequentialArgs.name = sequentialArgs.name ?? "unet_model";
|
|
6
|
-
super({
|
|
7
|
-
...sequentialArgs,
|
|
8
|
-
// calling user should not modify the layers after instantiation
|
|
9
|
-
layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
|
|
10
|
-
});
|
|
11
|
-
}
|
|
12
|
-
summary(lineLength, positions, printFn) {
|
|
13
|
-
super.summary(lineLength, positions, printFn);
|
|
14
|
-
this.layers[0].summary(lineLength, positions, printFn);
|
|
15
|
-
}
|
|
16
|
-
}
|
|
17
|
-
export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }) {
|
|
18
|
-
if (units < 1) {
|
|
19
|
-
throw Error(`createUNet: units should be >= 1, got ${units}`);
|
|
20
|
-
}
|
|
21
|
-
const [image_height, image_width] = inputShape;
|
|
22
|
-
const divisble_by = 2 ** depth;
|
|
23
|
-
if ((image_height != null && image_height % divisble_by != 0) ||
|
|
24
|
-
image_width != null && image_width % divisble_by != 0) {
|
|
25
|
-
throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`);
|
|
26
|
-
}
|
|
27
|
-
const input = tf.input({ shape: inputShape });
|
|
28
|
-
const skip_connection = [];
|
|
29
|
-
let x = input;
|
|
30
|
-
// calculate the filter sizes for each level
|
|
31
|
-
const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
|
|
32
|
-
for (const filter_size of filter_sizes) {
|
|
33
|
-
const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
|
|
34
|
-
x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction);
|
|
35
|
-
skip_connection.push(contraction);
|
|
36
|
-
}
|
|
37
|
-
x = contractionBlock(x, filter_sizes.at(-1) * 2, residual, batchNorm, "bottleneck");
|
|
38
|
-
for (let i = filter_sizes.length - 1; i >= 0; i--) {
|
|
39
|
-
x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
|
|
40
|
-
}
|
|
41
|
-
const output = tf.layers.conv2d({
|
|
42
|
-
filters: units,
|
|
43
|
-
kernelSize: 1,
|
|
44
|
-
padding: "same",
|
|
45
|
-
activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
|
|
46
|
-
name: "output-conv"
|
|
47
|
-
}).apply(x);
|
|
48
|
-
return tf.model({ inputs: input, outputs: output, name: "u_net" });
|
|
49
|
-
}
|
|
50
|
-
export async function loadUNetModel(pathOrIOHandler, options) {
|
|
51
|
-
const model = await tf.loadLayersModel(pathOrIOHandler, options);
|
|
52
|
-
const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
|
|
53
|
-
const { name, ...rest } = model;
|
|
54
|
-
Object.assign(unet, rest);
|
|
55
|
-
return unet;
|
|
56
|
-
}
|
|
57
|
-
/**
|
|
58
|
-
* The contraction block of a U-Net
|
|
59
|
-
*
|
|
60
|
-
* Conv > BN > ReLU > Conv > BN + residual > ReLU
|
|
61
|
-
*
|
|
62
|
-
* TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
|
|
63
|
-
*
|
|
64
|
-
* @param x a previous layer's symbolic output
|
|
65
|
-
* @param filters the number of filters, usually half the previous expansion block's
|
|
66
|
-
* @param residual includes a residual connection
|
|
67
|
-
* @param batchNorm applies batch normalization before ReLU activation
|
|
68
|
-
* @param name a unique name for the contraction block
|
|
69
|
-
*/
|
|
70
|
-
function contractionBlock(x, filters, residual, batchNorm, name) {
|
|
71
|
-
const conv1 = tf.layers.conv2d({
|
|
72
|
-
filters,
|
|
73
|
-
kernelSize: 3,
|
|
74
|
-
padding: "same",
|
|
75
|
-
useBias: !batchNorm,
|
|
76
|
-
kernelInitializer: "heNormal",
|
|
77
|
-
name: `${name}-1-conv2d`
|
|
78
|
-
});
|
|
79
|
-
const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
|
|
80
|
-
const conv2 = tf.layers.conv2d({
|
|
81
|
-
filters,
|
|
82
|
-
kernelSize: 3,
|
|
83
|
-
padding: "same",
|
|
84
|
-
useBias: !batchNorm,
|
|
85
|
-
kernelInitializer: "heNormal",
|
|
86
|
-
name: `${name}-2-conv2d`
|
|
87
|
-
});
|
|
88
|
-
const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
|
|
89
|
-
let forward = conv1.apply(x);
|
|
90
|
-
if (batchNorm) {
|
|
91
|
-
forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
|
|
92
|
-
}
|
|
93
|
-
forward = relu1.apply(forward);
|
|
94
|
-
forward = conv2.apply(forward);
|
|
95
|
-
if (batchNorm) {
|
|
96
|
-
forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
|
|
97
|
-
}
|
|
98
|
-
if (residual) {
|
|
99
|
-
let residual_skip = x;
|
|
100
|
-
if (x.shape[x.shape.length - 1] != filters) {
|
|
101
|
-
// a 1x1 convolution on the input to ensure the residual connection's
|
|
102
|
-
// channels/filters dim matches the convolution output
|
|
103
|
-
residual_skip = tf.layers.conv2d({
|
|
104
|
-
filters,
|
|
105
|
-
kernelSize: 1,
|
|
106
|
-
padding: "same",
|
|
107
|
-
useBias: !batchNorm,
|
|
108
|
-
kernelInitializer: "heNormal",
|
|
109
|
-
name: `${name}-residual`
|
|
110
|
-
}).apply(x);
|
|
111
|
-
}
|
|
112
|
-
if (batchNorm) {
|
|
113
|
-
residual_skip = tf.layers.batchNormalization({
|
|
114
|
-
name: `${name}-residual-batchnorm`
|
|
115
|
-
}).apply(residual_skip);
|
|
116
|
-
}
|
|
117
|
-
forward = tf.layers.add().apply([
|
|
118
|
-
residual_skip,
|
|
119
|
-
forward
|
|
120
|
-
]);
|
|
121
|
-
}
|
|
122
|
-
forward = relu2.apply(forward);
|
|
123
|
-
return forward;
|
|
124
|
-
}
|
|
125
|
-
/**
|
|
126
|
-
* The expansion block of a U-Net
|
|
127
|
-
*
|
|
128
|
-
* Upconv + skip > contraction block
|
|
129
|
-
*
|
|
130
|
-
* @param x a previous layer's symbolic output
|
|
131
|
-
* @param skip the corresponding contraction block's output (before pool), shape matches `x`
|
|
132
|
-
* @param filters the number of filters, usually half the previous expansion block's
|
|
133
|
-
* @param residual includes a residual connection
|
|
134
|
-
* @param batchNorm apply batch normalization, should be `false` when batch size is `1`
|
|
135
|
-
* @param name a unique name for the contraction block
|
|
136
|
-
*/
|
|
137
|
-
function expansionBlock(x, skip, filters, residual, batchNorm, name) {
|
|
138
|
-
const upconv = tf.layers.conv2dTranspose({
|
|
139
|
-
filters,
|
|
140
|
-
padding: "same",
|
|
141
|
-
kernelSize: 2,
|
|
142
|
-
strides: 2,
|
|
143
|
-
kernelInitializer: "heNormal",
|
|
144
|
-
name: `${name}-upconv`
|
|
145
|
-
});
|
|
146
|
-
const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
|
|
147
|
-
let forward = upconv.apply(x);
|
|
148
|
-
forward = concat.apply([forward, skip]);
|
|
149
|
-
return contractionBlock(forward, filters, residual, batchNorm, name);
|
|
150
|
-
}
|
|
151
|
-
//# sourceMappingURL=u_net.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"u_net.js","sourceRoot":"","sources":["../../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAuCvC,MAAM,OAAO,SAAU,SAAQ,EAAE,CAAC,UAAU;IAExC,YAAY,IAAmB;QAC3B,MAAM,EACF,OAAO,EACP,KAAK,EACL,UAAU,GAAG,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,EAC/C,KAAK,EACL,QAAQ,GAAG,KAAK,EAChB,SAAS,GAAG,KAAK,EACjB,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAC5B,GAAG,cAAc,EACpB,GAAG,IAAI,CAAC;QAET,cAAc,CAAC,IAAI,GAAG,cAAc,CAAC,IAAI,IAAI,YAAY,CAAC;QAE1D,KAAK,CAAC;YACF,GAAG,cAAc;YACjB,gEAAgE;YAChE,MAAM,EAAE,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,UAAU,EAAE,CAAC,CAAC;SAC/F,CAAC,CAAC;IACP,CAAC;IAGQ,OAAO,CAAC,UAAmB,EAAE,SAAoB,EAAE,OAA2D;QACnH,KAAK,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,CAAC,CAAC,CAAoB,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IAC/E,CAAC;CACJ;AAGD,MAAM,UAAU,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAQ,GAAG,KAAK,EAAE,SAAS,GAAG,KAAK,EAAE,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAAiB;IAC9I,IAAI,KAAK,GAAG,CAAC,EAAE,CAAC;QACZ,MAAM,KAAK,CAAC,yCAAyC,KAAK,EAAE,CAAC,CAAC;IAClE,CAAC;IAED,MAAM,CAAC,YAAY,EAAE,WAAW,CAAC,GAAG,UAAU,CAAC;IAC/C,MAAM,WAAW,GAAG,CAAC,IAAI,KAAK,CAAC;IAE/B,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,YAAY,GAAG,WAAW,IAAI,CAAC,CAAC;QACzD,WAAW,IAAI,IAAI,IAAI,WAAW,GAAG,WAAW,IAAI,CAAC,EAAE,CAAC;QACxD,MAAM,KAAK,CAAC,wEAAwE,WAAW,GAAG,CAAC,CAAA;IACvG,CAAC;IAED,MAAM,KAAK,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,KAAK,EAAE,UAAU,EAAE,CAAC,CAAC;IAE9C,MAAM,eAAe,GAAwB,EAAE,CAAC;IAEhD,IAAI,CAAC,GAAG,KAAK,CAAC;IAEd,4CAA4C;IAC5C,MAAM,YAAY,GAAG,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,OAAO,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjF,KAAK,MAAM,WAAW,IAAI,YAAY,EAAE,CAAC;QACrC,MAAM,WAAW,GAAG,gBAAgB,CAAC,CAAC,EAAE,WAAW,EAAE,QAAQ,EAAE,SAAS,EAAE,gBAAgB,WAAW,EAAE,CAAC,CAAC;QAEzG,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,YAAY,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,IAAI,EAAE,SAAS,WAAW,EAAE,EAAE,CAAC,CAAC,KAAK,CAAC,WAAW,CAAsB,CAAC;QAC9H,eAAe,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;IACtC,CAAC;IAED,CAAC,GAAG,gBAAgB,CAAC,CAAC,EAAE,YAAY,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,GAAG,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IAErF,KAAK,IAAI,CAAC,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QAChD,CAAC,GAAG,cAAc,CAAC,CAAC,EAAE,eAAe,CAAC,CAAC,CAAC,EAAE,YAAY,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,cAAc,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;IACrH,CAAC;IAED,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC5B,OAAO,EAAE,KAAK;QACd,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,UAAU,IAAI,CAAC,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,CAAC;QAC9D,IAAI,EAAE,aAAa;KACtB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IAEjC,OAAO,EAAE,CAAC,KAAK,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC;AACvE,CAAC;AAGD,MAAM,CAAC,KAAK,UAAU,aAAa,CAAC,eAAyC,EAAE,OAA2B;IACtG,MAAM,KAAK,GAAG,MAAM,EAAE,CAAC,eAAe,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;IACjE,MAAM,IAAI,GAAG,UAAU,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,4CAA4C;IACzG,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,EAAE,GAAG,KAAK,CAAC;IAChC,MAAM,CAAC,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;IAE1B,OAAO,IAAI,CAAC;AAChB,CAAC;AAGD;;;;;;;;;;;;GAYG;AACH,SAAS,gBAAgB,CAAC,CAAoB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEhH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,IAAI,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAE7B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,IAAI,QAAQ,EAAE,CAAC;QACX,IAAI,aAAa,GAAG,CAAC,CAAC;QAEtB,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,IAAI,OAAO,EAAE,CAAC;YACzC,qEAAqE;YACrE,sDAAsD;YACtD,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;gBAC7B,OAAO;gBACP,UAAU,EAAE,CAAC;gBACb,OAAO,EAAE,MAAM;gBACf,OAAO,EAAE,CAAC,SAAS;gBACnB,iBAAiB,EAAE,UAAU;gBAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;aAC3B,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;QACrC,CAAC;QAED,IAAI,SAAS,EAAE,CAAC;YACZ,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC;gBACzC,IAAI,EAAE,GAAG,IAAI,qBAAqB;aACrC,CAAC,CAAC,KAAK,CAAC,aAAa,CAAsB,CAAC;QACjD,CAAC;QAED,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;YAC5B,aAAkC;YAClC,OAA4B;SAC/B,CAAC,CAAA;IACN,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,OAA4B,CAAC;AACxC,CAAC;AAGD;;;;;;;;;;;GAWG;AACH,SAAS,cAAc,CAAC,CAAoB,EAAE,IAAuB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEvI,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,eAAe,CAAC;QACrC,OAAO;QACP,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,CAAC;QACV,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,SAAS;KACzB,CAAC,CAAC;IAEH,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,EAAE,IAAI,EAAE,CAAC,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,qBAAqB,EAAE,CAAC,CAAC;IAEvF,IAAI,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IACnD,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,EAAE,IAAI,CAAC,CAAsB,CAAC;IAE7D,OAAO,gBAAgB,CAAC,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC;AACzE,CAAC"}
|
package/dist/src/tfjs_types.d.ts
DELETED
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
import type { Tensor } from "@tensorflow/tfjs";
|
|
2
|
-
export declare abstract class LazyIterator<T> {
|
|
3
|
-
abstract next(): Promise<IteratorResult<T>>;
|
|
4
|
-
}
|
|
5
|
-
export declare abstract class Dataset<T> {
|
|
6
|
-
abstract iterator(): Promise<LazyIterator<T>>;
|
|
7
|
-
size: number;
|
|
8
|
-
}
|
|
9
|
-
export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor;
|
|
10
|
-
//# sourceMappingURL=tfjs_types.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"tfjs_types.d.ts","sourceRoot":"","sources":["../../src/tfjs_types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAG/C,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,YAAY,CAAC,CAAC;IACxC,QAAQ,CAAC,IAAI,IAAI,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC;CAC9C;AAGD,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,OAAO,CAAC,CAAC;IACnC,QAAQ,CAAC,QAAQ,IAAI,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;IAC7C,IAAI,EAAE,MAAM,CAAC;CAChB;AAGD,MAAM,MAAM,cAAc,GAAG,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,KAAK,MAAM,CAAC"}
|
package/dist/src/tfjs_types.js
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"tfjs_types.js","sourceRoot":"","sources":["../../src/tfjs_types.ts"],"names":[],"mappings":""}
|
package/dist/src/utils.d.ts
DELETED
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* Calculate the desired scaled image's height and width. The shortest edge will
|
|
4
|
-
* be scaled to match its corresponding target shape's edge. The longer
|
|
5
|
-
* edge might end up larger than intended.
|
|
6
|
-
*
|
|
7
|
-
* @param image_shape the `[height, width]` of the image
|
|
8
|
-
* @param target_shape the intended `[height, width]` of the final scaled image
|
|
9
|
-
*/
|
|
10
|
-
export declare function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number];
|
|
11
|
-
/**
|
|
12
|
-
* Calculate the starting point for a crop (slice) operation
|
|
13
|
-
* on an image tensor with the shape `[height, width, channels]`.
|
|
14
|
-
*
|
|
15
|
-
* @param image_shape the `[height, width]` of the image
|
|
16
|
-
* @param target_shape the intended `[height, width]` of the final cropped image
|
|
17
|
-
*/
|
|
18
|
-
export declare function getRandomCropStart(image_shape: [height: number, width: number], target_shape: [height: number, width: number]): [number, number, number];
|
|
19
|
-
/**
|
|
20
|
-
* Calculate the height and width padding such that the image is
|
|
21
|
-
* divisible by 2^depth.
|
|
22
|
-
*
|
|
23
|
-
* In U-Net image segmentation, the contraction and concatenate
|
|
24
|
-
* operations requires the input image's height and width
|
|
25
|
-
* dimensions to be divisible by 2^depth.
|
|
26
|
-
*/
|
|
27
|
-
export declare function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number];
|
|
28
|
-
//# sourceMappingURL=utils.d.ts.map
|
package/dist/src/utils.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"utils.d.ts","sourceRoot":"","sources":["../../src/utils.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;;;;GAOG;AACH,wBAAgB,aAAa,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,EAAE,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,CAAC,aAAa,EAAE,MAAM,EAAE,YAAY,EAAE,MAAM,CAAC,CAelI;AAGD;;;;;;GAMG;AACH,wBAAgB,kBAAkB,CAC9B,WAAW,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,EAC5C,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,GAC9C,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAiB1B;AAGD;;;;;;;GAOG;AACH,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,MAAM,GAAG,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,CAS5G"}
|
package/dist/src/utils.js
DELETED
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Calculate the desired scaled image's height and width. The shortest edge will
|
|
3
|
-
* be scaled to match its corresponding target shape's edge. The longer
|
|
4
|
-
* edge might end up larger than intended.
|
|
5
|
-
*
|
|
6
|
-
* @param image_shape the `[height, width]` of the image
|
|
7
|
-
* @param target_shape the intended `[height, width]` of the final scaled image
|
|
8
|
-
*/
|
|
9
|
-
export function getScaleShape(image_shape, target_shape) {
|
|
10
|
-
const [img_height, img_width] = image_shape;
|
|
11
|
-
const [target_height, target_width] = target_shape;
|
|
12
|
-
// scale based on whichever target_edge / original_edge is largest,
|
|
13
|
-
// we need the following to be true (1)
|
|
14
|
-
// height * scale >= target_height
|
|
15
|
-
// width * scale >= target_height
|
|
16
|
-
// rearranging to get an equivalent requirement (2)
|
|
17
|
-
// scale >= target_height / height
|
|
18
|
-
// scale >= target_width / width
|
|
19
|
-
// by picking the scale value that's largest of the two, we satisfy (2), and therefore (1)
|
|
20
|
-
// it may be more intuitive to think of scale as scale_h and scale_w
|
|
21
|
-
const scale_factor = Math.max(target_height / img_height, target_width / img_width);
|
|
22
|
-
return [Math.round(img_height * scale_factor), Math.round(img_width * scale_factor)];
|
|
23
|
-
}
|
|
24
|
-
/**
|
|
25
|
-
* Calculate the starting point for a crop (slice) operation
|
|
26
|
-
* on an image tensor with the shape `[height, width, channels]`.
|
|
27
|
-
*
|
|
28
|
-
* @param image_shape the `[height, width]` of the image
|
|
29
|
-
* @param target_shape the intended `[height, width]` of the final cropped image
|
|
30
|
-
*/
|
|
31
|
-
export function getRandomCropStart(image_shape, target_shape) {
|
|
32
|
-
const [img_height, img_width] = image_shape;
|
|
33
|
-
const [crop_x, crop_y] = target_shape;
|
|
34
|
-
if (img_height < crop_x || img_width < crop_y) {
|
|
35
|
-
throw Error(`getRandomCropShape: cannot crop with a size that's bigger than,` +
|
|
36
|
-
` the image. Original [${img_height}, ${img_width}], crop [${crop_x}, ${crop_y}].`);
|
|
37
|
-
}
|
|
38
|
-
// there's a +1 because Math.random()'s range is [0, 1), excluding 1,
|
|
39
|
-
// hence +1 to ensure the full range of possible crop starting points
|
|
40
|
-
return [
|
|
41
|
-
// TODO: revisit the +1
|
|
42
|
-
Math.floor(Math.random() * (img_height - crop_x + 1)),
|
|
43
|
-
Math.floor(Math.random() * (img_width - crop_y + 1)),
|
|
44
|
-
0 // not cropping channels, so it starts at the first index
|
|
45
|
-
];
|
|
46
|
-
}
|
|
47
|
-
/**
|
|
48
|
-
* Calculate the height and width padding such that the image is
|
|
49
|
-
* divisible by 2^depth.
|
|
50
|
-
*
|
|
51
|
-
* In U-Net image segmentation, the contraction and concatenate
|
|
52
|
-
* operations requires the input image's height and width
|
|
53
|
-
* dimensions to be divisible by 2^depth.
|
|
54
|
-
*/
|
|
55
|
-
export function getPaddingForSegmentation(image, depth) {
|
|
56
|
-
const divisible = Math.pow(2, depth);
|
|
57
|
-
const [height, width] = image.shape;
|
|
58
|
-
return [
|
|
59
|
-
(Math.ceil(height / divisible)) * divisible - height,
|
|
60
|
-
(Math.ceil(width / divisible)) * divisible - width,
|
|
61
|
-
];
|
|
62
|
-
}
|
|
63
|
-
//# sourceMappingURL=utils.js.map
|
package/dist/src/utils.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"utils.js","sourceRoot":"","sources":["../../src/utils.ts"],"names":[],"mappings":"AAGA;;;;;;;GAOG;AACH,MAAM,UAAU,aAAa,CAAC,WAAqB,EAAE,YAA8B;IAC/E,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAuC,CAAC;IACxE,MAAM,CAAC,aAAa,EAAE,YAAY,CAAC,GAAG,YAAY,CAAC;IAEnD,mEAAmE;IACnE,uCAAuC;IACvC,kCAAkC;IAClC,iCAAiC;IACjC,mDAAmD;IACnD,kCAAkC;IAClC,gCAAgC;IAChC,0FAA0F;IAC1F,oEAAoE;IACpE,MAAM,YAAY,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,GAAG,UAAU,EAAE,YAAY,GAAG,SAAS,CAAC,CAAC;IACpF,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC,CAAC;AACzF,CAAC;AAGD;;;;;;GAMG;AACH,MAAM,UAAU,kBAAkB,CAC9B,WAA4C,EAC5C,YAA6C;IAE7C,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAW,CAAC;IAC5C,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,YAAY,CAAC;IAEtC,IAAI,UAAU,GAAG,MAAM,IAAI,SAAS,GAAG,MAAM,EAAE,CAAC;QAC5C,MAAM,KAAK,CAAC,iEAAiE;YACzE,yBAAyB,UAAU,KAAK,SAAS,YAAY,MAAM,KAAK,MAAM,IAAI,CAAC,CAAC;IAC5F,CAAC;IAED,qEAAqE;IACrE,qEAAqE;IACrE,OAAO;QACH,uBAAuB;QACvB,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,UAAU,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACrD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,SAAS,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACpD,CAAC,CAAC,yDAAyD;KAC9D,CAAA;AACL,CAAC;AAGD;;;;;;;GAOG;AACH,MAAM,UAAU,yBAAyB,CAAC,KAAkB,EAAE,KAAa;IACvE,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;IAErC,MAAM,CAAC,MAAM,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC;IAEpC,OAAO;QACH,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,MAAM;QACpD,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,KAAK;KACrD,CAAA;AACL,CAAC"}
|
package/dist/src/utils.test.d.ts
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"utils.test.d.ts","sourceRoot":"","sources":["../../src/utils.test.ts"],"names":[],"mappings":""}
|
package/dist/src/utils.test.js
DELETED
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { getScaleShape, getRandomCropStart } from "@/utils";
|
|
3
|
-
import { causal } from "@/masks";
|
|
4
|
-
// avoid TFJS node message during Jest testing
|
|
5
|
-
tf.env().set('IS_NODE', false);
|
|
6
|
-
describe("test custom TFJS utility functions", () => {
|
|
7
|
-
test("crop an image using the same shape, results in same shape", async () => {
|
|
8
|
-
// cropping an image of the same shape
|
|
9
|
-
const img_size = [133, 84];
|
|
10
|
-
const target_size = [133, 84];
|
|
11
|
-
expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
|
|
12
|
-
});
|
|
13
|
-
it("should throw when crop is larger than image", async () => {
|
|
14
|
-
expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
|
|
15
|
-
});
|
|
16
|
-
test("cropped image shape", async () => {
|
|
17
|
-
// cropping from wide to tall image
|
|
18
|
-
for (let i = 0; i < 100; i++) {
|
|
19
|
-
const img_size = [4923, 832];
|
|
20
|
-
const target_size = [333, 739];
|
|
21
|
-
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
|
|
22
|
-
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
23
|
-
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
24
|
-
}
|
|
25
|
-
// cropping from tall to wide image
|
|
26
|
-
for (let i = 0; i < 100; i++) {
|
|
27
|
-
const img_size = [381, 999];
|
|
28
|
-
const target_size = [300, 157];
|
|
29
|
-
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
|
|
30
|
-
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
31
|
-
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
32
|
-
}
|
|
33
|
-
});
|
|
34
|
-
test("scale 1:1, results in the same shape", async () => {
|
|
35
|
-
const scale = getScaleShape([256, 256], [256, 256]);
|
|
36
|
-
expect(scale).toEqual([256, 256]);
|
|
37
|
-
});
|
|
38
|
-
test("scaled image shape", async () => {
|
|
39
|
-
// scaling squares result in squares
|
|
40
|
-
const scale1 = getScaleShape([256, 256], [128, 128]);
|
|
41
|
-
expect(scale1).toEqual([128, 128]);
|
|
42
|
-
const scale2 = getScaleShape([128, 128], [256, 256]);
|
|
43
|
-
expect(scale2).toEqual([256, 256]);
|
|
44
|
-
const scale3 = getScaleShape([123, 123], [321, 321]);
|
|
45
|
-
expect(scale3).toEqual([321, 321]);
|
|
46
|
-
const scale4 = getScaleShape([321, 321], [123, 123]);
|
|
47
|
-
expect(scale4).toEqual([123, 123]);
|
|
48
|
-
// scaling rectangles result in rectangles
|
|
49
|
-
const scale5 = getScaleShape([640, 480], [1280, 960]);
|
|
50
|
-
expect(scale5).toEqual([1280, 960]);
|
|
51
|
-
const scale6 = getScaleShape([480, 640], [960, 1280]);
|
|
52
|
-
expect(scale6).toEqual([960, 1280]);
|
|
53
|
-
const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333]);
|
|
54
|
-
expect(scale7_h).toBeGreaterThan(scale7_w);
|
|
55
|
-
const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
|
|
56
|
-
expect(scale8_h).toBeLessThan(scale8_w);
|
|
57
|
-
});
|
|
58
|
-
test("causal attention map", async () => {
|
|
59
|
-
const seq_len = 4;
|
|
60
|
-
const causal_mask = causal(seq_len, seq_len);
|
|
61
|
-
const _ = -1e7;
|
|
62
|
-
const expected_mask = tf.tensor([
|
|
63
|
-
[0, _, _, _],
|
|
64
|
-
[0, 0, _, _],
|
|
65
|
-
[0, 0, 0, _],
|
|
66
|
-
[0, 0, 0, 0]
|
|
67
|
-
]);
|
|
68
|
-
// this might fail due to precision issues on the masked positions,
|
|
69
|
-
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
70
|
-
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
71
|
-
});
|
|
72
|
-
});
|
|
73
|
-
//# sourceMappingURL=utils.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAC5D,OAAO,EAAE,MAAM,EAAE,MAAM,SAAS,CAAC;AAEjC,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAE7C,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
|