@stellarapp/tfjs-stellar 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +14 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -0,0 +1,151 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ export class UNetModel extends tf.Sequential {
3
+ constructor(args) {
4
+ const { filters, units, activation = units == 1 ? "sigmoid" : "softmax", depth, residual = false, batchNorm = false, inputShape = [null, null, 3], ...sequentialArgs } = args;
5
+ sequentialArgs.name = sequentialArgs.name ?? "unet_model";
6
+ super({
7
+ ...sequentialArgs,
8
+ // calling user should not modify the layers after instantiation
9
+ layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
10
+ });
11
+ }
12
+ summary(lineLength, positions, printFn) {
13
+ super.summary(lineLength, positions, printFn);
14
+ this.layers[0].summary(lineLength, positions, printFn);
15
+ }
16
+ }
17
+ export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }) {
18
+ if (units < 1) {
19
+ throw Error(`createUNet: units should be >= 1, got ${units}`);
20
+ }
21
+ const [image_height, image_width] = inputShape;
22
+ const divisble_by = 2 ** depth;
23
+ if ((image_height != null && image_height % divisble_by != 0) ||
24
+ image_width != null && image_width % divisble_by != 0) {
25
+ throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`);
26
+ }
27
+ const input = tf.input({ shape: inputShape });
28
+ const skip_connection = [];
29
+ let x = input;
30
+ // calculate the filter sizes for each level
31
+ const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
32
+ for (const filter_size of filter_sizes) {
33
+ const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
34
+ x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction);
35
+ skip_connection.push(contraction);
36
+ }
37
+ x = contractionBlock(x, filter_sizes.at(-1) * 2, residual, batchNorm, "bottleneck");
38
+ for (let i = filter_sizes.length - 1; i >= 0; i--) {
39
+ x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
40
+ }
41
+ const output = tf.layers.conv2d({
42
+ filters: units,
43
+ kernelSize: 1,
44
+ padding: "same",
45
+ activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
46
+ name: "output-conv"
47
+ }).apply(x);
48
+ return tf.model({ inputs: input, outputs: output, name: "u_net" });
49
+ }
50
+ export async function loadUNetModel(pathOrIOHandler, options) {
51
+ const model = await tf.loadLayersModel(pathOrIOHandler, options);
52
+ const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
53
+ const { name, ...rest } = model;
54
+ Object.assign(unet, rest);
55
+ return unet;
56
+ }
57
+ /**
58
+ * The contraction block of a U-Net
59
+ *
60
+ * Conv > BN > ReLU > Conv > BN + residual > ReLU
61
+ *
62
+ * TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
63
+ *
64
+ * @param x a previous layer's symbolic output
65
+ * @param filters the number of filters, usually half the previous expansion block's
66
+ * @param residual includes a residual connection
67
+ * @param batchNorm applies batch normalization before ReLU activation
68
+ * @param name a unique name for the contraction block
69
+ */
70
+ function contractionBlock(x, filters, residual, batchNorm, name) {
71
+ const conv1 = tf.layers.conv2d({
72
+ filters,
73
+ kernelSize: 3,
74
+ padding: "same",
75
+ useBias: !batchNorm,
76
+ kernelInitializer: "heNormal",
77
+ name: `${name}-1-conv2d`
78
+ });
79
+ const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
80
+ const conv2 = tf.layers.conv2d({
81
+ filters,
82
+ kernelSize: 3,
83
+ padding: "same",
84
+ useBias: !batchNorm,
85
+ kernelInitializer: "heNormal",
86
+ name: `${name}-2-conv2d`
87
+ });
88
+ const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
89
+ let forward = conv1.apply(x);
90
+ if (batchNorm) {
91
+ forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
92
+ }
93
+ forward = relu1.apply(forward);
94
+ forward = conv2.apply(forward);
95
+ if (batchNorm) {
96
+ forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
97
+ }
98
+ if (residual) {
99
+ let residual_skip = x;
100
+ if (x.shape[x.shape.length - 1] != filters) {
101
+ // a 1x1 convolution on the input to ensure the residual connection's
102
+ // channels/filters dim matches the convolution output
103
+ residual_skip = tf.layers.conv2d({
104
+ filters,
105
+ kernelSize: 1,
106
+ padding: "same",
107
+ useBias: !batchNorm,
108
+ kernelInitializer: "heNormal",
109
+ name: `${name}-residual`
110
+ }).apply(x);
111
+ }
112
+ if (batchNorm) {
113
+ residual_skip = tf.layers.batchNormalization({
114
+ name: `${name}-residual-batchnorm`
115
+ }).apply(residual_skip);
116
+ }
117
+ forward = tf.layers.add().apply([
118
+ residual_skip,
119
+ forward
120
+ ]);
121
+ }
122
+ forward = relu2.apply(forward);
123
+ return forward;
124
+ }
125
+ /**
126
+ * The expansion block of a U-Net
127
+ *
128
+ * Upconv + skip > contraction block
129
+ *
130
+ * @param x a previous layer's symbolic output
131
+ * @param skip the corresponding contraction block's output (before pool), shape matches `x`
132
+ * @param filters the number of filters, usually half the previous expansion block's
133
+ * @param residual includes a residual connection
134
+ * @param batchNorm apply batch normalization, should be `false` when batch size is `1`
135
+ * @param name a unique name for the contraction block
136
+ */
137
+ function expansionBlock(x, skip, filters, residual, batchNorm, name) {
138
+ const upconv = tf.layers.conv2dTranspose({
139
+ filters,
140
+ padding: "same",
141
+ kernelSize: 2,
142
+ strides: 2,
143
+ kernelInitializer: "heNormal",
144
+ name: `${name}-upconv`
145
+ });
146
+ const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
147
+ let forward = upconv.apply(x);
148
+ forward = concat.apply([forward, skip]);
149
+ return contractionBlock(forward, filters, residual, batchNorm, name);
150
+ }
151
+ //# sourceMappingURL=u_net.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"u_net.js","sourceRoot":"","sources":["../../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAuCvC,MAAM,OAAO,SAAU,SAAQ,EAAE,CAAC,UAAU;IAExC,YAAY,IAAmB;QAC3B,MAAM,EACF,OAAO,EACP,KAAK,EACL,UAAU,GAAG,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,EAC/C,KAAK,EACL,QAAQ,GAAG,KAAK,EAChB,SAAS,GAAG,KAAK,EACjB,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAC5B,GAAG,cAAc,EACpB,GAAG,IAAI,CAAC;QAET,cAAc,CAAC,IAAI,GAAG,cAAc,CAAC,IAAI,IAAI,YAAY,CAAC;QAE1D,KAAK,CAAC;YACF,GAAG,cAAc;YACjB,gEAAgE;YAChE,MAAM,EAAE,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,UAAU,EAAE,CAAC,CAAC;SAC/F,CAAC,CAAC;IACP,CAAC;IAGQ,OAAO,CAAC,UAAmB,EAAE,SAAoB,EAAE,OAA2D;QACnH,KAAK,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,CAAC,CAAC,CAAoB,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IAC/E,CAAC;CACJ;AAGD,MAAM,UAAU,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAQ,GAAG,KAAK,EAAE,SAAS,GAAG,KAAK,EAAE,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAAiB;IAC9I,IAAI,KAAK,GAAG,CAAC,EAAE,CAAC;QACZ,MAAM,KAAK,CAAC,yCAAyC,KAAK,EAAE,CAAC,CAAC;IAClE,CAAC;IAED,MAAM,CAAC,YAAY,EAAE,WAAW,CAAC,GAAG,UAAU,CAAC;IAC/C,MAAM,WAAW,GAAG,CAAC,IAAI,KAAK,CAAC;IAE/B,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,YAAY,GAAG,WAAW,IAAI,CAAC,CAAC;QACzD,WAAW,IAAI,IAAI,IAAI,WAAW,GAAG,WAAW,IAAI,CAAC,EAAE,CAAC;QACxD,MAAM,KAAK,CAAC,wEAAwE,WAAW,GAAG,CAAC,CAAA;IACvG,CAAC;IAED,MAAM,KAAK,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,KAAK,EAAE,UAAU,EAAE,CAAC,CAAC;IAE9C,MAAM,eAAe,GAAwB,EAAE,CAAC;IAEhD,IAAI,CAAC,GAAG,KAAK,CAAC;IAEd,4CAA4C;IAC5C,MAAM,YAAY,GAAG,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,OAAO,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjF,KAAK,MAAM,WAAW,IAAI,YAAY,EAAE,CAAC;QACrC,MAAM,WAAW,GAAG,gBAAgB,CAAC,CAAC,EAAE,WAAW,EAAE,QAAQ,EAAE,SAAS,EAAE,gBAAgB,WAAW,EAAE,CAAC,CAAC;QAEzG,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,YAAY,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,IAAI,EAAE,SAAS,WAAW,EAAE,EAAE,CAAC,CAAC,KAAK,CAAC,WAAW,CAAsB,CAAC;QAC9H,eAAe,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;IACtC,CAAC;IAED,CAAC,GAAG,gBAAgB,CAAC,CAAC,EAAE,YAAY,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,GAAG,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IAErF,KAAK,IAAI,CAAC,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QAChD,CAAC,GAAG,cAAc,CAAC,CAAC,EAAE,eAAe,CAAC,CAAC,CAAC,EAAE,YAAY,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,cAAc,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;IACrH,CAAC;IAED,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC5B,OAAO,EAAE,KAAK;QACd,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,UAAU,IAAI,CAAC,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,CAAC;QAC9D,IAAI,EAAE,aAAa;KACtB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IAEjC,OAAO,EAAE,CAAC,KAAK,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC;AACvE,CAAC;AAGD,MAAM,CAAC,KAAK,UAAU,aAAa,CAAC,eAAyC,EAAE,OAA2B;IACtG,MAAM,KAAK,GAAG,MAAM,EAAE,CAAC,eAAe,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;IACjE,MAAM,IAAI,GAAG,UAAU,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,4CAA4C;IACzG,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,EAAE,GAAG,KAAK,CAAC;IAChC,MAAM,CAAC,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;IAE1B,OAAO,IAAI,CAAC;AAChB,CAAC;AAGD;;;;;;;;;;;;GAYG;AACH,SAAS,gBAAgB,CAAC,CAAoB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEhH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,IAAI,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAE7B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,IAAI,QAAQ,EAAE,CAAC;QACX,IAAI,aAAa,GAAG,CAAC,CAAC;QAEtB,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,IAAI,OAAO,EAAE,CAAC;YACzC,qEAAqE;YACrE,sDAAsD;YACtD,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;gBAC7B,OAAO;gBACP,UAAU,EAAE,CAAC;gBACb,OAAO,EAAE,MAAM;gBACf,OAAO,EAAE,CAAC,SAAS;gBACnB,iBAAiB,EAAE,UAAU;gBAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;aAC3B,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;QACrC,CAAC;QAED,IAAI,SAAS,EAAE,CAAC;YACZ,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC;gBACzC,IAAI,EAAE,GAAG,IAAI,qBAAqB;aACrC,CAAC,CAAC,KAAK,CAAC,aAAa,CAAsB,CAAC;QACjD,CAAC;QAED,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;YAC5B,aAAkC;YAClC,OAA4B;SAC/B,CAAC,CAAA;IACN,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,OAA4B,CAAC;AACxC,CAAC;AAGD;;;;;;;;;;;GAWG;AACH,SAAS,cAAc,CAAC,CAAoB,EAAE,IAAuB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEvI,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,eAAe,CAAC;QACrC,OAAO;QACP,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,CAAC;QACV,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,SAAS;KACzB,CAAC,CAAC;IAEH,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,EAAE,IAAI,EAAE,CAAC,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,qBAAqB,EAAE,CAAC,CAAC;IAEvF,IAAI,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IACnD,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,EAAE,IAAI,CAAC,CAAsB,CAAC;IAE7D,OAAO,gBAAgB,CAAC,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC;AACzE,CAAC"}
@@ -1,15 +1,10 @@
1
1
  import type { Tensor } from "@tensorflow/tfjs";
2
-
3
-
4
2
  export declare abstract class LazyIterator<T> {
5
3
  abstract next(): Promise<IteratorResult<T>>;
6
4
  }
7
-
8
-
9
5
  export declare abstract class Dataset<T> {
10
6
  abstract iterator(): Promise<LazyIterator<T>>;
11
7
  size: number;
12
8
  }
13
-
14
-
15
9
  export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor;
10
+ //# sourceMappingURL=tfjs_types.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"tfjs_types.d.ts","sourceRoot":"","sources":["../../src/tfjs_types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAG/C,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,YAAY,CAAC,CAAC;IACxC,QAAQ,CAAC,IAAI,IAAI,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC;CAC9C;AAGD,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,OAAO,CAAC,CAAC;IACnC,QAAQ,CAAC,QAAQ,IAAI,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;IAC7C,IAAI,EAAE,MAAM,CAAC;CAChB;AAGD,MAAM,MAAM,cAAc,GAAG,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,KAAK,MAAM,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=tfjs_types.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"tfjs_types.js","sourceRoot":"","sources":["../../src/tfjs_types.ts"],"names":[],"mappings":""}
@@ -0,0 +1,28 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ /**
3
+ * Calculate the desired scaled image's height and width. The shortest edge will
4
+ * be scaled to match its corresponding target shape's edge. The longer
5
+ * edge might end up larger than intended.
6
+ *
7
+ * @param image_shape the `[height, width]` of the image
8
+ * @param target_shape the intended `[height, width]` of the final scaled image
9
+ */
10
+ export declare function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number];
11
+ /**
12
+ * Calculate the starting point for a crop (slice) operation
13
+ * on an image tensor with the shape `[height, width, channels]`.
14
+ *
15
+ * @param image_shape the `[height, width]` of the image
16
+ * @param target_shape the intended `[height, width]` of the final cropped image
17
+ */
18
+ export declare function getRandomCropStart(image_shape: [height: number, width: number], target_shape: [height: number, width: number]): [number, number, number];
19
+ /**
20
+ * Calculate the height and width padding such that the image is
21
+ * divisible by 2^depth.
22
+ *
23
+ * In U-Net image segmentation, the contraction and concatenate
24
+ * operations requires the input image's height and width
25
+ * dimensions to be divisible by 2^depth.
26
+ */
27
+ export declare function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number];
28
+ //# sourceMappingURL=utils.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.d.ts","sourceRoot":"","sources":["../../src/utils.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;;;;GAOG;AACH,wBAAgB,aAAa,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,EAAE,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,CAAC,aAAa,EAAE,MAAM,EAAE,YAAY,EAAE,MAAM,CAAC,CAelI;AAGD;;;;;;GAMG;AACH,wBAAgB,kBAAkB,CAC9B,WAAW,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,EAC5C,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,GAC9C,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAiB1B;AAGD;;;;;;;GAOG;AACH,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,MAAM,GAAG,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,CAS5G"}
@@ -1,18 +1,14 @@
1
- import * as tf from "@tensorflow/tfjs";
2
-
3
-
4
1
  /**
5
2
  * Calculate the desired scaled image's height and width. The shortest edge will
6
3
  * be scaled to match its corresponding target shape's edge. The longer
7
4
  * edge might end up larger than intended.
8
- *
5
+ *
9
6
  * @param image_shape the `[height, width]` of the image
10
7
  * @param target_shape the intended `[height, width]` of the final scaled image
11
8
  */
12
- export function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number] {
13
- const [img_height, img_width] = image_shape as [number, number, number];
9
+ export function getScaleShape(image_shape, target_shape) {
10
+ const [img_height, img_width] = image_shape;
14
11
  const [target_height, target_width] = target_shape;
15
-
16
12
  // scale based on whichever target_edge / original_edge is largest,
17
13
  // we need the following to be true (1)
18
14
  // height * scale >= target_height
@@ -25,27 +21,20 @@ export function getScaleShape(image_shape: tf.Shape, target_shape: [number, numb
25
21
  const scale_factor = Math.max(target_height / img_height, target_width / img_width);
26
22
  return [Math.round(img_height * scale_factor), Math.round(img_width * scale_factor)];
27
23
  }
28
-
29
-
30
24
  /**
31
25
  * Calculate the starting point for a crop (slice) operation
32
26
  * on an image tensor with the shape `[height, width, channels]`.
33
- *
27
+ *
34
28
  * @param image_shape the `[height, width]` of the image
35
29
  * @param target_shape the intended `[height, width]` of the final cropped image
36
30
  */
37
- export function getRandomCropStart(
38
- image_shape: [height: number, width: number],
39
- target_shape: [height: number, width: number]
40
- ): [number, number, number] {
31
+ export function getRandomCropStart(image_shape, target_shape) {
41
32
  const [img_height, img_width] = image_shape;
42
33
  const [crop_x, crop_y] = target_shape;
43
-
44
34
  if (img_height < crop_x || img_width < crop_y) {
45
35
  throw Error(`getRandomCropShape: cannot crop with a size that's bigger than,` +
46
36
  ` the image. Original [${img_height}, ${img_width}], crop [${crop_x}, ${crop_y}].`);
47
37
  }
