@simulatte/doppler 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -3
- package/package.json +25 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +50 -46
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +62 -22
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +14 -93
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +9 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +115 -1
|
@@ -15,10 +15,16 @@ import {
|
|
|
15
15
|
getActiveKernelPathSource,
|
|
16
16
|
getActiveKernelPathPolicy,
|
|
17
17
|
} from '../config/kernel-path-loader.js';
|
|
18
|
-
import {
|
|
18
|
+
import {
|
|
19
|
+
getInferenceLayerPatternContractArtifact,
|
|
20
|
+
selectRuleValue,
|
|
21
|
+
} from '../rules/rule-registry.js';
|
|
19
22
|
import { mergeRuntimeValues } from '../config/runtime-merge.js';
|
|
20
23
|
import { isPlainObject } from '../utils/plain-object.js';
|
|
24
|
+
import { validateBrowserSuiteMetrics } from '../config/schema/browser-suite-metrics.schema.js';
|
|
21
25
|
import { validateTrainingMetricsReport } from '../config/schema/training-metrics.schema.js';
|
|
26
|
+
import { buildExecutionContractArtifact } from '../config/execution-contract-check.js';
|
|
27
|
+
import { buildManifestRequiredInferenceFieldsArtifact } from '../config/required-inference-fields-contract-check.js';
|
|
22
28
|
|
|
23
29
|
const TRAINING_SUITE_MODULE_PATH = '../training/suite.js';
|
|
24
30
|
const NODE_SOURCE_RUNTIME_MODULE_PATH = '../tooling/node-source-runtime.js';
|
|
@@ -41,6 +47,29 @@ async function runTrainingBenchSuite(options = {}) {
|
|
|
41
47
|
return module.runTrainingBenchSuite(options);
|
|
42
48
|
}
|
|
43
49
|
|
|
50
|
+
function buildSuiteContractMetrics(suite, baseMetrics, manifest) {
|
|
51
|
+
const executionContractArtifact = buildExecutionContractArtifact(manifest);
|
|
52
|
+
const executionV0GraphContractArtifact = executionContractArtifact?.executionV0?.graph ?? null;
|
|
53
|
+
const layerPatternContractArtifact = getInferenceLayerPatternContractArtifact();
|
|
54
|
+
const requiredInferenceFieldsArtifact = manifest?.modelType === 'transformer'
|
|
55
|
+
&& isPlainObject(manifest?.inference?.attention)
|
|
56
|
+
? buildManifestRequiredInferenceFieldsArtifact(
|
|
57
|
+
manifest?.inference ?? null,
|
|
58
|
+
`${manifest?.modelId ?? 'unknown'}.inference`
|
|
59
|
+
)
|
|
60
|
+
: null;
|
|
61
|
+
return validateBrowserSuiteMetrics({
|
|
62
|
+
...baseMetrics,
|
|
63
|
+
schemaVersion: 1,
|
|
64
|
+
source: 'doppler',
|
|
65
|
+
suite,
|
|
66
|
+
...(executionContractArtifact ? { executionContractArtifact } : {}),
|
|
67
|
+
executionV0GraphContractArtifact,
|
|
68
|
+
layerPatternContractArtifact,
|
|
69
|
+
requiredInferenceFieldsArtifact,
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
|
|
44
73
|
function parseReportTimestamp(rawTimestamp, label = 'timestamp') {
|
|
45
74
|
if (rawTimestamp == null) {
|
|
46
75
|
return null;
|
|
@@ -1824,6 +1853,11 @@ async function runInferenceSuite(options = {}) {
|
|
|
1824
1853
|
source: 'doppler',
|
|
1825
1854
|
prefillSemantics: 'internal_prefill_phase',
|
|
1826
1855
|
});
|
|
1856
|
+
const metricsWithContracts = buildSuiteContractMetrics(
|
|
1857
|
+
options.suiteName || 'inference',
|
|
1858
|
+
metrics,
|
|
1859
|
+
harness.manifest
|
|
1860
|
+
);
|
|
1827
1861
|
return {
|
|
1828
1862
|
...summary,
|
|
1829
1863
|
modelId: options.modelId || harness.manifest?.modelId || 'unknown',
|
|
@@ -1841,7 +1875,7 @@ async function runInferenceSuite(options = {}) {
|
|
|
1841
1875
|
timing,
|
|
1842
1876
|
timingDiagnostics,
|
|
1843
1877
|
output,
|
|
1844
|
-
metrics,
|
|
1878
|
+
metrics: metricsWithContracts,
|
|
1845
1879
|
memoryStats,
|
|
1846
1880
|
deviceInfo: resolveDeviceInfo(),
|
|
1847
1881
|
pipeline: options.keepPipeline ? harness.pipeline : null,
|
|
@@ -2218,6 +2252,7 @@ async function runBenchSuite(options = {}) {
|
|
|
2218
2252
|
source: 'doppler',
|
|
2219
2253
|
prefillSemantics: 'internal_prefill_phase',
|
|
2220
2254
|
});
|
|
2255
|
+
const metricsWithContracts = buildSuiteContractMetrics('bench', metrics, harness.manifest);
|
|
2221
2256
|
return {
|
|
2222
2257
|
...summary,
|
|
2223
2258
|
modelId: options.modelId || harness.manifest?.modelId || 'unknown',
|
|
@@ -2235,7 +2270,7 @@ async function runBenchSuite(options = {}) {
|
|
|
2235
2270
|
timing,
|
|
2236
2271
|
timingDiagnostics,
|
|
2237
2272
|
output,
|
|
2238
|
-
metrics,
|
|
2273
|
+
metrics: metricsWithContracts,
|
|
2239
2274
|
memoryStats,
|
|
2240
2275
|
deviceInfo: resolveDeviceInfo(),
|
|
2241
2276
|
pipeline: options.keepPipeline ? harness.pipeline : null,
|
|
@@ -2396,25 +2431,9 @@ async function runDiffusionSuite(options = {}) {
|
|
|
2396
2431
|
source: 'doppler',
|
|
2397
2432
|
prefillSemantics: 'internal_prefill_phase',
|
|
2398
2433
|
});
|
|
2399
|
-
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
modelId: options.modelId || harness.manifest?.modelId || 'unknown',
|
|
2403
|
-
cacheMode,
|
|
2404
|
-
loadMode,
|
|
2405
|
-
env: {
|
|
2406
|
-
library: 'doppler',
|
|
2407
|
-
runtime: 'browser',
|
|
2408
|
-
device: 'webgpu',
|
|
2409
|
-
browserUserAgent: typeof navigator !== 'undefined' ? (navigator.userAgent || null) : null,
|
|
2410
|
-
browserPlatform: typeof navigator !== 'undefined' ? (navigator.platform || null) : null,
|
|
2411
|
-
browserLanguage: typeof navigator !== 'undefined' ? (navigator.language || null) : null,
|
|
2412
|
-
browserVendor: typeof navigator !== 'undefined' ? (navigator.vendor || null) : null,
|
|
2413
|
-
},
|
|
2414
|
-
timing,
|
|
2415
|
-
timingDiagnostics,
|
|
2416
|
-
output,
|
|
2417
|
-
metrics: {
|
|
2434
|
+
const metricsWithContracts = buildSuiteContractMetrics(
|
|
2435
|
+
'diffusion',
|
|
2436
|
+
{
|
|
2418
2437
|
warmupRuns,
|
|
2419
2438
|
timedRuns,
|
|
2420
2439
|
width,
|
|
@@ -2439,6 +2458,27 @@ async function runDiffusionSuite(options = {}) {
|
|
|
2439
2458
|
gpu: gpuStats,
|
|
2440
2459
|
performanceArtifact: diffusionPerformanceArtifact,
|
|
2441
2460
|
},
|
|
2461
|
+
harness.manifest
|
|
2462
|
+
);
|
|
2463
|
+
|
|
2464
|
+
return {
|
|
2465
|
+
...summary,
|
|
2466
|
+
modelId: options.modelId || harness.manifest?.modelId || 'unknown',
|
|
2467
|
+
cacheMode,
|
|
2468
|
+
loadMode,
|
|
2469
|
+
env: {
|
|
2470
|
+
library: 'doppler',
|
|
2471
|
+
runtime: 'browser',
|
|
2472
|
+
device: 'webgpu',
|
|
2473
|
+
browserUserAgent: typeof navigator !== 'undefined' ? (navigator.userAgent || null) : null,
|
|
2474
|
+
browserPlatform: typeof navigator !== 'undefined' ? (navigator.platform || null) : null,
|
|
2475
|
+
browserLanguage: typeof navigator !== 'undefined' ? (navigator.language || null) : null,
|
|
2476
|
+
browserVendor: typeof navigator !== 'undefined' ? (navigator.vendor || null) : null,
|
|
2477
|
+
},
|
|
2478
|
+
timing,
|
|
2479
|
+
timingDiagnostics,
|
|
2480
|
+
output,
|
|
2481
|
+
metrics: metricsWithContracts,
|
|
2442
2482
|
memoryStats,
|
|
2443
2483
|
deviceInfo: resolveDeviceInfo(),
|
|
2444
2484
|
pipeline: options.keepPipeline ? harness.pipeline : null,
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import { DEFAULT_DIFFUSION_CONFIG } from '../../../config/schema/index.js';
|
|
2
2
|
|
|
3
|
+
const SUPPORTED_DIFFUSION_RUNTIME_LAYOUTS = new Set(['sd3', 'flux', 'sana']);
|
|
4
|
+
|
|
3
5
|
function mergeSection(base, override) {
|
|
4
6
|
if (!override) return { ...base };
|
|
5
7
|
return { ...base, ...override };
|
|
@@ -38,6 +40,9 @@ function resolveSchedulerType(modelScheduler, runtimeScheduler) {
|
|
|
38
40
|
if (modelClass === 'FlowMatchEulerDiscreteScheduler') {
|
|
39
41
|
return 'flowmatch_euler';
|
|
40
42
|
}
|
|
43
|
+
if (modelClass === 'SCMScheduler') {
|
|
44
|
+
return 'scm';
|
|
45
|
+
}
|
|
41
46
|
if (modelClass === 'EulerDiscreteScheduler') {
|
|
42
47
|
return 'euler';
|
|
43
48
|
}
|
|
@@ -58,6 +63,8 @@ function mergeSchedulerConfig(modelConfig, runtimeScheduler) {
|
|
|
58
63
|
type,
|
|
59
64
|
numTrainTimesteps: modelScheduler.num_train_timesteps ?? runtimeScheduler.numTrainTimesteps,
|
|
60
65
|
shift: modelScheduler.shift ?? runtimeScheduler.shift,
|
|
66
|
+
predictionType: modelScheduler.prediction_type ?? runtimeScheduler.predictionType,
|
|
67
|
+
sigmaData: modelScheduler.sigma_data ?? runtimeScheduler.sigmaData,
|
|
61
68
|
};
|
|
62
69
|
}
|
|
63
70
|
|
|
@@ -95,6 +102,13 @@ export function initializeDiffusion(manifest, runtimeConfig) {
|
|
|
95
102
|
}
|
|
96
103
|
throw new Error('Diffusion manifest missing config.diffusion model contract.');
|
|
97
104
|
}
|
|
105
|
+
const layout = modelConfig.layout;
|
|
106
|
+
if (layout && !SUPPORTED_DIFFUSION_RUNTIME_LAYOUTS.has(layout)) {
|
|
107
|
+
throw new Error(
|
|
108
|
+
`Diffusion layout "${layout}" is recognized in the manifest, but the GPU runtime is not implemented yet. ` +
|
|
109
|
+
'Supported runtime layouts: sd3, flux, sana.'
|
|
110
|
+
);
|
|
111
|
+
}
|
|
98
112
|
|
|
99
113
|
const runtimeBase = mergeDiffusionConfig(DEFAULT_DIFFUSION_CONFIG, runtimeConfig?.inference?.diffusion);
|
|
100
114
|
const runtime = {
|
|
@@ -14,10 +14,11 @@ import {
|
|
|
14
14
|
projectContext,
|
|
15
15
|
assertClipHiddenActivationSupported,
|
|
16
16
|
} from './text-encoder-gpu.js';
|
|
17
|
-
import { buildScheduler } from './scheduler.js';
|
|
17
|
+
import { buildScheduler, stepScmScheduler } from './scheduler.js';
|
|
18
18
|
import { decodeLatents } from './vae.js';
|
|
19
19
|
import { createDiffusionWeightLoader } from './weights.js';
|
|
20
20
|
import { runSD3Transformer } from './sd3-transformer.js';
|
|
21
|
+
import { runSanaTransformer, buildSanaTimestepConditioning, projectSanaContext } from './sana-transformer.js';
|
|
21
22
|
import { createSD3WeightResolver } from './sd3-weights.js';
|
|
22
23
|
import { createTensor, dtypeBytes } from '../../../gpu/tensor.js';
|
|
23
24
|
import { acquireBuffer, releaseBuffer, readBuffer } from '../../../memory/buffer-pool.js';
|
|
@@ -27,6 +28,8 @@ import { runResidualAdd, runScale, recordResidualAdd, recordScale } from '../../
|
|
|
27
28
|
import { f16ToF32 } from '../../../loader/dtype-utils.js';
|
|
28
29
|
|
|
29
30
|
const SUPPORTED_DIFFUSION_BACKEND_PIPELINES = new Set(['gpu']);
|
|
31
|
+
const SD3_TEXT_ENCODER_KEYS = ['text_encoder', 'text_encoder_2', 'text_encoder_3'];
|
|
32
|
+
const SANA_TEXT_ENCODER_KEYS = ['text_encoder'];
|
|
30
33
|
|
|
31
34
|
function createRandomSeed() {
|
|
32
35
|
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
|
@@ -58,6 +61,49 @@ function extractTokenSet(tokensByEncoder, key) {
|
|
|
58
61
|
return output;
|
|
59
62
|
}
|
|
60
63
|
|
|
64
|
+
function resolveDiffusionLayout(modelConfig) {
|
|
65
|
+
return modelConfig?.layout ?? 'sd3';
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
function getTextEncoderKeysForLayout(layout) {
|
|
69
|
+
if (layout === 'sana') {
|
|
70
|
+
return SANA_TEXT_ENCODER_KEYS;
|
|
71
|
+
}
|
|
72
|
+
return SD3_TEXT_ENCODER_KEYS;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
function assertLayoutTextEncoderContract(layout, modelConfig, tokenizers) {
|
|
76
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
77
|
+
for (const key of requiredKeys) {
|
|
78
|
+
if (!modelConfig?.components?.[key]) {
|
|
79
|
+
throw new Error(`Diffusion GPU pipeline requires component "${key}" for layout "${layout}".`);
|
|
80
|
+
}
|
|
81
|
+
if (!tokenizers?.[key]) {
|
|
82
|
+
throw new Error(`Diffusion GPU pipeline requires tokenizer "${key}" for layout "${layout}".`);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
function buildTokenizerMaxLengths(layout, runtime) {
|
|
88
|
+
const maxLength = runtime?.textEncoder?.maxLength;
|
|
89
|
+
if (!Number.isFinite(maxLength) || maxLength <= 0) {
|
|
90
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.maxLength.');
|
|
91
|
+
}
|
|
92
|
+
if (layout === 'sana') {
|
|
93
|
+
return { text_encoder: maxLength };
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const t5MaxLength = runtime?.textEncoder?.t5MaxLength ?? maxLength;
|
|
97
|
+
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
98
|
+
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
99
|
+
}
|
|
100
|
+
return {
|
|
101
|
+
text_encoder: maxLength,
|
|
102
|
+
text_encoder_2: maxLength,
|
|
103
|
+
text_encoder_3: t5MaxLength,
|
|
104
|
+
};
|
|
105
|
+
}
|
|
106
|
+
|
|
61
107
|
function getTensorSize(shape) {
|
|
62
108
|
if (!Array.isArray(shape)) return 0;
|
|
63
109
|
return shape.reduce((acc, value) => acc * value, 1);
|
|
@@ -120,6 +166,49 @@ async function readTensorToFloat32(tensor) {
|
|
|
120
166
|
return new Float32Array(data);
|
|
121
167
|
}
|
|
122
168
|
|
|
169
|
+
async function applySchedulerStep(latentsTensor, scheduler, stepIndex, timestep, predictionTensor, runtime, options = {}) {
|
|
170
|
+
if (scheduler.type === 'flowmatch_euler') {
|
|
171
|
+
const sigma = scheduler.sigmas[stepIndex];
|
|
172
|
+
const sigmaNext = stepIndex + 1 < scheduler.steps ? scheduler.sigmas[stepIndex + 1] : 0;
|
|
173
|
+
const delta = sigmaNext - sigma;
|
|
174
|
+
const latentSize = getTensorSize(latentsTensor.shape);
|
|
175
|
+
const scale = options.scale ?? runScale;
|
|
176
|
+
const residualAdd = options.residualAdd ?? runResidualAdd;
|
|
177
|
+
const release = options.release ?? releaseBuffer;
|
|
178
|
+
|
|
179
|
+
const scaled = await scale(predictionTensor, delta, { count: latentSize });
|
|
180
|
+
const updated = await residualAdd(latentsTensor, scaled, latentSize, { useVec4: true });
|
|
181
|
+
|
|
182
|
+
release(latentsTensor.buffer);
|
|
183
|
+
release(scaled.buffer);
|
|
184
|
+
release(predictionTensor.buffer);
|
|
185
|
+
|
|
186
|
+
return createTensor(updated.buffer, updated.dtype, [...latentsTensor.shape], 'diffusion_latents');
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
if (scheduler.type === 'scm') {
|
|
190
|
+
const sample = await readTensorToFloat32(latentsTensor);
|
|
191
|
+
const modelOutput = await readTensorToFloat32(predictionTensor);
|
|
192
|
+
releaseBuffer(predictionTensor.buffer);
|
|
193
|
+
releaseBuffer(latentsTensor.buffer);
|
|
194
|
+
|
|
195
|
+
const isFinalStep = stepIndex + 1 >= scheduler.timesteps.length - 1;
|
|
196
|
+
const noise = isFinalStep
|
|
197
|
+
? null
|
|
198
|
+
: generateLatents(
|
|
199
|
+
runtime.latent.width,
|
|
200
|
+
runtime.latent.height,
|
|
201
|
+
runtime.latent.channels,
|
|
202
|
+
runtime.latent.scale,
|
|
203
|
+
(options.seedBase ?? createRandomSeed()) + stepIndex + 1
|
|
204
|
+
).latents;
|
|
205
|
+
const step = stepScmScheduler(scheduler, modelOutput, timestep, sample, stepIndex, noise);
|
|
206
|
+
return createLatentTensor(step.prevSample, [...latentsTensor.shape], runtime);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
throw new Error(`Unsupported diffusion scheduler.type "${scheduler.type}".`);
|
|
210
|
+
}
|
|
211
|
+
|
|
123
212
|
async function applyGuidance(uncond, cond, guidanceScale, size, options = {}) {
|
|
124
213
|
if (!uncond || !Number.isFinite(guidanceScale) || guidanceScale <= 1) {
|
|
125
214
|
return cond;
|
|
@@ -251,14 +340,17 @@ export class DiffusionPipeline {
|
|
|
251
340
|
});
|
|
252
341
|
}
|
|
253
342
|
|
|
254
|
-
const
|
|
255
|
-
const
|
|
256
|
-
const
|
|
343
|
+
const layout = resolveDiffusionLayout(this.diffusionState?.modelConfig);
|
|
344
|
+
const requiredKeys = getTextEncoderKeysForLayout(layout);
|
|
345
|
+
const weights = {};
|
|
346
|
+
for (const key of requiredKeys) {
|
|
347
|
+
weights[key] = await this.weightLoader.loadComponentWeights(key);
|
|
348
|
+
}
|
|
257
349
|
|
|
258
350
|
this.textEncoderWeights = {
|
|
259
|
-
text_encoder,
|
|
260
|
-
text_encoder_2,
|
|
261
|
-
text_encoder_3,
|
|
351
|
+
text_encoder: weights.text_encoder ?? null,
|
|
352
|
+
text_encoder_2: weights.text_encoder_2 ?? null,
|
|
353
|
+
text_encoder_3: weights.text_encoder_3 ?? null,
|
|
262
354
|
};
|
|
263
355
|
|
|
264
356
|
return this.textEncoderWeights;
|
|
@@ -315,14 +407,9 @@ export class DiffusionPipeline {
|
|
|
315
407
|
async generateGPU(request = {}) {
|
|
316
408
|
const start = performance.now();
|
|
317
409
|
const runtime = this.diffusionState.runtime;
|
|
318
|
-
const
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
}
|
|
322
|
-
const t5MaxLength = runtime.textEncoder?.t5MaxLength ?? clipMaxLength;
|
|
323
|
-
if (!Number.isFinite(t5MaxLength) || t5MaxLength <= 0) {
|
|
324
|
-
throw new Error('Diffusion runtime requires runtime.textEncoder.t5MaxLength (or runtime.textEncoder.maxLength).');
|
|
325
|
-
}
|
|
410
|
+
const modelConfig = this.diffusionState.modelConfig;
|
|
411
|
+
const layout = resolveDiffusionLayout(modelConfig);
|
|
412
|
+
const tokenizerMaxLengths = buildTokenizerMaxLengths(layout, runtime);
|
|
326
413
|
|
|
327
414
|
const defaultWidth = runtime.latent.width;
|
|
328
415
|
const defaultHeight = runtime.latent.height;
|
|
@@ -346,28 +433,20 @@ export class DiffusionPipeline {
|
|
|
346
433
|
throw new Error(`Invalid diffusion steps: ${steps}`);
|
|
347
434
|
}
|
|
348
435
|
|
|
349
|
-
const modelConfig = this.diffusionState.modelConfig;
|
|
350
436
|
if (!modelConfig?.components?.transformer) {
|
|
351
437
|
throw new Error('Diffusion GPU pipeline requires transformer component config.');
|
|
352
438
|
}
|
|
353
|
-
|
|
354
|
-
|
|
439
|
+
assertLayoutTextEncoderContract(layout, modelConfig, this.tokenizers);
|
|
440
|
+
if (layout === 'sd3') {
|
|
441
|
+
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
355
442
|
}
|
|
356
|
-
if (!this.tokenizers?.text_encoder || !this.tokenizers?.text_encoder_2 || !this.tokenizers?.text_encoder_3) {
|
|
357
|
-
throw new Error('Diffusion GPU pipeline requires tokenizers for text_encoder, text_encoder_2, and text_encoder_3.');
|
|
358
|
-
}
|
|
359
|
-
assertClipHiddenActivationSupported(modelConfig?.components?.text_encoder?.config || {});
|
|
360
443
|
|
|
361
444
|
const promptStart = performance.now();
|
|
362
445
|
const encoded = encodePrompt(
|
|
363
446
|
{ prompt: request.prompt ?? '', negativePrompt: request.negativePrompt ?? '' },
|
|
364
447
|
this.tokenizers || {},
|
|
365
448
|
{
|
|
366
|
-
maxLengthByTokenizer:
|
|
367
|
-
text_encoder: clipMaxLength,
|
|
368
|
-
text_encoder_2: clipMaxLength,
|
|
369
|
-
text_encoder_3: t5MaxLength,
|
|
370
|
-
},
|
|
449
|
+
maxLengthByTokenizer: tokenizerMaxLengths,
|
|
371
450
|
}
|
|
372
451
|
);
|
|
373
452
|
|
|
@@ -410,13 +489,31 @@ export class DiffusionPipeline {
|
|
|
410
489
|
const prefillRecorder = canProfileGpu
|
|
411
490
|
? new CommandRecorder(getDevice(), 'diffusion_prefill', { profile: true })
|
|
412
491
|
: null;
|
|
413
|
-
const condContext =
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
492
|
+
const condContext = layout === 'sana'
|
|
493
|
+
? await projectSanaContext(
|
|
494
|
+
promptCondition.context,
|
|
495
|
+
promptCondition.attentionMask,
|
|
496
|
+
transformerWeights,
|
|
497
|
+
transformerConfig,
|
|
498
|
+
runtime,
|
|
499
|
+
{ recorder: prefillRecorder }
|
|
500
|
+
)
|
|
501
|
+
: await projectContext(promptCondition.context, transformerWeights, modelConfig, runtime, {
|
|
418
502
|
recorder: prefillRecorder,
|
|
419
|
-
})
|
|
503
|
+
});
|
|
504
|
+
const uncondContext = shouldUseUncond && negativeCondition
|
|
505
|
+
? layout === 'sana'
|
|
506
|
+
? await projectSanaContext(
|
|
507
|
+
negativeCondition.context,
|
|
508
|
+
negativeCondition.attentionMask,
|
|
509
|
+
transformerWeights,
|
|
510
|
+
transformerConfig,
|
|
511
|
+
runtime,
|
|
512
|
+
{ recorder: prefillRecorder }
|
|
513
|
+
)
|
|
514
|
+
: await projectContext(negativeCondition.context, transformerWeights, modelConfig, runtime, {
|
|
515
|
+
recorder: prefillRecorder,
|
|
516
|
+
})
|
|
420
517
|
: null;
|
|
421
518
|
if (prefillRecorder) {
|
|
422
519
|
prefillRecorder.submit();
|
|
@@ -428,11 +525,6 @@ export class DiffusionPipeline {
|
|
|
428
525
|
}
|
|
429
526
|
|
|
430
527
|
const scheduler = buildScheduler(runtime.scheduler, steps);
|
|
431
|
-
if (scheduler.type !== 'flowmatch_euler') {
|
|
432
|
-
throw new Error(
|
|
433
|
-
`Diffusion GPU pipeline requires scheduler.type="flowmatch_euler"; got "${scheduler.type}".`
|
|
434
|
-
);
|
|
435
|
-
}
|
|
436
528
|
const latentScale = this.diffusionState.latentScale;
|
|
437
529
|
const latentChannels = this.diffusionState.latentChannels;
|
|
438
530
|
const { latents, latentWidth, latentHeight } = generateLatents(width, height, latentChannels, latentScale, seed);
|
|
@@ -463,9 +555,6 @@ export class DiffusionPipeline {
|
|
|
463
555
|
const latentSize = latentChannels * latentHeight * latentWidth;
|
|
464
556
|
for (let i = 0; i < scheduler.steps; i++) {
|
|
465
557
|
const timestep = scheduler.timesteps[i];
|
|
466
|
-
const sigma = scheduler.sigmas[i];
|
|
467
|
-
const sigmaNext = i + 1 < scheduler.steps ? scheduler.sigmas[i + 1] : 0;
|
|
468
|
-
const delta = sigmaNext - sigma;
|
|
469
558
|
const stepRecorder = canProfileGpu
|
|
470
559
|
? new CommandRecorder(getDevice(), `diffusion_step_${i}`, { profile: true })
|
|
471
560
|
: null;
|
|
@@ -477,37 +566,71 @@ export class DiffusionPipeline {
|
|
|
477
566
|
? (left, right, count, options) => recordResidualAdd(stepRecorder, left, right, count, options)
|
|
478
567
|
: runResidualAdd;
|
|
479
568
|
|
|
480
|
-
const
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
569
|
+
const condPred = layout === 'sana'
|
|
570
|
+
? await (async () => {
|
|
571
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
572
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
573
|
+
guidanceScale,
|
|
574
|
+
transformerWeights,
|
|
575
|
+
transformerConfig,
|
|
576
|
+
runtime,
|
|
577
|
+
{ recorder: stepRecorder }
|
|
578
|
+
);
|
|
579
|
+
return runSanaTransformer(latentsTensor, condContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
580
|
+
recorder: stepRecorder,
|
|
581
|
+
});
|
|
582
|
+
})()
|
|
583
|
+
: await (async () => {
|
|
584
|
+
const timeCond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
585
|
+
dim: timeEmbedDim,
|
|
586
|
+
recorder: stepRecorder,
|
|
587
|
+
});
|
|
588
|
+
const textCond = await buildTimeTextEmbedding(promptCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
589
|
+
recorder: stepRecorder,
|
|
590
|
+
});
|
|
591
|
+
const timeTextCond = await combineTimeTextEmbeddings(timeCond, textCond, hiddenSize, {
|
|
592
|
+
recorder: stepRecorder,
|
|
593
|
+
});
|
|
594
|
+
const output = await runSD3Transformer(latentsTensor, condContext, timeTextCond, transformerWeights, modelConfig, runtime, {
|
|
595
|
+
recorder: stepRecorder,
|
|
596
|
+
});
|
|
597
|
+
releaseStep(timeTextCond.buffer);
|
|
598
|
+
return output;
|
|
599
|
+
})();
|
|
494
600
|
|
|
495
601
|
let pred = condPred;
|
|
496
602
|
if (shouldUseUncond && uncondContext && negativeCondition) {
|
|
497
|
-
const
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
603
|
+
const uncondPred = layout === 'sana'
|
|
604
|
+
? await (async () => {
|
|
605
|
+
const timeState = await buildSanaTimestepConditioning(
|
|
606
|
+
timestep * (transformerConfig.timestep_scale ?? 1.0),
|
|
607
|
+
guidanceScale,
|
|
608
|
+
transformerWeights,
|
|
609
|
+
transformerConfig,
|
|
610
|
+
runtime,
|
|
611
|
+
{ recorder: stepRecorder }
|
|
612
|
+
);
|
|
613
|
+
return runSanaTransformer(latentsTensor, uncondContext, timeState, transformerWeights, modelConfig, runtime, {
|
|
614
|
+
recorder: stepRecorder,
|
|
615
|
+
});
|
|
616
|
+
})()
|
|
617
|
+
: await (async () => {
|
|
618
|
+
const timeUncond = await buildTimestepEmbedding(timestep, transformerWeights, modelConfig, runtime, {
|
|
619
|
+
dim: timeEmbedDim,
|
|
620
|
+
recorder: stepRecorder,
|
|
621
|
+
});
|
|
622
|
+
const textUncond = await buildTimeTextEmbedding(negativeCondition.pooled, transformerWeights, modelConfig, runtime, {
|
|
623
|
+
recorder: stepRecorder,
|
|
624
|
+
});
|
|
625
|
+
const timeTextUncond = await combineTimeTextEmbeddings(timeUncond, textUncond, hiddenSize, {
|
|
626
|
+
recorder: stepRecorder,
|
|
627
|
+
});
|
|
628
|
+
const output = await runSD3Transformer(latentsTensor, uncondContext, timeTextUncond, transformerWeights, modelConfig, runtime, {
|
|
629
|
+
recorder: stepRecorder,
|
|
630
|
+
});
|
|
631
|
+
releaseStep(timeTextUncond.buffer);
|
|
632
|
+
return output;
|
|
633
|
+
})();
|
|
511
634
|
pred = await applyGuidance(uncondPred, condPred, guidanceScale, latentSize, {
|
|
512
635
|
recorder: stepRecorder,
|
|
513
636
|
release: releaseStep,
|
|
@@ -516,14 +639,20 @@ export class DiffusionPipeline {
|
|
|
516
639
|
releaseStep(condPred.buffer);
|
|
517
640
|
}
|
|
518
641
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
642
|
+
latentsTensor = await applySchedulerStep(
|
|
643
|
+
latentsTensor,
|
|
644
|
+
scheduler,
|
|
645
|
+
i,
|
|
646
|
+
timestep,
|
|
647
|
+
pred,
|
|
648
|
+
runtime,
|
|
649
|
+
{
|
|
650
|
+
scale,
|
|
651
|
+
residualAdd,
|
|
652
|
+
release: releaseStep,
|
|
653
|
+
seedBase: seed,
|
|
654
|
+
}
|
|
655
|
+
);
|
|
527
656
|
|
|
528
657
|
if (stepRecorder) {
|
|
529
658
|
stepRecorder.submit();
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import type { Tensor } from '../../../gpu/tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../../../gpu/command-recorder.js';
|
|
3
|
+
|
|
4
|
+
export interface SanaTimestepState {
|
|
5
|
+
modulation: Tensor;
|
|
6
|
+
embeddedTimestep: Tensor;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export interface SanaTransformerOptions {
|
|
10
|
+
recorder?: CommandRecorder | null;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export declare function buildSanaTimestepConditioning(
|
|
14
|
+
timestep: number,
|
|
15
|
+
guidanceScale: number,
|
|
16
|
+
weightsEntry: any,
|
|
17
|
+
config: any,
|
|
18
|
+
runtime: any,
|
|
19
|
+
options?: SanaTransformerOptions
|
|
20
|
+
): Promise<SanaTimestepState>;
|
|
21
|
+
|
|
22
|
+
export declare function projectSanaContext(
|
|
23
|
+
context: Tensor,
|
|
24
|
+
attentionMask: Uint32Array | null | undefined,
|
|
25
|
+
weightsEntry: any,
|
|
26
|
+
config: any,
|
|
27
|
+
runtime: any,
|
|
28
|
+
options?: SanaTransformerOptions
|
|
29
|
+
): Promise<Tensor>;
|
|
30
|
+
|
|
31
|
+
export declare function runSanaTransformer(
|
|
32
|
+
latents: Tensor,
|
|
33
|
+
context: Tensor,
|
|
34
|
+
timeState: SanaTimestepState,
|
|
35
|
+
weightsEntry: any,
|
|
36
|
+
modelConfig: any,
|
|
37
|
+
runtime: any,
|
|
38
|
+
options?: SanaTransformerOptions
|
|
39
|
+
): Promise<Tensor>;
|
|
40
|
+
|
|
41
|
+
export declare function buildSanaConditioning(
|
|
42
|
+
context: Tensor,
|
|
43
|
+
attentionMask: Uint32Array | null | undefined,
|
|
44
|
+
timestep: number,
|
|
45
|
+
guidanceScale: number,
|
|
46
|
+
weightsEntry: any,
|
|
47
|
+
modelConfig: any,
|
|
48
|
+
runtime: any,
|
|
49
|
+
options?: SanaTransformerOptions
|
|
50
|
+
): Promise<{
|
|
51
|
+
context: Tensor;
|
|
52
|
+
timeState: SanaTimestepState;
|
|
53
|
+
}>;
|