@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
num_heads: u32,
|
|
5
|
+
head_dim: u32,
|
|
6
|
+
num_tokens: u32,
|
|
7
|
+
hidden_size: u32,
|
|
8
|
+
eps: f32,
|
|
9
|
+
_pad0: u32,
|
|
10
|
+
_pad1: u32,
|
|
11
|
+
_pad2: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> query: array<f32>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> summary: array<f32>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
|
|
18
|
+
|
|
19
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
20
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
21
|
+
let hidden = gid.x;
|
|
22
|
+
let token = gid.y;
|
|
23
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let idx = token * u.hidden_size + hidden;
|
|
28
|
+
let head = hidden / u.head_dim;
|
|
29
|
+
let dim = hidden - head * u.head_dim;
|
|
30
|
+
let rows_per_head = u.head_dim + 1u;
|
|
31
|
+
let head_offset = head * rows_per_head * u.head_dim;
|
|
32
|
+
let hidden_base = head * u.head_dim;
|
|
33
|
+
|
|
34
|
+
var numerator: f32 = 0.0;
|
|
35
|
+
var denominator: f32 = 0.0;
|
|
36
|
+
for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
|
|
37
|
+
let q_value = max(query[token * u.hidden_size + hidden_base + i], 0.0);
|
|
38
|
+
numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
|
|
39
|
+
denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
output[idx] = numerator / (denominator + u.eps);
|
|
43
|
+
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
num_heads: u32,
|
|
7
|
+
head_dim: u32,
|
|
8
|
+
num_tokens: u32,
|
|
9
|
+
hidden_size: u32,
|
|
10
|
+
eps: f32,
|
|
11
|
+
_pad0: u32,
|
|
12
|
+
_pad1: u32,
|
|
13
|
+
_pad2: u32,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
17
|
+
@group(0) @binding(1) var<storage, read> query: array<f16>;
|
|
18
|
+
@group(0) @binding(2) var<storage, read> summary: array<f32>;
|
|
19
|
+
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
|
|
20
|
+
|
|
21
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
22
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
23
|
+
let hidden = gid.x;
|
|
24
|
+
let token = gid.y;
|
|
25
|
+
if (token >= u.num_tokens || hidden >= u.hidden_size) {
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
let idx = token * u.hidden_size + hidden;
|
|
30
|
+
let head = hidden / u.head_dim;
|
|
31
|
+
let dim = hidden - head * u.head_dim;
|
|
32
|
+
let rows_per_head = u.head_dim + 1u;
|
|
33
|
+
let head_offset = head * rows_per_head * u.head_dim;
|
|
34
|
+
let hidden_base = head * u.head_dim;
|
|
35
|
+
|
|
36
|
+
var numerator: f32 = 0.0;
|
|
37
|
+
var denominator: f32 = 0.0;
|
|
38
|
+
for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
|
|
39
|
+
let q_value = max(f32(query[token * u.hidden_size + hidden_base + i]), 0.0);
|
|
40
|
+
numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
|
|
41
|
+
denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
let result = numerator / (denominator + u.eps);
|
|
45
|
+
output[idx] = f16(clamp(result, -65504.0, 65504.0));
|
|
46
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
num_heads: u32,
|
|
5
|
+
head_dim: u32,
|
|
6
|
+
num_tokens: u32,
|
|
7
|
+
hidden_size: u32,
|
|
8
|
+
_pad0: u32,
|
|
9
|
+
_pad1: u32,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
13
|
+
@group(0) @binding(1) var<storage, read> query: array<f32>;
|
|
14
|
+
@group(0) @binding(2) var<storage, read> key: array<f32>;
|
|
15
|
+
@group(0) @binding(3) var<storage, read> value: array<f32>;
|
|
16
|
+
@group(0) @binding(4) var<storage, read_write> summary: array<f32>;
|
|
17
|
+
|
|
18
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
19
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
20
|
+
let idx = gid.x;
|
|
21
|
+
let rows_per_head = u.head_dim + 1u;
|
|
22
|
+
let head_span = rows_per_head * u.head_dim;
|
|
23
|
+
let total = u.num_heads * head_span;
|
|
24
|
+
if (idx >= total) {
|
|
25
|
+
return;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
let head = idx / head_span;
|
|
29
|
+
let rem = idx - head * head_span;
|
|
30
|
+
let row = rem / u.head_dim;
|
|
31
|
+
let col = rem - row * u.head_dim;
|
|
32
|
+
let hidden_base = head * u.head_dim;
|
|
33
|
+
|
|
34
|
+
var acc: f32 = 0.0;
|
|
35
|
+
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
36
|
+
let query_value = query[token * u.hidden_size + hidden_base + col];
|
|
37
|
+
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
38
|
+
let key_value = max(key[key_idx], 0.0);
|
|
39
|
+
let value_value = select(
|
|
40
|
+
value[token * u.hidden_size + hidden_base + row],
|
|
41
|
+
1.0,
|
|
42
|
+
row == u.head_dim
|
|
43
|
+
);
|
|
44
|
+
if (u.hidden_size == 0u) {
|
|
45
|
+
acc = acc + query_value;
|
|
46
|
+
}
|
|
47
|
+
acc = acc + value_value * key_value;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
summary[idx] = acc;
|
|
51
|
+
}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
num_heads: u32,
|
|
7
|
+
head_dim: u32,
|
|
8
|
+
num_tokens: u32,
|
|
9
|
+
hidden_size: u32,
|
|
10
|
+
_pad0: u32,
|
|
11
|
+
_pad1: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> query: array<f16>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> key: array<f16>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read> value: array<f16>;
|
|
18
|
+
@group(0) @binding(4) var<storage, read_write> summary: array<f32>;
|
|
19
|
+
|
|
20
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
+
let idx = gid.x;
|
|
23
|
+
let rows_per_head = u.head_dim + 1u;
|
|
24
|
+
let head_span = rows_per_head * u.head_dim;
|
|
25
|
+
let total = u.num_heads * head_span;
|
|
26
|
+
if (idx >= total) {
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
let head = idx / head_span;
|
|
31
|
+
let rem = idx - head * head_span;
|
|
32
|
+
let row = rem / u.head_dim;
|
|
33
|
+
let col = rem - row * u.head_dim;
|
|
34
|
+
let hidden_base = head * u.head_dim;
|
|
35
|
+
|
|
36
|
+
var acc: f32 = 0.0;
|
|
37
|
+
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
38
|
+
let query_value = f32(query[token * u.hidden_size + hidden_base + col]);
|
|
39
|
+
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
40
|
+
let key_value = max(f32(key[key_idx]), 0.0);
|
|
41
|
+
let value_value = select(
|
|
42
|
+
f32(value[token * u.hidden_size + hidden_base + row]),
|
|
43
|
+
1.0,
|
|
44
|
+
row == u.head_dim
|
|
45
|
+
);
|
|
46
|
+
if (u.hidden_size == 0u) {
|
|
47
|
+
acc = acc + query_value;
|
|
48
|
+
}
|
|
49
|
+
acc = acc + value_value * key_value;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
summary[idx] = acc;
|
|
53
|
+
}
|
|
@@ -16,6 +16,7 @@ export interface SiLUOptions extends OutputBufferOptions {
|
|
|
16
16
|
size?: number | null;
|
|
17
17
|
gate?: Tensor | null;
|
|
18
18
|
gateActivation?: 'silu' | 'sigmoid';
|
|
19
|
+
inputActivation?: 'silu' | 'identity';
|
|
19
20
|
useVec4?: boolean;
|
|
20
21
|
biasOffset?: number;
|
|
21
22
|
swigluLimit: number | null;
|
package/src/gpu/kernels/silu.js
CHANGED
|
@@ -47,6 +47,18 @@ function createSiLUBindGroupEntries(uniformBuffer, input, output, gate) {
|
|
|
47
47
|
];
|
|
48
48
|
}
|
|
49
49
|
|
|
50
|
+
function planSiLUDispatch(device, size, useVec4) {
|
|
51
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
52
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
53
|
+
: 65535;
|
|
54
|
+
const laneWidth = useVec4 ? 4 : 1;
|
|
55
|
+
const chunkSize = maxPerDim * WORKGROUP_SIZES.DEFAULT * laneWidth;
|
|
56
|
+
const dispatchStride = Math.min(size, chunkSize);
|
|
57
|
+
const x = Math.min(maxPerDim, Math.ceil(dispatchStride / (WORKGROUP_SIZES.DEFAULT * laneWidth)));
|
|
58
|
+
const y = Math.max(1, Math.ceil(size / chunkSize));
|
|
59
|
+
return { dispatchStride, workgroups: [x, y, 1] };
|
|
60
|
+
}
|
|
61
|
+
|
|
50
62
|
|
|
51
63
|
export async function runSiLU(
|
|
52
64
|
input,
|
|
@@ -60,6 +72,7 @@ export async function runSiLU(
|
|
|
60
72
|
useVec4 = false,
|
|
61
73
|
swigluLimit,
|
|
62
74
|
gateActivation = 'silu',
|
|
75
|
+
inputActivation = 'silu',
|
|
63
76
|
} = options;
|
|
64
77
|
const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
|
|
65
78
|
|
|
@@ -74,14 +87,17 @@ export async function runSiLU(
|
|
|
74
87
|
useSplit: false,
|
|
75
88
|
useRowsplit: false,
|
|
76
89
|
});
|
|
77
|
-
const constants =
|
|
78
|
-
|
|
79
|
-
:
|
|
90
|
+
const constants = {
|
|
91
|
+
...(overrides || {}),
|
|
92
|
+
...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
|
|
93
|
+
...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
|
|
94
|
+
};
|
|
80
95
|
const pipeline = await getPipelineFast('silu', variant, null, constants);
|
|
81
96
|
|
|
82
97
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
83
98
|
const outputSize = inferredSize * bytesPerElement;
|
|
84
99
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
100
|
+
const dispatchPlan = planSiLUDispatch(device, inferredSize, useVec4);
|
|
85
101
|
|
|
86
102
|
// Create uniform buffer
|
|
87
103
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -89,7 +105,7 @@ export async function runSiLU(
|
|
|
89
105
|
16,
|
|
90
106
|
(view) => {
|
|
91
107
|
view.setUint32(0, inferredSize, true);
|
|
92
|
-
view.setUint32(4,
|
|
108
|
+
view.setUint32(4, dispatchPlan.dispatchStride, true);
|
|
93
109
|
view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
|
|
94
110
|
view.setFloat32(12, 0, true);
|
|
95
111
|
},
|
|
@@ -106,8 +122,7 @@ export async function runSiLU(
|
|
|
106
122
|
entries,
|
|
107
123
|
});
|
|
108
124
|
|
|
109
|
-
|
|
110
|
-
dispatch(device, pipeline, bindGroup, workgroups, 'silu');
|
|
125
|
+
dispatch(device, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
111
126
|
|
|
112
127
|
uniformBuffer.destroy();
|
|
113
128
|
|
|
@@ -215,7 +230,7 @@ export async function runSiLURowSplit(
|
|
|
215
230
|
],
|
|
216
231
|
});
|
|
217
232
|
|
|
218
|
-
const workgroups = Math.ceil(
|
|
233
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
219
234
|
dispatch(device, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
220
235
|
|
|
221
236
|
uniformBuffer.destroy();
|
|
@@ -269,7 +284,7 @@ export async function recordSiLURowSplit(
|
|
|
269
284
|
],
|
|
270
285
|
});
|
|
271
286
|
|
|
272
|
-
const workgroups = Math.ceil(
|
|
287
|
+
const workgroups = [Math.ceil(dim / WORKGROUP_SIZES.DEFAULT), numTokens, 1];
|
|
273
288
|
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu_rowsplit');
|
|
274
289
|
|
|
275
290
|
return createTensor(output, input.dtype, [numTokens, dim], 'silu_rowsplit_output');
|
|
@@ -288,6 +303,7 @@ export async function recordSiLU(
|
|
|
288
303
|
outputBuffer = null,
|
|
289
304
|
swigluLimit,
|
|
290
305
|
gateActivation = 'silu',
|
|
306
|
+
inputActivation = 'silu',
|
|
291
307
|
} = options;
|
|
292
308
|
const resolvedSwigluLimit = resolveSwigluLimit(swigluLimit, 'SiLU');
|
|
293
309
|
|
|
@@ -302,14 +318,17 @@ export async function recordSiLU(
|
|
|
302
318
|
useSplit: false,
|
|
303
319
|
useRowsplit: false,
|
|
304
320
|
});
|
|
305
|
-
const constants =
|
|
306
|
-
|
|
307
|
-
:
|
|
321
|
+
const constants = {
|
|
322
|
+
...(overrides || {}),
|
|
323
|
+
...(gate && gateActivation === 'sigmoid' ? { GATE_USE_SIGMOID: true } : {}),
|
|
324
|
+
...(inputActivation === 'identity' ? { INPUT_USE_IDENTITY: true } : {}),
|
|
325
|
+
};
|
|
308
326
|
const pipeline = await getPipelineFast('silu', variant, null, constants);
|
|
309
327
|
|
|
310
328
|
const inferredSize = size || (input.buffer.size / bytesPerElement);
|
|
311
329
|
const outputSize = inferredSize * bytesPerElement;
|
|
312
330
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'silu_output');
|
|
331
|
+
const dispatchPlan = planSiLUDispatch(device, inferredSize, false);
|
|
313
332
|
|
|
314
333
|
// Uniform buffer
|
|
315
334
|
const uniformBuffer = createUniformBufferWithView(
|
|
@@ -317,7 +336,7 @@ export async function recordSiLU(
|
|
|
317
336
|
16,
|
|
318
337
|
(view) => {
|
|
319
338
|
view.setUint32(0, inferredSize, true);
|
|
320
|
-
view.setUint32(4,
|
|
339
|
+
view.setUint32(4, dispatchPlan.dispatchStride, true);
|
|
321
340
|
view.setFloat32(8, gate ? resolvedSwigluLimit : 0, true);
|
|
322
341
|
view.setFloat32(12, 0, true);
|
|
323
342
|
},
|
|
@@ -333,8 +352,7 @@ export async function recordSiLU(
|
|
|
333
352
|
entries,
|
|
334
353
|
});
|
|
335
354
|
|
|
336
|
-
|
|
337
|
-
recordDispatch(recorder, pipeline, bindGroup, workgroups, 'silu');
|
|
355
|
+
recordDispatch(recorder, pipeline, bindGroup, dispatchPlan.workgroups, 'silu');
|
|
338
356
|
|
|
339
357
|
return createTensor(output, input.dtype, [inferredSize], 'silu_output');
|
|
340
358
|
}
|
|
@@ -10,13 +10,14 @@
|
|
|
10
10
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
11
11
|
override HAS_GATE: bool = false;
|
|
12
12
|
override GATE_USE_SIGMOID: bool = false;
|
|
13
|
+
override INPUT_USE_IDENTITY: bool = false;
|
|
13
14
|
override USE_SPLIT: bool = false;
|
|
14
15
|
override USE_VEC4: bool = false;
|
|
15
16
|
override USE_ROWSPLIT: bool = false;
|
|
16
17
|
|
|
17
18
|
struct Uniforms {
|
|
18
19
|
size: u32, // Total output elements
|
|
19
|
-
rowsplit_dim: u32, //
|
|
20
|
+
rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
|
|
20
21
|
clamp_max: f32, // SwiGLU clamp (0 = disabled)
|
|
21
22
|
_pad1: f32,
|
|
22
23
|
}
|
|
@@ -35,6 +36,10 @@ fn silu(x: f32) -> f32 {
|
|
|
35
36
|
return x * sigmoid(x);
|
|
36
37
|
}
|
|
37
38
|
|
|
39
|
+
fn apply_input_activation(x: f32) -> f32 {
|
|
40
|
+
return select(silu(x), x, INPUT_USE_IDENTITY);
|
|
41
|
+
}
|
|
42
|
+
|
|
38
43
|
fn clamp_swiglu(x: f32) -> f32 {
|
|
39
44
|
if (u.clamp_max <= 0.0) {
|
|
40
45
|
return x;
|
|
@@ -46,8 +51,9 @@ fn clamp_swiglu(x: f32) -> f32 {
|
|
|
46
51
|
fn main(
|
|
47
52
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
48
53
|
) {
|
|
54
|
+
let dispatch_stride = max(u.rowsplit_dim, 1u);
|
|
49
55
|
if (USE_VEC4) {
|
|
50
|
-
let base_idx = global_id.x * 4u;
|
|
56
|
+
let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
|
|
51
57
|
if (base_idx >= u.size) {
|
|
52
58
|
return;
|
|
53
59
|
}
|
|
@@ -55,12 +61,12 @@ fn main(
|
|
|
55
61
|
let remaining = min(4u, u.size - base_idx);
|
|
56
62
|
for (var i: u32 = 0u; i < remaining; i = i + 1u) {
|
|
57
63
|
let x = input[base_idx + i];
|
|
58
|
-
output[base_idx + i] =
|
|
64
|
+
output[base_idx + i] = apply_input_activation(x);
|
|
59
65
|
}
|
|
60
66
|
return;
|
|
61
67
|
}
|
|
62
68
|
|
|
63
|
-
let idx = global_id.x;
|
|
69
|
+
let idx = global_id.y * dispatch_stride + global_id.x;
|
|
64
70
|
if (idx >= u.size) {
|
|
65
71
|
return;
|
|
66
72
|
}
|
|
@@ -70,12 +76,16 @@ fn main(
|
|
|
70
76
|
return;
|
|
71
77
|
}
|
|
72
78
|
let dim = u.rowsplit_dim;
|
|
73
|
-
let
|
|
74
|
-
let
|
|
79
|
+
let num_tokens = u.size / dim;
|
|
80
|
+
let token_idx = global_id.y;
|
|
81
|
+
let dim_idx = global_id.x;
|
|
82
|
+
if (token_idx >= num_tokens || dim_idx >= dim) {
|
|
83
|
+
return;
|
|
84
|
+
}
|
|
75
85
|
let row_base = token_idx * dim * 2u;
|
|
76
86
|
let g = input[row_base + dim_idx];
|
|
77
87
|
let up = input[row_base + dim + dim_idx];
|
|
78
|
-
output[
|
|
88
|
+
output[token_idx * dim + dim_idx] = clamp_swiglu(silu(g) * up);
|
|
79
89
|
return;
|
|
80
90
|
}
|
|
81
91
|
|
|
@@ -83,7 +93,7 @@ fn main(
|
|
|
83
93
|
let up = input[idx];
|
|
84
94
|
let g = gate[idx];
|
|
85
95
|
let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
|
|
86
|
-
output[idx] = clamp_swiglu(gateAct * up);
|
|
96
|
+
output[idx] = clamp_swiglu(gateAct * apply_input_activation(up));
|
|
87
97
|
return;
|
|
88
98
|
}
|
|
89
99
|
|
|
@@ -95,5 +105,5 @@ fn main(
|
|
|
95
105
|
}
|
|
96
106
|
|
|
97
107
|
let x = input[idx];
|
|
98
|
-
output[idx] =
|
|
108
|
+
output[idx] = apply_input_activation(x);
|
|
99
109
|
}
|
|
@@ -9,13 +9,14 @@ enable f16;
|
|
|
9
9
|
override WORKGROUP_SIZE: u32 = 256u;
|
|
10
10
|
override HAS_GATE: bool = false;
|
|
11
11
|
override GATE_USE_SIGMOID: bool = false;
|
|
12
|
+
override INPUT_USE_IDENTITY: bool = false;
|
|
12
13
|
override USE_SPLIT: bool = false;
|
|
13
14
|
override USE_VEC4: bool = false;
|
|
14
15
|
override USE_ROWSPLIT: bool = false;
|
|
15
16
|
|
|
16
17
|
struct Uniforms {
|
|
17
18
|
size: u32, // Total output elements
|
|
18
|
-
rowsplit_dim: u32, //
|
|
19
|
+
rowsplit_dim: u32, // Row-split dim or dispatch stride for non-row-split variants
|
|
19
20
|
clamp_max: f32, // SwiGLU clamp (0 = disabled)
|
|
20
21
|
_pad1: f32,
|
|
21
22
|
}
|
|
@@ -34,6 +35,10 @@ fn silu(x: f32) -> f32 {
|
|
|
34
35
|
return x * sigmoid(x);
|
|
35
36
|
}
|
|
36
37
|
|
|
38
|
+
fn apply_input_activation(x: f32) -> f32 {
|
|
39
|
+
return select(silu(x), x, INPUT_USE_IDENTITY);
|
|
40
|
+
}
|
|
41
|
+
|
|
37
42
|
fn clamp_swiglu(x: f32) -> f32 {
|
|
38
43
|
if (u.clamp_max <= 0.0) {
|
|
39
44
|
return x;
|
|
@@ -45,8 +50,9 @@ fn clamp_swiglu(x: f32) -> f32 {
|
|
|
45
50
|
fn main(
|
|
46
51
|
@builtin(global_invocation_id) global_id: vec3<u32>
|
|
47
52
|
) {
|
|
53
|
+
let dispatch_stride = max(u.rowsplit_dim, 1u);
|
|
48
54
|
if (USE_VEC4) {
|
|
49
|
-
let base_idx = global_id.x * 4u;
|
|
55
|
+
let base_idx = global_id.y * dispatch_stride + global_id.x * 4u;
|
|
50
56
|
if (base_idx >= u.size) {
|
|
51
57
|
return;
|
|
52
58
|
}
|
|
@@ -54,12 +60,12 @@ fn main(
|
|
|
54
60
|
let remaining = min(4u, u.size - base_idx);
|
|
55
61
|
for (var i: u32 = 0u; i < remaining; i = i + 1u) {
|
|
56
62
|
let x = f32(input[base_idx + i]);
|
|
57
|
-
output[base_idx + i] = f16(
|
|
63
|
+
output[base_idx + i] = f16(apply_input_activation(x));
|
|
58
64
|
}
|
|
59
65
|
return;
|
|
60
66
|
}
|
|
61
67
|
|
|
62
|
-
let idx = global_id.x;
|
|
68
|
+
let idx = global_id.y * dispatch_stride + global_id.x;
|
|
63
69
|
if (idx >= u.size) {
|
|
64
70
|
return;
|
|
65
71
|
}
|
|
@@ -69,12 +75,16 @@ fn main(
|
|
|
69
75
|
return;
|
|
70
76
|
}
|
|
71
77
|
let dim = u.rowsplit_dim;
|
|
72
|
-
let
|
|
73
|
-
let
|
|
78
|
+
let num_tokens = u.size / dim;
|
|
79
|
+
let token_idx = global_id.y;
|
|
80
|
+
let dim_idx = global_id.x;
|
|
81
|
+
if (token_idx >= num_tokens || dim_idx >= dim) {
|
|
82
|
+
return;
|
|
83
|
+
}
|
|
74
84
|
let row_base = token_idx * dim * 2u;
|
|
75
85
|
let g = f32(input[row_base + dim_idx]);
|
|
76
86
|
let up = f32(input[row_base + dim + dim_idx]);
|
|
77
|
-
output[
|
|
87
|
+
output[token_idx * dim + dim_idx] = f16(clamp_swiglu(silu(g) * up));
|
|
78
88
|
return;
|
|
79
89
|
}
|
|
80
90
|
|
|
@@ -82,7 +92,7 @@ fn main(
|
|
|
82
92
|
let up = f32(input[idx]);
|
|
83
93
|
let g = f32(gate[idx]);
|
|
84
94
|
let gateAct = select(silu(g), sigmoid(g), GATE_USE_SIGMOID);
|
|
85
|
-
output[idx] = f16(clamp_swiglu(gateAct * up));
|
|
95
|
+
output[idx] = f16(clamp_swiglu(gateAct * apply_input_activation(up)));
|
|
86
96
|
return;
|
|
87
97
|
}
|
|
88
98
|
|
|
@@ -94,5 +104,5 @@ fn main(
|
|
|
94
104
|
}
|
|
95
105
|
|
|
96
106
|
let x = f32(input[idx]);
|
|
97
|
-
output[idx] = f16(
|
|
107
|
+
output[idx] = f16(apply_input_activation(x));
|
|
98
108
|
}
|
|
@@ -3,19 +3,32 @@ import { createTensor, dtypeBytes } from '../tensor.js';
|
|
|
3
3
|
import { WORKGROUP_SIZES } from './constants.js';
|
|
4
4
|
import { unifiedKernelWrapper } from './utils.js';
|
|
5
5
|
|
|
6
|
+
function planTransposeDispatch(target, cols) {
|
|
7
|
+
const device = target?.device;
|
|
8
|
+
const maxPerDim = Number.isFinite(device?.limits?.maxComputeWorkgroupsPerDimension)
|
|
9
|
+
? device.limits.maxComputeWorkgroupsPerDimension
|
|
10
|
+
: 65535;
|
|
11
|
+
const dispatchStride = Math.min(cols, maxPerDim * WORKGROUP_SIZES.DEFAULT);
|
|
12
|
+
return {
|
|
13
|
+
dispatchStride,
|
|
14
|
+
workgroups: [Math.ceil(dispatchStride / WORKGROUP_SIZES.DEFAULT), 1, 1],
|
|
15
|
+
};
|
|
16
|
+
}
|
|
17
|
+
|
|
6
18
|
async function _transpose(target, input, rows, cols, options = {}) {
|
|
7
19
|
const { outputBuffer = null } = options;
|
|
8
20
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
9
21
|
const outputSize = rows * cols * bytesPerElement;
|
|
10
22
|
const outputBuf = outputBuffer || acquireBuffer(outputSize, undefined, 'transpose_output');
|
|
23
|
+
const dispatchPlan = planTransposeDispatch(target, cols);
|
|
11
24
|
|
|
12
25
|
await unifiedKernelWrapper(
|
|
13
26
|
'transpose',
|
|
14
27
|
target,
|
|
15
28
|
'default',
|
|
16
29
|
[input, outputBuf],
|
|
17
|
-
{ rows, cols },
|
|
18
|
-
|
|
30
|
+
{ rows, cols, _pad0: dispatchPlan.dispatchStride, _pad1: 0 },
|
|
31
|
+
[dispatchPlan.workgroups[0], rows, 1]
|
|
19
32
|
);
|
|
20
33
|
|
|
21
34
|
return createTensor(outputBuf, input.dtype, [cols, rows], 'transpose_output');
|
|
@@ -19,14 +19,13 @@ struct Uniforms {
|
|
|
19
19
|
|
|
20
20
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
21
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
-
let
|
|
23
|
-
let
|
|
24
|
-
|
|
22
|
+
let dispatch_stride = max(u._pad0, 1u);
|
|
23
|
+
let row = gid.y;
|
|
24
|
+
let col = gid.x + row * dispatch_stride;
|
|
25
|
+
if (row >= u.rows || col >= u.cols) {
|
|
25
26
|
return;
|
|
26
27
|
}
|
|
27
|
-
|
|
28
|
-
let row = idx / u.cols;
|
|
29
|
-
let col = idx % u.cols;
|
|
28
|
+
let idx = row * u.cols + col;
|
|
30
29
|
let out_idx = col * u.rows + row;
|
|
31
30
|
output[out_idx] = input[idx];
|
|
32
31
|
}
|
|
@@ -31,6 +31,7 @@ async function _upsample2d(target, input, options = {}) {
|
|
|
31
31
|
|
|
32
32
|
const outHeight = resolvedHeight * scale;
|
|
33
33
|
const outWidth = resolvedWidth * scale;
|
|
34
|
+
const outSpatial = outHeight * outWidth;
|
|
34
35
|
const bytesPerElement = dtypeBytes(input.dtype);
|
|
35
36
|
const outputSize = channels * outHeight * outWidth * bytesPerElement;
|
|
36
37
|
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'upsample2d_output');
|
|
@@ -43,7 +44,7 @@ async function _upsample2d(target, input, options = {}) {
|
|
|
43
44
|
out_height: outHeight, out_width: outWidth, scale,
|
|
44
45
|
_pad0: 0, _pad1: 0,
|
|
45
46
|
},
|
|
46
|
-
Math.ceil(
|
|
47
|
+
[Math.ceil(outSpatial / WORKGROUP_SIZES.DEFAULT), channels, 1]
|
|
47
48
|
);
|
|
48
49
|
|
|
49
50
|
return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'upsample2d_output');
|
|
@@ -19,19 +19,16 @@ struct Uniforms {
|
|
|
19
19
|
|
|
20
20
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
21
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
-
let idx = gid.x;
|
|
23
22
|
let out_spatial = u.out_height * u.out_width;
|
|
24
|
-
let
|
|
25
|
-
|
|
23
|
+
let spatial_idx = gid.x;
|
|
24
|
+
let channel = gid.y;
|
|
25
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
26
26
|
return;
|
|
27
27
|
}
|
|
28
|
-
|
|
29
|
-
let
|
|
30
|
-
let rem = idx - channel * out_spatial;
|
|
31
|
-
let out_y = rem / u.out_width;
|
|
32
|
-
let out_x = rem - out_y * u.out_width;
|
|
28
|
+
let out_y = spatial_idx / u.out_width;
|
|
29
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
33
30
|
let in_y = out_y / u.scale;
|
|
34
31
|
let in_x = out_x / u.scale;
|
|
35
32
|
let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
|
|
36
|
-
output[
|
|
33
|
+
output[channel * out_spatial + spatial_idx] = input[in_idx];
|
|
37
34
|
}
|
|
@@ -23,19 +23,16 @@ struct Uniforms {
|
|
|
23
23
|
|
|
24
24
|
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
25
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
-
let idx = gid.x;
|
|
27
26
|
let out_spatial = u.out_height * u.out_width;
|
|
28
|
-
let
|
|
29
|
-
|
|
27
|
+
let spatial_idx = gid.x;
|
|
28
|
+
let channel = gid.y;
|
|
29
|
+
if (spatial_idx >= out_spatial || channel >= u.channels) {
|
|
30
30
|
return;
|
|
31
31
|
}
|
|
32
|
-
|
|
33
|
-
let
|
|
34
|
-
let rem = idx - channel * out_spatial;
|
|
35
|
-
let out_y = rem / u.out_width;
|
|
36
|
-
let out_x = rem - out_y * u.out_width;
|
|
32
|
+
let out_y = spatial_idx / u.out_width;
|
|
33
|
+
let out_x = spatial_idx - out_y * u.out_width;
|
|
37
34
|
let in_y = out_y / u.scale;
|
|
38
35
|
let in_x = out_x / u.scale;
|
|
39
36
|
let in_idx = (channel * u.in_height + in_y) * u.in_width + in_x;
|
|
40
|
-
output[
|
|
37
|
+
output[channel * out_spatial + spatial_idx] = input[in_idx];
|
|
41
38
|
}
|
package/src/gpu/kernels/utils.js
CHANGED
|
@@ -116,9 +116,24 @@ export async function unifiedKernelWrapper(opName, target, variant, bindings, un
|
|
|
116
116
|
index = config.variantMetadata.outputBinding;
|
|
117
117
|
}
|
|
118
118
|
|
|
119
|
+
const buffer = binding?.buffer || binding;
|
|
120
|
+
const isGpuBuffer = buffer && (
|
|
121
|
+
typeof GPUBuffer === 'undefined'
|
|
122
|
+
? true
|
|
123
|
+
: buffer instanceof GPUBuffer
|
|
124
|
+
);
|
|
125
|
+
if (!isGpuBuffer) {
|
|
126
|
+
const bindingLabel = binding?.label ?? buffer?.label ?? 'unknown';
|
|
127
|
+
const bufferType = buffer === null ? 'null' : buffer === undefined ? 'undefined' : buffer.constructor?.name || typeof buffer;
|
|
128
|
+
throw new Error(
|
|
129
|
+
`Kernel "${opName}/${variant}" binding "${bindingConfig.name}" (index ${index}) requires a GPUBuffer ` +
|
|
130
|
+
`(label=${bindingLabel}, type=${bufferType}).`
|
|
131
|
+
);
|
|
132
|
+
}
|
|
133
|
+
|
|
119
134
|
bindGroupEntries.push({
|
|
120
135
|
binding: index,
|
|
121
|
-
resource: { buffer
|
|
136
|
+
resource: { buffer }
|
|
122
137
|
});
|
|
123
138
|
}
|
|
124
139
|
|
package/src/index-browser.d.ts
CHANGED