@simulatte/doppler 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. package/README.md +11 -5
  2. package/package.json +27 -4
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.d.ts +80 -0
  6. package/src/client/doppler-api.js +298 -0
  7. package/src/client/doppler-provider/types.js +1 -1
  8. package/src/client/doppler-registry.d.ts +23 -0
  9. package/src/client/doppler-registry.js +88 -0
  10. package/src/client/doppler-registry.json +39 -0
  11. package/src/config/execution-contract-check.d.ts +82 -0
  12. package/src/config/execution-contract-check.js +317 -0
  13. package/src/config/execution-v0-contract-check.d.ts +94 -0
  14. package/src/config/execution-v0-contract-check.js +251 -0
  15. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  16. package/src/config/execution-v0-graph-contract-check.js +64 -0
  17. package/src/config/kernel-path-contract-check.d.ts +76 -0
  18. package/src/config/kernel-path-contract-check.js +479 -0
  19. package/src/config/kernel-path-loader.d.ts +16 -0
  20. package/src/config/kernel-path-loader.js +54 -0
  21. package/src/config/kernels/kernel-ref-digests.js +12 -0
  22. package/src/config/kernels/registry.json +556 -0
  23. package/src/config/loader.js +90 -67
  24. package/src/config/merge-contract-check.d.ts +16 -0
  25. package/src/config/merge-contract-check.js +321 -0
  26. package/src/config/merge-helpers.d.ts +58 -0
  27. package/src/config/merge-helpers.js +54 -0
  28. package/src/config/merge.js +3 -6
  29. package/src/config/presets/models/janus-text.json +27 -0
  30. package/src/config/quantization-contract-check.d.ts +12 -0
  31. package/src/config/quantization-contract-check.js +91 -0
  32. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  33. package/src/config/required-inference-fields-contract-check.js +231 -0
  34. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  35. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  36. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  37. package/src/config/schema/conversion-report.schema.js +108 -0
  38. package/src/config/schema/doppler.schema.js +12 -18
  39. package/src/config/schema/index.d.ts +22 -0
  40. package/src/config/schema/index.js +18 -0
  41. package/src/converter/core.d.ts +10 -0
  42. package/src/converter/core.js +49 -11
  43. package/src/converter/parsers/diffusion.js +63 -3
  44. package/src/converter/tokenizer-utils.js +17 -3
  45. package/src/formats/rdrr/validation.js +13 -0
  46. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  47. package/src/gpu/kernels/depthwise_conv2d.js +98 -0
  48. package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
  49. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
  50. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  51. package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
  52. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
  53. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
  54. package/src/gpu/kernels/index.d.ts +30 -0
  55. package/src/gpu/kernels/index.js +25 -0
  56. package/src/gpu/kernels/relu.d.ts +18 -0
  57. package/src/gpu/kernels/relu.js +45 -0
  58. package/src/gpu/kernels/relu.wgsl +21 -0
  59. package/src/gpu/kernels/relu_f16.wgsl +23 -0
  60. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  61. package/src/gpu/kernels/repeat_channels.js +60 -0
  62. package/src/gpu/kernels/repeat_channels.wgsl +29 -0
  63. package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
  64. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  65. package/src/gpu/kernels/sana_linear_attention.js +122 -0
  66. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
  67. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
  68. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
  69. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
  70. package/src/index-browser.d.ts +1 -0
  71. package/src/index-browser.js +2 -1
  72. package/src/index.d.ts +1 -0
  73. package/src/index.js +2 -1
  74. package/src/inference/browser-harness.js +164 -38
  75. package/src/inference/pipelines/diffusion/init.js +14 -0
  76. package/src/inference/pipelines/diffusion/pipeline.js +206 -77
  77. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  78. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  79. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  80. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  81. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
  82. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
  83. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  84. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  85. package/src/inference/pipelines/diffusion/vae.js +782 -78
  86. package/src/inference/pipelines/text/config.d.ts +5 -0
  87. package/src/inference/pipelines/text/config.js +1 -1
  88. package/src/inference/pipelines/text/execution-v0.js +141 -101
  89. package/src/inference/pipelines/text/init.js +41 -10
  90. package/src/inference/pipelines/text.js +7 -1
  91. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  92. package/src/rules/execution-rules-contract-check.js +245 -0
  93. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  94. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  95. package/src/rules/kernels/relu.rules.json +6 -0
  96. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  97. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  98. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  99. package/src/rules/layer-pattern-contract-check.js +231 -0
  100. package/src/rules/rule-registry.d.ts +28 -0
  101. package/src/rules/rule-registry.js +38 -0
  102. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  103. package/src/tooling/conversion-config-materializer.js +99 -0
  104. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  105. package/src/tooling/lean-execution-contract-runner.js +158 -0
  106. package/src/tooling/lean-execution-contract.d.ts +16 -0
  107. package/src/tooling/lean-execution-contract.js +81 -0
  108. package/src/tooling/node-convert.d.ts +10 -0
  109. package/src/tooling/node-converter.js +59 -0
  110. package/src/tooling/node-webgpu.js +30 -9
  111. package/src/version.d.ts +2 -0
  112. package/src/version.js +2 -0
  113. package/tools/convert-safetensors-node.js +47 -0
  114. package/tools/doppler-cli.js +167 -6