48
-
49
38
  // there's a +1 because Math.random()'s range is [0, 1), excluding 1,
50
39
  // hence +1 to ensure the full range of possible crop starting points
51
40
  return [
@@ -53,34 +42,22 @@ export function getRandomCropStart(
53
42
  Math.floor(Math.random() * (img_height - crop_x + 1)),
54
43
  Math.floor(Math.random() * (img_width - crop_y + 1)),
55
44
  0 // not cropping channels, so it starts at the first index
56
- ]
45
+ ];
57
46
  }
58
-
59
-
60
47
  /**
61
48
  * Calculate the height and width padding such that the image is
62
49
  * divisible by 2^depth.
63
- *
50
+ *
64
51
  * In U-Net image segmentation, the contraction and concatenate
65
52
  * operations requires the input image's height and width
66
53
  * dimensions to be divisible by 2^depth.
67
54
  */
68
- export function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number] {
55
+ export function getPaddingForSegmentation(image, depth) {
69
56
  const divisible = Math.pow(2, depth);
70
-
71
57
  const [height, width] = image.shape;
72
-
73
58
  return [
74
59
  (Math.ceil(height / divisible)) * divisible - height,
75
60
  (Math.ceil(width / divisible)) * divisible - width,
76
- ]
77
- }
78
-
79
-
80
- export function generateCausalAttentionMask(query_seq_length: number, key_seq_length: number) {
81
- return tf.tidy(() => {
82
- return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
83
- .sub(1)
84
- .mul(1e7);
85
- })
61
+ ];
86
62
  }
