@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,119 @@
1
+ // Weight Update WGSL Kernel (AdamW Optimizer)
2
+ // Implements fused AdamW parameter update on the GPU.
3
+ //
4
+ // AdamW update rule:
5
+ // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
6
+ // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
7
+ // m_hat = m_t / (1 - beta1^t)
8
+ // v_hat = v_t / (1 - beta2^t)
9
+ // theta_t = theta_{t-1} * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
10
+ export const WEIGHT_UPDATE_WGSL = /* wgsl */ `
11
+
12
+ struct AdamParams {
13
+ num_elements : u32,
14
+ lr : f32, // learning rate
15
+ beta1 : f32, // default 0.9
16
+ beta2 : f32, // default 0.999
17
+ eps : f32, // default 1e-8
18
+ weight_decay : f32, // default 0.01
19
+ beta1_t : f32, // beta1^t (precomputed bias correction term)
20
+ beta2_t : f32, // beta2^t
21
+ };
22
+
23
+ @group(0) @binding(0) var<uniform> adam : AdamParams;
24
+ // param (N,) – weight tensor (read-write: updated in-place)
25
+ @group(0) @binding(1) var<storage, read_write> param : array<f32>;
26
+ // grad (N,) – gradient
27
+ @group(0) @binding(2) var<storage, read> grad : array<f32>;
28
+ // m (N,) – first moment
29
+ @group(0) @binding(3) var<storage, read_write> m_state : array<f32>;
30
+ // v (N,) – second moment
31
+ @group(0) @binding(4) var<storage, read_write> v_state : array<f32>;
32
+
33
+ // Dispatch: (ceil(N / 256), 1, 1)
34
+ @compute @workgroup_size(256, 1, 1)
35
+ fn adamw_update(
36
+ @builtin(global_invocation_id) gid : vec3<u32>,
37
+ ) {
38
+ let i = gid.x;
39
+ if (i >= adam.num_elements) { return; }
40
+
41
+ let g = grad[i];
42
+ let p = param[i];
43
+
44
+ // Moment updates
45
+ let m_new = adam.beta1 * m_state[i] + (1.0 - adam.beta1) * g;
46
+ let v_new = adam.beta2 * v_state[i] + (1.0 - adam.beta2) * g * g;
47
+ m_state[i] = m_new;
48
+ v_state[i] = v_new;
49
+
50
+ // Bias-corrected estimates
51
+ let m_hat = m_new / (1.0 - adam.beta1_t);
52
+ let v_hat = v_new / (1.0 - adam.beta2_t);
53
+
54
+ // Weight decay (decoupled) + gradient step
55
+ param[i] = p * (1.0 - adam.lr * adam.weight_decay) -
56
+ adam.lr * m_hat / (sqrt(v_hat) + adam.eps);
57
+ }
58
+ `;
59
+ // Gradient clipping kernel – clips global gradient norm to max_norm.
60
+ // Run before weight updates. Two-pass: first compute squared norm, then scale.
61
+ export const GRAD_CLIP_WGSL = /* wgsl */ `
62
+
63
+ struct ClipParams {
64
+ num_elements : u32,
65
+ max_norm_sq : f32, // max_norm^2
66
+ };
67
+
68
+ @group(0) @binding(0) var<uniform> clip_p : ClipParams;
69
+ @group(0) @binding(1) var<storage, read_write> grad : array<f32>;
70
+ @group(0) @binding(2) var<storage, read_write> norm_sq : array<f32>; // size 1, atomic accumulator
71
+
72
+ var<workgroup> local_sq : array<f32, 256>;
73
+
74
+ // Pass 1: reduce sum of squares into norm_sq[0]
75
+ @compute @workgroup_size(256, 1, 1)
76
+ fn grad_norm_reduce(
77
+ @builtin(global_invocation_id) gid : vec3<u32>,
78
+ @builtin(local_invocation_index) lid : u32,
79
+ ) {
80
+ let i = gid.x;
81
+ local_sq[lid] = 0.0;
82
+ if (i < clip_p.num_elements) {
83
+ local_sq[lid] = grad[i] * grad[i];
84
+ }
85
+ workgroupBarrier();
86
+
87
+ // Parallel reduction within workgroup
88
+ var s: u32 = 128u;
89
+ loop {
90
+ if (s == 0u) { break; }
91
+ if (lid < s) {
92
+ local_sq[lid] = local_sq[lid] + local_sq[lid + s];
93
+ }
94
+ workgroupBarrier();
95
+ s = s >> 1u;
96
+ }
97
+
98
+ if (lid == 0u) {
99
+ // Non-atomic accumulation (single workgroup assumption for small models)
100
+ norm_sq[0] = norm_sq[0] + local_sq[0];
101
+ }
102
+ }
103
+
104
+ // Pass 2: scale gradients if norm exceeds max_norm
105
+ @compute @workgroup_size(256, 1, 1)
106
+ fn grad_clip_scale(
107
+ @builtin(global_invocation_id) gid : vec3<u32>,
108
+ ) {
109
+ let i = gid.x;
110
+ if (i >= clip_p.num_elements) { return; }
111
+
112
+ let ns = norm_sq[0];
113
+ if (ns > clip_p.max_norm_sq) {
114
+ let scale = sqrt(clip_p.max_norm_sq / ns);
115
+ grad[i] = grad[i] * scale;
116
+ }
117
+ }
118
+ `;
119
+ //# sourceMappingURL=weight_update.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"weight_update.js","sourceRoot":"","sources":["../../src/kernels/weight_update.ts"],"names":[],"mappings":"AAAA,8CAA8C;AAC9C,sDAAsD;AACtD,EAAE;AACF,qBAAqB;AACrB,8CAA8C;AAC9C,gDAAgD;AAChD,gCAAgC;AAChC,gCAAgC;AAChC,uFAAuF;AAEvF,MAAM,CAAC,MAAM,kBAAkB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAgDnD,CAAC;AAEF,qEAAqE;AACrE,gFAAgF;AAChF,MAAM,CAAC,MAAM,cAAc,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyD/C,CAAC"}
@@ -0,0 +1,48 @@
1
+ /**
2
+ * attention_block.ts – Causal Multi-Head Self-Attention Block.
3
+ *
4
+ * Intentionally simple for WebGPU — naive O(L²) tiled attention,
5
+ * no Flash-Attention dependency. Suitable for hybrid (Jamba/Zamba) schedules
6
+ * where a few attention layers interleave with SSM layers.
7
+ *
8
+ * Data flow:
9
+ * Input (B, L, D_model)
10
+ * └─ RMSNorm
11
+ * └─ wQKV → Q (B,L,H,dh), K (B,L,H,dh), V (B,L,H,dh)
12
+ * └─ causal attention scores = Q·Kᵀ / √dh (masked)
13
+ * └─ softmax
14
+ * └─ weighted V sum
15
+ * └─ concat heads → wO → D_model
16
+ * └─ + residual
17
+ * [optional FFN sublayer]
18
+ *
19
+ * Implements SequenceLayer.
20
+ */
21
+ import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
22
+ export interface AttentionBlockConfig {
23
+ dModel: number;
24
+ nHeads: number;
25
+ dHead?: number;
26
+ hasFfn?: boolean;
27
+ ffnMult?: number;
28
+ }
29
+ export interface AttentionCache {
30
+ scores: GPUBuffer;
31
+ }
32
+ export declare class AttentionBlock implements SequenceLayer {
33
+ readonly layerType: "attention";
34
+ device: GPUDevice;
35
+ config: Required<AttentionBlockConfig>;
36
+ dHead: number;
37
+ gpuWeights: Record<string, GPUBuffer>;
38
+ pipelines: Record<string, GPUComputePipeline>;
39
+ constructor(device: GPUDevice, config: AttentionBlockConfig);
40
+ private _initWeights;
41
+ private _buildPipelines;
42
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult;
43
+ parameters(): LayerParam[];
44
+ getTrainableParams(): LayerParam[];
45
+ setWSLAMode(_enabled: boolean): void;
46
+ destroy(): void;
47
+ }
48
+ //# sourceMappingURL=attention_block.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"attention_block.d.ts","sourceRoot":"","sources":["../../src/model/attention_block.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;GAmBG;AAoBH,OAAO,KAAK,EAAE,aAAa,EAAE,kBAAkB,EAAE,UAAU,EAAE,MAAM,qBAAqB,CAAC;AAEzF,MAAM,WAAW,oBAAoB;IACjC,MAAM,EAAI,MAAM,CAAC;IACjB,MAAM,EAAI,MAAM,CAAC;IACjB,KAAK,CAAC,EAAI,MAAM,CAAC;IACjB,MAAM,CAAC,EAAG,OAAO,CAAC;IAClB,OAAO,CAAC,EAAE,MAAM,CAAC;CACpB;AAED,MAAM,WAAW,cAAc;IAC3B,MAAM,EAAE,SAAS,CAAC;CACrB;AA6BD,qBAAa,cAAe,YAAW,aAAa;IAChD,QAAQ,CAAC,SAAS,EAAG,WAAW,CAAU;IAE1C,MAAM,EAAG,SAAS,CAAC;IACnB,MAAM,EAAG,QAAQ,CAAC,oBAAoB,CAAC,CAAC;IACxC,KAAK,EAAI,MAAM,CAAC;IAEhB,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;IACtC,SAAS,EAAG,MAAM,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;gBAEnC,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,oBAAoB;IAyB3D,OAAO,CAAC,YAAY;IA0BpB,OAAO,CAAC,eAAe;IAgBvB,OAAO,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,kBAAkB;IAsJ3E,UAAU,IAAI,UAAU,EAAE;IAuB1B,kBAAkB,IAAI,UAAU,EAAE;IAKlC,WAAW,CAAC,QAAQ,EAAE,OAAO,GAAG,IAAI;IAIpC,OAAO,IAAI,IAAI;CAIlB"}
@@ -0,0 +1,262 @@
1
+ /**
2
+ * attention_block.ts – Causal Multi-Head Self-Attention Block.
3
+ *
4
+ * Intentionally simple for WebGPU — naive O(L²) tiled attention,
5
+ * no Flash-Attention dependency. Suitable for hybrid (Jamba/Zamba) schedules
6
+ * where a few attention layers interleave with SSM layers.
7
+ *
8
+ * Data flow:
9
+ * Input (B, L, D_model)
10
+ * └─ RMSNorm
11
+ * └─ wQKV → Q (B,L,H,dh), K (B,L,H,dh), V (B,L,H,dh)
12
+ * └─ causal attention scores = Q·Kᵀ / √dh (masked)
13
+ * └─ softmax
14
+ * └─ weighted V sum
15
+ * └─ concat heads → wO → D_model
16
+ * └─ + residual
17
+ * [optional FFN sublayer]
18
+ *
19
+ * Implements SequenceLayer.
20
+ */
21
+ import { createComputePipeline, createBindGroup, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
22
+ import { ATTENTION_FORWARD_WGSL, SOFTMAX_WGSL, } from '../kernels/attention.js';
23
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
24
+ import { gaussianArray } from '../utils/rng.js';
25
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
26
+ const ADD_SHADER = /* wgsl */ `
27
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
28
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
29
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
30
+ @group(0) @binding(3) var<uniform> n : u32;
31
+ @compute @workgroup_size(256)
32
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
33
+ let i = gid.x;
34
+ if (i < n) { c[i] = a[i] + b[i]; }
35
+ }
36
+ `;
37
+ // SiLU for FFN
38
+ const SILU_SHADER = /* wgsl */ `
39
+ struct ActParams { num_elements: u32; };
40
+ @group(0) @binding(0) var<uniform> p : ActParams;
41
+ @group(0) @binding(1) var<storage, read> x : array<f32>;
42
+ @group(0) @binding(2) var<storage, read_write> y : array<f32>;
43
+ @compute @workgroup_size(256, 1, 1)
44
+ fn silu_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
45
+ let i = gid.x;
46
+ if (i >= p.num_elements) { return; }
47
+ let v = x[i];
48
+ y[i] = v / (1.0 + exp(-v));
49
+ }
50
+ `;
51
+ export class AttentionBlock {
52
+ layerType = 'attention';
53
+ device;
54
+ config;
55
+ dHead;
56
+ gpuWeights;
57
+ pipelines;
58
+ constructor(device, config) {
59
+ this.device = device;
60
+ if (config.dModel % config.nHeads !== 0) {
61
+ throw new Error(`AttentionBlock: dModel (${config.dModel}) must be divisible by nHeads (${config.nHeads}).`);
62
+ }
63
+ this.config = {
64
+ dHead: config.dModel / config.nHeads,
65
+ hasFfn: false,
66
+ ffnMult: 4,
67
+ ...config,
68
+ };
69
+ this.dHead = this.config.dHead;
70
+ this.gpuWeights = {};
71
+ this.pipelines = {};
72
+ this._initWeights();
73
+ this._buildPipelines();
74
+ }
75
+ _initWeights() {
76
+ const { dModel, nHeads, hasFfn, ffnMult } = this.config;
77
+ const randn = (n, std = 0.02) => gaussianArray(n, std);
78
+ const zeros = (n) => new Float32Array(n);
79
+ const ones = (n) => new Float32Array(n).fill(1.0);
80
+ const mk = (arr) => createStorageBuffer(this.device, arr, true);
81
+ this.gpuWeights = {
82
+ wQKV: mk(randn(3 * dModel * dModel)),
83
+ bQKV: mk(zeros(3 * dModel)),
84
+ wO: mk(randn(dModel * dModel)),
85
+ bO: mk(zeros(dModel)),
86
+ normWeight: mk(ones(dModel)),
87
+ };
88
+ if (hasFfn) {
89
+ const ffnDim = dModel * ffnMult;
90
+ this.gpuWeights['wFfn1'] = mk(randn(ffnDim * dModel));
91
+ this.gpuWeights['bFfn1'] = mk(zeros(ffnDim));
92
+ this.gpuWeights['wFfn2'] = mk(randn(dModel * ffnDim));
93
+ this.gpuWeights['bFfn2'] = mk(zeros(dModel));
94
+ }
95
+ }
96
+ _buildPipelines() {
97
+ const d = this.device;
98
+ this.pipelines = {
99
+ linear: createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
100
+ rmsnorm: createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
101
+ attn_fwd: createComputePipeline(d, ATTENTION_FORWARD_WGSL, 'attention_forward'),
102
+ attn_val: createComputePipeline(d, ATTENTION_FORWARD_WGSL, 'attention_value'),
103
+ softmax: createComputePipeline(d, SOFTMAX_WGSL, 'softmax_forward'),
104
+ elAdd: createComputePipeline(d, ADD_SHADER, 'main'),
105
+ };
106
+ if (this.config.hasFfn) {
107
+ this.pipelines['silu'] = createComputePipeline(d, SILU_SHADER, 'silu_forward');
108
+ }
109
+ }
110
+ forward(xBuf, batch, seqLen) {
111
+ const d = this.device;
112
+ const { dModel, nHeads, hasFfn } = this.config;
113
+ const dh = this.dHead;
114
+ const B = batch;
115
+ const L = seqLen;
116
+ const M = B * L;
117
+ const H = nHeads;
118
+ // 1. Pre-block RMSNorm
119
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
120
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
121
+ {
122
+ const params = new ArrayBuffer(16);
123
+ new Uint32Array(params, 0, 2).set([M, dModel]);
124
+ new Float32Array(params, 8, 1).set([1e-6]);
125
+ const pBuf = createUniformBuffer(d, params);
126
+ const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, xBuf, this.gpuWeights['normWeight'], normOut, normInv]);
127
+ dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
128
+ }
129
+ normInv.destroy();
130
+ // 2. QKV projection: [B, L, 3*D]
131
+ const qkvOut = createEmptyStorageBuffer(d, M * 3 * dModel * 4, true);
132
+ {
133
+ const params = new Uint32Array([M, dModel, 3 * dModel]).buffer;
134
+ const pBuf = createUniformBuffer(d, params);
135
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, normOut, this.gpuWeights['wQKV'], this.gpuWeights['bQKV'], qkvOut]);
136
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(3 * dModel, 16), 1]);
137
+ }
138
+ normOut.destroy();
139
+ // Split QKV into Q, K, V: each [B, L, H, dh] = [B, L, D]
140
+ const QBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
141
+ const KBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
142
+ const VBuf = createEmptyStorageBuffer(d, M * dModel * 4, true);
143
+ {
144
+ const enc = d.createCommandEncoder();
145
+ enc.copyBufferToBuffer(qkvOut, 0, QBuf, 0, M * dModel * 4);
146
+ enc.copyBufferToBuffer(qkvOut, M * dModel * 4, KBuf, 0, M * dModel * 4);
147
+ enc.copyBufferToBuffer(qkvOut, 2 * M * dModel * 4, VBuf, 0, M * dModel * 4);
148
+ d.queue.submit([enc.finish()]);
149
+ }
150
+ qkvOut.destroy();
151
+ // 3. Attention scores: [B, H, L, L]
152
+ const scores = createEmptyStorageBuffer(d, B * H * L * L * 4, true);
153
+ {
154
+ const attnParams = new Uint32Array([B, L, dModel, H, dh]).buffer;
155
+ const pBuf = createUniformBuffer(d, attnParams);
156
+ const bg = createBindGroup(d, this.pipelines['attn_fwd'], [pBuf, QBuf, KBuf, VBuf, scores,
157
+ createEmptyStorageBuffer(d, M * dModel * 4, true)]); // out_buf placeholder
158
+ dispatchKernel(d, this.pipelines['attn_fwd'], bg, [cdiv(L, 16), H, B]);
159
+ }
160
+ // 4. Softmax (causal) per row: dispatch (L, H, B)
161
+ {
162
+ const smParams = new Uint32Array([L, L, 1]).buffer;
163
+ const pBuf = createUniformBuffer(d, smParams);
164
+ const bg = createBindGroup(d, this.pipelines['softmax'], [pBuf, scores]);
165
+ dispatchKernel(d, this.pipelines['softmax'], bg, [L, H, B]);
166
+ }
167
+ // 5. Weighted V sum → attn output [B, L, H, dh]
168
+ const attnOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
169
+ {
170
+ const attnParams = new Uint32Array([B, L, dModel, H, dh]).buffer;
171
+ const pBuf = createUniformBuffer(d, attnParams);
172
+ const bg = createBindGroup(d, this.pipelines['attn_val'], [pBuf, QBuf, KBuf, VBuf, scores, attnOut]);
173
+ dispatchKernel(d, this.pipelines['attn_val'], bg, [cdiv(L, 16), H, B]);
174
+ }
175
+ QBuf.destroy();
176
+ KBuf.destroy();
177
+ VBuf.destroy();
178
+ // 6. Output projection: [B, L, D] → [B, L, D]
179
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
180
+ {
181
+ const params = new Uint32Array([M, dModel, dModel]).buffer;
182
+ const pBuf = createUniformBuffer(d, params);
183
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, attnOut, this.gpuWeights['wO'], this.gpuWeights['bO'], outProjOut]);
184
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
185
+ }
186
+ attnOut.destroy();
187
+ // 7. Residual add
188
+ let current = createEmptyStorageBuffer(d, M * dModel * 4, true);
189
+ {
190
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
191
+ const bg = createBindGroup(d, this.pipelines['elAdd'], [outProjOut, xBuf, current, nBuf]);
192
+ dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
193
+ }
194
+ outProjOut.destroy();
195
+ // 8. Optional FFN sublayer
196
+ if (hasFfn) {
197
+ const { ffnMult } = this.config;
198
+ const ffnDim = dModel * ffnMult;
199
+ const ffn1Out = createEmptyStorageBuffer(d, M * ffnDim * 4, true);
200
+ {
201
+ const params = new Uint32Array([M, dModel, ffnDim]).buffer;
202
+ const pBuf = createUniformBuffer(d, params);
203
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, current, this.gpuWeights['wFfn1'], this.gpuWeights['bFfn1'], ffn1Out]);
204
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(ffnDim, 16), 1]);
205
+ }
206
+ const siluOut = createEmptyStorageBuffer(d, M * ffnDim * 4, true);
207
+ {
208
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * ffnDim]).buffer);
209
+ const bg = createBindGroup(d, this.pipelines['silu'], [nBuf, ffn1Out, siluOut]);
210
+ dispatchKernel(d, this.pipelines['silu'], bg, [cdiv(M * ffnDim, 256), 1, 1]);
211
+ }
212
+ ffn1Out.destroy();
213
+ const ffn2Out = createEmptyStorageBuffer(d, M * dModel * 4, true);
214
+ {
215
+ const params = new Uint32Array([M, ffnDim, dModel]).buffer;
216
+ const pBuf = createUniformBuffer(d, params);
217
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, siluOut, this.gpuWeights['wFfn2'], this.gpuWeights['bFfn2'], ffn2Out]);
218
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
219
+ }
220
+ siluOut.destroy();
221
+ const residual2 = createEmptyStorageBuffer(d, M * dModel * 4, true);
222
+ {
223
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
224
+ const bg = createBindGroup(d, this.pipelines['elAdd'], [ffn2Out, current, residual2, nBuf]);
225
+ dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
226
+ }
227
+ ffn2Out.destroy();
228
+ current.destroy();
229
+ current = residual2;
230
+ }
231
+ const cache = { scores };
232
+ return { output: current, cache };
233
+ }
234
+ parameters() {
235
+ const { dModel, hasFfn, ffnMult } = this.config;
236
+ const params = [
237
+ { buf: this.gpuWeights['wQKV'], numel: 3 * dModel * dModel, name: 'wQKV' },
238
+ { buf: this.gpuWeights['bQKV'], numel: 3 * dModel, name: 'bQKV' },
239
+ { buf: this.gpuWeights['wO'], numel: dModel * dModel, name: 'wO' },
240
+ { buf: this.gpuWeights['bO'], numel: dModel, name: 'bO' },
241
+ { buf: this.gpuWeights['normWeight'], numel: dModel, name: 'normWeight' },
242
+ ];
243
+ if (hasFfn) {
244
+ const ffnDim = dModel * ffnMult;
245
+ params.push({ buf: this.gpuWeights['wFfn1'], numel: ffnDim * dModel, name: 'wFfn1' }, { buf: this.gpuWeights['bFfn1'], numel: ffnDim, name: 'bFfn1' }, { buf: this.gpuWeights['wFfn2'], numel: dModel * ffnDim, name: 'wFfn2' }, { buf: this.gpuWeights['bFfn2'], numel: dModel, name: 'bFfn2' });
246
+ }
247
+ return params;
248
+ }
249
+ getTrainableParams() {
250
+ // Attention layers are always fully trained — no WSLA subset
251
+ return this.parameters();
252
+ }
253
+ setWSLAMode(_enabled) {
254
+ // No-op for attention: WSLA does not apply
255
+ }
256
+ destroy() {
257
+ for (const buf of Object.values(this.gpuWeights))
258
+ buf.destroy();
259
+ this.gpuWeights = {};
260
+ }
261
+ }
262
+ //# sourceMappingURL=attention_block.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"attention_block.js","sourceRoot":"","sources":["../../src/model/attention_block.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;GAmBG;AAEH,OAAO,EACH,qBAAqB,EACrB,eAAe,EACf,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,cAAc,EACd,IAAI,GACP,MAAM,uBAAuB,CAAC;AAE/B,OAAO,EACH,sBAAsB,EACtB,YAAY,GACf,MAAM,yBAAyB,CAAC;AACjC,OAAO,EAAE,mBAAmB,EAAE,MAAM,iCAAiC,CAAC;AACtE,OAAO,EAAE,aAAa,EAAE,MAAM,iBAAiB,CAAC;AAChD,OAAO,EAAE,gBAAgB,EAAE,MAAS,2BAA2B,CAAC;AAgBhE,MAAM,UAAU,GAAG,UAAU,CAAA;;;;;;;;;;CAU5B,CAAC;AAEF,eAAe;AACf,MAAM,WAAW,GAAG,UAAU,CAAA;;;;;;;;;;;;CAY7B,CAAC;AAEF,MAAM,OAAO,cAAc;IACd,SAAS,GAAG,WAAoB,CAAC;IAE1C,MAAM,CAAa;IACnB,MAAM,CAAkC;IACxC,KAAK,CAAW;IAEhB,UAAU,CAA4B;IACtC,SAAS,CAAsC;IAE/C,YAAY,MAAiB,EAAE,MAA4B;QACvD,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,MAAM,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YACtC,MAAM,IAAI,KAAK,CACX,2BAA2B,MAAM,CAAC,MAAM,kCAAkC,MAAM,CAAC,MAAM,IAAI,CAC9F,CAAC;QACN,CAAC;QAED,IAAI,CAAC,MAAM,GAAG;YACV,KAAK,EAAI,MAAM,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM;YACtC,MAAM,EAAG,KAAK;YACd,OAAO,EAAE,CAAC;YACV,GAAG,MAAM;SACsB,CAAC;QAEpC,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC;QAE/B,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;QACrB,IAAI,CAAC,SAAS,GAAI,EAAE,CAAC;QAErB,IAAI,CAAC,YAAY,EAAE,CAAC;QACpB,IAAI,CAAC,eAAe,EAAE,CAAC;IAC3B,CAAC;IAEO,YAAY;QAChB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAExD,MAAM,KAAK,GAAG,CAAC,CAAS,EAAE,GAAG,GAAG,IAAI,EAAgB,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAE7E,MAAM,KAAK,GAAG,CAAC,CAAS,EAAE,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAI,CAAC,CAAS,EAAE,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAC3D,MAAM,EAAE,GAAM,CAAC,GAAiB,EAAE,EAAE,CAAC,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC;QAEjF,IAAI,CAAC,UAAU,GAAG;YACd,IAAI,EAAQ,EAAE,CAAC,KAAK,CAAC,CAAC,GAAG,MAAM,GAAG,MAAM,CAAC,CAAC;YAC1C,IAAI,EAAQ,EAAE,CAAC,KAAK,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC;YACjC,EAAE,EAAU,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC;YACtC,EAAE,EAAU,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;YAC7B,UAAU,EAAE,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;SAC/B,CAAC;QAEF,IAAI,MAAM,EAAE,CAAC;YACT,MAAM,MAAM,GAAG,MAAM,GAAG,OAAO,CAAC;YAChC,IAAI,CAAC,UAAU,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC;YACtD,IAAI,CAAC,UAAU,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC;YAC7C,IAAI,CAAC,UAAU,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC;YACtD,IAAI,CAAC,UAAU,CAAC,OAAO,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC;QACjD,CAAC;IACL,CAAC;IAEO,eAAe;QACnB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,IAAI,CAAC,SAAS,GAAG;YACb,MAAM,EAAI,qBAAqB,CAAC,CAAC,EAAE,mBAAmB,EAAM,gBAAgB,CAAC;YAC7E,OAAO,EAAG,qBAAqB,CAAC,CAAC,EAAE,gBAAgB,EAAS,iBAAiB,CAAC;YAC9E,QAAQ,EAAE,qBAAqB,CAAC,CAAC,EAAE,sBAAsB,EAAG,mBAAmB,CAAC;YAChF,QAAQ,EAAE,qBAAqB,CAAC,CAAC,EAAE,sBAAsB,EAAG,iBAAiB,CAAC;YAC9E,OAAO,EAAG,qBAAqB,CAAC,CAAC,EAAE,YAAY,EAAa,iBAAiB,CAAC;YAC9E,KAAK,EAAK,qBAAqB,CAAC,CAAC,EAAE,UAAU,EAAe,MAAM,CAAC;SACtE,CAAC;QAEF,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC;YACrB,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,GAAG,qBAAqB,CAAC,CAAC,EAAE,WAAW,EAAE,cAAc,CAAC,CAAC;QACnF,CAAC;IACL,CAAC;IAED,OAAO,CAAC,IAAe,EAAE,KAAa,EAAE,MAAc;QAClD,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC/C,MAAM,EAAE,GAAG,IAAI,CAAC,KAAK,CAAC;QACtB,MAAM,CAAC,GAAI,KAAK,CAAC;QACjB,MAAM,CAAC,GAAI,MAAM,CAAC;QAClB,MAAM,CAAC,GAAI,CAAC,GAAG,CAAC,CAAC;QACjB,MAAM,CAAC,GAAI,MAAM,CAAC;QAElB,uBAAuB;QACvB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAClE,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACzD,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;YACnC,IAAI,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC;YAC/C,IAAI,YAAY,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;YAC3C,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC5C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EACpD,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;YACpE,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QACD,OAAO,CAAC,OAAO,EAAE,CAAC;QAElB,iCAAiC;QACjC,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACrE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;YAC/D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,MAAM,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,MAAM,CAAE,EAAE,MAAM,CAAC,CAAC,CAAC;YACjF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7F,CAAC;QACD,OAAO,CAAC,OAAO,EAAE,CAAC;QAElB,yDAAyD;QACzD,MAAM,IAAI,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,MAAM,IAAI,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,MAAM,IAAI,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,CAAC;YACG,MAAM,GAAG,GAAG,CAAC,CAAC,oBAAoB,EAAE,CAAC;YACrC,GAAG,CAAC,kBAAkB,CAAC,MAAM,EAAE,CAAC,EAAgB,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;YACzE,GAAG,CAAC,kBAAkB,CAAC,MAAM,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAI,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;YAC1E,GAAG,CAAC,kBAAkB,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,CAAC,CAAC;YAC5E,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC;QACD,MAAM,CAAC,OAAO,EAAE,CAAC;QAEjB,oCAAoC;QACpC,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACpE,CAAC;YACG,MAAM,UAAU,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC;YACjE,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;YAChD,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EACrD,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,MAAM;gBAC9B,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,CAAE,sBAAsB;YACjF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5E,CAAC;QAED,kDAAkD;QAClD,CAAC;YACG,MAAM,QAAQ,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACnD,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EACpD,CAAC,IAAI,EAAE,MAAM,CAAC,CAAC,CAAC;YACpB,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjE,CAAC;QAED,gDAAgD;QAChD,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAClE,CAAC;YACG,MAAM,UAAU,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC;YACjE,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;YAChD,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EACrD,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,MAAM,EAAE,OAAO,CAAC,CAAC,CAAC;YAC/C,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5E,CAAC;QACD,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,IAAI,CAAC,OAAO,EAAE,CAAC;QAEf,8CAA8C;QAC9C,MAAM,UAAU,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACrE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;YAC3D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAE,EAAE,UAAU,CAAC,CAAC,CAAC;YACjF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACzF,CAAC;QACD,OAAO,CAAC,OAAO,EAAE,CAAC;QAElB,kBAAkB;QAClB,IAAI,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAChE,CAAC;YACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YAC1E,MAAM,EAAE,GAAK,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACpD,CAAC,UAAU,EAAE,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,CAAC,CAAC;YACvC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnF,CAAC;QACD,UAAU,CAAC,OAAO,EAAE,CAAC;QAErB,2BAA2B;QAC3B,IAAI,MAAM,EAAE,CAAC;YACT,MAAM,EAAE,OAAO,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;YAChC,MAAM,MAAM,GAAG,MAAM,GAAG,OAAO,CAAC;YAEhC,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;YAClE,CAAC;gBACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;gBAC3D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;gBAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,OAAO,CAAC,CAAC,CAAC;gBACpF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YACzF,CAAC;YAED,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;YAClE,CAAC;gBACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBAC1E,MAAM,EAAE,GAAK,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;gBAC9B,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAClF,CAAC;YACD,OAAO,CAAC,OAAO,EAAE,CAAC;YAElB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;YAClE,CAAC;gBACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;gBAC3D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;gBAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,OAAO,CAAC,CAAC,CAAC;gBACpF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YACzF,CAAC;YACD,OAAO,CAAC,OAAO,EAAE,CAAC;YAElB,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;YACpE,CAAC;gBACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBAC1E,MAAM,EAAE,GAAK,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACpD,CAAC,OAAO,EAAE,OAAO,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC;gBACzC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YACnF,CAAC;YACD,OAAO,CAAC,OAAO,EAAE,CAAC;YAClB,OAAO,CAAC,OAAO,EAAE,CAAC;YAClB,OAAO,GAAG,SAAS,CAAC;QACxB,CAAC;QAED,MAAM,KAAK,GAAmB,EAAE,MAAM,EAAE,CAAC;QACzC,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,KAAK,EAAE,CAAC;IACtC,CAAC;IAED,UAAU;QACN,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAChD,MAAM,MAAM,GAAiB;YACzB,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,MAAM,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,MAAM,GAAG,MAAM,EAAE,IAAI,EAAE,MAAM,EAAO;YACrF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,MAAM,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,MAAM,EAAW,IAAI,EAAE,MAAM,EAAO;YACrF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAE,EAAU,KAAK,EAAE,MAAM,GAAG,MAAM,EAAM,IAAI,EAAE,IAAI,EAAS;YACtF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,IAAI,CAAE,EAAU,KAAK,EAAE,MAAM,EAAe,IAAI,EAAE,IAAI,EAAS;YACtF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,KAAK,EAAE,MAAM,EAAe,IAAI,EAAE,YAAY,EAAC;SACzF,CAAC;QAEF,IAAI,MAAM,EAAE,CAAC;YACT,MAAM,MAAM,GAAG,MAAM,GAAG,OAAO,CAAC;YAChC,MAAM,CAAC,IAAI,CACP,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,EACzE,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,EAAW,IAAI,EAAE,OAAO,EAAE,EACzE,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,EACzE,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,EAAW,IAAI,EAAE,OAAO,EAAE,CAC5E,CAAC;QACN,CAAC;QAED,OAAO,MAAM,CAAC;IAClB,CAAC;IAED,kBAAkB;QACd,6DAA6D;QAC7D,OAAO,IAAI,CAAC,UAAU,EAAE,CAAC;IAC7B,CAAC;IAED,WAAW,CAAC,QAAiB;QACzB,2CAA2C;IAC/C,CAAC;IAED,OAAO;QACH,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC;YAAE,GAAG,CAAC,OAAO,EAAE,CAAC;QAChE,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;IACzB,CAAC;CACJ"}
@@ -0,0 +1,70 @@
1
+ /**
2
+ * mamba1_block.ts – Mamba-1 Mixer Block (S6 selective scan).
3
+ *
4
+ * Renamed from mamba_block.ts; MambaBlock is kept as a deprecated alias.
5
+ * Implements SequenceLayer so HybridMambaModel can iterate blocks generically.
6
+ */
7
+ import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
8
+ export interface Mamba1BlockConfig {
9
+ dModel: number;
10
+ dState?: number;
11
+ dConv?: number;
12
+ expand?: number;
13
+ dtRank?: number;
14
+ biasConv?: boolean;
15
+ }
16
+ /** @deprecated Use LayerParam */
17
+ export type BlockParam = LayerParam;
18
+ export interface BlockCache {
19
+ normInv: GPUBuffer;
20
+ normIn: GPUBuffer;
21
+ normOut: GPUBuffer;
22
+ zBuf: GPUBuffer;
23
+ xConvIn: GPUBuffer;
24
+ convOut: GPUBuffer;
25
+ siluOut: GPUBuffer;
26
+ deltaFull: GPUBuffer;
27
+ B_raw: GPUBuffer;
28
+ C_raw: GPUBuffer;
29
+ hCache: GPUBuffer;
30
+ }
31
+ export interface BlockForwardResult extends LayerForwardResult {
32
+ output: GPUBuffer;
33
+ cache: BlockCache;
34
+ }
35
+ export declare class Mamba1Block implements SequenceLayer {
36
+ readonly layerType: "mamba1";
37
+ device: GPUDevice;
38
+ config: Required<Mamba1BlockConfig>;
39
+ dInner: number;
40
+ dtRank: number;
41
+ wInProj: Float32Array;
42
+ bInProj: Float32Array;
43
+ wConv: Float32Array;
44
+ bConv: Float32Array;
45
+ wXProj: Float32Array;
46
+ bXProj: Float32Array;
47
+ wDtProj: Float32Array;
48
+ bDtProj: Float32Array;
49
+ A_log: Float32Array;
50
+ D_vec: Float32Array;
51
+ wOutProj: Float32Array;
52
+ bOutProj: Float32Array;
53
+ normWeight: Float32Array;
54
+ gpuWeights: Record<string, GPUBuffer>;
55
+ pipelines: Record<string, GPUComputePipeline>;
56
+ private _wslaMode;
57
+ constructor(device: GPUDevice, config: Mamba1BlockConfig);
58
+ private _initWeights;
59
+ private _uploadWeightsToGPU;
60
+ private _buildPipelines;
61
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult;
62
+ parameters(): LayerParam[];
63
+ getTrainableParams(): LayerParam[];
64
+ setWSLAMode(enabled: boolean): void;
65
+ destroy(): void;
66
+ }
67
+ export { Mamba1Block as MambaBlock };
68
+ /** @deprecated Use Mamba1BlockConfig */
69
+ export type MambaBlockConfig = Mamba1BlockConfig;
70
+ //# sourceMappingURL=mamba1_block.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"mamba1_block.d.ts","sourceRoot":"","sources":["../../src/model/mamba1_block.ts"],"names":[],"mappings":"AAAA;;;;;GAKG;AAkBH,OAAO,KAAK,EAAE,aAAa,EAAE,kBAAkB,EAAE,UAAU,EAAE,MAAM,qBAAqB,CAAC;AAEzF,MAAM,WAAW,iBAAiB;IAC9B,MAAM,EAAK,MAAM,CAAC;IAClB,MAAM,CAAC,EAAI,MAAM,CAAC;IAClB,KAAK,CAAC,EAAK,MAAM,CAAC;IAClB,MAAM,CAAC,EAAI,MAAM,CAAC;IAClB,MAAM,CAAC,EAAI,MAAM,CAAC;IAClB,QAAQ,CAAC,EAAE,OAAO,CAAC;CACtB;AAED,iCAAiC;AACjC,MAAM,MAAM,UAAU,GAAG,UAAU,CAAC;AAEpC,MAAM,WAAW,UAAU;IACvB,OAAO,EAAK,SAAS,CAAC;IACtB,MAAM,EAAM,SAAS,CAAC;IACtB,OAAO,EAAK,SAAS,CAAC;IACtB,IAAI,EAAQ,SAAS,CAAC;IACtB,OAAO,EAAK,SAAS,CAAC;IACtB,OAAO,EAAK,SAAS,CAAC;IACtB,OAAO,EAAK,SAAS,CAAC;IACtB,SAAS,EAAG,SAAS,CAAC;IACtB,KAAK,EAAO,SAAS,CAAC;IACtB,KAAK,EAAO,SAAS,CAAC;IACtB,MAAM,EAAM,SAAS,CAAC;CACzB;AAED,MAAM,WAAW,kBAAmB,SAAQ,kBAAkB;IAC1D,MAAM,EAAG,SAAS,CAAC;IACnB,KAAK,EAAI,UAAU,CAAC;CACvB;AA8BD,qBAAa,WAAY,YAAW,aAAa;IAC7C,QAAQ,CAAC,SAAS,EAAG,QAAQ,CAAU;IAEvC,MAAM,EAAI,SAAS,CAAC;IACpB,MAAM,EAAI,QAAQ,CAAC,iBAAiB,CAAC,CAAC;IACtC,MAAM,EAAI,MAAM,CAAC;IACjB,MAAM,EAAI,MAAM,CAAC;IAEjB,OAAO,EAAK,YAAY,CAAC;IACzB,OAAO,EAAK,YAAY,CAAC;IACzB,KAAK,EAAO,YAAY,CAAC;IACzB,KAAK,EAAO,YAAY,CAAC;IACzB,MAAM,EAAM,YAAY,CAAC;IACzB,MAAM,EAAM,YAAY,CAAC;IACzB,OAAO,EAAK,YAAY,CAAC;IACzB,OAAO,EAAK,YAAY,CAAC;IACzB,KAAK,EAAO,YAAY,CAAC;IACzB,KAAK,EAAO,YAAY,CAAC;IACzB,QAAQ,EAAI,YAAY,CAAC;IACzB,QAAQ,EAAI,YAAY,CAAC;IACzB,UAAU,EAAE,YAAY,CAAC;IAEzB,UAAU,EAAG,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;IACvC,SAAS,EAAI,MAAM,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAEhD,OAAO,CAAC,SAAS,CAAS;gBAEd,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,iBAAiB;IAmCxD,OAAO,CAAC,YAAY;IAoCpB,OAAO,CAAC,mBAAmB;IAqB3B,OAAO,CAAC,eAAe;IAcvB,OAAO,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,kBAAkB;IAyK3E,UAAU,IAAI,UAAU,EAAE;IAwB1B,kBAAkB,IAAI,UAAU,EAAE;IAUlC,WAAW,CAAC,OAAO,EAAE,OAAO,GAAG,IAAI;IAInC,OAAO,IAAI,IAAI;CAMlB;AAGD,OAAO,EAAE,WAAW,IAAI,UAAU,EAAE,CAAC;AAErC,wCAAwC;AACxC,MAAM,MAAM,gBAAgB,GAAG,iBAAiB,CAAC"}