mambacode.js 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +18 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +59 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
  61. package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/trainer.ts +312 -0
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/training/trainer.js +0 -394
  71. package/src/utils/gpu_utils.js +0 -217
  72. package/src/utils/quantization.js +0 -215
  73. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -0,0 +1,116 @@
1
+ /**
2
+ * quantization.ts – FP16 and Int8 quantization utilities.
3
+ */
4
+ export function floatToFp16(val) {
5
+ const buf = new ArrayBuffer(4);
6
+ const f32 = new Float32Array(buf);
7
+ const u32 = new Uint32Array(buf);
8
+ f32[0] = val;
9
+ const bits = u32[0];
10
+ const sign = (bits >>> 31) & 0x1;
11
+ const exponent = (bits >>> 23) & 0xFF;
12
+ const mantissa = bits & 0x7FFFFF;
13
+ if (exponent === 255) {
14
+ return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
15
+ }
16
+ const expAdj = exponent - 127 + 15;
17
+ if (expAdj >= 31) {
18
+ return (sign << 15) | 0x7C00;
19
+ }
20
+ if (expAdj <= 0) {
21
+ if (expAdj < -10) {
22
+ return sign << 15;
23
+ }
24
+ const shift = 14 - expAdj;
25
+ return (sign << 15) | ((mantissa | 0x800000) >> shift);
26
+ }
27
+ return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
28
+ }
29
+ export function fp16ToFloat(val) {
30
+ const sign = (val >>> 15) & 0x1;
31
+ const exponent = (val >>> 10) & 0x1F;
32
+ const mantissa = val & 0x3FF;
33
+ if (exponent === 0) {
34
+ const f = mantissa / 1024.0;
35
+ return sign ? -f : f;
36
+ }
37
+ if (exponent === 31) {
38
+ return sign ? -Infinity : (mantissa ? NaN : Infinity);
39
+ }
40
+ const expUnbiased = exponent - 15;
41
+ const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
42
+ return sign ? -f : f;
43
+ }
44
+ export function quantizeFp16(f32) {
45
+ const out = new Uint16Array(f32.length);
46
+ for (let i = 0; i < f32.length; i++) {
47
+ out[i] = floatToFp16(f32[i]);
48
+ }
49
+ return out;
50
+ }
51
+ export function dequantizeFp16(fp16) {
52
+ const out = new Float32Array(fp16.length);
53
+ for (let i = 0; i < fp16.length; i++) {
54
+ out[i] = fp16ToFloat(fp16[i]);
55
+ }
56
+ return out;
57
+ }
58
+ export function quantizeInt8(f32) {
59
+ let maxAbs = 0;
60
+ for (let i = 0; i < f32.length; i++) {
61
+ const a = Math.abs(f32[i]);
62
+ if (a > maxAbs)
63
+ maxAbs = a;
64
+ }
65
+ const scale = maxAbs / 127.0 || 1.0;
66
+ const data = new Int8Array(f32.length);
67
+ for (let i = 0; i < f32.length; i++) {
68
+ data[i] = Math.max(-128, Math.min(127, Math.round(f32[i] / scale)));
69
+ }
70
+ return { data, scale };
71
+ }
72
+ export function dequantizeInt8(int8, scale) {
73
+ const out = new Float32Array(int8.length);
74
+ for (let i = 0; i < int8.length; i++) {
75
+ out[i] = int8[i] * scale;
76
+ }
77
+ return out;
78
+ }
79
+ export function quantizeInt8PerChannel(f32, numChannels) {
80
+ const channelSize = f32.length / numChannels;
81
+ const scales = new Float32Array(numChannels);
82
+ const data = new Int8Array(f32.length);
83
+ for (let c = 0; c < numChannels; c++) {
84
+ let maxAbs = 0;
85
+ const base = c * channelSize;
86
+ for (let j = 0; j < channelSize; j++) {
87
+ const a = Math.abs(f32[base + j]);
88
+ if (a > maxAbs)
89
+ maxAbs = a;
90
+ }
91
+ scales[c] = maxAbs / 127.0 || 1.0;
92
+ for (let j = 0; j < channelSize; j++) {
93
+ data[base + j] = Math.max(-128, Math.min(127, Math.round(f32[base + j] / scales[c])));
94
+ }
95
+ }
96
+ return { data, scales };
97
+ }
98
+ export function dequantizeInt8PerChannel(int8, scales, numChannels) {
99
+ const channelSize = int8.length / numChannels;
100
+ const out = new Float32Array(int8.length);
101
+ for (let c = 0; c < numChannels; c++) {
102
+ const base = c * channelSize;
103
+ for (let j = 0; j < channelSize; j++) {
104
+ out[base + j] = int8[base + j] * scales[c];
105
+ }
106
+ }
107
+ return out;
108
+ }
109
+ export function estimateMemory(numElements) {
110
+ return {
111
+ fp32: numElements * 4,
112
+ fp16: numElements * 2,
113
+ int8: numElements * 1,
114
+ };
115
+ }
116
+ //# sourceMappingURL=quantization.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"quantization.js","sourceRoot":"","sources":["../../src/utils/quantization.ts"],"names":[],"mappings":"AAAA;;GAEG;AAkBH,MAAM,UAAU,WAAW,CAAC,GAAW;IACnC,MAAM,GAAG,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;IAC/B,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,GAAG,CAAC,CAAC;IAClC,MAAM,GAAG,GAAG,IAAI,WAAW,CAAC,GAAG,CAAC,CAAC;IACjC,GAAG,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;IACb,MAAM,IAAI,GAAG,GAAG,CAAC,CAAC,CAAE,CAAC;IAErB,MAAM,IAAI,GAAO,CAAC,IAAI,KAAK,EAAE,CAAC,GAAG,GAAG,CAAC;IACrC,MAAM,QAAQ,GAAG,CAAC,IAAI,KAAK,EAAE,CAAC,GAAG,IAAI,CAAC;IACtC,MAAM,QAAQ,GAAI,IAAI,GAAW,QAAQ,CAAC;IAE1C,IAAI,QAAQ,KAAK,GAAG,EAAE,CAAC;QACnB,OAAO,CAAC,IAAI,IAAI,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC;IAED,MAAM,MAAM,GAAG,QAAQ,GAAG,GAAG,GAAG,EAAE,CAAC;IAEnC,IAAI,MAAM,IAAI,EAAE,EAAE,CAAC;QACf,OAAO,CAAC,IAAI,IAAI,EAAE,CAAC,GAAG,MAAM,CAAC;IACjC,CAAC;IAED,IAAI,MAAM,IAAI,CAAC,EAAE,CAAC;QACd,IAAI,MAAM,GAAG,CAAC,EAAE,EAAE,CAAC;YAAC,OAAO,IAAI,IAAI,EAAE,CAAC;QAAC,CAAC;QACxC,MAAM,KAAK,GAAG,EAAE,GAAG,MAAM,CAAC;QAC1B,OAAO,CAAC,IAAI,IAAI,EAAE,CAAC,GAAG,CAAC,CAAC,QAAQ,GAAG,QAAQ,CAAC,IAAI,KAAK,CAAC,CAAC;IAC3D,CAAC;IAED,OAAO,CAAC,IAAI,IAAI,EAAE,CAAC,GAAG,CAAC,MAAM,IAAI,EAAE,CAAC,GAAG,CAAC,QAAQ,IAAI,EAAE,CAAC,CAAC;AAC5D,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,GAAW;IACnC,MAAM,IAAI,GAAO,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,GAAG,CAAC;IACpC,MAAM,QAAQ,GAAG,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,IAAI,CAAC;IACrC,MAAM,QAAQ,GAAI,GAAG,GAAW,KAAK,CAAC;IAEtC,IAAI,QAAQ,KAAK,CAAC,EAAE,CAAC;QACjB,MAAM,CAAC,GAAG,QAAQ,GAAG,MAAM,CAAC;QAC5B,OAAO,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC;IAED,IAAI,QAAQ,KAAK,EAAE,EAAE,CAAC;QAClB,OAAO,IAAI,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC;IAC1D,CAAC;IAED,MAAM,WAAW,GAAG,QAAQ,GAAG,EAAE,CAAC;IAClC,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,QAAQ,GAAG,MAAM,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,WAAW,CAAC,CAAC;IAC7D,OAAO,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;AACzB,CAAC;AAED,MAAM,UAAU,YAAY,CAAC,GAAiB;IAC1C,MAAM,GAAG,GAAG,IAAI,WAAW,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;IACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAClC,GAAG,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,GAAG,CAAC,CAAC,CAAE,CAAC,CAAC;IAClC,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,IAAiB;IAC5C,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAC1C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACnC,GAAG,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,IAAI,CAAC,CAAC,CAAE,CAAC,CAAC;IACnC,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,YAAY,CAAC,GAAiB;IAC1C,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAClC,MAAM,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAE,CAAC,CAAC;QAC5B,IAAI,CAAC,GAAG,MAAM;YAAE,MAAM,GAAG,CAAC,CAAC;IAC/B,CAAC;IAED,MAAM,KAAK,GAAG,MAAM,GAAG,KAAK,IAAI,GAAG,CAAC;IACpC,MAAM,IAAI,GAAI,IAAI,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;IAExC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAClC,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAE,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC;IAED,OAAO,EAAE,IAAI,EAAE,KAAK,EAAE,CAAC;AAC3B,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,IAAe,EAAE,KAAa;IACzD,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAC1C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACnC,GAAG,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAE,GAAG,KAAK,CAAC;IAC9B,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,sBAAsB,CAAC,GAAiB,EAAE,WAAmB;IACzE,MAAM,WAAW,GAAG,GAAG,CAAC,MAAM,GAAG,WAAW,CAAC;IAC7C,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,WAAW,CAAC,CAAC;IAC7C,MAAM,IAAI,GAAK,IAAI,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;IAEzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,EAAE,CAAC,EAAE,EAAE,CAAC;QACnC,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,MAAM,IAAI,GAAG,CAAC,GAAG,WAAW,CAAC;QAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,MAAM,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAE,CAAC,CAAC;YACnC,IAAI,CAAC,GAAG,MAAM;gBAAE,MAAM,GAAG,CAAC,CAAC;QAC/B,CAAC;QACD,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,GAAG,KAAK,IAAI,GAAG,CAAC;QAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,EACxC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAE,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC,CAC1C,CAAC,CAAC;QACP,CAAC;IACL,CAAC;IAED,OAAO,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC;AAC5B,CAAC;AAED,MAAM,UAAU,wBAAwB,CAAC,IAAe,EAAE,MAAoB,EAAE,WAAmB;IAC/F,MAAM,WAAW,GAAG,IAAI,CAAC,MAAM,GAAG,WAAW,CAAC;IAC9C,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAE1C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,EAAE,CAAC,EAAE,EAAE,CAAC;QACnC,MAAM,IAAI,GAAG,CAAC,GAAG,WAAW,CAAC;QAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,EAAE,CAAC,EAAE,EAAE,CAAC;YACnC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,CAAE,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC;QACjD,CAAC;IACL,CAAC;IAED,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,WAAmB;IAC9C,OAAO;QACH,IAAI,EAAE,WAAW,GAAG,CAAC;QACrB,IAAI,EAAE,WAAW,GAAG,CAAC;QACrB,IAAI,EAAE,WAAW,GAAG,CAAC;KACxB,CAAC;AACN,CAAC"}
package/package.json CHANGED
@@ -1,29 +1,30 @@
1
1
  {
2
2
  "name": "mambacode.js",
3
- "version": "1.0.0",
4
- "description": "High-performance JavaScript/WGSL Mamba SSM library for browser-based code model training and inference",
5
- "main": "src/index.js",
3
+ "version": "1.0.1",
4
+ "description": "High-performance TypeScript/WGSL Mamba SSM library for browser-based code model training and inference",
5
+ "main": "dist/index.js",
6
+ "module": "dist/index.js",
7
+ "types": "dist/index.d.ts",
8
+ "exports": {
9
+ ".": {
10
+ "types": "./dist/index.d.ts",
11
+ "import": "./dist/index.js"
12
+ }
13
+ },
6
14
  "type": "module",
7
15
  "files": [
16
+ "dist",
8
17
  "src",
9
18
  "README.md",
10
19
  "LICENSE"
11
20
  ],
12
21
  "scripts": {
22
+ "build": "tsc",
13
23
  "test": "node --experimental-vm-modules node_modules/.bin/jest",
14
- "lint": "eslint src/ tests/"
15
- },
16
- "keywords": [
17
- "mamba",
18
- "ssm",
19
- "state-space-model",
20
- "webgpu",
21
- "wgsl",
22
- "machine-learning",
23
- "code-model",
24
- "bpe",
25
- "transformer-alternative"
26
- ],
24
+ "lint": "eslint src/ tests/",
25
+ "prepublishOnly": "npm run build"
26
+ },
27
+ "keywords": ["mamba", "ssm", "state-space-model", "webgpu", "wgsl", "machine-learning", "code-model", "bpe", "transformer-alternative", "typescript"],
27
28
  "author": {
28
29
  "name": "Sean Hogg",
29
30
  "email": "seanhogg@gmail.com",
@@ -45,10 +46,34 @@
45
46
  "access": "public"
46
47
  },
