@simulatte/doppler 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. package/README.md +11 -5
  2. package/package.json +27 -4
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.d.ts +80 -0
  6. package/src/client/doppler-api.js +298 -0
  7. package/src/client/doppler-provider/types.js +1 -1
  8. package/src/client/doppler-registry.d.ts +23 -0
  9. package/src/client/doppler-registry.js +88 -0
  10. package/src/client/doppler-registry.json +39 -0
  11. package/src/config/execution-contract-check.d.ts +82 -0
  12. package/src/config/execution-contract-check.js +317 -0
  13. package/src/config/execution-v0-contract-check.d.ts +94 -0
  14. package/src/config/execution-v0-contract-check.js +251 -0
  15. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  16. package/src/config/execution-v0-graph-contract-check.js +64 -0
  17. package/src/config/kernel-path-contract-check.d.ts +76 -0
  18. package/src/config/kernel-path-contract-check.js +479 -0
  19. package/src/config/kernel-path-loader.d.ts +16 -0
  20. package/src/config/kernel-path-loader.js +54 -0
  21. package/src/config/kernels/kernel-ref-digests.js +12 -0
  22. package/src/config/kernels/registry.json +556 -0
  23. package/src/config/loader.js +90 -67
  24. package/src/config/merge-contract-check.d.ts +16 -0
  25. package/src/config/merge-contract-check.js +321 -0
  26. package/src/config/merge-helpers.d.ts +58 -0
  27. package/src/config/merge-helpers.js +54 -0
  28. package/src/config/merge.js +3 -6
  29. package/src/config/presets/models/janus-text.json +27 -0
  30. package/src/config/quantization-contract-check.d.ts +12 -0
  31. package/src/config/quantization-contract-check.js +91 -0
  32. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  33. package/src/config/required-inference-fields-contract-check.js +231 -0
  34. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  35. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  36. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  37. package/src/config/schema/conversion-report.schema.js +108 -0
  38. package/src/config/schema/doppler.schema.js +12 -18
  39. package/src/config/schema/index.d.ts +22 -0
  40. package/src/config/schema/index.js +18 -0
  41. package/src/converter/core.d.ts +10 -0
  42. package/src/converter/core.js +49 -11
  43. package/src/converter/parsers/diffusion.js +63 -3
  44. package/src/converter/tokenizer-utils.js +17 -3
  45. package/src/formats/rdrr/validation.js +13 -0
  46. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  47. package/src/gpu/kernels/depthwise_conv2d.js +98 -0
  48. package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
  49. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
  50. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  51. package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
  52. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
  53. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
  54. package/src/gpu/kernels/index.d.ts +30 -0
  55. package/src/gpu/kernels/index.js +25 -0
  56. package/src/gpu/kernels/relu.d.ts +18 -0
  57. package/src/gpu/kernels/relu.js +45 -0
  58. package/src/gpu/kernels/relu.wgsl +21 -0
  59. package/src/gpu/kernels/relu_f16.wgsl +23 -0
  60. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  61. package/src/gpu/kernels/repeat_channels.js +60 -0
  62. package/src/gpu/kernels/repeat_channels.wgsl +29 -0
  63. package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
  64. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  65. package/src/gpu/kernels/sana_linear_attention.js +122 -0
  66. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
  67. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
  68. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
  69. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
  70. package/src/index-browser.d.ts +1 -0
  71. package/src/index-browser.js +2 -1
  72. package/src/index.d.ts +1 -0
  73. package/src/index.js +2 -1
  74. package/src/inference/browser-harness.js +164 -38
  75. package/src/inference/pipelines/diffusion/init.js +14 -0
  76. package/src/inference/pipelines/diffusion/pipeline.js +206 -77
  77. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  78. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  79. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  80. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  81. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
  82. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
  83. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  84. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  85. package/src/inference/pipelines/diffusion/vae.js +782 -78
  86. package/src/inference/pipelines/text/config.d.ts +5 -0
  87. package/src/inference/pipelines/text/config.js +1 -1
  88. package/src/inference/pipelines/text/execution-v0.js +141 -101
  89. package/src/inference/pipelines/text/init.js +41 -10
  90. package/src/inference/pipelines/text.js +7 -1
  91. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  92. package/src/rules/execution-rules-contract-check.js +245 -0
  93. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  94. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  95. package/src/rules/kernels/relu.rules.json +6 -0
  96. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  97. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  98. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  99. package/src/rules/layer-pattern-contract-check.js +231 -0
  100. package/src/rules/rule-registry.d.ts +28 -0
  101. package/src/rules/rule-registry.js +38 -0
  102. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  103. package/src/tooling/conversion-config-materializer.js +99 -0
  104. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  105. package/src/tooling/lean-execution-contract-runner.js +158 -0
  106. package/src/tooling/lean-execution-contract.d.ts +16 -0
  107. package/src/tooling/lean-execution-contract.js +81 -0
  108. package/src/tooling/node-convert.d.ts +10 -0
  109. package/src/tooling/node-converter.js +59 -0
  110. package/src/tooling/node-webgpu.js +30 -9
  111. package/src/version.d.ts +2 -0
  112. package/src/version.js +2 -0
  113. package/tools/convert-safetensors-node.js +47 -0
  114. package/tools/doppler-cli.js +167 -6
