@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -9,11 +9,27 @@ import type { DiffusionSchedulerConfig } from './types.js';
9
9
  export interface DiffusionScheduler {
10
10
  type: string;
11
11
  steps: number;
12
- sigmas: Float32Array;
12
+ sigmas: Float32Array | null;
13
13
  timesteps: Float32Array;
14
+ predictionType?: string;
15
+ sigmaData?: number;
16
+ }
17
+
18
+ export interface DiffusionSchedulerStepResult {
19
+ prevSample: Float32Array;
20
+ predOriginalSample: Float32Array;
14
21
  }
15
22
 
16
23
  export declare function buildScheduler(
17
24
  config: DiffusionSchedulerConfig,
18
25
  stepsOverride?: number | null
19
26
  ): DiffusionScheduler;
27
+
28
+ export declare function stepScmScheduler(
29
+ config: DiffusionScheduler,
30
+ modelOutput: Float32Array,
31
+ timestep: number,
32
+ sample: Float32Array,
33
+ stepIndex?: number,
34
+ noise?: Float32Array | null
35
+ ): DiffusionSchedulerStepResult;
@@ -34,6 +34,84 @@ function buildFlowMatchSchedule(config, steps) {
34
34
  return sigmas;
35
35
  }
36
36
 
37
+ function buildScmTimesteps(steps, config) {
38
+ const maxTimesteps = Number.isFinite(config.maxTimesteps) ? config.maxTimesteps : 1.5708;
39
+ const intermediateTimesteps = Number.isFinite(config.intermediateTimesteps)
40
+ ? config.intermediateTimesteps
41
+ : 1.3;
42
+ const count = Math.max(1, steps);
43
+ if (count === 1) {
44
+ return new Float32Array([maxTimesteps, 0.0]);
45
+ }
46
+ if (count === 2) {
47
+ return new Float32Array([maxTimesteps, intermediateTimesteps, 0.0]);
48
+ }
49
+ return linspace(maxTimesteps, 0.0, count + 1);
50
+ }
51
+
52
+ export function stepScmScheduler(config, modelOutput, timestep, sample, stepIndex = 0, noise = null) {
53
+ if (!config || config.type !== 'scm') {
54
+ throw new Error('stepScmScheduler requires scheduler.type="scm".');
55
+ }
56
+ if (!(modelOutput instanceof Float32Array)) {
57
+ throw new Error('stepScmScheduler requires modelOutput as Float32Array.');
58
+ }
59
+ if (!(sample instanceof Float32Array)) {
60
+ throw new Error('stepScmScheduler requires sample as Float32Array.');
61
+ }
62
+ if (modelOutput.length !== sample.length) {
63
+ throw new Error(
64
+ `stepScmScheduler requires modelOutput and sample with matching sizes; got ${modelOutput.length} and ${sample.length}.`
65
+ );
66
+ }
67
+ if (!(config.timesteps instanceof Float32Array) || config.timesteps.length < 2) {
68
+ throw new Error('stepScmScheduler requires scheduler.timesteps with length >= 2.');
69
+ }
70
+ if (!Number.isInteger(stepIndex) || stepIndex < 0 || stepIndex + 1 >= config.timesteps.length) {
71
+ throw new Error(
72
+ `stepScmScheduler received invalid stepIndex=${stepIndex} for ${config.timesteps.length} timesteps.`
73
+ );
74
+ }
75
+
76
+ const parameterization = config.predictionType ?? 'trigflow';
77
+ if (parameterization !== 'trigflow') {
78
+ throw new Error(`Unsupported SCM predictionType "${parameterization}".`);
79
+ }
80
+
81
+ const s = config.timesteps[stepIndex];
82
+ const t = config.timesteps[stepIndex + 1];
83
+ const predOriginalSample = new Float32Array(sample.length);
84
+ const prevSample = new Float32Array(sample.length);
85
+
86
+ const cosS = Math.cos(s);
87
+ const sinS = Math.sin(s);
88
+ const cosT = Math.cos(t);
89
+ const sinT = Math.sin(t);
90
+
91
+ for (let i = 0; i < sample.length; i++) {
92
+ const predX0 = cosS * sample[i] - sinS * modelOutput[i];
93
+ predOriginalSample[i] = predX0;
94
+ prevSample[i] = predX0;
95
+ }
96
+
97
+ if (stepIndex + 1 < config.timesteps.length - 1) {
98
+ if (!(noise instanceof Float32Array) || noise.length !== sample.length) {
99
+ throw new Error(
100
+ 'stepScmScheduler requires a Float32Array noise tensor for multi-step SCM updates.'
101
+ );
102
+ }
103
+ const sigmaData = Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5;
104
+ for (let i = 0; i < prevSample.length; i++) {
105
+ prevSample[i] = cosT * predOriginalSample[i] + sinT * noise[i] * sigmaData;
106
+ }
107
+ }
108
+
109
+ return {
110
+ prevSample,
111
+ predOriginalSample,
112
+ };
113
+ }
114
+
37
115
  export function buildScheduler(config, stepsOverride = null) {
38
116
  if (!config) {
39
117
  throw new Error('Scheduler config is required');
@@ -43,15 +121,25 @@ export function buildScheduler(config, stepsOverride = null) {
43
121
  if (typeof type !== 'string' || !type) {
44
122
  throw new Error('Diffusion scheduler requires a scheduler type.');
45
123
  }
46
- const sigmas = type === 'flowmatch_euler'
47
- ? buildFlowMatchSchedule(config, steps)
48
- : buildLinearSigmaSchedule(steps);
49
124
  const trainSteps = Number.isFinite(config.numTrainTimesteps)
50
125
  ? config.numTrainTimesteps
51
126
  : null;
52
127
  if (!Number.isFinite(trainSteps) || trainSteps <= 0) {
53
128
  throw new Error('Diffusion scheduler requires valid numTrainTimesteps.');
54
129
  }
130
+ if (type === 'scm') {
131
+ return {
132
+ type,
133
+ steps,
134
+ sigmas: null,
135
+ timesteps: buildScmTimesteps(steps, config),
136
+ predictionType: config.predictionType ?? 'trigflow',
137
+ sigmaData: Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5,
138
+ };
139
+ }
140
+ const sigmas = type === 'flowmatch_euler'
141
+ ? buildFlowMatchSchedule(config, steps)
142
+ : buildLinearSigmaSchedule(steps);
55
143
  const timesteps = new Float32Array(steps);
56
144
  for (let i = 0; i < steps; i++) {
57
145
  timesteps[i] = sigmas[i] * trainSteps;
@@ -16,25 +16,27 @@ export interface DiffusionTextEncoderWeightsEntry {
16
16
 
17
17
  export interface DiffusionTextEncoderWeights {
18
18
  text_encoder: DiffusionTextEncoderWeightsEntry;
19
- text_encoder_2: DiffusionTextEncoderWeightsEntry;
20
- text_encoder_3: DiffusionTextEncoderWeightsEntry;
19
+ text_encoder_2?: DiffusionTextEncoderWeightsEntry | null;
20
+ text_encoder_3?: DiffusionTextEncoderWeightsEntry | null;
21
21
  transformer?: DiffusionTextEncoderWeightsEntry;
22
22
  }
23
23
 
24
24
  export interface DiffusionTextTokens {
25
25
  text_encoder: number[];
26
- text_encoder_2: number[];
27
- text_encoder_3: number[];
26
+ text_encoder_2?: number[];
27
+ text_encoder_3?: number[];
28
28
  }
29
29
 
30
30
  export interface DiffusionTextConditioning {
31
31
  pooled: Float32Array;
32
32
  context: Tensor;
33
+ attentionMask?: Uint32Array | null;
33
34
  profile?: {
34
35
  totalMs?: number | null;
35
36
  clipMs?: number | null;
36
37
  clip2Ms?: number | null;
37
38
  t5Ms?: number | null;
39
+ gemmaMs?: number | null;
38
40
  } | null;
39
41
  }
40
42
 
@@ -78,3 +80,8 @@ export declare function projectContext(
78
80
  ): Promise<Tensor>;
79
81
 
80
82
  export declare function assertClipHiddenActivationSupported(config: { hidden_act?: string }): void;
83
+
84
+ export declare function resolveGemma2WeightRoot(
85
+ weights: Map<string, any>,
86
+ prefix?: string
87
+ ): string;
@@ -40,6 +40,8 @@ import {
40
40
  inferDiffusionMatmulDtypeFromBuffer,
41
41
  sumDiffusionProfileTimings,
42
42
  } from './helpers.js';
43
+ import { initRoPEFrequencies } from '../text/init.js';
44
+ import { processLayerGPU } from '../text/layer.js';
43
45
 
44
46
  const QUICK_GELU_ALPHA = 1.702;
45
47
  const SUPPORTED_CLIP_HIDDEN_ACTIVATIONS = new Set(['gelu', 'quick_gelu']);
@@ -56,6 +58,16 @@ function padTokens(tokens, maxLength, padTokenId) {
56
58
  return out;
57
59
  }
58
60
 
61
+ function normalizeTokens(tokens, maxLength, fallbackTokenId) {
62
+ const source = Array.isArray(tokens) ? tokens : [];
63
+ const limit = Number.isFinite(maxLength) && maxLength > 0 ? Math.floor(maxLength) : source.length;
64
+ const trimmed = source.slice(0, limit);
65
+ if (trimmed.length > 0) {
66
+ return Uint32Array.from(trimmed);
67
+ }
68
+ return new Uint32Array([fallbackTokenId >>> 0]);
69
+ }
70
+
59
71
  function findEosIndex(tokens, eosTokenId) {
60
72
  if (eosTokenId == null) return tokens.length - 1;
61
73
  for (let i = 0; i < tokens.length; i++) {
@@ -702,7 +714,276 @@ async function runT5Encoder(tokens, weightsEntry, config, runtime, options = {})
702
714
  };
703
715
  }
704
716
 
717
+ function buildGemma2LayerTypes(layerCount, slidingWindow) {
718
+ if (!Number.isFinite(slidingWindow) || slidingWindow <= 0) {
719
+ return Array.from({ length: layerCount }, () => 'full_attention');
720
+ }
721
+ return Array.from({ length: layerCount }, (_, index) => (
722
+ index % 2 === 1 ? 'full_attention' : 'sliding_attention'
723
+ ));
724
+ }
725
+
726
+ export function resolveGemma2WeightRoot(weights, prefix = 'text_encoder') {
727
+ const nestedRoot = `${prefix}.model`;
728
+ if (weights?.has(`${nestedRoot}.embed_tokens.weight`)) {
729
+ return nestedRoot;
730
+ }
731
+ if (weights?.has(`${prefix}.embed_tokens.weight`)) {
732
+ return prefix;
733
+ }
734
+ return nestedRoot;
735
+ }
736
+
737
+ function getGemma2LayerWeight(weights, weightRoot, layerIdx, suffix, required = true) {
738
+ const key = `${weightRoot}.layers.${layerIdx}.${suffix}`;
739
+ const weight = weights.get(key) || null;
740
+ if (!weight && required) {
741
+ throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
742
+ }
743
+ return weight;
744
+ }
745
+
746
+ function resolveGemma2TextConfig(config) {
747
+ const hiddenSize = config.hidden_size;
748
+ const numHeads = config.num_attention_heads;
749
+ const numKVHeads = config.num_key_value_heads ?? numHeads;
750
+ const headDim = config.head_dim ?? (
751
+ Number.isFinite(hiddenSize) && Number.isFinite(numHeads) && numHeads > 0
752
+ ? Math.floor(hiddenSize / numHeads)
753
+ : null
754
+ );
755
+ const numLayers = config.num_hidden_layers;
756
+ const intermediateSize = config.intermediate_size;
757
+ const maxPositionEmbeddings = config.max_position_embeddings;
758
+ const rmsNormEps = config.rms_norm_eps ?? 1e-6;
759
+
760
+ if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
761
+ throw new Error('Gemma2 diffusion text encoder requires hidden_size.');
762
+ }
763
+ if (!Number.isFinite(numHeads) || numHeads <= 0) {
764
+ throw new Error('Gemma2 diffusion text encoder requires num_attention_heads.');
765
+ }
766
+ if (!Number.isFinite(numKVHeads) || numKVHeads <= 0) {
767
+ throw new Error('Gemma2 diffusion text encoder requires num_key_value_heads.');
768
+ }
769
+ if (!Number.isFinite(headDim) || headDim <= 0) {
770
+ throw new Error('Gemma2 diffusion text encoder requires head_dim or hidden_size/num_attention_heads.');
771
+ }
772
+ if (!Number.isFinite(numLayers) || numLayers <= 0) {
773
+ throw new Error('Gemma2 diffusion text encoder requires num_hidden_layers.');
774
+ }
775
+ if (!Number.isFinite(intermediateSize) || intermediateSize <= 0) {
776
+ throw new Error('Gemma2 diffusion text encoder requires intermediate_size.');
777
+ }
778
+ if (!Number.isFinite(maxPositionEmbeddings) || maxPositionEmbeddings <= 0) {
779
+ throw new Error('Gemma2 diffusion text encoder requires max_position_embeddings.');
780
+ }
781
+
782
+ return {
783
+ hiddenSize,
784
+ numHeads,
785
+ numKVHeads,
786
+ headDim,
787
+ numLayers,
788
+ intermediateSize,
789
+ maxPositionEmbeddings,
790
+ rmsNormEps,
791
+ ropeTheta: config.rope_theta ?? 10000,
792
+ slidingWindow: config.sliding_window ?? 4096,
793
+ scaleEmbeddings: config.scale_embeddings !== false,
794
+ };
795
+ }
796
+
797
+ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, options = {}) {
798
+ const device = getDevice();
799
+ if (!device) throw new Error('Gemma2 diffusion text encoder requires a WebGPU device.');
800
+ if (!weightsEntry?.weights || !weightsEntry?.shapes) {
801
+ throw new Error('Gemma2 diffusion text encoder requires loaded weights.');
802
+ }
803
+
804
+ const prefix = options.prefix ?? 'text_encoder';
805
+ const localRecorder = options.recorder
806
+ ? null
807
+ : (options.profile ? new CommandRecorder(device, `${prefix}_gemma2_encoder`, { profile: true }) : null);
808
+ const recorder = options.recorder ?? localRecorder;
809
+ const ops = createKernelOps(recorder);
810
+ const release = createDiffusionBufferReleaser(recorder);
811
+ const destroy = createDiffusionBufferDestroyer(recorder);
812
+ const weights = weightsEntry.weights;
813
+ const activationDtype = resolveDiffusionActivationDtype(runtime);
814
+ const resolved = resolveGemma2TextConfig(config);
815
+ const padTokenId = config.pad_token_id ?? config.bos_token_id ?? 0;
816
+ const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
817
+ const numTokens = tokenIds.length;
818
+ const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
819
+ const weightRoot = resolveGemma2WeightRoot(weights, prefix);
820
+
821
+ const embedKey = `${weightRoot}.embed_tokens.weight`;
822
+ const embedWeight = expectDiffusionWeight(
823
+ weights.get(embedKey),
824
+ embedKey
825
+ );
826
+ const embedDtype = resolveEmbeddingDtype(embedWeight, weightsEntry, embedKey, runtime);
827
+ let hidden = await ops.gather(
828
+ tokenBuffer,
829
+ getBuffer(embedWeight),
830
+ numTokens,
831
+ resolved.hiddenSize,
832
+ config.vocab_size,
833
+ {
834
+ embeddingDtype: embedDtype,
835
+ outputDtype: activationDtype,
836
+ transpose: false,
837
+ }
838
+ );
839
+ destroy(tokenBuffer);
840
+
841
+ if (resolved.scaleEmbeddings) {
842
+ const scaled = await ops.scale(hidden, Math.sqrt(resolved.hiddenSize), {
843
+ count: numTokens * resolved.hiddenSize,
844
+ });
845
+ release(hidden.buffer);
846
+ hidden = createTensor(scaled.buffer, scaled.dtype, [numTokens, resolved.hiddenSize], 'gemma2_embed');
847
+ }
848
+
849
+ const layerWeights = new Map();
850
+ for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
851
+ layerWeights.set(`layer_${layerIdx}`, {
852
+ inputNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'input_layernorm.weight'),
853
+ qProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.q_proj.weight'),
854
+ kProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.k_proj.weight'),
855
+ vProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.v_proj.weight'),
856
+ oProj: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'self_attn.o_proj.weight'),
857
+ postAttentionNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'post_attention_layernorm.weight'),
858
+ preFeedforwardNorm: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'pre_feedforward_layernorm.weight'),
859
+ gate: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.gate_proj.weight'),
860
+ up: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.up_proj.weight'),
861
+ down: getGemma2LayerWeight(weights, weightRoot, layerIdx, 'mlp.down_proj.weight'),
862
+ });
863
+ }
864
+
865
+ const ropeFreqs = await initRoPEFrequencies({
866
+ headDim: resolved.headDim,
867
+ maxSeqLen: resolved.maxPositionEmbeddings,
868
+ ropeTheta: resolved.ropeTheta,
869
+ ropeLocalTheta: null,
870
+ ropeScale: 1,
871
+ ropeLocalScale: null,
872
+ ropeScalingType: null,
873
+ ropeLocalScalingType: null,
874
+ ropeScaling: null,
875
+ ropeLocalScaling: null,
876
+ }, true);
877
+
878
+ const context = {
879
+ useGPU: true,
880
+ activationDtype,
881
+ recorder,
882
+ currentSeqLen: 0,
883
+ kvCache: null,
884
+ weightConfig: {
885
+ rmsNormWeightOffset: true,
886
+ },
887
+ debugFlags: {},
888
+ weights: layerWeights,
889
+ ropeFreqsCos: ropeFreqs.cos,
890
+ ropeFreqsSin: ropeFreqs.sin,
891
+ ropeLocalCos: ropeFreqs.localCos,
892
+ ropeLocalSin: ropeFreqs.localSin,
893
+ config: {
894
+ hiddenSize: resolved.hiddenSize,
895
+ intermediateSize: resolved.intermediateSize,
896
+ numHeads: resolved.numHeads,
897
+ numKVHeads: resolved.numKVHeads,
898
+ headDim: resolved.headDim,
899
+ rmsNormEps: resolved.rmsNormEps,
900
+ slidingWindow: resolved.slidingWindow,
901
+ attnLogitSoftcapping: 50.0,
902
+ queryPreAttnScalar: resolved.headDim,
903
+ queryKeyNorm: false,
904
+ attentionOutputGate: false,
905
+ causalAttention: true,
906
+ hiddenActivation: 'gelu',
907
+ swigluLimit: null,
908
+ useMoE: false,
909
+ layerTypes: buildGemma2LayerTypes(resolved.numLayers, resolved.slidingWindow),
910
+ preFeedforwardNorm: true,
911
+ postFeedforwardNorm: false,
912
+ postAttentionNorm: true,
913
+ },
914
+ };
915
+
916
+ for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
917
+ const output = await processLayerGPU(
918
+ layerIdx,
919
+ hidden.buffer,
920
+ numTokens,
921
+ true,
922
+ numTokens * resolved.hiddenSize,
923
+ context
924
+ );
925
+ hidden = createTensor(output, activationDtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
926
+ }
927
+
928
+ const finalNormKey = `${weightRoot}.norm.weight`;
929
+ const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
930
+ const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
931
+ batchSize: numTokens,
932
+ hiddenSize: resolved.hiddenSize,
933
+ rmsNormWeightOffset: true,
934
+ });
935
+ release(hidden.buffer);
936
+
937
+ let profile = null;
938
+ if (localRecorder) {
939
+ localRecorder.submit();
940
+ const timings = await localRecorder.resolveProfileTimings();
941
+ profile = timings ? { totalMs: sumDiffusionProfileTimings(timings) ?? 0, timings } : { totalMs: null };
942
+ }
943
+
944
+ return {
945
+ hidden: final,
946
+ attentionMask: Uint32Array.from({ length: numTokens }, () => 1),
947
+ maxLength: numTokens,
948
+ hiddenSize: resolved.hiddenSize,
949
+ profile,
950
+ };
951
+ }
952
+
705
953
  export async function runTextEncodersForPrompt(tokensByEncoder, weightsByComponent, modelConfig, runtime, options = {}) {
954
+ const layout = modelConfig?.layout ?? 'sd3';
955
+ if (layout === 'sana') {
956
+ const gemmaConfig = modelConfig?.components?.text_encoder?.config || {};
957
+ const gemmaMaxLength = runtime?.textEncoder?.maxLength;
958
+ if (!Number.isFinite(gemmaMaxLength) || gemmaMaxLength <= 0) {
959
+ throw new Error('Sana Gemma2 encoder requires runtime.textEncoder.maxLength.');
960
+ }
961
+ const profileEnabled = options.profile === true;
962
+ const gemma = await runGemma2TextEncoder(
963
+ tokensByEncoder.text_encoder,
964
+ weightsByComponent.text_encoder,
965
+ gemmaConfig,
966
+ runtime,
967
+ {
968
+ prefix: 'text_encoder',
969
+ maxLength: gemmaMaxLength,
970
+ profile: profileEnabled,
971
+ }
972
+ );
973
+
974
+ return {
975
+ pooled: new Float32Array(0),
976
+ context: gemma.hidden,
977
+ attentionMask: gemma.attentionMask,
978
+ profile: profileEnabled
979
+ ? {
980
+ totalMs: gemma.profile?.totalMs ?? null,
981
+ gemmaMs: gemma.profile?.totalMs ?? null,
982
+ }
983
+ : null,
984
+ };
985
+ }
986
+
706
987
  const clipConfig = modelConfig?.components?.text_encoder?.config || {};
707
988
  const clip2Config = modelConfig?.components?.text_encoder_2?.config || {};
708
989
  const t5Config = modelConfig?.components?.text_encoder_3?.config || {};
@@ -749,6 +1030,7 @@ export async function runTextEncodersForPrompt(tokensByEncoder, weightsByCompone
749
1030
  return {
750
1031
  pooled,
751
1032
  context: t5.hidden,
1033
+ attentionMask: null,
752
1034
  profile,
753
1035
  };
754
1036
  }
@@ -1,5 +1,6 @@
1
1
  import { BPETokenizer } from '../../tokenizers/bpe.js';
2
2
  import { SentencePieceTokenizer } from '../../tokenizers/sentencepiece.js';
3
+ import { BundledTokenizer } from '../../tokenizers/bundled.js';
3
4
  import { loadAuxText, loadAuxFile } from '../../../storage/shard-manager.js';
4
5
 
5
6
  function parseMerges(text) {
@@ -136,11 +137,27 @@ async function loadSentencePieceTokenizer(tokenizerConfig, options = {}) {
136
137
  return tokenizer;
137
138
  }
138
139
 
140
+ async function loadBundledTokenizer(tokenizerConfig, options = {}) {
141
+ const { baseUrl } = options;
142
+ const tokenizerJsonText = await loadTextAsset(tokenizerConfig.tokenizerFile, baseUrl);
143
+ const tokenizerJson = JSON.parse(tokenizerJsonText);
144
+ const tokenizer = new BundledTokenizer({
145
+ vocabSize: 0,
146
+ deferSpecialTokens: true,
147
+ });
148
+ tokenizer.load(tokenizerJson);
149
+ return tokenizer;
150
+ }
151
+
139
152
  export async function loadDiffusionTokenizers(diffusionConfig, options = {}) {
140
153
  const tokenizers = {};
141
154
  const config = diffusionConfig?.tokenizers || {};
142
155
  if (config.text_encoder) {
143
- tokenizers.text_encoder = await loadBpeTokenizer(config.text_encoder, options);
156
+ if (config.text_encoder.type === 'bundled') {
157
+ tokenizers.text_encoder = await loadBundledTokenizer(config.text_encoder, options);
158
+ } else {
159
+ tokenizers.text_encoder = await loadBpeTokenizer(config.text_encoder, options);
160
+ }
144
161
  }
145
162
  if (config.text_encoder_2) {
146
163
  tokenizers.text_encoder_2 = await loadBpeTokenizer(config.text_encoder_2, options);
@@ -27,6 +27,10 @@ export interface DiffusionSchedulerConfig {
27
27
  eta: number;
28
28
  numTrainTimesteps: number;
29
29
  shift: number;
30
+ predictionType?: string;
31
+ sigmaData?: number;
32
+ maxTimesteps?: number;
33
+ intermediateTimesteps?: number;
30
34
  }
31
35
 
32
36
  export interface DiffusionLatentConfig {