@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,231 @@
1
+ import { selectByRules } from '../gpu/kernels/rule-matcher.js';
2
+ import { computeGlobalLayers } from '../config/schema/inference.schema.js';
3
+
4
+ function isPlainObject(value) {
5
+ return value != null && typeof value === 'object' && !Array.isArray(value);
6
+ }
7
+
8
+ function matchesExactObject(actual, expected) {
9
+ if (!isPlainObject(actual) || !isPlainObject(expected)) {
10
+ return false;
11
+ }
12
+ const actualKeys = Object.keys(actual).sort();
13
+ const expectedKeys = Object.keys(expected).sort();
14
+ if (actualKeys.length !== expectedKeys.length) {
15
+ return false;
16
+ }
17
+ for (let i = 0; i < actualKeys.length; i += 1) {
18
+ if (actualKeys[i] !== expectedKeys[i]) {
19
+ return false;
20
+ }
21
+ }
22
+ for (const key of expectedKeys) {
23
+ const expectedValue = expected[key];
24
+ const actualValue = actual[key];
25
+ if (isPlainObject(expectedValue)) {
26
+ if (!matchesExactObject(actualValue, expectedValue)) {
27
+ return false;
28
+ }
29
+ continue;
30
+ }
31
+ if (actualValue !== expectedValue) {
32
+ return false;
33
+ }
34
+ }
35
+ return true;
36
+ }
37
+
38
+ function expectedPatternKind(context) {
39
+ if (context.patternType === 'alternating' && context.globalPattern === 'even') {
40
+ return 'alternating_even';
41
+ }
42
+ if (context.patternType === 'alternating' && context.globalPattern === 'odd') {
43
+ return 'alternating_odd';
44
+ }
45
+ if (context.patternType === 'every_n') {
46
+ return 'every_n';
47
+ }
48
+ return null;
49
+ }
50
+
51
+ function expectedLayerType(context) {
52
+ if (context.patternKind === 'alternating_even') {
53
+ return context.isEven ? 'full_attention' : 'sliding_attention';
54
+ }
55
+ if (context.patternKind === 'alternating_odd') {
56
+ return context.isEven ? 'sliding_attention' : 'full_attention';
57
+ }
58
+ if (context.patternKind === 'every_n') {
59
+ return context.isStride ? 'full_attention' : 'sliding_attention';
60
+ }
61
+ return null;
62
+ }
63
+
64
+ function enumeratePatternKindContexts() {
65
+ const patternTypes = ['alternating', 'every_n', 'custom', null];
66
+ const globalPatterns = ['even', 'odd', 'every_n', null];
67
+ const contexts = [];
68
+ for (const patternType of patternTypes) {
69
+ for (const globalPattern of globalPatterns) {
70
+ contexts.push({ patternType, globalPattern });
71
+ }
72
+ }
73
+ return contexts;
74
+ }
75
+
76
+ function enumerateLayerTypeContexts() {
77
+ const patternKinds = ['alternating_even', 'alternating_odd', 'every_n'];
78
+ const booleans = [true, false];
79
+ const contexts = [];
80
+ for (const patternKind of patternKinds) {
81
+ for (const isEven of booleans) {
82
+ for (const isStride of booleans) {
83
+ contexts.push({ patternKind, isEven, isStride });
84
+ }
85
+ }
86
+ }
87
+ return contexts;
88
+ }
89
+
90
+ function checkRuleShape(rules, expected, label) {
91
+ if (!Array.isArray(rules)) {
92
+ return {
93
+ ok: false,
94
+ errors: [`[LayerPatternContract] ${label} must be an array.`],
95
+ };
96
+ }
97
+ if (rules.length !== expected.length) {
98
+ return {
99
+ ok: false,
100
+ errors: [`[LayerPatternContract] ${label} must contain exactly ${expected.length} rules; got ${rules.length}.`],
101
+ };
102
+ }
103
+ const errors = [];
104
+ for (let i = 0; i < expected.length; i += 1) {
105
+ if (!matchesExactObject(rules[i]?.match, expected[i].match) || rules[i]?.value !== expected[i].value) {
106
+ errors.push(`[LayerPatternContract] ${label} rule[${i}] drifted from the expected decision table.`);
107
+ break;
108
+ }
109
+ }
110
+ return {
111
+ ok: errors.length === 0,
112
+ errors,
113
+ };
114
+ }
115
+
116
+ function checkRuleSemantics(rules, contexts, expectedValue, label) {
117
+ const errors = [];
118
+ for (const context of contexts) {
119
+ const actual = selectByRules(rules, context);
120
+ const expected = expectedValue(context);
121
+ if (actual !== expected) {
122
+ errors.push(
123
+ `[LayerPatternContract] ${label} mismatched context ${JSON.stringify(context)}: ` +
124
+ `expected ${JSON.stringify(expected)}, got ${JSON.stringify(actual)}.`
125
+ );
126
+ break;
127
+ }
128
+ }
129
+ return {
130
+ ok: errors.length === 0,
131
+ errors,
132
+ sampledContexts: contexts.length,
133
+ };
134
+ }
135
+
136
+ function checkGlobalLayerSemantics() {
137
+ const checks = [
138
+ {
139
+ id: 'inference.layerPattern.computeGlobalLayers.even',
140
+ actual: computeGlobalLayers({ type: 'alternating', globalPattern: 'even' }, 6),
141
+ expected: [0, 2, 4],
142
+ },
143
+ {
144
+ id: 'inference.layerPattern.computeGlobalLayers.odd',
145
+ actual: computeGlobalLayers({ type: 'alternating', globalPattern: 'odd' }, 6),
146
+ expected: [1, 3, 5],
147
+ },
148
+ {
149
+ id: 'inference.layerPattern.computeGlobalLayers.every_n_offset',
150
+ actual: computeGlobalLayers({ type: 'every_n', period: 6, offset: 5 }, 12),
151
+ expected: [5, 11],
152
+ },
153
+ {
154
+ id: 'inference.layerPattern.computeGlobalLayers.every_n_negative_offset',
155
+ actual: computeGlobalLayers({ type: 'every_n', period: 6, offset: -1 }, 12),
156
+ expected: [5, 11],
157
+ },
158
+ ];
159
+ const errors = [];
160
+ const results = [];
161
+ for (const entry of checks) {
162
+ const ok = JSON.stringify(entry.actual) === JSON.stringify(entry.expected);
163
+ results.push({ id: entry.id, ok });
164
+ if (!ok) {
165
+ errors.push(
166
+ `[LayerPatternContract] ${entry.id} expected ${JSON.stringify(entry.expected)}, got ${JSON.stringify(entry.actual)}.`
167
+ );
168
+ }
169
+ }
170
+ return {
171
+ checks: results,
172
+ errors,
173
+ };
174
+ }
175
+
176
+ export function buildLayerPatternContractArtifact(ruleGroup) {
177
+ const errors = [];
178
+ const checks = [];
179
+ const patternKindRules = ruleGroup?.patternKind;
180
+ const layerTypeRules = ruleGroup?.layerType;
181
+
182
+ const patternKindShape = checkRuleShape(patternKindRules, [
183
+ { match: { patternType: 'alternating', globalPattern: 'even' }, value: 'alternating_even' },
184
+ { match: { patternType: 'alternating', globalPattern: 'odd' }, value: 'alternating_odd' },
185
+ { match: { patternType: 'every_n' }, value: 'every_n' },
186
+ { match: {}, value: null },
187
+ ], 'patternKind');
188
+ errors.push(...patternKindShape.errors);
189
+ checks.push({ id: 'inference.layerPattern.patternKind.shape', ok: patternKindShape.ok });
190
+
191
+ const patternKindSemantics = Array.isArray(patternKindRules)
192
+ ? checkRuleSemantics(patternKindRules, enumeratePatternKindContexts(), expectedPatternKind, 'patternKind')
193
+ : { ok: false, errors: ['[LayerPatternContract] patternKind is unavailable for semantic check.'], sampledContexts: 0 };
194
+ errors.push(...patternKindSemantics.errors);
195
+ checks.push({ id: 'inference.layerPattern.patternKind.semantics', ok: patternKindSemantics.ok });
196
+
197
+ const layerTypeShape = checkRuleShape(layerTypeRules, [
198
+ { match: { patternKind: 'alternating_even', isEven: true }, value: 'full_attention' },
199
+ { match: { patternKind: 'alternating_even' }, value: 'sliding_attention' },
200
+ { match: { patternKind: 'alternating_odd', isEven: false }, value: 'full_attention' },
201
+ { match: { patternKind: 'alternating_odd' }, value: 'sliding_attention' },
202
+ { match: { patternKind: 'every_n', isStride: true }, value: 'full_attention' },
203
+ { match: { patternKind: 'every_n' }, value: 'sliding_attention' },
204
+ ], 'layerType');
205
+ errors.push(...layerTypeShape.errors);
206
+ checks.push({ id: 'inference.layerPattern.layerType.shape', ok: layerTypeShape.ok });
207
+
208
+ const layerTypeSemantics = Array.isArray(layerTypeRules)
209
+ ? checkRuleSemantics(layerTypeRules, enumerateLayerTypeContexts(), expectedLayerType, 'layerType')
210
+ : { ok: false, errors: ['[LayerPatternContract] layerType is unavailable for semantic check.'], sampledContexts: 0 };
211
+ errors.push(...layerTypeSemantics.errors);
212
+ checks.push({ id: 'inference.layerPattern.layerType.semantics', ok: layerTypeSemantics.ok });
213
+
214
+ const globalLayerSemantics = checkGlobalLayerSemantics();
215
+ errors.push(...globalLayerSemantics.errors);
216
+ checks.push(...globalLayerSemantics.checks);
217
+
218
+ return {
219
+ schemaVersion: 1,
220
+ source: 'doppler',
221
+ ok: errors.length === 0,
222
+ checks,
223
+ errors,
224
+ stats: {
225
+ patternKindRules: Array.isArray(patternKindRules) ? patternKindRules.length : 0,
226
+ layerTypeRules: Array.isArray(layerTypeRules) ? layerTypeRules.length : 0,
227
+ patternKindContexts: patternKindSemantics.sampledContexts,
228
+ layerTypeContexts: layerTypeSemantics.sampledContexts,
229
+ },
230
+ };
231
+ }
@@ -46,3 +46,31 @@ export declare function registerRuleGroup(
46
46
  group: RuleGroup,
47
47
  rules: Record<string, RuleSet>
48
48
  ): void;