@@ -0,0 +1,27 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+ import type { WeightBuffer } from '../weight-buffer.js';
5
+
6
+ export interface GroupedPointwiseConv2DOptions extends OutputBufferOptions {
7
+ inChannels: number;
8
+ outChannels: number;
9
+ height: number;
10
+ width: number;
11
+ groups: number;
12
+ }
13
+
14
+ export declare function runGroupedPointwiseConv2D(
15
+ input: Tensor,
16
+ weight: GPUBuffer | WeightBuffer,
17
+ bias: GPUBuffer | WeightBuffer | null,
18
+ options: GroupedPointwiseConv2DOptions
19
+ ): Promise<Tensor>;
20
+
21
+ export declare function recordGroupedPointwiseConv2D(
22
+ recorder: CommandRecorder,
23
+ input: Tensor,
24
+ weight: GPUBuffer | WeightBuffer,
25
+ bias: GPUBuffer | WeightBuffer | null,
26
+ options: GroupedPointwiseConv2DOptions
27
+ ): Promise<Tensor>;
@@ -0,0 +1,92 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { getBuffer } from '../weight-buffer.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+ import { WORKGROUP_SIZES } from './constants.js';
8
+
9
+ function selectGroupedPointwiseConv2DVariant(isF16) {
10
+ return selectRuleValue('groupedPointwiseConv2d', 'variant', { isF16 });
11
+ }
12
+
13
+ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}) {
14
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
15
+ const device = target?.device || getDevice();
16
+ const {
17
+ inChannels,
18
+ outChannels,
19
+ height,
20
+ width,
21
+ groups,
22
+ outputBuffer = null,
23
+ } = options;
24
+
25
+ if (
26
+ !Number.isFinite(inChannels) ||
27
+ !Number.isFinite(outChannels) ||
28
+ !Number.isFinite(height) ||
29
+ !Number.isFinite(width) ||
30
+ !Number.isFinite(groups)
31
+ ) {
32
+ throw new Error('GroupedPointwiseConv2D requires explicit dimensions.');
33
+ }
34
+ if (groups <= 0 || inChannels % groups !== 0 || outChannels % groups !== 0) {
35
+ throw new Error(
36
+ `GroupedPointwiseConv2D requires inChannels/outChannels divisible by groups. Got ${inChannels}/${outChannels}/${groups}.`
37
+ );
38
+ }
39
+
40
+ const isF16 = input.dtype === 'f16';
41
+ const variant = selectGroupedPointwiseConv2DVariant(isF16);
42
+ const bytesPerElement = dtypeBytes(input.dtype);
43
+ const outputSize = outChannels * height * width * bytesPerElement;
44
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
45
+
46
+ const weightBuffer = getBuffer(weight);
47
+ let biasBuffer = getBuffer(bias);
48
+ let tempBias = null;
49
+ if (!biasBuffer) {
50
+ const biasSize = outChannels * bytesPerElement;
51
+ tempBias = acquireBuffer(biasSize, undefined, 'grouped_pointwise_conv2d_bias_zero');
52
+ biasBuffer = tempBias;
53
+ const paddedSize = Math.ceil(biasSize / 4) * 4;
54
+ device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
55
+ }
56
+
57
+ await unifiedKernelWrapper(
58
+ 'grouped_pointwise_conv2d',
59
+ target,
60
+ variant,
61
+ [input, weightBuffer, biasBuffer, output],
62
+ {
63
+ in_channels: inChannels,
64
+ out_channels: outChannels,
65
+ height,
66
+ width,
67
+ groups,
68
+ _pad0: 0,
69
+ _pad1: 0,
70
+ _pad2: 0,
71
+ },
72
+ Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
73
+ );
74
+
75
+ if (tempBias) {
76
+ if (recorder) {
77
+ recorder.trackTemporaryBuffer(tempBias);
78
+ } else {
79
+ releaseBuffer(tempBias);
80
+ }
81
+ }
82
+
83
+ return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
84
+ }
85
+
86
+ export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
87
+ return _groupedPointwiseConv2D(null, input, weight, bias, options);
88
+ }
89
+
90
+ export async function recordGroupedPointwiseConv2D(recorder, input, weight, bias, options = {}) {
91
+ return _groupedPointwiseConv2D(recorder, input, weight, bias, options);
92
+ }
@@ -0,0 +1,47 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ in_channels: u32,
5
+ out_channels: u32,
6
+ height: u32,
7
+ width: u32,
8
+ groups: u32,
9
+ _pad0: u32,
10
+ _pad1: u32,
11
+ _pad2: u32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<uniform> u: Uniforms;
15
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
16
+ @group(0) @binding(2) var<storage, read> weight: array<f32>;
17
+ @group(0) @binding(3) var<storage, read> bias: array<f32>;
18
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
19
+
20
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
+ let idx = gid.x;
23
+ let spatial = u.height * u.width;
24
+ let out_size = u.out_channels * spatial;
25
+ if (idx >= out_size) {
26
+ return;
27
+ }
28
+
29
+ let out_channel = idx / spatial;
30
+ let rem = idx - out_channel * spatial;
31
+ let y = rem / u.width;
32
+ let x = rem - y * u.width;
33
+
34
+ let in_per_group = u.in_channels / u.groups;
35
+ let out_per_group = u.out_channels / u.groups;
36
+ let group_idx = out_channel / out_per_group;
37
+ let in_offset = group_idx * in_per_group;
38
+
39
+ var sum: f32 = bias[out_channel];
40
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
41
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
42
+ let weight_idx = out_channel * in_per_group + i;
43
+ sum = sum + input[input_idx] * weight[weight_idx];
44
+ }
45
+
46
+ output[idx] = sum;
47
+ }
@@ -0,0 +1,51 @@
1
+ // Grouped Pointwise Conv2D Kernel (NCHW, f16)
2
+
3
+ enable f16;
4
+
5
+ override WORKGROUP_SIZE: u32 = 256u;
6
+
7
+ struct Uniforms {
8
+ in_channels: u32,
9
+ out_channels: u32,
10
+ height: u32,
11
+ width: u32,
12
+ groups: u32,
13
+ _pad0: u32,
14
+ _pad1: u32,
15
+ _pad2: u32,
16
+ }
17
+
18
+ @group(0) @binding(0) var<uniform> u: Uniforms;
19
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
20
+ @group(0) @binding(2) var<storage, read> weight: array<f16>;
21
+ @group(0) @binding(3) var<storage, read> bias: array<f16>;
22
+ @group(0) @binding(4) var<storage, read_write> output: array<f16>;
23
+
24
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let idx = gid.x;
27
+ let spatial = u.height * u.width;
28
+ let out_size = u.out_channels * spatial;
29
+ if (idx >= out_size) {
30
+ return;
31
+ }
32
+
33
+ let out_channel = idx / spatial;
34
+ let rem = idx - out_channel * spatial;
35
+ let y = rem / u.width;
36
+ let x = rem - y * u.width;
37
+
38
+ let in_per_group = u.in_channels / u.groups;
39
+ let out_per_group = u.out_channels / u.groups;
40
+ let group_idx = out_channel / out_per_group;
41
+ let in_offset = group_idx * in_per_group;
42
+
43
+ var sum: f32 = f32(bias[out_channel]);
44
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
45
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
46
+ let weight_idx = out_channel * in_per_group + i;
47
+ sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
48
+ }
49
+
50
+ output[idx] = f16(sum);
51
+ }
@@ -174,6 +174,18 @@ export {
174
174
  type Conv2DOptions,
175
175
  } from './conv2d.js';
