@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -14,10 +14,11 @@ import {
14
14
  projectContext,
15
15
  assertClipHiddenActivationSupported,
16
16
  } from './text-encoder-gpu.js';
17
- import { buildScheduler } from './scheduler.js';
17
+ import { buildScheduler, stepScmScheduler } from './scheduler.js';
18
18
  import { decodeLatents } from './vae.js';
19
19
  import { createDiffusionWeightLoader } from './weights.js';
20
20
  import { runSD3Transformer } from './sd3-transformer.js';
21
+ import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
21
22
  import { createSD3WeightResolver } from './sd3-weights.js';
22
23
  import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
23
24
  import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
27
28
  import { f16ToF32 } from '../../../loader/dtype-utils.js';
28
29
 
29
30
  const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
31
+ const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
32
+ const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
30
33
 
31
34
  function createRandomSeed() {
32
35
  if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
@@ -49,6 +52,18 @@ function generateLatents(width, height, channels, latentScale, seed) {
49
52
  return { latents, latentWidth, latentHeight };
50
53
  }
51
54
 
55
+ function generateNoiseVector(size, seed) {
56
+ if (!Number.isFinite(size) || size <= 0) {
57
+ throw new Error(`generateNoiseVector requires a positive size, got ${size}.`);
58
+ }
59
+ const out = new Float32Array(size);
60
+ const rand = createRng(seed ?? createRandomSeed());
61
+ for (let i = 0; i < size; i++) {
62
+ out[i] = sampleNormal(rand);
63
+ }
64
+ return out;
65
+ }
66
+
52
67
  function extractTokenSet(tokensByEncoder, key) {
53
68
  const output = {};
54
69
  for (const [name, entry] of Object.entries(tokensByEncoder || {})) {
@@ -58,6 +73,49 @@ function extractTokenSet(tokensByEncoder, key) {
58
73
  return output;
59
74
  }
60
75
 
76
+ function resolveDiffusionLayout(modelConfig) {
77
+ return modelConfig?.layout ?? 'sd3';
78
+ }
79
+
80
+ function getTextEncoderKeysForLayout(layout) {
81
+ if (layout === 'sana') {
82
+ return SANA_TEXT_ENCODER_KEYS;
83
+ }
84
+ return SD3_TEXT_ENCODER_KEYS;
85
+ }
86
+
87
+ function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
88
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
89
+ for (const key of requiredKeys) {
90
+ if (!modelConfig?.components?.[key]) {
91
+ throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
92
+ }
93
+ if (!tokenizers?.[key]) {
94
+ throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
95
+ }
96
+ }
97
+ }
98
+
99
+ function buildTokenizerMaxLengths(layout, runtime) {
100
+ const maxLength = runtime?.textEncoder?.maxLength;
101
+ if (!Number.isFinite(maxLength) || maxLength <= 0) {
102
+ throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
103
+ }
104
+ if (layout === 'sana') {
105
+ return { text_encoder: maxLength };
106
+ }
107
+
108
+ const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
109
+ if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
110
+ throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
111
+ }
112
+ return {
113
+ text_encoder: maxLength,
114
+ text_encoder_2: maxLength,
115
+ text_encoder_3: t5MaxLength,
116
+ };
117
+ }
118
+
61
119
  function getTensorSize(shape) {
62
120
  if (!Array.isArray(shape)) return 0;
63
121
  return shape.reduce((acc, value) => acc * value, 1);
@@ -120,6 +178,46 @@ async function readTensorToFloat32(tensor) {
120
178
  return new Float32Array(data);
121
179
  }
122
180
 
181
+ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
182
+ if (scheduler.type === 'flowmatch_euler') {
183
+ const sigma = scheduler.sigmas[stepIndex];
184
+ const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
185
+ const delta = sigmaNext - sigma;
186
+ const latentSize = getTensorSize(latentsTensor.shape);
187
+ const scale = options.scale ?? runScale;
188
+ const residualAdd = options.residualAdd ?? runResidualAdd;
189
+ const release = options.release ?? releaseBuffer;
190
+
191
+ const scaled = await scale(predictionTensor, delta, { count: latentSize });
192
+ const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
193
+
194
+ release(latentsTensor.buffer);
195
+ release(scaled.buffer);
196
+ release(predictionTensor.buffer);
197
+
198
+ return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
199
+ }
200
+
201
+ if (scheduler.type === 'scm') {
202
+ const sample = await readTensorToFloat32(latentsTensor);
203
+ const modelOutput = await readTensorToFloat32(predictionTensor);
204
+ releaseBuffer(predictionTensor.buffer);
205
+ releaseBuffer(latentsTensor.buffer);
206
+
207
+ const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
208
+ const noise = isFinalStep
209
+ ? null
210
+ : generateNoiseVector(
211
+ sample.length,
212
+ (options.seedBase ?? createRandomSeed()) + stepIndex + 1
213
+ );
214
+ const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
215
+ return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
216
+ }
217
+
218
+ throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
219
+ }
220
+
123
221
  async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
124
222
  if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
125
223
  return cond;
@@ -251,14 +349,17 @@ export class DiffusionPipeline {
251
349
  });
