@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -3,6 +3,13 @@ import { DEFAULT_INFERENCE_DEFAULTS_CONFIG } from './inference-defaults.schema.j
3
3
  import { DEFAULT_SHARED_RUNTIME_CONFIG } from './shared-runtime.schema.js';
4
4
  import { DEFAULT_EMULATION_CONFIG, createEmulationConfig } from './emulation.schema.js';
5
5
  import { mergeEcosystemConfig } from './ecosystem.schema.js';
6
+ import {
7
+ chooseNullish,
8
+ mergeExecutionPatchLists,
9
+ mergeKernelPathPolicy,
10
+ mergeShallowObject,
11
+ replaceSubtree,
12
+ } from '../merge-helpers.js';
6
13
 
7
14
  // =============================================================================
8
15
  // Runtime Config (all non-model-specific settings)
@@ -172,8 +179,6 @@ function mergeInferenceConfig(
172
179
  const overrideExecutionPatch = overrides.executionPatch ?? {};
173
180
  const baseKernelPathPolicy = base.kernelPathPolicy ?? {};
174
181
  const overrideKernelPathPolicy = overrides.kernelPathPolicy ?? {};
175
- const baseKernelPathSourceScope = baseKernelPathPolicy.sourceScope ?? baseKernelPathPolicy.allowSources;
176
- const overrideKernelPathSourceScope = overrideKernelPathPolicy.sourceScope ?? overrideKernelPathPolicy.allowSources;
177
182
  const hasRuntimeKernelProfiles = Object.prototype.hasOwnProperty.call(
178
183
  overrideSessionCompute,
179
184
  'kernelProfiles'
@@ -236,15 +241,8 @@ function mergeInferenceConfig(
236
241
  pipeline: overrides.pipeline ?? base.pipeline,
237
242
  kernelPath: overrides.kernelPath ?? base.kernelPath,
238
243
  kernelPathSource: overrides.kernelPathSource ?? base.kernelPathSource,
239
- kernelPathPolicy: {
240
- mode: overrideKernelPathPolicy.mode ?? baseKernelPathPolicy.mode,
241
- sourceScope: overrideKernelPathSourceScope ?? baseKernelPathSourceScope,
242
- allowSources: overrideKernelPathSourceScope ?? baseKernelPathSourceScope,
243
- onIncompatible: overrideKernelPathPolicy.onIncompatible ?? baseKernelPathPolicy.onIncompatible,
244
- },
245
- chatTemplate: overrides.chatTemplate
246
- ? { ...base.chatTemplate, ...overrides.chatTemplate }
247
- : base.chatTemplate,
244
+ kernelPathPolicy: mergeKernelPathPolicy(baseKernelPathPolicy, overrideKernelPathPolicy),
245
+ chatTemplate: mergeShallowObject(base.chatTemplate, overrides.chatTemplate),
248
246
  session: {
249
247
  ...baseSession,
250
248
  ...overrideSession,
@@ -259,14 +257,10 @@ function mergeInferenceConfig(
259
257
  ? { kernelProfiles: overrideSessionCompute.kernelProfiles }
260
258
  : { kernelProfiles: baseSessionCompute.kernelProfiles }),
261
259
  },
262
- kvcache: overrideSession.kvcache ?? baseSession.kvcache,
263
- decodeLoop: overrideSession.decodeLoop ?? baseSession.decodeLoop,
264
- },
265
- executionPatch: {
266
- set: overrideExecutionPatch.set ?? baseExecutionPatch.set ?? [],
267
- remove: overrideExecutionPatch.remove ?? baseExecutionPatch.remove ?? [],
268
- add: overrideExecutionPatch.add ?? baseExecutionPatch.add ?? [],
260
+ kvcache: replaceSubtree(overrideSession.kvcache, baseSession.kvcache),
261
+ decodeLoop: replaceSubtree(overrideSession.decodeLoop, baseSession.decodeLoop),
269
262
  },
263
+ executionPatch: mergeExecutionPatchLists(baseExecutionPatch, overrideExecutionPatch),
270
264
  // Model-specific inference overrides (merged with manifest.inference at load time)
271
265
  modelOverrides: overrides.modelOverrides ?? base.modelOverrides,
272
266
  };
@@ -225,6 +225,28 @@ export {
225
225
  type ConversionIOSchema,
226
226
  } from './conversion.schema.js';
227
227
 
228
+ // =============================================================================
229
+ // Browser Suite Metrics Schema
230
+ // =============================================================================
231
+ export {
232
+ type BrowserSuiteMetricsSchema,
233
+ BROWSER_SUITE_METRICS_SCHEMA_VERSION,
234
+ DEFAULT_BROWSER_SUITE_METRICS,
235
+ validateBrowserSuiteMetrics,
236
+ } from './browser-suite-metrics.schema.js';
237
+
238
+ // =============================================================================
239
+ // Conversion Report Schema
240
+ // =============================================================================
241
+ export {
242
+ type ConversionReportResultSchema,
243
+ type ConversionReportManifestSchema,
244
+ type ConversionReportSchema,
245
+ CONVERSION_REPORT_SCHEMA_VERSION,
246
+ DEFAULT_CONVERSION_REPORT,
247
+ validateConversionReport,
248
+ } from './conversion-report.schema.js';
249
+
228
250
  // =============================================================================
229
251
  // Converter Schema
230
252
  // =============================================================================
@@ -55,6 +55,24 @@ export {
55
55
  ConversionStage,
56
56
  } from './conversion.schema.js';
57
57
 
58
+ // =============================================================================
59
+ // Browser Suite Metrics Schema
60
+ // =============================================================================
61
+ export {
62
+ BROWSER_SUITE_METRICS_SCHEMA_VERSION,
63
+ DEFAULT_BROWSER_SUITE_METRICS,
64
+ validateBrowserSuiteMetrics,
65
+ } from './browser-suite-metrics.schema.js';
66
+
67
+ // =============================================================================
68
+ // Conversion Report Schema
69
+ // =============================================================================
70
+ export {
71
+ CONVERSION_REPORT_SCHEMA_VERSION,
72
+ DEFAULT_CONVERSION_REPORT,
73
+ validateConversionReport,
74
+ } from './conversion-report.schema.js';
75
+
58
76
  // =============================================================================
59
77
  // Converter Schema
60
78
  // =============================================================================
@@ -165,6 +165,9 @@ export const DEFAULT_PRESET_INFERENCE_CONFIG = {
165
165
  rope: {
166
166
  ropeTheta: 10000,
167
167
  ropeLocalTheta: null,
168
+ mropeInterleaved: false,
169
+ mropeSection: null,
170
+ partialRotaryFactor: null,
168
171
  ropeScalingType: null,
169
172
  ropeScalingFactor: 1.0,
170
173
  ropeLocalScalingType: null,
@@ -18,6 +18,15 @@ export interface RoPEConfigSchema {
18
18
  /** Local RoPE theta for sliding window layers (Gemma 3 uses 10000) */
19
19
  ropeLocalTheta?: number;
20
20
 
21
+ /** Apply adjacent-pair rotary layout instead of rotate-half layout. */
22
+ mropeInterleaved?: boolean;
23
+
24
+ /** mRoPE section sizes before the Qwen doubling step. */
25
+ mropeSection?: number[] | null;
26
+
27
+ /** Fraction of the head dimension that participates in rotary embedding. */
28
+ partialRotaryFactor?: number | null;
29
+
21
30
  /** RoPE scaling type */
22
31
  ropeScalingType?: 'linear' | 'dynamic' | 'yarn' | null;
23
32
 
@@ -105,6 +105,12 @@ export interface KernelPathSchema {
105
105
  /** KV cache dtype for this path; defaults to activationDtype when omitted. */
106
106
  kvDtype?: string;
107
107
 
108
+ /**
109
+ * Explicit widening target used by the finiteness fallback execution plan.
110
+ * Required for inline/generated kernel paths that do not have a stable registry id.
111
+ */
112
+ finitenessFallbackKernelPathId?: string;
113
+
108
114
  /**
109
115
  * Prefill phase kernel sequence (M > 1).
110
116
  * If not specified, uses decode with batched variants.
@@ -217,6 +217,12 @@ export interface ManifestRoPESchema {
217
217
  ropeTheta: number;
218
218
  /** Local theta for sliding window layers (null = same as ropeTheta) */
219
219
  ropeLocalTheta: number | null;
220
+ /** Use adjacent-pair rotary layout instead of rotate-half layout. */
221
+ mropeInterleaved: boolean;
222
+ /** mRoPE section sizes before the Qwen doubling step. */
223
+ mropeSection: number[] | null;
224
+ /** Fraction of the head dimension that participates in rotary embedding. */
225
+ partialRotaryFactor: number | null;
220
226
  /** RoPE scaling type (null = no scaling, 'linear', 'dynamic', 'yarn') */
221
227
  ropeScalingType: string | null;
222
228
  /** RoPE scaling factor (1.0 if no scaling) */
@@ -62,6 +62,9 @@ export const DEFAULT_MANIFEST_INFERENCE = {
62
62
  rope: {
63
63
  ropeTheta: 10000,
64
64
  ropeLocalTheta: null, // Same as ropeTheta (null = use ropeTheta)
65
+ mropeInterleaved: false,
66
+ mropeSection: null,
67
+ partialRotaryFactor: null,
65
68
  ropeScalingType: null, // No scaling (null = disabled)
66
69
  ropeScalingFactor: 1.0,
67
70
  ropeLocalScalingType: null, // Local scaling policy (null = no scaling)
@@ -27,6 +27,12 @@ import type {
27
27
  MoEConfigSchema,
28
28
  ConversionInfoSchema,
29
29
  } from '../config/schema/index.js';
30
+ import type { ExecutionContractArtifact } from '../config/execution-contract-check.js';
31
+ import type { ExecutionV0GraphContractArtifact } from '../config/execution-v0-graph-contract-check.js';
32
+ import type {
33
+ ManifestRequiredInferenceFieldsArtifact,
34
+ RequiredInferenceFieldsContractArtifact,
35
+ } from '../config/required-inference-fields-contract-check.js';
30
36
 
31
37
  export { generateShardFilename } from '../formats/rdrr/index.js';
32
38
 
@@ -144,6 +150,10 @@ export interface ConvertResult {
144
150
  shardCount: number;
145
151
  tensorCount: number;
146
152
  totalSize: number;
153
+ executionContractArtifact: ExecutionContractArtifact | null;
154
+ executionV0GraphContractArtifact: ExecutionV0GraphContractArtifact | null;
155
+ layerPatternContractArtifact: Record<string, unknown> | null;
156
+ requiredInferenceFieldsArtifact: ManifestRequiredInferenceFieldsArtifact | RequiredInferenceFieldsContractArtifact | null;
147
157
  }
148
158
 
149
159
  /** @deprecated Use ConversionIOSchema from config/schema */
@@ -9,15 +9,20 @@ import {
9
9
  formatBytes,
10
10
  } from '../config/schema/index.js';
11
11
 
12
- import { classifyTensorRole, generateShardFilename } from '../formats/rdrr/index.js';
12
+ import { classifyTensor, classifyTensorRole, generateShardFilename } from '../formats/rdrr/index.js';
13
13
  import { log } from '../debug/index.js';
14
- import { selectRuleValue } from '../rules/rule-registry.js';
14
+ import {
15
+ getInferenceLayerPatternContractArtifact,
16
+ selectRuleValue,
17
+ } from '../rules/rule-registry.js';
15
18
  import {
16
19
  createConverterConfig,
17
20
  detectPreset,
18
21
  listPresets,
19
22
  resolvePreset,
20
23
  } from '../config/index.js';
24
+ import { buildExecutionContractArtifact } from '../config/execution-contract-check.js';
25
+ import { buildManifestRequiredInferenceFieldsArtifact } from '../config/required-inference-fields-contract-check.js';
21
26
  import { buildManifestInference, inferEmbeddingOutputConfig } from './manifest-inference.js';
22
27
  import { resolveEosTokenId } from './tokenizer-utils.js';
23
28
  import {
@@ -1128,6 +1133,7 @@ export async function convertModel(model, io, options = {}) {
1128
1133
  }
1129
1134
  const totalTensors = tensors.length;
1130
1135
  const targetQuant = String(options.quantization ?? model.quantization ?? '').trim().toLowerCase();
1136
+ const tensorGroupModelType = String(options.modelType ?? model.modelType ?? 'transformer');
1131
1137
  const q4kLayout = normalizeQ4KLayout(options.quantizationInfo?.layout);
1132
1138
  const quantizeEmbeddings = resolveQuantizeEmbeddings(
1133
1139
  options.quantizationInfo ?? null,
@@ -1251,6 +1257,7 @@ export async function convertModel(model, io, options = {}) {
1251
1257
 
1252
1258
  // Record tensor location
1253
1259
  const role = classifyTensorRole(tensor.name);
1260
+ const group = classifyTensor(tensor.name, tensorGroupModelType);
1254
1261
 
1255
1262
  if (tensorSpans.length === 1) {
1256
1263
  tensorLocations[tensor.name] = {
@@ -1260,6 +1267,7 @@ export async function convertModel(model, io, options = {}) {
1260
1267
  shape: tensor.shape,
1261
1268
  dtype: outDtype,
1262
1269
  role,
1270
+ group,
1263
1271
  ...(outLayout ? { layout: outLayout } : {}),
1264
1272
  };
1265
1273
  } else {
@@ -1269,6 +1277,7 @@ export async function convertModel(model, io, options = {}) {
1269
1277
  shape: tensor.shape,
1270
1278
  dtype: outDtype,
1271
1279
  role,
1280
+ group,
1272
1281
  ...(outLayout ? { layout: outLayout } : {}),
1273
1282
  };
1274
1283
  }
@@ -1327,11 +1336,27 @@ export async function convertModel(model, io, options = {}) {
1327
1336
  totalSize: formatBytes(totalSize),
1328
1337
  });
1329
1338
 
1339
+ const executionContractArtifact = buildExecutionContractArtifact(manifest);
1340
+ const layerPatternContractArtifact = getInferenceLayerPatternContractArtifact();
1341
+ const requiredInferenceFieldsArtifact = manifest?.modelType === 'transformer'
1342
+ && manifest?.inference
1343
+ && typeof manifest.inference === 'object'
1344
+ && manifest.inference.attention
1345
+ && typeof manifest.inference.attention === 'object'
1346
+ ? buildManifestRequiredInferenceFieldsArtifact(
1347
+ manifest?.inference ?? null,
1348
+ `${manifest?.modelId ?? modelId}.inference`
1349
+ )
1350
+ : null;
1330
1351
  return {
1331
1352
  manifest,
1332
1353
  shardCount: shards.length,
1333
1354
  tensorCount: tensors.length,
1334
1355
  totalSize,
1356
+ executionContractArtifact,
1357
+ executionV0GraphContractArtifact: executionContractArtifact?.executionV0?.graph ?? null,
1358
+ layerPatternContractArtifact,
1359
+ requiredInferenceFieldsArtifact,
1335
1360
  };
1336
1361
  }
1337
1362
 
@@ -4,6 +4,13 @@ const SD3_LAYOUT = {
4
4
  id: 'sd3',
5
5
  requiredComponents: ['transformer', 'text_encoder', 'text_encoder_2', 'text_encoder_3', 'vae', 'scheduler'],
6
6
  weightedComponents: ['transformer', 'text_encoder', 'text_encoder_2', 'text_encoder_3', 'vae'],
7
+ matches(modelIndex, components) {
8
+ return (
9
+ components.has('text_encoder_2') &&
10
+ components.has('text_encoder_3') &&
11
+ getComponentClassName(modelIndex?.transformer) === 'SD3Transformer2DModel'
12
+ );
13
+ },
7
14
  tokenizerSpecs: [
8
15
  {
9
16
  modelIndexKey: 'tokenizer',
@@ -66,6 +73,10 @@ const FLUX_LAYOUT = {
66
73
  id: 'flux',
67
74
  requiredComponents: ['transformer', 'text_encoder', 'vae', 'scheduler'],
68
75
  weightedComponents: ['transformer', 'text_encoder', 'vae'],
76
+ matches(modelIndex) {
77
+ const transformerClass = getComponentClassName(modelIndex?.transformer);
78
+ return typeof transformerClass === 'string' && /^Flux/i.test(transformerClass);
79
+ },
69
80
  tokenizerSpecs: [
70
81
  {
71
82
  modelIndexKey: 'tokenizer',
@@ -91,7 +102,39 @@ const FLUX_LAYOUT = {
91
102
  ],
92
103
  };
93
104
 
94
- const LAYOUTS = [SD3_LAYOUT, FLUX_LAYOUT];
105
+ const SANA_LAYOUT = {
106
+ id: 'sana',
107
+ requiredComponents: ['transformer', 'text_encoder', 'tokenizer', 'vae', 'scheduler'],
108
+ weightedComponents: ['transformer', 'text_encoder', 'vae'],
109
+ matches(modelIndex) {
110
+ return (
111
+ getComponentClassName(modelIndex?.transformer) === 'SanaTransformer2DModel' &&
112
+ getComponentClassName(modelIndex?.text_encoder) === 'Gemma2Model'
113
+ );
114
+ },
115
+ tokenizerSpecs: [
116
+ {
117
+ modelIndexKey: 'tokenizer',
118
+ componentId: 'text_encoder',
119
+ type: 'bundled',
120
+ assets: [
121
+ { suffix: 'tokenizer/tokenizer.json', targetName: 'tokenizer_tokenizer.json', kind: 'text', required: true },
122
+ { suffix: 'tokenizer/tokenizer_config.json', targetName: 'tokenizer_config.json', kind: 'text', required: false },
123
+ { suffix: 'tokenizer/special_tokens_map.json', targetName: 'tokenizer_special_tokens_map.json', kind: 'text', required: false },
124
+ { suffix: 'tokenizer/tokenizer.model', targetName: 'tokenizer_tokenizer.model', kind: 'binary', required: false },
125
+ ],
126
+ config: {
127
+ type: 'bundled',
128
+ tokenizerFile: 'tokenizer_tokenizer.json',
129
+ configFile: 'tokenizer_config.json',
130
+ specialTokensFile: 'tokenizer_special_tokens_map.json',
131
+ sentencePieceFile: 'tokenizer_tokenizer.model',
132
+ },
133
+ },
134
+ ],
135
+ };
136
+
137
+ const LAYOUTS = [SD3_LAYOUT, FLUX_LAYOUT, SANA_LAYOUT];
95
138
 
96
139
  function toAbortError(message = 'Cancelled') {
97
140
  if (typeof DOMException === 'function') {
@@ -112,12 +155,26 @@ function listModelComponents(modelIndex) {
112
155
  return Object.keys(modelIndex || {}).filter((key) => !key.startsWith('_'));
113
156
  }
114
157
 
158
+ function getComponentClassName(componentEntry) {
159
+ if (Array.isArray(componentEntry) && componentEntry.length >= 2 && typeof componentEntry[1] === 'string') {
160
+ return componentEntry[1];
161
+ }
162
+ if (componentEntry && typeof componentEntry === 'object' && typeof componentEntry._class_name === 'string') {
163
+ return componentEntry._class_name;
164
+ }
165
+ return null;
166
+ }
167
+
115
168
  export function detectDiffusionLayout(modelIndex) {
116
169
  const components = new Set(listModelComponents(modelIndex));
117
170
  for (const layout of LAYOUTS) {
118
- if (layout.requiredComponents.every((component) => components.has(component))) {
119
- return layout;
171
+ if (!layout.requiredComponents.every((component) => components.has(component))) {
172
+ continue;
120
173
  }
174
+ if (typeof layout.matches === 'function' && !layout.matches(modelIndex, components)) {
175
+ continue;
176
+ }
177
+ return layout;
121
178
  }
122
179
  const listed = [...components].sort().join(', ') || '(none)';
123
180
  const expected = LAYOUTS
@@ -199,6 +256,9 @@ export async function parseDiffusionModel(adapter) {
199
256
  const tensors = [];
200
257
 
201
258
  for (const componentId of layout.requiredComponents) {
259
+ if (componentId === 'tokenizer') {
260
+ continue;
261
+ }
202
262
  const configSuffix = defaultConfigPath(componentId);
203
263
  const config = await readJson(configSuffix, `${componentId} config`);
204
264
  if (componentId === 'transformer' && config && !config.weight_format) {
@@ -6,10 +6,26 @@ function asObject(value) {
6
6
  }
7
7
 
8
8
  function asFiniteNumber(value) {
9
+ if (value == null || value === '') {
10
+ return null;
11
+ }
9
12
  const parsed = Number(value);
10
13
  return Number.isFinite(parsed) ? parsed : null;
11
14
  }
12
15
 
16
+ function asBoolean(value) {
17
+ return typeof value === 'boolean' ? value : null;
18
+ }
19
+
20
+ function asNumberArray(value) {
21
+ if (!Array.isArray(value)) return null;
22
+ const normalized = value.map((entry) => asFiniteNumber(entry));
23
+ if (normalized.some((entry) => entry == null || entry <= 0)) {
24
+ return null;
25
+ }
26
+ return normalized.map((entry) => Math.trunc(entry));
27
+ }
28
+
13
29
  function normalizeRoPEType(value) {
14
30
  if (typeof value !== 'string') return null;
15
31
  const normalized = value.trim().toLowerCase();
@@ -125,6 +141,13 @@ function failOnConflictingScaling(sourceLabel, canonicalScaling, candidateScalin
125
141
  export function buildRoPEConfig(presetInference, config) {
126
142
  const ropeScaling = asObject(config.rope_scaling);
127
143
  const ropeParameters = asObject(config.rope_parameters);
144
+ const flatRoPEParameters = (
145
+ ropeParameters
146
+ && !asObject(ropeParameters.full_attention)
147
+ && !asObject(ropeParameters.sliding_attention)
148
+ )
149
+ ? ropeParameters
150
+ : null;
128
151
  const fullAttentionRoPE = asObject(ropeParameters?.full_attention);
129
152
  const slidingAttentionRoPE = asObject(ropeParameters?.sliding_attention);
130
153
  const presetRoPE = presetInference.rope ?? {};
@@ -164,6 +187,11 @@ export function buildRoPEConfig(presetInference, config) {
164
187
  strictMissingTypeAndFactor: false,
165
188
  sourceLabel: 'HF config rope_parameters.full_attention',
166
189
  });
190
+ } else if (flatRoPEParameters) {
191
+ globalScaling = resolveScalingConfig(flatRoPEParameters, {
192
+ strictMissingTypeAndFactor: false,
193
+ sourceLabel: 'HF config rope_parameters',
194
+ });
167
195
  }
168
196
 
169
197
  const hasPresetLocalScaling = presetRoPE.ropeLocalScalingType !== undefined
@@ -192,6 +220,7 @@ export function buildRoPEConfig(presetInference, config) {
192
220
  // HF config is source of truth for ropeTheta when provided:
193
221
  // prefer rope_parameters.full_attention.rope_theta, then rope_theta.
194
222
  const ropeTheta = asFiniteNumber(fullAttentionRoPE?.rope_theta)
223
+ ?? asFiniteNumber(flatRoPEParameters?.rope_theta)
195
224
  ?? asFiniteNumber(config.rope_theta)
196
225
  ?? presetInference.rope?.ropeTheta
197
226
  ?? 10000;
@@ -201,9 +230,22 @@ export function buildRoPEConfig(presetInference, config) {
201
230
  ?? presetInference.rope?.ropeLocalTheta
202
231
  ?? null;
203
232
 
233
+ const mropeInterleaved = asBoolean(flatRoPEParameters?.mrope_interleaved)
234
+ ?? presetInference.rope?.mropeInterleaved
235
+ ?? false;
236
+ const mropeSection = asNumberArray(flatRoPEParameters?.mrope_section)
237
+ ?? presetInference.rope?.mropeSection
238
+ ?? null;
239
+ const partialRotaryFactor = asFiniteNumber(flatRoPEParameters?.partial_rotary_factor)
240
+ ?? asFiniteNumber(presetInference.rope?.partialRotaryFactor)
241
+ ?? null;
242
+
204
243
  return {
205
244
  ropeTheta,
206
245
  ropeLocalTheta,
246
+ mropeInterleaved,
247
+ mropeSection,
248
+ partialRotaryFactor,
207
249
  ropeScalingType: globalScaling.ropeScalingType,
208
250
  ropeScalingFactor: globalScaling.ropeScalingFactor,
209
251
  yarnBetaFast: globalScaling.yarnBetaFast,
package/src/gpu/device.js CHANGED
@@ -28,6 +28,62 @@ function advanceDeviceEpoch() {
28
28
  deviceEpoch += 1;
29
29
  }
30
30
 
31
+ function isValidGPUBuffer(value) {
32
+ if (!value) {
33
+ return false;
34
+ }
35
+ if (typeof GPUBuffer === 'undefined') {
36
+ return true;
37
+ }
38
+ return value instanceof GPUBuffer;
39
+ }
40
+
41
+ function describeBindGroupBufferValue(value) {
42
+ if (value === null) return 'null';
43
+ if (value === undefined) return 'undefined';
44
+ if (typeof GPUBuffer !== 'undefined' && value instanceof GPUBuffer) return 'GPUBuffer';
45
+ if (typeof value === 'object') {
46
+ return value.constructor?.name || 'object';
47
+ }
48
+ return typeof value;
49
+ }
50
+
51
+ function validateBindGroupDescriptor(descriptor) {
52
+ const label = descriptor?.label || 'unlabeled_bind_group';
53
+ const entries = Array.isArray(descriptor?.entries) ? descriptor.entries : [];
54
+ for (const entry of entries) {
55
+ const resource = entry?.resource;
56
+ if (!resource || typeof resource !== 'object' || !('buffer' in resource)) {
57
+ continue;
58
+ }
59
+ if (isValidGPUBuffer(resource.buffer)) {
60
+ continue;
61
+ }
62
+ throw new Error(
63
+ `[${label}] binding ${entry.binding} requires a GPUBuffer; ` +
64
+ `got ${describeBindGroupBufferValue(resource.buffer)}.`
65
+ );
66
+ }
67
+ }
68
+
69
+ function wrapDeviceCreateBindGroup(device) {
70
+ if (!device || device.__dopplerBindGroupValidationWrapped) {
71
+ return device;
72
+ }
73
+ const originalCreateBindGroup = device.createBindGroup.bind(device);
74
+ device.createBindGroup = (descriptor) => {
75
+ validateBindGroupDescriptor(descriptor);
76
+ return originalCreateBindGroup(descriptor);
77
+ };
78
+ Object.defineProperty(device, '__dopplerBindGroupValidationWrapped', {
79
+ value: true,
80
+ configurable: true,
81
+ enumerable: false,
82
+ writable: false,
83
+ });
84
+ return device;
85
+ }
86
+
31
87
 
32
88
  export const FEATURES = ({
33
89
  SHADER_F16: 'shader-f16',
@@ -201,6 +257,7 @@ export async function initDevice() {
201
257
  if (!gpuDevice) {
202
258
  throw createDopplerError(ERROR_CODES.GPU_DEVICE_FAILED, 'Failed to create WebGPU device');
203
259
  }
260
+ wrapDeviceCreateBindGroup(gpuDevice);
204
261
  advanceDeviceEpoch();
205
262
 
206
263
  // Set up device lost handler
@@ -253,6 +310,7 @@ export function setDevice(device, options = {}) {
253
310
  }
254
311
 
255
312
  gpuDevice = device;
313
+ wrapDeviceCreateBindGroup(gpuDevice);
256
314
  advanceDeviceEpoch();
257
315
  wrapQueueForTracking(gpuDevice.queue);
258
316