@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,240 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
3
+
4
+
5
+ export interface UNetArgs {
6
+ /**
7
+ * The starting number of filters.
8
+ */
9
+ filters: number;
10
+ /**
11
+ * The number of categories. For binary segmentation, `units=1`.
12
+ */
13
+ units: number;
14
+ /**
15
+ * The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
16
+ */
17
+ activation?: ActivationIdentifier;
18
+ /**
19
+ * The depth of the U-Net or the number of contractions and the number of expansions.
20
+ */
21
+ depth: number;
22
+ /**
23
+ * Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
24
+ */
25
+ residual?: boolean;
26
+ /**
27
+ * Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
28
+ */
29
+ batchNorm?: boolean;
30
+ /**
31
+ * Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
32
+ */
33
+ inputShape?: [number | null, number | null, number];
34
+ }
35
+
36
+
37
+ export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
38
+
39
+
40
+ export class UNetModel extends tf.Sequential {
41
+
42
+ constructor(args: UNetModelArgs) {
43
+ const {
44
+ filters,
45
+ units,
46
+ activation = units == 1 ? "sigmoid" : "softmax",
47
+ depth,
48
+ residual = false,
49
+ batchNorm = false,
50
+ inputShape = [null, null, 3],
51
+ ...sequentialArgs
52
+ } = args;
53
+
54
+ sequentialArgs.name = sequentialArgs.name ?? "unet_model";
55
+
56
+ super({
57
+ ...sequentialArgs,
58
+ // calling user should not modify the layers after instantiation
59
+ layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
60
+ });
61
+ }
62
+
63
+
64
+ override summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void {
65
+ super.summary(lineLength, positions, printFn);
66
+ (this.layers[0] as tf.LayersModel).summary(lineLength, positions, printFn);
67
+ }
68
+ }
69
+
70
+
71
+ export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }: UNetModelArgs) {
72
+ if (units < 1) {
73
+ throw Error(`createUNet: units should be >= 1, got ${units}`);
74
+ }
75
+
76
+ const [image_height, image_width] = inputShape;
77
+ const divisble_by = 2 ** depth;
78
+
79
+ if ((image_height != null && image_height % divisble_by != 0) ||
80
+ image_width != null && image_width % divisble_by != 0) {
81
+ throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`)
82
+ }
83
+
84
+ const input = tf.input({ shape: inputShape });
85
+
86
+ const skip_connection: tf.SymbolicTensor[] = [];
87
+
88
+ let x = input;
89
+
90
+ // calculate the filter sizes for each level
91
+ const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
92
+
93
+ for (const filter_size of filter_sizes) {
94
+ const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
95
+
96
+ x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction) as tf.SymbolicTensor;
97
+ skip_connection.push(contraction);
98
+ }
99
+
100
+ x = contractionBlock(x, filter_sizes.at(-1)! * 2, residual, batchNorm, "bottleneck");
101
+
102
+ for (let i = filter_sizes.length - 1; i >= 0; i--) {
103
+ x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
104
+ }
105
+
106
+ const output = tf.layers.conv2d({
107
+ filters: units,
108
+ kernelSize: 1,
109
+ padding: "same",
110
+ activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
111
+ name: "output-conv"
112
+ }).apply(x) as tf.SymbolicTensor;
113
+
114
+ return tf.model({ inputs: input, outputs: output, name: "u_net" });
115
+ }
116
+
117
+
118
+ export async function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions) {
119
+ const model = await tf.loadLayersModel(pathOrIOHandler, options);
120
+ const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
121
+ const { name, ...rest } = model;
122
+ Object.assign(unet, rest);
123
+
124
+ return unet;
125
+ }
126
+
127
+
128
+ /**
129
+ * The contraction block of a U-Net
130
+ *
131
+ * Conv > BN > ReLU > Conv > BN + residual > ReLU
132
+ *
133
+ * TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
134
+ *
135
+ * @param x a previous layer's symbolic output
136
+ * @param filters the number of filters, usually half the previous expansion block's
137
+ * @param residual includes a residual connection
138
+ * @param batchNorm applies batch normalization before ReLU activation
139
+ * @param name a unique name for the contraction block
140
+ */
141
+ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
142
+
143
+ const conv1 = tf.layers.conv2d({
144
+ filters,
145
+ kernelSize: 3,
146
+ padding: "same",
147
+ useBias: !batchNorm,
148
+ kernelInitializer: "heNormal",
149
+ name: `${name}-1-conv2d`
150
+ });
151
+ const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
152
+
153
+ const conv2 = tf.layers.conv2d({
154
+ filters,
155
+ kernelSize: 3,
156
+ padding: "same",
157
+ useBias: !batchNorm,
158
+ kernelInitializer: "heNormal",
159
+ name: `${name}-2-conv2d`
160
+ });
161
+ const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
162
+
163
+ let forward = conv1.apply(x);
164
+
165
+ if (batchNorm) {
166
+ forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
167
+ }
168
+
169
+ forward = relu1.apply(forward);
170
+
171
+ forward = conv2.apply(forward);
172
+
173
+ if (batchNorm) {
174
+ forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
175
+ }
176
+
177
+ if (residual) {
178
+ let residual_skip = x;
179
+
180
+ if (x.shape[x.shape.length - 1] != filters) {
181
+ // a 1x1 convolution on the input to ensure the residual connection's
182
+ // channels/filters dim matches the convolution output
183
+ residual_skip = tf.layers.conv2d({
184
+ filters,
185
+ kernelSize: 1,
186
+ padding: "same",
187
+ useBias: !batchNorm,
188
+ kernelInitializer: "heNormal",
189
+ name: `${name}-residual`
190
+ }).apply(x) as tf.SymbolicTensor;
191
+ }
192
+
193
+ if (batchNorm) {
194
+ residual_skip = tf.layers.batchNormalization({
195
+ name: `${name}-residual-batchnorm`
196
+ }).apply(residual_skip) as tf.SymbolicTensor;
197
+ }
198
+
199
+ forward = tf.layers.add().apply([
200
+ residual_skip as tf.SymbolicTensor,
201
+ forward as tf.SymbolicTensor
202
+ ])
203
+ }
204
+
205
+ forward = relu2.apply(forward);
206
+
207
+ return forward as tf.SymbolicTensor;
208
+ }
209
+
210
+
211
+ /**
212
+ * The expansion block of a U-Net
213
+ *
214
+ * Upconv + skip > contraction block
215
+ *
216
+ * @param x a previous layer's symbolic output
217
+ * @param skip the corresponding contraction block's output (before pool), shape matches `x`
218
+ * @param filters the number of filters, usually half the previous expansion block's
219
+ * @param residual includes a residual connection
220
+ * @param batchNorm apply batch normalization, should be `false` when batch size is `1`
221
+ * @param name a unique name for the contraction block
222
+ */
223
+ function expansionBlock(x: tf.SymbolicTensor, skip: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
224
+
225
+ const upconv = tf.layers.conv2dTranspose({
226
+ filters,
227
+ padding: "same",
228
+ kernelSize: 2,
229
+ strides: 2,
230
+ kernelInitializer: "heNormal",
231
+ name: `${name}-upconv`
232
+ });
233
+
234
+ const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
235
+
236
+ let forward = upconv.apply(x) as tf.SymbolicTensor;
237
+ forward = concat.apply([forward, skip]) as tf.SymbolicTensor;
238
+
239
+ return contractionBlock(forward, filters, residual, batchNorm, name);
240
+ }
@@ -0,0 +1,28 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+
3
+ /**
4
+ * Generate a self-attention mask that prevents packed sequences from cross document
5
+ * boundaries and attending to each other. The result is a tensor of diagonally
6
+ * positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
7
+ * The latter is scored zero during the scaled dot product attention's softmax operation.
8
+ *
9
+ * @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
10
+ *
11
+ * Example boundary of 3 samples are packed into one:
12
+ * `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
13
+ */
14
+ export function generatePackingSelfAttentionMask(boundaries: Int32Array) {
15
+ // see images at
16
+ // https://reddit.com/r/LocalLLaMA/comments/197efaz/training_llama_mistral_and_mixtralmoe_faster_with/
17
+ return tf.tidy(() => {
18
+ // cumsum transforms the tensor such that each sequence in the pack gets its own id,
19
+ const partitions = tf.tensor1d(boundaries).cumsum();
20
+
21
+ return partitions.expandDims(1)
22
+ .equal(partitions.expandDims(0))
23
+ .sub(1)
24
+ .mul(1e7)
25
+ // introduce a head dimension so it can be broadcasted
26
+ .expandDims(0);
27
+ })
28
+ }
package/src/testing.ts ADDED
@@ -0,0 +1 @@
1
+ console.log("test")
@@ -0,0 +1,15 @@
1
+ import type { Tensor } from "@tensorflow/tfjs";
2
+
3
+
4
+ export declare abstract class LazyIterator<T> {
5
+ abstract next(): Promise<IteratorResult<T>>;
6
+ }
7
+
8
+
9
+ export declare abstract class Dataset<T> {
10
+ abstract iterator(): Promise<LazyIterator<T>>;
11
+ size: number;
12
+ }
13
+
14
+
15
+ export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor;
@@ -0,0 +1,101 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { getScaleShape, getRandomCropStart, generateCausalAttentionMask } from "@/utils";
3
+
4
+ // avoid TFJS node message during Jest testing
5
+ tf.env().set('IS_NODE', false);
6
+
7
+
8
+ describe("test custom TFJS utility functions", () => {
9
+
10
+ test("crop an image using the same shape, results in same shape", async () => {
11
+ // cropping an image of the same shape
12
+ const img_size = [133, 84] as [number, number];
13
+ const target_size = [133, 84] as [number, number];
14
+
15
+ expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
16
+ });
17
+
18
+
19
+ it("should throw when crop is larger than image", async () => {
20
+ expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
21
+ })
22
+
23
+
24
+ test("cropped image shape", async () => {
25
+ // cropping from wide to tall image
26
+ for (let i = 0; i < 100; i++) {
27
+ const img_size = [4923, 832] as [number, number];
28
+ const target_size = [333, 739] as [number, number];
29
+
30
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
31
+
32
+ expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
33
+ expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
34
+ }
35
+
36
+ // cropping from tall to wide image
37
+ for (let i = 0; i < 100; i++) {
38
+ const img_size = [381, 999] as [number, number];
39
+ const target_size = [300, 157] as [number, number];
40
+
41
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
42
+
43
+ expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
44
+ expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
45
+ }
46
+ });
47
+
48
+
49
+ test("scale 1:1, results in the same shape", async () => {
50
+ const scale = getScaleShape([256, 256], [256, 256])
51
+ expect(scale).toEqual([256, 256]);
52
+ });
53
+
54
+
55
+ test("scaled image shape", async () => {
56
+ // scaling squares result in squares
57
+ const scale1 = getScaleShape([256, 256], [128, 128])
58
+ expect(scale1).toEqual([128, 128]);
59
+
60
+ const scale2 = getScaleShape([128, 128], [256, 256])
61
+ expect(scale2).toEqual([256, 256]);
62
+
63
+ const scale3 = getScaleShape([123, 123], [321, 321])
64
+ expect(scale3).toEqual([321, 321]);
65
+
66
+ const scale4 = getScaleShape([321, 321], [123, 123])
67
+ expect(scale4).toEqual([123, 123]);
68
+
69
+ // scaling rectangles result in rectangles
70
+ const scale5 = getScaleShape([640, 480], [1280, 960])
71
+ expect(scale5).toEqual([1280, 960]);
72
+
73
+ const scale6 = getScaleShape([480, 640], [960, 1280])
74
+ expect(scale6).toEqual([960, 1280]);
75
+
76
+ const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333])
77
+ expect(scale7_h).toBeGreaterThan(scale7_w);
78
+
79
+ const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555])
80
+ expect(scale8_h).toBeLessThan(scale8_w);
81
+ });
82
+
83
+
84
+ test("causal attention map", async () => {
85
+ const seq_len = 4;
86
+ const causal_mask = generateCausalAttentionMask(seq_len, seq_len);
87
+
88
+ const _ = -1e7;
89
+ const expected_mask = tf.tensor([
90
+ [0, _, _, _],
91
+ [0, 0, _, _],
92
+ [0, 0, 0, _],
93
+ [0, 0, 0, 0]
94
+ ]);
95
+
96
+ // this might fail due to precision issues on the masked positions,
97
+ // in which case use less <= to 6 or 12 (number of masked positions x2)
98
+ expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
99
+ });
100
+
101
+ });
package/src/utils.ts ADDED
@@ -0,0 +1,86 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+
3
+
4
+ /**
5
+ * Calculate the desired scaled image's height and width. The shortest edge will
6
+ * be scaled to match its corresponding target shape's edge. The longer
7
+ * edge might end up larger than intended.
8
+ *
9
+ * @param image_shape the `[height, width]` of the image
10
+ * @param target_shape the intended `[height, width]` of the final scaled image
11
+ */
12
+ export function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number] {
13
+ const [img_height, img_width] = image_shape as [number, number, number];
14
+ const [target_height, target_width] = target_shape;
15
+
16
+ // scale based on whichever target_edge / original_edge is largest,
17
+ // we need the following to be true (1)
18
+ // height * scale >= target_height
19
+ // width * scale >= target_height
20
+ // rearranging to get an equivalent requirement (2)
21
+ // scale >= target_height / height
22
+ // scale >= target_width / width
23
+ // by picking the scale value that's largest of the two, we satisfy (2), and therefore (1)
24
+ // it may be more intuitive to think of scale as scale_h and scale_w
25
+ const scale_factor = Math.max(target_height / img_height, target_width / img_width);
26
+ return [Math.round(img_height * scale_factor), Math.round(img_width * scale_factor)];
27
+ }
28
+
29
+
30
+ /**
31
+ * Calculate the starting point for a crop (slice) operation
32
+ * on an image tensor with the shape `[height, width, channels]`.
33
+ *
34
+ * @param image_shape the `[height, width]` of the image
35
+ * @param target_shape the intended `[height, width]` of the final cropped image
36
+ */
37
+ export function getRandomCropStart(
38
+ image_shape: [height: number, width: number],
39
+ target_shape: [height: number, width: number]
40
+ ): [number, number, number] {
41
+ const [img_height, img_width] = image_shape;
42
+ const [crop_x, crop_y] = target_shape;
43
+
44
+ if (img_height < crop_x || img_width < crop_y) {
45
+ throw Error(`getRandomCropShape: cannot crop with a size that's bigger than,` +
46
+ ` the image. Original [${img_height}, ${img_width}], crop [${crop_x}, ${crop_y}].`);
47
+ }
48
+
49
+ // there's a +1 because Math.random()'s range is [0, 1), excluding 1,
50
+ // hence +1 to ensure the full range of possible crop starting points
51
+ return [
52
+ // TODO: revisit the +1
53
+ Math.floor(Math.random() * (img_height - crop_x + 1)),
54
+ Math.floor(Math.random() * (img_width - crop_y + 1)),
55
+ 0 // not cropping channels, so it starts at the first index
56
+ ]
57
+ }
58
+
59
+
60
+ /**
61
+ * Calculate the height and width padding such that the image is
62
+ * divisible by 2^depth.
63
+ *
64
+ * In U-Net image segmentation, the contraction and concatenate
65
+ * operations requires the input image's height and width
66
+ * dimensions to be divisible by 2^depth.
67
+ */
68
+ export function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number] {
69
+ const divisible = Math.pow(2, depth);
70
+
71
+ const [height, width] = image.shape;
72
+
73
+ return [
74
+ (Math.ceil(height / divisible)) * divisible - height,
75
+ (Math.ceil(width / divisible)) * divisible - width,
76
+ ]
77
+ }
78
+
79
+
80
+ export function generateCausalAttentionMask(query_seq_length: number, key_seq_length: number) {
81
+ return tf.tidy(() => {
82
+ return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
83
+ .sub(1)
84
+ .mul(1e7);
85
+ })
86
+ }
package/tsconfig.json ADDED
@@ -0,0 +1,49 @@
1
+ {
2
+ // Visit https://aka.ms/tsconfig to read more about this file
3
+ "compilerOptions": {
4
+ // File Layout
5
+ // "rootDir": "./src",
6
+ // "outDir": "./dist",
7
+ // Environment Settings
8
+ // See also https://aka.ms/tsconfig/module
9
+ "module": "es2022",
10
+ "target": "esnext",
11
+ "types": ["jest"],
12
+ // For nodejs:
13
+ // "lib": ["esnext"],
14
+ // "types": ["node"],
15
+ // and npm install -D @types/node
16
+ // Other Outputs
17
+ "sourceMap": true,
18
+ "declaration": true,
19
+ "declarationMap": true,
20
+ // Stricter Typechecking Options
21
+ //"noUncheckedIndexedAccess": true,
22
+ "exactOptionalPropertyTypes": true,
23
+ // Style Options
24
+ // "noImplicitReturns": true,
25
+ // "noImplicitOverride": true,
26
+ // "noUnusedLocals": true,
27
+ // "noUnusedParameters": true,
28
+ // "noFallthroughCasesInSwitch": true,
29
+ // "noPropertyAccessFromIndexSignature": true,
30
+ // Recommended Options
31
+ "strict": true,
32
+ "jsx": "react-jsx",
33
+ //"verbatimModuleSyntax": true,
34
+ "isolatedModules": true,
35
+ "noUncheckedSideEffectImports": true,
36
+ "moduleDetection": "force",
37
+ "skipLibCheck": true,
38
+ "paths": {
39
+ "@/*": [
40
+ "./src/*"
41
+ ],
42
+ "e2e/*": [
43
+ "./e2e/*"
44
+ ]
45
+ },
46
+ "moduleResolution": "bundler",
47
+ "esModuleInterop": true
48
+ }
49
+ }