63
+ //# sourceMappingURL=utils.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.js","sourceRoot":"","sources":["../../src/utils.ts"],"names":[],"mappings":"AAGA;;;;;;;GAOG;AACH,MAAM,UAAU,aAAa,CAAC,WAAqB,EAAE,YAA8B;IAC/E,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAuC,CAAC;IACxE,MAAM,CAAC,aAAa,EAAE,YAAY,CAAC,GAAG,YAAY,CAAC;IAEnD,mEAAmE;IACnE,uCAAuC;IACvC,kCAAkC;IAClC,iCAAiC;IACjC,mDAAmD;IACnD,kCAAkC;IAClC,gCAAgC;IAChC,0FAA0F;IAC1F,oEAAoE;IACpE,MAAM,YAAY,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,GAAG,UAAU,EAAE,YAAY,GAAG,SAAS,CAAC,CAAC;IACpF,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC,CAAC;AACzF,CAAC;AAGD;;;;;;GAMG;AACH,MAAM,UAAU,kBAAkB,CAC9B,WAA4C,EAC5C,YAA6C;IAE7C,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAW,CAAC;IAC5C,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,YAAY,CAAC;IAEtC,IAAI,UAAU,GAAG,MAAM,IAAI,SAAS,GAAG,MAAM,EAAE,CAAC;QAC5C,MAAM,KAAK,CAAC,iEAAiE;YACzE,yBAAyB,UAAU,KAAK,SAAS,YAAY,MAAM,KAAK,MAAM,IAAI,CAAC,CAAC;IAC5F,CAAC;IAED,qEAAqE;IACrE,qEAAqE;IACrE,OAAO;QACH,uBAAuB;QACvB,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,UAAU,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACrD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,SAAS,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACpD,CAAC,CAAC,yDAAyD;KAC9D,CAAA;AACL,CAAC;AAGD;;;;;;;GAOG;AACH,MAAM,UAAU,yBAAyB,CAAC,KAAkB,EAAE,KAAa;IACvE,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;IAErC,MAAM,CAAC,MAAM,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC;IAEpC,OAAO;QACH,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,MAAM;QACpD,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,KAAK;KACrD,CAAA;AACL,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=utils.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.test.d.ts","sourceRoot":"","sources":["../../src/utils.test.ts"],"names":[],"mappings":""}
@@ -1,90 +1,63 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { getScaleShape, getRandomCropStart, generateCausalAttentionMask } from "@/utils";
3
-
2
+ import { getScaleShape, getRandomCropStart } from "@/utils";
3
+ import { causal } from "@/masks";
4
4
  // avoid TFJS node message during Jest testing
