@simulatte/doppler 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. package/README.md +4 -3
  2. package/package.json +25 -4
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +12 -0
  18. package/src/config/kernels/registry.json +556 -0
  19. package/src/config/loader.js +50 -46
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +3 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/quantization-contract-check.d.ts +12 -0
  27. package/src/config/quantization-contract-check.js +91 -0
  28. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  29. package/src/config/required-inference-fields-contract-check.js +231 -0
  30. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  31. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  32. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  33. package/src/config/schema/conversion-report.schema.js +108 -0
  34. package/src/config/schema/doppler.schema.js +12 -18
  35. package/src/config/schema/index.d.ts +22 -0
  36. package/src/config/schema/index.js +18 -0
  37. package/src/converter/core.d.ts +10 -0
  38. package/src/converter/core.js +27 -2
  39. package/src/converter/parsers/diffusion.js +63 -3
  40. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  41. package/src/gpu/kernels/depthwise_conv2d.js +98 -0
  42. package/src/gpu/kernels/depthwise_conv2d.wgsl +58 -0
  43. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +62 -0
  44. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  45. package/src/gpu/kernels/grouped_pointwise_conv2d.js +92 -0
  46. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +47 -0
  47. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +51 -0
  48. package/src/gpu/kernels/index.d.ts +30 -0
  49. package/src/gpu/kernels/index.js +25 -0
  50. package/src/gpu/kernels/relu.d.ts +18 -0
  51. package/src/gpu/kernels/relu.js +45 -0
  52. package/src/gpu/kernels/relu.wgsl +21 -0
  53. package/src/gpu/kernels/relu_f16.wgsl +23 -0
  54. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  55. package/src/gpu/kernels/repeat_channels.js +60 -0
  56. package/src/gpu/kernels/repeat_channels.wgsl +29 -0
  57. package/src/gpu/kernels/repeat_channels_f16.wgsl +31 -0
  58. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  59. package/src/gpu/kernels/sana_linear_attention.js +122 -0
  60. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +44 -0
  61. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +47 -0
  62. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +47 -0
  63. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +49 -0
  64. package/src/index-browser.d.ts +1 -1
  65. package/src/index-browser.js +2 -2
  66. package/src/index.js +1 -1
  67. package/src/inference/browser-harness.js +62 -22
  68. package/src/inference/pipelines/diffusion/init.js +14 -0
  69. package/src/inference/pipelines/diffusion/pipeline.js +206 -77
  70. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  71. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  72. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  73. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  74. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +6 -4
  75. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +270 -0
  76. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  77. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  78. package/src/inference/pipelines/diffusion/vae.js +782 -78
  79. package/src/inference/pipelines/text/config.d.ts +5 -0
  80. package/src/inference/pipelines/text/config.js +1 -1
  81. package/src/inference/pipelines/text/execution-v0.js +14 -93
  82. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  83. package/src/rules/execution-rules-contract-check.js +245 -0
  84. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  85. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  86. package/src/rules/kernels/relu.rules.json +6 -0
  87. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  88. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  89. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  90. package/src/rules/layer-pattern-contract-check.js +231 -0
  91. package/src/rules/rule-registry.d.ts +28 -0
  92. package/src/rules/rule-registry.js +38 -0
  93. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  94. package/src/tooling/conversion-config-materializer.js +99 -0
  95. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  96. package/src/tooling/lean-execution-contract-runner.js +158 -0
  97. package/src/tooling/node-convert.d.ts +10 -0
  98. package/src/tooling/node-converter.js +59 -0
  99. package/src/tooling/node-webgpu.js +9 -9
  100. package/src/version.d.ts +2 -0
  101. package/src/version.js +2 -0
  102. package/tools/convert-safetensors-node.js +47 -0
  103. package/tools/doppler-cli.js +115 -1