176
176
 
177
+ export {
178
+ runDepthwiseConv2D,
179
+ recordDepthwiseConv2D,
180
+ type DepthwiseConv2DOptions,
181
+ } from './depthwise_conv2d.js';
182
+
183
+ export {
184
+ runGroupedPointwiseConv2D,
185
+ recordGroupedPointwiseConv2D,
186
+ type GroupedPointwiseConv2DOptions,
187
+ } from './grouped_pointwise_conv2d.js';
188
+
177
189
  // Gather (Embedding Lookup)
178
190
  export {
179
191
  runGather,
@@ -250,6 +262,24 @@ export {
250
262
  type SampleResult,
251
263
  } from './sample.js';
252
264
 
265
+ export {
266
+ runSanaLinearAttention,
267
+ recordSanaLinearAttention,
268
+ type SanaLinearAttentionOptions,
269
+ } from './sana_linear_attention.js';
270
+
271
+ export {
272
+ runRepeatChannels,
273
+ recordRepeatChannels,
274
+ type RepeatChannelsOptions,
275
+ } from './repeat_channels.js';
276
+
277
+ export {
278
+ runReLU,
279
+ recordReLU,
280
+ type ReLUOptions,
281
+ } from './relu.js';
282
+
253
283
  // Fused FFN (Tier 2 P0)
254
284
  export {
255
285
  runFusedFFN,
@@ -139,6 +139,16 @@ export {
139
139
  recordConv2D,
140
140
  } from './conv2d.js';
141
141
 
142
+ export {
143
+ runDepthwiseConv2D,
144
+ recordDepthwiseConv2D,
145
+ } from './depthwise_conv2d.js';
146
+
147
+ export {
148
+ runGroupedPointwiseConv2D,
149
+ recordGroupedPointwiseConv2D,
150
+ } from './grouped_pointwise_conv2d.js';
151
+
142
152
  // Gather (Embedding Lookup)
143
153
  export {
144
154
  runGather,
@@ -205,6 +215,21 @@ export {
205
215
  isGPUSamplingAvailable,
206
216
  } from './sample.js';
207
217
 
218
+ export {
219
+ runSanaLinearAttention,
220
+ recordSanaLinearAttention,
221
+ } from './sana_linear_attention.js';
222
+
223
+ export {
224
+ runRepeatChannels,
225
+ recordRepeatChannels,
226
+ } from './repeat_channels.js';
227
+
228
+ export {
229
+ runReLU,
230
+ recordReLU,
231
+ } from './relu.js';
232
+
208
233
  // Fused FFN (Tier 2 P0)
209
234
  export {
210
235
  runFusedFFN,
@@ -0,0 +1,18 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface ReLUOptions extends OutputBufferOptions {
6
+ count?: number | null;
7
+ }
8
+
9
+ export declare function runReLU(
10
+ input: Tensor,
11
+ options?: ReLUOptions
12
+ ): Promise<Tensor>;
13
+
14
+ export declare function recordReLU(
15
+ recorder: CommandRecorder,
16
+ input: Tensor,
17
+ options?: ReLUOptions
18
+ ): Promise<Tensor>;
@@ -0,0 +1,45 @@
1
+ import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { createTensor, dtypeBytes } from '../tensor.js';
3
+ import { unifiedKernelWrapper } from './utils.js';
4
+ import { selectRuleValue } from './rule-registry.js';
5
+ import { WORKGROUP_SIZES } from './constants.js';
6
+
7
+ function selectReluVariant(dtype) {
8
+ return selectRuleValue('relu', 'variant', { dtype });
9
+ }
10
+
11
+ function resolveCount(input, countOverride) {
12
+ if (Number.isFinite(countOverride) && countOverride > 0) {
13
+ return Math.floor(countOverride);
14
+ }
15
+ if (Array.isArray(input.shape) && input.shape.length > 0) {
16
+ return input.shape.reduce((acc, value) => acc * value, 1);
17
+ }
18
+ return Math.floor(input.buffer.size / dtypeBytes(input.dtype));
19
+ }
20
+
21
+ async function _relu(target, input, options = {}) {
22
+ const { count = null, outputBuffer = null } = options;
23
+ const size = resolveCount(input, count);
24
+ const variant = selectReluVariant(input.dtype);
25
+ const output = outputBuffer || acquireBuffer(size * dtypeBytes(input.dtype), undefined, 'relu_output');
26
+
27
+ await unifiedKernelWrapper(
28
+ 'relu',
29
+ target,
30
+ variant,
31
+ [input, output],
32
+ { size, _pad0: 0, _pad1: 0, _pad2: 0 },
33
+ Math.ceil(size / WORKGROUP_SIZES.DEFAULT)
34
+ );
35
+
36
+ return createTensor(output, input.dtype, [...input.shape], 'relu_output');
37
+ }
38
+
39
+ export async function runReLU(input, options = {}) {
40
+ return _relu(null, input, options);
41
+ }
42
+
43
+ export async function recordReLU(recorder, input, options = {}) {
44
+ return _relu(recorder, input, options);
45
+ }
@@ -0,0 +1,21 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ size: u32,
5
+ _pad0: u32,
6
+ _pad1: u32,
7
+ _pad2: u32,
8
+ }
9
+
10
+ @group(0) @binding(0) var<uniform> u: Uniforms;
11
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
12
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
13
+
14
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
15
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
16
+ let idx = gid.x;
17
+ if (idx >= u.size) {
18
+ return;
19
+ }
20
+ output[idx] = max(input[idx], 0.0);
21
+ }
@@ -0,0 +1,23 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ size: u32,
7
+ _pad0: u32,
8
+ _pad1: u32,
9
+ _pad2: u32,
10
+ }
11
+
12
+ @group(0) @binding(0) var<uniform> u: Uniforms;
13
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
14
+ @group(0) @binding(2) var<storage, read_write> output: array<f16>;
15
+
16
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
17
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18
+ let idx = gid.x;
19
+ if (idx >= u.size) {
20
+ return;
21
+ }
22
+ output[idx] = max(input[idx], 0.0h);
23
+ }
@@ -0,0 +1,21 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface RepeatChannelsOptions extends OutputBufferOptions {
6
+ inChannels: number;
7
+ height: number;
8
+ width: number;
9
+ repeats: number;
10
+ }
11
+
12
+ export declare function runRepeatChannels(
13
+ input: Tensor,
14
+ options: RepeatChannelsOptions
15
+ ): Promise<Tensor>;
16
+
17
+ export declare function recordRepeatChannels(
18
+ recorder: CommandRecorder,
19
+ input: Tensor,
20
+ options: RepeatChannelsOptions
21
+ ): Promise<Tensor>;
@@ -0,0 +1,60 @@
1
+ import { acquireBuffer } from '../../memory/buffer-pool.js';
2
+ import { createTensor, dtypeBytes } from '../tensor.js';
3
+ import { unifiedKernelWrapper } from './utils.js';
4
+ import { selectRuleValue } from './rule-registry.js';
5
+ import { WORKGROUP_SIZES } from './constants.js';
6
+
7
+ function selectRepeatChannelsVariant(dtype) {
8
+ return selectRuleValue('repeatChannels', 'variant', { dtype });
9
+ }
10
+
11
+ async function _repeatChannels(target, input, options = {}) {
12
+ const {
13
+ inChannels,
14
+ height,
15
+ width,
16
+ repeats,
17
+ outputBuffer = null,
18
+ } = options;
19
+
20
+ if (
21
+ !Number.isFinite(inChannels) ||
22
+ !Number.isFinite(height) ||
23
+ !Number.isFinite(width) ||
24
+ !Number.isFinite(repeats) ||
25
+ repeats < 1
26
+ ) {
27
+ throw new Error('RepeatChannels requires inChannels, height, width, and repeats.');
28
+ }
29
+
30
+ const outChannels = inChannels * repeats;
31
+ const variant = selectRepeatChannelsVariant(input.dtype);
32
+ const bytesPerElement = dtypeBytes(input.dtype);
33
+ const outputSize = outChannels * height * width * bytesPerElement;
34
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'repeat_channels_output');
35
+
36
+ await unifiedKernelWrapper(
37
+ 'repeat_channels',
38
+ target,
39
+ variant,
40
+ [input, output],
41
+ {
42
+ in_channels: inChannels,
43
+ height,
44
+ width,
45
+ repeats,
46
+ _pad0: 0,
47
+ },
48
+ Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
49
+ );
50
+
51
+ return createTensor(output, input.dtype, [outChannels, height, width], 'repeat_channels_output');
52
+ }
53
+
54
+ export async function runRepeatChannels(input, options = {}) {
55
+ return _repeatChannels(null, input, options);
56
+ }
57
+
58
+ export async function recordRepeatChannels(recorder, input, options = {}) {
59
+ return _repeatChannels(recorder, input, options);
60
+ }
@@ -0,0 +1,29 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ in_channels: u32,
5
+ height: u32,
6
+ width: u32,
7
+ repeats: u32,
8
+ _pad0: u32,
9
+ }
10
+
11
+ @group(0) @binding(0) var<uniform> u: Uniforms;
12
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
13
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
14
+
15
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
16
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
17
+ let idx = gid.x;
18
+ let spatial = u.height * u.width;
19
+ let out_channels = u.in_channels * u.repeats;
20
+ let total = out_channels * spatial;
21
+ if (idx >= total) {
22
+ return;
23
+ }
24
+
25
+ let out_channel = idx / spatial;
26
+ let channel = out_channel / u.repeats;
27
+ let spatial_idx = idx - out_channel * spatial;
28
+ output[idx] = input[channel * spatial + spatial_idx];
29
+ }
@@ -0,0 +1,31 @@
1
+ enable f16;
2
+
3
+ override WORKGROUP_SIZE: u32 = 256u;
4
+
5
+ struct Uniforms {
6
+ in_channels: u32,
7
+ height: u32,
8
+ width: u32,
9
+ repeats: u32,
10
+ _pad0: u32,
11
+ }
12
+
13
+ @group(0) @binding(0) var<uniform> u: Uniforms;
14
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
15
+ @group(0) @binding(2) var<storage, read_write> output: array<f16>;
16
+
17
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
18
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19
+ let idx = gid.x;
20
+ let spatial = u.height * u.width;
21
+ let out_channels = u.in_channels * u.repeats;
22
+ let total = out_channels * spatial;
23
+ if (idx >= total) {
24
+ return;
25
+ }
26
+
27
+ let out_channel = idx / spatial;
28
+ let channel = out_channel / u.repeats;
29
+ let spatial_idx = idx - out_channel * spatial;
30
+ output[idx] = input[channel * spatial + spatial_idx];
31
+ }
@@ -0,0 +1,27 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface SanaLinearAttentionOptions extends OutputBufferOptions {
6
+ numHeads: number;
7
+ headDim: number;
8
+ numTokens?: number;
9
+ hiddenSize?: number;
10
+ eps?: number;
11
+ summaryBuffer?: GPUBuffer | null;
12
+ }
13
+
14
+ export declare function runSanaLinearAttention(
15
+ query: Tensor,
16
+ key: Tensor,
17
+ value: Tensor,
18
+ options: SanaLinearAttentionOptions
19
+ ): Promise<Tensor>;
20
+
21
+ export declare function recordSanaLinearAttention(
22
+ recorder: CommandRecorder,
23
+ query: Tensor,
24
+ key: Tensor,
25
+ value: Tensor,
26
+ options: SanaLinearAttentionOptions
27
+ ): Promise<Tensor>;