5
5
  tf.env().set('IS_NODE', false);
6
-
7
-
8
6
  describe("test custom TFJS utility functions", () => {
9
-
10
7
  test("crop an image using the same shape, results in same shape", async () => {
11
8
  // cropping an image of the same shape
12
- const img_size = [133, 84] as [number, number];
13
- const target_size = [133, 84] as [number, number];
14
-
9
+ const img_size = [133, 84];
10
+ const target_size = [133, 84];
15
11
  expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
16
12
  });
17
-
18
-
19
13
  it("should throw when crop is larger than image", async () => {
20
14
  expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
21
- })
22
-
23
-
15
+ });
24
16
  test("cropped image shape", async () => {
25
17
  // cropping from wide to tall image
26
18
  for (let i = 0; i < 100; i++) {
27
- const img_size = [4923, 832] as [number, number];
28
- const target_size = [333, 739] as [number, number];
29
-
30
- const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
31
-
19
+ const img_size = [4923, 832];
20
+ const target_size = [333, 739];
21
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
32
22
  expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
33
23
  expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
34
24
  }
35
-
36
25
  // cropping from tall to wide image
37
26
  for (let i = 0; i < 100; i++) {
38
- const img_size = [381, 999] as [number, number];
39
- const target_size = [300, 157] as [number, number];
40
-
41
- const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size)
42
-
27
+ const img_size = [381, 999];
28
+ const target_size = [300, 157];
29
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
43
30
  expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
44
31
  expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
45
32
  }
46
33
  });