@@ -14,10 +14,11 @@ import {
14
14
  projectContext,
15
15
  assertClipHiddenActivationSupported,
16
16
  } from './text-encoder-gpu.js';
17
- import { buildScheduler } from './scheduler.js';
17
+ import { buildScheduler, stepScmScheduler } from './scheduler.js';
18
18
  import { decodeLatents } from './vae.js';
19
19
  import { createDiffusionWeightLoader } from './weights.js';
20
20
  import { runSD3Transformer } from './sd3-transformer.js';
21
+ import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
21
22
  import { createSD3WeightResolver } from './sd3-weights.js';
22
23
  import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
23
24
  import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
27
28
  import { f16ToF32 } from '../../../loader/dtype-utils.js';
28
29
 
29
30
  const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
31
+ const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
32
+ const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
30
33
 
31
34
  function createRandomSeed() {
32
35
  if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
@@ -58,6 +61,49 @@ function extractTokenSet(tokensByEncoder, key) {
58
61
  return output;
59
62
  }
60
63
 
64
+ function resolveDiffusionLayout(modelConfig) {
65
+ return modelConfig?.layout ?? 'sd3';
66
+ }
67
+
68
+ function getTextEncoderKeysForLayout(layout) {
69
+ if (layout === 'sana') {
70
+ return SANA_TEXT_ENCODER_KEYS;
71
+ }
72
+ return SD3_TEXT_ENCODER_KEYS;
73
+ }
74
+
75
+ function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
76
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
77
+ for (const key of requiredKeys) {
78
+ if (!modelConfig?.components?.[key]) {
79
+ throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
80
+ }
81
+ if (!tokenizers?.[key]) {
82
+ throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
83
+ }
84
+ }
85
+ }
86
+
87
+ function buildTokenizerMaxLengths(layout, runtime) {
88
+ const maxLength = runtime?.textEncoder?.maxLength;
89
+ if (!Number.isFinite(maxLength) || maxLength <= 0) {
90
+ throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
91
+ }
92
+ if (layout === 'sana') {
93
+ return { text_encoder: maxLength };
94
+ }
95
+
96
+ const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
97
+ if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
98
+ throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
99
+ }
100
+ return {
101
+ text_encoder: maxLength,
102
+ text_encoder_2: maxLength,
103
+ text_encoder_3: t5MaxLength,
104
+ };
105
+ }
106
+
61
107
  function getTensorSize(shape) {
62
108
  if (!Array.isArray(shape)) return 0;
63
109
  return shape.reduce((acc, value) => acc * value, 1);
@@ -120,6 +166,49 @@ async function readTensorToFloat32(tensor) {
120
166
  return new Float32Array(data);
121
167
  }
122
168
 
169
+ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
170
+ if (scheduler.type === 'flowmatch_euler') {
171
+ const sigma = scheduler.sigmas[stepIndex];
172
+ const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
173
+ const delta = sigmaNext - sigma;
174
+ const latentSize = getTensorSize(latentsTensor.shape);
175
+ const scale = options.scale ?? runScale;
176
+ const residualAdd = options.residualAdd ?? runResidualAdd;
177
+ const release = options.release ?? releaseBuffer;
178
+
179
+ const scaled = await scale(predictionTensor, delta, { count: latentSize });
180
+ const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
181
+
182
+ release(latentsTensor.buffer);
183
+ release(scaled.buffer);
184
+ release(predictionTensor.buffer);
185
+
186
+ return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
187
+ }
188
+
189
+ if (scheduler.type === 'scm') {
190
+ const sample = await readTensorToFloat32(latentsTensor);
191
+ const modelOutput = await readTensorToFloat32(predictionTensor);
192
+ releaseBuffer(predictionTensor.buffer);
193
+ releaseBuffer(latentsTensor.buffer);
194
+
195
+ const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
196
+ const noise = isFinalStep
197
+ ? null
198
+ : generateLatents(
199
+ runtime.latent.width,
200
+ runtime.latent.height,
201
+ runtime.latent.channels,
202
+ runtime.latent.scale,
203
+ (options.seedBase ?? createRandomSeed()) + stepIndex + 1
204
+ ).latents;
205
+ const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
206
+ return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
207
+ }
208
+
209
+ throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
210
+ }
211
+
123
212
  async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
