@simulatte/doppler 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -5
- package/package.json +27 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.d.ts +80 -0
- package/src/client/doppler-api.js +298 -0
- package/src/client/doppler-provider/types.js +1 -1
- package/src/client/doppler-registry.d.ts +23 -0
- package/src/client/doppler-registry.js +88 -0
- package/src/client/doppler-registry.json +39 -0
- package/src/config/execution-contract-check.d.ts +82 -0
- package/src/config/execution-contract-check.js +317 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +90 -67
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +27 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +49 -11
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/tokenizer-utils.js +17 -3
- package/src/formats/rdrr/validation.js +13 -0
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -0
- package/src/index-browser.js +2 -1
- package/src/index.d.ts +1 -0
- package/src/index.js +2 -1
- package/src/inference/browser-harness.js +164 -38
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +141 -101
- package/src/inference/pipelines/text/init.js +41 -10
- package/src/inference/pipelines/text.js +7 -1
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/lean-execution-contract.d.ts +16 -0
- package/src/tooling/lean-execution-contract.js +81 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +30 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +167 -6
|
@@ -9,11 +9,27 @@ import type { DiffusionSchedulerConfig } from './types.js';
|
|
|
9
9
|
export interface DiffusionScheduler {
|
|
10
10
|
type: string;
|
|
11
11
|
steps: number;
|
|
12
|
-
sigmas: Float32Array;
|
|
12
|
+
sigmas: Float32Array | null;
|
|
13
13
|
timesteps: Float32Array;
|
|
14
|
+
predictionType?: string;
|
|
15
|
+
sigmaData?: number;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
export interface DiffusionSchedulerStepResult {
|
|
19
|
+
prevSample: Float32Array;
|
|
20
|
+
predOriginalSample: Float32Array;
|
|
14
21
|
}
|
|
15
22
|
|
|
16
23
|
export declare function buildScheduler(
|
|
17
24
|
config: DiffusionSchedulerConfig,
|
|
18
25
|
stepsOverride?: number | null
|
|
19
26
|
): DiffusionScheduler;
|
|
27
|
+
|
|
28
|
+
export declare function stepScmScheduler(
|
|
29
|
+
config: DiffusionScheduler,
|
|
30
|
+
modelOutput: Float32Array,
|
|
31
|
+
timestep: number,
|
|
32
|
+
sample: Float32Array,
|
|
33
|
+
stepIndex?: number,
|
|
34
|
+
noise?: Float32Array | null
|
|
35
|
+
): DiffusionSchedulerStepResult;
|
|
@@ -34,6 +34,84 @@ function buildFlowMatchSchedule(config, steps) {
|
|
|
34
34
|
return sigmas;
|
|
35
35
|
}
|
|
36
36
|
|
|
37
|
+
function buildScmTimesteps(steps, config) {
|
|
38
|
+
const maxTimesteps = Number.isFinite(config.maxTimesteps) ? config.maxTimesteps : 1.5708;
|
|
39
|
+
const intermediateTimesteps = Number.isFinite(config.intermediateTimesteps)
|
|
40
|
+
? config.intermediateTimesteps
|
|
41
|
+
: 1.3;
|
|
42
|
+
const count = Math.max(1, steps);
|
|
43
|
+
if (count === 1) {
|
|
44
|
+
return new Float32Array([maxTimesteps, 0.0]);
|
|
45
|
+
}
|
|
46
|
+
if (count === 2) {
|
|
47
|
+
return new Float32Array([maxTimesteps, intermediateTimesteps, 0.0]);
|
|
48
|
+
}
|
|
49
|
+
return linspace(maxTimesteps, 0.0, count + 1);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
export function stepScmScheduler(config, modelOutput, timestep, sample, stepIndex = 0, noise = null) {
|
|
53
|
+
if (!config || config.type !== 'scm') {
|
|
54
|
+
throw new Error('stepScmScheduler requires scheduler.type="scm".');
|
|
55
|
+
}
|
|
56
|
+
if (!(modelOutput instanceof Float32Array)) {
|
|
57
|
+
throw new Error('stepScmScheduler requires modelOutput as Float32Array.');
|
|
58
|
+
}
|
|
59
|
+
if (!(sample instanceof Float32Array)) {
|
|
60
|
+
throw new Error('stepScmScheduler requires sample as Float32Array.');
|
|
61
|
+
}
|
|
62
|
+
if (modelOutput.length !== sample.length) {
|
|
63
|
+
throw new Error(
|
|
64
|
+
`stepScmScheduler requires modelOutput and sample with matching sizes; got ${modelOutput.length} and ${sample.length}.`
|
|
65
|
+
);
|
|
66
|
+
}
|
|
67
|
+
if (!(config.timesteps instanceof Float32Array) || config.timesteps.length < 2) {
|
|
68
|
+
throw new Error('stepScmScheduler requires scheduler.timesteps with length >= 2.');
|
|
69
|
+
}
|
|
70
|
+
if (!Number.isInteger(stepIndex) || stepIndex < 0 || stepIndex + 1 >= config.timesteps.length) {
|
|
71
|
+
throw new Error(
|
|
72
|
+
`stepScmScheduler received invalid stepIndex=${stepIndex} for ${config.timesteps.length} timesteps.`
|
|
73
|
+
);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
const parameterization = config.predictionType ?? 'trigflow';
|
|
77
|
+
if (parameterization !== 'trigflow') {
|
|
78
|
+
throw new Error(`Unsupported SCM predictionType "${parameterization}".`);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const s = config.timesteps[stepIndex];
|
|
82
|
+
const t = config.timesteps[stepIndex + 1];
|
|
83
|
+
const predOriginalSample = new Float32Array(sample.length);
|
|
84
|
+
const prevSample = new Float32Array(sample.length);
|
|
85
|
+
|
|
86
|
+
const cosS = Math.cos(s);
|
|
87
|
+
const sinS = Math.sin(s);
|
|
88
|
+
const cosT = Math.cos(t);
|
|
89
|
+
const sinT = Math.sin(t);
|
|
90
|
+
|
|
91
|
+
for (let i = 0; i < sample.length; i++) {
|
|
92
|
+
const predX0 = cosS * sample[i] - sinS * modelOutput[i];
|
|
93
|
+
predOriginalSample[i] = predX0;
|
|
94
|
+
prevSample[i] = predX0;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
if (stepIndex + 1 < config.timesteps.length - 1) {
|
|
98
|
+
if (!(noise instanceof Float32Array) || noise.length !== sample.length) {
|
|
99
|
+
throw new Error(
|
|
100
|
+
'stepScmScheduler requires a Float32Array noise tensor for multi-step SCM updates.'
|
|
101
|
+
);
|
|
102
|
+
}
|
|
103
|
+
const sigmaData = Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5;
|
|
104
|
+
for (let i = 0; i < prevSample.length; i++) {
|
|
105
|
+
prevSample[i] = cosT * predOriginalSample[i] + sinT * noise[i] * sigmaData;
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
return {
|
|
110
|
+
prevSample,
|
|
111
|
+
predOriginalSample,
|
|
112
|
+
};
|
|
113
|
+
}
|
|
114
|
+
|
|
37
115
|
export function buildScheduler(config, stepsOverride = null) {
|
|
38
116
|
if (!config) {
|
|
39
117
|
throw new Error('Scheduler config is required');
|
|
@@ -43,15 +121,25 @@ export function buildScheduler(config, stepsOverride = null) {
|
|
|
43
121
|
if (typeof type !== 'string' || !type) {
|
|
44
122
|
throw new Error('Diffusion scheduler requires a scheduler type.');
|
|
45
123
|
}
|
|
46
|
-
const sigmas = type === 'flowmatch_euler'
|
|
47
|
-
? buildFlowMatchSchedule(config, steps)
|
|
48
|
-
: buildLinearSigmaSchedule(steps);
|
|
49
124
|
const trainSteps = Number.isFinite(config.numTrainTimesteps)
|
|
50
125
|
? config.numTrainTimesteps
|
|
51
126
|
: null;
|
|
52
127
|
if (!Number.isFinite(trainSteps) || trainSteps <= 0) {
|
|
53
128
|
throw new Error('Diffusion scheduler requires valid numTrainTimesteps.');
|
|
54
129
|
}
|
|
130
|
+
if (type === 'scm') {
|
|
131
|
+
return {
|
|
132
|
+
type,
|
|
133
|
+
steps,
|
|
134
|
+
sigmas: null,
|
|
135
|
+
timesteps: buildScmTimesteps(steps, config),
|
|
136
|
+
predictionType: config.predictionType ?? 'trigflow',
|
|
137
|
+
sigmaData: Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5,
|
|
138
|
+
};
|
|
139
|
+
}
|
|
140
|
+
const sigmas = type === 'flowmatch_euler'
|
|
141
|
+
? buildFlowMatchSchedule(config, steps)
|
|
142
|
+
: buildLinearSigmaSchedule(steps);
|
|
55
143
|
const timesteps = new Float32Array(steps);
|
|
56
144
|
for (let i = 0; i < steps; i++) {
|
|
57
145
|
timesteps[i] = sigmas[i] * trainSteps;
|
|
@@ -16,25 +16,27 @@ export interface DiffusionTextEncoderWeightsEntry {
|
|
|
16
16
|
|
|
17
17
|
export interface DiffusionTextEncoderWeights {
|
|
18
18
|
text_encoder: DiffusionTextEncoderWeightsEntry;
|
|
19
|
-
text_encoder_2
|
|
20
|
-
text_encoder_3
|
|
19
|
+
text_encoder_2?: DiffusionTextEncoderWeightsEntry | null;
|
|
20
|
+
text_encoder_3?: DiffusionTextEncoderWeightsEntry | null;
|
|
21
21
|
transformer?: DiffusionTextEncoderWeightsEntry;
|
|
22
22
|
}
|
|
23
23
|
|
|
24
24
|
export interface DiffusionTextTokens {
|
|
25
25
|
text_encoder: number[];
|
|
26
|
-
text_encoder_2
|
|
27
|
-
text_encoder_3
|
|
26
|
+
text_encoder_2?: number[];
|
|
27
|
+
text_encoder_3?: number[];
|
|
28
28
|
}
|
|
29
29
|
|
|
30
30
|
export interface DiffusionTextConditioning {
|
|
31
31
|
pooled: Float32Array;
|
|
32
32
|
context: Tensor;
|
|
33
|
+
attentionMask?: Uint32Array | null;
|
|
33
34
|
profile?: {
|
|
34
35
|
totalMs?: number | null;
|
|
35
36
|
clipMs?: number | null;
|
|
36
37
|
clip2Ms?: number | null;
|
|
37
38
|
t5Ms?: number | null;
|
|
39
|
+
gemmaMs?: number | null;
|
|
38
40
|
} | null;
|
|
39
41
|
}
|
|
40
42
|
|
|
@@ -40,6 +40,8 @@ import {
|
|
|
40
40
|
inferDiffusionMatmulDtypeFromBuffer,
|
|
41
41
|
sumDiffusionProfileTimings,
|
|
42
42
|
} from './helpers.js';
|
|
43
|
+
import { initRoPEFrequencies } from '../text/init.js';
|
|
44
|
+
import { processLayerGPU } from '../text/layer.js';
|
|
43
45
|
|
|
44
46
|
const QUICK_GELU_ALPHA = 1.702;
|
|
45
47
|
const SUPPORTED_CLIP_HIDDEN_ACTIVATIONS = new Set(['gelu', 'quick_gelu']);
|
|
@@ -56,6 +58,16 @@ function padTokens(tokens, maxLength, padTokenId) {
|
|
|
56
58
|
return out;
|
|
57
59
|
}
|
|
58
60
|
|
|
61
|
+
function normalizeTokens(tokens, maxLength, fallbackTokenId) {
|
|
62
|
+
const source = Array.isArray(tokens) ? tokens : [];
|
|
63
|
+
const limit = Number.isFinite(maxLength) && maxLength > 0 ? Math.floor(maxLength) : source.length;
|
|
64
|
+
const trimmed = source.slice(0, limit);
|
|
65
|
+
if (trimmed.length > 0) {
|
|
66
|
+
return Uint32Array.from(trimmed);
|
|
67
|
+
}
|
|
68
|
+
return new Uint32Array([fallbackTokenId >>> 0]);
|
|
69
|
+
}
|
|
70
|
+
|
|
59
71
|
function findEosIndex(tokens, eosTokenId) {
|
|
60
72
|
if (eosTokenId == null) return tokens.length - 1;
|
|
61
73
|
for (let i = 0; i < tokens.length; i++) {
|
|
@@ -702,7 +714,264 @@ async function runT5Encoder(tokens, weightsEntry, config, runtime, options = {})
|
|
|
702
714
|
};
|
|
703
715
|
}
|
|
704
716
|
|
|
717
|
+
function buildGemma2LayerTypes(layerCount, slidingWindow) {
|
|
718
|
+
if (!Number.isFinite(slidingWindow) || slidingWindow <= 0) {
|
|
719
|
+
return Array.from({ length: layerCount }, () => 'full_attention');
|
|
720
|
+
}
|
|
721
|
+
return Array.from({ length: layerCount }, (_, index) => (
|
|
722
|
+
index % 2 === 1 ? 'full_attention' : 'sliding_attention'
|
|
723
|
+
));
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
function getGemma2LayerWeight(weights, prefix, layerIdx, suffix, required = true) {
|
|
727
|
+
const key = `${prefix}.model.layers.${layerIdx}.${suffix}`;
|
|
728
|
+
const weight = weights.get(key) || null;
|
|
729
|
+
if (!weight && required) {
|
|
730
|
+
throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
|
|
731
|
+
}
|
|
732
|
+
return weight;
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
function resolveGemma2TextConfig(config) {
|
|
736
|
+
const hiddenSize = config.hidden_size;
|
|
737
|
+
const numHeads = config.num_attention_heads;
|
|
738
|
+
const numKVHeads = config.num_key_value_heads ?? numHeads;
|
|
739
|
+
const headDim = config.head_dim ?? (
|
|
740
|
+
Number.isFinite(hiddenSize) && Number.isFinite(numHeads) && numHeads > 0
|
|
741
|
+
? Math.floor(hiddenSize / numHeads)
|
|
742
|
+
: null
|
|
743
|
+
);
|
|
744
|
+
const numLayers = config.num_hidden_layers;
|
|
745
|
+
const intermediateSize = config.intermediate_size;
|
|
746
|
+
const maxPositionEmbeddings = config.max_position_embeddings;
|
|
747
|
+
const rmsNormEps = config.rms_norm_eps ?? 1e-6;
|
|
748
|
+
|
|
749
|
+
if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
|
|
750
|
+
throw new Error('Gemma2 diffusion text encoder requires hidden_size.');
|
|
751
|
+
}
|
|
752
|
+
if (!Number.isFinite(numHeads) || numHeads <= 0) {
|
|
753
|
+
throw new Error('Gemma2 diffusion text encoder requires num_attention_heads.');
|
|
754
|
+
}
|
|
755
|
+
if (!Number.isFinite(numKVHeads) || numKVHeads <= 0) {
|
|
756
|
+
throw new Error('Gemma2 diffusion text encoder requires num_key_value_heads.');
|
|
757
|
+
}
|
|
758
|
+
if (!Number.isFinite(headDim) || headDim <= 0) {
|
|
759
|
+
throw new Error('Gemma2 diffusion text encoder requires head_dim or hidden_size/num_attention_heads.');
|
|
760
|
+
}
|
|
761
|
+
if (!Number.isFinite(numLayers) || numLayers <= 0) {
|
|
762
|
+
throw new Error('Gemma2 diffusion text encoder requires num_hidden_layers.');
|
|
763
|
+
}
|
|
764
|
+
if (!Number.isFinite(intermediateSize) || intermediateSize <= 0) {
|
|
765
|
+
throw new Error('Gemma2 diffusion text encoder requires intermediate_size.');
|
|
766
|
+
}
|
|
767
|
+
if (!Number.isFinite(maxPositionEmbeddings) || maxPositionEmbeddings <= 0) {
|
|
768
|
+
throw new Error('Gemma2 diffusion text encoder requires max_position_embeddings.');
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
return {
|
|
772
|
+
hiddenSize,
|
|
773
|
+
numHeads,
|
|
774
|
+
numKVHeads,
|
|
775
|
+
headDim,
|
|
776
|
+
numLayers,
|
|
777
|
+
intermediateSize,
|
|
778
|
+
maxPositionEmbeddings,
|
|
779
|
+
rmsNormEps,
|
|
780
|
+
ropeTheta: config.rope_theta ?? 10000,
|
|
781
|
+
slidingWindow: config.sliding_window ?? 4096,
|
|
782
|
+
scaleEmbeddings: config.scale_embeddings !== false,
|
|
783
|
+
};
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, options = {}) {
|
|
787
|
+
const device = getDevice();
|
|
788
|
+
if (!device) throw new Error('Gemma2 diffusion text encoder requires a WebGPU device.');
|
|
789
|
+
if (!weightsEntry?.weights || !weightsEntry?.shapes) {
|
|
790
|
+
throw new Error('Gemma2 diffusion text encoder requires loaded weights.');
|
|
791
|
+
}
|
|
792
|
+
|
|
793
|
+
const prefix = options.prefix ?? 'text_encoder';
|
|
794
|
+
const localRecorder = options.recorder
|
|
795
|
+
? null
|
|
796
|
+
: (options.profile ? new CommandRecorder(device, `${prefix}_gemma2_encoder`, { profile: true }) : null);
|
|
797
|
+
const recorder = options.recorder ?? localRecorder;
|
|
798
|
+
const ops = createKernelOps(recorder);
|
|
799
|
+
const release = createDiffusionBufferReleaser(recorder);
|
|
800
|
+
const destroy = createDiffusionBufferDestroyer(recorder);
|
|
801
|
+
const weights = weightsEntry.weights;
|
|
802
|
+
const activationDtype = resolveDiffusionActivationDtype(runtime);
|
|
803
|
+
const resolved = resolveGemma2TextConfig(config);
|
|
804
|
+
const padTokenId = config.pad_token_id ?? config.bos_token_id ?? 0;
|
|
805
|
+
const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
|
|
806
|
+
const numTokens = tokenIds.length;
|
|
807
|
+
const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
|
|
808
|
+
|
|
809
|
+
const embedKey = `${prefix}.model.embed_tokens.weight`;
|
|
810
|
+
const embedWeight = expectDiffusionWeight(
|
|
811
|
+
weights.get(embedKey),
|
|
812
|
+
embedKey
|
|
813
|
+
);
|
|
814
|
+
const embedDtype = resolveEmbeddingDtype(embedWeight, weightsEntry, embedKey, runtime);
|
|
815
|
+
let hidden = await ops.gather(
|
|
816
|
+
tokenBuffer,
|
|
817
|
+
getBuffer(embedWeight),
|
|
818
|
+
numTokens,
|
|
819
|
+
resolved.hiddenSize,
|
|
820
|
+
config.vocab_size,
|
|
821
|
+
{
|
|
822
|
+
embeddingDtype: embedDtype,
|
|
823
|
+
outputDtype: activationDtype,
|
|
824
|
+
transpose: false,
|
|
825
|
+
}
|
|
826
|
+
);
|
|
827
|
+
destroy(tokenBuffer);
|
|
828
|
+
|
|
829
|
+
if (resolved.scaleEmbeddings) {
|
|
830
|
+
const scaled = await ops.scale(hidden, Math.sqrt(resolved.hiddenSize), {
|
|
831
|
+
count: numTokens * resolved.hiddenSize,
|
|
832
|
+
});
|
|
833
|
+
release(hidden.buffer);
|
|
834
|
+
hidden = createTensor(scaled.buffer, scaled.dtype, [numTokens, resolved.hiddenSize], 'gemma2_embed');
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
const layerWeights = new Map();
|
|
838
|
+
for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
|
|
839
|
+
layerWeights.set(`layer_${layerIdx}`, {
|
|
840
|
+
inputNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'input_layernorm.weight'),
|
|
841
|
+
qProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.q_proj.weight'),
|
|
842
|
+
kProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.k_proj.weight'),
|
|
843
|
+
vProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.v_proj.weight'),
|
|
844
|
+
oProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.o_proj.weight'),
|
|
845
|
+
postAttentionNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'post_attention_layernorm.weight'),
|
|
846
|
+
preFeedforwardNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'pre_feedforward_layernorm.weight'),
|
|
847
|
+
gate: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.gate_proj.weight'),
|
|
848
|
+
up: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.up_proj.weight'),
|
|
849
|
+
down: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.down_proj.weight'),
|
|
850
|
+
});
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
const ropeFreqs = await initRoPEFrequencies({
|
|
854
|
+
headDim: resolved.headDim,
|
|
855
|
+
maxSeqLen: resolved.maxPositionEmbeddings,
|
|
856
|
+
ropeTheta: resolved.ropeTheta,
|
|
857
|
+
ropeLocalTheta: null,
|
|
858
|
+
ropeScale: 1,
|
|
859
|
+
ropeLocalScale: null,
|
|
860
|
+
ropeScalingType: null,
|
|
861
|
+
ropeLocalScalingType: null,
|
|
862
|
+
ropeScaling: null,
|
|
863
|
+
ropeLocalScaling: null,
|
|
864
|
+
}, true);
|
|
865
|
+
|
|
866
|
+
const context = {
|
|
867
|
+
useGPU: true,
|
|
868
|
+
activationDtype,
|
|
869
|
+
recorder,
|
|
870
|
+
currentSeqLen: 0,
|
|
871
|
+
kvCache: null,
|
|
872
|
+
weightConfig: {
|
|
873
|
+
rmsNormWeightOffset: true,
|
|
874
|
+
},
|
|
875
|
+
debugFlags: {},
|
|
876
|
+
weights: layerWeights,
|
|
877
|
+
ropeFreqsCos: ropeFreqs.cos,
|
|
878
|
+
ropeFreqsSin: ropeFreqs.sin,
|
|
879
|
+
ropeLocalCos: ropeFreqs.localCos,
|
|
880
|
+
ropeLocalSin: ropeFreqs.localSin,
|
|
881
|
+
config: {
|
|
882
|
+
hiddenSize: resolved.hiddenSize,
|
|
883
|
+
intermediateSize: resolved.intermediateSize,
|
|
884
|
+
numHeads: resolved.numHeads,
|
|
885
|
+
numKVHeads: resolved.numKVHeads,
|
|
886
|
+
headDim: resolved.headDim,
|
|
887
|
+
rmsNormEps: resolved.rmsNormEps,
|
|
888
|
+
slidingWindow: resolved.slidingWindow,
|
|
889
|
+
attnLogitSoftcapping: 50.0,
|
|
890
|
+
queryPreAttnScalar: resolved.headDim,
|
|
891
|
+
queryKeyNorm: false,
|
|
892
|
+
attentionOutputGate: false,
|
|
893
|
+
causalAttention: true,
|
|
894
|
+
hiddenActivation: 'gelu',
|
|
895
|
+
swigluLimit: null,
|
|
896
|
+
useMoE: false,
|
|
897
|
+
layerTypes: buildGemma2LayerTypes(resolved.numLayers, resolved.slidingWindow),
|
|
898
|
+
preFeedforwardNorm: true,
|
|
899
|
+
postFeedforwardNorm: false,
|
|
900
|
+
postAttentionNorm: true,
|
|
901
|
+
},
|
|
902
|
+
};
|
|
903
|
+
|
|
904
|
+
for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
|
|
905
|
+
const output = await processLayerGPU(
|
|
906
|
+
layerIdx,
|
|
907
|
+
hidden.buffer,
|
|
908
|
+
numTokens,
|
|
909
|
+
true,
|
|
910
|
+
numTokens * resolved.hiddenSize,
|
|
911
|
+
context
|
|
912
|
+
);
|
|
913
|
+
hidden = createTensor(output.buffer, output.dtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
const finalNormKey = `${prefix}.model.norm.weight`;
|
|
917
|
+
const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
|
|
918
|
+
const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
|
|
919
|
+
batchSize: numTokens,
|
|
920
|
+
hiddenSize: resolved.hiddenSize,
|
|
921
|
+
rmsNormWeightOffset: true,
|
|
922
|
+
});
|
|
923
|
+
release(hidden.buffer);
|
|
924
|
+
|
|
925
|
+
let profile = null;
|
|
926
|
+
if (localRecorder) {
|
|
927
|
+
localRecorder.submit();
|
|
928
|
+
const timings = await localRecorder.resolveProfileTimings();
|
|
929
|
+
profile = timings ? { totalMs: sumDiffusionProfileTimings(timings) ?? 0, timings } : { totalMs: null };
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
return {
|
|
933
|
+
hidden: final,
|
|
934
|
+
attentionMask: Uint32Array.from({ length: numTokens }, () => 1),
|
|
935
|
+
maxLength: numTokens,
|
|
936
|
+
hiddenSize: resolved.hiddenSize,
|
|
937
|
+
profile,
|
|
938
|
+
};
|
|
939
|
+
}
|
|
940
|
+
|
|
705
941
|
export async function runTextEncodersForPrompt(tokensByEncoder, weightsByComponent, modelConfig, runtime, options = {}) {
|
|
942
|
+
const layout = modelConfig?.layout ?? 'sd3';
|
|
943
|
+
if (layout === 'sana') {
|
|
944
|
+
const gemmaConfig = modelConfig?.components?.text_encoder?.config || {};
|
|
945
|
+
const gemmaMaxLength = runtime?.textEncoder?.maxLength;
|
|
946
|
+
if (!Number.isFinite(gemmaMaxLength) || gemmaMaxLength <= 0) {
|
|
947
|
+
throw new Error('Sana Gemma2 encoder requires runtime.textEncoder.maxLength.');
|
|
948
|
+
}
|
|
949
|
+
const profileEnabled = options.profile === true;
|
|
950
|
+
const gemma = await runGemma2TextEncoder(
|
|
951
|
+
tokensByEncoder.text_encoder,
|
|
952
|
+
weightsByComponent.text_encoder,
|
|
953
|
+
gemmaConfig,
|
|
954
|
+
runtime,
|
|
955
|
+
{
|
|
956
|
+
prefix: 'text_encoder',
|
|
957
|
+
maxLength: gemmaMaxLength,
|
|
958
|
+
profile: profileEnabled,
|
|
959
|
+
}
|
|
960
|
+
);
|
|
961
|
+
|
|
962
|
+
return {
|
|
963
|
+
pooled: new Float32Array(0),
|
|
964
|
+
context: gemma.hidden,
|
|
965
|
+
attentionMask: gemma.attentionMask,
|
|
966
|
+
profile: profileEnabled
|
|
967
|
+
? {
|
|
968
|
+
totalMs: gemma.profile?.totalMs ?? null,
|
|
969
|
+
gemmaMs: gemma.profile?.totalMs ?? null,
|
|
970
|
+
}
|
|
971
|
+
: null,
|
|
972
|
+
};
|
|
973
|
+
}
|
|
974
|
+
|
|
706
975
|
const clipConfig = modelConfig?.components?.text_encoder?.config || {};
|
|
707
976
|
const clip2Config = modelConfig?.components?.text_encoder_2?.config || {};
|
|
708
977
|
const t5Config = modelConfig?.components?.text_encoder_3?.config || {};
|
|
@@ -749,6 +1018,7 @@ export async function runTextEncodersForPrompt(tokensByEncoder, weightsByCompone
|
|
|
749
1018
|
return {
|
|
750
1019
|
pooled,
|
|
751
1020
|
context: t5.hidden,
|
|
1021
|
+
attentionMask: null,
|
|
752
1022
|
profile,
|
|
753
1023
|
};
|
|
754
1024
|
}
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { BPETokenizer } from '../../tokenizers/bpe.js';
|
|
2
2
|
import { SentencePieceTokenizer } from '../../tokenizers/sentencepiece.js';
|
|
3
|
+
import { BundledTokenizer } from '../../tokenizers/bundled.js';
|
|
3
4
|
import { loadAuxText, loadAuxFile } from '../../../storage/shard-manager.js';
|
|
4
5
|
|
|
5
6
|
function parseMerges(text) {
|
|
@@ -136,11 +137,27 @@ async function loadSentencePieceTokenizer(tokenizerConfig, options = {}) {
|
|
|
136
137
|
return tokenizer;
|
|
137
138
|
}
|
|
138
139
|
|
|
140
|
+
async function loadBundledTokenizer(tokenizerConfig, options = {}) {
|
|
141
|
+
const { baseUrl } = options;
|
|
142
|
+
const tokenizerJsonText = await loadTextAsset(tokenizerConfig.tokenizerFile, baseUrl);
|
|
143
|
+
const tokenizerJson = JSON.parse(tokenizerJsonText);
|
|
144
|
+
const tokenizer = new BundledTokenizer({
|
|
145
|
+
vocabSize: 0,
|
|
146
|
+
deferSpecialTokens: true,
|
|
147
|
+
});
|
|
148
|
+
tokenizer.load(tokenizerJson);
|
|
149
|
+
return tokenizer;
|
|
150
|
+
}
|
|
151
|
+
|
|
139
152
|
export async function loadDiffusionTokenizers(diffusionConfig, options = {}) {
|
|
140
153
|
const tokenizers = {};
|
|
141
154
|
const config = diffusionConfig?.tokenizers || {};
|
|
142
155
|
if (config.text_encoder) {
|
|
143
|
-
|
|
156
|
+
if (config.text_encoder.type === 'bundled') {
|
|
157
|
+
tokenizers.text_encoder = await loadBundledTokenizer(config.text_encoder, options);
|
|
158
|
+
} else {
|
|
159
|
+
tokenizers.text_encoder = await loadBpeTokenizer(config.text_encoder, options);
|
|
160
|
+
}
|
|
144
161
|
}
|
|
145
162
|
if (config.text_encoder_2) {
|
|
146
163
|
tokenizers.text_encoder_2 = await loadBpeTokenizer(config.text_encoder_2, options);
|
|
@@ -27,6 +27,10 @@ export interface DiffusionSchedulerConfig {
|
|
|
27
27
|
eta: number;
|
|
28
28
|
numTrainTimesteps: number;
|
|
29
29
|
shift: number;
|
|
30
|
+
predictionType?: string;
|
|
31
|
+
sigmaData?: number;
|
|
32
|
+
maxTimesteps?: number;
|
|
33
|
+
intermediateTimesteps?: number;
|
|
30
34
|
}
|
|
31
35
|
|
|
32
36
|
export interface DiffusionLatentConfig {
|