@simulatte/doppler 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. package/README.md +11 -5
  2. package/package.json +27 -4
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.d.ts +80 -0
  6. package/src/client/doppler-api.js +298 -0
  7. package/src/client/doppler-provider/types.js +1 -1
  8. package/src/client/doppler-registry.d.ts +23 -0
  9. package/src/client/doppler-registry.js +88 -0
  10. package/src/client/doppler-registry.json +39 -0
  11. package/src/config/execution-contract-check.d.ts +82 -0
  12. package/src/config/execution-contract-check.js +317 -0
  13. package/src/config/execution-v0-contract-check.d.ts +94 -0
  14. package/src/config/execution-v0-contract-check.js +251 -0
  15. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  16. package/src/config/execution-v0-graph-contract-check.js +64 -0
  17. package/src/config/kernel-path-contract-check.d.ts +76 -0
  18. package/src/config/kernel-path-contract-check.js +479 -0
  19. package/src/config/kernel-path-loader.d.ts +16 -0
  20. package/src/config/kernel-path-loader.js +54 -0
  21. package/src/config/kernels/kernel-ref-digests.js +12 -0
  22. package/src/config/kernels/registry.json +556 -0
  23. package/src/config/loader.js +90 -67
  24. package/src/config/merge-contract-check.d.ts +16 -0
  25. package/src/config/merge-contract-check.js +321 -0
  26. package/src/config/merge-helpers.d.ts +58 -0
  27. package/src/config/merge-helpers.js +54 -0
  28. package/src/config/merge.js +3 -6
  29. package/src/config/presets/models/janus-text.json +27 -0
  30. package/src/config/quantization-contract-check.d.ts +12 -0
  31. package/src/config/quantization-contract-check.js +91 -0
  32. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  33. package/src/config/required-inference-fields-contract-check.js +231 -0
  34. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  35. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  36. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  37. package/src/config/schema/conversion-report.schema.js +108 -0
  38. package/src/config/schema/doppler.schema.js +12 -18
  39. package/src/config/schema/index.d.ts +22 -0
  40. package/src/config/schema/index.js +18 -0
  41. package/src/converter/core.d.ts +10 -0
  42. package/src/converter/core.js +49 -11
  43. package/src/converter/parsers/diffusion.js +63 -3
  44. package/src/converter/tokenizer-utils.js +17 -3
  45. package/src/formats/rdrr/validation.js +13 -0
  46. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  47. package/src/gpu/kernels/depthwise_conv2d.js +98 -0
  48. package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
  49. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
  50. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  51. package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
  52. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
  53. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
  54. package/src/gpu/kernels/index.d.ts +30 -0
  55. package/src/gpu/kernels/index.js +25 -0
  56. package/src/gpu/kernels/relu.d.ts +18 -0
  57. package/src/gpu/kernels/relu.js +45 -0
  58. package/src/gpu/kernels/relu.wgsl +21 -0
  59. package/src/gpu/kernels/relu_f16.wgsl +23 -0
  60. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  61. package/src/gpu/kernels/repeat_channels.js +60 -0
  62. package/src/gpu/kernels/repeat_channels.wgsl +29 -0
  63. package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
  64. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  65. package/src/gpu/kernels/sana_linear_attention.js +122 -0
  66. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
  67. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
  68. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
  69. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
  70. package/src/index-browser.d.ts +1 -0
  71. package/src/index-browser.js +2 -1
  72. package/src/index.d.ts +1 -0
  73. package/src/index.js +2 -1
  74. package/src/inference/browser-harness.js +164 -38
  75. package/src/inference/pipelines/diffusion/init.js +14 -0
  76. package/src/inference/pipelines/diffusion/pipeline.js +206 -77
  77. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  78. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  79. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  80. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  81. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
  82. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
  83. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  84. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  85. package/src/inference/pipelines/diffusion/vae.js +782 -78
  86. package/src/inference/pipelines/text/config.d.ts +5 -0
  87. package/src/inference/pipelines/text/config.js +1 -1
  88. package/src/inference/pipelines/text/execution-v0.js +141 -101
  89. package/src/inference/pipelines/text/init.js +41 -10
  90. package/src/inference/pipelines/text.js +7 -1
  91. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  92. package/src/rules/execution-rules-contract-check.js +245 -0
  93. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  94. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  95. package/src/rules/kernels/relu.rules.json +6 -0
  96. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  97. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  98. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  99. package/src/rules/layer-pattern-contract-check.js +231 -0
  100. package/src/rules/rule-registry.d.ts +28 -0
  101. package/src/rules/rule-registry.js +38 -0
  102. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  103. package/src/tooling/conversion-config-materializer.js +99 -0
  104. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  105. package/src/tooling/lean-execution-contract-runner.js +158 -0
  106. package/src/tooling/lean-execution-contract.d.ts +16 -0
  107. package/src/tooling/lean-execution-contract.js +81 -0
  108. package/src/tooling/node-convert.d.ts +10 -0
  109. package/src/tooling/node-converter.js +59 -0
  110. package/src/tooling/node-webgpu.js +30 -9
  111. package/src/version.d.ts +2 -0
  112. package/src/version.js +2 -0
  113. package/tools/convert-safetensors-node.js +47 -0
  114. package/tools/doppler-cli.js +167 -6
