@simulatte/doppler 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -5
- package/package.json +27 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.d.ts +80 -0
- package/src/client/doppler-api.js +298 -0
- package/src/client/doppler-provider/types.js +1 -1
- package/src/client/doppler-registry.d.ts +23 -0
- package/src/client/doppler-registry.js +88 -0
- package/src/client/doppler-registry.json +39 -0
- package/src/config/execution-contract-check.d.ts +82 -0
- package/src/config/execution-contract-check.js +317 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +90 -67
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +27 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +49 -11
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/tokenizer-utils.js +17 -3
- package/src/formats/rdrr/validation.js +13 -0
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -0
- package/src/index-browser.js +2 -1
- package/src/index.d.ts +1 -0
- package/src/index.js +2 -1
- package/src/inference/browser-harness.js +164 -38
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +141 -101
- package/src/inference/pipelines/text/init.js +41 -10
- package/src/inference/pipelines/text.js +7 -1
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/lean-execution-contract.d.ts +16 -0
- package/src/tooling/lean-execution-contract.js +81 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +30 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +167 -6
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
import type { WeightBuffer } from '../weight-buffer.js';
|
|
5
|
+
|
|
6
|
+
export interface GroupedPointwiseConv2DOptions extends OutputBufferOptions {
|
|
7
|
+
inChannels: number;
|
|
8
|
+
outChannels: number;
|
|
9
|
+
height: number;
|
|
10
|
+
width: number;
|
|
11
|
+
groups: number;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export declare function runGroupedPointwiseConv2D(
|
|
15
|
+
input: Tensor,
|
|
16
|
+
weight: GPUBuffer | WeightBuffer,
|
|
17
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
18
|
+
options: GroupedPointwiseConv2DOptions
|
|
19
|
+
): Promise<Tensor>;
|
|
20
|
+
|
|
21
|
+
export declare function recordGroupedPointwiseConv2D(
|
|
22
|
+
recorder: CommandRecorder,
|
|
23
|
+
input: Tensor,
|
|
24
|
+
weight: GPUBuffer | WeightBuffer,
|
|
25
|
+
bias: GPUBuffer | WeightBuffer | null,
|
|
26
|
+
options: GroupedPointwiseConv2DOptions
|
|
27
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import { getDevice } from '../device.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { getBuffer } from '../weight-buffer.js';
|
|
5
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
6
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
7
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
8
|
+
|
|
9
|
+
function selectGroupedPointwiseConv2DVariant(isF16) {
|
|
10
|
+
return selectRuleValue('groupedPointwiseConv2d', 'variant', { isF16 });
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}) {
|
|
14
|
+
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
15
|
+
const device = target?.device || getDevice();
|
|
16
|
+
const {
|
|
17
|
+
inChannels,
|
|
18
|
+
outChannels,
|
|
19
|
+
height,
|
|
20
|
+
width,
|
|
21
|
+
groups,
|
|
22
|
+
outputBuffer = null,
|
|
23
|
+
} = options;
|
|
24
|
+
|
|
25
|
+
if (
|
|
26
|
+
!Number.isFinite(inChannels) ||
|
|
27
|
+
!Number.isFinite(outChannels) ||
|
|
28
|
+
!Number.isFinite(height) ||
|
|
29
|
+
!Number.isFinite(width) ||
|
|
30
|
+
!Number.isFinite(groups)
|
|
31
|
+
) {
|
|
32
|
+
throw new Error('GroupedPointwiseConv2D requires explicit dimensions.');
|
|
33
|
+
}
|
|
34
|
+
if (groups <= 0 || inChannels % groups !== 0 || outChannels % groups !== 0) {
|
|
35
|
+
throw new Error(
|
|
36
|
+
`GroupedPointwiseConv2D requires inChannels/outChannels divisible by groups. Got ${inChannels}/${outChannels}/${groups}.`
|
|
37
|
+
);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
const isF16 = input.dtype === 'f16';
|
|
41
|
+
const variant = selectGroupedPointwiseConv2DVariant(isF16);
|
|
42
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
43
|
+
const outputSize = outChannels * height * width * bytesPerElement;
|
|
44
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
|
|
45
|
+
|
|
46
|
+
const weightBuffer = getBuffer(weight);
|
|
47
|
+
let biasBuffer = getBuffer(bias);
|
|
48
|
+
let tempBias = null;
|
|
49
|
+
if (!biasBuffer) {
|
|
50
|
+
const biasSize = outChannels * bytesPerElement;
|
|
51
|
+
tempBias = acquireBuffer(biasSize, undefined, 'grouped_pointwise_conv2d_bias_zero');
|
|
52
|
+
biasBuffer = tempBias;
|
|
53
|
+
const paddedSize = Math.ceil(biasSize / 4) * 4;
|
|
54
|
+
device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
await unifiedKernelWrapper(
|
|
58
|
+
'grouped_pointwise_conv2d',
|
|
59
|
+
target,
|
|
60
|
+
variant,
|
|
61
|
+
[input, weightBuffer, biasBuffer, output],
|
|
62
|
+
{
|
|
63
|
+
in_channels: inChannels,
|
|
64
|
+
out_channels: outChannels,
|
|
65
|
+
height,
|
|
66
|
+
width,
|
|
67
|
+
groups,
|
|
68
|
+
_pad0: 0,
|
|
69
|
+
_pad1: 0,
|
|
70
|
+
_pad2: 0,
|
|
71
|
+
},
|
|
72
|
+
Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
|
|
73
|
+
);
|
|
74
|
+
|
|
75
|
+
if (tempBias) {
|
|
76
|
+
if (recorder) {
|
|
77
|
+
recorder.trackTemporaryBuffer(tempBias);
|
|
78
|
+
} else {
|
|
79
|
+
releaseBuffer(tempBias);
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
|
|
87
|
+
return _groupedPointwiseConv2D(null, input, weight, bias, options);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
export async function recordGroupedPointwiseConv2D(recorder, input, weight, bias, options = {}) {
|
|
91
|
+
return _groupedPointwiseConv2D(recorder, input, weight, bias, options);
|
|
92
|
+
}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
in_channels: u32,
|
|
5
|
+
out_channels: u32,
|
|
6
|
+
height: u32,
|
|
7
|
+
width: u32,
|
|
8
|
+
groups: u32,
|
|
9
|
+
_pad0: u32,
|
|
10
|
+
_pad1: u32,
|
|
11
|
+
_pad2: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> weight: array<f32>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read> bias: array<f32>;
|
|
18
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f32>;
|
|
19
|
+
|
|
20
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
+
let idx = gid.x;
|
|
23
|
+
let spatial = u.height * u.width;
|
|
24
|
+
let out_size = u.out_channels * spatial;
|
|
25
|
+
if (idx >= out_size) {
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
let out_channel = idx / spatial;
|
|
30
|
+
let rem = idx - out_channel * spatial;
|
|
31
|
+
let y = rem / u.width;
|
|
32
|
+
let x = rem - y * u.width;
|
|
33
|
+
|
|
34
|
+
let in_per_group = u.in_channels / u.groups;
|
|
35
|
+
let out_per_group = u.out_channels / u.groups;
|
|
36
|
+
let group_idx = out_channel / out_per_group;
|
|
37
|
+
let in_offset = group_idx * in_per_group;
|
|
38
|
+
|
|
39
|
+
var sum: f32 = bias[out_channel];
|
|
40
|
+
for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
|
|
41
|
+
let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
|
|
42
|
+
let weight_idx = out_channel * in_per_group + i;
|
|
43
|
+
sum = sum + input[input_idx] * weight[weight_idx];
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
output[idx] = sum;
|
|
47
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
// Grouped Pointwise Conv2D Kernel (NCHW, f16)
|
|
2
|
+
|
|
3
|
+
enable f16;
|
|
4
|
+
|
|
5
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
6
|
+
|
|
7
|
+
struct Uniforms {
|
|
8
|
+
in_channels: u32,
|
|
9
|
+
out_channels: u32,
|
|
10
|
+
height: u32,
|
|
11
|
+
width: u32,
|
|
12
|
+
groups: u32,
|
|
13
|
+
_pad0: u32,
|
|
14
|
+
_pad1: u32,
|
|
15
|
+
_pad2: u32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
19
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
20
|
+
@group(0) @binding(2) var<storage, read> weight: array<f16>;
|
|
21
|
+
@group(0) @binding(3) var<storage, read> bias: array<f16>;
|
|
22
|
+
@group(0) @binding(4) var<storage, read_write> output: array<f16>;
|
|
23
|
+
|
|
24
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
25
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
+
let idx = gid.x;
|
|
27
|
+
let spatial = u.height * u.width;
|
|
28
|
+
let out_size = u.out_channels * spatial;
|
|
29
|
+
if (idx >= out_size) {
|
|
30
|
+
return;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
let out_channel = idx / spatial;
|
|
34
|
+
let rem = idx - out_channel * spatial;
|
|
35
|
+
let y = rem / u.width;
|
|
36
|
+
let x = rem - y * u.width;
|
|
37
|
+
|
|
38
|
+
let in_per_group = u.in_channels / u.groups;
|
|
39
|
+
let out_per_group = u.out_channels / u.groups;
|
|
40
|
+
let group_idx = out_channel / out_per_group;
|
|
41
|
+
let in_offset = group_idx * in_per_group;
|
|
42
|
+
|
|
43
|
+
var sum: f32 = f32(bias[out_channel]);
|
|
44
|
+
for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
|
|
45
|
+
let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
|
|
46
|
+
let weight_idx = out_channel * in_per_group + i;
|
|
47
|
+
sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
output[idx] = f16(sum);
|
|
51
|
+
}
|
|
@@ -174,6 +174,18 @@ export {
|
|
|
174
174
|
type Conv2DOptions,
|
|
175
175
|
} from './conv2d.js';
|
|
176
176
|
|
|
177
|
+
export {
|
|
178
|
+
runDepthwiseConv2D,
|
|
179
|
+
recordDepthwiseConv2D,
|
|
180
|
+
type DepthwiseConv2DOptions,
|
|
181
|
+
} from './depthwise_conv2d.js';
|
|
182
|
+
|
|
183
|
+
export {
|
|
184
|
+
runGroupedPointwiseConv2D,
|
|
185
|
+
recordGroupedPointwiseConv2D,
|
|
186
|
+
type GroupedPointwiseConv2DOptions,
|
|
187
|
+
} from './grouped_pointwise_conv2d.js';
|
|
188
|
+
|
|
177
189
|
// Gather (Embedding Lookup)
|
|
178
190
|
export {
|
|
179
191
|
runGather,
|
|
@@ -250,6 +262,24 @@ export {
|
|
|
250
262
|
type SampleResult,
|
|
251
263
|
} from './sample.js';
|
|
252
264
|
|
|
265
|
+
export {
|
|
266
|
+
runSanaLinearAttention,
|
|
267
|
+
recordSanaLinearAttention,
|
|
268
|
+
type SanaLinearAttentionOptions,
|
|
269
|
+
} from './sana_linear_attention.js';
|
|
270
|
+
|
|
271
|
+
export {
|
|
272
|
+
runRepeatChannels,
|
|
273
|
+
recordRepeatChannels,
|
|
274
|
+
type RepeatChannelsOptions,
|
|
275
|
+
} from './repeat_channels.js';
|
|
276
|
+
|
|
277
|
+
export {
|
|
278
|
+
runReLU,
|
|
279
|
+
recordReLU,
|
|
280
|
+
type ReLUOptions,
|
|
281
|
+
} from './relu.js';
|
|
282
|
+
|
|
253
283
|
// Fused FFN (Tier 2 P0)
|
|
254
284
|
export {
|
|
255
285
|
runFusedFFN,
|
package/src/gpu/kernels/index.js
CHANGED
|
@@ -139,6 +139,16 @@ export {
|
|
|
139
139
|
recordConv2D,
|
|
140
140
|
} from './conv2d.js';
|
|
141
141
|
|
|
142
|
+
export {
|
|
143
|
+
runDepthwiseConv2D,
|
|
144
|
+
recordDepthwiseConv2D,
|
|
145
|
+
} from './depthwise_conv2d.js';
|
|
146
|
+
|
|
147
|
+
export {
|
|
148
|
+
runGroupedPointwiseConv2D,
|
|
149
|
+
recordGroupedPointwiseConv2D,
|
|
150
|
+
} from './grouped_pointwise_conv2d.js';
|
|
151
|
+
|
|
142
152
|
// Gather (Embedding Lookup)
|
|
143
153
|
export {
|
|
144
154
|
runGather,
|
|
@@ -205,6 +215,21 @@ export {
|
|
|
205
215
|
isGPUSamplingAvailable,
|
|
206
216
|
} from './sample.js';
|
|
207
217
|
|
|
218
|
+
export {
|
|
219
|
+
runSanaLinearAttention,
|
|
220
|
+
recordSanaLinearAttention,
|
|
221
|
+
} from './sana_linear_attention.js';
|
|
222
|
+
|
|
223
|
+
export {
|
|
224
|
+
runRepeatChannels,
|
|
225
|
+
recordRepeatChannels,
|
|
226
|
+
} from './repeat_channels.js';
|
|
227
|
+
|
|
228
|
+
export {
|
|
229
|
+
runReLU,
|
|
230
|
+
recordReLU,
|
|
231
|
+
} from './relu.js';
|
|
232
|
+
|
|
208
233
|
// Fused FFN (Tier 2 P0)
|
|
209
234
|
export {
|
|
210
235
|
runFusedFFN,
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface ReLUOptions extends OutputBufferOptions {
|
|
6
|
+
count?: number | null;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export declare function runReLU(
|
|
10
|
+
input: Tensor,
|
|
11
|
+
options?: ReLUOptions
|
|
12
|
+
): Promise<Tensor>;
|
|
13
|
+
|
|
14
|
+
export declare function recordReLU(
|
|
15
|
+
recorder: CommandRecorder,
|
|
16
|
+
input: Tensor,
|
|
17
|
+
options?: ReLUOptions
|
|
18
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectReluVariant(dtype) {
|
|
8
|
+
return selectRuleValue('relu', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
function resolveCount(input, countOverride) {
|
|
12
|
+
if (Number.isFinite(countOverride) && countOverride > 0) {
|
|
13
|
+
return Math.floor(countOverride);
|
|
14
|
+
}
|
|
15
|
+
if (Array.isArray(input.shape) && input.shape.length > 0) {
|
|
16
|
+
return input.shape.reduce((acc, value) => acc * value, 1);
|
|
17
|
+
}
|
|
18
|
+
return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
async function _relu(target, input, options = {}) {
|
|
22
|
+
const { count = null, outputBuffer = null } = options;
|
|
23
|
+
const size = resolveCount(input, count);
|
|
24
|
+
const variant = selectReluVariant(input.dtype);
|
|
25
|
+
const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
|
|
26
|
+
|
|
27
|
+
await unifiedKernelWrapper(
|
|
28
|
+
'relu',
|
|
29
|
+
target,
|
|
30
|
+
variant,
|
|
31
|
+
[input, output],
|
|
32
|
+
{ size, _pad0: 0, _pad1: 0, _pad2: 0 },
|
|
33
|
+
Math.ceil(size / WORKGROUP_SIZES.DEFAULT)
|
|
34
|
+
);
|
|
35
|
+
|
|
36
|
+
return createTensor(output, input.dtype, [...input.shape], 'relu_output');
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
export async function runReLU(input, options = {}) {
|
|
40
|
+
return _relu(null, input, options);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
export async function recordReLU(recorder, input, options = {}) {
|
|
44
|
+
return _relu(recorder, input, options);
|
|
45
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
size: u32,
|
|
5
|
+
_pad0: u32,
|
|
6
|
+
_pad1: u32,
|
|
7
|
+
_pad2: u32,
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
11
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
12
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
13
|
+
|
|
14
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
15
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
16
|
+
let idx = gid.x;
|
|
17
|
+
if (idx >= u.size) {
|
|
18
|
+
return;
|
|
19
|
+
}
|
|
20
|
+
output[idx] = max(input[idx], 0.0);
|
|
21
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
size: u32,
|
|
7
|
+
_pad0: u32,
|
|
8
|
+
_pad1: u32,
|
|
9
|
+
_pad2: u32,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
13
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
14
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
15
|
+
|
|
16
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
17
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
18
|
+
let idx = gid.x;
|
|
19
|
+
if (idx >= u.size) {
|
|
20
|
+
return;
|
|
21
|
+
}
|
|
22
|
+
output[idx] = max(input[idx], 0.0h);
|
|
23
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface RepeatChannelsOptions extends OutputBufferOptions {
|
|
6
|
+
inChannels: number;
|
|
7
|
+
height: number;
|
|
8
|
+
width: number;
|
|
9
|
+
repeats: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export declare function runRepeatChannels(
|
|
13
|
+
input: Tensor,
|
|
14
|
+
options: RepeatChannelsOptions
|
|
15
|
+
): Promise<Tensor>;
|
|
16
|
+
|
|
17
|
+
export declare function recordRepeatChannels(
|
|
18
|
+
recorder: CommandRecorder,
|
|
19
|
+
input: Tensor,
|
|
20
|
+
options: RepeatChannelsOptions
|
|
21
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectRepeatChannelsVariant(dtype) {
|
|
8
|
+
return selectRuleValue('repeatChannels', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
async function _repeatChannels(target, input, options = {}) {
|
|
12
|
+
const {
|
|
13
|
+
inChannels,
|
|
14
|
+
height,
|
|
15
|
+
width,
|
|
16
|
+
repeats,
|
|
17
|
+
outputBuffer = null,
|
|
18
|
+
} = options;
|
|
19
|
+
|
|
20
|
+
if (
|
|
21
|
+
!Number.isFinite(inChannels) ||
|
|
22
|
+
!Number.isFinite(height) ||
|
|
23
|
+
!Number.isFinite(width) ||
|
|
24
|
+
!Number.isFinite(repeats) ||
|
|
25
|
+
repeats < 1
|
|
26
|
+
) {
|
|
27
|
+
throw new Error('RepeatChannels requires inChannels, height, width, and repeats.');
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const outChannels = inChannels * repeats;
|
|
31
|
+
const variant = selectRepeatChannelsVariant(input.dtype);
|
|
32
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
33
|
+
const outputSize = outChannels * height * width * bytesPerElement;
|
|
34
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
|
|
35
|
+
|
|
36
|
+
await unifiedKernelWrapper(
|
|
37
|
+
'repeat_channels',
|
|
38
|
+
target,
|
|
39
|
+
variant,
|
|
40
|
+
[input, output],
|
|
41
|
+
{
|
|
42
|
+
in_channels: inChannels,
|
|
43
|
+
height,
|
|
44
|
+
width,
|
|
45
|
+
repeats,
|
|
46
|
+
_pad0: 0,
|
|
47
|
+
},
|
|
48
|
+
Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
|
|
49
|
+
);
|
|
50
|
+
|
|
51
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export async function runRepeatChannels(input, options = {}) {
|
|
55
|
+
return _repeatChannels(null, input, options);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export async function recordRepeatChannels(recorder, input, options = {}) {
|
|
59
|
+
return _repeatChannels(recorder, input, options);
|
|
60
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
in_channels: u32,
|
|
5
|
+
height: u32,
|
|
6
|
+
width: u32,
|
|
7
|
+
repeats: u32,
|
|
8
|
+
_pad0: u32,
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
12
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
13
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
14
|
+
|
|
15
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
16
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
17
|
+
let idx = gid.x;
|
|
18
|
+
let spatial = u.height * u.width;
|
|
19
|
+
let out_channels = u.in_channels * u.repeats;
|
|
20
|
+
let total = out_channels * spatial;
|
|
21
|
+
if (idx >= total) {
|
|
22
|
+
return;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
let out_channel = idx / spatial;
|
|
26
|
+
let channel = out_channel / u.repeats;
|
|
27
|
+
let spatial_idx = idx - out_channel * spatial;
|
|
28
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
29
|
+
}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
in_channels: u32,
|
|
7
|
+
height: u32,
|
|
8
|
+
width: u32,
|
|
9
|
+
repeats: u32,
|
|
10
|
+
_pad0: u32,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
14
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
15
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
16
|
+
|
|
17
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
18
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
19
|
+
let idx = gid.x;
|
|
20
|
+
let spatial = u.height * u.width;
|
|
21
|
+
let out_channels = u.in_channels * u.repeats;
|
|
22
|
+
let total = out_channels * spatial;
|
|
23
|
+
if (idx >= total) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let out_channel = idx / spatial;
|
|
28
|
+
let channel = out_channel / u.repeats;
|
|
29
|
+
let spatial_idx = idx - out_channel * spatial;
|
|
30
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
31
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface SanaLinearAttentionOptions extends OutputBufferOptions {
|
|
6
|
+
numHeads: number;
|
|
7
|
+
headDim: number;
|
|
8
|
+
numTokens?: number;
|
|
9
|
+
hiddenSize?: number;
|
|
10
|
+
eps?: number;
|
|
11
|
+
summaryBuffer?: GPUBuffer | null;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export declare function runSanaLinearAttention(
|
|
15
|
+
query: Tensor,
|
|
16
|
+
key: Tensor,
|
|
17
|
+
value: Tensor,
|
|
18
|
+
options: SanaLinearAttentionOptions
|
|
19
|
+
): Promise<Tensor>;
|
|
20
|
+
|
|
21
|
+
export declare function recordSanaLinearAttention(
|
|
22
|
+
recorder: CommandRecorder,
|
|
23
|
+
query: Tensor,
|
|
24
|
+
key: Tensor,
|
|
25
|
+
value: Tensor,
|
|
26
|
+
options: SanaLinearAttentionOptions
|
|
27
|
+
): Promise<Tensor>;
|