@simulatte/doppler 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. package/README.md +4 -3
  2. package/package.json +25 -4
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +12 -0
  18. package/src/config/kernels/registry.json +556 -0
  19. package/src/config/loader.js +50 -46
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +3 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/quantization-contract-check.d.ts +12 -0
  27. package/src/config/quantization-contract-check.js +91 -0
  28. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  29. package/src/config/required-inference-fields-contract-check.js +231 -0
  30. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  31. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  32. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  33. package/src/config/schema/conversion-report.schema.js +108 -0
  34. package/src/config/schema/doppler.schema.js +12 -18
  35. package/src/config/schema/index.d.ts +22 -0
  36. package/src/config/schema/index.js +18 -0
  37. package/src/converter/core.d.ts +10 -0
  38. package/src/converter/core.js +27 -2
  39. package/src/converter/parsers/diffusion.js +63 -3
  40. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  41. package/src/gpu/kernels/depthwise_conv2d.js +98 -0
  42. package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
  43. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
  44. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  45. package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
  46. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
  47. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
  48. package/src/gpu/kernels/index.d.ts +30 -0
  49. package/src/gpu/kernels/index.js +25 -0
  50. package/src/gpu/kernels/relu.d.ts +18 -0
  51. package/src/gpu/kernels/relu.js +45 -0
  52. package/src/gpu/kernels/relu.wgsl +21 -0
  53. package/src/gpu/kernels/relu_f16.wgsl +23 -0
  54. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  55. package/src/gpu/kernels/repeat_channels.js +60 -0
  56. package/src/gpu/kernels/repeat_channels.wgsl +29 -0
  57. package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
  58. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  59. package/src/gpu/kernels/sana_linear_attention.js +122 -0
  60. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
  61. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
  62. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
  63. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
  64. package/src/index-browser.d.ts +1 -1
  65. package/src/index-browser.js +2 -2
  66. package/src/index.js +1 -1
  67. package/src/inference/browser-harness.js +62 -22
  68. package/src/inference/pipelines/diffusion/init.js +14 -0
  69. package/src/inference/pipelines/diffusion/pipeline.js +206 -77
  70. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  71. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  72. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  73. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  74. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
  75. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
  76. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  77. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  78. package/src/inference/pipelines/diffusion/vae.js +782 -78
  79. package/src/inference/pipelines/text/config.d.ts +5 -0
  80. package/src/inference/pipelines/text/config.js +1 -1
  81. package/src/inference/pipelines/text/execution-v0.js +14 -93
  82. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  83. package/src/rules/execution-rules-contract-check.js +245 -0
  84. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  85. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  86. package/src/rules/kernels/relu.rules.json +6 -0
  87. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  88. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  89. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  90. package/src/rules/layer-pattern-contract-check.js +231 -0
  91. package/src/rules/rule-registry.d.ts +28 -0
  92. package/src/rules/rule-registry.js +38 -0
  93. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  94. package/src/tooling/conversion-config-materializer.js +99 -0
  95. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  96. package/src/tooling/lean-execution-contract-runner.js +158 -0
  97. package/src/tooling/node-convert.d.ts +10 -0
  98. package/src/tooling/node-converter.js +59 -0
  99. package/src/tooling/node-webgpu.js +9 -9
  100. package/src/version.d.ts +2 -0
  101. package/src/version.js +2 -0
  102. package/tools/convert-safetensors-node.js +47 -0
  103. package/tools/doppler-cli.js +115 -1
@@ -15,10 +15,16 @@ import {
15
15
  getActiveKernelPathSource,
16
16
  getActiveKernelPathPolicy,
17
17
  } from '../config/kernel-path-loader.js';
18
- import { selectRuleValue } from '../rules/rule-registry.js';
18
+ import {
19
+ getInferenceLayerPatternContractArtifact,
20
+ selectRuleValue,
21
+ } from '../rules/rule-registry.js';
19
22
  import { mergeRuntimeValues } from '../config/runtime-merge.js';
20
23
  import { isPlainObject } from '../utils/plain-object.js';