124
213
  if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
125
214
  return cond;
@@ -251,14 +340,17 @@ export class DiffusionPipeline {
251
340
  });
252
341
  }
253
342
 
254
- const text_encoder = await this.weightLoader.loadComponentWeights('text_encoder');
255
- const text_encoder_2 = await this.weightLoader.loadComponentWeights('text_encoder_2');
256
- const text_encoder_3 = await this.weightLoader.loadComponentWeights('text_encoder_3');
343
+ const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
344
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
345
+ const weights = {};
346
+ for (const key of requiredKeys) {
347
+ weights[key] = await this.weightLoader.loadComponentWeights(key);
348
+ }
257
349
 
258
350
  this.textEncoderWeights = {
259
- text_encoder,
260
- text_encoder_2,
261
- text_encoder_3,
351
+ text_encoder: weights.text_encoder ?? null,
352
+ text_encoder_2: weights.text_encoder_2 ?? null,
353
+ text_encoder_3: weights.text_encoder_3 ?? null,
262
354
  };
263
355
 
264
356
  return this.textEncoderWeights;
@@ -315,14 +407,9 @@ export class DiffusionPipeline {
315
407
  async generateGPU(request = {}) {
316
408
  const start = performance.now();
317
409
  const runtime = this.diffusionState.runtime;
318
- const clipMaxLength = runtime.textEncoder?.maxLength;
319
- if (!Number.isFinite(clipMaxLength) || clipMaxLength <= 0) {
320
- throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
321
- }
322
- const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
323
- if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
324
- throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
325
- }
410
+ const modelConfig = this.diffusionState.modelConfig;
411
+ const layout = resolveDiffusionLayout(modelConfig);
412
+ const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
326
413
 
327
414
  const defaultWidth = runtime.latent.width;
328
415
  const defaultHeight = runtime.latent.height;
@@ -346,28 +433,20 @@ export class DiffusionPipeline {
346
433
  throw new Error(`Invalid diffusion steps: ${steps}`);
347
434
  }
348
435
 
349
- const modelConfig = this.diffusionState.modelConfig;
350
436
  if (!modelConfig?.components?.transformer) {
351
437
  throw new Error('Diffusion GPU pipeline requires transformer component config.');
352
438
  }
353
- if (!modelConfig?.components?.text_encoder || !modelConfig?.components?.text_encoder_2 || !modelConfig?.components?.text_encoder_3) {
354
- throw new Error('Diffusion GPU pipeline requires text encoder components (text_encoder, text_encoder_2, text_encoder_3).');
439
+ assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
440
+ if (layout === 'sd3') {
441
+ assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
355
442
  }
356
- if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
357
- throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
358
- }
359
- assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
360
443
 
361
444
  const promptStart = performance.now();
362
445
  const encoded = encodePrompt(
363
446
  { prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
364
447
  this.tokenizers || {},
365
448
  {
366
- maxLengthByTokenizer: {
367
- text_encoder: clipMaxLength,
368
- text_encoder_2: clipMaxLength,
369
- text_encoder_3: t5MaxLength,
370
- },
449
+ maxLengthByTokenizer: tokenizerMaxLengths,
371
450
  }
372
451
  );
373
452
 
@@ -410,13 +489,31 @@ export class DiffusionPipeline {
410
489
  const prefillRecorder = canProfileGpu
411
490
  ? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
412
491
  : null;
413
- const condContext = await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
414
- recorder: prefillRecorder,
415
- });
416
- const uncondContext = shouldUseUncond && negativeCondition
417
- ? await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
492
+ const condContext = layout === 'sana'
493
+ ? await projectSanaContext(
494
+ promptCondition.context,
495
+ promptCondition.attentionMask,
496
+ transformerWeights,
497
+ transformerConfig,
498
+ runtime,
499
+ { recorder: prefillRecorder }
500
+ )
501
+ : await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
418
502
  recorder: prefillRecorder,
419
- })
503
+ });
504
+ const uncondContext = shouldUseUncond && negativeCondition
505
+ ? layout === 'sana'
506
+ ? await projectSanaContext(
507
+ negativeCondition.context,
508
+ negativeCondition.attentionMask,
509
+ transformerWeights,
510
+ transformerConfig,
511
+ runtime,
512
+ { recorder: prefillRecorder }
513
+ )
514
+ : await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
515
+ recorder: prefillRecorder,
516
+ })
420
517
  : null;