49
+
50
+ export declare function getInferenceExecutionRulesContractArtifact(): {
51
+ schemaVersion: 1;
52
+ source: 'doppler';
53
+ ok: boolean;
54
+ checks: Array<{ id: string; ok: boolean }>;
55
+ errors: string[];
56
+ stats: {
57
+ decodeRecorderRules: number;
58
+ batchDecodeRules: number;
59
+ decodeRecorderContexts: number;
60
+ batchDecodeContexts: number;
61
+ };
62
+ };
63
+
64
+ export declare function getInferenceLayerPatternContractArtifact(): {
65
+ schemaVersion: 1;
66
+ source: 'doppler';
67
+ ok: boolean;
68
+ checks: Array<{ id: string; ok: boolean }>;
69
+ errors: string[];
70
+ stats: {
71
+ patternKindRules: number;
72
+ layerTypeRules: number;
73
+ patternKindContexts: number;
74
+ layerTypeContexts: number;
75
+ };
76
+ };
@@ -1,8 +1,11 @@
1
1
  import { selectByRules } from '../gpu/kernels/rule-matcher.js';
2
+ import { buildInferenceExecutionRulesContractArtifact } from './execution-rules-contract-check.js';
3
+ import { buildLayerPatternContractArtifact } from './layer-pattern-contract-check.js';
2
4
  import { loadJson } from '../utils/load-json.js';
