@simulatte/doppler 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -5
- package/package.json +27 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.d.ts +80 -0
- package/src/client/doppler-api.js +298 -0
- package/src/client/doppler-provider/types.js +1 -1
- package/src/client/doppler-registry.d.ts +23 -0
- package/src/client/doppler-registry.js +88 -0
- package/src/client/doppler-registry.json +39 -0
- package/src/config/execution-contract-check.d.ts +82 -0
- package/src/config/execution-contract-check.js +317 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +90 -67
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +27 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +49 -11
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/tokenizer-utils.js +17 -3
- package/src/formats/rdrr/validation.js +13 -0
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -0
- package/src/index-browser.js +2 -1
- package/src/index.d.ts +1 -0
- package/src/index.js +2 -1
- package/src/inference/browser-harness.js +164 -38
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +141 -101
- package/src/inference/pipelines/text/init.js +41 -10
- package/src/inference/pipelines/text.js +7 -1
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/lean-execution-contract.d.ts +16 -0
- package/src/tooling/lean-execution-contract.js +81 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +30 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +167 -6
|
@@ -14,10 +14,11 @@ import {
|
|
|
14
14
|
projectContext,
|
|
15
15
|
assertClipHiddenActivationSupported,
|
|
16
16
|
} from './text-encoder-gpu.js';
|
|
17
|
-
import { buildScheduler } from './scheduler.js';
|
|
17
|
+
import { buildScheduler, stepScmScheduler } from './scheduler.js';
|
|
18
18
|
import { decodeLatents } from './vae.js';
|
|
19
19
|
import { createDiffusionWeightLoader } from './weights.js';
|
|
20
20
|
import { runSD3Transformer } from './sd3-transformer.js';
|
|
21
|
+
import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
|
|
21
22
|
import { createSD3WeightResolver } from './sd3-weights.js';
|
|
22
23
|
import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
|
|
23
24
|
import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
|
|
|
27
28
|
import { f16ToF32 } from '../../../loader/dtype-utils.js';
|
|
28
29
|
|
|
29
30
|
const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
|
|
31
|
+
const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
|
|
32
|
+
const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
|
|
30
33
|
|
|
31
34
|
function createRandomSeed() {
|
|
32
35
|
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
|
@@ -58,6 +61,49 @@ function extractTokenSet(tokensByEncoder, key) {
|
|
|
58
61
|
return output;
|
|
59
62
|
}
|
|
60
63
|
|
|
64
|
+
function resolveDiffusionLayout(modelConfig) {
|
|
65
|
+
return modelConfig?.layout ?? 'sd3';
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
function getTextEncoderKeysForLayout(layout) {
|
|
69
|
+
if (layout === 'sana') {
|
|
70
|
+
return SANA_TEXT_ENCODER_KEYS;
|
|
71
|
+
}
|
|
72
|
+
return SD3_TEXT_ENCODER_KEYS;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
|
|
76
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
77
|
+
for (const key of requiredKeys) {
|
|
78
|
+
if (!modelConfig?.components?.[key]) {
|
|
79
|
+
throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
|
|
80
|
+
}
|
|
81
|
+
if (!tokenizers?.[key]) {
|
|
82
|
+
throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
function buildTokenizerMaxLengths(layout, runtime) {
|
|
88
|
+
const maxLength = runtime?.textEncoder?.maxLength;
|
|
89
|
+
if (!Number.isFinite(maxLength) || maxLength <= 0) {
|
|
90
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
|
|
91
|
+
}
|
|
92
|
+
if (layout === 'sana') {
|
|
93
|
+
return { text_encoder: maxLength };
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
|
|
97
|
+
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
98
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
99
|
+
}
|
|
100
|
+
return {
|
|
101
|
+
text_encoder: maxLength,
|
|
102
|
+
text_encoder_2: maxLength,
|
|
103
|
+
text_encoder_3: t5MaxLength,
|
|
104
|
+
};
|
|
105
|
+
}
|
|
106
|
+
|
|
61
107
|
function getTensorSize(shape) {
|
|
62
108
|
if (!Array.isArray(shape)) return 0;
|
|
63
109
|
return shape.reduce((acc, value) => acc * value, 1);
|
|
@@ -120,6 +166,49 @@ async function readTensorToFloat32(tensor) {
|
|
|
120
166
|
return new Float32Array(data);
|
|
121
167
|
}
|
|
122
168
|
|
|
169
|
+
async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
|
|
170
|
+
if (scheduler.type === 'flowmatch_euler') {
|
|
171
|
+
const sigma = scheduler.sigmas[stepIndex];
|
|
172
|
+
const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
|
|
173
|
+
const delta = sigmaNext - sigma;
|
|
174
|
+
const latentSize = getTensorSize(latentsTensor.shape);
|
|
175
|
+
const scale = options.scale ?? runScale;
|
|
176
|
+
const residualAdd = options.residualAdd ?? runResidualAdd;
|
|
177
|
+
const release = options.release ?? releaseBuffer;
|
|
178
|
+
|
|
179
|
+
const scaled = await scale(predictionTensor, delta, { count: latentSize });
|
|
180
|
+
const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
|
|
181
|
+
|
|
182
|
+
release(latentsTensor.buffer);
|
|
183
|
+
release(scaled.buffer);
|
|
184
|
+
release(predictionTensor.buffer);
|
|
185
|
+
|
|
186
|
+
return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
if (scheduler.type === 'scm') {
|
|
190
|
+
const sample = await readTensorToFloat32(latentsTensor);
|
|
191
|
+
const modelOutput = await readTensorToFloat32(predictionTensor);
|
|
192
|
+
releaseBuffer(predictionTensor.buffer);
|
|
193
|
+
releaseBuffer(latentsTensor.buffer);
|
|
194
|
+
|
|
195
|
+
const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
|
|
196
|
+
const noise = isFinalStep
|
|
197
|
+
? null
|
|
198
|
+
: generateLatents(
|
|
199
|
+
runtime.latent.width,
|
|
200
|
+
runtime.latent.height,
|
|
201
|
+
runtime.latent.channels,
|
|
202
|
+
runtime.latent.scale,
|
|
203
|
+
(options.seedBase ?? createRandomSeed()) + stepIndex + 1
|
|
204
|
+
).latents;
|
|
205
|
+
const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
|
|
206
|
+
return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
|
|
210
|
+
}
|
|
211
|
+
|
|
123
212
|
async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
|
|
124
213
|
if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
|
|
125
214
|
return cond;
|
|
@@ -251,14 +340,17 @@ export class DiffusionPipeline {
|
|
|
251
340
|
});
|
|
252
341
|
}
|
|
253
342
|
|
|
254
|
-
const
|
|
255
|
-
const
|
|
256
|
-
const
|
|
343
|
+
const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
|
|
344
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
345
|
+
const weights = {};
|
|
346
|
+
for (const key of requiredKeys) {
|
|
347
|
+
weights[key] = await this.weightLoader.loadComponentWeights(key);
|
|
348
|
+
}
|
|
257
349
|
|
|
258
350
|
this.textEncoderWeights = {
|
|
259
|
-
text_encoder,
|
|
260
|
-
text_encoder_2,
|
|
261
|
-
text_encoder_3,
|
|
351
|
+
text_encoder: weights.text_encoder ?? null,
|
|
352
|
+
text_encoder_2: weights.text_encoder_2 ?? null,
|
|
353
|
+
text_encoder_3: weights.text_encoder_3 ?? null,
|
|
262
354
|
};
|
|
263
355
|
|
|
264
356
|
return this.textEncoderWeights;
|
|
@@ -315,14 +407,9 @@ export class DiffusionPipeline {
|
|
|
315
407
|
async generateGPU(request = {}) {
|
|
316
408
|
const start = performance.now();
|
|
317
409
|
const runtime = this.diffusionState.runtime;
|
|
318
|
-
const
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
}
|
|
322
|
-
const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
|
|
323
|
-
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
324
|
-
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
325
|
-
}
|
|
410
|
+
const modelConfig = this.diffusionState.modelConfig;
|
|
411
|
+
const layout = resolveDiffusionLayout(modelConfig);
|
|
412
|
+
const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
|
|
326
413
|
|
|
327
414
|
const defaultWidth = runtime.latent.width;
|
|
328
415
|
const defaultHeight = runtime.latent.height;
|
|
@@ -346,28 +433,20 @@ export class DiffusionPipeline {
|
|
|
346
433
|
throw new Error(`Invalid diffusion steps: ${steps}`);
|
|
347
434
|
}
|
|
348
435
|
|
|
349
|
-
const modelConfig = this.diffusionState.modelConfig;
|
|
350
436
|
if (!modelConfig?.components?.transformer) {
|
|
351
437
|
throw new Error('Diffusion GPU pipeline requires transformer component config.');
|
|
352
438
|
}
|
|
353
|
-
|
|
354
|
-
|
|
439
|
+
assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
|
|
440
|
+
if (layout === 'sd3') {
|
|
441
|
+
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
355
442
|
}
|
|
356
|
-
if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
|
|
357
|
-
throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
|
|
358
|
-
}
|
|
359
|
-
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
360
443
|
|
|
361
444
|
const promptStart = performance.now();
|
|
362
445
|
const encoded = encodePrompt(
|
|
363
446
|
{ prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
|
|
364
447
|
this.tokenizers || {},
|
|
365
448
|
{
|
|
366
|
-
maxLengthByTokenizer:
|
|
367
|
-
text_encoder: clipMaxLength,
|
|
368
|
-
text_encoder_2: clipMaxLength,
|
|
369
|
-
text_encoder_3: t5MaxLength,
|
|
370
|
-
},
|
|
449
|
+
maxLengthByTokenizer: tokenizerMaxLengths,
|
|
371
450
|
}
|
|
372
451
|
);
|
|
373
452
|
|
|
@@ -410,13 +489,31 @@ export class DiffusionPipeline {
|
|
|
410
489
|
const prefillRecorder = canProfileGpu
|
|
411
490
|
? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
|
|
412
491
|
: null;
|
|
413
|
-
const condContext =
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
492
|
+
const condContext = layout === 'sana'
|
|
493
|
+
? await projectSanaContext(
|
|
494
|
+
promptCondition.context,
|
|
495
|
+
promptCondition.attentionMask,
|
|
496
|
+
transformerWeights,
|
|
497
|
+
transformerConfig,
|
|
498
|
+
runtime,
|
|
499
|
+
{ recorder: prefillRecorder }
|
|
500
|
+
)
|
|
501
|
+
: await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
|
|
418
502
|
recorder: prefillRecorder,
|
|
419
|
-
})
|
|
503
|
+
});
|
|
504
|
+
const uncondContext = shouldUseUncond && negativeCondition
|
|
505
|
+
? layout === 'sana'
|
|
506
|
+
? await projectSanaContext(
|
|
507
|
+
negativeCondition.context,
|
|
508
|
+
negativeCondition.attentionMask,
|
|
509
|
+
transformerWeights,
|
|
510
|
+
transformerConfig,
|
|
511
|
+
runtime,
|
|
512
|
+
{ recorder: prefillRecorder }
|
|
513
|
+
)
|
|
514
|
+
: await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
|
|
515
|
+
recorder: prefillRecorder,
|
|
516
|
+
})
|
|
420
517
|
: null;
|
|
421
518
|
if (prefillRecorder) {
|
|
422
519
|
prefillRecorder.submit();
|
|
@@ -428,11 +525,6 @@ export class DiffusionPipeline {
|
|
|
428
525
|
}
|
|
429
526
|
|
|
430
527
|
const scheduler = buildScheduler(runtime.scheduler, steps);
|
|
431
|
-
if (scheduler.type !== 'flowmatch_euler') {
|
|
432
|
-
throw new Error(
|
|
433
|
-
`Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
|
|
434
|
-
);
|
|
435
|
-
}
|
|
436
528
|
const latentScale = this.diffusionState.latentScale;
|
|
437
529
|
const latentChannels = this.diffusionState.latentChannels;
|
|
438
530
|
const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
|
|
@@ -463,9 +555,6 @@ export class DiffusionPipeline {
|
|
|
463
555
|
const latentSize = latentChannels * latentHeight * latentWidth;
|
|
464
556
|
for (let i = 0; i < scheduler.steps; i++) {
|
|
465
557
|
const timestep = scheduler.timesteps[i];
|
|
466
|
-
const sigma = scheduler.sigmas[i];
|
|
467
|
-
const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
|
|
468
|
-
const delta = sigmaNext - sigma;
|
|
469
558
|
const stepRecorder = canProfileGpu
|
|
470
559
|
? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
|
|
471
560
|
: null;
|
|
@@ -477,37 +566,71 @@ export class DiffusionPipeline {
|
|
|
477
566
|
? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
|
|
478
567
|
: runResidualAdd;
|
|
479
568
|
|
|
480
|
-
const
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
569
|
+
const condPred = layout === 'sana'
|
|
570
|
+
? await (async () => {
|
|
571
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
572
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
573
|
+
guidanceScale,
|
|
574
|
+
transformerWeights,
|
|
575
|
+
transformerConfig,
|
|
576
|
+
runtime,
|
|
577
|
+
{ recorder: stepRecorder }
|
|
578
|
+
);
|
|
579
|
+
return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
580
|
+
recorder: stepRecorder,
|
|
581
|
+
});
|
|
582
|
+
})()
|
|
583
|
+
: await (async () => {
|
|
584
|
+
const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
585
|
+
dim: timeEmbedDim,
|
|
586
|
+
recorder: stepRecorder,
|
|
587
|
+
});
|
|
588
|
+
const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
589
|
+
recorder: stepRecorder,
|
|
590
|
+
});
|
|
591
|
+
const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
|
|
592
|
+
recorder: stepRecorder,
|
|
593
|
+
});
|
|
594
|
+
const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
|
|
595
|
+
recorder: stepRecorder,
|
|
596
|
+
});
|
|
597
|
+
releaseStep(timeTextCond.buffer);
|
|
598
|
+
return output;
|
|
599
|
+
})();
|
|
494
600
|
|
|
495
601
|
let pred = condPred;
|
|
496
602
|
if (shouldUseUncond && uncondContext && negativeCondition) {
|
|
497
|
-
const
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
603
|
+
const uncondPred = layout === 'sana'
|
|
604
|
+
? await (async () => {
|
|
605
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
606
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
607
|
+
guidanceScale,
|
|
608
|
+
transformerWeights,
|
|
609
|
+
transformerConfig,
|
|
610
|
+
runtime,
|
|
611
|
+
{ recorder: stepRecorder }
|
|
612
|
+
);
|
|
613
|
+
return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
614
|
+
recorder: stepRecorder,
|
|
615
|
+
});
|
|
616
|
+
})()
|
|
617
|
+
: await (async () => {
|
|
618
|
+
const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
619
|
+
dim: timeEmbedDim,
|
|
620
|
+
recorder: stepRecorder,
|
|
621
|
+
});
|
|
622
|
+
const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
623
|
+
recorder: stepRecorder,
|
|
624
|
+
});
|
|
625
|
+
const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
|
|
626
|
+
recorder: stepRecorder,
|
|
627
|
+
});
|
|
628
|
+
const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
|
|
629
|
+
recorder: stepRecorder,
|
|
630
|
+
});
|
|
631
|
+
releaseStep(timeTextUncond.buffer);
|
|
632
|
+
return output;
|
|
633
|
+
})();
|
|
511
634
|
pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
|
|
512
635
|
recorder: stepRecorder,
|
|
513
636
|
release: releaseStep,
|
|
@@ -516,14 +639,20 @@ export class DiffusionPipeline {
|
|
|
516
639
|
releaseStep(condPred.buffer);
|
|
517
640
|
}
|
|
518
641
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
642
|
+
latentsTensor = await applySchedulerStep(
|
|
643
|
+
latentsTensor,
|
|
644
|
+
scheduler,
|
|
645
|
+
i,
|
|
646
|
+
timestep,
|
|
647
|
+
pred,
|
|
648
|
+
runtime,
|
|
649
|
+
{
|
|
650
|
+
scale,
|
|
651
|
+
residualAdd,
|
|
652
|
+
release: releaseStep,
|
|
653
|
+
seedBase: seed,
|
|
654
|
+
}
|
|
655
|
+
);
|
|
527
656
|
|
|
528
657
|
if (stepRecorder) {
|
|
529
658
|
stepRecorder.submit();
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import type { Tensor } from '../../../gpu/tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
3
|
+
|
|
4
|
+
export interface SanaTimestepState {
|
|
5
|
+
modulation: Tensor;
|
|
6
|
+
embeddedTimestep: Tensor;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export interface SanaTransformerOptions {
|
|
10
|
+
recorder?: CommandRecorder | null;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export declare function buildSanaTimestepConditioning(
|
|
14
|
+
timestep: number,
|
|
15
|
+
guidanceScale: number,
|
|
16
|
+
weightsEntry: any,
|
|
17
|
+
config: any,
|
|
18
|
+
runtime: any,
|
|
19
|
+
options?: SanaTransformerOptions
|
|
20
|
+
): Promise<SanaTimestepState>;
|
|
21
|
+
|
|
22
|
+
export declare function projectSanaContext(
|
|
23
|
+
context: Tensor,
|
|
24
|
+
attentionMask: Uint32Array | null | undefined,
|
|
25
|
+
weightsEntry: any,
|
|
26
|
+
config: any,
|
|
27
|
+
runtime: any,
|
|
28
|
+
options?: SanaTransformerOptions
|
|
29
|
+
): Promise<Tensor>;
|
|
30
|
+
|
|
31
|
+
export declare function runSanaTransformer(
|
|
32
|
+
latents: Tensor,
|
|
33
|
+
context: Tensor,
|
|
34
|
+
timeState: SanaTimestepState,
|
|
35
|
+
weightsEntry: any,
|
|
36
|
+
modelConfig: any,
|
|
37
|
+
runtime: any,
|
|
38
|
+
options?: SanaTransformerOptions
|
|
39
|
+
): Promise<Tensor>;
|
|
40
|
+
|
|
41
|
+
export declare function buildSanaConditioning(
|
|
42
|
+
context: Tensor,
|
|
43
|
+
attentionMask: Uint32Array | null | undefined,
|
|
44
|
+
timestep: number,
|
|
45
|
+
guidanceScale: number,
|
|
46
|
+
weightsEntry: any,
|
|
47
|
+
modelConfig: any,
|
|
48
|
+
runtime: any,
|
|
49
|
+
options?: SanaTransformerOptions
|
|
50
|
+
): Promise<{
|
|
51
|
+
context: Tensor;
|
|
52
|
+
timeState: SanaTimestepState;
|
|
53
|
+
}>;
|