47
48
  "devDependencies": {
49
+ "@types/jest": "^29.5.0",
50
+ "@webgpu/types": "^0.1.49",
51
+ "eslint": "^8.57.0",
48
52
  "jest": "^29.7.0",
49
- "eslint": "^8.57.0"
53
+ "ts-jest": "^29.2.0",
54
+ "typescript": "^5.0.0"
50
55
  },
51
56
  "jest": {
52
- "transform": {}
57
+ "preset": "ts-jest/presets/default-esm",
58
+ "extensionsToTreatAsEsm": [".ts"],
59
+ "transform": {
60
+ "^.+\\.tsx?$": [
61
+ "ts-jest",
62
+ {
63
+ "useESM": true,
64
+ "tsconfig": {
65
+ "module": "ES2022",
66
+ "moduleResolution": "bundler",
67
+ "strict": true,
68
+ "esModuleInterop": true,
69
+ "skipLibCheck": true,
70
+ "types": ["@webgpu/types", "jest"]
71
+ }
72
+ }
73
+ ]
74
+ },
75
+ "moduleNameMapper": {
76
+ "^(\\.{1,2}/.*)\\.js$": "$1"
77
+ }
53
78
  }
54
79
  }
package/src/index.ts ADDED
@@ -0,0 +1,59 @@
1
+ /**
2
+ * MambaCode.js – Entry Point
3
+ */
4
+
5
+ export { MambaModel } from './model/mamba_model';
6
+ export { MambaBlock } from './model/mamba_block';
7
+
8
+ export { MambaTrainer } from './training/trainer';
9
+ export {
10
+ Tensor,
11
+ backward,
12
+ enableGrad,
13
+ noGrad,
14
+ clearTape,
15
+ recordOperation,
16
+ crossEntropyLoss,
17
+ crossEntropyGrad,
18
+ } from './training/autograd';
19
+
20
+ export { BPETokenizer } from './tokenizer/bpe';
21
+
22
+ export {
23
+ initWebGPU,
24
+ createStorageBuffer,
25
+ createEmptyStorageBuffer,
26
+ createUniformBuffer,
27
+ createComputePipeline,
28
+ createBindGroup,
29
+ dispatchKernel,
30
+ readBuffer,
31
+ uploadBuffer,
32
+ cdiv,
33
+ } from './utils/gpu_utils';
34
+
35
+ export {
36
+ quantizeFp16,
37
+ dequantizeFp16,
38
+ floatToFp16,
39
+ fp16ToFloat,
40
+ quantizeInt8,
41
+ dequantizeInt8,
42
+ quantizeInt8PerChannel,
43
+ dequantizeInt8PerChannel,
44
+ estimateMemory,
45
+ } from './utils/quantization';
46
+
47
+ export { SELECTIVE_SCAN_FORWARD_WGSL, SELECTIVE_SCAN_BACKWARD_WGSL }
48
+ from './kernels/selective_scan';
49
+ export { CONV1D_FORWARD_WGSL, CONV1D_BACKWARD_WGSL }
50
+ from './kernels/conv1d';
51
+ export { LINEAR_FORWARD_WGSL, LINEAR_BACKWARD_WGSL }
52
+ from './kernels/linear_projection';
53
+ export { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL }
54
+ from './kernels/weight_update';
55
+ export { ACTIVATIONS_WGSL, ACTIVATIONS_BACKWARD_WGSL }
56
+ from './kernels/activations';
57
+
58
+ export const VERSION = '1.0.1';
59
+ export const DESCRIPTION = 'MambaCode.js: WebGPU-accelerated Mamba SSM for browser code models';
@@ -1,7 +1,7 @@
1
1
  // Activation function WGSL kernels: SiLU (Swish) and its backward pass.