3
5
 
4
6
  const attentionRules = await loadJson('./kernels/attention.rules.json', import.meta.url, 'Failed to load rules');
5
7
  const conv2dRules = await loadJson('./kernels/conv2d.rules.json', import.meta.url, 'Failed to load rules');
8
+ const depthwiseConv2dRules = await loadJson('./kernels/depthwise-conv2d.rules.json', import.meta.url, 'Failed to load rules');
6
9
  const dequantRules = await loadJson('./kernels/dequant.rules.json', import.meta.url, 'Failed to load rules');
7
10
  const energyRules = await loadJson('./kernels/energy.rules.json', import.meta.url, 'Failed to load rules');
8
11
  const fusedFfnRules = await loadJson('./kernels/fused-ffn.rules.json', import.meta.url, 'Failed to load rules');
@@ -10,6 +13,7 @@ const fusedMatmulResidualRules = await loadJson('./kernels/fused-matmul-residual
10
13
  const fusedMatmulRmsnormRules = await loadJson('./kernels/fused-matmul-rmsnorm.rules.json', import.meta.url, 'Failed to load rules');
11
14
  const gatherRules = await loadJson('./kernels/gather.rules.json', import.meta.url, 'Failed to load rules');
12
15
  const geluRules = await loadJson('./kernels/gelu.rules.json', import.meta.url, 'Failed to load rules');
16
+ const groupedPointwiseConv2dRules = await loadJson('./kernels/grouped-pointwise-conv2d.rules.json', import.meta.url, 'Failed to load rules');
13
17
  const groupnormRules = await loadJson('./kernels/groupnorm.rules.json', import.meta.url, 'Failed to load rules');
14
18
  const kvQuantizeRules = await loadJson('./kernels/kv_quantize.rules.json', import.meta.url, 'Failed to load rules');
15
19
  const layernormRules = await loadJson('./kernels/layernorm.rules.json', import.meta.url, 'Failed to load rules');
@@ -18,9 +22,12 @@ const kernelMoeRules = await loadJson('./kernels/moe.rules.json', import.meta.ur
18
22
  const kernelMoeGptOssRules = await loadJson('./kernels/moe.rules.gptoss.json', import.meta.url, 'Failed to load rules');
19
23
  const modulateRules = await loadJson('./kernels/modulate.rules.json', import.meta.url, 'Failed to load rules');
20
24
  const pixelShuffleRules = await loadJson('./kernels/pixel_shuffle.rules.json', import.meta.url, 'Failed to load rules');
25
+ const repeatChannelsRules = await loadJson('./kernels/repeat-channels.rules.json', import.meta.url, 'Failed to load rules');
26
+ const reluRules = await loadJson('./kernels/relu.rules.json', import.meta.url, 'Failed to load rules');
21
27
  const residualRules = await loadJson('./kernels/residual.rules.json', import.meta.url, 'Failed to load rules');
22
28
  const rmsnormRules = await loadJson('./kernels/rmsnorm.rules.json', import.meta.url, 'Failed to load rules');
23
29
  const ropeRules = await loadJson('./kernels/rope.rules.json', import.meta.url, 'Failed to load rules');
30
+ const sanaLinearAttentionRules = await loadJson('./kernels/sana-linear-attention.rules.json', import.meta.url, 'Failed to load rules');
24
31
  const sampleRules = await loadJson('./kernels/sample.rules.json', import.meta.url, 'Failed to load rules');
25
32
  const scaleRules = await loadJson('./kernels/scale.rules.json', import.meta.url, 'Failed to load rules');
26
33
  const siluRules = await loadJson('./kernels/silu.rules.json', import.meta.url, 'Failed to load rules');
@@ -46,6 +53,24 @@ const toolingCommandRuntimeRules = await loadJson(
46
53
  import.meta.url,
47
54
  'Failed to load rules'
48
55
  );
56
+ const INFERENCE_EXECUTION_RULES_CONTRACT_ARTIFACT = buildInferenceExecutionRulesContractArtifact(
57
+ inferenceExecutionRules
58
+ );
59
+ if (!INFERENCE_EXECUTION_RULES_CONTRACT_ARTIFACT.ok) {
60
+ throw new Error(
61
+ `RuleRegistry: inference.execution rules contract failed: ` +
62
+ `${INFERENCE_EXECUTION_RULES_CONTRACT_ARTIFACT.errors.join(' | ')}`
63
+ );
64
+ }
65
+ const INFERENCE_LAYER_PATTERN_CONTRACT_ARTIFACT = buildLayerPatternContractArtifact(
66
+ layerPatternRules
67
+ );
68
+ if (!INFERENCE_LAYER_PATTERN_CONTRACT_ARTIFACT.ok) {
69
+ throw new Error(
70
+ `RuleRegistry: inference.layerPattern rules contract failed: ` +
71
+ `${INFERENCE_LAYER_PATTERN_CONTRACT_ARTIFACT.errors.join(' | ')}`
72
+ );
73
+ }
49
74
 
50
75
  const RULE_SETS = {
51
76
  shared: {
@@ -54,6 +79,7 @@ const RULE_SETS = {
54
79
  kernels: {
55
80
  attention: attentionRules,
56
81
  conv2d: conv2dRules,
82
+ depthwiseConv2d: depthwiseConv2dRules,
57
83
  dequant: dequantRules,
58
84
  energy: energyRules,
59
85
  fusedFfn: fusedFfnRules,
@@ -61,6 +87,7 @@ const RULE_SETS = {
61
87
  fusedMatmulRmsnorm: fusedMatmulRmsnormRules,
62
88
  gather: gatherRules,
63
89
  gelu: geluRules,
90
+ groupedPointwiseConv2d: groupedPointwiseConv2dRules,
64
91
  groupnorm: groupnormRules,
65
92
  kv_quantize: kvQuantizeRules,
66
93
  layernorm: layernormRules,
@@ -69,9 +96,12 @@ const RULE_SETS = {
69
96
  moeGptoss: kernelMoeGptOssRules,
70
97
  modulate: modulateRules,
71
98
  pixel_shuffle: pixelShuffleRules,
99
+ repeatChannels: repeatChannelsRules,
100
+ relu: reluRules,
72
101
  residual: residualRules,
73
102
  rmsnorm: rmsnormRules,
74
103
  rope: ropeRules,
104
+ sanaLinearAttention: sanaLinearAttentionRules,
75
105
  sample: sampleRules,
76
106
  scale: scaleRules,
77
107
  silu: siluRules,
@@ -133,6 +163,14 @@ export function registerRuleGroup(domain, group, rules) {
133
163
  RULE_SETS[domain][group] = rules;
134
164
  }
135
165
 
166
+ export function getInferenceExecutionRulesContractArtifact() {
167
+ return INFERENCE_EXECUTION_RULES_CONTRACT_ARTIFACT;
168
+ }
169
+
170
+ export function getInferenceLayerPatternContractArtifact() {
171
+ return INFERENCE_LAYER_PATTERN_CONTRACT_ARTIFACT;
172
+ }
173
+
136
174
  function resolveRuleValue(value, context) {
137
175
  if (Array.isArray(value)) {
138
176
  return value.map((entry) => resolveRuleValue(entry, context));
@@ -27,6 +27,24 @@
27
27
  "intent": "verify"
28
28
  }
29
29
  },
30
+ {
31
+ "match": {
32
+ "command": "lora"
33
+ },
34
+ "value": {
35
+ "suite": null,
36
+ "intent": null
37
+ }
38
+ },
39
+ {
40
+ "match": {
41
+ "command": "distill"
42
+ },
43
+ "value": {
44
+ "suite": null,
45
+ "intent": null
46
+ }
47
+ },
30
48
  {
31
49
  "match": {},
32
50
  "value": {
@@ -1,10 +1,12 @@
1
1
  import type { ConverterConfigSchema } from '../config/schema/converter.schema.js';
2
2
 
3
- export type ToolingCommand = 'convert' | 'debug' | 'bench' | 'verify';
3
+ export type ToolingCommand = 'convert' | 'debug' | 'bench' | 'verify' | 'lora' | 'distill';
4
4
  export type ToolingSurface = 'browser' | 'node';
5
5
  export type ToolingSuite = 'kernels' | 'inference' | 'training' | 'bench' | 'debug' | 'diffusion' | 'energy';
6
6
  export type ToolingIntent = 'verify' | 'investigate' | 'calibrate' | null;
7
7
  export type ToolingTrainingStage = 'stage1_joint' | 'stage2_base' | 'stage_a' | 'stage_b';
8
+ export type ToolingDistillAction = 'run' | 'stage-a' | 'stage-b' | 'eval' | 'watch' | 'compare' | 'quality-gate' | 'subsets';
9
+ export type ToolingLoraAction = 'run' | 'eval' | 'watch' | 'export' | 'compare' | 'quality-gate' | 'activate';
8
10
 
9
11
  export interface ToolingConvertExecutionPayload {
10
12
  workers?: number | null;
@@ -25,6 +27,7 @@ export interface ToolingConvertPayload {
25
27
 
26
28
  export interface ToolingCommandRequestInput {
27
29
  command: ToolingCommand;
30
+ action?: ToolingDistillAction | ToolingLoraAction;
28
31
  suite?: ToolingSuite;
29
32
  modelId?: string;
30
33
  trainingTests?: string[];
@@ -65,6 +68,17 @@ export interface ToolingCommandRequestInput {
65
68
  inputDir?: string;
66
69
  outputDir?: string;
67
70
  convertPayload?: ToolingConvertPayload;
71
+ workloadPath?: string;
72
+ runRoot?: string;
73
+ checkpointPath?: string;
74
+ checkpointId?: string;
75
+ checkpointStep?: number;
76
+ stageId?: string;
77
+ stageArtifact?: string;
78
+ subsetManifest?: string;
79
+ evalDatasetId?: string;
80
+ pollIntervalMs?: number;
81
+ stopWhenIdle?: boolean;
68
82
  captureOutput?: boolean;
69
83
  keepPipeline?: boolean;
70
84
  report?: Record<string, unknown> | null;
@@ -76,6 +90,7 @@ export interface ToolingCommandRequest {
76
90
  command: ToolingCommand;
77
91
  suite: ToolingSuite | null;
78
92
  intent: ToolingIntent;
93
+ action: ToolingDistillAction | ToolingLoraAction | null;
79
94
  modelId: string | null;
80
95
  trainingTests: string[] | null;
81
96
  trainingStage: ToolingTrainingStage | null;
@@ -115,6 +130,17 @@ export interface ToolingCommandRequest {
115
130
  inputDir: string | null;
116
131
  outputDir: string | null;
117
132
  convertPayload: ToolingConvertPayload | null;
133
+ workloadPath: string | null;
134
+ runRoot: string | null;
135
+ checkpointPath: string | null;
136
+ checkpointId: string | null;
137
+ checkpointStep: number | null;
138
+ stageId: string | null;
139
+ stageArtifact: string | null;
140
+ subsetManifest: string | null;
141
+ evalDatasetId: string | null;
142
+ pollIntervalMs: number | null;
143
+ stopWhenIdle: boolean | null;
118
144
  captureOutput: boolean;
119
145
  keepPipeline: boolean;
120
146
  report: Record<string, unknown> | null;
@@ -1,12 +1,14 @@
1
1
  import { isPlainObject } from '../utils/plain-object.js';
2
2
  import { selectRuleValue } from '../rules/rule-registry.js';
3
3
 
4
- const TOOLING_COMMAND_SET = ['convert', 'debug', 'bench', 'verify'];
4
+ const TOOLING_COMMAND_SET = ['convert', 'debug', 'bench', 'verify', 'lora', 'distill'];
5
5
  const TOOLING_SURFACE_SET = ['browser', 'node'];
6
6
  const TOOLING_SUITE_SET = ['kernels', 'inference', 'training', 'bench', 'debug', 'diffusion', 'energy'];
7
7
  const TOOLING_INTENT_SET = ['verify', 'investigate', 'calibrate'];
8
8
  const VERIFY_SUITES = ['kernels', 'inference', 'training', 'diffusion', 'energy'];
9
9
  const TRAINING_STAGE_SET = ['stage1_joint', 'stage2_base', 'stage_a', 'stage_b'];
10
+ const DISTILL_ACTION_SET = ['run', 'stage-a', 'stage-b', 'eval', 'watch', 'compare', 'quality-gate', 'subsets'];
11
+ const LORA_ACTION_SET = ['run', 'eval', 'watch', 'export', 'compare', 'quality-gate', 'activate'];
10
12
  const TRAINING_COMMAND_SCHEMA_VERSION = 1;
11
13
 
12
14
  export const TOOLING_COMMANDS = Object.freeze([...TOOLING_COMMAND_SET]);
@@ -82,6 +84,15 @@ function asOptionalForceResumeReason(value, label) {
82
84
  return reason;
83
85
  }
84
86
 
87
+ function asOptionalAction(value, label, allowed) {
88
+ const action = asOptionalString(value, label);
89
+ if (!action) return null;
90
+ if (!allowed.includes(action)) {
91
+ throw new Error(`tooling command: ${label} must be one of ${allowed.join(', ')}.`);
92
+ }
93
+ return action;
94
+ }
95
+
85
96
  function assertCommand(value) {
86
97
  const command = asOptionalString(value, 'command');
87
98
  if (!command) {
@@ -246,6 +257,7 @@ function normalizeConvert(raw) {
246
257
  command: 'convert',
247
258
  suite: null,
248
259
  intent: null,
260
+ action: null,
249
261
  modelId: null,
250
262
  trainingTests: null,
251
263
  trainingStage: null,
@@ -285,6 +297,113 @@ function normalizeConvert(raw) {
285
297
  inputDir,
286
298
  outputDir,
287
299
  convertPayload: payload,
300
+ workloadPath: null,
301
+ runRoot: null,
302
+ checkpointPath: null,
303
+ checkpointId: null,
304
+ checkpointStep: null,
305
+ stageId: null,
306
+ stageArtifact: null,
307
+ subsetManifest: null,
308
+ evalDatasetId: null,
309
+ pollIntervalMs: null,
310
+ stopWhenIdle: null,
311
+ captureOutput: false,
312
+ keepPipeline: false,
313
+ report: asOptionalObject(raw.report, 'report'),
314
+ timestamp: raw.timestamp ?? null,
315
+ searchParams: raw.searchParams ?? null,
316
+ };
317
+ }
318
+
319
+ function normalizeTrainingOperatorCommand(raw, command) {
320
+ const allowedActions = command === 'distill' ? DISTILL_ACTION_SET : LORA_ACTION_SET;
321
+ const action = asOptionalAction(raw.action, 'action', allowedActions);
322
+ if (!action) {
323
+ throw new Error(`tooling command: ${command} requires action.`);
324
+ }
325
+ const workloadPath = asOptionalString(raw.workloadPath, 'workloadPath');
326
+ const runRoot = asOptionalString(raw.runRoot, 'runRoot');
327
+ const checkpointPath = asOptionalString(raw.checkpointPath, 'checkpointPath');
328
+ const checkpointId = asOptionalString(raw.checkpointId, 'checkpointId');
329
+ const checkpointStep = asOptionalPositiveInteger(raw.checkpointStep, 'checkpointStep');
330
+ const stageId = asOptionalString(raw.stageId, 'stageId');
331
+ const stageArtifact = asOptionalString(raw.stageArtifact, 'stageArtifact');
332
+ const subsetManifest = asOptionalString(raw.subsetManifest, 'subsetManifest');
333
+ const evalDatasetId = asOptionalString(raw.evalDatasetId, 'evalDatasetId');
334
+ const pollIntervalMs = asOptionalPositiveInteger(raw.pollIntervalMs, 'pollIntervalMs');
335
+ const stopWhenIdle = asOptionalBoolean(raw.stopWhenIdle, 'stopWhenIdle');
336
+ if (!workloadPath && !runRoot) {
337
+ throw new Error(`tooling command: ${command} requires workloadPath or runRoot.`);
338
+ }
339
+ if ((action === 'eval' || action === 'export') && !checkpointPath && !runRoot) {
340
+ throw new Error(`tooling command: ${command} ${action} requires checkpointPath or runRoot.`);
341
+ }
342
+ if (action === 'watch' && !runRoot) {
343
+ throw new Error(`tooling command: ${command} watch requires runRoot.`);
344
+ }
345
+ if ((action === 'compare' || action === 'quality-gate') && !runRoot) {
346
+ throw new Error(`tooling command: ${command} ${action} requires runRoot.`);
347
+ }
348
+ if (command === 'distill' && action === 'stage-b' && !stageArtifact && !runRoot) {
349
+ throw new Error('tooling command: distill stage-b requires stageArtifact or runRoot.');
350
+ }
351
+
352
+ return {
353
+ command,
354
+ suite: null,
355
+ intent: null,
356
+ action,
357
+ modelId: null,
358
+ trainingTests: null,
359
+ trainingStage: null,
360
+ trainingConfig: null,
361
+ stage1Artifact: null,
362
+ stage1ArtifactHash: null,
363
+ ulArtifactDir: null,
364
+ stageAArtifact: null,
365
+ stageAArtifactHash: null,
366
+ distillArtifactDir: null,
367
+ teacherModelId: null,
368
+ studentModelId: null,
369
+ distillDatasetId: null,
370
+ distillDatasetPath: null,
371
+ distillLanguagePair: null,
372
+ distillSourceLangs: null,
373
+ distillTargetLangs: null,
374
+ distillPairAllowlist: null,
375
+ strictPairContract: null,
376
+ distillShardIndex: null,
377
+ distillShardCount: null,
378
+ resumeFrom: null,
379
+ forceResume: null,
380
+ forceResumeReason: null,
381
+ forceResumeSource: null,
382
+ checkpointOperator: null,
383
+ trainingSchemaVersion: null,
384
+ trainingBenchSteps: null,
385
+ checkpointEvery: null,
386
+ workloadType: 'training',
387
+ modelUrl: null,
388
+ cacheMode: asOptionalCacheMode(raw.cacheMode, 'cacheMode'),
389
+ loadMode: asOptionalLoadMode(raw.loadMode, 'loadMode'),
390
+ runtimePreset: asOptionalString(raw.runtimePreset, 'runtimePreset'),
391
+ runtimeConfigUrl: asOptionalString(raw.runtimeConfigUrl, 'runtimeConfigUrl'),
392
+ runtimeConfig: asOptionalObject(raw.runtimeConfig, 'runtimeConfig'),
393
+ inputDir: null,
394
+ outputDir: null,
395
+ convertPayload: null,
396
+ workloadPath,
397
+ runRoot,
398
+ checkpointPath,
399
+ checkpointId,
400
+ checkpointStep,
401
+ stageId,
402
+ stageArtifact,
403
+ subsetManifest,
404
+ evalDatasetId,
405
+ pollIntervalMs,
406
+ stopWhenIdle,
288
407
  captureOutput: false,
289
408
  keepPipeline: false,
290
409
  report: asOptionalObject(raw.report, 'report'),
@@ -428,6 +547,7 @@ function normalizeSuiteCommand(raw, command) {
428
547
  command,
429
548
  suite,
430
549
  intent: runtimeContract.intent,
550
+ action: null,
431
551
  modelId,
432
552
  trainingTests,
433
553
  trainingStage,
@@ -469,6 +589,17 @@ function normalizeSuiteCommand(raw, command) {
469
589
  inputDir: null,
470
590
  outputDir: null,
471
591
  convertPayload: null,
592
+ workloadPath: null,
593
+ runRoot: null,
594
+ checkpointPath: null,
595
+ checkpointId: null,
596
+ checkpointStep: null,
597
+ stageId: null,
598
+ stageArtifact: null,
599
+ subsetManifest: null,
600
+ evalDatasetId: null,
601
+ pollIntervalMs: null,
602
+ stopWhenIdle: null,
472
603
  captureOutput: asOptionalBoolean(raw.captureOutput, 'captureOutput') ?? false,
473
604
  keepPipeline: asOptionalBoolean(raw.keepPipeline, 'keepPipeline') ?? false,
474
605
  report: asOptionalObject(raw.report, 'report'),
@@ -485,6 +616,9 @@ export function normalizeToolingCommandRequest(input) {
485
616
  if (command === 'convert') {
486
617
  return normalizeConvert(input);
487
618
  }
619
+ if (command === 'lora' || command === 'distill') {
620
+ return normalizeTrainingOperatorCommand(input, command);
621
+ }
488
622
  return normalizeSuiteCommand(input, command);
489
623
  }
490
624
 
@@ -514,8 +648,13 @@ export function ensureCommandSupportedOnSurface(commandRequest, surface) {
514
648
  throw new Error(`tooling command: unsupported surface "${surface}".`);
515
649
  }
516
650
 
517
- // All commands are contractually available on both surfaces.
518
- // Surface-specific capability checks happen in the runners.
651
+ if (
652
+ normalizedSurface === 'browser'
653
+ && (request.command === 'lora' || request.command === 'distill')
654
+ ) {
655
+ throw new Error(`tooling command: ${request.command} is currently Node-only and must fail closed on browser.`);
656
+ }
657
+
519
658
  return {
520
659
  request,
521
660
  surface: normalizedSurface,
@@ -0,0 +1,24 @@
1
+ export declare function extractTensorEntriesFromManifest(
2
+ manifest: Record<string, unknown>
3
+ ): Array<{
4
+ name: string;
5
+ dtype: unknown;
6
+ shape: unknown;
7
+ role: unknown;
8
+ layout: unknown;
9
+ }>;
10
+
11
+ export declare function resolveMaterializedManifestFromConversionConfig(
12
+ conversionConfigInput: Record<string, unknown>,
13
+ manifest: Record<string, unknown>
14
+ ): {
15
+ modelId: string;
16
+ modelType: string;
17
+ architecture: Record<string, unknown> | null;
18
+ inference: Record<string, unknown> | null;
19
+ };
20
+
21
+ export declare function inferConversionConfigModelId(
22
+ configPath: string,
23
+ conversionConfigInput: Record<string, unknown>
24
+ ): string;