421
518
  if (prefillRecorder) {
422
519
  prefillRecorder.submit();
@@ -428,11 +525,6 @@ export class DiffusionPipeline {
428
525
  }
429
526
 
430
527
  const scheduler = buildScheduler(runtime.scheduler, steps);
431
- if (scheduler.type !== 'flowmatch_euler') {
432
- throw new Error(
433
- `Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
434
- );
435
- }
436
528
  const latentScale = this.diffusionState.latentScale;
437
529
  const latentChannels = this.diffusionState.latentChannels;
438
530
  const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
@@ -463,9 +555,6 @@ export class DiffusionPipeline {
463
555
  const latentSize = latentChannels * latentHeight * latentWidth;
464
556
  for (let i = 0; i < scheduler.steps; i++) {
465
557
  const timestep = scheduler.timesteps[i];
466
- const sigma = scheduler.sigmas[i];
467
- const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
468
- const delta = sigmaNext - sigma;
469
558
  const stepRecorder = canProfileGpu
470
559
  ? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
471
560
  : null;
@@ -477,37 +566,71 @@ export class DiffusionPipeline {
477
566
  ? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
478
567
  : runResidualAdd;
479
568
 
480
- const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
481
- dim: timeEmbedDim,
482
- recorder: stepRecorder,
483
- });
484
- const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
485
- recorder: stepRecorder,
486
- });
487
- const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
488
- recorder: stepRecorder,
489
- });
490
- const condPred = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
491
- recorder: stepRecorder,
492
- });
493
- releaseStep(timeTextCond.buffer);
569
+ const condPred = layout === 'sana'
570
+ ? await (async () => {
571
+ const timeState = await buildSanaTimestepConditioning(
572
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
573
+ guidanceScale,
574
+ transformerWeights,
575
+ transformerConfig,
576
+ runtime,
577
+ { recorder: stepRecorder }
578
+ );
579
+ return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
580
+ recorder: stepRecorder,
581
+ });
582
+ })()
583
+ : await (async () => {
584
+ const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
585
+ dim: timeEmbedDim,
586
+ recorder: stepRecorder,
587
+ });
588
+ const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
589
+ recorder: stepRecorder,
590
+ });
591
+ const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
592
+ recorder: stepRecorder,
593
+ });
594
+ const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
595
+ recorder: stepRecorder,
596
+ });
597
+ releaseStep(timeTextCond.buffer);
598
+ return output;
599
+ })();
494
600
 
495
601
  let pred = condPred;
496
602
  if (shouldUseUncond && uncondContext && negativeCondition) {
497
- const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
498
- dim: timeEmbedDim,
499
- recorder: stepRecorder,
500
- });
501
- const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
502
- recorder: stepRecorder,
503
- });
504
- const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
505
- recorder: stepRecorder,
506
- });
507
- const uncondPred = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
508
- recorder: stepRecorder,
509
- });
510
- releaseStep(timeTextUncond.buffer);
603
+ const uncondPred = layout === 'sana'
604
+ ? await (async () => {
605
+ const timeState = await buildSanaTimestepConditioning(
606
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
607
+ guidanceScale,
608
+ transformerWeights,
609
+ transformerConfig,
610
+ runtime,
611
+ { recorder: stepRecorder }
612
+ );
613
+ return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
614
+ recorder: stepRecorder,
615
+ });
616
+ })()
617
+ : await (async () => {
618
+ const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
619
+ dim: timeEmbedDim,
620
+ recorder: stepRecorder,
621
+ });
622
+ const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
623
+ recorder: stepRecorder,
624
+ });
625
+ const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
626
+ recorder: stepRecorder,
627
+ });
628
+ const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
629
+ recorder: stepRecorder,
630
+ });
631
+ releaseStep(timeTextUncond.buffer);
632
+ return output;
633
+ })();
511
634
  pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
512
635
  recorder: stepRecorder,
513
636
  release: releaseStep,
@@ -516,14 +639,20 @@ export class DiffusionPipeline {
516
639
  releaseStep(condPred.buffer);
517
640
  }
518
641
 
519
- const scaled = await scale(pred, delta, { count: latentSize });
520
- const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
521
-
522
- releaseStep(latentsTensor.buffer);
523
- releaseStep(scaled.buffer);
524
- releaseStep(pred.buffer);
525
-
526
- latentsTensor = createTensor(updated.buffer, updated.dtype, [latentChannels, latentHeight, latentWidth], 'sd3_latents');
642
+ latentsTensor = await applySchedulerStep(
643
+ latentsTensor,
644
+ scheduler,
645
+ i,
646
+ timestep,
647
+ pred,
648
+ runtime,
649
+ {
650
+ scale,
651
+ residualAdd,
652
+ release: releaseStep,
653
+ seedBase: seed,
654
+ }
655
+ );
527
656
 
528
657
  if (stepRecorder) {
529
658
  stepRecorder.submit();
@@ -0,0 +1,53 @@
1
+ import type { Tensor } from '../../../gpu/tensor.js';
2
+ import type { CommandRecorder } from '../../../gpu/command-recorder.js';
3
+
4
+ export interface SanaTimestepState {
5
+ modulation: Tensor;
6
+ embeddedTimestep: Tensor;
7
+ }
8
+
9
+ export interface SanaTransformerOptions {
10
+ recorder?: CommandRecorder | null;
11
+ }
12
+
13
+ export declare function buildSanaTimestepConditioning(
14
+ timestep: number,
15
+ guidanceScale: number,
16
+ weightsEntry: any,
17
+ config: any,
18
+ runtime: any,
19
+ options?: SanaTransformerOptions
20
+ ): Promise<SanaTimestepState>;
21
+
22
+ export declare function projectSanaContext(
23
+ context: Tensor,
24
+ attentionMask: Uint32Array | null | undefined,
25
+ weightsEntry: any,
26
+ config: any,
27
+ runtime: any,
28
+ options?: SanaTransformerOptions
29
+ ): Promise<Tensor>;
30
+
31
+ export declare function runSanaTransformer(
32
+ latents: Tensor,
33
+ context: Tensor,
34
+ timeState: SanaTimestepState,
35
+ weightsEntry: any,
36
+ modelConfig: any,
37
+ runtime: any,
38
+ options?: SanaTransformerOptions
39
+ ): Promise<Tensor>;
40
+
41
+ export declare function buildSanaConditioning(
42
+ context: Tensor,
43
+ attentionMask: Uint32Array | null | undefined,
44
+ timestep: number,
45
+ guidanceScale: number,
46
+ weightsEntry: any,
47
+ modelConfig: any,
48
+ runtime: any,
49
+ options?: SanaTransformerOptions
50
+ ): Promise<{
51
+ context: Tensor;
52
+ timeState: SanaTimestepState;
53
+ }>;