47
-
48
-
49
34
  test("scale 1:1, results in the same shape", async () => {
50
- const scale = getScaleShape([256, 256], [256, 256])
35
+ const scale = getScaleShape([256, 256], [256, 256]);
51
36
  expect(scale).toEqual([256, 256]);
52
37
  });
53
-
54
-
55
38
  test("scaled image shape", async () => {
56
39
  // scaling squares result in squares
57
- const scale1 = getScaleShape([256, 256], [128, 128])
40
+ const scale1 = getScaleShape([256, 256], [128, 128]);
58
41
  expect(scale1).toEqual([128, 128]);
59
-
60
- const scale2 = getScaleShape([128, 128], [256, 256])
42
+ const scale2 = getScaleShape([128, 128], [256, 256]);
61
43
  expect(scale2).toEqual([256, 256]);
62
-
63
- const scale3 = getScaleShape([123, 123], [321, 321])
44
+ const scale3 = getScaleShape([123, 123], [321, 321]);
64
45
  expect(scale3).toEqual([321, 321]);
65
-
66
- const scale4 = getScaleShape([321, 321], [123, 123])
46
+ const scale4 = getScaleShape([321, 321], [123, 123]);
67
47
  expect(scale4).toEqual([123, 123]);
68
-
69
48
  // scaling rectangles result in rectangles
70
- const scale5 = getScaleShape([640, 480], [1280, 960])
49
+ const scale5 = getScaleShape([640, 480], [1280, 960]);
71
50
  expect(scale5).toEqual([1280, 960]);
72
-
73
- const scale6 = getScaleShape([480, 640], [960, 1280])
51
+ const scale6 = getScaleShape([480, 640], [960, 1280]);
74
52
  expect(scale6).toEqual([960, 1280]);
75
-
76
- const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333])
53
+ const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333]);
77
54
  expect(scale7_h).toBeGreaterThan(scale7_w);