2
2
  // Used in the gating mechanism of the Mamba Mixer Block.
3
3
 
4
- export const ACTIVATIONS_WGSL = /* wgsl */`
4
+ export const ACTIVATIONS_WGSL: string = /* wgsl */`
5
5
 
6
6
  struct ActParams {
7
7
  num_elements : u32,
@@ -61,7 +61,7 @@ fn rmsnorm_forward(
61
61
  `;
62
62
 
63
63
  // ---- Backward for SiLU ----
64
- export const ACTIVATIONS_BACKWARD_WGSL = /* wgsl */`
64
+ export const ACTIVATIONS_BACKWARD_WGSL: string = /* wgsl */`
65
65
 
66
66
  struct ActParams {
67
67
  num_elements : u32,
@@ -8,7 +8,7 @@
8
8
  // b : (out_features,) – bias
9
9
  // Y : (batch * seq_len, out_features) – output
10
10
 
11
- export const LINEAR_FORWARD_WGSL = /* wgsl */`
11
+ export const LINEAR_FORWARD_WGSL: string = /* wgsl */`
12
12
 
13
13
  struct LinearParams {
14
14
  M : u32, // number of rows (batch * seq_len)
@@ -78,7 +78,7 @@ fn linear_forward(
78
78
  `;
79
79
 
80
80
  // ---- Backward pass for linear projection ----
81
- export const LINEAR_BACKWARD_WGSL = /* wgsl */`
81
+ export const LINEAR_BACKWARD_WGSL: string = /* wgsl */`
82
82
 
83
83
  struct LinearParams {
84
84
  M : u32,
@@ -8,7 +8,7 @@
8
8
  //
9
9
  // where A_t, B_t, C_t are input-dependent (selective) gate matrices.
10
10
 
11
- export const SELECTIVE_SCAN_FORWARD_WGSL = /* wgsl */`
11
+ export const SELECTIVE_SCAN_FORWARD_WGSL: string = /* wgsl */`
12
12
 
13
13
  // ---- Binding layout ----
14
14
  // group 0: sequence data
@@ -227,7 +227,7 @@ fn forward_reduce(
227
227
  // ---- Backward scan kernel (for autograd) ----
228
228
  // Computes gradients w.r.t. Δ, A, B, C using the cached hidden states.
229
229
 
230
- export const SELECTIVE_SCAN_BACKWARD_WGSL = /* wgsl */`
230
+ export const SELECTIVE_SCAN_BACKWARD_WGSL: string = /* wgsl */`
231
231
 
232
232
  struct ScanParams {
233
233
  seq_len : u32,
@@ -8,7 +8,7 @@
8
8
  // v_hat = v_t / (1 - beta2^t)
9
9
  // theta_t = theta_{t-1} * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
10
10
 
11
- export const WEIGHT_UPDATE_WGSL = /* wgsl */`
11
+ export const WEIGHT_UPDATE_WGSL: string = /* wgsl */`
12
12
 
13
13
  struct AdamParams {
14
14
  num_elements : u32,
@@ -60,7 +60,7 @@ fn adamw_update(
60
60
 
61
61
  // Gradient clipping kernel – clips global gradient norm to max_norm.
62
62
  // Run before weight updates. Two-pass: first compute squared norm, then scale.
63
- export const GRAD_CLIP_WGSL = /* wgsl */`
63
+ export const GRAD_CLIP_WGSL: string = /* wgsl */`
64
64
 
65
65
  struct ClipParams {
66
66
  num_elements : u32,