@simulatte/doppler 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -3
- package/package.json +25 -4
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +12 -0
- package/src/config/kernels/registry.json +556 -0
- package/src/config/loader.js +50 -46
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +3 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +231 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +98 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +45 -0
- package/src/gpu/kernels/relu.wgsl +21 -0
- package/src/gpu/kernels/relu_f16.wgsl +23 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +29 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +122 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +62 -22
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +206 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/config.d.ts +5 -0
- package/src/inference/pipelines/text/config.js +1 -1
- package/src/inference/pipelines/text/execution-v0.js +14 -93
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +9 -9
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +115 -1
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectReluVariant(dtype) {
|
|
8
|
+
return selectRuleValue('relu', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
function resolveCount(input, countOverride) {
|
|
12
|
+
if (Number.isFinite(countOverride) && countOverride > 0) {
|
|
13
|
+
return Math.floor(countOverride);
|
|
14
|
+
}
|
|
15
|
+
if (Array.isArray(input.shape) && input.shape.length > 0) {
|
|
16
|
+
return input.shape.reduce((acc, value) => acc * value, 1);
|
|
17
|
+
}
|
|
18
|
+
return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
async function _relu(target, input, options = {}) {
|
|
22
|
+
const { count = null, outputBuffer = null } = options;
|
|
23
|
+
const size = resolveCount(input, count);
|
|
24
|
+
const variant = selectReluVariant(input.dtype);
|
|
25
|
+
const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
|
|
26
|
+
|
|
27
|
+
await unifiedKernelWrapper(
|
|
28
|
+
'relu',
|
|
29
|
+
target,
|
|
30
|
+
variant,
|
|
31
|
+
[input, output],
|
|
32
|
+
{ size, _pad0: 0, _pad1: 0, _pad2: 0 },
|
|
33
|
+
Math.ceil(size / WORKGROUP_SIZES.DEFAULT)
|
|
34
|
+
);
|
|
35
|
+
|
|
36
|
+
return createTensor(output, input.dtype, [...input.shape], 'relu_output');
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
export async function runReLU(input, options = {}) {
|
|
40
|
+
return _relu(null, input, options);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
export async function recordReLU(recorder, input, options = {}) {
|
|
44
|
+
return _relu(recorder, input, options);
|
|
45
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
size: u32,
|
|
5
|
+
_pad0: u32,
|
|
6
|
+
_pad1: u32,
|
|
7
|
+
_pad2: u32,
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
11
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
12
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
13
|
+
|
|
14
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
15
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
16
|
+
let idx = gid.x;
|
|
17
|
+
if (idx >= u.size) {
|
|
18
|
+
return;
|
|
19
|
+
}
|
|
20
|
+
output[idx] = max(input[idx], 0.0);
|
|
21
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
size: u32,
|
|
7
|
+
_pad0: u32,
|
|
8
|
+
_pad1: u32,
|
|
9
|
+
_pad2: u32,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
13
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
14
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
15
|
+
|
|
16
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
17
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
18
|
+
let idx = gid.x;
|
|
19
|
+
if (idx >= u.size) {
|
|
20
|
+
return;
|
|
21
|
+
}
|
|
22
|
+
output[idx] = max(input[idx], 0.0h);
|
|
23
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface RepeatChannelsOptions extends OutputBufferOptions {
|
|
6
|
+
inChannels: number;
|
|
7
|
+
height: number;
|
|
8
|
+
width: number;
|
|
9
|
+
repeats: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export declare function runRepeatChannels(
|
|
13
|
+
input: Tensor,
|
|
14
|
+
options: RepeatChannelsOptions
|
|
15
|
+
): Promise<Tensor>;
|
|
16
|
+
|
|
17
|
+
export declare function recordRepeatChannels(
|
|
18
|
+
recorder: CommandRecorder,
|
|
19
|
+
input: Tensor,
|
|
20
|
+
options: RepeatChannelsOptions
|
|
21
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import { acquireBuffer } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
3
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
4
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
5
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
6
|
+
|
|
7
|
+
function selectRepeatChannelsVariant(dtype) {
|
|
8
|
+
return selectRuleValue('repeatChannels', 'variant', { dtype });
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
async function _repeatChannels(target, input, options = {}) {
|
|
12
|
+
const {
|
|
13
|
+
inChannels,
|
|
14
|
+
height,
|
|
15
|
+
width,
|
|
16
|
+
repeats,
|
|
17
|
+
outputBuffer = null,
|
|
18
|
+
} = options;
|
|
19
|
+
|
|
20
|
+
if (
|
|
21
|
+
!Number.isFinite(inChannels) ||
|
|
22
|
+
!Number.isFinite(height) ||
|
|
23
|
+
!Number.isFinite(width) ||
|
|
24
|
+
!Number.isFinite(repeats) ||
|
|
25
|
+
repeats < 1
|
|
26
|
+
) {
|
|
27
|
+
throw new Error('RepeatChannels requires inChannels, height, width, and repeats.');
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const outChannels = inChannels * repeats;
|
|
31
|
+
const variant = selectRepeatChannelsVariant(input.dtype);
|
|
32
|
+
const bytesPerElement = dtypeBytes(input.dtype);
|
|
33
|
+
const outputSize = outChannels * height * width * bytesPerElement;
|
|
34
|
+
const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
|
|
35
|
+
|
|
36
|
+
await unifiedKernelWrapper(
|
|
37
|
+
'repeat_channels',
|
|
38
|
+
target,
|
|
39
|
+
variant,
|
|
40
|
+
[input, output],
|
|
41
|
+
{
|
|
42
|
+
in_channels: inChannels,
|
|
43
|
+
height,
|
|
44
|
+
width,
|
|
45
|
+
repeats,
|
|
46
|
+
_pad0: 0,
|
|
47
|
+
},
|
|
48
|
+
Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
|
|
49
|
+
);
|
|
50
|
+
|
|
51
|
+
return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export async function runRepeatChannels(input, options = {}) {
|
|
55
|
+
return _repeatChannels(null, input, options);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export async function recordRepeatChannels(recorder, input, options = {}) {
|
|
59
|
+
return _repeatChannels(recorder, input, options);
|
|
60
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
in_channels: u32,
|
|
5
|
+
height: u32,
|
|
6
|
+
width: u32,
|
|
7
|
+
repeats: u32,
|
|
8
|
+
_pad0: u32,
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
12
|
+
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
13
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
14
|
+
|
|
15
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
16
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
17
|
+
let idx = gid.x;
|
|
18
|
+
let spatial = u.height * u.width;
|
|
19
|
+
let out_channels = u.in_channels * u.repeats;
|
|
20
|
+
let total = out_channels * spatial;
|
|
21
|
+
if (idx >= total) {
|
|
22
|
+
return;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
let out_channel = idx / spatial;
|
|
26
|
+
let channel = out_channel / u.repeats;
|
|
27
|
+
let spatial_idx = idx - out_channel * spatial;
|
|
28
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
29
|
+
}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
in_channels: u32,
|
|
7
|
+
height: u32,
|
|
8
|
+
width: u32,
|
|
9
|
+
repeats: u32,
|
|
10
|
+
_pad0: u32,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
14
|
+
@group(0) @binding(1) var<storage, read> input: array<f16>;
|
|
15
|
+
@group(0) @binding(2) var<storage, read_write> output: array<f16>;
|
|
16
|
+
|
|
17
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
18
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
19
|
+
let idx = gid.x;
|
|
20
|
+
let spatial = u.height * u.width;
|
|
21
|
+
let out_channels = u.in_channels * u.repeats;
|
|
22
|
+
let total = out_channels * spatial;
|
|
23
|
+
if (idx >= total) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let out_channel = idx / spatial;
|
|
28
|
+
let channel = out_channel / u.repeats;
|
|
29
|
+
let spatial_idx = idx - out_channel * spatial;
|
|
30
|
+
output[idx] = input[channel * spatial + spatial_idx];
|
|
31
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import type { Tensor } from '../tensor.js';
|
|
2
|
+
import type { CommandRecorder } from '../command-recorder.js';
|
|
3
|
+
import type { OutputBufferOptions } from './types.js';
|
|
4
|
+
|
|
5
|
+
export interface SanaLinearAttentionOptions extends OutputBufferOptions {
|
|
6
|
+
numHeads: number;
|
|
7
|
+
headDim: number;
|
|
8
|
+
numTokens?: number;
|
|
9
|
+
hiddenSize?: number;
|
|
10
|
+
eps?: number;
|
|
11
|
+
summaryBuffer?: GPUBuffer | null;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export declare function runSanaLinearAttention(
|
|
15
|
+
query: Tensor,
|
|
16
|
+
key: Tensor,
|
|
17
|
+
value: Tensor,
|
|
18
|
+
options: SanaLinearAttentionOptions
|
|
19
|
+
): Promise<Tensor>;
|
|
20
|
+
|
|
21
|
+
export declare function recordSanaLinearAttention(
|
|
22
|
+
recorder: CommandRecorder,
|
|
23
|
+
query: Tensor,
|
|
24
|
+
key: Tensor,
|
|
25
|
+
value: Tensor,
|
|
26
|
+
options: SanaLinearAttentionOptions
|
|
27
|
+
): Promise<Tensor>;
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import { getDevice } from '../device.js';
|
|
2
|
+
import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
|
+
import { createTensor, dtypeBytes } from '../tensor.js';
|
|
4
|
+
import { unifiedKernelWrapper } from './utils.js';
|
|
5
|
+
import { selectRuleValue } from './rule-registry.js';
|
|
6
|
+
import { WORKGROUP_SIZES } from './constants.js';
|
|
7
|
+
|
|
8
|
+
function selectSanaLinearAttentionVariant(isF16) {
|
|
9
|
+
return selectRuleValue('sanaLinearAttention', 'variant', { isF16 });
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
async function runSummary(target, query, key, value, summaryBuffer, uniforms, variant) {
|
|
13
|
+
const summarySize = uniforms.num_heads * (uniforms.head_dim + 1) * uniforms.head_dim;
|
|
14
|
+
await unifiedKernelWrapper(
|
|
15
|
+
'sana_linear_attention_summary',
|
|
16
|
+
target,
|
|
17
|
+
variant,
|
|
18
|
+
[query, key, value, summaryBuffer],
|
|
19
|
+
{
|
|
20
|
+
num_heads: uniforms.num_heads,
|
|
21
|
+
head_dim: uniforms.head_dim,
|
|
22
|
+
num_tokens: uniforms.num_tokens,
|
|
23
|
+
hidden_size: uniforms.hidden_size,
|
|
24
|
+
_pad0: 0,
|
|
25
|
+
_pad1: 0,
|
|
26
|
+
},
|
|
27
|
+
Math.ceil(summarySize / WORKGROUP_SIZES.DEFAULT)
|
|
28
|
+
);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
async function runApply(target, query, summaryBuffer, outputBuffer, uniforms, variant) {
|
|
32
|
+
const outputSize = uniforms.num_tokens * uniforms.hidden_size;
|
|
33
|
+
await unifiedKernelWrapper(
|
|
34
|
+
'sana_linear_attention_apply',
|
|
35
|
+
target,
|
|
36
|
+
variant,
|
|
37
|
+
[query, summaryBuffer, outputBuffer],
|
|
38
|
+
{
|
|
39
|
+
num_heads: uniforms.num_heads,
|
|
40
|
+
head_dim: uniforms.head_dim,
|
|
41
|
+
num_tokens: uniforms.num_tokens,
|
|
42
|
+
hidden_size: uniforms.hidden_size,
|
|
43
|
+
eps: uniforms.eps,
|
|
44
|
+
_pad0: 0,
|
|
45
|
+
_pad1: 0,
|
|
46
|
+
_pad2: 0,
|
|
47
|
+
},
|
|
48
|
+
Math.ceil(outputSize / WORKGROUP_SIZES.DEFAULT)
|
|
49
|
+
);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
async function _sanaLinearAttention(target, query, key, value, options = {}) {
|
|
53
|
+
const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
|
|
54
|
+
const device = target?.device || getDevice();
|
|
55
|
+
if (!device) {
|
|
56
|
+
throw new Error('SanaLinearAttention requires a WebGPU device.');
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
const {
|
|
60
|
+
numHeads,
|
|
61
|
+
headDim,
|
|
62
|
+
numTokens = query.shape?.[0],
|
|
63
|
+
hiddenSize = query.shape?.[1],
|
|
64
|
+
eps = 1e-15,
|
|
65
|
+
outputBuffer = null,
|
|
66
|
+
summaryBuffer = null,
|
|
67
|
+
} = options;
|
|
68
|
+
|
|
69
|
+
if (
|
|
70
|
+
!Number.isFinite(numHeads) ||
|
|
71
|
+
!Number.isFinite(headDim) ||
|
|
72
|
+
!Number.isFinite(numTokens) ||
|
|
73
|
+
!Number.isFinite(hiddenSize)
|
|
74
|
+
) {
|
|
75
|
+
throw new Error('SanaLinearAttention requires numHeads, headDim, numTokens, and hiddenSize.');
|
|
76
|
+
}
|
|
77
|
+
if (hiddenSize !== numHeads * headDim) {
|
|
78
|
+
throw new Error(`SanaLinearAttention hiddenSize mismatch: ${hiddenSize} != ${numHeads} * ${headDim}`);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const isF16 = query.dtype === 'f16';
|
|
82
|
+
const variant = selectSanaLinearAttentionVariant(isF16);
|
|
83
|
+
const temporarySummary = summaryBuffer || acquireBuffer(
|
|
84
|
+
numHeads * (headDim + 1) * headDim * Float32Array.BYTES_PER_ELEMENT,
|
|
85
|
+
undefined,
|
|
86
|
+
'sana_linear_attention_summary'
|
|
87
|
+
);
|
|
88
|
+
const output = outputBuffer || acquireBuffer(
|
|
89
|
+
numTokens * hiddenSize * dtypeBytes(query.dtype),
|
|
90
|
+
undefined,
|
|
91
|
+
'sana_linear_attention_output'
|
|
92
|
+
);
|
|
93
|
+
|
|
94
|
+
const uniforms = {
|
|
95
|
+
num_heads: numHeads,
|
|
96
|
+
head_dim: headDim,
|
|
97
|
+
num_tokens: numTokens,
|
|
98
|
+
hidden_size: hiddenSize,
|
|
99
|
+
eps,
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
await runSummary(target, query, key, value, temporarySummary, uniforms, variant);
|
|
103
|
+
await runApply(target, query, temporarySummary, output, uniforms, variant);
|
|
104
|
+
|
|
105
|
+
if (!summaryBuffer) {
|
|
106
|
+
if (recorder) {
|
|
107
|
+
recorder.trackTemporaryBuffer(temporarySummary);
|
|
108
|
+
} else {
|
|
109
|
+
releaseBuffer(temporarySummary);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return createTensor(output, query.dtype, [numTokens, hiddenSize], 'sana_linear_attention_output');
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
export async function runSanaLinearAttention(query, key, value, options = {}) {
|
|
117
|
+
return _sanaLinearAttention(null, query, key, value, options);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
export async function recordSanaLinearAttention(recorder, query, key, value, options = {}) {
|
|
121
|
+
return _sanaLinearAttention(recorder, query, key, value, options);
|
|
122
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
num_heads: u32,
|
|
5
|
+
head_dim: u32,
|
|
6
|
+
num_tokens: u32,
|
|
7
|
+
hidden_size: u32,
|
|
8
|
+
eps: f32,
|
|
9
|
+
_pad0: u32,
|
|
10
|
+
_pad1: u32,
|
|
11
|
+
_pad2: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> query: array<f32>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> summary: array<f32>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
|
|
18
|
+
|
|
19
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
20
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
21
|
+
let idx = gid.x;
|
|
22
|
+
let total = u.num_tokens * u.hidden_size;
|
|
23
|
+
if (idx >= total) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
let token = idx / u.hidden_size;
|
|
28
|
+
let hidden = idx - token * u.hidden_size;
|
|
29
|
+
let head = hidden / u.head_dim;
|
|
30
|
+
let dim = hidden - head * u.head_dim;
|
|
31
|
+
let rows_per_head = u.head_dim + 1u;
|
|
32
|
+
let head_offset = head * rows_per_head * u.head_dim;
|
|
33
|
+
let hidden_base = head * u.head_dim;
|
|
34
|
+
|
|
35
|
+
var numerator: f32 = 0.0;
|
|
36
|
+
var denominator: f32 = 0.0;
|
|
37
|
+
for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
|
|
38
|
+
let q_value = max(query[token * u.hidden_size + hidden_base + i], 0.0);
|
|
39
|
+
numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
|
|
40
|
+
denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
output[idx] = numerator / (denominator + u.eps);
|
|
44
|
+
}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
num_heads: u32,
|
|
7
|
+
head_dim: u32,
|
|
8
|
+
num_tokens: u32,
|
|
9
|
+
hidden_size: u32,
|
|
10
|
+
eps: f32,
|
|
11
|
+
_pad0: u32,
|
|
12
|
+
_pad1: u32,
|
|
13
|
+
_pad2: u32,
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
17
|
+
@group(0) @binding(1) var<storage, read> query: array<f16>;
|
|
18
|
+
@group(0) @binding(2) var<storage, read> summary: array<f32>;
|
|
19
|
+
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
|
|
20
|
+
|
|
21
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
22
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
23
|
+
let idx = gid.x;
|
|
24
|
+
let total = u.num_tokens * u.hidden_size;
|
|
25
|
+
if (idx >= total) {
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
let token = idx / u.hidden_size;
|
|
30
|
+
let hidden = idx - token * u.hidden_size;
|
|
31
|
+
let head = hidden / u.head_dim;
|
|
32
|
+
let dim = hidden - head * u.head_dim;
|
|
33
|
+
let rows_per_head = u.head_dim + 1u;
|
|
34
|
+
let head_offset = head * rows_per_head * u.head_dim;
|
|
35
|
+
let hidden_base = head * u.head_dim;
|
|
36
|
+
|
|
37
|
+
var numerator: f32 = 0.0;
|
|
38
|
+
var denominator: f32 = 0.0;
|
|
39
|
+
for (var i: u32 = 0u; i < u.head_dim; i = i + 1u) {
|
|
40
|
+
let q_value = max(f32(query[token * u.hidden_size + hidden_base + i]), 0.0);
|
|
41
|
+
numerator = numerator + summary[head_offset + dim * u.head_dim + i] * q_value;
|
|
42
|
+
denominator = denominator + summary[head_offset + u.head_dim * u.head_dim + i] * q_value;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
let result = numerator / (denominator + u.eps);
|
|
46
|
+
output[idx] = f16(clamp(result, -65504.0, 65504.0));
|
|
47
|
+
}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
2
|
+
|
|
3
|
+
struct Uniforms {
|
|
4
|
+
num_heads: u32,
|
|
5
|
+
head_dim: u32,
|
|
6
|
+
num_tokens: u32,
|
|
7
|
+
hidden_size: u32,
|
|
8
|
+
_pad0: u32,
|
|
9
|
+
_pad1: u32,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
13
|
+
@group(0) @binding(1) var<storage, read> query: array<f32>;
|
|
14
|
+
@group(0) @binding(2) var<storage, read> key: array<f32>;
|
|
15
|
+
@group(0) @binding(3) var<storage, read> value: array<f32>;
|
|
16
|
+
@group(0) @binding(4) var<storage, read_write> summary: array<f32>;
|
|
17
|
+
|
|
18
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
19
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
20
|
+
let idx = gid.x;
|
|
21
|
+
let rows_per_head = u.head_dim + 1u;
|
|
22
|
+
let head_span = rows_per_head * u.head_dim;
|
|
23
|
+
let total = u.num_heads * head_span;
|
|
24
|
+
if (idx >= total) {
|
|
25
|
+
return;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
let head = idx / head_span;
|
|
29
|
+
let rem = idx - head * head_span;
|
|
30
|
+
let row = rem / u.head_dim;
|
|
31
|
+
let col = rem - row * u.head_dim;
|
|
32
|
+
let hidden_base = head * u.head_dim;
|
|
33
|
+
|
|
34
|
+
var acc: f32 = 0.0;
|
|
35
|
+
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
36
|
+
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
37
|
+
let key_value = max(key[key_idx], 0.0);
|
|
38
|
+
let value_value = select(
|
|
39
|
+
value[token * u.hidden_size + hidden_base + row],
|
|
40
|
+
1.0,
|
|
41
|
+
row == u.head_dim
|
|
42
|
+
);
|
|
43
|
+
acc = acc + value_value * key_value;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
summary[idx] = acc;
|
|
47
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
enable f16;
|
|
2
|
+
|
|
3
|
+
override WORKGROUP_SIZE: u32 = 256u;
|
|
4
|
+
|
|
5
|
+
struct Uniforms {
|
|
6
|
+
num_heads: u32,
|
|
7
|
+
head_dim: u32,
|
|
8
|
+
num_tokens: u32,
|
|
9
|
+
hidden_size: u32,
|
|
10
|
+
_pad0: u32,
|
|
11
|
+
_pad1: u32,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
@group(0) @binding(0) var<uniform> u: Uniforms;
|
|
15
|
+
@group(0) @binding(1) var<storage, read> query: array<f16>;
|
|
16
|
+
@group(0) @binding(2) var<storage, read> key: array<f16>;
|
|
17
|
+
@group(0) @binding(3) var<storage, read> value: array<f16>;
|
|
18
|
+
@group(0) @binding(4) var<storage, read_write> summary: array<f32>;
|
|
19
|
+
|
|
20
|
+
@compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
|
|
21
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
22
|
+
let idx = gid.x;
|
|
23
|
+
let rows_per_head = u.head_dim + 1u;
|
|
24
|
+
let head_span = rows_per_head * u.head_dim;
|
|
25
|
+
let total = u.num_heads * head_span;
|
|
26
|
+
if (idx >= total) {
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
let head = idx / head_span;
|
|
31
|
+
let rem = idx - head * head_span;
|
|
32
|
+
let row = rem / u.head_dim;
|
|
33
|
+
let col = rem - row * u.head_dim;
|
|
34
|
+
let hidden_base = head * u.head_dim;
|
|
35
|
+
|
|
36
|
+
var acc: f32 = 0.0;
|
|
37
|
+
for (var token: u32 = 0u; token < u.num_tokens; token = token + 1u) {
|
|
38
|
+
let key_idx = token * u.hidden_size + hidden_base + col;
|
|
39
|
+
let key_value = max(f32(key[key_idx]), 0.0);
|
|
40
|
+
let value_value = select(
|
|
41
|
+
f32(value[token * u.hidden_size + hidden_base + row]),
|
|
42
|
+
1.0,
|
|
43
|
+
row == u.head_dim
|
|
44
|
+
);
|
|
45
|
+
acc = acc + value_value * key_value;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
summary[idx] = acc;
|
|
49
|
+
}
|
package/src/index-browser.d.ts
CHANGED
package/src/index-browser.js
CHANGED
package/src/index.js
CHANGED