mambacode.js 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +59 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
- package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +312 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/training/trainer.js +0 -394
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
// 1D Causal Convolution WGSL Kernel
|
|
2
|
+
// Implements a depthwise 1D causal convolution over the sequence dimension.
|
|
3
|
+
// "Causal" means the output at position t only depends on positions <= t,
|
|
4
|
+
// which is enforced by left-padding with (kernel_size - 1) zeros.
|
|
5
|
+
//
|
|
6
|
+
// Forward: y[b, t, d] = sum_{k=0}^{K-1} weight[d, k] * x[b, t-k, d]
|
|
7
|
+
// where x[b, t', d] = 0 for t' < 0 (causal padding)
|
|
8
|
+
export const CONV1D_FORWARD_WGSL = /* wgsl */ `
|
|
9
|
+
|
|
10
|
+
struct ConvParams {
|
|
11
|
+
seq_len : u32, // L
|
|
12
|
+
d_channels : u32, // D (number of depthwise channels)
|
|
13
|
+
kernel_size : u32, // K (typically 4)
|
|
14
|
+
batch : u32, // B
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
@group(0) @binding(0) var<uniform> params : ConvParams;
|
|
18
|
+
// x (B, L, D) – input
|
|
19
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
20
|
+
// weight (D, K) – depthwise conv weights
|
|
21
|
+
@group(0) @binding(2) var<storage, read> weight : array<f32>;
|
|
22
|
+
// bias (D,) – optional bias (zeros if unused)
|
|
23
|
+
@group(0) @binding(3) var<storage, read> bias : array<f32>;
|
|
24
|
+
// y (B, L, D) – output
|
|
25
|
+
@group(0) @binding(4) var<storage, read_write> y : array<f32>;
|
|
26
|
+
|
|
27
|
+
// Dispatch: (ceil(L/16), ceil(D/16), B)
|
|
28
|
+
@compute @workgroup_size(16, 16, 1)
|
|
29
|
+
fn conv1d_forward(
|
|
30
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
31
|
+
) {
|
|
32
|
+
let L = params.seq_len;
|
|
33
|
+
let D = params.d_channels;
|
|
34
|
+
let K = params.kernel_size;
|
|
35
|
+
let B = params.batch;
|
|
36
|
+
|
|
37
|
+
let t = gid.x; // time position
|
|
38
|
+
let d = gid.y; // channel
|
|
39
|
+
let b = gid.z; // batch
|
|
40
|
+
|
|
41
|
+
if (t >= L || d >= D || b >= B) { return; }
|
|
42
|
+
|
|
43
|
+
var acc: f32 = 0.0;
|
|
44
|
+
|
|
45
|
+
// Causal: convolve over k = 0..K-1, reading position (t - k)
|
|
46
|
+
for (var k: u32 = 0u; k < K; k = k + 1u) {
|
|
47
|
+
let w_idx = d * K + k;
|
|
48
|
+
let w_val = weight[w_idx];
|
|
49
|
+
|
|
50
|
+
// t - k: use causal zero-padding for t < k
|
|
51
|
+
if (t >= k) {
|
|
52
|
+
let src = b * L * D + (t - k) * D + d;
|
|
53
|
+
acc = acc + w_val * x[src];
|
|
54
|
+
}
|
|
55
|
+
// else: zero-padding contributes 0
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
acc = acc + bias[d];
|
|
59
|
+
|
|
60
|
+
let out = b * L * D + t * D + d;
|
|
61
|
+
y[out] = acc;
|
|
62
|
+
}
|
|
63
|
+
`;
|
|
64
|
+
// ---- Backward kernel for 1D convolution ----
|
|
65
|
+
export const CONV1D_BACKWARD_WGSL = /* wgsl */ `
|
|
66
|
+
|
|
67
|
+
struct ConvParams {
|
|
68
|
+
seq_len : u32,
|
|
69
|
+
d_channels : u32,
|
|
70
|
+
kernel_size : u32,
|
|
71
|
+
batch : u32,
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
@group(0) @binding(0) var<uniform> params : ConvParams;
|
|
75
|
+
@group(0) @binding(1) var<storage, read> x : array<f32>;
|
|
76
|
+
@group(0) @binding(2) var<storage, read> weight : array<f32>;
|
|
77
|
+
@group(0) @binding(3) var<storage, read> dy : array<f32>;
|
|
78
|
+
@group(0) @binding(4) var<storage, read_write> dx : array<f32>;
|
|
79
|
+
@group(0) @binding(5) var<storage, read_write> dweight : array<f32>;
|
|
80
|
+
@group(0) @binding(6) var<storage, read_write> dbias : array<f32>;
|
|
81
|
+
|
|
82
|
+
// Dispatch: (ceil(L/16), ceil(D/16), B) – computes dx
|
|
83
|
+
@compute @workgroup_size(16, 16, 1)
|
|
84
|
+
fn conv1d_backward_dx(
|
|
85
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
86
|
+
) {
|
|
87
|
+
let L = params.seq_len;
|
|
88
|
+
let D = params.d_channels;
|
|
89
|
+
let K = params.kernel_size;
|
|
90
|
+
let B = params.batch;
|
|
91
|
+
|
|
92
|
+
let t = gid.x;
|
|
93
|
+
let d = gid.y;
|
|
94
|
+
let b = gid.z;
|
|
95
|
+
|
|
96
|
+
if (t >= L || d >= D || b >= B) { return; }
|
|
97
|
+
|
|
98
|
+
var grad: f32 = 0.0;
|
|
99
|
+
|
|
100
|
+
// dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]
|
|
101
|
+
for (var k: u32 = 0u; k < K; k = k + 1u) {
|
|
102
|
+
let tp = t + k;
|
|
103
|
+
if (tp < L) {
|
|
104
|
+
let dy_idx = b * L * D + tp * D + d;
|
|
105
|
+
let w_idx = d * K + k;
|
|
106
|
+
grad = grad + dy[dy_idx] * weight[w_idx];
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
let dx_idx = b * L * D + t * D + d;
|
|
111
|
+
dx[dx_idx] = grad;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// Dispatch: (K, D, 1) – accumulates dweight over (B, L)
|
|
115
|
+
@compute @workgroup_size(1, 1, 1)
|
|
116
|
+
fn conv1d_backward_dw(
|
|
117
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
118
|
+
) {
|
|
119
|
+
let L = params.seq_len;
|
|
120
|
+
let D = params.d_channels;
|
|
121
|
+
let K = params.kernel_size;
|
|
122
|
+
let B = params.batch;
|
|
123
|
+
|
|
124
|
+
let k = gid.x;
|
|
125
|
+
let d = gid.y;
|
|
126
|
+
|
|
127
|
+
if (k >= K || d >= D) { return; }
|
|
128
|
+
|
|
129
|
+
var grad_w: f32 = 0.0;
|
|
130
|
+
var grad_b: f32 = 0.0;
|
|
131
|
+
|
|
132
|
+
for (var b: u32 = 0u; b < B; b = b + 1u) {
|
|
133
|
+
for (var t: u32 = 0u; t < L; t = t + 1u) {
|
|
134
|
+
let dy_idx = b * L * D + t * D + d;
|
|
135
|
+
let dy_val = dy[dy_idx];
|
|
136
|
+
if (t >= k) {
|
|
137
|
+
let x_idx = b * L * D + (t - k) * D + d;
|
|
138
|
+
grad_w = grad_w + dy_val * x[x_idx];
|
|
139
|
+
}
|
|
140
|
+
if (k == 0u) {
|
|
141
|
+
grad_b = grad_b + dy_val;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
dweight[d * K + k] = grad_w;
|
|
147
|
+
if (k == 0u) {
|
|
148
|
+
dbias[d] = grad_b;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
`;
|
|
152
|
+
//# sourceMappingURL=conv1d.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"conv1d.js","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,4EAA4E;AAC5E,0EAA0E;AAC1E,kEAAkE;AAClE,EAAE;AACF,qEAAqE;AACrE,+DAA+D;AAE/D,MAAM,CAAC,MAAM,mBAAmB,GAAG,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuD5C,CAAC;AAEF,+CAA+C;AAC/C,MAAM,CAAC,MAAM,oBAAoB,GAAG,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAsF7C,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"linear_projection.d.ts","sourceRoot":"","sources":["../../src/kernels/linear_projection.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,mBAAmB,EAAE,MAmEjC,CAAC;AAGF,eAAO,MAAM,oBAAoB,EAAE,MA2IlC,CAAC"}
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
// Linear Projection WGSL Kernel
|
|
2
|
+
// General-purpose matrix multiplication: Y = X @ W^T + b
|
|
3
|
+
// Supports the up-projection and down-projection linear layers in the Mamba block.
|
|
4
|
+
//
|
|
5
|
+
// Shapes:
|
|
6
|
+
// X : (batch * seq_len, in_features) – input (rows)
|
|
7
|
+
// W : (out_features, in_features) – weight matrix (row-major)
|
|
8
|
+
// b : (out_features,) – bias
|
|
9
|
+
// Y : (batch * seq_len, out_features) – output
|
|
10
|
+
export const LINEAR_FORWARD_WGSL = /* wgsl */ `
|
|
11
|
+
|
|
12
|
+
struct LinearParams {
|
|
13
|
+
M : u32, // number of rows (batch * seq_len)
|
|
14
|
+
K : u32, // in_features
|
|
15
|
+
N : u32, // out_features
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
@group(0) @binding(0) var<uniform> params : LinearParams;
|
|
19
|
+
@group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
|
|
20
|
+
@group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
|
|
21
|
+
@group(0) @binding(3) var<storage, read> bias : array<f32>; // (N,)
|
|
22
|
+
@group(0) @binding(4) var<storage, read_write> Y : array<f32>; // (M, N)
|
|
23
|
+
|
|
24
|
+
// Tiled matmul using workgroup shared memory (16x16 tiles)
|
|
25
|
+
var<workgroup> tile_X : array<f32, 256>; // 16 * 16
|
|
26
|
+
var<workgroup> tile_W : array<f32, 256>;
|
|
27
|
+
|
|
28
|
+
@compute @workgroup_size(16, 16, 1)
|
|
29
|
+
fn linear_forward(
|
|
30
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
31
|
+
@builtin(local_invocation_id) lid : vec3<u32>,
|
|
32
|
+
@builtin(workgroup_id) wid : vec3<u32>,
|
|
33
|
+
) {
|
|
34
|
+
let M = params.M;
|
|
35
|
+
let K = params.K;
|
|
36
|
+
let N = params.N;
|
|
37
|
+
|
|
38
|
+
let row = gid.x; // output row (M dimension)
|
|
39
|
+
let col = gid.y; // output col (N dimension)
|
|
40
|
+
|
|
41
|
+
var acc: f32 = 0.0;
|
|
42
|
+
let TILE: u32 = 16u;
|
|
43
|
+
let num_tiles = (K + TILE - 1u) / TILE;
|
|
44
|
+
|
|
45
|
+
for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
|
|
46
|
+
// Load X tile: shape (TILE_M, TILE_K)
|
|
47
|
+
let x_col = tile_idx * TILE + lid.y;
|
|
48
|
+
let x_row = wid.x * TILE + lid.x;
|
|
49
|
+
if (x_row < M && x_col < K) {
|
|
50
|
+
tile_X[lid.x * TILE + lid.y] = X[x_row * K + x_col];
|
|
51
|
+
} else {
|
|
52
|
+
tile_X[lid.x * TILE + lid.y] = 0.0;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// Load W tile: shape (TILE_N, TILE_K) — W is (N, K)
|
|
56
|
+
let w_col = tile_idx * TILE + lid.x; // K dimension
|
|
57
|
+
let w_row = wid.y * TILE + lid.y; // N dimension
|
|
58
|
+
if (w_row < N && w_col < K) {
|
|
59
|
+
tile_W[lid.y * TILE + lid.x] = W[w_row * K + w_col];
|
|
60
|
+
} else {
|
|
61
|
+
tile_W[lid.y * TILE + lid.x] = 0.0;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
workgroupBarrier();
|
|
65
|
+
|
|
66
|
+
// Dot product within tile
|
|
67
|
+
for (var k: u32 = 0u; k < TILE; k = k + 1u) {
|
|
68
|
+
acc = acc + tile_X[lid.x * TILE + k] * tile_W[lid.y * TILE + k];
|
|
69
|
+
}
|
|
70
|
+
workgroupBarrier();
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if (row < M && col < N) {
|
|
74
|
+
Y[row * N + col] = acc + bias[col];
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
`;
|
|
78
|
+
// ---- Backward pass for linear projection ----
|
|
79
|
+
export const LINEAR_BACKWARD_WGSL = /* wgsl */ `
|
|
80
|
+
|
|
81
|
+
struct LinearParams {
|
|
82
|
+
M : u32,
|
|
83
|
+
K : u32,
|
|
84
|
+
N : u32,
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
@group(0) @binding(0) var<uniform> params : LinearParams;
|
|
88
|
+
@group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
|
|
89
|
+
@group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
|
|
90
|
+
@group(0) @binding(3) var<storage, read> dY : array<f32>; // (M, N)
|
|
91
|
+
@group(0) @binding(4) var<storage, read_write> dX : array<f32>; // (M, K)
|
|
92
|
+
@group(0) @binding(5) var<storage, read_write> dW : array<f32>; // (N, K)
|
|
93
|
+
@group(0) @binding(6) var<storage, read_write> db : array<f32>; // (N,)
|
|
94
|
+
|
|
95
|
+
// Dispatch: (ceil(M/16), ceil(K/16), 1) – computes dX = dY @ W
|
|
96
|
+
var<workgroup> tile_dY : array<f32, 256>;
|
|
97
|
+
var<workgroup> tile_W : array<f32, 256>;
|
|
98
|
+
|
|
99
|
+
@compute @workgroup_size(16, 16, 1)
|
|
100
|
+
fn linear_backward_dX(
|
|
101
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
102
|
+
@builtin(local_invocation_id) lid : vec3<u32>,
|
|
103
|
+
@builtin(workgroup_id) wid : vec3<u32>,
|
|
104
|
+
) {
|
|
105
|
+
let M = params.M;
|
|
106
|
+
let K = params.K;
|
|
107
|
+
let N = params.N;
|
|
108
|
+
|
|
109
|
+
let row = gid.x; // M
|
|
110
|
+
let col = gid.y; // K
|
|
111
|
+
|
|
112
|
+
var acc: f32 = 0.0;
|
|
113
|
+
let TILE: u32 = 16u;
|
|
114
|
+
let num_tiles = (N + TILE - 1u) / TILE;
|
|
115
|
+
|
|
116
|
+
for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
|
|
117
|
+
// tile_dY: (M, TILE_N) slice
|
|
118
|
+
let dy_col = tile_idx * TILE + lid.y;
|
|
119
|
+
let dy_row = wid.x * TILE + lid.x;
|
|
120
|
+
if (dy_row < M && dy_col < N) {
|
|
121
|
+
tile_dY[lid.x * TILE + lid.y] = dY[dy_row * N + dy_col];
|
|
122
|
+
} else {
|
|
123
|
+
tile_dY[lid.x * TILE + lid.y] = 0.0;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// tile_W: (TILE_N, K) slice — W[n, k]
|
|
127
|
+
let w_row = tile_idx * TILE + lid.x; // N
|
|
128
|
+
let w_col = wid.y * TILE + lid.y; // K
|
|
129
|
+
if (w_row < N && w_col < K) {
|
|
130
|
+
tile_W[lid.x * TILE + lid.y] = W[w_row * K + w_col];
|
|
131
|
+
} else {
|
|
132
|
+
tile_W[lid.x * TILE + lid.y] = 0.0;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
workgroupBarrier();
|
|
136
|
+
|
|
137
|
+
for (var n: u32 = 0u; n < TILE; n = n + 1u) {
|
|
138
|
+
acc = acc + tile_dY[lid.x * TILE + n] * tile_W[n * TILE + lid.y];
|
|
139
|
+
}
|
|
140
|
+
workgroupBarrier();
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if (row < M && col < K) {
|
|
144
|
+
dX[row * K + col] = acc;
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Dispatch: (ceil(N/16), ceil(K/16), 1) – computes dW = dY^T @ X
|
|
149
|
+
var<workgroup> tile_dY2 : array<f32, 256>;
|
|
150
|
+
var<workgroup> tile_X2 : array<f32, 256>;
|
|
151
|
+
|
|
152
|
+
@compute @workgroup_size(16, 16, 1)
|
|
153
|
+
fn linear_backward_dW(
|
|
154
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
155
|
+
@builtin(local_invocation_id) lid : vec3<u32>,
|
|
156
|
+
@builtin(workgroup_id) wid : vec3<u32>,
|
|
157
|
+
) {
|
|
158
|
+
let M = params.M;
|
|
159
|
+
let K = params.K;
|
|
160
|
+
let N = params.N;
|
|
161
|
+
|
|
162
|
+
let row = gid.x; // N
|
|
163
|
+
let col = gid.y; // K
|
|
164
|
+
|
|
165
|
+
var acc: f32 = 0.0;
|
|
166
|
+
let TILE: u32 = 16u;
|
|
167
|
+
let num_tiles = (M + TILE - 1u) / TILE;
|
|
168
|
+
|
|
169
|
+
for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
|
|
170
|
+
// dY^T tile: [N, M] accessed as dY[m, n]
|
|
171
|
+
let m_idx = tile_idx * TILE + lid.y;
|
|
172
|
+
let n_idx = wid.x * TILE + lid.x;
|
|
173
|
+
if (n_idx < N && m_idx < M) {
|
|
174
|
+
tile_dY2[lid.x * TILE + lid.y] = dY[m_idx * N + n_idx];
|
|
175
|
+
} else {
|
|
176
|
+
tile_dY2[lid.x * TILE + lid.y] = 0.0;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// X tile: [M, K]
|
|
180
|
+
let xm = tile_idx * TILE + lid.x;
|
|
181
|
+
let xk = wid.y * TILE + lid.y;
|
|
182
|
+
if (xm < M && xk < K) {
|
|
183
|
+
tile_X2[lid.x * TILE + lid.y] = X[xm * K + xk];
|
|
184
|
+
} else {
|
|
185
|
+
tile_X2[lid.x * TILE + lid.y] = 0.0;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
workgroupBarrier();
|
|
189
|
+
|
|
190
|
+
for (var m: u32 = 0u; m < TILE; m = m + 1u) {
|
|
191
|
+
acc = acc + tile_dY2[lid.x * TILE + m] * tile_X2[m * TILE + lid.y];
|
|
192
|
+
}
|
|
193
|
+
workgroupBarrier();
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if (row < N && col < K) {
|
|
197
|
+
dW[row * K + col] = acc;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// Dispatch: (N, 1, 1) – accumulates db = sum_M dY
|
|
202
|
+
@compute @workgroup_size(64, 1, 1)
|
|
203
|
+
fn linear_backward_db(
|
|
204
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
205
|
+
) {
|
|
206
|
+
let M = params.M;
|
|
207
|
+
let N = params.N;
|
|
208
|
+
|
|
209
|
+
let n = gid.x;
|
|
210
|
+
if (n >= N) { return; }
|
|
211
|
+
|
|
212
|
+
var acc: f32 = 0.0;
|
|
213
|
+
for (var m: u32 = 0u; m < M; m = m + 1u) {
|
|
214
|
+
acc = acc + dY[m * N + n];
|
|
215
|
+
}
|
|
216
|
+
db[n] = acc;
|
|
217
|
+
}
|
|
218
|
+
`;
|
|
219
|
+
//# sourceMappingURL=linear_projection.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"linear_projection.js","sourceRoot":"","sources":["../../src/kernels/linear_projection.ts"],"names":[],"mappings":"AAAA,gCAAgC;AAChC,yDAAyD;AACzD,mFAAmF;AACnF,EAAE;AACF,UAAU;AACV,wDAAwD;AACxD,qEAAqE;AACrE,gDAAgD;AAChD,kDAAkD;AAElD,MAAM,CAAC,MAAM,mBAAmB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAmEpD,CAAC;AAEF,gDAAgD;AAChD,MAAM,CAAC,MAAM,oBAAoB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA2IrD,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"selective_scan.d.ts","sourceRoot":"","sources":["../../src/kernels/selective_scan.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,2BAA2B,EAAE,MAsNzC,CAAC;AAKF,eAAO,MAAM,4BAA4B,EAAE,MAwH1C,CAAC"}
|