24
+ import { validateBrowserSuiteMetrics } from '../config/schema/browser-suite-metrics.schema.js';
21
25
  import { validateTrainingMetricsReport } from '../config/schema/training-metrics.schema.js';
26
+ import { buildExecutionContractArtifact } from '../config/execution-contract-check.js';
27
+ import { buildManifestRequiredInferenceFieldsArtifact } from '../config/required-inference-fields-contract-check.js';
22
28
 
23
29
  const TRAINING_SUITE_MODULE_PATH = '../training/suite.js';
24
30
  const NODE_SOURCE_RUNTIME_MODULE_PATH = '../tooling/node-source-runtime.js';
@@ -41,6 +47,29 @@ async function runTrainingBenchSuite(options = {}) {
41
47
  return module.runTrainingBenchSuite(options);
42
48
  }
43
49
 
50
+ function buildSuiteContractMetrics(suite, baseMetrics, manifest) {
51
+ const executionContractArtifact = buildExecutionContractArtifact(manifest);
52
+ const executionV0GraphContractArtifact = executionContractArtifact?.executionV0?.graph ?? null;
53
+ const layerPatternContractArtifact = getInferenceLayerPatternContractArtifact();
54
+ const requiredInferenceFieldsArtifact = manifest?.modelType === 'transformer'
55
+ && isPlainObject(manifest?.inference?.attention)
56
+ ? buildManifestRequiredInferenceFieldsArtifact(
57
+ manifest?.inference ?? null,
58
+ `${manifest?.modelId ?? 'unknown'}.inference`
59
+ )
60
+ : null;
61
+ return validateBrowserSuiteMetrics({
62
+ ...baseMetrics,
63
+ schemaVersion: 1,
64
+ source: 'doppler',
65
+ suite,
66
+ ...(executionContractArtifact ? { executionContractArtifact } : {}),
67
+ executionV0GraphContractArtifact,
68
+ layerPatternContractArtifact,
69
+ requiredInferenceFieldsArtifact,
70
+ });
71
+ }
72
+
44
73
  function parseReportTimestamp(rawTimestamp, label = 'timestamp') {
45
74
  if (rawTimestamp == null) {
46
75
  return null;
@@ -1824,6 +1853,11 @@ async function runInferenceSuite(options = {}) {
1824
1853
  source: 'doppler',
1825
1854
  prefillSemantics: 'internal_prefill_phase',
1826
1855
  });
1856
+ const metricsWithContracts = buildSuiteContractMetrics(
1857
+ options.suiteName || 'inference',
1858
+ metrics,
1859
+ harness.manifest
1860
+ );
1827
1861
  return {
1828
1862
  ...summary,
1829
1863
  modelId: options.modelId || harness.manifest?.modelId || 'unknown',
@@ -1841,7 +1875,7 @@ async function runInferenceSuite(options = {}) {
1841
1875
  timing,
1842
1876
  timingDiagnostics,
1843
1877
  output,
1844
- metrics,
1878
+ metrics: metricsWithContracts,
1845
1879
  memoryStats,
1846
1880
  deviceInfo: resolveDeviceInfo(),
1847
1881
  pipeline: options.keepPipeline ? harness.pipeline : null,
@@ -2218,6 +2252,7 @@ async function runBenchSuite(options = {}) {
2218
2252
  source: 'doppler',
2219
2253
  prefillSemantics: 'internal_prefill_phase',
2220
2254
  });
2255
+ const metricsWithContracts = buildSuiteContractMetrics('bench', metrics, harness.manifest);
2221
2256
  return {
2222
2257
  ...summary,
2223
2258
  modelId: options.modelId || harness.manifest?.modelId || 'unknown',
@@ -2235,7 +2270,7 @@ async function runBenchSuite(options = {}) {
2235
2270
  timing,
2236
2271
  timingDiagnostics,
2237
2272
  output,
2238
- metrics,
2273
+ metrics: metricsWithContracts,
2239
2274
  memoryStats,
2240
2275
  deviceInfo: resolveDeviceInfo(),
2241
2276
  pipeline: options.keepPipeline ? harness.pipeline : null,
@@ -2396,25 +2431,9 @@ async function runDiffusionSuite(options = {}) {
2396
2431
  source: 'doppler',
2397
2432
  prefillSemantics: 'internal_prefill_phase',
2398
2433
  });
2399
-
2400
- return {
2401
- ...summary,
2402
- modelId: options.modelId || harness.manifest?.modelId || 'unknown',
2403
- cacheMode,
2404
- loadMode,
2405
- env: {
2406
- library: 'doppler',
2407
- runtime: 'browser',
2408
- device: 'webgpu',
2409
- browserUserAgent: typeof navigator !== 'undefined' ? (navigator.userAgent || null) : null,
2410
- browserPlatform: typeof navigator !== 'undefined' ? (navigator.platform || null) : null,
2411
- browserLanguage: typeof navigator !== 'undefined' ? (navigator.language || null) : null,
2412
- browserVendor: typeof navigator !== 'undefined' ? (navigator.vendor || null) : null,
2413
- },
2414
- timing,
2415
- timingDiagnostics,
2416
- output,
2417
- metrics: {
2434
+ const metricsWithContracts = buildSuiteContractMetrics(
2435
+ 'diffusion',
2436
+ {
2418
2437
  warmupRuns,
2419
2438
  timedRuns,
2420
2439
  width,
@@ -2439,6 +2458,27 @@ async function runDiffusionSuite(options = {}) {
2439
2458
  gpu: gpuStats,
2440
2459
  performanceArtifact: diffusionPerformanceArtifact,
2441
2460
  },
2461
+ harness.manifest
2462
+ );
2463
+
2464
+ return {
2465
+ ...summary,
2466
+ modelId: options.modelId || harness.manifest?.modelId || 'unknown',
2467
+ cacheMode,
2468
+ loadMode,
2469
+ env: {
2470
+ library: 'doppler',
2471
+ runtime: 'browser',
2472
+ device: 'webgpu',
2473
+ browserUserAgent: typeof navigator !== 'undefined' ? (navigator.userAgent || null) : null,
2474
+ browserPlatform: typeof navigator !== 'undefined' ? (navigator.platform || null) : null,
2475
+ browserLanguage: typeof navigator !== 'undefined' ? (navigator.language || null) : null,
2476
+ browserVendor: typeof navigator !== 'undefined' ? (navigator.vendor || null) : null,
2477
+ },
2478
+ timing,
2479
+ timingDiagnostics,
2480
+ output,
2481
+ metrics: metricsWithContracts,
2442
2482
  memoryStats,
2443
2483
  deviceInfo: resolveDeviceInfo(),
2444
2484
  pipeline: options.keepPipeline ? harness.pipeline : null,
@@ -1,5 +1,7 @@
1
1
  import { DEFAULT_DIFFUSION_CONFIG } from '../../../config/schema/index.js';
2
2
 
3
+ const SUPPORTED_DIFFUSION_RUNTIME_LAYOUTS = new Set(['sd3', 'flux', 'sana']);
4
+
3
5
  function mergeSection(base, override) {
4
6
  if (!override) return { ...base };
5
7
  return { ...base, ...override };
@@ -38,6 +40,9 @@ function resolveSchedulerType(modelScheduler, runtimeScheduler) {
38
40
  if (modelClass === 'FlowMatchEulerDiscreteScheduler') {
39
41
  return 'flowmatch_euler';
40
42
  }
43
+ if (modelClass === 'SCMScheduler') {
44
+ return 'scm';
45
+ }
41
46
  if (modelClass === 'EulerDiscreteScheduler') {
42
47
  return 'euler';
43
48
  }
@@ -58,6 +63,8 @@ function mergeSchedulerConfig(modelConfig, runtimeScheduler) {
58
63
  type,
59
64
  numTrainTimesteps: modelScheduler.num_train_timesteps ?? runtimeScheduler.numTrainTimesteps,
60
65
  shift: modelScheduler.shift ?? runtimeScheduler.shift,
66
+ predictionType: modelScheduler.prediction_type ?? runtimeScheduler.predictionType,
67
+ sigmaData: modelScheduler.sigma_data ?? runtimeScheduler.sigmaData,
61
68
  };
62
69
  }
63
70
 
@@ -95,6 +102,13 @@ export function initializeDiffusion(manifest, runtimeConfig) {
95
102
  }
96
103
  throw new Error('Diffusion manifest missing config.diffusion model contract.');
97
104
  }
105
+ const layout = modelConfig.layout;
106
+ if (layout && !SUPPORTED_DIFFUSION_RUNTIME_LAYOUTS.has(layout)) {
107
+ throw new Error(
108
+ `Diffusion layout "${layout}" is recognized in the manifest, but the GPU runtime is not implemented yet. ` +
109
+ 'Supported runtime layouts: sd3, flux, sana.'
110
+ );
111
+ }
98
112
 
99
113
  const runtimeBase = mergeDiffusionConfig(DEFAULT_DIFFUSION_CONFIG, runtimeConfig?.inference?.diffusion);
100
114
  const runtime = {
@@ -14,10 +14,11 @@ import {
14
14
  projectContext,
15
15
  assertClipHiddenActivationSupported,
16
16
  } from './text-encoder-gpu.js';
17
- import { buildScheduler } from './scheduler.js';
17
+ import { buildScheduler, stepScmScheduler } from './scheduler.js';
18
18
  import { decodeLatents } from './vae.js';
19
19
  import { createDiffusionWeightLoader } from './weights.js';
20
20
  import { runSD3Transformer } from './sd3-transformer.js';
21
+ import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
21
22
  import { createSD3WeightResolver } from './sd3-weights.js';
22
23
  import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
23
24
  import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
27
28
  import { f16ToF32 } from '../../../loader/dtype-utils.js';
28
29
 
29
30
  const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
31
+ const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
32
+ const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
30
33
 
31
34
  function createRandomSeed() {
32
35
  if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
@@ -58,6 +61,49 @@ function extractTokenSet(tokensByEncoder, key) {
58
61
  return output;
59
62
  }
60
63
 
64
+ function resolveDiffusionLayout(modelConfig) {
65
+ return modelConfig?.layout ?? 'sd3';
66
+ }
67
+
68
+ function getTextEncoderKeysForLayout(layout) {
69
+ if (layout === 'sana') {
70
+ return SANA_TEXT_ENCODER_KEYS;
71
+ }
72
+ return SD3_TEXT_ENCODER_KEYS;
73
+ }
74
+
75
+ function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
76
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
77
+ for (const key of requiredKeys) {
78
+ if (!modelConfig?.components?.[key]) {
79
+ throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
80
+ }
81
+ if (!tokenizers?.[key]) {
82
+ throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
83
+ }
84
+ }
85
+ }
86
+
87
+ function buildTokenizerMaxLengths(layout, runtime) {
88
+ const maxLength = runtime?.textEncoder?.maxLength;
89
+ if (!Number.isFinite(maxLength) || maxLength <= 0) {
90
+ throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
91
+ }
92
+ if (layout === 'sana') {
93
+ return { text_encoder: maxLength };
94
+ }
95
+
96
+ const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
97
+ if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
98
+ throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
99
+ }
100
+ return {
101
+ text_encoder: maxLength,
102
+ text_encoder_2: maxLength,
103
+ text_encoder_3: t5MaxLength,
104
+ };
105
+ }
106
+
61
107
  function getTensorSize(shape) {
62
108
  if (!Array.isArray(shape)) return 0;
63
109
  return shape.reduce((acc, value) => acc * value, 1);
@@ -120,6 +166,49 @@ async function readTensorToFloat32(tensor) {
120
166
  return new Float32Array(data);
121
167
  }
122
168
 
169
+ async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
170
+ if (scheduler.type === 'flowmatch_euler') {
171
+ const sigma = scheduler.sigmas[stepIndex];
172
+ const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
173
+ const delta = sigmaNext - sigma;
174
+ const latentSize = getTensorSize(latentsTensor.shape);
175
+ const scale = options.scale ?? runScale;
176
+ const residualAdd = options.residualAdd ?? runResidualAdd;
177
+ const release = options.release ?? releaseBuffer;
178
+
179
+ const scaled = await scale(predictionTensor, delta, { count: latentSize });
180
+ const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
181
+
182
+ release(latentsTensor.buffer);
183
+ release(scaled.buffer);
184
+ release(predictionTensor.buffer);
185
+
186
+ return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
187
+ }
188
+
189
+ if (scheduler.type === 'scm') {
190
+ const sample = await readTensorToFloat32(latentsTensor);
191
+ const modelOutput = await readTensorToFloat32(predictionTensor);
192
+ releaseBuffer(predictionTensor.buffer);
193
+ releaseBuffer(latentsTensor.buffer);
194
+
195
+ const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
196
+ const noise = isFinalStep
197
+ ? null
198
+ : generateLatents(
199
+ runtime.latent.width,
200
+ runtime.latent.height,
201
+ runtime.latent.channels,
202
+ runtime.latent.scale,
203
+ (options.seedBase ?? createRandomSeed()) + stepIndex + 1
204
+ ).latents;
205
+ const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
206
+ return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
207
+ }
208
+
209
+ throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
210
+ }
211
+
123
212
  async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
124
213
  if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
125
214
  return cond;
@@ -251,14 +340,17 @@ export class DiffusionPipeline {
251
340
  });
252
341
  }
253
342
 
254
- const text_encoder = await this.weightLoader.loadComponentWeights('text_encoder');
255
- const text_encoder_2 = await this.weightLoader.loadComponentWeights('text_encoder_2');
256
- const text_encoder_3 = await this.weightLoader.loadComponentWeights('text_encoder_3');
343
+ const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
344
+ const requiredKeys = getTextEncoderKeysForLayout(layout);
345
+ const weights = {};
346
+ for (const key of requiredKeys) {
347
+ weights[key] = await this.weightLoader.loadComponentWeights(key);
348
+ }
257
349
 
258
350
  this.textEncoderWeights = {
259
- text_encoder,
260
- text_encoder_2,
261
- text_encoder_3,
351
+ text_encoder: weights.text_encoder ?? null,
352
+ text_encoder_2: weights.text_encoder_2 ?? null,
353
+ text_encoder_3: weights.text_encoder_3 ?? null,
262
354
  };
263
355
 
264
356
  return this.textEncoderWeights;
@@ -315,14 +407,9 @@ export class DiffusionPipeline {
315
407
  async generateGPU(request = {}) {
316
408
  const start = performance.now();
317
409
  const runtime = this.diffusionState.runtime;
318
- const clipMaxLength = runtime.textEncoder?.maxLength;
319
- if (!Number.isFinite(clipMaxLength) || clipMaxLength <= 0) {
320
- throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
321
- }
322
- const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
323
- if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
324
- throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
325
- }
410
+ const modelConfig = this.diffusionState.modelConfig;
411
+ const layout = resolveDiffusionLayout(modelConfig);
412
+ const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
326
413
 
327
414
  const defaultWidth = runtime.latent.width;
328
415
  const defaultHeight = runtime.latent.height;
@@ -346,28 +433,20 @@ export class DiffusionPipeline {
346
433
  throw new Error(`Invalid diffusion steps: ${steps}`);
347
434
  }
348
435
 
349
- const modelConfig = this.diffusionState.modelConfig;
350
436
  if (!modelConfig?.components?.transformer) {
351
437
  throw new Error('Diffusion GPU pipeline requires transformer component config.');
352
438
  }
353
- if (!modelConfig?.components?.text_encoder || !modelConfig?.components?.text_encoder_2 || !modelConfig?.components?.text_encoder_3) {
354
- throw new Error('Diffusion GPU pipeline requires text encoder components (text_encoder, text_encoder_2, text_encoder_3).');
439
+ assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
440
+ if (layout === 'sd3') {
441
+ assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
355
442
  }
356
- if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
357
- throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
358
- }
359
- assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
360
443
 
361
444
  const promptStart = performance.now();
362
445
  const encoded = encodePrompt(
363
446
  { prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
364
447
  this.tokenizers || {},
365
448
  {
366
- maxLengthByTokenizer: {
367
- text_encoder: clipMaxLength,
368
- text_encoder_2: clipMaxLength,
369
- text_encoder_3: t5MaxLength,
370
- },
449
+ maxLengthByTokenizer: tokenizerMaxLengths,
371
450
  }
372
451
  );
373
452
 
@@ -410,13 +489,31 @@ export class DiffusionPipeline {
410
489
  const prefillRecorder = canProfileGpu
411
490
  ? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
412
491
  : null;
413
- const condContext = await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
414
- recorder: prefillRecorder,
415
- });
416
- const uncondContext = shouldUseUncond && negativeCondition
417
- ? await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
492
+ const condContext = layout === 'sana'
493
+ ? await projectSanaContext(
494
+ promptCondition.context,
495
+ promptCondition.attentionMask,
496
+ transformerWeights,
497
+ transformerConfig,
498
+ runtime,
499
+ { recorder: prefillRecorder }
500
+ )
501
+ : await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
418
502
  recorder: prefillRecorder,
419
- })
503
+ });
504
+ const uncondContext = shouldUseUncond && negativeCondition
505
+ ? layout === 'sana'
506
+ ? await projectSanaContext(
507
+ negativeCondition.context,
508
+ negativeCondition.attentionMask,
509
+ transformerWeights,
510
+ transformerConfig,
511
+ runtime,
512
+ { recorder: prefillRecorder }
513
+ )
514
+ : await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
515
+ recorder: prefillRecorder,
516
+ })
420
517
  : null;