@@ -4,6 +4,13 @@ const SD3_LAYOUT = {
4
4
  id: 'sd3',
5
5
  requiredComponents: ['transformer', 'text_encoder', 'text_encoder_2', 'text_encoder_3', 'vae', 'scheduler'],
6
6
  weightedComponents: ['transformer', 'text_encoder', 'text_encoder_2', 'text_encoder_3', 'vae'],
7
+ matches(modelIndex, components) {
8
+ return (
9
+ components.has('text_encoder_2') &&
10
+ components.has('text_encoder_3') &&
11
+ getComponentClassName(modelIndex?.transformer) === 'SD3Transformer2DModel'
12
+ );
13
+ },
7
14
  tokenizerSpecs: [
8
15
  {
9
16
  modelIndexKey: 'tokenizer',
@@ -66,6 +73,10 @@ const FLUX_LAYOUT = {
66
73
  id: 'flux',
67
74
  requiredComponents: ['transformer', 'text_encoder', 'vae', 'scheduler'],
68
75
  weightedComponents: ['transformer', 'text_encoder', 'vae'],
76
+ matches(modelIndex) {
77
+ const transformerClass = getComponentClassName(modelIndex?.transformer);
78
+ return typeof transformerClass === 'string' && /^Flux/i.test(transformerClass);
79
+ },
69
80
  tokenizerSpecs: [
70
81
  {
71
82
  modelIndexKey: 'tokenizer',
@@ -91,7 +102,39 @@ const FLUX_LAYOUT = {
91
102
  ],
92
103
  };
93
104
 
94
- const LAYOUTS = [SD3_LAYOUT, FLUX_LAYOUT];
105
+ const SANA_LAYOUT = {
106
+ id: 'sana',
107
+ requiredComponents: ['transformer', 'text_encoder', 'tokenizer', 'vae', 'scheduler'],
108
+ weightedComponents: ['transformer', 'text_encoder', 'vae'],
109
+ matches(modelIndex) {
110
+ return (
111
+ getComponentClassName(modelIndex?.transformer) === 'SanaTransformer2DModel' &&
112
+ getComponentClassName(modelIndex?.text_encoder) === 'Gemma2Model'
113
+ );
114
+ },
115
+ tokenizerSpecs: [
116
+ {
117
+ modelIndexKey: 'tokenizer',
118
+ componentId: 'text_encoder',
119
+ type: 'bundled',
120
+ assets: [
121
+ { suffix: 'tokenizer/tokenizer.json', targetName: 'tokenizer_tokenizer.json', kind: 'text', required: true },
122
+ { suffix: 'tokenizer/tokenizer_config.json', targetName: 'tokenizer_config.json', kind: 'text', required: false },
123
+ { suffix: 'tokenizer/special_tokens_map.json', targetName: 'tokenizer_special_tokens_map.json', kind: 'text', required: false },
124
+ { suffix: 'tokenizer/tokenizer.model', targetName: 'tokenizer_tokenizer.model', kind: 'binary', required: false },
125
+ ],
126
+ config: {
127
+ type: 'bundled',
128
+ tokenizerFile: 'tokenizer_tokenizer.json',
129
+ configFile: 'tokenizer_config.json',
130
+ specialTokensFile: 'tokenizer_special_tokens_map.json',
131
+ sentencePieceFile: 'tokenizer_tokenizer.model',
132
+ },
133
+ },
134
+ ],
135
+ };
136
+
137
+ const LAYOUTS = [SD3_LAYOUT, FLUX_LAYOUT, SANA_LAYOUT];
95
138
 
96
139
  function toAbortError(message = 'Cancelled') {
97
140
  if (typeof DOMException === 'function') {
@@ -112,12 +155,26 @@ function listModelComponents(modelIndex) {
112
155
  return Object.keys(modelIndex || {}).filter((key) => !key.startsWith('_'));
113
156
  }
114
157
 
158
+ function getComponentClassName(componentEntry) {
159
+ if (Array.isArray(componentEntry) && componentEntry.length >= 2 && typeof componentEntry[1] === 'string') {
160
+ return componentEntry[1];
161
+ }
162
+ if (componentEntry && typeof componentEntry === 'object' && typeof componentEntry._class_name === 'string') {
163
+ return componentEntry._class_name;
164
+ }
165
+ return null;
166
+ }
167
+
115
168
  export function detectDiffusionLayout(modelIndex) {
116
169
  const components = new Set(listModelComponents(modelIndex));
117
170
  for (const layout of LAYOUTS) {
118
- if (layout.requiredComponents.every((component) => components.has(component))) {
119
- return layout;
171
+ if (!layout.requiredComponents.every((component) => components.has(component))) {
172
+ continue;
120
173
  }
174
+ if (typeof layout.matches === 'function' && !layout.matches(modelIndex, components)) {
175
+ continue;
176
+ }
177
+ return layout;
121
178
  }
122
179
  const listed = [...components].sort().join(', ') || '(none)';
123
180
  const expected = LAYOUTS
@@ -199,6 +256,9 @@ export async function parseDiffusionModel(adapter) {
199
256
  const tensors = [];
200
257
 
201
258
  for (const componentId of layout.requiredComponents) {
259
+ if (componentId === 'tokenizer') {
260
+ continue;
261
+ }
202
262
  const configSuffix = defaultConfigPath(componentId);
203
263
  const config = await readJson(configSuffix, `${componentId} config`);
204
264
  if (componentId === 'transformer' && config && !config.weight_format) {
@@ -0,0 +1,29 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+ import type { WeightBuffer } from '../weight-buffer.js';
5
+
6
+ export interface DepthwiseConv2DOptions extends OutputBufferOptions {
7
+ channels: number;
8
+ height: number;
9
+ width: number;
10
+ kernelH: number;
11
+ kernelW: number;
12
+ stride?: number;
13
+ pad?: number;
14
+ }
15
+
16
+ export declare function runDepthwiseConv2D(
17
+ input: Tensor,
18
+ weight: GPUBuffer | WeightBuffer,
19
+ bias: GPUBuffer | WeightBuffer | null,
20
+ options: DepthwiseConv2DOptions
21
+ ): Promise<Tensor>;
22
+
23
+ export declare function recordDepthwiseConv2D(
24
+ recorder: CommandRecorder,
25
+ input: Tensor,
26
+ weight: GPUBuffer | WeightBuffer,
27
+ bias: GPUBuffer | WeightBuffer | null,
28
+ options: DepthwiseConv2DOptions
29
+ ): Promise<Tensor>;
@@ -0,0 +1,98 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { getBuffer } from '../weight-buffer.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+ import { WORKGROUP_SIZES } from './constants.js';
8
+
9
+ function selectDepthwiseConv2DVariant(isF16) {
10
+ return selectRuleValue('depthwiseConv2d', 'variant', { isF16 });
11
+ }
12
+
13
+ async function _depthwiseConv2D(target, input, weight, bias, options = {}) {
14
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
15
+ const device = target?.device || getDevice();
16
+ const {
17
+ channels,
18
+ height,
19
+ width,
20
+ kernelH,
21
+ kernelW,
22
+ stride = 1,
23
+ pad = 0,
24
+ outputBuffer = null,
25
+ } = options;
26
+
27
+ if (
28
+ !Number.isFinite(channels) ||
29
+ !Number.isFinite(height) ||
30
+ !Number.isFinite(width) ||
31
+ !Number.isFinite(kernelH) ||
32
+ !Number.isFinite(kernelW)
33
+ ) {
34
+ throw new Error('DepthwiseConv2D requires explicit dimensions.');
35
+ }
36
+
37
+ const outHeight = Math.floor((height + pad * 2 - kernelH) / stride) + 1;
38
+ const outWidth = Math.floor((width + pad * 2 - kernelW) / stride) + 1;
39
+ if (outHeight <= 0 || outWidth <= 0) {
40
+ throw new Error(`DepthwiseConv2D invalid output size: ${outHeight}x${outWidth}`);
41
+ }
42
+
43
+ const isF16 = input.dtype === 'f16';
44
+ const variant = selectDepthwiseConv2DVariant(isF16);
45
+ const bytesPerElement = dtypeBytes(input.dtype);
46
+ const outputSize = channels * outHeight * outWidth * bytesPerElement;
47
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'depthwise_conv2d_output');
48
+
49
+ const weightBuffer = getBuffer(weight);
50
+ let biasBuffer = getBuffer(bias);
51
+ let tempBias = null;
52
+ if (!biasBuffer) {
53
+ const biasSize = channels * bytesPerElement;
54
+ tempBias = acquireBuffer(biasSize, undefined, 'depthwise_conv2d_bias_zero');
55
+ biasBuffer = tempBias;
56
+ const paddedSize = Math.ceil(biasSize / 4) * 4;
57
+ device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
58
+ }
59
+
60
+ await unifiedKernelWrapper(
61
+ 'depthwise_conv2d',
62
+ target,
63
+ variant,
64
+ [input, weightBuffer, biasBuffer, output],
65
+ {
66
+ channels,
67
+ height,
68
+ width,
69
+ out_height: outHeight,
70
+ out_width: outWidth,
71
+ kernel_h: kernelH,
72
+ kernel_w: kernelW,
73
+ stride,
74
+ pad,
75
+ _pad0: 0,
76
+ _pad1: 0,
77
+ },
78
+ Math.ceil((channels * outHeight * outWidth) / WORKGROUP_SIZES.DEFAULT)
79
+ );
80
+
81
+ if (tempBias) {
82
+ if (recorder) {
83
+ recorder.trackTemporaryBuffer(tempBias);
84
+ } else {
85
+ releaseBuffer(tempBias);
86
+ }
87
+ }
88
+
89
+ return createTensor(output, input.dtype, [channels, outHeight, outWidth], 'depthwise_conv2d_output');
90
+ }
91
+
92
+ export async function runDepthwiseConv2D(input, weight, bias, options = {}) {
93
+ return _depthwiseConv2D(null, input, weight, bias, options);
94
+ }
95
+
96
+ export async function recordDepthwiseConv2D(recorder, input, weight, bias, options = {}) {
97
+ return _depthwiseConv2D(recorder, input, weight, bias, options);
98
+ }
@@ -0,0 +1,58 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ channels: u32,
5
+ height: u32,
6
+ width: u32,
7
+ out_height: u32,
8
+ out_width: u32,
9
+ kernel_h: u32,
10
+ kernel_w: u32,
11
+ stride: u32,
12
+ pad: u32,
13
+ _pad0: u32,
14
+ _pad1: u32,
15
+ _pad2: u32,
16
+ }
17
+
18
+ @group(0) @binding(0) var<uniform> u: Uniforms;
19
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
20
+ @group(0) @binding(2) var<storage, read> weight: array<f32>;
21
+ @group(0) @binding(3) var<storage, read> bias: array<f32>;
22
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
23
+
24
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let idx = gid.x;
27
+ let out_spatial = u.out_height * u.out_width;
28
+ let out_size = u.channels * out_spatial;
29
+ if (idx >= out_size) {
30
+ return;
31
+ }
32
+
33
+ let channel = idx / out_spatial;
34
+ let rem = idx - channel * out_spatial;
35
+ let out_y = rem / u.out_width;
36
+ let out_x = rem - out_y * u.out_width;
37
+
38
+ var sum: f32 = bias[channel];
39
+ let pad = i32(u.pad);
40
+
41
+ for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
42
+ let in_y = i32(out_y * u.stride + ky) - pad;
43
+ if (in_y < 0 || in_y >= i32(u.height)) {
44
+ continue;
45
+ }
46
+ for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
47
+ let in_x = i32(out_x * u.stride + kx) - pad;
48
+ if (in_x < 0 || in_x >= i32(u.width)) {
49
+ continue;
50
+ }
51
+ let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
52
+ let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
53
+ sum = sum + input[input_idx] * weight[weight_idx];
54
+ }
55
+ }
56
+
57
+ output[idx] = sum;
58
+ }
@@ -0,0 +1,62 @@
1
+ // Depthwise Conv2D Kernel (NCHW, f16)
2
+
3
+ enable f16;
4
+
5
+ override WORKGROUP_SIZE: u32 = 256u;
6
+
7
+ struct Uniforms {
8
+ channels: u32,
9
+ height: u32,
10
+ width: u32,
11
+ out_height: u32,
12
+ out_width: u32,
13
+ kernel_h: u32,
14
+ kernel_w: u32,
15
+ stride: u32,
16
+ pad: u32,
17
+ _pad0: u32,
18
+ _pad1: u32,
19
+ _pad2: u32,
20
+ }
21
+
22
+ @group(0) @binding(0) var<uniform> u: Uniforms;
23
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
24
+ @group(0) @binding(2) var<storage, read> weight: array<f16>;
25
+ @group(0) @binding(3) var<storage, read> bias: array<f16>;
26
+ @group(0) @binding(4) var<storage, read_write> output: array<f16>;
27
+
28
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
29
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
30
+ let idx = gid.x;
31
+ let out_spatial = u.out_height * u.out_width;
32
+ let out_size = u.channels * out_spatial;
33
+ if (idx >= out_size) {
34
+ return;
35
+ }
36
+
37
+ let channel = idx / out_spatial;
38
+ let rem = idx - channel * out_spatial;
39
+ let out_y = rem / u.out_width;
40
+ let out_x = rem - out_y * u.out_width;
41
+
42
+ var sum: f32 = f32(bias[channel]);
43
+ let pad = i32(u.pad);
44
+
45
+ for (var ky: u32 = 0u; ky < u.kernel_h; ky = ky + 1u) {
46
+ let in_y = i32(out_y * u.stride + ky) - pad;
47
+ if (in_y < 0 || in_y >= i32(u.height)) {
48
+ continue;
49
+ }
50
+ for (var kx: u32 = 0u; kx < u.kernel_w; kx = kx + 1u) {
51
+ let in_x = i32(out_x * u.stride + kx) - pad;
52
+ if (in_x < 0 || in_x >= i32(u.width)) {
53
+ continue;
54
+ }
55
+ let input_idx = (channel * u.height + u32(in_y)) * u.width + u32(in_x);
56
+ let weight_idx = ((channel * u.kernel_h + ky) * u.kernel_w + kx);
57
+ sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
58
+ }
59
+ }
60
+
61
+ output[idx] = f16(sum);
62
+ }
@@ -0,0 +1,27 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+ import type { WeightBuffer } from '../weight-buffer.js';
5
+
6
+ export interface GroupedPointwiseConv2DOptions extends OutputBufferOptions {
7
+ inChannels: number;
8
+ outChannels: number;
9
+ height: number;
10
+ width: number;
11
+ groups: number;
12
+ }
13
+
14
+ export declare function runGroupedPointwiseConv2D(
15
+ input: Tensor,
16
+ weight: GPUBuffer | WeightBuffer,
17
+ bias: GPUBuffer | WeightBuffer | null,
18
+ options: GroupedPointwiseConv2DOptions
19
+ ): Promise<Tensor>;
20
+
21
+ export declare function recordGroupedPointwiseConv2D(
22
+ recorder: CommandRecorder,
23
+ input: Tensor,
24
+ weight: GPUBuffer | WeightBuffer,
25
+ bias: GPUBuffer | WeightBuffer | null,
26
+ options: GroupedPointwiseConv2DOptions
27
+ ): Promise<Tensor>;
@@ -0,0 +1,92 @@
1
+ import { getDevice } from '../device.js';
2
+ import { acquireBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
3
+ import { createTensor, dtypeBytes } from '../tensor.js';
4
+ import { getBuffer } from '../weight-buffer.js';
5
+ import { unifiedKernelWrapper } from './utils.js';
6
+ import { selectRuleValue } from './rule-registry.js';
7
+ import { WORKGROUP_SIZES } from './constants.js';
8
+
9
+ function selectGroupedPointwiseConv2DVariant(isF16) {
10
+ return selectRuleValue('groupedPointwiseConv2d', 'variant', { isF16 });
11
+ }
12
+
13
+ async function _groupedPointwiseConv2D(target, input, weight, bias, options = {}) {
14
+ const recorder = target && typeof target.beginComputePass === 'function' ? target : null;
15
+ const device = target?.device || getDevice();
16
+ const {
17
+ inChannels,
18
+ outChannels,
19
+ height,
20
+ width,
21
+ groups,
22
+ outputBuffer = null,
23
+ } = options;
24
+
25
+ if (
26
+ !Number.isFinite(inChannels) ||
27
+ !Number.isFinite(outChannels) ||
28
+ !Number.isFinite(height) ||
29
+ !Number.isFinite(width) ||
30
+ !Number.isFinite(groups)
31
+ ) {
32
+ throw new Error('GroupedPointwiseConv2D requires explicit dimensions.');
33
+ }
34
+ if (groups <= 0 || inChannels % groups !== 0 || outChannels % groups !== 0) {
35
+ throw new Error(
36
+ `GroupedPointwiseConv2D requires inChannels/outChannels divisible by groups. Got ${inChannels}/${outChannels}/${groups}.`
37
+ );
38
+ }
39
+
40
+ const isF16 = input.dtype === 'f16';
41
+ const variant = selectGroupedPointwiseConv2DVariant(isF16);
42
+ const bytesPerElement = dtypeBytes(input.dtype);
43
+ const outputSize = outChannels * height * width * bytesPerElement;
44
+ const output = outputBuffer || acquireBuffer(outputSize, undefined, 'grouped_pointwise_conv2d_output');
45
+
46
+ const weightBuffer = getBuffer(weight);
47
+ let biasBuffer = getBuffer(bias);
48
+ let tempBias = null;
49
+ if (!biasBuffer) {
50
+ const biasSize = outChannels * bytesPerElement;
51
+ tempBias = acquireBuffer(biasSize, undefined, 'grouped_pointwise_conv2d_bias_zero');
52
+ biasBuffer = tempBias;
53
+ const paddedSize = Math.ceil(biasSize / 4) * 4;
54
+ device.queue.writeBuffer(biasBuffer, 0, new Uint8Array(paddedSize));
55
+ }
56
+
57
+ await unifiedKernelWrapper(
58
+ 'grouped_pointwise_conv2d',
59
+ target,
60
+ variant,
61
+ [input, weightBuffer, biasBuffer, output],
62
+ {
63
+ in_channels: inChannels,
64
+ out_channels: outChannels,
65
+ height,
66
+ width,
67
+ groups,
68
+ _pad0: 0,
69
+ _pad1: 0,
70
+ _pad2: 0,
71
+ },
72
+ Math.ceil((outChannels * height * width) / WORKGROUP_SIZES.DEFAULT)
73
+ );
74
+
75
+ if (tempBias) {
76
+ if (recorder) {
77
+ recorder.trackTemporaryBuffer(tempBias);
78
+ } else {
79
+ releaseBuffer(tempBias);
80
+ }
81
+ }
82
+
83
+ return createTensor(output, input.dtype, [outChannels, height, width], 'grouped_pointwise_conv2d_output');
84
+ }
85
+
86
+ export async function runGroupedPointwiseConv2D(input, weight, bias, options = {}) {
87
+ return _groupedPointwiseConv2D(null, input, weight, bias, options);
88
+ }
89
+
90
+ export async function recordGroupedPointwiseConv2D(recorder, input, weight, bias, options = {}) {
91
+ return _groupedPointwiseConv2D(recorder, input, weight, bias, options);
92
+ }
@@ -0,0 +1,47 @@
1
+ override WORKGROUP_SIZE: u32 = 256u;
2
+
3
+ struct Uniforms {
4
+ in_channels: u32,
5
+ out_channels: u32,
6
+ height: u32,
7
+ width: u32,
8
+ groups: u32,
9
+ _pad0: u32,
10
+ _pad1: u32,
11
+ _pad2: u32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<uniform> u: Uniforms;
15
+ @group(0) @binding(1) var<storage, read> input: array<f32>;
16
+ @group(0) @binding(2) var<storage, read> weight: array<f32>;
17
+ @group(0) @binding(3) var<storage, read> bias: array<f32>;
18
+ @group(0) @binding(4) var<storage, read_write> output: array<f32>;
19
+
20
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
21
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22
+ let idx = gid.x;
23
+ let spatial = u.height * u.width;
24
+ let out_size = u.out_channels * spatial;
25
+ if (idx >= out_size) {
26
+ return;
27
+ }
28
+
29
+ let out_channel = idx / spatial;
30
+ let rem = idx - out_channel * spatial;
31
+ let y = rem / u.width;
32
+ let x = rem - y * u.width;
33
+
34
+ let in_per_group = u.in_channels / u.groups;
35
+ let out_per_group = u.out_channels / u.groups;
36
+ let group_idx = out_channel / out_per_group;
37
+ let in_offset = group_idx * in_per_group;
38
+
39
+ var sum: f32 = bias[out_channel];
40
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
41
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
42
+ let weight_idx = out_channel * in_per_group + i;
43
+ sum = sum + input[input_idx] * weight[weight_idx];
44
+ }
45
+
46
+ output[idx] = sum;
47
+ }
@@ -0,0 +1,51 @@
1
+ // Grouped Pointwise Conv2D Kernel (NCHW, f16)
2
+
3
+ enable f16;
4
+
5
+ override WORKGROUP_SIZE: u32 = 256u;
6
+
7
+ struct Uniforms {
8
+ in_channels: u32,
9
+ out_channels: u32,
10
+ height: u32,
11
+ width: u32,
12
+ groups: u32,
13
+ _pad0: u32,
14
+ _pad1: u32,
15
+ _pad2: u32,
16
+ }
17
+
18
+ @group(0) @binding(0) var<uniform> u: Uniforms;
19
+ @group(0) @binding(1) var<storage, read> input: array<f16>;
20
+ @group(0) @binding(2) var<storage, read> weight: array<f16>;
21
+ @group(0) @binding(3) var<storage, read> bias: array<f16>;
22
+ @group(0) @binding(4) var<storage, read_write> output: array<f16>;
23
+
24
+ @compute @workgroup_size(WORKGROUP_SIZE, 1, 1)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let idx = gid.x;
27
+ let spatial = u.height * u.width;
28
+ let out_size = u.out_channels * spatial;
29
+ if (idx >= out_size) {
30
+ return;
31
+ }
32
+
33
+ let out_channel = idx / spatial;
34
+ let rem = idx - out_channel * spatial;
35
+ let y = rem / u.width;
36
+ let x = rem - y * u.width;
37
+
38
+ let in_per_group = u.in_channels / u.groups;
39
+ let out_per_group = u.out_channels / u.groups;
40
+ let group_idx = out_channel / out_per_group;
41
+ let in_offset = group_idx * in_per_group;
42
+
43
+ var sum: f32 = f32(bias[out_channel]);
44
+ for (var i: u32 = 0u; i < in_per_group; i = i + 1u) {
45
+ let input_idx = ((in_offset + i) * u.height + y) * u.width + x;
46
+ let weight_idx = out_channel * in_per_group + i;
47
+ sum = sum + f32(input[input_idx]) * f32(weight[weight_idx]);
48
+ }
49
+
50
+ output[idx] = f16(sum);
51
+ }
@@ -174,6 +174,18 @@ export {
174
174
  type Conv2DOptions,
175
175
  } from './conv2d.js';
176
176
 
177
+ export {
178
+ runDepthwiseConv2D,
179
+ recordDepthwiseConv2D,
180
+ type DepthwiseConv2DOptions,
181
+ } from './depthwise_conv2d.js';
182
+
183
+ export {
184
+ runGroupedPointwiseConv2D,
185
+ recordGroupedPointwiseConv2D,
186
+ type GroupedPointwiseConv2DOptions,
187
+ } from './grouped_pointwise_conv2d.js';
188
+
177
189
  // Gather (Embedding Lookup)
178
190
  export {
179
191
  runGather,
@@ -250,6 +262,24 @@ export {
250
262
  type SampleResult,
251
263
  } from './sample.js';
252
264
 
265
+ export {
266
+ runSanaLinearAttention,
267
+ recordSanaLinearAttention,
268
+ type SanaLinearAttentionOptions,
269
+ } from './sana_linear_attention.js';
270
+
271
+ export {
272
+ runRepeatChannels,
273
+ recordRepeatChannels,
274
+ type RepeatChannelsOptions,
275
+ } from './repeat_channels.js';
276
+
277
+ export {
278
+ runReLU,
279
+ recordReLU,
280
+ type ReLUOptions,
281
+ } from './relu.js';
282
+
253
283
  // Fused FFN (Tier 2 P0)
254
284
  export {
255
285
  runFusedFFN,
@@ -139,6 +139,16 @@ export {
139
139
  recordConv2D,
140
140
  } from './conv2d.js';
141
141
 
142
+ export {
143
+ runDepthwiseConv2D,
144
+ recordDepthwiseConv2D,
145
+ } from './depthwise_conv2d.js';
146
+
147
+ export {
148
+ runGroupedPointwiseConv2D,
149
+ recordGroupedPointwiseConv2D,
150
+ } from './grouped_pointwise_conv2d.js';
151
+
142
152
  // Gather (Embedding Lookup)
143
153
  export {
144
154
  runGather,
@@ -205,6 +215,21 @@ export {
205
215
  isGPUSamplingAvailable,
206
216
  } from './sample.js';
207
217
 
218
+ export {
219
+ runSanaLinearAttention,
220
+ recordSanaLinearAttention,
221
+ } from './sana_linear_attention.js';
222
+
223
+ export {
224
+ runRepeatChannels,
225
+ recordRepeatChannels,
226
+ } from './repeat_channels.js';
227
+
228
+ export {
229
+ runReLU,
230
+ recordReLU,
231
+ } from './relu.js';
232
+
208
233
  // Fused FFN (Tier 2 P0)
209
234
  export {
210
235
  runFusedFFN,
@@ -0,0 +1,18 @@
1
+ import type { Tensor } from '../tensor.js';
2
+ import type { CommandRecorder } from '../command-recorder.js';
3
+ import type { OutputBufferOptions } from './types.js';
4
+
5
+ export interface ReLUOptions extends OutputBufferOptions {
6
+ count?: number | null;
7
+ }
8
+
9
+ export declare function runReLU(
10
+ input: Tensor,
11
+ options?: ReLUOptions
12
+ ): Promise<Tensor>;
13
+
14
+ export declare function recordReLU(
15
+ recorder: CommandRecorder,
16
+ input: Tensor,
17
+ options?: ReLUOptions
18
+ ): Promise<Tensor>;