@simulatte/doppler 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -5
- package/package.json +27 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.d.ts +80 -0
- package/src/client/doppler-api.js +298 -0
- package/src/client/doppler-provider/types.js +1 -1
- package/src/client/doppler-registry.d.ts +23 -0
- package/src/client/doppler-registry.js +88 -0
- package/src/client/doppler-registry.json +39 -0
- package/src/config/execution-contract-check.d.ts +82 -0
- package/src/config/execution-contract-check.js +317 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +90 -67
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +27 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +49 -11
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/tokenizer-utils.js +17 -3
- package/src/formats/rdrr/validation.js +13 -0
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -0
- package/src/index-browser.js +2 -1
- package/src/index.d.ts +1 -0
- package/src/index.js +2 -1
- package/src/inference/browser-harness.js +164 -38
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +141 -101
- package/src/inference/pipelines/text/init.js +41 -10
- package/src/inference/pipelines/text.js +7 -1
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/lean-execution-contract.d.ts +16 -0
- package/src/tooling/lean-execution-contract.js +81 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +30 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +167 -6
|
@@ -5,12 +5,19 @@ import { getBuffer, getWeightDtype } from '../../../gpu/weight-buffer.js';
|
|
|
5
5
|
import { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
6
6
|
import { runConv2D, recordConv2D } from '../../../gpu/kernels/conv2d.js';
|
|
7
7
|
import { runGroupNorm, recordGroupNorm } from '../../../gpu/kernels/groupnorm.js';
|
|
8
|
-
import {
|
|
8
|
+
import { runRMSNorm, recordRMSNorm } from '../../../gpu/kernels/rmsnorm.js';
|
|
9
|
+
import { runSiLU, runSiLURowSplit, recordSiLU, recordSiLURowSplit } from '../../../gpu/kernels/silu.js';
|
|
9
10
|
import { runMatmul, recordMatmul } from '../../../gpu/kernels/matmul.js';
|
|
10
11
|
import { runAttention, recordAttention } from '../../../gpu/kernels/attention.js';
|
|
11
12
|
import { runTranspose, recordTranspose } from '../../../gpu/kernels/transpose.js';
|
|
12
13
|
import { runResidualAdd, runBiasAdd, recordResidualAdd, recordBiasAdd } from '../../../gpu/kernels/residual.js';
|
|
13
14
|
import { runUpsample2D, recordUpsample2D } from '../../../gpu/kernels/upsample2d.js';
|
|
15
|
+
import { runDepthwiseConv2D, recordDepthwiseConv2D } from '../../../gpu/kernels/depthwise_conv2d.js';
|
|
16
|
+
import { runGroupedPointwiseConv2D, recordGroupedPointwiseConv2D } from '../../../gpu/kernels/grouped_pointwise_conv2d.js';
|
|
17
|
+
import { runSanaLinearAttention, recordSanaLinearAttention } from '../../../gpu/kernels/sana_linear_attention.js';
|
|
18
|
+
import { runPixelShuffle, recordPixelShuffle } from '../../../gpu/kernels/pixel_shuffle.js';
|
|
19
|
+
import { runRepeatChannels, recordRepeatChannels } from '../../../gpu/kernels/repeat_channels.js';
|
|
20
|
+
import { runReLU, recordReLU } from '../../../gpu/kernels/relu.js';
|
|
14
21
|
import { castF32ToF16, recordCastF32ToF16 } from '../../../gpu/kernels/cast.js';
|
|
15
22
|
import { f16ToF32 } from '../../../loader/dtype-utils.js';
|
|
16
23
|
import { log } from '../../../debug/index.js';
|
|
@@ -146,31 +153,64 @@ function buildIndexList(weights, prefix) {
|
|
|
146
153
|
return Array.from(indices).sort((a, b) => a - b);
|
|
147
154
|
}
|
|
148
155
|
|
|
156
|
+
function normalizePerBlockValue(value, count, label) {
|
|
157
|
+
if (Array.isArray(value)) {
|
|
158
|
+
if (value.length !== count) {
|
|
159
|
+
throw new Error(`${label} must have ${count} entries, got ${value.length}.`);
|
|
160
|
+
}
|
|
161
|
+
return value;
|
|
162
|
+
}
|
|
163
|
+
return Array.from({ length: count }, () => value);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
function tensorElementCount(tensor) {
|
|
167
|
+
if (!Array.isArray(tensor?.shape) || tensor.shape.length === 0) {
|
|
168
|
+
throw new Error('Tensor shape is required.');
|
|
169
|
+
}
|
|
170
|
+
return tensor.shape.reduce((acc, value) => acc * value, 1);
|
|
171
|
+
}
|
|
172
|
+
|
|
149
173
|
function createKernelOps(recorder) {
|
|
150
174
|
if (!recorder) {
|
|
151
175
|
return {
|
|
152
176
|
conv2d: runConv2D,
|
|
153
177
|
groupNorm: runGroupNorm,
|
|
178
|
+
rmsNorm: runRMSNorm,
|
|
154
179
|
silu: runSiLU,
|
|
180
|
+
siluRowSplit: runSiLURowSplit,
|
|
155
181
|
matmul: runMatmul,
|
|
156
182
|
attention: runAttention,
|
|
157
183
|
transpose: runTranspose,
|
|
158
184
|
residualAdd: runResidualAdd,
|
|
159
185
|
biasAdd: runBiasAdd,
|
|
160
186
|
upsample2d: runUpsample2D,
|
|
187
|
+
depthwiseConv2d: runDepthwiseConv2D,
|
|
188
|
+
groupedPointwiseConv2d: runGroupedPointwiseConv2D,
|
|
189
|
+
sanaLinearAttention: runSanaLinearAttention,
|
|
190
|
+
pixelShuffle: runPixelShuffle,
|
|
191
|
+
repeatChannels: runRepeatChannels,
|
|
192
|
+
relu: runReLU,
|
|
161
193
|
castF32ToF16,
|
|
162
194
|
};
|
|
163
195
|
}
|
|
164
196
|
return {
|
|
165
197
|
conv2d: (...args) => recordConv2D(recorder, ...args),
|
|
166
198
|
groupNorm: (...args) => recordGroupNorm(recorder, ...args),
|
|
199
|
+
rmsNorm: (...args) => recordRMSNorm(recorder, ...args),
|
|
167
200
|
silu: (...args) => recordSiLU(recorder, ...args),
|
|
201
|
+
siluRowSplit: (...args) => recordSiLURowSplit(recorder, ...args),
|
|
168
202
|
matmul: (...args) => recordMatmul(recorder, ...args),
|
|
169
203
|
attention: (...args) => recordAttention(recorder, ...args),
|
|
170
204
|
transpose: (...args) => recordTranspose(recorder, ...args),
|
|
171
205
|
residualAdd: (...args) => recordResidualAdd(recorder, ...args),
|
|
172
206
|
biasAdd: (...args) => recordBiasAdd(recorder, ...args),
|
|
173
207
|
upsample2d: (...args) => recordUpsample2D(recorder, ...args),
|
|
208
|
+
depthwiseConv2d: (...args) => recordDepthwiseConv2D(recorder, ...args),
|
|
209
|
+
groupedPointwiseConv2d: (...args) => recordGroupedPointwiseConv2D(recorder, ...args),
|
|
210
|
+
sanaLinearAttention: (...args) => recordSanaLinearAttention(recorder, ...args),
|
|
211
|
+
pixelShuffle: (...args) => recordPixelShuffle(recorder, ...args),
|
|
212
|
+
repeatChannels: (...args) => recordRepeatChannels(recorder, ...args),
|
|
213
|
+
relu: (...args) => recordReLU(recorder, ...args),
|
|
174
214
|
castF32ToF16: (...args) => recordCastF32ToF16(recorder, ...args),
|
|
175
215
|
};
|
|
176
216
|
}
|
|
@@ -197,7 +237,7 @@ async function applyConv2D(state, weights, shapes, namePrefix, options = {}, ops
|
|
|
197
237
|
const weightName = `${namePrefix}.weight`;
|
|
198
238
|
const biasName = `${namePrefix}.bias`;
|
|
199
239
|
const weight = getWeight(weights, shapes, weightName);
|
|
200
|
-
const bias =
|
|
240
|
+
const bias = getWeightOptional(weights, shapes, biasName);
|
|
201
241
|
const { outChannels, inChannels, kernelH, kernelW } = getConvShape(weight.shape);
|
|
202
242
|
|
|
203
243
|
if (inChannels !== state.channels) {
|
|
@@ -207,7 +247,7 @@ async function applyConv2D(state, weights, shapes, namePrefix, options = {}, ops
|
|
|
207
247
|
const output = await ops.conv2d(
|
|
208
248
|
state.tensor,
|
|
209
249
|
weight.value,
|
|
210
|
-
bias
|
|
250
|
+
bias?.value ?? null,
|
|
211
251
|
{
|
|
212
252
|
inChannels,
|
|
213
253
|
outChannels,
|
|
@@ -230,6 +270,117 @@ async function applyConv2D(state, weights, shapes, namePrefix, options = {}, ops
|
|
|
230
270
|
};
|
|
231
271
|
}
|
|
232
272
|
|
|
273
|
+
async function submitCopyWork(device, recorder, encoder) {
|
|
274
|
+
if (recorder) {
|
|
275
|
+
return;
|
|
276
|
+
}
|
|
277
|
+
device.queue.submit([encoder.finish()]);
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
async function concatChannelTensors(tensors, height, width, recorder) {
|
|
281
|
+
if (!Array.isArray(tensors) || tensors.length === 0) {
|
|
282
|
+
throw new Error('concatChannelTensors requires at least one tensor.');
|
|
283
|
+
}
|
|
284
|
+
const device = getDevice();
|
|
285
|
+
if (!device) {
|
|
286
|
+
throw new Error('Channel tensor concatenation requires a WebGPU device.');
|
|
287
|
+
}
|
|
288
|
+
const dtype = tensors[0].dtype;
|
|
289
|
+
const bytesPerElement = dtypeBytes(dtype);
|
|
290
|
+
let totalChannels = 0;
|
|
291
|
+
for (const tensor of tensors) {
|
|
292
|
+
if (tensor.dtype !== dtype) {
|
|
293
|
+
throw new Error('concatChannelTensors requires matching dtypes.');
|
|
294
|
+
}
|
|
295
|
+
if (tensor.shape[1] !== height || tensor.shape[2] !== width) {
|
|
296
|
+
throw new Error('concatChannelTensors requires matching spatial dimensions.');
|
|
297
|
+
}
|
|
298
|
+
totalChannels += tensor.shape[0];
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
const output = acquireBuffer(totalChannels * height * width * bytesPerElement, undefined, 'vae_concat_channels');
|
|
302
|
+
const encoder = recorder ? recorder.getEncoder() : device.createCommandEncoder({ label: 'vae_concat_channels' });
|
|
303
|
+
let channelOffset = 0;
|
|
304
|
+
for (const tensor of tensors) {
|
|
305
|
+
const byteLength = tensor.shape[0] * height * width * bytesPerElement;
|
|
306
|
+
encoder.copyBufferToBuffer(
|
|
307
|
+
tensor.buffer,
|
|
308
|
+
0,
|
|
309
|
+
output,
|
|
310
|
+
channelOffset * height * width * bytesPerElement,
|
|
311
|
+
byteLength
|
|
312
|
+
);
|
|
313
|
+
channelOffset += tensor.shape[0];
|
|
314
|
+
}
|
|
315
|
+
await submitCopyWork(device, recorder, encoder);
|
|
316
|
+
return createTensor(output, dtype, [totalChannels, height, width], 'vae_concat_channels');
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
async function sliceChannelTensor(tensor, startChannel, channelCount, height, width, recorder) {
|
|
320
|
+
const device = getDevice();
|
|
321
|
+
if (!device) {
|
|
322
|
+
throw new Error('Channel tensor slicing requires a WebGPU device.');
|
|
323
|
+
}
|
|
324
|
+
const bytesPerElement = dtypeBytes(tensor.dtype);
|
|
325
|
+
const channelSize = height * width * bytesPerElement;
|
|
326
|
+
const output = acquireBuffer(channelCount * channelSize, undefined, 'vae_slice_channels');
|
|
327
|
+
const encoder = recorder ? recorder.getEncoder() : device.createCommandEncoder({ label: 'vae_slice_channels' });
|
|
328
|
+
encoder.copyBufferToBuffer(
|
|
329
|
+
tensor.buffer,
|
|
330
|
+
startChannel * channelSize,
|
|
331
|
+
output,
|
|
332
|
+
0,
|
|
333
|
+
channelCount * channelSize
|
|
334
|
+
);
|
|
335
|
+
await submitCopyWork(device, recorder, encoder);
|
|
336
|
+
return createTensor(output, tensor.dtype, [channelCount, height, width], 'vae_slice_channels');
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
async function runChannelwiseRmsNorm(state, normWeight, normBias, eps, ops, release) {
|
|
340
|
+
const spatial = state.height * state.width;
|
|
341
|
+
const channelsSpatial = reshapeTensor(state.tensor, [state.channels, spatial], 'vae_rmsnorm_channels_spatial');
|
|
342
|
+
const tokens = await ops.transpose(channelsSpatial, state.channels, spatial);
|
|
343
|
+
const normed = await ops.rmsNorm(tokens, normWeight.value, eps, {
|
|
344
|
+
batchSize: spatial,
|
|
345
|
+
hiddenSize: state.channels,
|
|
346
|
+
});
|
|
347
|
+
release(tokens.buffer);
|
|
348
|
+
let shifted = normed;
|
|
349
|
+
if (normBias) {
|
|
350
|
+
const biasTensor = createBiasTensor(normBias, `${normBias.name ?? 'vae_rmsnorm_bias'}`, normed.dtype);
|
|
351
|
+
shifted = await ops.biasAdd(normed, biasTensor, spatial, state.channels);
|
|
352
|
+
release(normed.buffer);
|
|
353
|
+
}
|
|
354
|
+
const channelsFirst = await ops.transpose(shifted, spatial, state.channels);
|
|
355
|
+
release(shifted.buffer);
|
|
356
|
+
return {
|
|
357
|
+
tensor: reshapeTensor(channelsFirst, [state.channels, state.height, state.width], 'vae_rmsnorm_output'),
|
|
358
|
+
channels: state.channels,
|
|
359
|
+
height: state.height,
|
|
360
|
+
width: state.width,
|
|
361
|
+
};
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
async function channelsToTokens(state, ops) {
|
|
365
|
+
const spatial = state.height * state.width;
|
|
366
|
+
const channelsSpatial = reshapeTensor(state.tensor, [state.channels, spatial], 'vae_channels_spatial');
|
|
367
|
+
const tokens = await ops.transpose(channelsSpatial, state.channels, spatial);
|
|
368
|
+
return {
|
|
369
|
+
tensor: tokens,
|
|
370
|
+
numTokens: spatial,
|
|
371
|
+
};
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
async function tokensToChannels(tokens, channels, height, width, ops) {
|
|
375
|
+
const channelsSpatial = await ops.transpose(tokens, height * width, channels);
|
|
376
|
+
return {
|
|
377
|
+
tensor: reshapeTensor(channelsSpatial, [channels, height, width], 'vae_tokens_channels'),
|
|
378
|
+
channels,
|
|
379
|
+
height,
|
|
380
|
+
width,
|
|
381
|
+
};
|
|
382
|
+
}
|
|
383
|
+
|
|
233
384
|
async function runResnetBlock(state, weights, shapes, prefix, config, ops, release) {
|
|
234
385
|
const numGroups = config.numGroups;
|
|
235
386
|
const eps = config.eps;
|
|
@@ -464,6 +615,553 @@ async function runMidBlockAttention(state, weights, shapes, prefix, config, ops,
|
|
|
464
615
|
};
|
|
465
616
|
}
|
|
466
617
|
|
|
618
|
+
async function runAutoencoderDCInputProjection(state, weights, shapes, config, ops, release) {
|
|
619
|
+
const blockOutChannels = config.decoder_block_out_channels;
|
|
620
|
+
if (!Array.isArray(blockOutChannels) || blockOutChannels.length === 0) {
|
|
621
|
+
throw new Error('AutoencoderDC decode requires decoder_block_out_channels.');
|
|
622
|
+
}
|
|
623
|
+
const outChannels = blockOutChannels[blockOutChannels.length - 1];
|
|
624
|
+
const repeats = outChannels / state.channels;
|
|
625
|
+
if (!Number.isInteger(repeats) || repeats < 1) {
|
|
626
|
+
throw new Error(
|
|
627
|
+
`AutoencoderDC input shortcut requires an integer repeat factor; got ${outChannels}/${state.channels}.`
|
|
628
|
+
);
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
const shortcut = await ops.repeatChannels(state.tensor, {
|
|
632
|
+
inChannels: state.channels,
|
|
633
|
+
height: state.height,
|
|
634
|
+
width: state.width,
|
|
635
|
+
repeats,
|
|
636
|
+
});
|
|
637
|
+
const projected = await applyConv2D(state, weights, shapes, 'vae.decoder.conv_in', { pad: 1 }, ops, release);
|
|
638
|
+
const size = projected.channels * projected.height * projected.width;
|
|
639
|
+
const combined = await ops.residualAdd(
|
|
640
|
+
reshapeTensor(projected.tensor, [size], 'vae_dc_conv_in'),
|
|
641
|
+
reshapeTensor(shortcut, [size], 'vae_dc_conv_in_shortcut'),
|
|
642
|
+
size,
|
|
643
|
+
{ useVec4: true }
|
|
644
|
+
);
|
|
645
|
+
release(projected.tensor.buffer);
|
|
646
|
+
release(shortcut.buffer);
|
|
647
|
+
return {
|
|
648
|
+
tensor: reshapeTensor(combined, [projected.channels, projected.height, projected.width], 'vae_dc_conv_in_out'),
|
|
649
|
+
channels: projected.channels,
|
|
650
|
+
height: projected.height,
|
|
651
|
+
width: projected.width,
|
|
652
|
+
};
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
async function runAutoencoderDCUpBlock(state, weights, shapes, prefix, config, ops, release, recorder) {
|
|
656
|
+
const convWeight = getWeight(weights, shapes, `${prefix}.conv.weight`);
|
|
657
|
+
const convBias = getWeightOptional(weights, shapes, `${prefix}.conv.bias`);
|
|
658
|
+
const { outChannels, inChannels, kernelH, kernelW } = getConvShape(convWeight.shape);
|
|
659
|
+
if (inChannels !== state.channels) {
|
|
660
|
+
throw new Error(
|
|
661
|
+
`AutoencoderDC up block "${prefix}" expected ${inChannels} input channels, got ${state.channels}.`
|
|
662
|
+
);
|
|
663
|
+
}
|
|
664
|
+
const factor = 2;
|
|
665
|
+
const outHeight = state.height * factor;
|
|
666
|
+
const outWidth = state.width * factor;
|
|
667
|
+
const shortcutRepeats = outChannels * factor * factor / state.channels;
|
|
668
|
+
if (!Number.isInteger(shortcutRepeats) || shortcutRepeats < 1) {
|
|
669
|
+
throw new Error(
|
|
670
|
+
`AutoencoderDC up block "${prefix}" requires integer shortcut repeats; got ${outChannels}/${state.channels}.`
|
|
671
|
+
);
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
let projected;
|
|
675
|
+
if (config.upsample_block_type === 'interpolate') {
|
|
676
|
+
const upsampled = await ops.upsample2d(state.tensor, {
|
|
677
|
+
channels: state.channels,
|
|
678
|
+
height: state.height,
|
|
679
|
+
width: state.width,
|
|
680
|
+
scale: factor,
|
|
681
|
+
});
|
|
682
|
+
projected = await ops.conv2d(
|
|
683
|
+
reshapeTensor(upsampled, [state.channels, outHeight, outWidth], 'vae_dc_upsample'),
|
|
684
|
+
convWeight.value,
|
|
685
|
+
convBias?.value ?? null,
|
|
686
|
+
{
|
|
687
|
+
inChannels: state.channels,
|
|
688
|
+
outChannels,
|
|
689
|
+
height: outHeight,
|
|
690
|
+
width: outWidth,
|
|
691
|
+
kernelH,
|
|
692
|
+
kernelW,
|
|
693
|
+
stride: 1,
|
|
694
|
+
pad: 1,
|
|
695
|
+
}
|
|
696
|
+
);
|
|
697
|
+
release(upsampled.buffer);
|
|
698
|
+
} else if (config.upsample_block_type === 'pixel_shuffle') {
|
|
699
|
+
const conv = await ops.conv2d(state.tensor, convWeight.value, convBias?.value ?? null, {
|
|
700
|
+
inChannels: state.channels,
|
|
701
|
+
outChannels: outChannels * factor * factor,
|
|
702
|
+
height: state.height,
|
|
703
|
+
width: state.width,
|
|
704
|
+
kernelH,
|
|
705
|
+
kernelW,
|
|
706
|
+
stride: 1,
|
|
707
|
+
pad: 1,
|
|
708
|
+
});
|
|
709
|
+
projected = await ops.pixelShuffle(conv, {
|
|
710
|
+
outChannels,
|
|
711
|
+
outHeight,
|
|
712
|
+
outWidth,
|
|
713
|
+
gridWidth: state.width,
|
|
714
|
+
gridHeight: state.height,
|
|
715
|
+
patchSize: factor,
|
|
716
|
+
patchChannels: outChannels * factor * factor,
|
|
717
|
+
});
|
|
718
|
+
release(conv.buffer);
|
|
719
|
+
} else {
|
|
720
|
+
throw new Error(
|
|
721
|
+
`Unsupported AutoencoderDC upsample_block_type "${config.upsample_block_type}".`
|
|
722
|
+
);
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
const repeated = await ops.repeatChannels(state.tensor, {
|
|
726
|
+
inChannels: state.channels,
|
|
727
|
+
height: state.height,
|
|
728
|
+
width: state.width,
|
|
729
|
+
repeats: shortcutRepeats,
|
|
730
|
+
});
|
|
731
|
+
const shortcut = await ops.pixelShuffle(repeated, {
|
|
732
|
+
outChannels,
|
|
733
|
+
outHeight,
|
|
734
|
+
outWidth,
|
|
735
|
+
gridWidth: state.width,
|
|
736
|
+
gridHeight: state.height,
|
|
737
|
+
patchSize: factor,
|
|
738
|
+
patchChannels: state.channels * shortcutRepeats,
|
|
739
|
+
});
|
|
740
|
+
release(repeated.buffer);
|
|
741
|
+
release(state.tensor.buffer);
|
|
742
|
+
|
|
743
|
+
const size = outChannels * outHeight * outWidth;
|
|
744
|
+
const combined = await ops.residualAdd(
|
|
745
|
+
reshapeTensor(projected, [size], 'vae_dc_up_main'),
|
|
746
|
+
reshapeTensor(shortcut, [size], 'vae_dc_up_shortcut'),
|
|
747
|
+
size,
|
|
748
|
+
{ useVec4: true }
|
|
749
|
+
);
|
|
750
|
+
release(projected.buffer);
|
|
751
|
+
release(shortcut.buffer);
|
|
752
|
+
return {
|
|
753
|
+
tensor: reshapeTensor(combined, [outChannels, outHeight, outWidth], 'vae_dc_up_out'),
|
|
754
|
+
channels: outChannels,
|
|
755
|
+
height: outHeight,
|
|
756
|
+
width: outWidth,
|
|
757
|
+
};
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
async function runAutoencoderDCResBlock(state, weights, shapes, prefix, eps, ops, release) {
|
|
761
|
+
const conv1Weight = getWeight(weights, shapes, `${prefix}.conv1.weight`);
|
|
762
|
+
const conv1Bias = getWeightOptional(weights, shapes, `${prefix}.conv1.bias`);
|
|
763
|
+
const conv1Shape = getConvShape(conv1Weight.shape);
|
|
764
|
+
const conv1Tensor = await ops.conv2d(state.tensor, conv1Weight.value, conv1Bias?.value ?? null, {
|
|
765
|
+
inChannels: conv1Shape.inChannels,
|
|
766
|
+
outChannels: conv1Shape.outChannels,
|
|
767
|
+
height: state.height,
|
|
768
|
+
width: state.width,
|
|
769
|
+
kernelH: conv1Shape.kernelH,
|
|
770
|
+
kernelW: conv1Shape.kernelW,
|
|
771
|
+
stride: 1,
|
|
772
|
+
pad: 1,
|
|
773
|
+
});
|
|
774
|
+
const conv1 = {
|
|
775
|
+
tensor: reshapeTensor(conv1Tensor, [conv1Shape.outChannels, state.height, state.width], 'vae_dc_resblock_conv1'),
|
|
776
|
+
channels: conv1Shape.outChannels,
|
|
777
|
+
height: state.height,
|
|
778
|
+
width: state.width,
|
|
779
|
+
};
|
|
780
|
+
const activated = await ops.silu(conv1.tensor, {
|
|
781
|
+
size: conv1.channels * conv1.height * conv1.width,
|
|
782
|
+
swigluLimit: null,
|
|
783
|
+
});
|
|
784
|
+
release(conv1.tensor.buffer);
|
|
785
|
+
const conv2Weight = getWeight(weights, shapes, `${prefix}.conv2.weight`);
|
|
786
|
+
const conv2Shape = getConvShape(conv2Weight.shape);
|
|
787
|
+
const conv2 = await ops.conv2d(
|
|
788
|
+
reshapeTensor(activated, [conv1.channels, conv1.height, conv1.width], 'vae_dc_resblock_act'),
|
|
789
|
+
conv2Weight.value,
|
|
790
|
+
null,
|
|
791
|
+
{
|
|
792
|
+
inChannels: conv1.channels,
|
|
793
|
+
outChannels: conv2Shape.outChannels,
|
|
794
|
+
height: conv1.height,
|
|
795
|
+
width: conv1.width,
|
|
796
|
+
kernelH: 3,
|
|
797
|
+
kernelW: 3,
|
|
798
|
+
stride: 1,
|
|
799
|
+
pad: 1,
|
|
800
|
+
}
|
|
801
|
+
);
|
|
802
|
+
release(activated.buffer);
|
|
803
|
+
|
|
804
|
+
const normed = await runChannelwiseRmsNorm(
|
|
805
|
+
{
|
|
806
|
+
tensor: reshapeTensor(conv2, [conv2Shape.outChannels, conv1.height, conv1.width], 'vae_dc_resblock_conv2'),
|
|
807
|
+
channels: conv2Shape.outChannels,
|
|
808
|
+
height: conv1.height,
|
|
809
|
+
width: conv1.width,
|
|
810
|
+
},
|
|
811
|
+
getWeight(weights, shapes, `${prefix}.norm.weight`),
|
|
812
|
+
getWeightOptional(weights, shapes, `${prefix}.norm.bias`),
|
|
813
|
+
eps,
|
|
814
|
+
ops,
|
|
815
|
+
release
|
|
816
|
+
);
|
|
817
|
+
release(conv2.buffer);
|
|
818
|
+
|
|
819
|
+
const size = normed.channels * normed.height * normed.width;
|
|
820
|
+
const combined = await ops.residualAdd(
|
|
821
|
+
reshapeTensor(normed.tensor, [size], 'vae_dc_resblock_main'),
|
|
822
|
+
reshapeTensor(state.tensor, [size], 'vae_dc_resblock_residual'),
|
|
823
|
+
size,
|
|
824
|
+
{ useVec4: true }
|
|
825
|
+
);
|
|
826
|
+
release(normed.tensor.buffer);
|
|
827
|
+
release(state.tensor.buffer);
|
|
828
|
+
return {
|
|
829
|
+
tensor: reshapeTensor(combined, [normed.channels, normed.height, normed.width], 'vae_dc_resblock_out'),
|
|
830
|
+
channels: normed.channels,
|
|
831
|
+
height: normed.height,
|
|
832
|
+
width: normed.width,
|
|
833
|
+
};
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
async function runAutoencoderDCAttention(state, weights, shapes, prefix, attentionHeadDim, qkvMultiscales, eps, ops, release, recorder) {
|
|
837
|
+
const qWeight = getWeight(weights, shapes, `${prefix}.attn.to_q.weight`);
|
|
838
|
+
const kWeight = getWeight(weights, shapes, `${prefix}.attn.to_k.weight`);
|
|
839
|
+
const vWeight = getWeight(weights, shapes, `${prefix}.attn.to_v.weight`);
|
|
840
|
+
const qShape = getLinearShape(qWeight.shape, `${prefix}.attn.to_q.weight`);
|
|
841
|
+
const innerDim = qShape.outFeatures;
|
|
842
|
+
if (qShape.inFeatures !== state.channels || innerDim !== getLinearShape(kWeight.shape, `${prefix}.attn.to_k.weight`).outFeatures || innerDim !== getLinearShape(vWeight.shape, `${prefix}.attn.to_v.weight`).outFeatures) {
|
|
843
|
+
throw new Error(`AutoencoderDC attention "${prefix}" has incompatible q/k/v projection shapes.`);
|
|
844
|
+
}
|
|
845
|
+
if (!Number.isFinite(attentionHeadDim) || attentionHeadDim <= 0 || innerDim % attentionHeadDim !== 0) {
|
|
846
|
+
throw new Error(`AutoencoderDC attention "${prefix}" requires innerDim divisible by attentionHeadDim.`);
|
|
847
|
+
}
|
|
848
|
+
const numHeads = innerDim / attentionHeadDim;
|
|
849
|
+
const baseOptions = {
|
|
850
|
+
inChannels: state.channels,
|
|
851
|
+
outChannels: innerDim,
|
|
852
|
+
height: state.height,
|
|
853
|
+
width: state.width,
|
|
854
|
+
groups: 1,
|
|
855
|
+
};
|
|
856
|
+
const qBase = await ops.groupedPointwiseConv2d(state.tensor, qWeight.value, null, baseOptions);
|
|
857
|
+
const kBase = await ops.groupedPointwiseConv2d(state.tensor, kWeight.value, null, baseOptions);
|
|
858
|
+
const vBase = await ops.groupedPointwiseConv2d(state.tensor, vWeight.value, null, baseOptions);
|
|
859
|
+
const qVariants = [qBase];
|
|
860
|
+
const kVariants = [kBase];
|
|
861
|
+
const vVariants = [vBase];
|
|
862
|
+
|
|
863
|
+
if (Array.isArray(qkvMultiscales)) {
|
|
864
|
+
const qkvBase = await concatChannelTensors([qBase, kBase, vBase], state.height, state.width, recorder);
|
|
865
|
+
for (let scaleIdx = 0; scaleIdx < qkvMultiscales.length; scaleIdx++) {
|
|
866
|
+
const depthWeight = getWeight(weights, shapes, `${prefix}.attn.to_qkv_multiscale.${scaleIdx}.proj_in.weight`);
|
|
867
|
+
const pointWeight = getWeight(weights, shapes, `${prefix}.attn.to_qkv_multiscale.${scaleIdx}.proj_out.weight`);
|
|
868
|
+
const depthShape = getConvShape(depthWeight.shape);
|
|
869
|
+
const pointShape = getConvShape(pointWeight.shape);
|
|
870
|
+
const groups = pointShape.outChannels / pointShape.inChannels;
|
|
871
|
+
const depth = await ops.depthwiseConv2d(qkvBase, depthWeight.value, null, {
|
|
872
|
+
channels: qkvBase.shape[0],
|
|
873
|
+
height: state.height,
|
|
874
|
+
width: state.width,
|
|
875
|
+
kernelH: depthShape.kernelH,
|
|
876
|
+
kernelW: depthShape.kernelW,
|
|
877
|
+
stride: 1,
|
|
878
|
+
pad: Math.floor(depthShape.kernelH / 2),
|
|
879
|
+
});
|
|
880
|
+
const projected = await ops.groupedPointwiseConv2d(depth, pointWeight.value, null, {
|
|
881
|
+
inChannels: qkvBase.shape[0],
|
|
882
|
+
outChannels: pointShape.outChannels,
|
|
883
|
+
height: state.height,
|
|
884
|
+
width: state.width,
|
|
885
|
+
groups,
|
|
886
|
+
});
|
|
887
|
+
release(depth.buffer);
|
|
888
|
+
|
|
889
|
+
const qScale = await sliceChannelTensor(projected, 0, innerDim, state.height, state.width, recorder);
|
|
890
|
+
const kScale = await sliceChannelTensor(projected, innerDim, innerDim, state.height, state.width, recorder);
|
|
891
|
+
const vScale = await sliceChannelTensor(projected, innerDim * 2, innerDim, state.height, state.width, recorder);
|
|
892
|
+
release(projected.buffer);
|
|
893
|
+
qVariants.push(qScale);
|
|
894
|
+
kVariants.push(kScale);
|
|
895
|
+
vVariants.push(vScale);
|
|
896
|
+
}
|
|
897
|
+
release(qkvBase.buffer);
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
const qAll = await concatChannelTensors(qVariants, state.height, state.width, recorder);
|
|
901
|
+
const kAll = await concatChannelTensors(kVariants, state.height, state.width, recorder);
|
|
902
|
+
const vAll = await concatChannelTensors(vVariants, state.height, state.width, recorder);
|
|
903
|
+
for (const tensor of qVariants) release(tensor.buffer);
|
|
904
|
+
for (const tensor of kVariants) release(tensor.buffer);
|
|
905
|
+
for (const tensor of vVariants) release(tensor.buffer);
|
|
906
|
+
|
|
907
|
+
const qTokens = await channelsToTokens({ tensor: qAll, channels: qAll.shape[0], height: state.height, width: state.width }, ops);
|
|
908
|
+
const kTokens = await channelsToTokens({ tensor: kAll, channels: kAll.shape[0], height: state.height, width: state.width }, ops);
|
|
909
|
+
const vTokens = await channelsToTokens({ tensor: vAll, channels: vAll.shape[0], height: state.height, width: state.width }, ops);
|
|
910
|
+
release(qAll.buffer);
|
|
911
|
+
release(kAll.buffer);
|
|
912
|
+
release(vAll.buffer);
|
|
913
|
+
|
|
914
|
+
const qRelu = await ops.relu(qTokens.tensor, { count: tensorElementCount(qTokens.tensor) });
|
|
915
|
+
const kRelu = await ops.relu(kTokens.tensor, { count: tensorElementCount(kTokens.tensor) });
|
|
916
|
+
release(qTokens.tensor.buffer);
|
|
917
|
+
release(kTokens.tensor.buffer);
|
|
918
|
+
|
|
919
|
+
const allHeads = numHeads * qVariants.length;
|
|
920
|
+
const attention = await ops.sanaLinearAttention(qRelu, kRelu, vTokens.tensor, {
|
|
921
|
+
numHeads: allHeads,
|
|
922
|
+
headDim: attentionHeadDim,
|
|
923
|
+
numTokens: qTokens.numTokens,
|
|
924
|
+
hiddenSize: allHeads * attentionHeadDim,
|
|
925
|
+
eps,
|
|
926
|
+
});
|
|
927
|
+
release(qRelu.buffer);
|
|
928
|
+
release(kRelu.buffer);
|
|
929
|
+
release(vTokens.tensor.buffer);
|
|
930
|
+
|
|
931
|
+
const attended = await tokensToChannels(attention, allHeads * attentionHeadDim, state.height, state.width, ops);
|
|
932
|
+
release(attention.buffer);
|
|
933
|
+
const outWeight = getWeight(weights, shapes, `${prefix}.attn.to_out.weight`);
|
|
934
|
+
const outShape = getLinearShape(outWeight.shape, `${prefix}.attn.to_out.weight`);
|
|
935
|
+
const projected = await ops.groupedPointwiseConv2d(attended.tensor, outWeight.value, null, {
|
|
936
|
+
inChannels: attended.channels,
|
|
937
|
+
outChannels: outShape.outFeatures,
|
|
938
|
+
height: state.height,
|
|
939
|
+
width: state.width,
|
|
940
|
+
groups: 1,
|
|
941
|
+
});
|
|
942
|
+
release(attended.tensor.buffer);
|
|
943
|
+
const normed = await runChannelwiseRmsNorm(
|
|
944
|
+
{
|
|
945
|
+
tensor: reshapeTensor(projected, [outShape.outFeatures, state.height, state.width], 'vae_dc_attn_projected'),
|
|
946
|
+
channels: outShape.outFeatures,
|
|
947
|
+
height: state.height,
|
|
948
|
+
width: state.width,
|
|
949
|
+
},
|
|
950
|
+
getWeight(weights, shapes, `${prefix}.attn.norm_out.weight`),
|
|
951
|
+
getWeightOptional(weights, shapes, `${prefix}.attn.norm_out.bias`),
|
|
952
|
+
1e-5,
|
|
953
|
+
ops,
|
|
954
|
+
release
|
|
955
|
+
);
|
|
956
|
+
release(projected.buffer);
|
|
957
|
+
|
|
958
|
+
const size = normed.channels * normed.height * normed.width;
|
|
959
|
+
const combined = await ops.residualAdd(
|
|
960
|
+
reshapeTensor(normed.tensor, [size], 'vae_dc_attn_main'),
|
|
961
|
+
reshapeTensor(state.tensor, [size], 'vae_dc_attn_residual'),
|
|
962
|
+
size,
|
|
963
|
+
{ useVec4: true }
|
|
964
|
+
);
|
|
965
|
+
release(normed.tensor.buffer);
|
|
966
|
+
release(state.tensor.buffer);
|
|
967
|
+
return {
|
|
968
|
+
tensor: reshapeTensor(combined, [normed.channels, normed.height, normed.width], 'vae_dc_attn_out'),
|
|
969
|
+
channels: normed.channels,
|
|
970
|
+
height: normed.height,
|
|
971
|
+
width: normed.width,
|
|
972
|
+
};
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
async function runAutoencoderDCGlumbConv(state, weights, shapes, prefix, eps, ops, release) {
|
|
976
|
+
const invertedWeight = getWeight(weights, shapes, `${prefix}.conv_out.conv_inverted.weight`);
|
|
977
|
+
const invertedBias = getWeightOptional(weights, shapes, `${prefix}.conv_out.conv_inverted.bias`);
|
|
978
|
+
const invertedShape = getLinearShape(invertedWeight.shape, `${prefix}.conv_out.conv_inverted.weight`);
|
|
979
|
+
const hiddenChannels = Math.floor(invertedShape.outFeatures / 2);
|
|
980
|
+
const inverted = await ops.groupedPointwiseConv2d(state.tensor, invertedWeight.value, invertedBias?.value ?? null, {
|
|
981
|
+
inChannels: state.channels,
|
|
982
|
+
outChannels: invertedShape.outFeatures,
|
|
983
|
+
height: state.height,
|
|
984
|
+
width: state.width,
|
|
985
|
+
groups: 1,
|
|
986
|
+
});
|
|
987
|
+
const activated = await ops.silu(inverted, {
|
|
988
|
+
size: invertedShape.outFeatures * state.height * state.width,
|
|
989
|
+
swigluLimit: null,
|
|
990
|
+
});
|
|
991
|
+
release(inverted.buffer);
|
|
992
|
+
const depthWeight = getWeight(weights, shapes, `${prefix}.conv_out.conv_depth.weight`);
|
|
993
|
+
const depthBias = getWeightOptional(weights, shapes, `${prefix}.conv_out.conv_depth.bias`);
|
|
994
|
+
const depthShape = getConvShape(depthWeight.shape);
|
|
995
|
+
const depth = await ops.depthwiseConv2d(
|
|
996
|
+
reshapeTensor(activated, [invertedShape.outFeatures, state.height, state.width], 'vae_dc_glumb_act'),
|
|
997
|
+
depthWeight.value,
|
|
998
|
+
depthBias?.value ?? null,
|
|
999
|
+
{
|
|
1000
|
+
channels: invertedShape.outFeatures,
|
|
1001
|
+
height: state.height,
|
|
1002
|
+
width: state.width,
|
|
1003
|
+
kernelH: depthShape.kernelH,
|
|
1004
|
+
kernelW: depthShape.kernelW,
|
|
1005
|
+
stride: 1,
|
|
1006
|
+
pad: 1,
|
|
1007
|
+
}
|
|
1008
|
+
);
|
|
1009
|
+
release(activated.buffer);
|
|
1010
|
+
const depthTokens = await channelsToTokens({ tensor: depth, channels: invertedShape.outFeatures, height: state.height, width: state.width }, ops);
|
|
1011
|
+
release(depth.buffer);
|
|
1012
|
+
const gated = await ops.siluRowSplit(depthTokens.tensor, {
|
|
1013
|
+
numTokens: depthTokens.numTokens,
|
|
1014
|
+
dim: hiddenChannels,
|
|
1015
|
+
activation: 'silu',
|
|
1016
|
+
swigluLimit: null,
|
|
1017
|
+
});
|
|
1018
|
+
release(depthTokens.tensor.buffer);
|
|
1019
|
+
const gatedChannels = await tokensToChannels(gated, hiddenChannels, state.height, state.width, ops);
|
|
1020
|
+
release(gated.buffer);
|
|
1021
|
+
|
|
1022
|
+
const pointWeight = getWeight(weights, shapes, `${prefix}.conv_out.conv_point.weight`);
|
|
1023
|
+
const pointShape = getLinearShape(pointWeight.shape, `${prefix}.conv_out.conv_point.weight`);
|
|
1024
|
+
const projected = await ops.groupedPointwiseConv2d(gatedChannels.tensor, pointWeight.value, null, {
|
|
1025
|
+
inChannels: hiddenChannels,
|
|
1026
|
+
outChannels: pointShape.outFeatures,
|
|
1027
|
+
height: state.height,
|
|
1028
|
+
width: state.width,
|
|
1029
|
+
groups: 1,
|
|
1030
|
+
});
|
|
1031
|
+
release(gatedChannels.tensor.buffer);
|
|
1032
|
+
const normed = await runChannelwiseRmsNorm(
|
|
1033
|
+
{
|
|
1034
|
+
tensor: reshapeTensor(projected, [pointShape.outFeatures, state.height, state.width], 'vae_dc_glumb_projected'),
|
|
1035
|
+
channels: pointShape.outFeatures,
|
|
1036
|
+
height: state.height,
|
|
1037
|
+
width: state.width,
|
|
1038
|
+
},
|
|
1039
|
+
getWeight(weights, shapes, `${prefix}.conv_out.norm.weight`),
|
|
1040
|
+
getWeightOptional(weights, shapes, `${prefix}.conv_out.norm.bias`),
|
|
1041
|
+
eps,
|
|
1042
|
+
ops,
|
|
1043
|
+
release
|
|
1044
|
+
);
|
|
1045
|
+
release(projected.buffer);
|
|
1046
|
+
|
|
1047
|
+
const size = normed.channels * normed.height * normed.width;
|
|
1048
|
+
const combined = await ops.residualAdd(
|
|
1049
|
+
reshapeTensor(normed.tensor, [size], 'vae_dc_glumb_main'),
|
|
1050
|
+
reshapeTensor(state.tensor, [size], 'vae_dc_glumb_residual'),
|
|
1051
|
+
size,
|
|
1052
|
+
{ useVec4: true }
|
|
1053
|
+
);
|
|
1054
|
+
release(normed.tensor.buffer);
|
|
1055
|
+
release(state.tensor.buffer);
|
|
1056
|
+
return {
|
|
1057
|
+
tensor: reshapeTensor(combined, [normed.channels, normed.height, normed.width], 'vae_dc_glumb_out'),
|
|
1058
|
+
channels: normed.channels,
|
|
1059
|
+
height: normed.height,
|
|
1060
|
+
width: normed.width,
|
|
1061
|
+
};
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
async function runAutoencoderDCEfficientVitBlock(state, weights, shapes, prefix, attentionHeadDim, qkvMultiscales, eps, ops, release, recorder) {
|
|
1065
|
+
const attended = await runAutoencoderDCAttention(
|
|
1066
|
+
state,
|
|
1067
|
+
weights,
|
|
1068
|
+
shapes,
|
|
1069
|
+
prefix,
|
|
1070
|
+
attentionHeadDim,
|
|
1071
|
+
qkvMultiscales,
|
|
1072
|
+
1e-15,
|
|
1073
|
+
ops,
|
|
1074
|
+
release,
|
|
1075
|
+
recorder
|
|
1076
|
+
);
|
|
1077
|
+
return runAutoencoderDCGlumbConv(attended, weights, shapes, prefix, eps, ops, release);
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
async function decodeLatentsAutoencoderDC(state, config, weights, shapes, ops, release, recorder) {
|
|
1081
|
+
const blockTypes = normalizePerBlockValue(config.decoder_block_types, config.decoder_block_out_channels.length, 'decoder_block_types');
|
|
1082
|
+
const layersPerBlock = normalizePerBlockValue(config.decoder_layers_per_block, config.decoder_block_out_channels.length, 'decoder_layers_per_block');
|
|
1083
|
+
const qkvMultiscales = normalizePerBlockValue(config.decoder_qkv_multiscales, config.decoder_block_out_channels.length, 'decoder_qkv_multiscales');
|
|
1084
|
+
const normTypes = normalizePerBlockValue(config.decoder_norm_types, config.decoder_block_out_channels.length, 'decoder_norm_types');
|
|
1085
|
+
const actFns = normalizePerBlockValue(config.decoder_act_fns, config.decoder_block_out_channels.length, 'decoder_act_fns');
|
|
1086
|
+
const rmsNormEps = 1e-5;
|
|
1087
|
+
|
|
1088
|
+
state = await runAutoencoderDCInputProjection(state, weights, shapes, config, ops, release);
|
|
1089
|
+
|
|
1090
|
+
for (let blockIdx = blockTypes.length - 1; blockIdx >= 0; blockIdx--) {
|
|
1091
|
+
const prefix = `vae.decoder.up_blocks.${blockIdx}`;
|
|
1092
|
+
const hasUpsample = weights.has(`${prefix}.0.conv.weight`);
|
|
1093
|
+
if (hasUpsample) {
|
|
1094
|
+
state = await runAutoencoderDCUpBlock(state, weights, shapes, `${prefix}.0`, config, ops, release, recorder);
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
if (normTypes[blockIdx] !== 'rms_norm') {
|
|
1098
|
+
throw new Error(
|
|
1099
|
+
`Unsupported AutoencoderDC norm type "${normTypes[blockIdx]}" in block ${blockIdx}.`
|
|
1100
|
+
);
|
|
1101
|
+
}
|
|
1102
|
+
if (actFns[blockIdx] !== 'silu') {
|
|
1103
|
+
throw new Error(
|
|
1104
|
+
`Unsupported AutoencoderDC activation "${actFns[blockIdx]}" in block ${blockIdx}.`
|
|
1105
|
+
);
|
|
1106
|
+
}
|
|
1107
|
+
|
|
1108
|
+
const startIndex = hasUpsample ? 1 : 0;
|
|
1109
|
+
const blockType = blockTypes[blockIdx];
|
|
1110
|
+
const numLayers = layersPerBlock[blockIdx];
|
|
1111
|
+
for (let layerOffset = 0; layerOffset < numLayers; layerOffset++) {
|
|
1112
|
+
const layerPrefix = `${prefix}.${startIndex + layerOffset}`;
|
|
1113
|
+
if (blockType === 'ResBlock') {
|
|
1114
|
+
state = await runAutoencoderDCResBlock(state, weights, shapes, layerPrefix, rmsNormEps, ops, release);
|
|
1115
|
+
continue;
|
|
1116
|
+
}
|
|
1117
|
+
if (blockType === 'EfficientViTBlock') {
|
|
1118
|
+
state = await runAutoencoderDCEfficientVitBlock(
|
|
1119
|
+
state,
|
|
1120
|
+
weights,
|
|
1121
|
+
shapes,
|
|
1122
|
+
layerPrefix,
|
|
1123
|
+
config.attention_head_dim,
|
|
1124
|
+
qkvMultiscales[blockIdx],
|
|
1125
|
+
rmsNormEps,
|
|
1126
|
+
ops,
|
|
1127
|
+
release,
|
|
1128
|
+
recorder
|
|
1129
|
+
);
|
|
1130
|
+
continue;
|
|
1131
|
+
}
|
|
1132
|
+
throw new Error(`Unsupported AutoencoderDC block type "${blockType}" in block ${blockIdx}.`);
|
|
1133
|
+
}
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
const normed = await runChannelwiseRmsNorm(
|
|
1137
|
+
state,
|
|
1138
|
+
getWeight(weights, shapes, 'vae.decoder.norm_out.weight'),
|
|
1139
|
+
getWeightOptional(weights, shapes, 'vae.decoder.norm_out.bias'),
|
|
1140
|
+
rmsNormEps,
|
|
1141
|
+
ops,
|
|
1142
|
+
release
|
|
1143
|
+
);
|
|
1144
|
+
release(state.tensor.buffer);
|
|
1145
|
+
const activated = await ops.relu(normed.tensor, {
|
|
1146
|
+
count: normed.channels * normed.height * normed.width,
|
|
1147
|
+
});
|
|
1148
|
+
release(normed.tensor.buffer);
|
|
1149
|
+
return applyConv2D(
|
|
1150
|
+
{
|
|
1151
|
+
tensor: reshapeTensor(activated, [normed.channels, normed.height, normed.width], 'vae_dc_norm_out'),
|
|
1152
|
+
channels: normed.channels,
|
|
1153
|
+
height: normed.height,
|
|
1154
|
+
width: normed.width,
|
|
1155
|
+
},
|
|
1156
|
+
weights,
|
|
1157
|
+
shapes,
|
|
1158
|
+
'vae.decoder.conv_out',
|
|
1159
|
+
{ pad: 1 },
|
|
1160
|
+
ops,
|
|
1161
|
+
release
|
|
1162
|
+
);
|
|
1163
|
+
}
|
|
1164
|
+
|
|
467
1165
|
async function decodeLatentsGPU(latents, options) {
|
|
468
1166
|
const device = getDevice();
|
|
469
1167
|
if (!device) {
|
|
@@ -495,14 +1193,7 @@ async function decodeLatentsGPU(latents, options) {
|
|
|
495
1193
|
throw new Error('VAE decode requires a valid scaling_factor in config.');
|
|
496
1194
|
}
|
|
497
1195
|
const shiftFactor = Number.isFinite(config.shift_factor) ? config.shift_factor : 0.0;
|
|
498
|
-
const
|
|
499
|
-
if (!Number.isFinite(numGroups) || numGroups <= 0) {
|
|
500
|
-
throw new Error('VAE decode requires norm_num_groups in config.');
|
|
501
|
-
}
|
|
502
|
-
const eps = runtime.decode?.groupNormEps;
|
|
503
|
-
if (!Number.isFinite(eps)) {
|
|
504
|
-
throw new Error('VAE decode requires runtime.decode.groupNormEps.');
|
|
505
|
-
}
|
|
1196
|
+
const isAutoencoderDC = config._class_name === 'AutoencoderDC' || Array.isArray(config.decoder_block_types);
|
|
506
1197
|
|
|
507
1198
|
const scaledLatents = new Float32Array(latents.length);
|
|
508
1199
|
for (let i = 0; i < latents.length; i++) {
|
|
@@ -537,82 +1228,95 @@ async function decodeLatentsGPU(latents, options) {
|
|
|
537
1228
|
width: state.width,
|
|
538
1229
|
};
|
|
539
1230
|
|
|
540
|
-
|
|
1231
|
+
if (isAutoencoderDC) {
|
|
1232
|
+
state = await decodeLatentsAutoencoderDC(state, config, weights, shapes, ops, release, recorder);
|
|
1233
|
+
} else {
|
|
1234
|
+
const numGroups = config.norm_num_groups;
|
|
1235
|
+
if (!Number.isFinite(numGroups) || numGroups <= 0) {
|
|
1236
|
+
throw new Error('VAE decode requires norm_num_groups in config.');
|
|
1237
|
+
}
|
|
1238
|
+
const eps = runtime.decode?.groupNormEps;
|
|
1239
|
+
if (!Number.isFinite(eps)) {
|
|
1240
|
+
throw new Error('VAE decode requires runtime.decode.groupNormEps.');
|
|
1241
|
+
}
|
|
541
1242
|
|
|
542
|
-
|
|
543
|
-
const midResnetIds = buildIndexList(weights, midResnetPrefix);
|
|
544
|
-
for (const idx of midResnetIds) {
|
|
545
|
-
state = await runResnetBlock(state, weights, shapes, `${midResnetPrefix}${idx}`, { numGroups, eps }, ops, release);
|
|
546
|
-
}
|
|
1243
|
+
state = await applyConv2D(state, weights, shapes, 'vae.decoder.conv_in', { pad: 1 }, ops, release);
|
|
547
1244
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
weights,
|
|
554
|
-
shapes,
|
|
555
|
-
`${midAttentionPrefix}${idx}`,
|
|
556
|
-
{
|
|
557
|
-
numGroups,
|
|
558
|
-
eps,
|
|
559
|
-
modelConfig: config,
|
|
560
|
-
},
|
|
561
|
-
ops,
|
|
562
|
-
release
|
|
563
|
-
);
|
|
564
|
-
}
|
|
1245
|
+
const midResnetPrefix = 'vae.decoder.mid_block.resnets.';
|
|
1246
|
+
const midResnetIds = buildIndexList(weights, midResnetPrefix);
|
|
1247
|
+
for (const idx of midResnetIds) {
|
|
1248
|
+
state = await runResnetBlock(state, weights, shapes, `${midResnetPrefix}${idx}`, { numGroups, eps }, ops, release);
|
|
1249
|
+
}
|
|
565
1250
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
1251
|
+
const midAttentionPrefix = 'vae.decoder.mid_block.attentions.';
|
|
1252
|
+
const midAttentionIds = buildIndexList(weights, midAttentionPrefix);
|
|
1253
|
+
for (const idx of midAttentionIds) {
|
|
1254
|
+
state = await runMidBlockAttention(
|
|
1255
|
+
state,
|
|
1256
|
+
weights,
|
|
1257
|
+
shapes,
|
|
1258
|
+
`${midAttentionPrefix}${idx}`,
|
|
1259
|
+
{
|
|
1260
|
+
numGroups,
|
|
1261
|
+
eps,
|
|
1262
|
+
modelConfig: config,
|
|
1263
|
+
},
|
|
1264
|
+
ops,
|
|
1265
|
+
release
|
|
1266
|
+
);
|
|
573
1267
|
}
|
|
574
1268
|
|
|
575
|
-
const
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
}
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
1269
|
+
const upBlockPrefix = 'vae.decoder.up_blocks.';
|
|
1270
|
+
const upBlocks = buildIndexList(weights, upBlockPrefix);
|
|
1271
|
+
for (const blockIdx of upBlocks) {
|
|
1272
|
+
const resnetPrefix = `${upBlockPrefix}${blockIdx}.resnets.`;
|
|
1273
|
+
const resnetIds = buildIndexList(weights, resnetPrefix);
|
|
1274
|
+
for (const idx of resnetIds) {
|
|
1275
|
+
state = await runResnetBlock(state, weights, shapes, `${resnetPrefix}${idx}`, { numGroups, eps }, ops, release);
|
|
1276
|
+
}
|
|
1277
|
+
|
|
1278
|
+
const upsampleWeightName = `${upBlockPrefix}${blockIdx}.upsamplers.0.conv.weight`;
|
|
1279
|
+
if (weights.has(upsampleWeightName)) {
|
|
1280
|
+
const upsample = await ops.upsample2d(state.tensor, {
|
|
1281
|
+
channels: state.channels,
|
|
1282
|
+
height: state.height,
|
|
1283
|
+
width: state.width,
|
|
1284
|
+
scale: 2,
|
|
1285
|
+
});
|
|
1286
|
+
release(state.tensor.buffer);
|
|
1287
|
+
state = {
|
|
1288
|
+
tensor: reshapeTensor(upsample, [state.channels, state.height * 2, state.width * 2], 'vae_upsample'),
|
|
1289
|
+
channels: state.channels,
|
|
1290
|
+
height: state.height * 2,
|
|
1291
|
+
width: state.width * 2,
|
|
1292
|
+
};
|
|
1293
|
+
|
|
1294
|
+
state = await applyConv2D(state, weights, shapes, `${upBlockPrefix}${blockIdx}.upsamplers.0.conv`, { pad: 1 }, ops, release);
|
|
1295
|
+
}
|
|
592
1296
|
}
|
|
593
|
-
}
|
|
594
1297
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
1298
|
+
const normOut = getWeight(weights, shapes, 'vae.decoder.conv_norm_out.weight');
|
|
1299
|
+
const normOutBias = getWeight(weights, shapes, 'vae.decoder.conv_norm_out.bias');
|
|
1300
|
+
const normed = await ops.groupNorm(state.tensor, normOut.value, normOutBias.value, {
|
|
1301
|
+
channels: state.channels,
|
|
1302
|
+
height: state.height,
|
|
1303
|
+
width: state.width,
|
|
1304
|
+
numGroups,
|
|
1305
|
+
eps,
|
|
1306
|
+
});
|
|
1307
|
+
release(state.tensor.buffer);
|
|
605
1308
|
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
1309
|
+
const siluOut = await ops.silu(normed, { size: state.channels * state.height * state.width, swigluLimit: null });
|
|
1310
|
+
release(normed.buffer);
|
|
1311
|
+
state = {
|
|
1312
|
+
tensor: reshapeTensor(siluOut, [state.channels, state.height, state.width], 'vae_norm_out'),
|
|
1313
|
+
channels: state.channels,
|
|
1314
|
+
height: state.height,
|
|
1315
|
+
width: state.width,
|
|
1316
|
+
};
|
|
614
1317
|
|
|
615
|
-
|
|
1318
|
+
state = await applyConv2D(state, weights, shapes, 'vae.decoder.conv_out', { pad: 1 }, ops, release);
|
|
1319
|
+
}
|
|
616
1320
|
|
|
617
1321
|
const outputSize = state.channels * state.height * state.width * dtypeBytes(state.tensor.dtype);
|
|
618
1322
|
if (localRecorder) {
|