421
518
  if (prefillRecorder) {
422
519
  prefillRecorder.submit();
@@ -428,11 +525,6 @@ export class DiffusionPipeline {
428
525
  }
429
526
 
430
527
  const scheduler = buildScheduler(runtime.scheduler, steps);
431
- if (scheduler.type !== 'flowmatch_euler') {
432
- throw new Error(
433
- `Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
434
- );
435
- }
436
528
  const latentScale = this.diffusionState.latentScale;
437
529
  const latentChannels = this.diffusionState.latentChannels;
438
530
  const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
@@ -463,9 +555,6 @@ export class DiffusionPipeline {
463
555
  const latentSize = latentChannels * latentHeight * latentWidth;
464
556
  for (let i = 0; i < scheduler.steps; i++) {
465
557
  const timestep = scheduler.timesteps[i];
466
- const sigma = scheduler.sigmas[i];
467
- const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
468
- const delta = sigmaNext - sigma;
469
558
  const stepRecorder = canProfileGpu
470
559
  ? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
471
560
  : null;
@@ -477,37 +566,71 @@ export class DiffusionPipeline {
477
566
  ? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
478
567
  : runResidualAdd;
479
568
 
480
- const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
481
- dim: timeEmbedDim,
482
- recorder: stepRecorder,
483
- });
484
- const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
485
- recorder: stepRecorder,
486
- });
487
- const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
488
- recorder: stepRecorder,
489
- });
490
- const condPred = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
491
- recorder: stepRecorder,
492
- });
493
- releaseStep(timeTextCond.buffer);
569
+ const condPred = layout === 'sana'
570
+ ? await (async () => {
571
+ const timeState = await buildSanaTimestepConditioning(
572
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
573
+ guidanceScale,
574
+ transformerWeights,
575
+ transformerConfig,
576
+ runtime,
577
+ { recorder: stepRecorder }
578
+ );
579
+ return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
580
+ recorder: stepRecorder,
581
+ });
582
+ })()
583
+ : await (async () => {
584
+ const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
585
+ dim: timeEmbedDim,
586
+ recorder: stepRecorder,
587
+ });
588
+ const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
589
+ recorder: stepRecorder,
590
+ });
591
+ const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
592
+ recorder: stepRecorder,
593
+ });
594
+ const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
595
+ recorder: stepRecorder,
596
+ });
597
+ releaseStep(timeTextCond.buffer);
598
+ return output;
599
+ })();
494
600
 
495
601
  let pred = condPred;
496
602
  if (shouldUseUncond && uncondContext && negativeCondition) {
497
- const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
498
- dim: timeEmbedDim,
499
- recorder: stepRecorder,
500
- });
501
- const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
502
- recorder: stepRecorder,
503
- });
504
- const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
505
- recorder: stepRecorder,
506
- });
507
- const uncondPred = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
508
- recorder: stepRecorder,
509
- });
510
- releaseStep(timeTextUncond.buffer);
603
+ const uncondPred = layout === 'sana'
604
+ ? await (async () => {
605
+ const timeState = await buildSanaTimestepConditioning(
606
+ timestep * (transformerConfig.timestep_scale ?? 1.0),
607
+ guidanceScale,
608
+ transformerWeights,
609
+ transformerConfig,
610
+ runtime,
611
+ { recorder: stepRecorder }
612
+ );
613
+ return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
614
+ recorder: stepRecorder,
615
+ });
616
+ })()
617
+ : await (async () => {
618
+ const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
619
+ dim: timeEmbedDim,
620
+ recorder: stepRecorder,
621
+ });
622
+ const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
623
+ recorder: stepRecorder,
624
+ });
625
+ const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
626
+ recorder: stepRecorder,
627
+ });
628
+ const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
629
+ recorder: stepRecorder,
630
+ });
631
+ releaseStep(timeTextUncond.buffer);
632
+ return output;
633
+ })();
511
634
  pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
512
635
  recorder: stepRecorder,
513
636
  release: releaseStep,
@@ -516,14 +639,20 @@ export class DiffusionPipeline {
516
639
  releaseStep(condPred.buffer);
517
640
  }
518
641
 
519
- const scaled = await scale(pred, delta, { count: latentSize });
520
- const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
521
-
522
- releaseStep(latentsTensor.buffer);
523
- releaseStep(scaled.buffer);
524
- releaseStep(pred.buffer);
525
-
526
- latentsTensor = createTensor(updated.buffer, updated.dtype, [latentChannels, latentHeight, latentWidth], 'sd3_latents');
642
+ latentsTensor = await applySchedulerStep(
643
+ latentsTensor,
644
+ scheduler,
645
+ i,
646
+ timestep,
647
+ pred,
648
+ runtime,
649
+ {
650
+ scale,
651
+ residualAdd,
652
+ release: releaseStep,
653
+ seedBase: seed,
654
+ }
655
+ );
527
656
 
528
657
  if (stepRecorder) {
529
658
  stepRecorder.submit();
@@ -0,0 +1,53 @@
1
+ import type { Tensor } from '../../../gpu/tensor.js';
2
+ import type { CommandRecorder } from '../../../gpu/command-recorder.js';
3
+
4
+ export interface SanaTimestepState {
5
+ modulation: Tensor;
6
+ embeddedTimestep: Tensor;
7
+ }
8
+
9
+ export interface SanaTransformerOptions {
10
+ recorder?: CommandRecorder | null;
11
+ }
12
+
13
+ export declare function buildSanaTimestepConditioning(
14
+ timestep: number,
15
+ guidanceScale: number,
16
+ weightsEntry: any,
17
+ config: any,
18
+ runtime: any,
19
+ options?: SanaTransformerOptions
20
+ ): Promise<SanaTimestepState>;
21
+
22
+ export declare function projectSanaContext(
23
+ context: Tensor,
24
+ attentionMask: Uint32Array | null | undefined,
25
+ weightsEntry: any,
26
+ config: any,
27
+ runtime: any,
28
+ options?: SanaTransformerOptions
29
+ ): Promise<Tensor>;
30
+
31
+ export declare function runSanaTransformer(
32
+ latents: Tensor,
33
+ context: Tensor,
34
+ timeState: SanaTimestepState,
35
+ weightsEntry: any,
36
+ modelConfig: any,
37
+ runtime: any,
38
+ options?: SanaTransformerOptions
39
+ ): Promise<Tensor>;
40
+
41
+ export declare function buildSanaConditioning(
42
+ context: Tensor,
43
+ attentionMask: Uint32Array | null | undefined,
44
+ timestep: number,
45
+ guidanceScale: number,
46
+ weightsEntry: any,
47
+ modelConfig: any,
48
+ runtime: any,
49
+ options?: SanaTransformerOptions
50
+ ): Promise<{
51
+ context: Tensor;
52
+ timeState: SanaTimestepState;
53
+ }>;