@@ -9,11 +9,27 @@ import type { DiffusionSchedulerConfig } from './types.js';
9
9
  export interface DiffusionScheduler {
10
10
  type: string;
11
11
  steps: number;
12
- sigmas: Float32Array;
12
+ sigmas: Float32Array | null;
13
13
  timesteps: Float32Array;
14
+ predictionType?: string;
15
+ sigmaData?: number;
16
+ }
17
+
18
+ export interface DiffusionSchedulerStepResult {
19
+ prevSample: Float32Array;
20
+ predOriginalSample: Float32Array;
14
21
  }
15
22
 
16
23
  export declare function buildScheduler(
17
24
  config: DiffusionSchedulerConfig,
18
25
  stepsOverride?: number | null
19
26
  ): DiffusionScheduler;
27
+
28
+ export declare function stepScmScheduler(
29
+ config: DiffusionScheduler,
30
+ modelOutput: Float32Array,
31
+ timestep: number,
32
+ sample: Float32Array,
33
+ stepIndex?: number,
34
+ noise?: Float32Array | null
35
+ ): DiffusionSchedulerStepResult;
@@ -34,6 +34,84 @@ function buildFlowMatchSchedule(config, steps) {
34
34
  return sigmas;
35
35
  }
36
36
 
37
+ function buildScmTimesteps(steps, config) {
38
+ const maxTimesteps = Number.isFinite(config.maxTimesteps) ? config.maxTimesteps : 1.5708;
39
+ const intermediateTimesteps = Number.isFinite(config.intermediateTimesteps)
40
+ ? config.intermediateTimesteps
41
+ : 1.3;
42
+ const count = Math.max(1, steps);
43
+ if (count === 1) {
44
+ return new Float32Array([maxTimesteps, 0.0]);
45
+ }
46
+ if (count === 2) {
47
+ return new Float32Array([maxTimesteps, intermediateTimesteps, 0.0]);
48
+ }
49
+ return linspace(maxTimesteps, 0.0, count + 1);
50
+ }
51
+
52
+ export function stepScmScheduler(config, modelOutput, timestep, sample, stepIndex = 0, noise = null) {
53
+ if (!config || config.type !== 'scm') {
54
+ throw new Error('stepScmScheduler requires scheduler.type="scm".');
55
+ }
56
+ if (!(modelOutput instanceof Float32Array)) {
57
+ throw new Error('stepScmScheduler requires modelOutput as Float32Array.');
58
+ }
59
+ if (!(sample instanceof Float32Array)) {
60
+ throw new Error('stepScmScheduler requires sample as Float32Array.');
61
+ }
62
+ if (modelOutput.length !== sample.length) {
63
+ throw new Error(
64
+ `stepScmScheduler requires modelOutput and sample with matching sizes; got ${modelOutput.length} and ${sample.length}.`
65
+ );
66
+ }
67
+ if (!(config.timesteps instanceof Float32Array) || config.timesteps.length < 2) {
68
+ throw new Error('stepScmScheduler requires scheduler.timesteps with length >= 2.');
69
+ }
70
+ if (!Number.isInteger(stepIndex) || stepIndex < 0 || stepIndex + 1 >= config.timesteps.length) {
71
+ throw new Error(
72
+ `stepScmScheduler received invalid stepIndex=${stepIndex} for ${config.timesteps.length} timesteps.`
73
+ );
74
+ }
75
+
76
+ const parameterization = config.predictionType ?? 'trigflow';
77
+ if (parameterization !== 'trigflow') {
78
+ throw new Error(`Unsupported SCM predictionType "${parameterization}".`);
79
+ }
80
+
81
+ const s = config.timesteps[stepIndex];
82
+ const t = config.timesteps[stepIndex + 1];
83
+ const predOriginalSample = new Float32Array(sample.length);
84
+ const prevSample = new Float32Array(sample.length);
85
+
86
+ const cosS = Math.cos(s);
87
+ const sinS = Math.sin(s);
88
+ const cosT = Math.cos(t);
89
+ const sinT = Math.sin(t);
90
+
91
+ for (let i = 0; i < sample.length; i++) {
92
+ const predX0 = cosS * sample[i] - sinS * modelOutput[i];
93
+ predOriginalSample[i] = predX0;
94
+ prevSample[i] = predX0;
95
+ }
96
+
97
+ if (stepIndex + 1 < config.timesteps.length - 1) {
98
+ if (!(noise instanceof Float32Array) || noise.length !== sample.length) {
99
+ throw new Error(
100
+ 'stepScmScheduler requires a Float32Array noise tensor for multi-step SCM updates.'
101
+ );
102
+ }
103
+ const sigmaData = Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5;
104
+ for (let i = 0; i < prevSample.length; i++) {
105
+ prevSample[i] = cosT * predOriginalSample[i] + sinT * noise[i] * sigmaData;
106
+ }
107
+ }
108
+
109
+ return {
110
+ prevSample,
111
+ predOriginalSample,
112
+ };
113
+ }
114
+
37
115
  export function buildScheduler(config, stepsOverride = null) {
38
116
  if (!config) {
39
117
  throw new Error('Scheduler config is required');
@@ -43,15 +121,25 @@ export function buildScheduler(config, stepsOverride = null) {
43
121
  if (typeof type !== 'string' || !type) {
44
122
  throw new Error('Diffusion scheduler requires a scheduler type.');
45
123
  }
46
- const sigmas = type === 'flowmatch_euler'
47
- ? buildFlowMatchSchedule(config, steps)
48
- : buildLinearSigmaSchedule(steps);
49
124
  const trainSteps = Number.isFinite(config.numTrainTimesteps)
50
125
  ? config.numTrainTimesteps
51
126
  : null;
52
127
  if (!Number.isFinite(trainSteps) || trainSteps <= 0) {
53
128
  throw new Error('Diffusion scheduler requires valid numTrainTimesteps.');
54
129
  }
130
+ if (type === 'scm') {
131
+ return {
132
+ type,
133
+ steps,
134
+ sigmas: null,
135
+ timesteps: buildScmTimesteps(steps, config),
136
+ predictionType: config.predictionType ?? 'trigflow',
137
+ sigmaData: Number.isFinite(config.sigmaData) ? config.sigmaData : 0.5,
138
+ };
139
+ }
140
+ const sigmas = type === 'flowmatch_euler'
141
+ ? buildFlowMatchSchedule(config, steps)
142
+ : buildLinearSigmaSchedule(steps);
55
143
  const timesteps = new Float32Array(steps);
56
144
  for (let i = 0; i < steps; i++) {
57
145
  timesteps[i] = sigmas[i] * trainSteps;
@@ -16,25 +16,27 @@ export interface DiffusionTextEncoderWeightsEntry {
16
16
 
17
17
  export interface DiffusionTextEncoderWeights {
18
18
  text_encoder: DiffusionTextEncoderWeightsEntry;
19
- text_encoder_2: DiffusionTextEncoderWeightsEntry;
20
- text_encoder_3: DiffusionTextEncoderWeightsEntry;
19
+ text_encoder_2?: DiffusionTextEncoderWeightsEntry | null;
20
+ text_encoder_3?: DiffusionTextEncoderWeightsEntry | null;
21
21
  transformer?: DiffusionTextEncoderWeightsEntry;
22
22
  }
23
23
 
24
24
  export interface DiffusionTextTokens {
25
25
  text_encoder: number[];
26
- text_encoder_2: number[];
27
- text_encoder_3: number[];
26
+ text_encoder_2?: number[];
27
+ text_encoder_3?: number[];
28
28
  }
29
29
 
30
30
  export interface DiffusionTextConditioning {
31
31
  pooled: Float32Array;
32
32
  context: Tensor;
33
+ attentionMask?: Uint32Array | null;
33
34
  profile?: {
34
35
  totalMs?: number | null;
35
36
  clipMs?: number | null;
36
37
  clip2Ms?: number | null;
37
38
  t5Ms?: number | null;
39
+ gemmaMs?: number | null;
38
40
  } | null;
39
41
  }
40
42
 
@@ -40,6 +40,8 @@ import {
40
40
  inferDiffusionMatmulDtypeFromBuffer,
41
41
  sumDiffusionProfileTimings,
42
42
  } from './helpers.js';
43
+ import { initRoPEFrequencies } from '../text/init.js';
44
+ import { processLayerGPU } from '../text/layer.js';
43
45
 
44
46
  const QUICK_GELU_ALPHA = 1.702;
45
47
  const SUPPORTED_CLIP_HIDDEN_ACTIVATIONS = new Set(['gelu', 'quick_gelu']);
@@ -56,6 +58,16 @@ function padTokens(tokens, maxLength, padTokenId) {
56
58
  return out;
57
59
  }
58
60
 
61
+ function normalizeTokens(tokens, maxLength, fallbackTokenId) {
62
+ const source = Array.isArray(tokens) ? tokens : [];
63
+ const limit = Number.isFinite(maxLength) && maxLength > 0 ? Math.floor(maxLength) : source.length;
64
+ const trimmed = source.slice(0, limit);
65
+ if (trimmed.length > 0) {
66
+ return Uint32Array.from(trimmed);
67
+ }
68
+ return new Uint32Array([fallbackTokenId >>> 0]);
69
+ }
70
+
59
71
  function findEosIndex(tokens, eosTokenId) {
60
72
  if (eosTokenId == null) return tokens.length - 1;
61
73
  for (let i = 0; i < tokens.length; i++) {
@@ -702,7 +714,264 @@ async function runT5Encoder(tokens, weightsEntry, config, runtime, options = {})
702
714
  };
703
715
  }
704
716
 
717
+ function buildGemma2LayerTypes(layerCount, slidingWindow) {
718
+ if (!Number.isFinite(slidingWindow) || slidingWindow <= 0) {
719
+ return Array.from({ length: layerCount }, () => 'full_attention');
720
+ }
721
+ return Array.from({ length: layerCount }, (_, index) => (
722
+ index % 2 === 1 ? 'full_attention' : 'sliding_attention'
723
+ ));
724
+ }
725
+
726
+ function getGemma2LayerWeight(weights, prefix, layerIdx, suffix, required = true) {
727
+ const key = `${prefix}.model.layers.${layerIdx}.${suffix}`;
728
+ const weight = weights.get(key) || null;
729
+ if (!weight && required) {
730
+ throw new Error(`Missing Gemma2 diffusion weight "${key}".`);
731
+ }
732
+ return weight;
733
+ }
734
+
735
+ function resolveGemma2TextConfig(config) {
736
+ const hiddenSize = config.hidden_size;
737
+ const numHeads = config.num_attention_heads;
738
+ const numKVHeads = config.num_key_value_heads ?? numHeads;
739
+ const headDim = config.head_dim ?? (
740
+ Number.isFinite(hiddenSize) && Number.isFinite(numHeads) && numHeads > 0
741
+ ? Math.floor(hiddenSize / numHeads)
742
+ : null
743
+ );
744
+ const numLayers = config.num_hidden_layers;
745
+ const intermediateSize = config.intermediate_size;
746
+ const maxPositionEmbeddings = config.max_position_embeddings;
747
+ const rmsNormEps = config.rms_norm_eps ?? 1e-6;
748
+
749
+ if (!Number.isFinite(hiddenSize) || hiddenSize <= 0) {
750
+ throw new Error('Gemma2 diffusion text encoder requires hidden_size.');
751
+ }
752
+ if (!Number.isFinite(numHeads) || numHeads <= 0) {
753
+ throw new Error('Gemma2 diffusion text encoder requires num_attention_heads.');
754
+ }
755
+ if (!Number.isFinite(numKVHeads) || numKVHeads <= 0) {
756
+ throw new Error('Gemma2 diffusion text encoder requires num_key_value_heads.');
757
+ }
758
+ if (!Number.isFinite(headDim) || headDim <= 0) {
759
+ throw new Error('Gemma2 diffusion text encoder requires head_dim or hidden_size/num_attention_heads.');
760
+ }
761
+ if (!Number.isFinite(numLayers) || numLayers <= 0) {
762
+ throw new Error('Gemma2 diffusion text encoder requires num_hidden_layers.');
763
+ }
764
+ if (!Number.isFinite(intermediateSize) || intermediateSize <= 0) {
765
+ throw new Error('Gemma2 diffusion text encoder requires intermediate_size.');
766
+ }
767
+ if (!Number.isFinite(maxPositionEmbeddings) || maxPositionEmbeddings <= 0) {
768
+ throw new Error('Gemma2 diffusion text encoder requires max_position_embeddings.');
769
+ }
770
+
771
+ return {
772
+ hiddenSize,
773
+ numHeads,
774
+ numKVHeads,
775
+ headDim,
776
+ numLayers,
777
+ intermediateSize,
778
+ maxPositionEmbeddings,
779
+ rmsNormEps,
780
+ ropeTheta: config.rope_theta ?? 10000,
781
+ slidingWindow: config.sliding_window ?? 4096,
782
+ scaleEmbeddings: config.scale_embeddings !== false,
783
+ };
784
+ }
785
+
786
+ async function runGemma2TextEncoder(tokens, weightsEntry, config, runtime, options = {}) {
787
+ const device = getDevice();
788
+ if (!device) throw new Error('Gemma2 diffusion text encoder requires a WebGPU device.');
789
+ if (!weightsEntry?.weights || !weightsEntry?.shapes) {
790
+ throw new Error('Gemma2 diffusion text encoder requires loaded weights.');
791
+ }
792
+
793
+ const prefix = options.prefix ?? 'text_encoder';
794
+ const localRecorder = options.recorder
795
+ ? null
796
+ : (options.profile ? new CommandRecorder(device, `${prefix}_gemma2_encoder`, { profile: true }) : null);
797
+ const recorder = options.recorder ?? localRecorder;
798
+ const ops = createKernelOps(recorder);
799
+ const release = createDiffusionBufferReleaser(recorder);
800
+ const destroy = createDiffusionBufferDestroyer(recorder);
801
+ const weights = weightsEntry.weights;
802
+ const activationDtype = resolveDiffusionActivationDtype(runtime);
803
+ const resolved = resolveGemma2TextConfig(config);
804
+ const padTokenId = config.pad_token_id ?? config.bos_token_id ?? 0;
805
+ const tokenIds = normalizeTokens(tokens, options.maxLength ?? resolved.maxPositionEmbeddings, padTokenId);
806
+ const numTokens = tokenIds.length;
807
+ const tokenBuffer = createDiffusionIndexBuffer(device, tokenIds, `${prefix}_tokens`);
808
+
809
+ const embedKey = `${prefix}.model.embed_tokens.weight`;
810
+ const embedWeight = expectDiffusionWeight(
811
+ weights.get(embedKey),
812
+ embedKey
813
+ );
814
+ const embedDtype = resolveEmbeddingDtype(embedWeight, weightsEntry, embedKey, runtime);
815
+ let hidden = await ops.gather(
816
+ tokenBuffer,
817
+ getBuffer(embedWeight),
818
+ numTokens,
819
+ resolved.hiddenSize,
820
+ config.vocab_size,
821
+ {
822
+ embeddingDtype: embedDtype,
823
+ outputDtype: activationDtype,
824
+ transpose: false,
825
+ }
826
+ );
827
+ destroy(tokenBuffer);
828
+
829
+ if (resolved.scaleEmbeddings) {
830
+ const scaled = await ops.scale(hidden, Math.sqrt(resolved.hiddenSize), {
831
+ count: numTokens * resolved.hiddenSize,
832
+ });
833
+ release(hidden.buffer);
834
+ hidden = createTensor(scaled.buffer, scaled.dtype, [numTokens, resolved.hiddenSize], 'gemma2_embed');
835
+ }
836
+
837
+ const layerWeights = new Map();
838
+ for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
839
+ layerWeights.set(`layer_${layerIdx}`, {
840
+ inputNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'input_layernorm.weight'),
841
+ qProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.q_proj.weight'),
842
+ kProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.k_proj.weight'),
843
+ vProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.v_proj.weight'),
844
+ oProj: getGemma2LayerWeight(weights, prefix, layerIdx, 'self_attn.o_proj.weight'),
845
+ postAttentionNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'post_attention_layernorm.weight'),
846
+ preFeedforwardNorm: getGemma2LayerWeight(weights, prefix, layerIdx, 'pre_feedforward_layernorm.weight'),
847
+ gate: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.gate_proj.weight'),
848
+ up: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.up_proj.weight'),
849
+ down: getGemma2LayerWeight(weights, prefix, layerIdx, 'mlp.down_proj.weight'),
850
+ });
851
+ }
852
+
853
+ const ropeFreqs = await initRoPEFrequencies({
854
+ headDim: resolved.headDim,
855
+ maxSeqLen: resolved.maxPositionEmbeddings,
856
+ ropeTheta: resolved.ropeTheta,
857
+ ropeLocalTheta: null,
858
+ ropeScale: 1,
859
+ ropeLocalScale: null,
860
+ ropeScalingType: null,
861
+ ropeLocalScalingType: null,
862
+ ropeScaling: null,
863
+ ropeLocalScaling: null,
864
+ }, true);
865
+
866
+ const context = {
867
+ useGPU: true,
868
+ activationDtype,
869
+ recorder,
870
+ currentSeqLen: 0,
871
+ kvCache: null,
872
+ weightConfig: {
873
+ rmsNormWeightOffset: true,
874
+ },
875
+ debugFlags: {},
876
+ weights: layerWeights,
877
+ ropeFreqsCos: ropeFreqs.cos,
878
+ ropeFreqsSin: ropeFreqs.sin,
879
+ ropeLocalCos: ropeFreqs.localCos,
880
+ ropeLocalSin: ropeFreqs.localSin,
881
+ config: {
882
+ hiddenSize: resolved.hiddenSize,
883
+ intermediateSize: resolved.intermediateSize,
884
+ numHeads: resolved.numHeads,
885
+ numKVHeads: resolved.numKVHeads,
886
+ headDim: resolved.headDim,
887
+ rmsNormEps: resolved.rmsNormEps,
888
+ slidingWindow: resolved.slidingWindow,
889
+ attnLogitSoftcapping: 50.0,
890
+ queryPreAttnScalar: resolved.headDim,
891
+ queryKeyNorm: false,
892
+ attentionOutputGate: false,
893
+ causalAttention: true,
894
+ hiddenActivation: 'gelu',
895
+ swigluLimit: null,
896
+ useMoE: false,
897
+ layerTypes: buildGemma2LayerTypes(resolved.numLayers, resolved.slidingWindow),
898
+ preFeedforwardNorm: true,
899
+ postFeedforwardNorm: false,
900
+ postAttentionNorm: true,
901
+ },
902
+ };
903
+
904
+ for (let layerIdx = 0; layerIdx < resolved.numLayers; layerIdx++) {
905
+ const output = await processLayerGPU(
906
+ layerIdx,
907
+ hidden.buffer,
908
+ numTokens,
909
+ true,
910
+ numTokens * resolved.hiddenSize,
911
+ context
912
+ );
913
+ hidden = createTensor(output.buffer, output.dtype, [numTokens, resolved.hiddenSize], `gemma2_layer_${layerIdx}`);
914
+ }
915
+
916
+ const finalNormKey = `${prefix}.model.norm.weight`;
917
+ const finalNorm = expectDiffusionWeight(weights.get(finalNormKey), finalNormKey);
918
+ const final = await ops.rmsNorm(hidden, getBuffer(finalNorm), resolved.rmsNormEps, {
919
+ batchSize: numTokens,
920
+ hiddenSize: resolved.hiddenSize,
921
+ rmsNormWeightOffset: true,
922
+ });
923
+ release(hidden.buffer);
924
+
925
+ let profile = null;
926
+ if (localRecorder) {
927
+ localRecorder.submit();
928
+ const timings = await localRecorder.resolveProfileTimings();
929
+ profile = timings ? { totalMs: sumDiffusionProfileTimings(timings) ?? 0, timings } : { totalMs: null };
930
+ }
931
+
932
+ return {
933
+ hidden: final,
934
+ attentionMask: Uint32Array.from({ length: numTokens }, () => 1),
935
+ maxLength: numTokens,
936
+ hiddenSize: resolved.hiddenSize,
937
+ profile,
938
+ };
939
+ }
940
+
705
941
  export async function runTextEncodersForPrompt(tokensByEncoder, weightsByComponent, modelConfig, runtime, options = {}) {
942
+ const layout = modelConfig?.layout ?? 'sd3';
943
+ if (layout === 'sana') {
944
+ const gemmaConfig = modelConfig?.components?.text_encoder?.config || {};
945
+ const gemmaMaxLength = runtime?.textEncoder?.maxLength;
946
+ if (!Number.isFinite(gemmaMaxLength) || gemmaMaxLength <= 0) {
947
+ throw new Error('Sana Gemma2 encoder requires runtime.textEncoder.maxLength.');
948
+ }
949
+ const profileEnabled = options.profile === true;
950
+ const gemma = await runGemma2TextEncoder(
951
+ tokensByEncoder.text_encoder,
952
+ weightsByComponent.text_encoder,
953
+ gemmaConfig,
954
+ runtime,
955
+ {
956
+ prefix: 'text_encoder',
957
+ maxLength: gemmaMaxLength,
958
+ profile: profileEnabled,
959
+ }
960
+ );
961
+
962
+ return {
963
+ pooled: new Float32Array(0),
964
+ context: gemma.hidden,
965
+ attentionMask: gemma.attentionMask,
966
+ profile: profileEnabled
967
+ ? {
968
+ totalMs: gemma.profile?.totalMs ?? null,
969
+ gemmaMs: gemma.profile?.totalMs ?? null,
970
+ }
971
+ : null,
972
+ };
973
+ }
974
+
706
975
  const clipConfig = modelConfig?.components?.text_encoder?.config || {};
707
976
  const clip2Config = modelConfig?.components?.text_encoder_2?.config || {};
708
977
  const t5Config = modelConfig?.components?.text_encoder_3?.config || {};
@@ -749,6 +1018,7 @@ export async function runTextEncodersForPrompt(tokensByEncoder, weightsByCompone
749
1018
  return {
750
1019
  pooled,
751
1020
  context: t5.hidden,
1021
+ attentionMask: null,
752
1022
  profile,
753
1023
  };
754
1024
  }
@@ -1,5 +1,6 @@
1
1
  import { BPETokenizer } from '../../tokenizers/bpe.js';
2
2
  import { SentencePieceTokenizer } from '../../tokenizers/sentencepiece.js';
3
+ import { BundledTokenizer } from '../../tokenizers/bundled.js';
3
4
  import { loadAuxText, loadAuxFile } from '../../../storage/shard-manager.js';
4
5
 
5
6
  function parseMerges(text) {
@@ -136,11 +137,27 @@ async function loadSentencePieceTokenizer(tokenizerConfig, options = {}) {
136
137
  return tokenizer;
137
138
  }
138
139
 
140
+ async function loadBundledTokenizer(tokenizerConfig, options = {}) {
141
+ const { baseUrl } = options;
142
+ const tokenizerJsonText = await loadTextAsset(tokenizerConfig.tokenizerFile, baseUrl);
143
+ const tokenizerJson = JSON.parse(tokenizerJsonText);
144
+ const tokenizer = new BundledTokenizer({
145
+ vocabSize: 0,
146
+ deferSpecialTokens: true,
147
+ });
148
+ tokenizer.load(tokenizerJson);
149
+ return tokenizer;
150
+ }
151
+
139
152
  export async function loadDiffusionTokenizers(diffusionConfig, options = {}) {
140
153
  const tokenizers = {};
141
154
  const config = diffusionConfig?.tokenizers || {};
142
155
  if (config.text_encoder) {
143
- tokenizers.text_encoder = await loadBpeTokenizer(config.text_encoder, options);
156
+ if (config.text_encoder.type === 'bundled') {
157
+ tokenizers.text_encoder = await loadBundledTokenizer(config.text_encoder, options);
158
+ } else {
159
+ tokenizers.text_encoder = await loadBpeTokenizer(config.text_encoder, options);
160
+ }
144
161
  }
145
162
  if (config.text_encoder_2) {
146
163
  tokenizers.text_encoder_2 = await loadBpeTokenizer(config.text_encoder_2, options);
@@ -27,6 +27,10 @@ export interface DiffusionSchedulerConfig {
27
27
  eta: number;
28
28
  numTrainTimesteps: number;
29
29
  shift: number;
30
+ predictionType?: string;
31
+ sigmaData?: number;
32
+ maxTimesteps?: number;
33
+ intermediateTimesteps?: number;
30
34
  }
31
35
 
32
36
  export interface DiffusionLatentConfig {