252
350
  }
253
351
 
254
- const text_encoder = await this.weightLoader.loadComponentWeights('text_encoder');
255
- const text_encoder_2 = await this.weightLoader.loadComponentWeights('text_encoder_2');
256
- const text_encoder_3 = await this.weightLoader.loadComponentWeights('text_encoder_3');
352
+ const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
353
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
354
+ const weights = {};
355
+ for (const key of requiredKeys) {
356
+ weights[key] = await this.weightLoader.loadComponentWeights(key);
357
+ }
257
358
 
258
359
  this.textEncoderWeights = {
259
- text_encoder,
260
- text_encoder_2,
261
- text_encoder_3,
360
+ text_encoder: weights.text_encoder ?? null,
361
+ text_encoder_2: weights.text_encoder_2 ?? null,
362
+ text_encoder_3: weights.text_encoder_3 ?? null,
262
363
  };
263
364
 
264
365
  return this.textEncoderWeights;
@@ -315,14 +416,9 @@ export class DiffusionPipeline {
315
416
  async generateGPU(request = {}) {
316
417
  const start = performance.now();
317
418
  const runtime = this.diffusionState.runtime;
318
- const clipMaxLength = runtime.textEncoder?.maxLength;
319
- if (!Number.isFinite(clipMaxLength) || clipMaxLength <= 0) {
320
- throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
321
- }
322
- const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
323
- if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
324
- throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
325
- }
419
+ const modelConfig = this.diffusionState.modelConfig;
420
+ const layout = resolveDiffusionLayout(modelConfig);
421
+ const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
326
422
 
327
423
  const defaultWidth = runtime.latent.width;
328
424
  const defaultHeight = runtime.latent.height;
@@ -346,28 +442,20 @@ export class DiffusionPipeline {
346
442
  throw new Error(`Invalid diffusion steps: ${steps}`);
347
443
  }
348
444
 
349
- const modelConfig = this.diffusionState.modelConfig;
350
445
  if (!modelConfig?.components?.transformer) {
351
446
  throw new Error('Diffusion GPU pipeline requires transformer component config.');
352
447
  }
353
- if (!modelConfig?.components?.text_encoder || !modelConfig?.components?.text_encoder_2 || !modelConfig?.components?.text_encoder_3) {
354
- throw new Error('Diffusion GPU pipeline requires text encoder components (text_encoder, text_encoder_2, text_encoder_3).');
355
- }
356
- if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
357
- throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
448
+ assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
449
+ if (layout === 'sd3') {
450
+ assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
358
451
  }
359
- assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
360
452
 
361
453
  const promptStart = performance.now();
362
454
  const encoded = encodePrompt(
363
455
  { prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
364
456
  this.tokenizers || {},
365
457
  {
366
- maxLengthByTokenizer: {
367
- text_encoder: clipMaxLength,
368
- text_encoder_2: clipMaxLength,
369
- text_encoder_3: t5MaxLength,
370
- },
458
+ maxLengthByTokenizer: tokenizerMaxLengths,
371
459
  }
372
460
  );
373
461
 
@@ -410,13 +498,31 @@ export class DiffusionPipeline {
410
498
  const prefillRecorder = canProfileGpu
411
499
  ? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
412
500
  : null;
413
- const condContext = await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
414
- recorder: prefillRecorder,
415
- });
416
- const uncondContext = shouldUseUncond && negativeCondition
417
- ? await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
501
+ const condContext = layout === 'sana'
502
+ ? await projectSanaContext(
503
+ promptCondition.context,
504
+ promptCondition.attentionMask,
505
+ transformerWeights,
506
+ transformerConfig,
507
+ runtime,
508
+ { recorder: prefillRecorder }
509
+ )
510
+ : await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
418
511
  recorder: prefillRecorder,
419
- })
512
+ });
513
+ const uncondContext = shouldUseUncond && negativeCondition
514
+ ? layout === 'sana'
515
+ ? await projectSanaContext(
516
+ negativeCondition.context,
517
+ negativeCondition.attentionMask,
518
+ transformerWeights,
519
+ transformerConfig,
520
+ runtime,
521
+ { recorder: prefillRecorder }
522
+ )
523
+ : await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
524
+ recorder: prefillRecorder,
525
+ })
420
526
  : null;