78
-
79
- const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555])
55
+ const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
80
56
  expect(scale8_h).toBeLessThan(scale8_w);
81
57
  });
82
-
83
-
84
58
  test("causal attention map", async () => {
85
59
  const seq_len = 4;
86
- const causal_mask = generateCausalAttentionMask(seq_len, seq_len);
87
-
60
+ const causal_mask = causal(seq_len, seq_len);
88
61
  const _ = -1e7;
89
62
  const expected_mask = tf.tensor([
90
63
  [0, _, _, _],
@@ -92,10 +65,9 @@ describe("test custom TFJS utility functions", () => {
92
65
  [0, 0, 0, _],
93
66
  [0, 0, 0, 0]
94
67
  ]);
95
-
96
68
  // this might fail due to precision issues on the masked positions,
97
69
  // in which case use less <= to 6 or 12 (number of masked positions x2)
98
70
  expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
99
71
  });
100
-
101
72
  });
73
+ //# sourceMappingURL=utils.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAC5D,OAAO,EAAE,MAAM,EAAE,MAAM,SAAS,CAAC;AAEjC,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAE7C,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
@@ -0,0 +1,10 @@
1
+ import type { Tensor } from "@tensorflow/tfjs";
2
+ export declare abstract class LazyIterator<T> {
3
+ abstract next(): Promise<IteratorResult<T>>;
4
+ }
5
+ export declare abstract class Dataset<T> {
6
+ abstract iterator(): Promise<LazyIterator<T>>;
7
+ size: number;
8
+ }
9
+ export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor;
10
+ //# sourceMappingURL=tfjs_types.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"tfjs_types.d.ts","sourceRoot":"","sources":["../src/tfjs_types.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAG/C,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,YAAY,CAAC,CAAC;IACxC,QAAQ,CAAC,IAAI,IAAI,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC;CAC9C;AAGD,MAAM,CAAC,OAAO,CAAC,QAAQ,OAAO,OAAO,CAAC,CAAC;IACnC,QAAQ,CAAC,QAAQ,IAAI,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;IAC7C,IAAI,EAAE,MAAM,CAAC;CAChB;AAGD,MAAM,MAAM,cAAc,GAAG,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,KAAK,MAAM,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=tfjs_types.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"tfjs_types.js","sourceRoot":"","sources":["../src/tfjs_types.ts"],"names":[],"mappings":""}
@@ -0,0 +1,28 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ /**
3
+ * Calculate the desired scaled image's height and width. The shortest edge will
4
+ * be scaled to match its corresponding target shape's edge. The longer
5
+ * edge might end up larger than intended.
6
+ *
7
+ * @param image_shape the `[height, width]` of the image
8
+ * @param target_shape the intended `[height, width]` of the final scaled image
9
+ */
10
+ export declare function getScaleShape(image_shape: tf.Shape, target_shape: [number, number]): [scaled_height: number, scaled_width: number];
11
+ /**
12
+ * Calculate the starting point for a crop (slice) operation
13
+ * on an image tensor with the shape `[height, width, channels]`.
14
+ *
15
+ * @param image_shape the `[height, width]` of the image
16
+ * @param target_shape the intended `[height, width]` of the final cropped image
17
+ */
18
+ export declare function getRandomCropStart(image_shape: [height: number, width: number], target_shape: [height: number, width: number]): [number, number, number];
19
+ /**
20
+ * Calculate the height and width padding such that the image is
21
+ * divisible by 2^depth.
22
+ *
23
+ * In U-Net image segmentation, the contraction and concatenate
24
+ * operations requires the input image's height and width
25
+ * dimensions to be divisible by 2^depth.
26
+ */
27
+ export declare function getPaddingForSegmentation(image: tf.Tensor3D, depth: number): [height: number, width: number];
28
+ //# sourceMappingURL=utils.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.d.ts","sourceRoot":"","sources":["../src/utils.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;;;;GAOG;AACH,wBAAgB,aAAa,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,EAAE,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,CAAC,aAAa,EAAE,MAAM,EAAE,YAAY,EAAE,MAAM,CAAC,CAelI;AAGD;;;;;;GAMG;AACH,wBAAgB,kBAAkB,CAC9B,WAAW,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,EAC5C,YAAY,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,GAC9C,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAiB1B;AAGD;;;;;;;GAOG;AACH,wBAAgB,yBAAyB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,MAAM,GAAG,CAAC,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,CAS5G"}
package/dist/utils.js ADDED
@@ -0,0 +1,63 @@
1
+ /**
2
+ * Calculate the desired scaled image's height and width. The shortest edge will
3
+ * be scaled to match its corresponding target shape's edge. The longer
4
+ * edge might end up larger than intended.
5
+ *
6
+ * @param image_shape the `[height, width]` of the image
7
+ * @param target_shape the intended `[height, width]` of the final scaled image
8
+ */
9
+ export function getScaleShape(image_shape, target_shape) {
10
+ const [img_height, img_width] = image_shape;
11
+ const [target_height, target_width] = target_shape;
12
+ // scale based on whichever target_edge / original_edge is largest,
13
+ // we need the following to be true (1)
14
+ // height * scale >= target_height
15
+ // width * scale >= target_height
16
+ // rearranging to get an equivalent requirement (2)
17
+ // scale >= target_height / height
18
+ // scale >= target_width / width
19
+ // by picking the scale value that's largest of the two, we satisfy (2), and therefore (1)
20
+ // it may be more intuitive to think of scale as scale_h and scale_w
21
+ const scale_factor = Math.max(target_height / img_height, target_width / img_width);
22
+ return [Math.round(img_height * scale_factor), Math.round(img_width * scale_factor)];
23
+ }
24
+ /**
25
+ * Calculate the starting point for a crop (slice) operation
26
+ * on an image tensor with the shape `[height, width, channels]`.
27
+ *
28
+ * @param image_shape the `[height, width]` of the image
29
+ * @param target_shape the intended `[height, width]` of the final cropped image
30
+ */
31
+ export function getRandomCropStart(image_shape, target_shape) {
32
+ const [img_height, img_width] = image_shape;
33
+ const [crop_x, crop_y] = target_shape;
34
+ if (img_height < crop_x || img_width < crop_y) {
35
+ throw Error(`getRandomCropShape: cannot crop with a size that's bigger than,` +
36
+ ` the image. Original [${img_height}, ${img_width}], crop [${crop_x}, ${crop_y}].`);
37
+ }
38
+ // there's a +1 because Math.random()'s range is [0, 1), excluding 1,
39
+ // hence +1 to ensure the full range of possible crop starting points
40
+ return [
41
+ // TODO: revisit the +1
42
+ Math.floor(Math.random() * (img_height - crop_x + 1)),
43
+ Math.floor(Math.random() * (img_width - crop_y + 1)),
44
+ 0 // not cropping channels, so it starts at the first index
45
+ ];
46
+ }
47
+ /**
48
+ * Calculate the height and width padding such that the image is
49
+ * divisible by 2^depth.
50
+ *
51
+ * In U-Net image segmentation, the contraction and concatenate
52
+ * operations requires the input image's height and width
53
+ * dimensions to be divisible by 2^depth.
54
+ */
55
+ export function getPaddingForSegmentation(image, depth) {
56
+ const divisible = Math.pow(2, depth);
57
+ const [height, width] = image.shape;
58
+ return [
59
+ (Math.ceil(height / divisible)) * divisible - height,
60
+ (Math.ceil(width / divisible)) * divisible - width,
61
+ ];
62
+ }
63
+ //# sourceMappingURL=utils.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.js","sourceRoot":"","sources":["../src/utils.ts"],"names":[],"mappings":"AAGA;;;;;;;GAOG;AACH,MAAM,UAAU,aAAa,CAAC,WAAqB,EAAE,YAA8B;IAC/E,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAuC,CAAC;IACxE,MAAM,CAAC,aAAa,EAAE,YAAY,CAAC,GAAG,YAAY,CAAC;IAEnD,mEAAmE;IACnE,uCAAuC;IACvC,kCAAkC;IAClC,iCAAiC;IACjC,mDAAmD;IACnD,kCAAkC;IAClC,gCAAgC;IAChC,0FAA0F;IAC1F,oEAAoE;IACpE,MAAM,YAAY,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,GAAG,UAAU,EAAE,YAAY,GAAG,SAAS,CAAC,CAAC;IACpF,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,YAAY,CAAC,CAAC,CAAC;AACzF,CAAC;AAGD;;;;;;GAMG;AACH,MAAM,UAAU,kBAAkB,CAC9B,WAA4C,EAC5C,YAA6C;IAE7C,MAAM,CAAC,UAAU,EAAE,SAAS,CAAC,GAAG,WAAW,CAAC;IAC5C,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,GAAG,YAAY,CAAC;IAEtC,IAAI,UAAU,GAAG,MAAM,IAAI,SAAS,GAAG,MAAM,EAAE,CAAC;QAC5C,MAAM,KAAK,CAAC,iEAAiE;YACzE,yBAAyB,UAAU,KAAK,SAAS,YAAY,MAAM,KAAK,MAAM,IAAI,CAAC,CAAC;IAC5F,CAAC;IAED,qEAAqE;IACrE,qEAAqE;IACrE,OAAO;QACH,uBAAuB;QACvB,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,UAAU,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACrD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,CAAC,SAAS,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;QACpD,CAAC,CAAC,yDAAyD;KAC9D,CAAA;AACL,CAAC;AAGD;;;;;;;GAOG;AACH,MAAM,UAAU,yBAAyB,CAAC,KAAkB,EAAE,KAAa;IACvE,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;IAErC,MAAM,CAAC,MAAM,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC;IAEpC,OAAO;QACH,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,MAAM;QACpD,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,GAAG,SAAS,CAAC,CAAC,GAAG,SAAS,GAAG,KAAK;KACrD,CAAA;AACL,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=utils.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.test.d.ts","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":""}