mambacode.js 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +18 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +59 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
  61. package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/trainer.ts +312 -0
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/training/trainer.js +0 -394
  71. package/src/utils/gpu_utils.js +0 -217
  72. package/src/utils/quantization.js +0 -215
  73. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -0,0 +1,152 @@
1
+ // 1D Causal Convolution WGSL Kernel
2
+ // Implements a depthwise 1D causal convolution over the sequence dimension.
3
+ // "Causal" means the output at position t only depends on positions <= t,
4
+ // which is enforced by left-padding with (kernel_size - 1) zeros.
5
+ //
6
+ // Forward: y[b, t, d] = sum_{k=0}^{K-1} weight[d, k] * x[b, t-k, d]
7
+ // where x[b, t', d] = 0 for t' < 0 (causal padding)
8
+ export const CONV1D_FORWARD_WGSL = /* wgsl */ `
9
+
10
+ struct ConvParams {
11
+ seq_len : u32, // L
12
+ d_channels : u32, // D (number of depthwise channels)
13
+ kernel_size : u32, // K (typically 4)
14
+ batch : u32, // B
15
+ };
16
+
17
+ @group(0) @binding(0) var<uniform> params : ConvParams;
18
+ // x (B, L, D) – input
19
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
20
+ // weight (D, K) – depthwise conv weights
21
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
22
+ // bias (D,) – optional bias (zeros if unused)
23
+ @group(0) @binding(3) var<storage, read> bias : array<f32>;
24
+ // y (B, L, D) – output
25
+ @group(0) @binding(4) var<storage, read_write> y : array<f32>;
26
+
27
+ // Dispatch: (ceil(L/16), ceil(D/16), B)
28
+ @compute @workgroup_size(16, 16, 1)
29
+ fn conv1d_forward(
30
+ @builtin(global_invocation_id) gid : vec3<u32>,
31
+ ) {
32
+ let L = params.seq_len;
33
+ let D = params.d_channels;
34
+ let K = params.kernel_size;
35
+ let B = params.batch;
36
+
37
+ let t = gid.x; // time position
38
+ let d = gid.y; // channel
39
+ let b = gid.z; // batch
40
+
41
+ if (t >= L || d >= D || b >= B) { return; }
42
+
43
+ var acc: f32 = 0.0;
44
+
45
+ // Causal: convolve over k = 0..K-1, reading position (t - k)
46
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
47
+ let w_idx = d * K + k;
48
+ let w_val = weight[w_idx];
49
+
50
+ // t - k: use causal zero-padding for t < k
51
+ if (t >= k) {
52
+ let src = b * L * D + (t - k) * D + d;
53
+ acc = acc + w_val * x[src];
54
+ }
55
+ // else: zero-padding contributes 0
56
+ }
57
+
58
+ acc = acc + bias[d];
59
+
60
+ let out = b * L * D + t * D + d;
61
+ y[out] = acc;
62
+ }
63
+ `;
64
+ // ---- Backward kernel for 1D convolution ----
65
+ export const CONV1D_BACKWARD_WGSL = /* wgsl */ `
66
+
67
+ struct ConvParams {
68
+ seq_len : u32,
69
+ d_channels : u32,
70
+ kernel_size : u32,
71
+ batch : u32,
72
+ };
73
+
74
+ @group(0) @binding(0) var<uniform> params : ConvParams;
75
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
76
+ @group(0) @binding(2) var<storage, read> weight : array<f32>;
77
+ @group(0) @binding(3) var<storage, read> dy : array<f32>;
78
+ @group(0) @binding(4) var<storage, read_write> dx : array<f32>;
79
+ @group(0) @binding(5) var<storage, read_write> dweight : array<f32>;
80
+ @group(0) @binding(6) var<storage, read_write> dbias : array<f32>;
81
+
82
+ // Dispatch: (ceil(L/16), ceil(D/16), B) – computes dx
83
+ @compute @workgroup_size(16, 16, 1)
84
+ fn conv1d_backward_dx(
85
+ @builtin(global_invocation_id) gid : vec3<u32>,
86
+ ) {
87
+ let L = params.seq_len;
88
+ let D = params.d_channels;
89
+ let K = params.kernel_size;
90
+ let B = params.batch;
91
+
92
+ let t = gid.x;
93
+ let d = gid.y;
94
+ let b = gid.z;
95
+
96
+ if (t >= L || d >= D || b >= B) { return; }
97
+
98
+ var grad: f32 = 0.0;
99
+
100
+ // dx[b, t, d] = sum_{k=0}^{K-1} dy[b, t+k, d] * weight[d, k]
101
+ for (var k: u32 = 0u; k < K; k = k + 1u) {
102
+ let tp = t + k;
103
+ if (tp < L) {
104
+ let dy_idx = b * L * D + tp * D + d;
105
+ let w_idx = d * K + k;
106
+ grad = grad + dy[dy_idx] * weight[w_idx];
107
+ }
108
+ }
109
+
110
+ let dx_idx = b * L * D + t * D + d;
111
+ dx[dx_idx] = grad;
112
+ }
113
+
114
+ // Dispatch: (K, D, 1) – accumulates dweight over (B, L)
115
+ @compute @workgroup_size(1, 1, 1)
116
+ fn conv1d_backward_dw(
117
+ @builtin(global_invocation_id) gid : vec3<u32>,
118
+ ) {
119
+ let L = params.seq_len;
120
+ let D = params.d_channels;
121
+ let K = params.kernel_size;
122
+ let B = params.batch;
123
+
124
+ let k = gid.x;
125
+ let d = gid.y;
126
+
127
+ if (k >= K || d >= D) { return; }
128
+
129
+ var grad_w: f32 = 0.0;
130
+ var grad_b: f32 = 0.0;
131
+
132
+ for (var b: u32 = 0u; b < B; b = b + 1u) {
133
+ for (var t: u32 = 0u; t < L; t = t + 1u) {
134
+ let dy_idx = b * L * D + t * D + d;
135
+ let dy_val = dy[dy_idx];
136
+ if (t >= k) {
137
+ let x_idx = b * L * D + (t - k) * D + d;
138
+ grad_w = grad_w + dy_val * x[x_idx];
139
+ }
140
+ if (k == 0u) {
141
+ grad_b = grad_b + dy_val;
142
+ }
143
+ }
144
+ }
145
+
146
+ dweight[d * K + k] = grad_w;
147
+ if (k == 0u) {
148
+ dbias[d] = grad_b;
149
+ }
150
+ }
151
+ `;
152
+ //# sourceMappingURL=conv1d.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"conv1d.js","sourceRoot":"","sources":["../../src/kernels/conv1d.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,4EAA4E;AAC5E,0EAA0E;AAC1E,kEAAkE;AAClE,EAAE;AACF,qEAAqE;AACrE,+DAA+D;AAE/D,MAAM,CAAC,MAAM,mBAAmB,GAAG,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuD5C,CAAC;AAEF,+CAA+C;AAC/C,MAAM,CAAC,MAAM,oBAAoB,GAAG,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAsF7C,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const LINEAR_FORWARD_WGSL: string;
2
+ export declare const LINEAR_BACKWARD_WGSL: string;
3
+ //# sourceMappingURL=linear_projection.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"linear_projection.d.ts","sourceRoot":"","sources":["../../src/kernels/linear_projection.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,mBAAmB,EAAE,MAmEjC,CAAC;AAGF,eAAO,MAAM,oBAAoB,EAAE,MA2IlC,CAAC"}
@@ -0,0 +1,219 @@
1
+ // Linear Projection WGSL Kernel
2
+ // General-purpose matrix multiplication: Y = X @ W^T + b
3
+ // Supports the up-projection and down-projection linear layers in the Mamba block.
4
+ //
5
+ // Shapes:
6
+ // X : (batch * seq_len, in_features) – input (rows)
7
+ // W : (out_features, in_features) – weight matrix (row-major)
8
+ // b : (out_features,) – bias
9
+ // Y : (batch * seq_len, out_features) – output
10
+ export const LINEAR_FORWARD_WGSL = /* wgsl */ `
11
+
12
+ struct LinearParams {
13
+ M : u32, // number of rows (batch * seq_len)
14
+ K : u32, // in_features
15
+ N : u32, // out_features
16
+ };
17
+
18
+ @group(0) @binding(0) var<uniform> params : LinearParams;
19
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
20
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
21
+ @group(0) @binding(3) var<storage, read> bias : array<f32>; // (N,)
22
+ @group(0) @binding(4) var<storage, read_write> Y : array<f32>; // (M, N)
23
+
24
+ // Tiled matmul using workgroup shared memory (16x16 tiles)
25
+ var<workgroup> tile_X : array<f32, 256>; // 16 * 16
26
+ var<workgroup> tile_W : array<f32, 256>;
27
+
28
+ @compute @workgroup_size(16, 16, 1)
29
+ fn linear_forward(
30
+ @builtin(global_invocation_id) gid : vec3<u32>,
31
+ @builtin(local_invocation_id) lid : vec3<u32>,
32
+ @builtin(workgroup_id) wid : vec3<u32>,
33
+ ) {
34
+ let M = params.M;
35
+ let K = params.K;
36
+ let N = params.N;
37
+
38
+ let row = gid.x; // output row (M dimension)
39
+ let col = gid.y; // output col (N dimension)
40
+
41
+ var acc: f32 = 0.0;
42
+ let TILE: u32 = 16u;
43
+ let num_tiles = (K + TILE - 1u) / TILE;
44
+
45
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
46
+ // Load X tile: shape (TILE_M, TILE_K)
47
+ let x_col = tile_idx * TILE + lid.y;
48
+ let x_row = wid.x * TILE + lid.x;
49
+ if (x_row < M && x_col < K) {
50
+ tile_X[lid.x * TILE + lid.y] = X[x_row * K + x_col];
51
+ } else {
52
+ tile_X[lid.x * TILE + lid.y] = 0.0;
53
+ }
54
+
55
+ // Load W tile: shape (TILE_N, TILE_K) — W is (N, K)
56
+ let w_col = tile_idx * TILE + lid.x; // K dimension
57
+ let w_row = wid.y * TILE + lid.y; // N dimension
58
+ if (w_row < N && w_col < K) {
59
+ tile_W[lid.y * TILE + lid.x] = W[w_row * K + w_col];
60
+ } else {
61
+ tile_W[lid.y * TILE + lid.x] = 0.0;
62
+ }
63
+
64
+ workgroupBarrier();
65
+
66
+ // Dot product within tile
67
+ for (var k: u32 = 0u; k < TILE; k = k + 1u) {
68
+ acc = acc + tile_X[lid.x * TILE + k] * tile_W[lid.y * TILE + k];
69
+ }
70
+ workgroupBarrier();
71
+ }
72
+
73
+ if (row < M && col < N) {
74
+ Y[row * N + col] = acc + bias[col];
75
+ }
76
+ }
77
+ `;
78
+ // ---- Backward pass for linear projection ----
79
+ export const LINEAR_BACKWARD_WGSL = /* wgsl */ `
80
+
81
+ struct LinearParams {
82
+ M : u32,
83
+ K : u32,
84
+ N : u32,
85
+ };
86
+
87
+ @group(0) @binding(0) var<uniform> params : LinearParams;
88
+ @group(0) @binding(1) var<storage, read> X : array<f32>; // (M, K)
89
+ @group(0) @binding(2) var<storage, read> W : array<f32>; // (N, K)
90
+ @group(0) @binding(3) var<storage, read> dY : array<f32>; // (M, N)
91
+ @group(0) @binding(4) var<storage, read_write> dX : array<f32>; // (M, K)
92
+ @group(0) @binding(5) var<storage, read_write> dW : array<f32>; // (N, K)
93
+ @group(0) @binding(6) var<storage, read_write> db : array<f32>; // (N,)
94
+
95
+ // Dispatch: (ceil(M/16), ceil(K/16), 1) – computes dX = dY @ W
96
+ var<workgroup> tile_dY : array<f32, 256>;
97
+ var<workgroup> tile_W : array<f32, 256>;
98
+
99
+ @compute @workgroup_size(16, 16, 1)
100
+ fn linear_backward_dX(
101
+ @builtin(global_invocation_id) gid : vec3<u32>,
102
+ @builtin(local_invocation_id) lid : vec3<u32>,
103
+ @builtin(workgroup_id) wid : vec3<u32>,
104
+ ) {
105
+ let M = params.M;
106
+ let K = params.K;
107
+ let N = params.N;
108
+
109
+ let row = gid.x; // M
110
+ let col = gid.y; // K
111
+
112
+ var acc: f32 = 0.0;
113
+ let TILE: u32 = 16u;
114
+ let num_tiles = (N + TILE - 1u) / TILE;
115
+
116
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
117
+ // tile_dY: (M, TILE_N) slice
118
+ let dy_col = tile_idx * TILE + lid.y;
119
+ let dy_row = wid.x * TILE + lid.x;
120
+ if (dy_row < M && dy_col < N) {
121
+ tile_dY[lid.x * TILE + lid.y] = dY[dy_row * N + dy_col];
122
+ } else {
123
+ tile_dY[lid.x * TILE + lid.y] = 0.0;
124
+ }
125
+
126
+ // tile_W: (TILE_N, K) slice — W[n, k]
127
+ let w_row = tile_idx * TILE + lid.x; // N
128
+ let w_col = wid.y * TILE + lid.y; // K
129
+ if (w_row < N && w_col < K) {
130
+ tile_W[lid.x * TILE + lid.y] = W[w_row * K + w_col];
131
+ } else {
132
+ tile_W[lid.x * TILE + lid.y] = 0.0;
133
+ }
134
+
135
+ workgroupBarrier();
136
+
137
+ for (var n: u32 = 0u; n < TILE; n = n + 1u) {
138
+ acc = acc + tile_dY[lid.x * TILE + n] * tile_W[n * TILE + lid.y];
139
+ }
140
+ workgroupBarrier();
141
+ }
142
+
143
+ if (row < M && col < K) {
144
+ dX[row * K + col] = acc;
145
+ }
146
+ }
147
+
148
+ // Dispatch: (ceil(N/16), ceil(K/16), 1) – computes dW = dY^T @ X
149
+ var<workgroup> tile_dY2 : array<f32, 256>;
150
+ var<workgroup> tile_X2 : array<f32, 256>;
151
+
152
+ @compute @workgroup_size(16, 16, 1)
153
+ fn linear_backward_dW(
154
+ @builtin(global_invocation_id) gid : vec3<u32>,
155
+ @builtin(local_invocation_id) lid : vec3<u32>,
156
+ @builtin(workgroup_id) wid : vec3<u32>,
157
+ ) {
158
+ let M = params.M;
159
+ let K = params.K;
160
+ let N = params.N;
161
+
162
+ let row = gid.x; // N
163
+ let col = gid.y; // K
164
+
165
+ var acc: f32 = 0.0;
166
+ let TILE: u32 = 16u;
167
+ let num_tiles = (M + TILE - 1u) / TILE;
168
+
169
+ for (var tile_idx: u32 = 0u; tile_idx < num_tiles; tile_idx = tile_idx + 1u) {
170
+ // dY^T tile: [N, M] accessed as dY[m, n]
171
+ let m_idx = tile_idx * TILE + lid.y;
172
+ let n_idx = wid.x * TILE + lid.x;
173
+ if (n_idx < N && m_idx < M) {
174
+ tile_dY2[lid.x * TILE + lid.y] = dY[m_idx * N + n_idx];
175
+ } else {
176
+ tile_dY2[lid.x * TILE + lid.y] = 0.0;
177
+ }
178
+
179
+ // X tile: [M, K]
180
+ let xm = tile_idx * TILE + lid.x;
181
+ let xk = wid.y * TILE + lid.y;
182
+ if (xm < M && xk < K) {
183
+ tile_X2[lid.x * TILE + lid.y] = X[xm * K + xk];
184
+ } else {
185
+ tile_X2[lid.x * TILE + lid.y] = 0.0;
186
+ }
187
+
188
+ workgroupBarrier();
189
+
190
+ for (var m: u32 = 0u; m < TILE; m = m + 1u) {
191
+ acc = acc + tile_dY2[lid.x * TILE + m] * tile_X2[m * TILE + lid.y];
192
+ }
193
+ workgroupBarrier();
194
+ }
195
+
196
+ if (row < N && col < K) {
197
+ dW[row * K + col] = acc;
198
+ }
199
+ }
200
+
201
+ // Dispatch: (N, 1, 1) – accumulates db = sum_M dY
202
+ @compute @workgroup_size(64, 1, 1)
203
+ fn linear_backward_db(
204
+ @builtin(global_invocation_id) gid : vec3<u32>,
205
+ ) {
206
+ let M = params.M;
207
+ let N = params.N;
208
+
209
+ let n = gid.x;
210
+ if (n >= N) { return; }
211
+
212
+ var acc: f32 = 0.0;
213
+ for (var m: u32 = 0u; m < M; m = m + 1u) {
214
+ acc = acc + dY[m * N + n];
215
+ }
216
+ db[n] = acc;
217
+ }
218
+ `;
219
+ //# sourceMappingURL=linear_projection.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"linear_projection.js","sourceRoot":"","sources":["../../src/kernels/linear_projection.ts"],"names":[],"mappings":"AAAA,gCAAgC;AAChC,yDAAyD;AACzD,mFAAmF;AACnF,EAAE;AACF,UAAU;AACV,wDAAwD;AACxD,qEAAqE;AACrE,gDAAgD;AAChD,kDAAkD;AAElD,MAAM,CAAC,MAAM,mBAAmB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAmEpD,CAAC;AAEF,gDAAgD;AAChD,MAAM,CAAC,MAAM,oBAAoB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA2IrD,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const SELECTIVE_SCAN_FORWARD_WGSL: string;
2
+ export declare const SELECTIVE_SCAN_BACKWARD_WGSL: string;
3
+ //# sourceMappingURL=selective_scan.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"selective_scan.d.ts","sourceRoot":"","sources":["../../src/kernels/selective_scan.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,2BAA2B,EAAE,MAsNzC,CAAC;AAKF,eAAO,MAAM,4BAA4B,EAAE,MAwH1C,CAAC"}