@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -14,10 +14,11 @@ import {
|
|
|
14
14
|
projectContext,
|
|
15
15
|
assertClipHiddenActivationSupported,
|
|
16
16
|
} from './text-encoder-gpu.js';
|
|
17
|
-
import { buildScheduler } from './scheduler.js';
|
|
17
|
+
import { buildScheduler, stepScmScheduler } from './scheduler.js';
|
|
18
18
|
import { decodeLatents } from './vae.js';
|
|
19
19
|
import { createDiffusionWeightLoader } from './weights.js';
|
|
20
20
|
import { runSD3Transformer } from './sd3-transformer.js';
|
|
21
|
+
import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
|
|
21
22
|
import { createSD3WeightResolver } from './sd3-weights.js';
|
|
22
23
|
import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
|
|
23
24
|
import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
|
|
|
27
28
|
import { f16ToF32 } from '../../../loader/dtype-utils.js';
|
|
28
29
|
|
|
29
30
|
const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
|
|
31
|
+
const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
|
|
32
|
+
const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
|
|
30
33
|
|
|
31
34
|
function createRandomSeed() {
|
|
32
35
|
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
|
@@ -49,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
|
|
|
49
52
|
return { latents, latentWidth, latentHeight };
|
|
50
53
|
}
|
|
51
54
|
|
|
55
|
+
function generateNoiseVector(size, seed) {
|
|
56
|
+
if (!Number.isFinite(size) || size <= 0) {
|
|
57
|
+
throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
|
|
58
|
+
}
|
|
59
|
+
const out = new Float32Array(size);
|
|
60
|
+
const rand = createRng(seed ?? createRandomSeed());
|
|
61
|
+
for (let i = 0; i < size; i++) {
|
|
62
|
+
out[i] = sampleNormal(rand);
|
|
63
|
+
}
|
|
64
|
+
return out;
|
|
65
|
+
}
|
|
66
|
+
|
|
52
67
|
function extractTokenSet(tokensByEncoder, key) {
|
|
53
68
|
const output = {};
|
|
54
69
|
for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
|
|
@@ -58,6 +73,49 @@ function extractTokenSet(tokensByEncoder, key) {
|
|
|
58
73
|
return output;
|
|
59
74
|
}
|
|
60
75
|
|
|
76
|
+
function resolveDiffusionLayout(modelConfig) {
|
|
77
|
+
return modelConfig?.layout ?? 'sd3';
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
function getTextEncoderKeysForLayout(layout) {
|
|
81
|
+
if (layout === 'sana') {
|
|
82
|
+
return SANA_TEXT_ENCODER_KEYS;
|
|
83
|
+
}
|
|
84
|
+
return SD3_TEXT_ENCODER_KEYS;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
|
|
88
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
89
|
+
for (const key of requiredKeys) {
|
|
90
|
+
if (!modelConfig?.components?.[key]) {
|
|
91
|
+
throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
|
|
92
|
+
}
|
|
93
|
+
if (!tokenizers?.[key]) {
|
|
94
|
+
throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
function buildTokenizerMaxLengths(layout, runtime) {
|
|
100
|
+
const maxLength = runtime?.textEncoder?.maxLength;
|
|
101
|
+
if (!Number.isFinite(maxLength) || maxLength <= 0) {
|
|
102
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
|
|
103
|
+
}
|
|
104
|
+
if (layout === 'sana') {
|
|
105
|
+
return { text_encoder: maxLength };
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
|
|
109
|
+
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
110
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
111
|
+
}
|
|
112
|
+
return {
|
|
113
|
+
text_encoder: maxLength,
|
|
114
|
+
text_encoder_2: maxLength,
|
|
115
|
+
text_encoder_3: t5MaxLength,
|
|
116
|
+
};
|
|
117
|
+
}
|
|
118
|
+
|
|
61
119
|
function getTensorSize(shape) {
|
|
62
120
|
if (!Array.isArray(shape)) return 0;
|
|
63
121
|
return shape.reduce((acc, value) => acc * value, 1);
|
|
@@ -120,6 +178,46 @@ async function readTensorToFloat32(tensor) {
|
|
|
120
178
|
return new Float32Array(data);
|
|
121
179
|
}
|
|
122
180
|
|
|
181
|
+
async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
|
|
182
|
+
if (scheduler.type === 'flowmatch_euler') {
|
|
183
|
+
const sigma = scheduler.sigmas[stepIndex];
|
|
184
|
+
const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
|
|
185
|
+
const delta = sigmaNext - sigma;
|
|
186
|
+
const latentSize = getTensorSize(latentsTensor.shape);
|
|
187
|
+
const scale = options.scale ?? runScale;
|
|
188
|
+
const residualAdd = options.residualAdd ?? runResidualAdd;
|
|
189
|
+
const release = options.release ?? releaseBuffer;
|
|
190
|
+
|
|
191
|
+
const scaled = await scale(predictionTensor, delta, { count: latentSize });
|
|
192
|
+
const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
|
|
193
|
+
|
|
194
|
+
release(latentsTensor.buffer);
|
|
195
|
+
release(scaled.buffer);
|
|
196
|
+
release(predictionTensor.buffer);
|
|
197
|
+
|
|
198
|
+
return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
if (scheduler.type === 'scm') {
|
|
202
|
+
const sample = await readTensorToFloat32(latentsTensor);
|
|
203
|
+
const modelOutput = await readTensorToFloat32(predictionTensor);
|
|
204
|
+
releaseBuffer(predictionTensor.buffer);
|
|
205
|
+
releaseBuffer(latentsTensor.buffer);
|
|
206
|
+
|
|
207
|
+
const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
|
|
208
|
+
const noise = isFinalStep
|
|
209
|
+
? null
|
|
210
|
+
: generateNoiseVector(
|
|
211
|
+
sample.length,
|
|
212
|
+
(options.seedBase ?? createRandomSeed()) + stepIndex + 1
|
|
213
|
+
);
|
|
214
|
+
const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
|
|
215
|
+
return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
|
|
219
|
+
}
|
|
220
|
+
|
|
123
221
|
async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
|
|
124
222
|
if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
|
|
125
223
|
return cond;
|
|
@@ -251,14 +349,17 @@ export class DiffusionPipeline {
|
|
|
251
349
|
});
|
|
252
350
|
}
|
|
253
351
|
|
|
254
|
-
const
|
|
255
|
-
const
|
|
256
|
-
const
|
|
352
|
+
const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
|
|
353
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
354
|
+
const weights = {};
|
|
355
|
+
for (const key of requiredKeys) {
|
|
356
|
+
weights[key] = await this.weightLoader.loadComponentWeights(key);
|
|
357
|
+
}
|
|
257
358
|
|
|
258
359
|
this.textEncoderWeights = {
|
|
259
|
-
text_encoder,
|
|
260
|
-
text_encoder_2,
|
|
261
|
-
text_encoder_3,
|
|
360
|
+
text_encoder: weights.text_encoder ?? null,
|
|
361
|
+
text_encoder_2: weights.text_encoder_2 ?? null,
|
|
362
|
+
text_encoder_3: weights.text_encoder_3 ?? null,
|
|
262
363
|
};
|
|
263
364
|
|
|
264
365
|
return this.textEncoderWeights;
|
|
@@ -315,14 +416,9 @@ export class DiffusionPipeline {
|
|
|
315
416
|
async generateGPU(request = {}) {
|
|
316
417
|
const start = performance.now();
|
|
317
418
|
const runtime = this.diffusionState.runtime;
|
|
318
|
-
const
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
}
|
|
322
|
-
const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
|
|
323
|
-
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
324
|
-
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
325
|
-
}
|
|
419
|
+
const modelConfig = this.diffusionState.modelConfig;
|
|
420
|
+
const layout = resolveDiffusionLayout(modelConfig);
|
|
421
|
+
const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
|
|
326
422
|
|
|
327
423
|
const defaultWidth = runtime.latent.width;
|
|
328
424
|
const defaultHeight = runtime.latent.height;
|
|
@@ -346,28 +442,20 @@ export class DiffusionPipeline {
|
|
|
346
442
|
throw new Error(`Invalid diffusion steps: ${steps}`);
|
|
347
443
|
}
|
|
348
444
|
|
|
349
|
-
const modelConfig = this.diffusionState.modelConfig;
|
|
350
445
|
if (!modelConfig?.components?.transformer) {
|
|
351
446
|
throw new Error('Diffusion GPU pipeline requires transformer component config.');
|
|
352
447
|
}
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
|
|
357
|
-
throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
|
|
448
|
+
assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
|
|
449
|
+
if (layout === 'sd3') {
|
|
450
|
+
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
358
451
|
}
|
|
359
|
-
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
360
452
|
|
|
361
453
|
const promptStart = performance.now();
|
|
362
454
|
const encoded = encodePrompt(
|
|
363
455
|
{ prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
|
|
364
456
|
this.tokenizers || {},
|
|
365
457
|
{
|
|
366
|
-
maxLengthByTokenizer:
|
|
367
|
-
text_encoder: clipMaxLength,
|
|
368
|
-
text_encoder_2: clipMaxLength,
|
|
369
|
-
text_encoder_3: t5MaxLength,
|
|
370
|
-
},
|
|
458
|
+
maxLengthByTokenizer: tokenizerMaxLengths,
|
|
371
459
|
}
|
|
372
460
|
);
|
|
373
461
|
|
|
@@ -410,13 +498,31 @@ export class DiffusionPipeline {
|
|
|
410
498
|
const prefillRecorder = canProfileGpu
|
|
411
499
|
? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
|
|
412
500
|
: null;
|
|
413
|
-
const condContext =
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
501
|
+
const condContext = layout === 'sana'
|
|
502
|
+
? await projectSanaContext(
|
|
503
|
+
promptCondition.context,
|
|
504
|
+
promptCondition.attentionMask,
|
|
505
|
+
transformerWeights,
|
|
506
|
+
transformerConfig,
|
|
507
|
+
runtime,
|
|
508
|
+
{ recorder: prefillRecorder }
|
|
509
|
+
)
|
|
510
|
+
: await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
|
|
418
511
|
recorder: prefillRecorder,
|
|
419
|
-
})
|
|
512
|
+
});
|
|
513
|
+
const uncondContext = shouldUseUncond && negativeCondition
|
|
514
|
+
? layout === 'sana'
|
|
515
|
+
? await projectSanaContext(
|
|
516
|
+
negativeCondition.context,
|
|
517
|
+
negativeCondition.attentionMask,
|
|
518
|
+
transformerWeights,
|
|
519
|
+
transformerConfig,
|
|
520
|
+
runtime,
|
|
521
|
+
{ recorder: prefillRecorder }
|
|
522
|
+
)
|
|
523
|
+
: await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
|
|
524
|
+
recorder: prefillRecorder,
|
|
525
|
+
})
|
|
420
526
|
: null;
|
|
421
527
|
if (prefillRecorder) {
|
|
422
528
|
prefillRecorder.submit();
|
|
@@ -428,11 +534,6 @@ export class DiffusionPipeline {
|
|
|
428
534
|
}
|
|
429
535
|
|
|
430
536
|
const scheduler = buildScheduler(runtime.scheduler, steps);
|
|
431
|
-
if (scheduler.type !== 'flowmatch_euler') {
|
|
432
|
-
throw new Error(
|
|
433
|
-
`Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
|
|
434
|
-
);
|
|
435
|
-
}
|
|
436
537
|
const latentScale = this.diffusionState.latentScale;
|
|
437
538
|
const latentChannels = this.diffusionState.latentChannels;
|
|
438
539
|
const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
|
|
@@ -463,9 +564,6 @@ export class DiffusionPipeline {
|
|
|
463
564
|
const latentSize = latentChannels * latentHeight * latentWidth;
|
|
464
565
|
for (let i = 0; i < scheduler.steps; i++) {
|
|
465
566
|
const timestep = scheduler.timesteps[i];
|
|
466
|
-
const sigma = scheduler.sigmas[i];
|
|
467
|
-
const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
|
|
468
|
-
const delta = sigmaNext - sigma;
|
|
469
567
|
const stepRecorder = canProfileGpu
|
|
470
568
|
? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
|
|
471
569
|
: null;
|
|
@@ -477,37 +575,71 @@ export class DiffusionPipeline {
|
|
|
477
575
|
? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
|
|
478
576
|
: runResidualAdd;
|
|
479
577
|
|
|
480
|
-
const
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
578
|
+
const condPred = layout === 'sana'
|
|
579
|
+
? await (async () => {
|
|
580
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
581
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
582
|
+
guidanceScale,
|
|
583
|
+
transformerWeights,
|
|
584
|
+
transformerConfig,
|
|
585
|
+
runtime,
|
|
586
|
+
{ recorder: stepRecorder }
|
|
587
|
+
);
|
|
588
|
+
return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
589
|
+
recorder: stepRecorder,
|
|
590
|
+
});
|
|
591
|
+
})()
|
|
592
|
+
: await (async () => {
|
|
593
|
+
const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
594
|
+
dim: timeEmbedDim,
|
|
595
|
+
recorder: stepRecorder,
|
|
596
|
+
});
|
|
597
|
+
const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
598
|
+
recorder: stepRecorder,
|
|
599
|
+
});
|
|
600
|
+
const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
|
|
601
|
+
recorder: stepRecorder,
|
|
602
|
+
});
|
|
603
|
+
const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
|
|
604
|
+
recorder: stepRecorder,
|
|
605
|
+
});
|
|
606
|
+
releaseStep(timeTextCond.buffer);
|
|
607
|
+
return output;
|
|
608
|
+
})();
|
|
494
609
|
|
|
495
610
|
let pred = condPred;
|
|
496
611
|
if (shouldUseUncond && uncondContext && negativeCondition) {
|
|
497
|
-
const
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
612
|
+
const uncondPred = layout === 'sana'
|
|
613
|
+
? await (async () => {
|
|
614
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
615
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
616
|
+
guidanceScale,
|
|
617
|
+
transformerWeights,
|
|
618
|
+
transformerConfig,
|
|
619
|
+
runtime,
|
|
620
|
+
{ recorder: stepRecorder }
|
|
621
|
+
);
|
|
622
|
+
return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
623
|
+
recorder: stepRecorder,
|
|
624
|
+
});
|
|
625
|
+
})()
|
|
626
|
+
: await (async () => {
|
|
627
|
+
const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
628
|
+
dim: timeEmbedDim,
|
|
629
|
+
recorder: stepRecorder,
|
|
630
|
+
});
|
|
631
|
+
const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
632
|
+
recorder: stepRecorder,
|
|
633
|
+
});
|
|
634
|
+
const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
|
|
635
|
+
recorder: stepRecorder,
|
|
636
|
+
});
|
|
637
|
+
const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
|
|
638
|
+
recorder: stepRecorder,
|
|
639
|
+
});
|
|
640
|
+
releaseStep(timeTextUncond.buffer);
|
|
641
|
+
return output;
|
|
642
|
+
})();
|
|
511
643
|
pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
|
|
512
644
|
recorder: stepRecorder,
|
|
513
645
|
release: releaseStep,
|
|
@@ -516,14 +648,20 @@ export class DiffusionPipeline {
|
|
|
516
648
|
releaseStep(condPred.buffer);
|
|
517
649
|
}
|
|
518
650
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
651
|
+
latentsTensor = await applySchedulerStep(
|
|
652
|
+
latentsTensor,
|
|
653
|
+
scheduler,
|
|
654
|
+
i,
|
|
655
|
+
timestep,
|
|
656
|
+
pred,
|
|
657
|
+
runtime,
|
|
658
|
+
{
|
|
659
|
+
scale,
|
|
660
|
+
residualAdd,
|
|
661
|
+
release: releaseStep,
|
|
662
|
+
seedBase: seed,
|
|
663
|
+
}
|
|
664
|
+
);
|
|
527
665
|
|
|
528
666
|
if (stepRecorder) {
|
|
529
667
|
stepRecorder.submit();
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import type { Tensor } from '../../../gpu/tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
3
|
+
|
|
4
|
+
export interface SanaTimestepState {
|
|
5
|
+
modulation: Tensor;
|
|
6
|
+
embeddedTimestep: Tensor;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export interface SanaTransformerOptions {
|
|
10
|
+
recorder?: CommandRecorder | null;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export declare function buildSanaTimestepConditioning(
|
|
14
|
+
timestep: number,
|
|
15
|
+
guidanceScale: number,
|
|
16
|
+
weightsEntry: any,
|
|
17
|
+
config: any,
|
|
18
|
+
runtime: any,
|
|
19
|
+
options?: SanaTransformerOptions
|
|
20
|
+
): Promise<SanaTimestepState>;
|
|
21
|
+
|
|
22
|
+
export declare function projectSanaContext(
|
|
23
|
+
context: Tensor,
|
|
24
|
+
attentionMask: Uint32Array | null | undefined,
|
|
25
|
+
weightsEntry: any,
|
|
26
|
+
config: any,
|
|
27
|
+
runtime: any,
|
|
28
|
+
options?: SanaTransformerOptions
|
|
29
|
+
): Promise<Tensor>;
|
|
30
|
+
|
|
31
|
+
export declare function runSanaTransformer(
|
|
32
|
+
latents: Tensor,
|
|
33
|
+
context: Tensor,
|
|
34
|
+
timeState: SanaTimestepState,
|
|
35
|
+
weightsEntry: any,
|
|
36
|
+
modelConfig: any,
|
|
37
|
+
runtime: any,
|
|
38
|
+
options?: SanaTransformerOptions
|
|
39
|
+
): Promise<Tensor>;
|
|
40
|
+
|
|
41
|
+
export declare function buildSanaConditioning(
|
|
42
|
+
context: Tensor,
|
|
43
|
+
attentionMask: Uint32Array | null | undefined,
|
|
44
|
+
timestep: number,
|
|
45
|
+
guidanceScale: number,
|
|
46
|
+
weightsEntry: any,
|
|
47
|
+
modelConfig: any,
|
|
48
|
+
runtime: any,
|
|
49
|
+
options?: SanaTransformerOptions
|
|
50
|
+
): Promise<{
|
|
51
|
+
context: Tensor;
|
|
52
|
+
timeState: SanaTimestepState;
|
|
53
|
+
}>;
|