@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
// Grouped Pointwise Conv2D Kernel (NCHW, f16)
|
|
2
|
+
|
|
3
|
+
enable f16;
|
|
4
|
+
|
|
5
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
6
|
+
|
|
7
|
+
struct Uniforms {
|
|
8
|
+
in_channels: u32,
|
|
9
|
+
out_channels: u32,
|
|
10
|
+
height: u32,
|
|
11
|
+
width: u32,
|
|
12
|
+
groups: u32,
|
|
13
|
+
_pad0: u32,
|
|
14
|
+
_pad1: u32,
|
|
15
|
+
_pad2: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
19
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
20
|
+
@group(0) @binding(2) var<storage, read> weight: array<f16>;
|
|
21
|
+
@group(0) @binding(3) var<storage, read> bias: array<f16>;
|
|
22
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f16>;
|
|
23
|
+
|
|
24
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
+
let spatial = u.height * u.width;
|
|
27
|
+
let spatial_idx = gid.x;
|
|
28
|
+
let out_channel = gid.y;
|
|
29
|
+
if (spatial_idx >= spatial || out_channel >= u.out_channels) {
|
|
30
|
+
return;
|
|
31
|
+
}
|
|
32
|
+
let y = spatial_idx / u.width;
|
|
33
|
+
let x = spatial_idx - y * u.width;
|
|
34
|
+
|
|
35
|
+
let in_per_group = u.in_channels / u.groups;
|
|
36
|
+
let out_per_group = u.out_channels / u.groups;
|
|
37
|
+
let group_idx = out_channel / out_per_group;
|
|
38
|
+
let in_offset = group_idx * in_per_group;
|
|
39
|
+
|
|
40
|
+
var sum: f32 = f32(bias[out_channel]);
|
|
41
|
+
for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
|
|
42
|
+
let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
|
|
43
|
+
let weight_idx = out_channel * in_per_group + i;
|
|
44
|
+
sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
output[out_channel * spatial + spatial_idx] = f16(sum);
|
|
48
|
+
}
|
|
@@ -174,6 +174,18 @@ export {
|
|
|
174
174
|
type Conv2DOptions,
|
|
175
175
|
} from './conv2d.js';
|
|
176
176
|
|
|
177
|
+
export {
|
|
178
|
+
runDepthwiseConv2D,
|
|
179
|
+
recordDepthwiseConv2D,
|
|
180
|
+
type DepthwiseConv2DOptions,
|
|
181
|
+
} from './depthwise_conv2d.js';
|
|
182
|
+
|
|
183
|
+
export {
|
|
184
|
+
runGroupedPointwiseConv2D,
|
|
185
|
+
recordGroupedPointwiseConv2D,
|
|
186
|
+
type GroupedPointwiseConv2DOptions,
|
|
187
|
+
} from './grouped_pointwise_conv2d.js';
|
|
188
|
+
|
|
177
189
|
// Gather (Embedding Lookup)
|
|
178
190
|
export {
|
|
179
191
|
runGather,
|
|
@@ -250,6 +262,24 @@ export {
|
|
|
250
262
|
type SampleResult,
|
|
251
263
|
} from './sample.js';
|
|
252
264
|
|
|
265
|
+
export {
|
|
266
|
+
runSanaLinearAttention,
|
|
267
|
+
recordSanaLinearAttention,
|
|
268
|
+
type SanaLinearAttentionOptions,
|
|
269
|
+
} from './sana_linear_attention.js';
|
|
270
|
+
|
|
271
|
+
export {
|
|
272
|
+
runRepeatChannels,
|
|
273
|
+
recordRepeatChannels,
|
|
274
|
+
type RepeatChannelsOptions,
|
|
275
|
+
} from './repeat_channels.js';
|
|
276
|
+
|
|
277
|
+
export {
|
|
278
|
+
runReLU,
|
|
279
|
+
recordReLU,
|
|
280
|
+
type ReLUOptions,
|
|
281
|
+
} from './relu.js';
|
|
282
|
+
|
|
253
283
|
// Fused FFN (Tier 2 P0)
|
|
254
284
|
export {
|
|
255
285
|
runFusedFFN,
|
package/src/gpu/kernels/index.js
CHANGED
|
@@ -139,6 +139,16 @@ export {
|
|
|
139
139
|
recordConv2D,
|
|
140
140
|
} from './conv2d.js';
|
|
141
141
|
|
|
142
|
+
export {
|
|
143
|
+
runDepthwiseConv2D,
|
|
144
|
+
recordDepthwiseConv2D,
|
|
145
|
+
} from './depthwise_conv2d.js';
|
|
146
|
+
|
|
147
|
+
export {
|
|
148
|
+
runGroupedPointwiseConv2D,
|
|
149
|
+
recordGroupedPointwiseConv2D,
|
|
150
|
+
} from './grouped_pointwise_conv2d.js';
|
|
151
|
+
|
|
142
152
|
// Gather (Embedding Lookup)
|
|
143
153
|
export {
|
|
144
154
|
runGather,
|
|
@@ -205,6 +215,21 @@ export {
|
|
|
205
215
|
isGPUSamplingAvailable,
|
|
206
216
|
} from './sample.js';
|
|
207
217
|
|
|
218
|
+
export {
|
|
219
|
+
runSanaLinearAttention,
|
|
220
|
+
recordSanaLinearAttention,
|
|
221
|
+
} from './sana_linear_attention.js';
|
|
222
|
+
|
|
223
|
+
export {
|
|
224
|
+
runRepeatChannels,
|
|
225
|
+
recordRepeatChannels,
|
|
226
|
+
} from './repeat_channels.js';
|
|
227
|
+
|
|
228
|
+
export {
|
|
229
|
+
runReLU,
|
|
230
|
+
recordReLU,
|
|
231
|
+
} from './relu.js';
|
|
232
|
+
|
|
208
233
|
// Fused FFN (Tier 2 P0)
|
|
209
234
|
export {
|
|
210
235
|
runFusedFFN,
|
|
@@ -52,6 +52,23 @@ function buildProfileLabel(options = {}) {
|
|
|
52
52
|
return `matmul${roleLabel}${layerLabel}`;
|
|
53
53
|
}
|
|
54
54
|
|
|
55
|
+
function assertBindGroupBuffer(kernelName, variant, bindingIndex, bindingLabel, buffer, details = []) {
|
|
56
|
+
const isGpuBuffer = buffer && (
|
|
57
|
+
typeof GPUBuffer === 'undefined'
|
|
58
|
+
? true
|
|
59
|
+
: buffer instanceof GPUBuffer
|
|
60
|
+
);
|
|
61
|
+
if (isGpuBuffer) {
|
|
62
|
+
return;
|
|
63
|
+
}
|
|
64
|
+
const detailText = details.filter(Boolean).join(', ');
|
|
65
|
+
throw new Error(
|
|
66
|
+
`[${kernelName}] variant="${variant}" binding ${bindingIndex} "${bindingLabel}" requires a GPUBuffer` +
|
|
67
|
+
(detailText ? ` (${detailText})` : '') +
|
|
68
|
+
'.'
|
|
69
|
+
);
|
|
70
|
+
}
|
|
71
|
+
|
|
55
72
|
function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuffer, outputBuffer, offsets, bindingSizes) {
|
|
56
73
|
const isQ4KF16 = variant === 'q4_fused_multicol_f16'
|
|
57
74
|
|| variant === 'q4_fused_f16a'
|
|
@@ -59,6 +76,14 @@ function createMatmulBindGroupEntries(variant, uniformBuffer, matmulInput, bBuff
|
|
|
59
76
|
|| variant === 'q4_fused_multicol_f16a'
|
|
60
77
|
|| variant === 'q4_fused_batched_f16a';
|
|
61
78
|
|
|
79
|
+
assertBindGroupBuffer('matmul', variant, 0, 'uniforms', uniformBuffer);
|
|
80
|
+
assertBindGroupBuffer('matmul', variant, 1, 'input', matmulInput?.buffer, [
|
|
81
|
+
`inputLabel=${matmulInput?.label ?? 'unknown'}`,
|
|
82
|
+
`inputDtype=${matmulInput?.dtype ?? 'unknown'}`,
|
|
83
|
+
]);
|
|
84
|
+
assertBindGroupBuffer('matmul', variant, 2, 'weights', bBuffer);
|
|
85
|
+
assertBindGroupBuffer('matmul', variant, isQ4KF16 ? 4 : 3, 'output', outputBuffer);
|
|
86
|
+
|
|
62
87
|
const entries = [
|
|
63
88
|
{ binding: 0, resource: { buffer: uniformBuffer } },
|
|
64
89
|
{ binding: 1, resource: { buffer: matmulInput.buffer, offset: offsets.aOffset, size: bindingSizes.aBindingSize } },
|
|
@@ -34,7 +34,7 @@ async function _pixelShuffle(target, input, options = {}) {
|
|
|
34
34
|
grid_width: gridWidth, grid_height: gridHeight, patch_size: patchSize,
|
|
35
35
|
patch_channels: inferredPatchChannels, _pad0: 0,
|
|
36
36
|
},
|
|
37
|
-
Math.ceil((
|
|
37
|
+
[Math.ceil((outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
38
38
|
);
|
|
39
39
|
|
|
40
40
|
return createTensor(output, input.dtype, [outChannels, outHeight, outWidth], 'pixel_shuffle_output');
|
|
@@ -19,17 +19,16 @@ struct Uniforms {
|
|
|
19
19
|
|
|
20
20
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
21
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
-
let idx = gid.x;
|
|
23
22
|
let spatial_size = u.out_height * u.out_width;
|
|
24
|
-
let
|
|
25
|
-
|
|
23
|
+
let spatial = gid.x;
|
|
24
|
+
let c = gid.y;
|
|
25
|
+
if (c >= u.out_channels || spatial >= spatial_size) {
|
|
26
26
|
return;
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
-
let c = idx / spatial_size;
|
|
30
|
-
let spatial = idx - c * spatial_size;
|
|
31
29
|
let y = spatial / u.out_width;
|
|
32
30
|
let x = spatial - y * u.out_width;
|
|
31
|
+
let idx = c * spatial_size + spatial;
|
|
33
32
|
|
|
34
33
|
let grid_y = y / u.patch_size;
|
|
35
34
|
let grid_x = x / u.patch_size;
|
|
@@ -22,17 +22,16 @@ struct Uniforms {
|
|
|
22
22
|
|
|
23
23
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
24
24
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
25
|
-
let idx = gid.x;
|
|
26
25
|
let spatial_size = u.out_height * u.out_width;
|
|
27
|
-
let
|
|
28
|
-
|
|
26
|
+
let spatial = gid.x;
|
|
27
|
+
let c = gid.y;
|
|
28
|
+
if (c >= u.out_channels || spatial >= spatial_size) {
|
|
29
29
|
return;
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
let c = idx / spatial_size;
|
|
33
|
-
let spatial = idx - c * spatial_size;
|
|
34
32
|
let y = spatial / u.out_width;
|
|
35
33
|
let x = spatial - y * u.out_width;
|
|
34
|
+
let idx = c * spatial_size + spatial;
|
|
36
35
|
|
|
37
36
|
let grid_y = y / u.patch_size;
|
|
38
37
|
let grid_x = x / u.patch_size;
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface ReLUOptions extends OutputBufferOptions {
|
|
6
|
+
count?: number | null;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export declare function runReLU(
|
|
10
|
+
input: Tensor,
|
|
11
|
+
options?: ReLUOptions
|
|
12
|
+
): Promise<Tensor>;
|
|
13
|
+
|
|
14
|
+
export declare function recordReLU(
|
|
15
|
+
recorder: CommandRecorder,
|
|
16
|
+
input: Tensor,
|
|
17
|
+
options?: ReLUOptions
|
|
18
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectReluVariant(dtype) {
|
|
8
|
+
return selectRuleValue('relu', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
function resolveCount(input, countOverride) {
|
|
12
|
+
if (Number.isFinite(countOverride) && countOverride > 0) {
|
|
13
|
+
return Math.floor(countOverride);
|
|
14
|
+
}
|
|
15
|
+
if (Array.isArray(input.shape) && input.shape.length > 0) {
|
|
16
|
+
return input.shape.reduce((acc, value) => acc * value, 1);
|
|
17
|
+
}
|
|
18
|
+
return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
function planReluDispatch(target, size) {
|
|
22
|
+
const device = target?.device;
|
|
23
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
24
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
25
|
+
: 65535;
|
|
26
|
+
const dispatchStride = Math.min(size, maxPerDim * WORKGROUP_SIZES.DEFAULT);
|
|
27
|
+
return {
|
|
28
|
+
dispatchStride,
|
|
29
|
+
workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
|
|
30
|
+
};
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
async function _relu(target, input, options = {}) {
|
|
34
|
+
const { count = null, outputBuffer = null } = options;
|
|
35
|
+
const size = resolveCount(input, count);
|
|
36
|
+
const variant = selectReluVariant(input.dtype);
|
|
37
|
+
const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
|
|
38
|
+
const dispatchPlan = planReluDispatch(target, size);
|
|
39
|
+
|
|
40
|
+
await unifiedKernelWrapper(
|
|
41
|
+
'relu',
|
|
42
|
+
target,
|
|
43
|
+
variant,
|
|
44
|
+
[input, output],
|
|
45
|
+
{ size, _pad0: dispatchPlan.dispatchStride, _pad1: 0, _pad2: 0 },
|
|
46
|
+
dispatchPlan.workgroups
|
|
47
|
+
);
|
|
48
|
+
|
|
49
|
+
return createTensor(output, input.dtype, [...input.shape], 'relu_output');
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
export async function runReLU(input, options = {}) {
|
|
53
|
+
return _relu(null, input, options);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
export async function recordReLU(recorder, input, options = {}) {
|
|
57
|
+
return _relu(recorder, input, options);
|
|
58
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
size: u32,
|
|
5
|
+
_pad0: u32,
|
|
6
|
+
_pad1: u32,
|
|
7
|
+
_pad2: u32,
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
11
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
12
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
13
|
+
|
|
14
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
15
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
16
|
+
let dispatch_stride = max(u._pad0, 1u);
|
|
17
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
18
|
+
if (idx >= u.size) {
|
|
19
|
+
return;
|
|
20
|
+
}
|
|
21
|
+
output[idx] = max(input[idx], 0.0);
|
|
22
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
size: u32,
|
|
7
|
+
_pad0: u32,
|
|
8
|
+
_pad1: u32,
|
|
9
|
+
_pad2: u32,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
13
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
14
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
15
|
+
|
|
16
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
17
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
18
|
+
let dispatch_stride = max(u._pad0, 1u);
|
|
19
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
20
|
+
if (idx >= u.size) {
|
|
21
|
+
return;
|
|
22
|
+
}
|
|
23
|
+
output[idx] = max(input[idx], 0.0h);
|
|
24
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface RepeatChannelsOptions extends OutputBufferOptions {
|
|
6
|
+
inChannels: number;
|
|
7
|
+
height: number;
|
|
8
|
+
width: number;
|
|
9
|
+
repeats: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export declare function runRepeatChannels(
|
|
13
|
+
input: Tensor,
|
|
14
|
+
options: RepeatChannelsOptions
|
|
15
|
+
): Promise<Tensor>;
|
|
16
|
+
|
|
17
|
+
export declare function recordRepeatChannels(
|
|
18
|
+
recorder: CommandRecorder,
|
|
19
|
+
input: Tensor,
|
|
20
|
+
options: RepeatChannelsOptions
|
|
21
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectRepeatChannelsVariant(dtype) {
|
|
8
|
+
return selectRuleValue('repeatChannels', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
async function _repeatChannels(target, input, options = {}) {
|
|
12
|
+
const {
|
|
13
|
+
inChannels,
|
|
14
|
+
height,
|
|
15
|
+
width,
|
|
16
|
+
repeats,
|
|
17
|
+
outputBuffer = null,
|
|
18
|
+
} = options;
|
|
19
|
+
|
|
20
|
+
if (
|
|
21
|
+
!Number.isFinite(inChannels) ||
|
|
22
|
+
!Number.isFinite(height) ||
|
|
23
|
+
!Number.isFinite(width) ||
|
|
24
|
+
!Number.isFinite(repeats) ||
|
|
25
|
+
repeats < 1
|
|
26
|
+
) {
|
|
27
|
+
throw new Error('RepeatChannels requires inChannels, height, width, and repeats.');
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const outChannels = inChannels * repeats;
|
|
31
|
+
const variant = selectRepeatChannelsVariant(input.dtype);
|
|
32
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
33
|
+
const outputSize = outChannels * height * width * bytesPerElement;
|
|
34
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
|
|
35
|
+
|
|
36
|
+
await unifiedKernelWrapper(
|
|
37
|
+
'repeat_channels',
|
|
38
|
+
target,
|
|
39
|
+
variant,
|
|
40
|
+
[input, output],
|
|
41
|
+
{
|
|
42
|
+
in_channels: inChannels,
|
|
43
|
+
height,
|
|
44
|
+
width,
|
|
45
|
+
repeats,
|
|
46
|
+
_pad0: 0,
|
|
47
|
+
},
|
|
48
|
+
[Math.ceil((height * width) / WORKGROUP_SIZES.DEFAULT), outChannels, 1]
|
|
49
|
+
);
|
|
50
|
+
|
|
51
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export async function runRepeatChannels(input, options = {}) {
|
|
55
|
+
return _repeatChannels(null, input, options);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export async function recordRepeatChannels(recorder, input, options = {}) {
|
|
59
|
+
return _repeatChannels(recorder, input, options);
|
|
60
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
in_channels: u32,
|
|
5
|
+
height: u32,
|
|
6
|
+
width: u32,
|
|
7
|
+
repeats: u32,
|
|
8
|
+
_pad0: u32,
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
12
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
13
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
14
|
+
|
|
15
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
16
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
17
|
+
let spatial = u.height * u.width;
|
|
18
|
+
let out_channels = u.in_channels * u.repeats;
|
|
19
|
+
let spatial_idx = gid.x;
|
|
20
|
+
let out_channel = gid.y;
|
|
21
|
+
if (out_channel >= out_channels || spatial_idx >= spatial) {
|
|
22
|
+
return;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
let channel = out_channel / u.repeats;
|
|
26
|
+
let idx = out_channel * spatial + spatial_idx;
|
|
27
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
28
|
+
}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
in_channels: u32,
|
|
7
|
+
height: u32,
|
|
8
|
+
width: u32,
|
|
9
|
+
repeats: u32,
|
|
10
|
+
_pad0: u32,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
14
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
15
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
16
|
+
|
|
17
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
18
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
19
|
+
let spatial = u.height * u.width;
|
|
20
|
+
let out_channels = u.in_channels * u.repeats;
|
|
21
|
+
let spatial_idx = gid.x;
|
|
22
|
+
let out_channel = gid.y;
|
|
23
|
+
if (out_channel >= out_channels || spatial_idx >= spatial) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let channel = out_channel / u.repeats;
|
|
28
|
+
let idx = out_channel * spatial + spatial_idx;
|
|
29
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
30
|
+
}
|
|
@@ -63,6 +63,22 @@ function cleanupTemps(temps, recorder) {
|
|
|
63
63
|
}
|
|
64
64
|
}
|
|
65
65
|
|
|
66
|
+
function planResidualDispatch(target, size, elementsPerWorkgroup) {
|
|
67
|
+
const device = target?.device;
|
|
68
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
69
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
70
|
+
: 65535;
|
|
71
|
+
const dispatchStride = Math.min(size, maxPerDim * elementsPerWorkgroup);
|
|
72
|
+
return {
|
|
73
|
+
dispatchStride,
|
|
74
|
+
workgroups: [
|
|
75
|
+
Math.ceil(dispatchStride / elementsPerWorkgroup),
|
|
76
|
+
Math.ceil(size / dispatchStride),
|
|
77
|
+
1,
|
|
78
|
+
],
|
|
79
|
+
};
|
|
80
|
+
}
|
|
81
|
+
|
|
66
82
|
async function _residualAdd(target, a, b, size, options = {}) {
|
|
67
83
|
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
68
84
|
const { useVec4 = true, outputBuffer = null } = options;
|
|
@@ -75,15 +91,17 @@ async function _residualAdd(target, a, b, size, options = {}) {
|
|
|
75
91
|
const outputSize = size * bytesPerElement;
|
|
76
92
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'residual_output');
|
|
77
93
|
|
|
78
|
-
const
|
|
79
|
-
|
|
80
|
-
|
|
94
|
+
const dispatchPlan = planResidualDispatch(
|
|
95
|
+
target,
|
|
96
|
+
size,
|
|
97
|
+
useVec4 ? VEC4_ELEMENTS_PER_WG : WORKGROUP_SIZES.DEFAULT
|
|
98
|
+
);
|
|
81
99
|
|
|
82
100
|
await unifiedKernelWrapper(
|
|
83
101
|
'residual', target, variant,
|
|
84
102
|
[aAligned, bAligned, output],
|
|
85
|
-
{ size },
|
|
86
|
-
workgroups
|
|
103
|
+
{ size, scale: 1, _pad1: dispatchPlan.dispatchStride, _pad2: 0 },
|
|
104
|
+
dispatchPlan.workgroups
|
|
87
105
|
);
|
|
88
106
|
|
|
89
107
|
cleanupTemps(temps, recorder);
|
|
@@ -96,13 +114,31 @@ async function _biasAdd(target, data, bias, numTokens, dim, options = {}) {
|
|
|
96
114
|
|
|
97
115
|
const { bias: biasAligned, temps } = await alignBiasTensor(data, bias, recorder);
|
|
98
116
|
const variant = selectBiasAddVariant(data.dtype, biasAligned.dtype);
|
|
99
|
-
|
|
100
|
-
const
|
|
117
|
+
const device = target?.device;
|
|
118
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
119
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
120
|
+
: 65535;
|
|
121
|
+
const tokenStride = Math.min(numTokens, maxPerDim);
|
|
122
|
+
|
|
123
|
+
const workgroups = [
|
|
124
|
+
Math.ceil(dim / WORKGROUP_SIZES.DEFAULT),
|
|
125
|
+
tokenStride,
|
|
126
|
+
Math.ceil(numTokens / tokenStride),
|
|
127
|
+
];
|
|
101
128
|
|
|
102
129
|
await unifiedKernelWrapper(
|
|
103
130
|
'bias_add', target, variant,
|
|
104
131
|
[data, biasAligned],
|
|
105
|
-
{
|
|
132
|
+
{
|
|
133
|
+
num_tokens: numTokens,
|
|
134
|
+
dim,
|
|
135
|
+
data_offset: dataOffset,
|
|
136
|
+
bias_offset: biasOffset,
|
|
137
|
+
token_stride: tokenStride,
|
|
138
|
+
_pad0: 0,
|
|
139
|
+
_pad1: 0,
|
|
140
|
+
_pad2: 0,
|
|
141
|
+
},
|
|
106
142
|
workgroups
|
|
107
143
|
);
|
|
108
144
|
|
|
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
23
23
|
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
25
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let
|
|
26
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
27
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
27
28
|
if (idx >= u.size) {
|
|
28
29
|
return;
|
|
29
30
|
}
|
|
@@ -35,7 +36,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
35
36
|
// This avoids requiring a different bind group layout with read_write on 'a'
|
|
36
37
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
37
38
|
fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
38
|
-
let
|
|
39
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
40
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
39
41
|
if (idx >= u.size) {
|
|
40
42
|
return;
|
|
41
43
|
}
|
|
@@ -45,7 +47,8 @@ fn add_inplace(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
45
47
|
// Fused residual + scale: output = a + scale * b
|
|
46
48
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
47
49
|
fn add_scaled(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
48
|
-
let
|
|
50
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
51
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
49
52
|
if (idx >= u.size) {
|
|
50
53
|
return;
|
|
51
54
|
}
|
|
@@ -27,7 +27,8 @@ override WORKGROUP_SIZE: u32 = 256u;
|
|
|
27
27
|
|
|
28
28
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
29
29
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
30
|
-
let
|
|
30
|
+
let dispatch_stride = max(u._pad1, 1u);
|
|
31
|
+
let idx = gid.y * dispatch_stride + gid.x;
|
|
31
32
|
if (idx >= u.size) {
|
|
32
33
|
return;
|
|
33
34
|
}
|
|
@@ -25,7 +25,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
|
|
|
25
25
|
// Vectorized version for better throughput
|
|
26
26
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
27
27
|
fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
28
|
-
let
|
|
28
|
+
let dispatch_stride = max(u._pad1, 4u);
|
|
29
|
+
let idx = gid.y * dispatch_stride + gid.x * 4u;
|
|
29
30
|
let size = u.size;
|
|
30
31
|
|
|
31
32
|
if (idx >= size) {
|
|
@@ -23,7 +23,8 @@ override WORKGROUP_SIZE_VEC4: u32 = 64u;
|
|
|
23
23
|
// Vectorized version for better throughput
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE_VEC4, 1, 1)
|
|
25
25
|
fn add_vec4(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let
|
|
26
|
+
let dispatch_stride = max(u._pad1, 4u);
|
|
27
|
+
let idx = gid.y * dispatch_stride + gid.x * 4u;
|
|
27
28
|
let size = u.size;
|
|
28
29
|
|
|
29
30
|
if (idx >= size) {
|