421
527
  if (prefillRecorder) {
422
528
  prefillRecorder.submit();
@@ -428,11 +534,6 @@ export class DiffusionPipeline {
428
534
  }
429
535
 
430
536
  const scheduler = buildScheduler(runtime.scheduler, steps);
431
- if (scheduler.type !== 'flowmatch_euler') {
432
- throw new Error(
433
- `Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
434
- );
435
- }
436
537
  const latentScale = this.diffusionState.latentScale;
437
538
  const latentChannels = this.diffusionState.latentChannels;
438
539
  const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
@@ -463,9 +564,6 @@ export class DiffusionPipeline {
463
564
  const latentSize = latentChannels * latentHeight * latentWidth;
464
565
  for (let i = 0; i < scheduler.steps; i++) {
465
566
  const timestep = scheduler.timesteps[i];
466
- const sigma = scheduler.sigmas[i];
467
- const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
468
- const delta = sigmaNext - sigma;
469
567
  const stepRecorder = canProfileGpu
470
568
  ? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
471
569
  : null;
@@ -477,37 +575,71 @@ export class DiffusionPipeline {
477
575
  ? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
478
576
  : runResidualAdd;
479
577
 
480
- const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
481
- dim: timeEmbedDim,
482
- recorder: stepRecorder,
483
- });
484
- const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
485
- recorder: stepRecorder,
486
- });
487
- const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
488
- recorder: stepRecorder,
489
- });
490
- const condPred = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
491
- recorder: stepRecorder,
492
- });
493
- releaseStep(timeTextCond.buffer);
578
+ const condPred = layout === 'sana'
579
+ ? await (async () => {
580
+ const timeState = await buildSanaTimestepConditioning(
581
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
582
+ guidanceScale,
583
+ transformerWeights,
584
+ transformerConfig,
585
+ runtime,
586
+ { recorder: stepRecorder }
587
+ );
588
+ return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
589
+ recorder: stepRecorder,
590
+ });
591
+ })()
592
+ : await (async () => {
593
+ const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
594
+ dim: timeEmbedDim,
595
+ recorder: stepRecorder,
596
+ });
597
+ const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
598
+ recorder: stepRecorder,
599
+ });
600
+ const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
601
+ recorder: stepRecorder,
602
+ });
603
+ const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
604
+ recorder: stepRecorder,
605
+ });
606
+ releaseStep(timeTextCond.buffer);
607
+ return output;
608
+ })();
494
609
 
495
610
  let pred = condPred;
496
611
  if (shouldUseUncond && uncondContext && negativeCondition) {
497
- const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
498
- dim: timeEmbedDim,
499
- recorder: stepRecorder,
500
- });
501
- const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
502
- recorder: stepRecorder,
503
- });
504
- const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
505
- recorder: stepRecorder,
506
- });
507
- const uncondPred = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
508
- recorder: stepRecorder,
509
- });
510
- releaseStep(timeTextUncond.buffer);
612
+ const uncondPred = layout === 'sana'
613
+ ? await (async () => {
614
+ const timeState = await buildSanaTimestepConditioning(
615
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
616
+ guidanceScale,
617
+ transformerWeights,
618
+ transformerConfig,
619
+ runtime,
620
+ { recorder: stepRecorder }
621
+ );
622
+ return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
623
+ recorder: stepRecorder,
624
+ });
625
+ })()
626
+ : await (async () => {
627
+ const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
628
+ dim: timeEmbedDim,
629
+ recorder: stepRecorder,
630
+ });
631
+ const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
632
+ recorder: stepRecorder,
633
+ });
634
+ const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
635
+ recorder: stepRecorder,
636
+ });
637
+ const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
638
+ recorder: stepRecorder,
639
+ });
640
+ releaseStep(timeTextUncond.buffer);
641
+ return output;
642
+ })();
511
643
  pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
512
644
  recorder: stepRecorder,
513
645
  release: releaseStep,
@@ -516,14 +648,20 @@ export class DiffusionPipeline {
516
648
  releaseStep(condPred.buffer);
517
649
  }
518
650
 
519
- const scaled = await scale(pred, delta, { count: latentSize });
520
- const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
521
-
522
- releaseStep(latentsTensor.buffer);
523
- releaseStep(scaled.buffer);
524
- releaseStep(pred.buffer);
525
-
526
- latentsTensor = createTensor(updated.buffer, updated.dtype, [latentChannels, latentHeight, latentWidth], 'sd3_latents');
651
+ latentsTensor = await applySchedulerStep(
652
+ latentsTensor,
653
+ scheduler,
654
+ i,
655
+ timestep,
656
+ pred,
657
+ runtime,
658
+ {
659
+ scale,
660
+ residualAdd,
661
+ release: releaseStep,
662
+ seedBase: seed,
663
+ }
664
+ );
527
665
 
528
666
  if (stepRecorder) {
529
667
  stepRecorder.submit();
@@ -0,0 +1,53 @@
1
+ import type { Tensor } from '../../../gpu/tensor.js';
2
+ import type { CommandRecorder } from '../../../gpu/command-recorder.js';
3
+
4
+ export interface SanaTimestepState {
5
+ modulation: Tensor;
6
+ embeddedTimestep: Tensor;
7
+ }
8
+
9
+ export interface SanaTransformerOptions {
10
+ recorder?: CommandRecorder | null;
11
+ }
12
+
13
+ export declare function buildSanaTimestepConditioning(
14
+ timestep: number,
15
+ guidanceScale: number,
16
+ weightsEntry: any,
17
+ config: any,
18
+ runtime: any,
19
+ options?: SanaTransformerOptions
20
+ ): Promise<SanaTimestepState>;
21
+
22
+ export declare function projectSanaContext(
23
+ context: Tensor,
24
+ attentionMask: Uint32Array | null | undefined,
25
+ weightsEntry: any,
26
+ config: any,
27
+ runtime: any,
28
+ options?: SanaTransformerOptions
29
+ ): Promise<Tensor>;
30
+
31
+ export declare function runSanaTransformer(
32
+ latents: Tensor,
33
+ context: Tensor,
34
+ timeState: SanaTimestepState,
35
+ weightsEntry: any,
36
+ modelConfig: any,
37
+ runtime: any,
38
+ options?: SanaTransformerOptions
39
+ ): Promise<Tensor>;
40
+
41
+ export declare function buildSanaConditioning(
42
+ context: Tensor,
43
+ attentionMask: Uint32Array | null | undefined,
44
+ timestep: number,
45
+ guidanceScale: number,
46
+ weightsEntry: any,
47
+ modelConfig: any,
48
+ runtime: any,
49
+ options?: SanaTransformerOptions
50
+ ): Promise<{
51
+ context: Tensor;
52
+ timeState: SanaTimestepState;
53
+ }>;