@stellarapp/tfjs-stellar 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/jest.config.ts +203 -0
- package/package.json +24 -0
- package/src/index.ts +93 -0
- package/src/kv_cache.ts +205 -0
- package/src/layers/cached_rope_multihead_attention.test.ts +59 -0
- package/src/layers/cached_rope_multihead_attention.ts +113 -0
- package/src/layers/gpt_decoder_block.ts +77 -0
- package/src/layers/multihead_attention.test.ts +212 -0
- package/src/layers/multihead_attention.ts +371 -0
- package/src/layers/positional_encoding.test.ts +113 -0
- package/src/layers/positional_encoding.ts +158 -0
- package/src/layers/rotary_position_embedding.test.ts +107 -0
- package/src/layers/rotary_position_embedding.ts +163 -0
- package/src/layers/token_and_positional_embedding.test.ts +81 -0
- package/src/layers/token_and_positional_embedding.ts +149 -0
- package/src/layers/transformer_decoder.test.ts +100 -0
- package/src/layers/transformer_decoder.ts +236 -0
- package/src/layers/transformer_encoder.test.ts +85 -0
- package/src/layers/transformer_encoder.ts +224 -0
- package/src/losses/dice.ts +156 -0
- package/src/losses/index.ts +1 -0
- package/src/metrics.ts +32 -0
- package/src/models/gpt_model.ts +232 -0
- package/src/models/index.ts +2 -0
- package/src/models/llm_model.ts +355 -0
- package/src/models/u_net.ts +240 -0
- package/src/packing_mask.ts +28 -0
- package/src/testing.ts +1 -0
- package/src/tfjs_types.ts +15 -0
- package/src/utils.test.ts +101 -0
- package/src/utils.ts +86 -0
- package/tsconfig.json +49 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
export interface UNetArgs {
|
|
6
|
+
/**
|
|
7
|
+
* The starting number of filters.
|
|
8
|
+
*/
|
|
9
|
+
filters: number;
|
|
10
|
+
/**
|
|
11
|
+
* The number of categories. For binary segmentation, `units=1`.
|
|
12
|
+
*/
|
|
13
|
+
units: number;
|
|
14
|
+
/**
|
|
15
|
+
* The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
|
|
16
|
+
*/
|
|
17
|
+
activation?: ActivationIdentifier;
|
|
18
|
+
/**
|
|
19
|
+
* The depth of the U-Net or the number of contractions and the number of expansions.
|
|
20
|
+
*/
|
|
21
|
+
depth: number;
|
|
22
|
+
/**
|
|
23
|
+
* Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
|
|
24
|
+
*/
|
|
25
|
+
residual?: boolean;
|
|
26
|
+
/**
|
|
27
|
+
* Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
|
|
28
|
+
*/
|
|
29
|
+
batchNorm?: boolean;
|
|
30
|
+
/**
|
|
31
|
+
* Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
|
|
32
|
+
*/
|
|
33
|
+
inputShape?: [number | null, number | null, number];
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
export class UNetModel extends tf.Sequential {
|
|
41
|
+
|
|
42
|
+
constructor(args: UNetModelArgs) {
|
|
43
|
+
const {
|
|
44
|
+
filters,
|
|
45
|
+
units,
|
|
46
|
+
activation = units == 1 ? "sigmoid" : "softmax",
|
|
47
|
+
depth,
|
|
48
|
+
residual = false,
|
|
49
|
+
batchNorm = false,
|
|
50
|
+
inputShape = [null, null, 3],
|
|
51
|
+
...sequentialArgs
|
|
52
|
+
} = args;
|
|
53
|
+
|
|
54
|
+
sequentialArgs.name = sequentialArgs.name ?? "unet_model";
|
|
55
|
+
|
|
56
|
+
super({
|
|
57
|
+
...sequentialArgs,
|
|
58
|
+
// calling user should not modify the layers after instantiation
|
|
59
|
+
layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
|
|
60
|
+
});
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
override summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void {
|
|
65
|
+
super.summary(lineLength, positions, printFn);
|
|
66
|
+
(this.layers[0] as tf.LayersModel).summary(lineLength, positions, printFn);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }: UNetModelArgs) {
|
|
72
|
+
if (units < 1) {
|
|
73
|
+
throw Error(`createUNet: units should be >= 1, got ${units}`);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
const [image_height, image_width] = inputShape;
|
|
77
|
+
const divisble_by = 2 ** depth;
|
|
78
|
+
|
|
79
|
+
if ((image_height != null && image_height % divisble_by != 0) ||
|
|
80
|
+
image_width != null && image_width % divisble_by != 0) {
|
|
81
|
+
throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`)
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
const input = tf.input({ shape: inputShape });
|
|
85
|
+
|
|
86
|
+
const skip_connection: tf.SymbolicTensor[] = [];
|
|
87
|
+
|
|
88
|
+
let x = input;
|
|
89
|
+
|
|
90
|
+
// calculate the filter sizes for each level
|
|
91
|
+
const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
|
|
92
|
+
|
|
93
|
+
for (const filter_size of filter_sizes) {
|
|
94
|
+
const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
|
|
95
|
+
|
|
96
|
+
x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction) as tf.SymbolicTensor;
|
|
97
|
+
skip_connection.push(contraction);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
x = contractionBlock(x, filter_sizes.at(-1)! * 2, residual, batchNorm, "bottleneck");
|
|
101
|
+
|
|
102
|
+
for (let i = filter_sizes.length - 1; i >= 0; i--) {
|
|
103
|
+
x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
const output = tf.layers.conv2d({
|
|
107
|
+
filters: units,
|
|
108
|
+
kernelSize: 1,
|
|
109
|
+
padding: "same",
|
|
110
|
+
activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
|
|
111
|
+
name: "output-conv"
|
|
112
|
+
}).apply(x) as tf.SymbolicTensor;
|
|
113
|
+
|
|
114
|
+
return tf.model({ inputs: input, outputs: output, name: "u_net" });
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
export async function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions) {
|
|
119
|
+
const model = await tf.loadLayersModel(pathOrIOHandler, options);
|
|
120
|
+
const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
|
|
121
|
+
const { name, ...rest } = model;
|
|
122
|
+
Object.assign(unet, rest);
|
|
123
|
+
|
|
124
|
+
return unet;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* The contraction block of a U-Net
|
|
130
|
+
*
|
|
131
|
+
* Conv > BN > ReLU > Conv > BN + residual > ReLU
|
|
132
|
+
*
|
|
133
|
+
* TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
|
|
134
|
+
*
|
|
135
|
+
* @param x a previous layer's symbolic output
|
|
136
|
+
* @param filters the number of filters, usually half the previous expansion block's
|
|
137
|
+
* @param residual includes a residual connection
|
|
138
|
+
* @param batchNorm applies batch normalization before ReLU activation
|
|
139
|
+
* @param name a unique name for the contraction block
|
|
140
|
+
*/
|
|
141
|
+
function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
|
|
142
|
+
|
|
143
|
+
const conv1 = tf.layers.conv2d({
|
|
144
|
+
filters,
|
|
145
|
+
kernelSize: 3,
|
|
146
|
+
padding: "same",
|
|
147
|
+
useBias: !batchNorm,
|
|
148
|
+
kernelInitializer: "heNormal",
|
|
149
|
+
name: `${name}-1-conv2d`
|
|
150
|
+
});
|
|
151
|
+
const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
|
|
152
|
+
|
|
153
|
+
const conv2 = tf.layers.conv2d({
|
|
154
|
+
filters,
|
|
155
|
+
kernelSize: 3,
|
|
156
|
+
padding: "same",
|
|
157
|
+
useBias: !batchNorm,
|
|
158
|
+
kernelInitializer: "heNormal",
|
|
159
|
+
name: `${name}-2-conv2d`
|
|
160
|
+
});
|
|
161
|
+
const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
|
|
162
|
+
|
|
163
|
+
let forward = conv1.apply(x);
|
|
164
|
+
|
|
165
|
+
if (batchNorm) {
|
|
166
|
+
forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
forward = relu1.apply(forward);
|
|
170
|
+
|
|
171
|
+
forward = conv2.apply(forward);
|
|
172
|
+
|
|
173
|
+
if (batchNorm) {
|
|
174
|
+
forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
if (residual) {
|
|
178
|
+
let residual_skip = x;
|
|
179
|
+
|
|
180
|
+
if (x.shape[x.shape.length - 1] != filters) {
|
|
181
|
+
// a 1x1 convolution on the input to ensure the residual connection's
|
|
182
|
+
// channels/filters dim matches the convolution output
|
|
183
|
+
residual_skip = tf.layers.conv2d({
|
|
184
|
+
filters,
|
|
185
|
+
kernelSize: 1,
|
|
186
|
+
padding: "same",
|
|
187
|
+
useBias: !batchNorm,
|
|
188
|
+
kernelInitializer: "heNormal",
|
|
189
|
+
name: `${name}-residual`
|
|
190
|
+
}).apply(x) as tf.SymbolicTensor;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
if (batchNorm) {
|
|
194
|
+
residual_skip = tf.layers.batchNormalization({
|
|
195
|
+
name: `${name}-residual-batchnorm`
|
|
196
|
+
}).apply(residual_skip) as tf.SymbolicTensor;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
forward = tf.layers.add().apply([
|
|
200
|
+
residual_skip as tf.SymbolicTensor,
|
|
201
|
+
forward as tf.SymbolicTensor
|
|
202
|
+
])
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
forward = relu2.apply(forward);
|
|
206
|
+
|
|
207
|
+
return forward as tf.SymbolicTensor;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
/**
|
|
212
|
+
* The expansion block of a U-Net
|
|
213
|
+
*
|
|
214
|
+
* Upconv + skip > contraction block
|
|
215
|
+
*
|
|
216
|
+
* @param x a previous layer's symbolic output
|
|
217
|
+
* @param skip the corresponding contraction block's output (before pool), shape matches `x`
|
|
218
|
+
* @param filters the number of filters, usually half the previous expansion block's
|
|
219
|
+
* @param residual includes a residual connection
|
|
220
|
+
* @param batchNorm apply batch normalization, should be `false` when batch size is `1`
|
|
221
|
+
* @param name a unique name for the contraction block
|
|
222
|
+
*/
|
|
223
|
+
function expansionBlock(x: tf.SymbolicTensor, skip: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
|
|
224
|
+
|
|
225
|
+
const upconv = tf.layers.conv2dTranspose({
|
|
226
|
+
filters,
|
|
227
|
+
padding: "same",
|
|
228
|
+
kernelSize: 2,
|
|
229
|
+
strides: 2,
|
|
230
|
+
kernelInitializer: "heNormal",
|
|
231
|
+
name: `${name}-upconv`
|
|
232
|
+
});
|
|
233
|
+
|
|
234
|
+
const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
|
|
235
|
+
|
|
236
|
+
let forward = upconv.apply(x) as tf.SymbolicTensor;
|
|
237
|
+
forward = concat.apply([forward, skip]) as tf.SymbolicTensor;
|
|
238
|
+
|
|
239
|
+
return contractionBlock(forward, filters, residual, batchNorm, name);
|
|
240
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Generate a self-attention mask that prevents packed sequences from cross document
|
|
5
|
+
* boundaries and attending to each other. The result is a tensor of diagonally
|
|
6
|
+
* positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
|
|
7
|
+
* The latter is scored zero during the scaled dot product attention's softmax operation.
|
|
8
|
+
*
|
|
9
|
+
* @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
|
|
10
|
+
*
|
|
11
|
+
* Example boundary of 3 samples are packed into one:
|
|
12
|
+
* `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
|
|
13
|
+
*/
|
|
14
|
+
export function generatePackingSelfAttentionMask(boundaries: Int32Array) {
|
|
15
|
+
// see images at
|
|
16
|
+
// https://reddit.com/r/LocalLLaMA/comments/197efaz/training_llama_mistral_and_mixtralmoe_faster_with/
|
|
17
|
+
return tf.tidy(() => {
|
|
18
|
+
// cumsum transforms the tensor such that each sequence in the pack gets its own id,
|
|
19
|
+
const partitions = tf.tensor1d(boundaries).cumsum();
|
|
20
|
+
|
|
21
|
+
return partitions.expandDims(1)
|
|
22
|
+
.equal(partitions.expandDims(0))
|
|
23
|
+
.sub(1)
|
|
24
|
+
.mul(1e7)
|
|
25
|
+
// introduce a head dimension so it can be broadcasted
|
|
26
|
+
.expandDims(0);
|
|
27
|
+
})
|
|
28
|
+
}
|
package/src/testing.ts
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
console.log("test")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import type { Tensor } from "@tensorflow/tfjs";
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
export declare abstract class LazyIterator<T> {
|
|
5
|
+
abstract next(): Promise<IteratorResult<T>>;
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
export declare abstract class Dataset<T> {
|
|
10
|
+
abstract iterator(): Promise<LazyIterator<T>>;
|
|
11
|
+
size: number;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor;
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { getScaleShape, getRandomCropStart, generateCausalAttentionMask } from "@/utils";
|
|
3
|
+
|
|
4
|
+
// avoid TFJS node message during Jest testing
|
|
5
|
+
tf.env().set('IS_NODE', false);
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
describe("test custom TFJS utility functions", () => {
|
|
9
|
+
|
|
10
|
+
test("crop an image using the same shape, results in same shape", async () => {
|
|
11
|
+
// cropping an image of the same shape
|
|
12
|
+
const img_size = [133, 84] as [number, number];
|
|
13
|
+
const target_size = [133, 84] as [number, number];
|
|
14
|
+
|
|
15
|
+
expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
|
|
16
|
+
});
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
it("should throw when crop is larger than image", async () => {
|
|
20
|
+
expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
|
|
21
|
+
})
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
test("cropped image shape", async () => {
|
|
25
|
+
// cropping from wide to tall image
|
|
26
|
+
for (let i = 0; i < 100; i++) {
|
|
27
|
+
const img_size = [4923, 832] as [number, number];
|
|
28
|
+
const target_size = [333, 739] as [number, number];
|
|
29
|
+
|
|
30
|
+
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
|
|
31
|
+
|
|
32
|
+
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
33
|
+
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
// cropping from tall to wide image
|
|
37
|
+
for (let i = 0; i < 100; i++) {
|
|
38
|
+
const img_size = [381, 999] as [number, number];
|
|
39
|
+
const target_size = [300, 157] as [number, number];
|
|
40
|
+
|
|
41
|
+
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
|
|
42
|
+
|
|
43
|
+
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
44
|
+
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
45
|
+
}
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
test("scale 1:1, results in the same shape", async () => {
|
|
50
|
+
const scale = getScaleShape([256, 256], [256, 256])
|
|
51
|
+
expect(scale).toEqual([256, 256]);
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
test("scaled image shape", async () => {
|
|
56
|
+
// scaling squares result in squares
|
|
57
|
+
const scale1 = getScaleShape([256, 256], [128, 128])
|
|
58
|
+
expect(scale1).toEqual([128, 128]);
|
|
59
|
+
|
|
60
|
+
const scale2 = getScaleShape([128, 128], [256, 256])
|
|
61
|
+
expect(scale2).toEqual([256, 256]);
|
|
62
|
+
|
|
63
|
+
const scale3 = getScaleShape([123, 123], [321, 321])
|
|
64
|
+
expect(scale3).toEqual([321, 321]);
|
|
65
|
+
|
|
66
|
+
const scale4 = getScaleShape([321, 321], [123, 123])
|
|
67
|
+
expect(scale4).toEqual([123, 123]);
|
|
68
|
+
|
|
69
|
+
// scaling rectangles result in rectangles
|
|
70
|
+
const scale5 = getScaleShape([640, 480], [1280, 960])
|
|
71
|
+
expect(scale5).toEqual([1280, 960]);
|
|
72
|
+
|
|
73
|
+
const scale6 = getScaleShape([480, 640], [960, 1280])
|
|
74
|
+
expect(scale6).toEqual([960, 1280]);
|
|
75
|
+
|
|
76
|
+
const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333])
|
|
77
|
+
expect(scale7_h).toBeGreaterThan(scale7_w);
|
|
78
|
+
|
|
79
|
+
const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555])
|
|
80
|
+
expect(scale8_h).toBeLessThan(scale8_w);
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
test("causal attention map", async () => {
|
|
85
|
+
const seq_len = 4;
|
|
86
|
+
const causal_mask = generateCausalAttentionMask(seq_len, seq_len);
|
|
87
|
+
|
|
88
|
+
const _ = -1e7;
|
|
89
|
+
const expected_mask = tf.tensor([
|
|
90
|
+
[0, _, _, _],
|
|
91
|
+
[0, 0, _, _],
|
|
92
|
+
[0, 0, 0, _],
|
|
93
|
+
[0, 0, 0, 0]
|
|
94
|
+
]);
|
|
95
|
+
|
|
96
|
+
// this might fail due to precision issues on the masked positions,
|
|
97
|
+
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
98
|
+
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
});
|
package/src/utils.ts
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Calculate the desired scaled image's height and width. The shortest edge will
|
|
6
|
+
* be scaled to match its corresponding target shape's edge. The longer
|
|
7
|
+
* edge might end up larger than intended.
|
|
8
|
+
*
|
|
9
|
+
* @param image_shape the `[height, width]` of the image
|
|
10
|
+
* @param target_shape the intended `[height, width]` of the final scaled image
|
|
11
|
+
*/
|
|
12
|
+
export function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number] {
|
|
13
|
+
const [img_height, img_width] = image_shape as [number, number, number];
|
|
14
|
+
const [target_height, target_width] = target_shape;
|
|
15
|
+
|
|
16
|
+
// scale based on whichever target_edge / original_edge is largest,
|
|
17
|
+
// we need the following to be true (1)
|
|
18
|
+
// height * scale >= target_height
|
|
19
|
+
// width * scale >= target_height
|
|
20
|
+
// rearranging to get an equivalent requirement (2)
|
|
21
|
+
// scale >= target_height / height
|
|
22
|
+
// scale >= target_width / width
|
|
23
|
+
// by picking the scale value that's largest of the two, we satisfy (2), and therefore (1)
|
|
24
|
+
// it may be more intuitive to think of scale as scale_h and scale_w
|
|
25
|
+
const scale_factor = Math.max(target_height / img_height, target_width / img_width);
|
|
26
|
+
return [Math.round(img_height * scale_factor), Math.round(img_width * scale_factor)];
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Calculate the starting point for a crop (slice) operation
|
|
32
|
+
* on an image tensor with the shape `[height, width, channels]`.
|
|
33
|
+
*
|
|
34
|
+
* @param image_shape the `[height, width]` of the image
|
|
35
|
+
* @param target_shape the intended `[height, width]` of the final cropped image
|
|
36
|
+
*/
|
|
37
|
+
export function getRandomCropStart(
|
|
38
|
+
image_shape: [height: number, width: number],
|
|
39
|
+
target_shape: [height: number, width: number]
|
|
40
|
+
): [number, number, number] {
|
|
41
|
+
const [img_height, img_width] = image_shape;
|
|
42
|
+
const [crop_x, crop_y] = target_shape;
|
|
43
|
+
|
|
44
|
+
if (img_height < crop_x || img_width < crop_y) {
|
|
45
|
+
throw Error(`getRandomCropShape: cannot crop with a size that's bigger than,` +
|
|
46
|
+
` the image. Original [${img_height}, ${img_width}], crop [${crop_x}, ${crop_y}].`);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// there's a +1 because Math.random()'s range is [0, 1), excluding 1,
|
|
50
|
+
// hence +1 to ensure the full range of possible crop starting points
|
|
51
|
+
return [
|
|
52
|
+
// TODO: revisit the +1
|
|
53
|
+
Math.floor(Math.random() * (img_height - crop_x + 1)),
|
|
54
|
+
Math.floor(Math.random() * (img_width - crop_y + 1)),
|
|
55
|
+
0 // not cropping channels, so it starts at the first index
|
|
56
|
+
]
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
/**
|
|
61
|
+
* Calculate the height and width padding such that the image is
|
|
62
|
+
* divisible by 2^depth.
|
|
63
|
+
*
|
|
64
|
+
* In U-Net image segmentation, the contraction and concatenate
|
|
65
|
+
* operations requires the input image's height and width
|
|
66
|
+
* dimensions to be divisible by 2^depth.
|
|
67
|
+
*/
|
|
68
|
+
export function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number] {
|
|
69
|
+
const divisible = Math.pow(2, depth);
|
|
70
|
+
|
|
71
|
+
const [height, width] = image.shape;
|
|
72
|
+
|
|
73
|
+
return [
|
|
74
|
+
(Math.ceil(height / divisible)) * divisible - height,
|
|
75
|
+
(Math.ceil(width / divisible)) * divisible - width,
|
|
76
|
+
]
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
export function generateCausalAttentionMask(query_seq_length: number, key_seq_length: number) {
|
|
81
|
+
return tf.tidy(() => {
|
|
82
|
+
return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
|
|
83
|
+
.sub(1)
|
|
84
|
+
.mul(1e7);
|
|
85
|
+
})
|
|
86
|
+
}
|
package/tsconfig.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
{
|
|
2
|
+
// Visit https://aka.ms/tsconfig to read more about this file
|
|
3
|
+
"compilerOptions": {
|
|
4
|
+
// File Layout
|
|
5
|
+
// "rootDir": "./src",
|
|
6
|
+
// "outDir": "./dist",
|
|
7
|
+
// Environment Settings
|
|
8
|
+
// See also https://aka.ms/tsconfig/module
|
|
9
|
+
"module": "es2022",
|
|
10
|
+
"target": "esnext",
|
|
11
|
+
"types": ["jest"],
|
|
12
|
+
// For nodejs:
|
|
13
|
+
// "lib": ["esnext"],
|
|
14
|
+
// "types": ["node"],
|
|
15
|
+
// and npm install -D @types/node
|
|
16
|
+
// Other Outputs
|
|
17
|
+
"sourceMap": true,
|
|
18
|
+
"declaration": true,
|
|
19
|
+
"declarationMap": true,
|
|
20
|
+
// Stricter Typechecking Options
|
|
21
|
+
//"noUncheckedIndexedAccess": true,
|
|
22
|
+
"exactOptionalPropertyTypes": true,
|
|
23
|
+
// Style Options
|
|
24
|
+
// "noImplicitReturns": true,
|
|
25
|
+
// "noImplicitOverride": true,
|
|
26
|
+
// "noUnusedLocals": true,
|
|
27
|
+
// "noUnusedParameters": true,
|
|
28
|
+
// "noFallthroughCasesInSwitch": true,
|
|
29
|
+
// "noPropertyAccessFromIndexSignature": true,
|
|
30
|
+
// Recommended Options
|
|
31
|
+
"strict": true,
|
|
32
|
+
"jsx": "react-jsx",
|
|
33
|
+
//"verbatimModuleSyntax": true,
|
|
34
|
+
"isolatedModules": true,
|
|
35
|
+
"noUncheckedSideEffectImports": true,
|
|
36
|
+
"moduleDetection": "force",
|
|
37
|
+
"skipLibCheck": true,
|
|
38
|
+
"paths": {
|
|
39
|
+
"@/*": [
|
|
40
|
+
"./src/*"
|
|
41
|
+
],
|
|
42
|
+
"e2e/*": [
|
|
43
|
+
"./e2e/*"
|
|
44
|
+
]
|
|
45
|
+
},
|
|
46
|
+
"moduleResolution": "bundler",
|
|
47
|
+
"esModuleInterop": true
|
|
48
|
+
}
|
|
49
|
+
}
|