@simulatte/doppler 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -5
- package/package.json +27 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.d.ts +80 -0
- package/src/client/doppler-api.js +298 -0
- package/src/client/doppler-provider/types.js +1 -1
- package/src/client/doppler-registry.d.ts +23 -0
- package/src/client/doppler-registry.js +88 -0
- package/src/client/doppler-registry.json +39 -0
- package/src/config/execution-contract-check.d.ts +82 -0
- package/src/config/execution-contract-check.js +317 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +90 -67
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +27 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +49 -11
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/tokenizer-utils.js +17 -3
- package/src/formats/rdrr/validation.js +13 -0
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -0
- package/src/index-browser.js +2 -1
- package/src/index.d.ts +1 -0
- package/src/index.js +2 -1
- package/src/inference/browser-harness.js +164 -38
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +141 -101
- package/src/inference/pipelines/text/init.js +41 -10
- package/src/inference/pipelines/text.js +7 -1
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/lean-execution-contract.d.ts +16 -0
- package/src/tooling/lean-execution-contract.js +81 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +30 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +167 -6
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
import { getDevice } from '../../../gpu/device.js';
|
|
2
|
+
import { createTensor } from '../../../gpu/tensor.js';
|
|
3
|
+
import { getBuffer } from '../../../gpu/weight-buffer.js';
|
|
4
|
+
import { acquireBuffer } from '../../../memory/buffer-pool.js';
|
|
5
|
+
import {
|
|
6
|
+
runConv2D,
|
|
7
|
+
runDepthwiseConv2D,
|
|
8
|
+
runGroupedPointwiseConv2D,
|
|
9
|
+
runLayerNorm,
|
|
10
|
+
runRMSNorm,
|
|
11
|
+
runMatmul,
|
|
12
|
+
runAttention,
|
|
13
|
+
runSiLU,
|
|
14
|
+
runSiLURowSplit,
|
|
15
|
+
runResidualAdd,
|
|
16
|
+
runBiasAdd,
|
|
17
|
+
runModulate,
|
|
18
|
+
runSanaLinearAttention,
|
|
19
|
+
recordConv2D,
|
|
20
|
+
recordDepthwiseConv2D,
|
|
21
|
+
recordGroupedPointwiseConv2D,
|
|
22
|
+
recordLayerNorm,
|
|
23
|
+
recordRMSNorm,
|
|
24
|
+
recordMatmul,
|
|
25
|
+
recordAttention,
|
|
26
|
+
recordSiLU,
|
|
27
|
+
recordSiLURowSplit,
|
|
28
|
+
recordResidualAdd,
|
|
29
|
+
recordBiasAdd,
|
|
30
|
+
recordModulate,
|
|
31
|
+
recordSanaLinearAttention,
|
|
32
|
+
} from '../../../gpu/kernels/index.js';
|
|
33
|
+
import { log } from '../../../debug/index.js';
|
|
34
|
+
import {
|
|
35
|
+
resolveDiffusionActivationDtype,
|
|
36
|
+
createDiffusionBufferReleaser,
|
|
37
|
+
createDiffusionBufferDestroyer,
|
|
38
|
+
normalizeDiffusionMatmulLocationDtype,
|
|
39
|
+
inferDiffusionMatmulDtypeFromBuffer,
|
|
40
|
+
expectDiffusionWeight,
|
|
41
|
+
} from './helpers.js';
|
|
42
|
+
|
|
43
|
+
function reshapeTensor(tensor, shape, label) {
|
|
44
|
+
return createTensor(tensor.buffer, tensor.dtype, shape, label ?? tensor.label);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
function getWeight(weightsEntry, name) {
|
|
48
|
+
return weightsEntry?.weights?.get(`transformer.${name}`) ?? null;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function getWeightShape(weightsEntry, name) {
|
|
52
|
+
return weightsEntry?.shapes?.get(`transformer.${name}`) ?? null;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
function getWeightDtype(weightsEntry, name) {
|
|
56
|
+
return weightsEntry?.dtypes?.get(`transformer.${name}`) ?? null;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
function createKernelOps(recorder) {
|
|
60
|
+
if (!recorder) {
|
|
61
|
+
return {
|
|
62
|
+
conv2d: runConv2D,
|
|
63
|
+
depthwiseConv2d: runDepthwiseConv2D,
|
|
64
|
+
groupedPointwiseConv2d: runGroupedPointwiseConv2D,
|
|
65
|
+
layerNorm: runLayerNorm,
|
|
66
|
+
rmsNorm: runRMSNorm,
|
|
67
|
+
attention: runAttention,
|
|
68
|
+
silu: runSiLU,
|
|
69
|
+
siluRowSplit: runSiLURowSplit,
|
|
70
|
+
residualAdd: runResidualAdd,
|
|
71
|
+
biasAdd: runBiasAdd,
|
|
72
|
+
modulate: runModulate,
|
|
73
|
+
sanaLinearAttention: runSanaLinearAttention,
|
|
74
|
+
};
|
|
75
|
+
}
|
|
76
|
+
return {
|
|
77
|
+
conv2d: (...args) => recordConv2D(recorder, ...args),
|
|
78
|
+
depthwiseConv2d: (...args) => recordDepthwiseConv2D(recorder, ...args),
|
|
79
|
+
groupedPointwiseConv2d: (...args) => recordGroupedPointwiseConv2D(recorder, ...args),
|
|
80
|
+
layerNorm: (...args) => recordLayerNorm(recorder, ...args),
|
|
81
|
+
rmsNorm: (...args) => recordRMSNorm(recorder, ...args),
|
|
82
|
+
attention: (...args) => recordAttention(recorder, ...args),
|
|
83
|
+
silu: (...args) => recordSiLU(recorder, ...args),
|
|
84
|
+
siluRowSplit: (...args) => recordSiLURowSplit(recorder, ...args),
|
|
85
|
+
residualAdd: (...args) => recordResidualAdd(recorder, ...args),
|
|
86
|
+
biasAdd: (...args) => recordBiasAdd(recorder, ...args),
|
|
87
|
+
modulate: (...args) => recordModulate(recorder, ...args),
|
|
88
|
+
sanaLinearAttention: (...args) => recordSanaLinearAttention(recorder, ...args),
|
|
89
|
+
};
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
function createVectorBuffer(device, values, label) {
|
|
93
|
+
const buffer = acquireBuffer(values.byteLength, undefined, label);
|
|
94
|
+
device.queue.writeBuffer(buffer, 0, values);
|
|
95
|
+
return buffer;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
function createBiasTensor(weight, size, label) {
|
|
99
|
+
if (!weight) return null;
|
|
100
|
+
return createTensor(getBuffer(weight), 'f32', [size], label);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
function countMask(mask) {
|
|
104
|
+
if (!mask) return null;
|
|
105
|
+
let count = 0;
|
|
106
|
+
for (const value of mask) {
|
|
107
|
+
if (value) count += 1;
|
|
108
|
+
}
|
|
109
|
+
return count;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
function trimContext(context, attentionMask) {
|
|
113
|
+
const validTokens = countMask(attentionMask);
|
|
114
|
+
if (!Number.isFinite(validTokens) || validTokens <= 0 || validTokens >= context.shape[0]) {
|
|
115
|
+
return context;
|
|
116
|
+
}
|
|
117
|
+
return createTensor(context.buffer, context.dtype, [validTokens, context.shape[1]], 'sana_trimmed_context');
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
function buildSinusoidalEmbedding(value, dim = 256) {
|
|
121
|
+
const half = Math.floor(dim / 2);
|
|
122
|
+
const out = new Float32Array(dim);
|
|
123
|
+
const maxPeriod = 10000;
|
|
124
|
+
for (let i = 0; i < half; i++) {
|
|
125
|
+
const freq = Math.exp(-Math.log(maxPeriod) * i / half);
|
|
126
|
+
const angle = value * freq;
|
|
127
|
+
out[2 * i] = Math.cos(angle);
|
|
128
|
+
out[2 * i + 1] = Math.sin(angle);
|
|
129
|
+
}
|
|
130
|
+
return out;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
function resolveMatmulDtype(weightsEntry, name, N, K) {
|
|
134
|
+
const weight = getWeight(weightsEntry, name);
|
|
135
|
+
const preferred = normalizeDiffusionMatmulLocationDtype(getWeightDtype(weightsEntry, name));
|
|
136
|
+
return inferDiffusionMatmulDtypeFromBuffer(weight, N, K, preferred);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
async function runMatmulResolved(input, weightsEntry, name, M, N, K, recorder, options = {}) {
|
|
140
|
+
const weight = expectDiffusionWeight(getWeight(weightsEntry, name), name);
|
|
141
|
+
const bDtype = resolveMatmulDtype(weightsEntry, name, N, K);
|
|
142
|
+
if (recorder) {
|
|
143
|
+
return recordMatmul(recorder, input, weight, M, N, K, { ...options, bDtype, transposeB: 'auto' });
|
|
144
|
+
}
|
|
145
|
+
return runMatmul(input, weight, M, N, K, { ...options, bDtype, transposeB: 'auto' });
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
async function runTwoLayerEmbedding(inputTensor, weightsEntry, prefix, outDim, recorder, runtime, ops, release) {
|
|
149
|
+
const activationDtype = resolveDiffusionActivationDtype(runtime);
|
|
150
|
+
|
|
151
|
+
let output = await runMatmulResolved(
|
|
152
|
+
inputTensor,
|
|
153
|
+
weightsEntry,
|
|
154
|
+
`${prefix}.linear_1.weight`,
|
|
155
|
+
1,
|
|
156
|
+
getWeightShape(weightsEntry, `${prefix}.linear_1.weight`)[0],
|
|
157
|
+
getWeightShape(weightsEntry, `${prefix}.linear_1.weight`)[1],
|
|
158
|
+
recorder,
|
|
159
|
+
{ outputDtype: activationDtype }
|
|
160
|
+
);
|
|
161
|
+
const bias1 = createBiasTensor(getWeight(weightsEntry, `${prefix}.linear_1.bias`), output.shape[1], `${prefix}_bias1`);
|
|
162
|
+
if (bias1) {
|
|
163
|
+
output = await ops.biasAdd(output, bias1, 1, output.shape[1]);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
const act = await ops.silu(output, { size: output.shape[1], swigluLimit: null });
|
|
167
|
+
release(output.buffer);
|
|
168
|
+
|
|
169
|
+
let projected = await runMatmulResolved(
|
|
170
|
+
act,
|
|
171
|
+
weightsEntry,
|
|
172
|
+
`${prefix}.linear_2.weight`,
|
|
173
|
+
1,
|
|
174
|
+
outDim,
|
|
175
|
+
getWeightShape(weightsEntry, `${prefix}.linear_2.weight`)[1],
|
|
176
|
+
recorder,
|
|
177
|
+
{ outputDtype: activationDtype }
|
|
178
|
+
);
|
|
179
|
+
const bias2 = createBiasTensor(getWeight(weightsEntry, `${prefix}.linear_2.bias`), outDim, `${prefix}_bias2`);
|
|
180
|
+
if (bias2) {
|
|
181
|
+
projected = await ops.biasAdd(projected, bias2, 1, outDim);
|
|
182
|
+
}
|
|
183
|
+
release(act.buffer);
|
|
184
|
+
return projected;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
export async function buildSanaTimestepConditioning(timestep, guidanceScale, weightsEntry, config, runtime, options = {}) {
|
|
188
|
+
const device = getDevice();
|
|
189
|
+
if (!device) {
|
|
190
|
+
throw new Error('Sana timestep conditioning requires a WebGPU device.');
|
|
191
|
+
}
|
|
192
|
+
const recorder = options.recorder ?? null;
|
|
193
|
+
const ops = createKernelOps(recorder);
|
|
194
|
+
const release = createDiffusionBufferReleaser(recorder);
|
|
195
|
+
const activationDtype = resolveDiffusionActivationDtype(runtime);
|
|
196
|
+
const hiddenSize = config.num_attention_heads * config.attention_head_dim;
|
|
197
|
+
|
|
198
|
+
const timeTensor = createTensor(
|
|
199
|
+
createVectorBuffer(device, buildSinusoidalEmbedding(timestep, 256), 'sana_timestep'),
|
|
200
|
+
activationDtype,
|
|
201
|
+
[1, 256],
|
|
202
|
+
'sana_timestep'
|
|
203
|
+
);
|
|
204
|
+
const timeEmbedding = await runTwoLayerEmbedding(
|
|
205
|
+
timeTensor,
|
|
206
|
+
weightsEntry,
|
|
207
|
+
'time_embed.timestep_embedder',
|
|
208
|
+
hiddenSize,
|
|
209
|
+
recorder,
|
|
210
|
+
runtime,
|
|
211
|
+
ops,
|
|
212
|
+
release
|
|
213
|
+
);
|
|
214
|
+
release(timeTensor.buffer);
|
|
215
|
+
|
|
216
|
+
let conditioning = timeEmbedding;
|
|
217
|
+
if (config.guidance_embeds === true) {
|
|
218
|
+
const guidanceTensor = createTensor(
|
|
219
|
+
createVectorBuffer(device, buildSinusoidalEmbedding(guidanceScale, 256), 'sana_guidance'),
|
|
220
|
+
activationDtype,
|
|
221
|
+
[1, 256],
|
|
222
|
+
'sana_guidance'
|
|
223
|
+
);
|
|
224
|
+
const guidanceEmbedding = await runTwoLayerEmbedding(
|
|
225
|
+
guidanceTensor,
|
|
226
|
+
weightsEntry,
|
|
227
|
+
'time_embed.guidance_embedder',
|
|
228
|
+
hiddenSize,
|
|
229
|
+
recorder,
|
|
230
|
+
runtime,
|
|
231
|
+
ops,
|
|
232
|
+
release
|
|
233
|
+
);
|
|
234
|
+
release(guidanceTensor.buffer);
|
|
235
|
+
conditioning = await ops.residualAdd(timeEmbedding, guidanceEmbedding, hiddenSize, { useVec4: true });
|
|
236
|
+
release(timeEmbedding.buffer);
|
|
237
|
+
release(guidanceEmbedding.buffer);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
const conditioningAct = await ops.silu(conditioning, { size: hiddenSize, swigluLimit: null });
|
|
241
|
+
let modulation = await runMatmulResolved(
|
|
242
|
+
conditioningAct,
|
|
243
|
+
weightsEntry,
|
|
244
|
+
'time_embed.linear.weight',
|
|
245
|
+
1,
|
|
246
|
+
hiddenSize * 6,
|
|
247
|
+
hiddenSize,
|
|
248
|
+
recorder,
|
|
249
|
+
{ outputDtype: activationDtype }
|
|
250
|
+
);
|
|
251
|
+
const modulationBias = createBiasTensor(getWeight(weightsEntry, 'time_embed.linear.bias'), hiddenSize * 6, 'sana_time_linear_bias');
|
|
252
|
+
if (modulationBias) {
|
|
253
|
+
modulation = await ops.biasAdd(modulation, modulationBias, 1, hiddenSize * 6);
|
|
254
|
+
}
|
|
255
|
+
release(conditioningAct.buffer);
|
|
256
|
+
|
|
257
|
+
return {
|
|
258
|
+
modulation,
|
|
259
|
+
embeddedTimestep: conditioning,
|
|
260
|
+
};
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
export async function projectSanaContext(context, attentionMask, weightsEntry, config, runtime, options = {}) {
|
|
264
|
+
const recorder = options.recorder ?? null;
|
|
265
|
+
const ops = createKernelOps(recorder);
|
|
266
|
+
const release = createDiffusionBufferReleaser(recorder);
|
|
267
|
+
const trimmed = trimContext(context, attentionMask);
|
|
268
|
+
const tokenCount = trimmed.shape[0];
|
|
269
|
+
const inputDim = trimmed.shape[1];
|
|
270
|
+
const hiddenSize = config.num_attention_heads * config.attention_head_dim;
|
|
271
|
+
const activationDtype = resolveDiffusionActivationDtype(runtime);
|
|
272
|
+
|
|
273
|
+
let hidden = await runMatmulResolved(
|
|
274
|
+
trimmed,
|
|
275
|
+
weightsEntry,
|
|
276
|
+
'caption_projection.linear_1.weight',
|
|
277
|
+
tokenCount,
|
|
278
|
+
hiddenSize,
|
|
279
|
+
inputDim,
|
|
280
|
+
recorder,
|
|
281
|
+
{ outputDtype: activationDtype }
|
|
282
|
+
);
|
|
283
|
+
const bias1 = createBiasTensor(getWeight(weightsEntry, 'caption_projection.linear_1.bias'), hiddenSize, 'sana_caption_bias1');
|
|
284
|
+
if (bias1) {
|
|
285
|
+
hidden = await ops.biasAdd(hidden, bias1, tokenCount, hiddenSize);
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
// PixArtAlphaTextProjection uses GELU(tanh). Reuse the existing GeLU kernel here.
|
|
289
|
+
const { runGeLU, recordGeLU } = await import('../../../gpu/kernels/gelu.js');
|
|
290
|
+
const gelu = recorder
|
|
291
|
+
? (input, options = {}) => recordGeLU(recorder, input, options)
|
|
292
|
+
: runGeLU;
|
|
293
|
+
const activated = await gelu(hidden, { size: tokenCount * hiddenSize });
|
|
294
|
+
release(hidden.buffer);
|
|
295
|
+
|
|
296
|
+
let projected = await runMatmulResolved(
|
|
297
|
+
activated,
|
|
298
|
+
weightsEntry,
|
|
299
|
+
'caption_projection.linear_2.weight',
|
|
300
|
+
tokenCount,
|
|
301
|
+
hiddenSize,
|
|
302
|
+
hiddenSize,
|
|
303
|
+
recorder,
|
|
304
|
+
{ outputDtype: activationDtype }
|
|
305
|
+
);
|
|
306
|
+
const bias2 = createBiasTensor(getWeight(weightsEntry, 'caption_projection.linear_2.bias'), hiddenSize, 'sana_caption_bias2');
|
|
307
|
+
if (bias2) {
|
|
308
|
+
projected = await ops.biasAdd(projected, bias2, tokenCount, hiddenSize);
|
|
309
|
+
}
|
|
310
|
+
release(activated.buffer);
|
|
311
|
+
|
|
312
|
+
const normWeight = expectDiffusionWeight(getWeight(weightsEntry, 'caption_norm.weight'), 'caption_norm.weight');
|
|
313
|
+
const normed = await ops.rmsNorm(projected, getBuffer(normWeight), 1e-5, {
|
|
314
|
+
batchSize: tokenCount,
|
|
315
|
+
hiddenSize,
|
|
316
|
+
});
|
|
317
|
+
release(projected.buffer);
|
|
318
|
+
return normed;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
async function duplicateVectorTensor(tensor, times, recorder) {
|
|
322
|
+
const device = getDevice();
|
|
323
|
+
const output = acquireBuffer(tensor.buffer.size * times, undefined, 'sana_duplicate_vector');
|
|
324
|
+
const encoder = recorder ? recorder.getEncoder() : device.createCommandEncoder();
|
|
325
|
+
for (let i = 0; i < times; i++) {
|
|
326
|
+
encoder.copyBufferToBuffer(tensor.buffer, 0, output, i * tensor.buffer.size, tensor.buffer.size);
|
|
327
|
+
}
|
|
328
|
+
if (!recorder) {
|
|
329
|
+
device.queue.submit([encoder.finish()]);
|
|
330
|
+
}
|
|
331
|
+
return createTensor(output, tensor.dtype, [1, tensor.shape[1] * times], 'sana_duplicate_vector');
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
async function buildLayerModulation(baseModulation, tableWeight, tableShape, recorder, ops, release) {
|
|
335
|
+
const segments = Array.isArray(tableShape) ? tableShape.reduce((acc, value) => acc * value, 1) : 0;
|
|
336
|
+
let combined = await ops.biasAdd(
|
|
337
|
+
baseModulation,
|
|
338
|
+
createTensor(getBuffer(tableWeight), baseModulation.dtype, [segments], 'sana_layer_table'),
|
|
339
|
+
1,
|
|
340
|
+
segments
|
|
341
|
+
);
|
|
342
|
+
return combined;
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
async function runSelfAttention(hiddenStates, layerIdx, weightsEntry, numTokens, hiddenSize, numHeads, headDim, eps, recorder, runtime, ops, release) {
|
|
346
|
+
let query = await runMatmulResolved(
|
|
347
|
+
hiddenStates,
|
|
348
|
+
weightsEntry,
|
|
349
|
+
`transformer_blocks.${layerIdx}.attn1.to_q.weight`,
|
|
350
|
+
numTokens,
|
|
351
|
+
hiddenSize,
|
|
352
|
+
hiddenSize,
|
|
353
|
+
recorder,
|
|
354
|
+
{ outputDtype: hiddenStates.dtype }
|
|
355
|
+
);
|
|
356
|
+
let key = await runMatmulResolved(
|
|
357
|
+
hiddenStates,
|
|
358
|
+
weightsEntry,
|
|
359
|
+
`transformer_blocks.${layerIdx}.attn1.to_k.weight`,
|
|
360
|
+
numTokens,
|
|
361
|
+
hiddenSize,
|
|
362
|
+
hiddenSize,
|
|
363
|
+
recorder,
|
|
364
|
+
{ outputDtype: hiddenStates.dtype }
|
|
365
|
+
);
|
|
366
|
+
let value = await runMatmulResolved(
|
|
367
|
+
hiddenStates,
|
|
368
|
+
weightsEntry,
|
|
369
|
+
`transformer_blocks.${layerIdx}.attn1.to_v.weight`,
|
|
370
|
+
numTokens,
|
|
371
|
+
hiddenSize,
|
|
372
|
+
hiddenSize,
|
|
373
|
+
recorder,
|
|
374
|
+
{ outputDtype: hiddenStates.dtype }
|
|
375
|
+
);
|
|
376
|
+
|
|
377
|
+
const normQ = getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn1.norm_q.weight`);
|
|
378
|
+
const normK = getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn1.norm_k.weight`);
|
|
379
|
+
if (normQ) {
|
|
380
|
+
const next = await ops.rmsNorm(query, getBuffer(normQ), eps, { batchSize: numTokens, hiddenSize });
|
|
381
|
+
release(query.buffer);
|
|
382
|
+
query = next;
|
|
383
|
+
}
|
|
384
|
+
if (normK) {
|
|
385
|
+
const next = await ops.rmsNorm(key, getBuffer(normK), eps, { batchSize: numTokens, hiddenSize });
|
|
386
|
+
release(key.buffer);
|
|
387
|
+
key = next;
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
let output = await ops.sanaLinearAttention(query, key, value, {
|
|
391
|
+
numHeads,
|
|
392
|
+
headDim,
|
|
393
|
+
numTokens,
|
|
394
|
+
hiddenSize,
|
|
395
|
+
eps: 1e-15,
|
|
396
|
+
});
|
|
397
|
+
release(query.buffer);
|
|
398
|
+
release(key.buffer);
|
|
399
|
+
release(value.buffer);
|
|
400
|
+
|
|
401
|
+
let projected = await runMatmulResolved(
|
|
402
|
+
output,
|
|
403
|
+
weightsEntry,
|
|
404
|
+
`transformer_blocks.${layerIdx}.attn1.to_out.0.weight`,
|
|
405
|
+
numTokens,
|
|
406
|
+
hiddenSize,
|
|
407
|
+
hiddenSize,
|
|
408
|
+
recorder,
|
|
409
|
+
{ outputDtype: hiddenStates.dtype }
|
|
410
|
+
);
|
|
411
|
+
release(output.buffer);
|
|
412
|
+
const outBias = createBiasTensor(getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn1.to_out.0.bias`), hiddenSize, 'sana_self_attn_bias');
|
|
413
|
+
if (outBias) {
|
|
414
|
+
projected = await ops.biasAdd(projected, outBias, numTokens, hiddenSize);
|
|
415
|
+
}
|
|
416
|
+
return projected;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
async function runCrossAttention(hiddenStates, context, layerIdx, weightsEntry, numTokens, hiddenSize, config, recorder, ops, release) {
|
|
420
|
+
const qHeads = getWeightShape(weightsEntry, `transformer_blocks.${layerIdx}.attn2.to_q.weight`)[0];
|
|
421
|
+
const kHeads = getWeightShape(weightsEntry, `transformer_blocks.${layerIdx}.attn2.to_k.weight`)[0];
|
|
422
|
+
const vHeads = getWeightShape(weightsEntry, `transformer_blocks.${layerIdx}.attn2.to_v.weight`)[0];
|
|
423
|
+
const contextTokens = context.shape[0];
|
|
424
|
+
let query = await runMatmulResolved(
|
|
425
|
+
hiddenStates,
|
|
426
|
+
weightsEntry,
|
|
427
|
+
`transformer_blocks.${layerIdx}.attn2.to_q.weight`,
|
|
428
|
+
numTokens,
|
|
429
|
+
qHeads,
|
|
430
|
+
hiddenSize,
|
|
431
|
+
recorder,
|
|
432
|
+
{ outputDtype: hiddenStates.dtype }
|
|
433
|
+
);
|
|
434
|
+
let key = await runMatmulResolved(
|
|
435
|
+
context,
|
|
436
|
+
weightsEntry,
|
|
437
|
+
`transformer_blocks.${layerIdx}.attn2.to_k.weight`,
|
|
438
|
+
contextTokens,
|
|
439
|
+
kHeads,
|
|
440
|
+
context.shape[1],
|
|
441
|
+
recorder,
|
|
442
|
+
{ outputDtype: hiddenStates.dtype }
|
|
443
|
+
);
|
|
444
|
+
let value = await runMatmulResolved(
|
|
445
|
+
context,
|
|
446
|
+
weightsEntry,
|
|
447
|
+
`transformer_blocks.${layerIdx}.attn2.to_v.weight`,
|
|
448
|
+
contextTokens,
|
|
449
|
+
vHeads,
|
|
450
|
+
context.shape[1],
|
|
451
|
+
recorder,
|
|
452
|
+
{ outputDtype: hiddenStates.dtype }
|
|
453
|
+
);
|
|
454
|
+
|
|
455
|
+
const crossHeads = config.num_cross_attention_heads;
|
|
456
|
+
const headDim = config.cross_attention_head_dim;
|
|
457
|
+
const normQ = getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn2.norm_q.weight`);
|
|
458
|
+
const normK = getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn2.norm_k.weight`);
|
|
459
|
+
if (normQ) {
|
|
460
|
+
const next = await ops.rmsNorm(query, getBuffer(normQ), Number(config.norm_eps ?? 1e-6), {
|
|
461
|
+
batchSize: numTokens,
|
|
462
|
+
hiddenSize: qHeads,
|
|
463
|
+
});
|
|
464
|
+
release(query.buffer);
|
|
465
|
+
query = next;
|
|
466
|
+
}
|
|
467
|
+
if (normK) {
|
|
468
|
+
const next = await ops.rmsNorm(key, getBuffer(normK), Number(config.norm_eps ?? 1e-6), {
|
|
469
|
+
batchSize: contextTokens,
|
|
470
|
+
hiddenSize: kHeads,
|
|
471
|
+
});
|
|
472
|
+
release(key.buffer);
|
|
473
|
+
key = next;
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
const attention = await ops.attention(query, key, value, null, crossHeads, headDim, {
|
|
477
|
+
seqLen: numTokens,
|
|
478
|
+
kvLen: contextTokens,
|
|
479
|
+
numKVHeads: crossHeads,
|
|
480
|
+
causal: false,
|
|
481
|
+
});
|
|
482
|
+
release(query.buffer);
|
|
483
|
+
release(key.buffer);
|
|
484
|
+
release(value.buffer);
|
|
485
|
+
|
|
486
|
+
let projected = await runMatmulResolved(
|
|
487
|
+
attention,
|
|
488
|
+
weightsEntry,
|
|
489
|
+
`transformer_blocks.${layerIdx}.attn2.to_out.0.weight`,
|
|
490
|
+
numTokens,
|
|
491
|
+
hiddenSize,
|
|
492
|
+
hiddenSize,
|
|
493
|
+
recorder,
|
|
494
|
+
{ outputDtype: hiddenStates.dtype }
|
|
495
|
+
);
|
|
496
|
+
release(attention.buffer);
|
|
497
|
+
const outBias = createBiasTensor(getWeight(weightsEntry, `transformer_blocks.${layerIdx}.attn2.to_out.0.bias`), hiddenSize, 'sana_cross_attn_bias');
|
|
498
|
+
if (outBias) {
|
|
499
|
+
projected = await ops.biasAdd(projected, outBias, numTokens, hiddenSize);
|
|
500
|
+
}
|
|
501
|
+
return projected;
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
async function runGlumbConv(hiddenStates, layerIdx, weightsEntry, gridHeight, gridWidth, hiddenSize, recorder, runtime, ops, release) {
|
|
505
|
+
const expandRatio = Number(getWeightShape(weightsEntry, `transformer_blocks.${layerIdx}.ff.conv_inverted.weight`)[0]) / hiddenSize / 2;
|
|
506
|
+
const hiddenChannels = Math.floor(hiddenSize * expandRatio);
|
|
507
|
+
let inverted = await runMatmulResolved(
|
|
508
|
+
hiddenStates,
|
|
509
|
+
weightsEntry,
|
|
510
|
+
`transformer_blocks.${layerIdx}.ff.conv_inverted.weight`,
|
|
511
|
+
hiddenStates.shape[0],
|
|
512
|
+
hiddenChannels * 2,
|
|
513
|
+
hiddenSize,
|
|
514
|
+
recorder,
|
|
515
|
+
{ outputDtype: hiddenStates.dtype }
|
|
516
|
+
);
|
|
517
|
+
const invertedBias = createBiasTensor(getWeight(weightsEntry, `transformer_blocks.${layerIdx}.ff.conv_inverted.bias`), hiddenChannels * 2, 'sana_ff_inverted_bias');
|
|
518
|
+
if (invertedBias) {
|
|
519
|
+
inverted = await ops.biasAdd(inverted, invertedBias, hiddenStates.shape[0], hiddenChannels * 2);
|
|
520
|
+
}
|
|
521
|
+
const invertedAct = await ops.silu(inverted, { size: hiddenStates.shape[0] * hiddenChannels * 2, swigluLimit: null });
|
|
522
|
+
release(inverted.buffer);
|
|
523
|
+
|
|
524
|
+
const convInput = reshapeTensor(invertedAct, [hiddenChannels * 2, gridHeight, gridWidth], 'sana_ff_conv_input');
|
|
525
|
+
const depthWeight = expectDiffusionWeight(getWeight(weightsEntry, `transformer_blocks.${layerIdx}.ff.conv_depth.weight`), `transformer_blocks.${layerIdx}.ff.conv_depth.weight`);
|
|
526
|
+
const depthBias = getWeight(weightsEntry, `transformer_blocks.${layerIdx}.ff.conv_depth.bias`);
|
|
527
|
+
const depth = await ops.depthwiseConv2d(convInput, depthWeight, depthBias, {
|
|
528
|
+
channels: hiddenChannels * 2,
|
|
529
|
+
height: gridHeight,
|
|
530
|
+
width: gridWidth,
|
|
531
|
+
kernelH: 3,
|
|
532
|
+
kernelW: 3,
|
|
533
|
+
stride: 1,
|
|
534
|
+
pad: 1,
|
|
535
|
+
});
|
|
536
|
+
release(invertedAct.buffer);
|
|
537
|
+
|
|
538
|
+
const depthTokens = reshapeTensor(depth, [hiddenStates.shape[0], hiddenChannels * 2], 'sana_ff_depth_tokens');
|
|
539
|
+
const gated = await ops.siluRowSplit(depthTokens, {
|
|
540
|
+
numTokens: hiddenStates.shape[0],
|
|
541
|
+
dim: hiddenChannels,
|
|
542
|
+
activation: 'silu',
|
|
543
|
+
swigluLimit: null,
|
|
544
|
+
});
|
|
545
|
+
release(depth.buffer);
|
|
546
|
+
|
|
547
|
+
let projected = await runMatmulResolved(
|
|
548
|
+
gated,
|
|
549
|
+
weightsEntry,
|
|
550
|
+
`transformer_blocks.${layerIdx}.ff.conv_point.weight`,
|
|
551
|
+
hiddenStates.shape[0],
|
|
552
|
+
hiddenSize,
|
|
553
|
+
hiddenChannels,
|
|
554
|
+
recorder,
|
|
555
|
+
{ outputDtype: hiddenStates.dtype }
|
|
556
|
+
);
|
|
557
|
+
release(gated.buffer);
|
|
558
|
+
return projected;
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
export async function runSanaTransformer(latents, context, timeState, weightsEntry, modelConfig, runtime, options = {}) {
|
|
562
|
+
const device = getDevice();
|
|
563
|
+
if (!device) {
|
|
564
|
+
throw new Error('Sana transformer requires a WebGPU device.');
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
const recorder = options.recorder ?? null;
|
|
568
|
+
const ops = createKernelOps(recorder);
|
|
569
|
+
const release = createDiffusionBufferReleaser(recorder);
|
|
570
|
+
const destroy = createDiffusionBufferDestroyer(recorder);
|
|
571
|
+
const config = modelConfig?.components?.transformer?.config || {};
|
|
572
|
+
const hiddenSize = config.num_attention_heads * config.attention_head_dim;
|
|
573
|
+
const numHeads = config.num_attention_heads;
|
|
574
|
+
const headDim = config.attention_head_dim;
|
|
575
|
+
const patchSize = config.patch_size ?? 1;
|
|
576
|
+
const normEps = Number(config.norm_eps ?? 1e-6);
|
|
577
|
+
const latentHeight = latents.shape[1];
|
|
578
|
+
const latentWidth = latents.shape[2];
|
|
579
|
+
const gridHeight = Math.floor(latentHeight / patchSize);
|
|
580
|
+
const gridWidth = Math.floor(latentWidth / patchSize);
|
|
581
|
+
const numTokens = gridHeight * gridWidth;
|
|
582
|
+
|
|
583
|
+
const patchWeight = expectDiffusionWeight(getWeight(weightsEntry, 'patch_embed.proj.weight'), 'patch_embed.proj.weight');
|
|
584
|
+
const patchBias = getWeight(weightsEntry, 'patch_embed.proj.bias');
|
|
585
|
+
const conv = await ops.conv2d(latents, patchWeight, patchBias, {
|
|
586
|
+
inChannels: latents.shape[0],
|
|
587
|
+
outChannels: hiddenSize,
|
|
588
|
+
height: latentHeight,
|
|
589
|
+
width: latentWidth,
|
|
590
|
+
kernelH: patchSize,
|
|
591
|
+
kernelW: patchSize,
|
|
592
|
+
stride: patchSize,
|
|
593
|
+
pad: 0,
|
|
594
|
+
});
|
|
595
|
+
let hidden = await import('../../../gpu/kernels/transpose.js').then(({ runTranspose, recordTranspose }) => {
|
|
596
|
+
const transpose = recorder ? (input, rows, cols) => recordTranspose(recorder, input, rows, cols) : runTranspose;
|
|
597
|
+
return transpose(conv, hiddenSize, numTokens);
|
|
598
|
+
});
|
|
599
|
+
release(conv.buffer);
|
|
600
|
+
|
|
601
|
+
const ones = new Float32Array(hiddenSize).fill(1.0);
|
|
602
|
+
const zeros = new Float32Array(hiddenSize);
|
|
603
|
+
const onesBuf = createVectorBuffer(device, ones, 'sana_norm_ones');
|
|
604
|
+
const zerosBuf = createVectorBuffer(device, zeros, 'sana_norm_zeros');
|
|
605
|
+
|
|
606
|
+
for (let layerIdx = 0; layerIdx < config.num_layers; layerIdx++) {
|
|
607
|
+
const layerTable = expectDiffusionWeight(getWeight(weightsEntry, `transformer_blocks.${layerIdx}.scale_shift_table`), `transformer_blocks.${layerIdx}.scale_shift_table`);
|
|
608
|
+
const modulation = await buildLayerModulation(
|
|
609
|
+
timeState.modulation,
|
|
610
|
+
layerTable,
|
|
611
|
+
getWeightShape(weightsEntry, `transformer_blocks.${layerIdx}.scale_shift_table`),
|
|
612
|
+
recorder,
|
|
613
|
+
ops,
|
|
614
|
+
release
|
|
615
|
+
);
|
|
616
|
+
|
|
617
|
+
let norm1 = await ops.layerNorm(hidden, onesBuf, zerosBuf, normEps, { batchSize: numTokens, hiddenSize });
|
|
618
|
+
norm1 = await ops.modulate(norm1, modulation, {
|
|
619
|
+
numTokens,
|
|
620
|
+
hiddenSize,
|
|
621
|
+
scaleOffset: hiddenSize,
|
|
622
|
+
shiftOffset: 0,
|
|
623
|
+
gateOffset: hiddenSize * 2,
|
|
624
|
+
hasGate: false,
|
|
625
|
+
addOne: true,
|
|
626
|
+
});
|
|
627
|
+
const selfAttn = await runSelfAttention(norm1, layerIdx, weightsEntry, numTokens, hiddenSize, numHeads, headDim, normEps, recorder, runtime, ops, release);
|
|
628
|
+
release(norm1.buffer);
|
|
629
|
+
const gatedSelf = await ops.modulate(selfAttn, modulation, {
|
|
630
|
+
numTokens,
|
|
631
|
+
hiddenSize,
|
|
632
|
+
scaleOffset: hiddenSize * 2,
|
|
633
|
+
shiftOffset: hiddenSize * 6,
|
|
634
|
+
gateOffset: hiddenSize * 2,
|
|
635
|
+
hasGate: false,
|
|
636
|
+
addOne: false,
|
|
637
|
+
});
|
|
638
|
+
release(selfAttn.buffer);
|
|
639
|
+
let nextHidden = await ops.residualAdd(hidden, gatedSelf, numTokens * hiddenSize, { useVec4: true });
|
|
640
|
+
release(hidden.buffer);
|
|
641
|
+
release(gatedSelf.buffer);
|
|
642
|
+
hidden = createTensor(nextHidden.buffer, nextHidden.dtype, [numTokens, hiddenSize], 'sana_hidden_after_self');
|
|
643
|
+
|
|
644
|
+
let norm2 = await ops.layerNorm(hidden, onesBuf, zerosBuf, normEps, { batchSize: numTokens, hiddenSize });
|
|
645
|
+
const crossAttn = await runCrossAttention(norm2, context, layerIdx, weightsEntry, numTokens, hiddenSize, config, recorder, ops, release);
|
|
646
|
+
release(norm2.buffer);
|
|
647
|
+
nextHidden = await ops.residualAdd(hidden, crossAttn, numTokens * hiddenSize, { useVec4: true });
|
|
648
|
+
release(hidden.buffer);
|
|
649
|
+
release(crossAttn.buffer);
|
|
650
|
+
hidden = createTensor(nextHidden.buffer, nextHidden.dtype, [numTokens, hiddenSize], 'sana_hidden_after_cross');
|
|
651
|
+
|
|
652
|
+
let normFf = await ops.layerNorm(hidden, onesBuf, zerosBuf, normEps, { batchSize: numTokens, hiddenSize });
|
|
653
|
+
normFf = await ops.modulate(normFf, modulation, {
|
|
654
|
+
numTokens,
|
|
655
|
+
hiddenSize,
|
|
656
|
+
scaleOffset: hiddenSize * 4,
|
|
657
|
+
shiftOffset: hiddenSize * 3,
|
|
658
|
+
gateOffset: hiddenSize * 5,
|
|
659
|
+
hasGate: false,
|
|
660
|
+
addOne: true,
|
|
661
|
+
});
|
|
662
|
+
const ff = await runGlumbConv(normFf, layerIdx, weightsEntry, gridHeight, gridWidth, hiddenSize, recorder, runtime, ops, release);
|
|
663
|
+
release(normFf.buffer);
|
|
664
|
+
const gatedFf = await ops.modulate(ff, modulation, {
|
|
665
|
+
numTokens,
|
|
666
|
+
hiddenSize,
|
|
667
|
+
scaleOffset: hiddenSize * 5,
|
|
668
|
+
shiftOffset: hiddenSize * 6,
|
|
669
|
+
gateOffset: hiddenSize * 5,
|
|
670
|
+
hasGate: false,
|
|
671
|
+
addOne: false,
|
|
672
|
+
});
|
|
673
|
+
release(ff.buffer);
|
|
674
|
+
release(modulation.buffer);
|
|
675
|
+
nextHidden = await ops.residualAdd(hidden, gatedFf, numTokens * hiddenSize, { useVec4: true });
|
|
676
|
+
release(hidden.buffer);
|
|
677
|
+
release(gatedFf.buffer);
|
|
678
|
+
hidden = createTensor(nextHidden.buffer, nextHidden.dtype, [numTokens, hiddenSize], 'sana_hidden_after_ff');
|
|
679
|
+
}
|
|
680
|
+
release(timeState.modulation.buffer);
|
|
681
|
+
|
|
682
|
+
const finalTable = expectDiffusionWeight(getWeight(weightsEntry, 'scale_shift_table'), 'scale_shift_table');
|
|
683
|
+
const duplicated = await duplicateVectorTensor(timeState.embeddedTimestep, 2, recorder);
|
|
684
|
+
release(timeState.embeddedTimestep.buffer);
|
|
685
|
+
let finalMod = await ops.biasAdd(
|
|
686
|
+
duplicated,
|
|
687
|
+
createTensor(getBuffer(finalTable), duplicated.dtype, [hiddenSize * 2], 'sana_final_table'),
|
|
688
|
+
1,
|
|
689
|
+
hiddenSize * 2
|
|
690
|
+
);
|
|
691
|
+
release(duplicated.buffer);
|
|
692
|
+
let normed = await ops.layerNorm(hidden, onesBuf, zerosBuf, 1e-6, { batchSize: numTokens, hiddenSize });
|
|
693
|
+
normed = await ops.modulate(normed, finalMod, {
|
|
694
|
+
numTokens,
|
|
695
|
+
hiddenSize,
|
|
696
|
+
scaleOffset: hiddenSize,
|
|
697
|
+
shiftOffset: 0,
|
|
698
|
+
gateOffset: hiddenSize,
|
|
699
|
+
hasGate: false,
|
|
700
|
+
addOne: true,
|
|
701
|
+
});
|
|
702
|
+
release(hidden.buffer);
|
|
703
|
+
release(finalMod.buffer);
|
|
704
|
+
|
|
705
|
+
let projected = await runMatmulResolved(
|
|
706
|
+
normed,
|
|
707
|
+
weightsEntry,
|
|
708
|
+
'proj_out.weight',
|
|
709
|
+
numTokens,
|
|
710
|
+
config.out_channels ?? latents.shape[0],
|
|
711
|
+
hiddenSize,
|
|
712
|
+
recorder,
|
|
713
|
+
{ outputDtype: normed.dtype }
|
|
714
|
+
);
|
|
715
|
+
release(normed.buffer);
|
|
716
|
+
const projBias = createBiasTensor(getWeight(weightsEntry, 'proj_out.bias'), projected.shape[1], 'sana_proj_out_bias');
|
|
717
|
+
if (projBias) {
|
|
718
|
+
projected = await ops.biasAdd(projected, projBias, numTokens, projected.shape[1]);
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
const { runTranspose, recordTranspose } = await import('../../../gpu/kernels/transpose.js');
|
|
722
|
+
const transpose = recorder ? (input, rows, cols) => recordTranspose(recorder, input, rows, cols) : runTranspose;
|
|
723
|
+
const channelsFirst = await transpose(projected, numTokens, projected.shape[1]);
|
|
724
|
+
release(projected.buffer);
|
|
725
|
+
destroy(onesBuf);
|
|
726
|
+
destroy(zerosBuf);
|
|
727
|
+
return reshapeTensor(channelsFirst, [config.out_channels ?? latents.shape[0], gridHeight, gridWidth], 'sana_output');
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
export async function buildSanaConditioning(context, attentionMask, timestep, guidanceScale, weightsEntry, modelConfig, runtime, options = {}) {
|
|
731
|
+
const config = modelConfig?.components?.transformer?.config || {};
|
|
732
|
+
const projectedContext = await projectSanaContext(context, attentionMask, weightsEntry, config, runtime, options);
|
|
733
|
+
const timestepState = await buildSanaTimestepConditioning(timestep, guidanceScale, weightsEntry, config, runtime, options);
|
|
734
|
+
return {
|
|
735
|
+
context: projectedContext,
|
|
736
|
+
timeState: timestepState,
|
|
737
|
+
};
|
|
738
|
+
}
|