@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
// Parallel Selective Scan WGSL Kernel
|
|
2
|
+
// Implements the S6 (Selective Scan) core of the Mamba architecture.
|
|
3
|
+
// Uses a Kogge-Stone parallel prefix-sum approach for O(log N) time on GPU.
|
|
4
|
+
//
|
|
5
|
+
// Forward pass recurrence:
|
|
6
|
+
// h_t = A_t * h_{t-1} + B_t * x_t
|
|
7
|
+
// y_t = C_t * h_t + D * x_t
|
|
8
|
+
//
|
|
9
|
+
// where A_t, B_t, C_t are input-dependent (selective) gate matrices.
|
|
10
|
+
export const SELECTIVE_SCAN_FORWARD_WGSL = /* wgsl */ `
|
|
11
|
+
|
|
12
|
+
// ---- Binding layout ----
|
|
13
|
+
// group 0: sequence data
|
|
14
|
+
// group 1: SSM parameters
|
|
15
|
+
|
|
16
|
+
struct ScanParams {
|
|
17
|
+
seq_len : u32, // L – sequence length
|
|
18
|
+
d_state : u32, // N – state dimension
|
|
19
|
+
d_inner : u32, // D – inner (expanded) channel dimension
|
|
20
|
+
batch : u32, // B – batch size
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
@group(0) @binding(0) var<uniform> params : ScanParams;
|
|
24
|
+
// u (B, L, D) – projected input after conv
|
|
25
|
+
@group(0) @binding(1) var<storage, read> u : array<f32>;
|
|
26
|
+
// delta (B, L, D) – time-step (Δ) after softplus
|
|
27
|
+
@group(0) @binding(2) var<storage, read> delta : array<f32>;
|
|
28
|
+
// A (D, N) – log-space diagonal state matrix (fixed, learned)
|
|
29
|
+
@group(0) @binding(3) var<storage, read> A : array<f32>;
|
|
30
|
+
// B (B, L, N) – input projection (selective)
|
|
31
|
+
@group(0) @binding(4) var<storage, read> B : array<f32>;
|
|
32
|
+
// C (B, L, N) – output projection (selective)
|
|
33
|
+
@group(0) @binding(5) var<storage, read> C : array<f32>;
|
|
34
|
+
// D (D,) – skip-connection scale
|
|
35
|
+
@group(0) @binding(6) var<storage, read> D_vec : array<f32>;
|
|
36
|
+
// y (B, L, D) – output (written by this kernel)
|
|
37
|
+
@group(0) @binding(7) var<storage, read_write> y : array<f32>;
|
|
38
|
+
// h_cache (B, L, D*N) – hidden states cache (for backward pass)
|
|
39
|
+
@group(0) @binding(8) var<storage, read_write> h_cache : array<f32>;
|
|
40
|
+
|
|
41
|
+
// ---- Workgroup shared memory ----
|
|
42
|
+
// Each workgroup processes one (batch, channel) slice across all time steps.
|
|
43
|
+
// We store the associative pair (a_bar, bu_bar) per time step so we can run
|
|
44
|
+
// a Kogge-Stone scan across the workgroup tile.
|
|
45
|
+
var<workgroup> wg_a : array<f32, 256>; // discretised A values
|
|
46
|
+
var<workgroup> wg_bu : array<f32, 256>; // B*u values
|
|
47
|
+
|
|
48
|
+
// ---- Helpers ----
|
|
49
|
+
|
|
50
|
+
// Softplus: numerically stable log(1 + exp(x))
|
|
51
|
+
fn softplus(x: f32) -> f32 {
|
|
52
|
+
return log(1.0 + exp(x));
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// ZerO-Order Hold discretisation of continuous A, Δ:
|
|
56
|
+
// A_bar = exp(Δ * A)
|
|
57
|
+
// B_bar = (A_bar - 1) / A * B ≈ Δ * B (first-order for simplicity)
|
|
58
|
+
fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
|
|
59
|
+
// A is stored as -exp(a_log) to ensure A_bar < 1 (stable)
|
|
60
|
+
let a_cont = -exp(a_log);
|
|
61
|
+
return exp(delta_val * a_cont);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
fn discretise_B(delta_val: f32, a_log: f32, b_val: f32) -> f32 {
|
|
65
|
+
let a_cont = -exp(a_log);
|
|
66
|
+
let a_bar = exp(delta_val * a_cont);
|
|
67
|
+
// (A_bar - 1) / A_cont * B
|
|
68
|
+
let b_bar = (a_bar - 1.0) / a_cont * b_val;
|
|
69
|
+
return b_bar;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// ---- Main kernel ----
|
|
73
|
+
// Dispatch: (ceil(D/8), ceil(N/8), B)
|
|
74
|
+
// Each invocation is responsible for one (d, n, batch) triplet and scans
|
|
75
|
+
// the entire sequence using a two-pass Kogge-Stone scan within workgroup tiles.
|
|
76
|
+
|
|
77
|
+
@compute @workgroup_size(64, 1, 1)
|
|
78
|
+
fn forward_scan(
|
|
79
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
80
|
+
@builtin(local_invocation_index) lid : u32,
|
|
81
|
+
@builtin(workgroup_id) wgid : vec3<u32>,
|
|
82
|
+
) {
|
|
83
|
+
let L = params.seq_len;
|
|
84
|
+
let N = params.d_state;
|
|
85
|
+
let D = params.d_inner;
|
|
86
|
+
let B = params.batch;
|
|
87
|
+
|
|
88
|
+
// Each workgroup handles one (batch b, channel d, state n) combination.
|
|
89
|
+
// We pack d and n into the x dimension: global d = wgid.x, global n = wgid.y
|
|
90
|
+
let d = wgid.x;
|
|
91
|
+
let n = wgid.y;
|
|
92
|
+
let b = gid.z;
|
|
93
|
+
|
|
94
|
+
if (d >= D || n >= N || b >= B) { return; }
|
|
95
|
+
|
|
96
|
+
// Tile size equals workgroup size (64). We process TILE_SIZE steps at once.
|
|
97
|
+
let TILE: u32 = 64u;
|
|
98
|
+
|
|
99
|
+
// Running state h for this (b, d, n)
|
|
100
|
+
var h: f32 = 0.0;
|
|
101
|
+
|
|
102
|
+
var tile_start: u32 = 0u;
|
|
103
|
+
loop {
|
|
104
|
+
if (tile_start >= L) { break; }
|
|
105
|
+
|
|
106
|
+
let t = tile_start + lid; // absolute time step handled by this lane
|
|
107
|
+
var a_bar: f32 = 1.0;
|
|
108
|
+
var bu: f32 = 0.0;
|
|
109
|
+
|
|
110
|
+
if (t < L) {
|
|
111
|
+
// Indices
|
|
112
|
+
let delta_idx = b * L * D + t * D + d;
|
|
113
|
+
let u_idx = b * L * D + t * D + d;
|
|
114
|
+
let A_idx = d * N + n;
|
|
115
|
+
let B_idx = b * L * N + t * N + n;
|
|
116
|
+
|
|
117
|
+
let dv = softplus(delta[delta_idx]);
|
|
118
|
+
a_bar = discretise_A(dv, A[A_idx]);
|
|
119
|
+
bu = discretise_B(dv, A[A_idx], B[B_idx]) * u[u_idx];
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
wg_a[lid] = a_bar;
|
|
123
|
+
wg_bu[lid] = bu;
|
|
124
|
+
workgroupBarrier();
|
|
125
|
+
|
|
126
|
+
// ---- Kogge-Stone inclusive prefix scan within tile ----
|
|
127
|
+
// Associative operator: (a1, b1) ∘ (a2, b2) = (a1*a2, a1*b2 + b1)
|
|
128
|
+
// This computes cumulative state recurrence in log2(TILE) steps.
|
|
129
|
+
var stride: u32 = 1u;
|
|
130
|
+
loop {
|
|
131
|
+
if (stride >= TILE) { break; }
|
|
132
|
+
if (lid >= stride) {
|
|
133
|
+
let prev_a = wg_a[lid - stride];
|
|
134
|
+
let prev_bu = wg_bu[lid - stride];
|
|
135
|
+
// Combine: new_state = prev_a * cur_a (product of A_bars)
|
|
136
|
+
// new_bu = prev_a * cur_bu + prev_bu
|
|
137
|
+
let new_a = prev_a * wg_a[lid];
|
|
138
|
+
let new_bu = prev_a * wg_bu[lid] + prev_bu;
|
|
139
|
+
workgroupBarrier();
|
|
140
|
+
wg_a[lid] = new_a;
|
|
141
|
+
wg_bu[lid] = new_bu;
|
|
142
|
+
}
|
|
143
|
+
workgroupBarrier();
|
|
144
|
+
stride = stride << 1u;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Incorporate the carry-in state from the previous tile.
|
|
148
|
+
// After the scan wg_bu[lid] holds the intra-tile inclusive sum.
|
|
149
|
+
// The actual h at position t = h_carry * wg_a[lid] + wg_bu[lid]
|
|
150
|
+
let h_t = h * wg_a[lid] + wg_bu[lid];
|
|
151
|
+
|
|
152
|
+
if (t < L) {
|
|
153
|
+
// Cache hidden state for backward pass
|
|
154
|
+
let h_idx = b * L * D * N + t * D * N + d * N + n;
|
|
155
|
+
h_cache[h_idx] = h_t;
|
|
156
|
+
|
|
157
|
+
// Accumulate y contribution: y_t += C_t[n] * h_t (over all n)
|
|
158
|
+
// We use an atomic-style accumulation: each (d, n) lane adds its
|
|
159
|
+
// contribution to the same y[b, t, d]. This races without atomics,
|
|
160
|
+
// so we instead write to a full h_cache and reduce in a second pass.
|
|
161
|
+
// Here we perform direct accumulation using atomicAdd approximation:
|
|
162
|
+
// (safe because each lane writes a unique n, which is stride 1 in mem)
|
|
163
|
+
let C_idx = b * L * N + t * N + n;
|
|
164
|
+
let y_idx = b * L * D + t * D + d;
|
|
165
|
+
|
|
166
|
+
// Direct write for n == 0 (first state dim), add for the rest.
|
|
167
|
+
// Since all workgroups for the same (b,d) run concurrently we must
|
|
168
|
+
// accumulate safely: we write each partial into h_cache and reduce
|
|
169
|
+
// in a subsequent lightweight kernel (forward_reduce).
|
|
170
|
+
// (For simplicity and correctness here we directly atomically add via
|
|
171
|
+
// f32 emulation – real deployment uses atomicAdd on f32 with spirv ext.)
|
|
172
|
+
// We store C*h contribution separately so forward_reduce can sum them.
|
|
173
|
+
// Layout: y_partial (B, L, D, N) – one slot per state dim
|
|
174
|
+
// y reused as y_partial in this kernel; forward_reduce collapses N dim.
|
|
175
|
+
let y_partial_idx = b * L * D * N + t * D * N + d * N + n;
|
|
176
|
+
// Reuse h_cache second half as y_partial (offset by B*L*D*N)
|
|
177
|
+
let offset = B * L * D * N;
|
|
178
|
+
h_cache[offset + y_partial_idx] = C[C_idx] * h_t;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// Update carry: last lane's h_t is the tile's final state
|
|
182
|
+
let last = min(TILE, L - tile_start) - 1u;
|
|
183
|
+
h = wg_a[last] * h + wg_bu[last]; // recombine carry
|
|
184
|
+
|
|
185
|
+
workgroupBarrier();
|
|
186
|
+
tile_start = tile_start + TILE;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
// ---- Reduction kernel ----
|
|
191
|
+
// Collapses the N (d_state) dimension of y_partial into y.
|
|
192
|
+
// Adds the D (skip connection) term: y_t[d] += D_vec[d] * u_t[d]
|
|
193
|
+
// Dispatch: (ceil(L/64), D, B)
|
|
194
|
+
|
|
195
|
+
@compute @workgroup_size(64, 1, 1)
|
|
196
|
+
fn forward_reduce(
|
|
197
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
198
|
+
) {
|
|
199
|
+
let L = params.seq_len;
|
|
200
|
+
let N = params.d_state;
|
|
201
|
+
let D = params.d_inner;
|
|
202
|
+
let B = params.batch;
|
|
203
|
+
|
|
204
|
+
let t = gid.x;
|
|
205
|
+
let d = gid.y;
|
|
206
|
+
let b = gid.z;
|
|
207
|
+
|
|
208
|
+
if (t >= L || d >= D || b >= B) { return; }
|
|
209
|
+
|
|
210
|
+
let offset = B * L * D * N;
|
|
211
|
+
var sum: f32 = 0.0;
|
|
212
|
+
for (var n: u32 = 0u; n < N; n = n + 1u) {
|
|
213
|
+
let idx = offset + b * L * D * N + t * D * N + d * N + n;
|
|
214
|
+
sum = sum + h_cache[idx];
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// Add skip connection
|
|
218
|
+
let u_idx = b * L * D + t * D + d;
|
|
219
|
+
sum = sum + D_vec[d] * u[u_idx];
|
|
220
|
+
|
|
221
|
+
let y_idx = b * L * D + t * D + d;
|
|
222
|
+
y[y_idx] = sum;
|
|
223
|
+
}
|
|
224
|
+
`;
|
|
225
|
+
// ---- Backward scan kernel (for autograd) ----
|
|
226
|
+
// Computes gradients w.r.t. Δ, A, B, C using the cached hidden states.
|
|
227
|
+
export const SELECTIVE_SCAN_BACKWARD_WGSL = /* wgsl */ `
|
|
228
|
+
|
|
229
|
+
struct ScanParams {
|
|
230
|
+
seq_len : u32,
|
|
231
|
+
d_state : u32,
|
|
232
|
+
d_inner : u32,
|
|
233
|
+
batch : u32,
|
|
234
|
+
};
|
|
235
|
+
|
|
236
|
+
@group(0) @binding(0) var<uniform> params : ScanParams;
|
|
237
|
+
@group(0) @binding(1) var<storage, read> u : array<f32>;
|
|
238
|
+
@group(0) @binding(2) var<storage, read> delta : array<f32>;
|
|
239
|
+
@group(0) @binding(3) var<storage, read> A : array<f32>;
|
|
240
|
+
@group(0) @binding(4) var<storage, read> B : array<f32>;
|
|
241
|
+
@group(0) @binding(5) var<storage, read> C : array<f32>;
|
|
242
|
+
@group(0) @binding(6) var<storage, read> h_cache : array<f32>;
|
|
243
|
+
@group(0) @binding(7) var<storage, read> dy : array<f32>; // upstream gradient
|
|
244
|
+
@group(0) @binding(8) var<storage, read_write> dA : array<f32>;
|
|
245
|
+
@group(0) @binding(9) var<storage, read_write> dB : array<f32>;
|
|
246
|
+
@group(0) @binding(10) var<storage, read_write> dC : array<f32>;
|
|
247
|
+
@group(0) @binding(11) var<storage, read_write> dDelta : array<f32>;
|
|
248
|
+
@group(0) @binding(12) var<storage, read_write> du : array<f32>;
|
|
249
|
+
|
|
250
|
+
fn softplus(x: f32) -> f32 {
|
|
251
|
+
return log(1.0 + exp(x));
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
fn softplus_grad(x: f32) -> f32 {
|
|
255
|
+
// d/dx softplus(x) = sigmoid(x)
|
|
256
|
+
return 1.0 / (1.0 + exp(-x));
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
|
|
260
|
+
let a_cont = -exp(a_log);
|
|
261
|
+
return exp(delta_val * a_cont);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Reverse scan (backward pass) – processes time from T-1 down to 0.
|
|
265
|
+
// Dispatch: (D, N, B)
|
|
266
|
+
@compute @workgroup_size(1, 1, 1)
|
|
267
|
+
fn backward_scan(
|
|
268
|
+
@builtin(global_invocation_id) gid : vec3<u32>,
|
|
269
|
+
) {
|
|
270
|
+
let L = params.seq_len;
|
|
271
|
+
let N = params.d_state;
|
|
272
|
+
let D = params.d_inner;
|
|
273
|
+
let B = params.batch;
|
|
274
|
+
|
|
275
|
+
let d = gid.x;
|
|
276
|
+
let n = gid.y;
|
|
277
|
+
let b = gid.z;
|
|
278
|
+
|
|
279
|
+
if (d >= D || n >= N || b >= B) { return; }
|
|
280
|
+
|
|
281
|
+
var dh: f32 = 0.0; // gradient of loss w.r.t. h_t, accumulated backwards
|
|
282
|
+
|
|
283
|
+
var t: u32 = L;
|
|
284
|
+
loop {
|
|
285
|
+
if (t == 0u) { break; }
|
|
286
|
+
t = t - 1u;
|
|
287
|
+
|
|
288
|
+
let delta_raw_idx = b * L * D + t * D + d;
|
|
289
|
+
let A_idx = d * N + n;
|
|
290
|
+
let B_idx = b * L * N + t * N + n;
|
|
291
|
+
let C_idx = b * L * N + t * N + n;
|
|
292
|
+
let u_idx = b * L * D + t * D + d;
|
|
293
|
+
let h_idx = b * L * D * N + t * D * N + d * N + n;
|
|
294
|
+
|
|
295
|
+
let delta_raw = delta[delta_raw_idx];
|
|
296
|
+
let dv = softplus(delta_raw);
|
|
297
|
+
let a_log = A[A_idx];
|
|
298
|
+
let a_cont = -exp(a_log);
|
|
299
|
+
let a_bar = exp(dv * a_cont);
|
|
300
|
+
let b_val = B[B_idx];
|
|
301
|
+
let c_val = C[C_idx];
|
|
302
|
+
let u_val = u[u_idx];
|
|
303
|
+
let h_t = h_cache[h_idx];
|
|
304
|
+
|
|
305
|
+
// dy_t contribution to dh (from C * h_t in the output)
|
|
306
|
+
// y_t[d] = sum_n C[n] * h_t[n] + D * u => dh_t[n] += C[n] * dy_t[d]
|
|
307
|
+
let dy_val = dy[b * L * D + t * D + d];
|
|
308
|
+
dh = dh + c_val * dy_val;
|
|
309
|
+
|
|
310
|
+
// dC[b, t, n] += dy_t[d] * h_t
|
|
311
|
+
dC[C_idx] = dC[C_idx] + dy_val * h_t;
|
|
312
|
+
|
|
313
|
+
// h_t = a_bar * h_{t-1} + b_bar * u_t
|
|
314
|
+
// b_bar = (a_bar - 1) / a_cont * b_val
|
|
315
|
+
let b_bar = (a_bar - 1.0) / a_cont * b_val;
|
|
316
|
+
let h_prev = (t > 0u) ? h_cache[b * L * D * N + (t - 1u) * D * N + d * N + n] : 0.0;
|
|
317
|
+
|
|
318
|
+
// dh_{t-1} += a_bar * dh_t
|
|
319
|
+
// (accumulated in next iteration; here dh already contains upstream)
|
|
320
|
+
let dh_cur = dh;
|
|
321
|
+
|
|
322
|
+
// dA[d,n] += dh_t * (d a_bar/d a_cont) * (d a_cont/d a_log) * h_{t-1}
|
|
323
|
+
// + dh_t * (d b_bar/d a_cont) * ... * b_val * u_val
|
|
324
|
+
// d(a_bar)/d(a_log) = a_bar * (-exp(a_log)) * dv = a_bar * a_cont * dv
|
|
325
|
+
let da_bar_da_log = a_bar * a_cont * dv;
|
|
326
|
+
dA[A_idx] = dA[A_idx] + dh_cur * (da_bar_da_log * h_prev);
|
|
327
|
+
|
|
328
|
+
// dB[b,t,n] += dh_t * b_bar / b_val * u_val (since b_bar is linear in b)
|
|
329
|
+
dB[B_idx] = dB[B_idx] + dh_cur * ((a_bar - 1.0) / a_cont) * u_val;
|
|
330
|
+
|
|
331
|
+
// du[b,t,d] += dh_t * b_bar (accumulate over n in separate kernel)
|
|
332
|
+
du[u_idx] = du[u_idx] + dh_cur * b_bar;
|
|
333
|
+
|
|
334
|
+
// dDelta[b,t,d]: chain rule through softplus and discretisation
|
|
335
|
+
// d(b_bar)/d(dv) = d/d(dv)[(a_bar-1)/a_cont * b] = a_bar * b / (a_cont ... )
|
|
336
|
+
// actually: d(a_bar)/d(dv) = a_bar * a_cont, d(b_bar)/d(dv) = a_bar * b_val
|
|
337
|
+
let da_bar_ddv = a_bar * a_cont;
|
|
338
|
+
let db_bar_ddv = a_bar * b_val;
|
|
339
|
+
let dLoss_ddv = dh_cur * (da_bar_ddv * h_prev + db_bar_ddv * u_val);
|
|
340
|
+
let ddv_ddelta = softplus_grad(delta_raw);
|
|
341
|
+
dDelta[delta_raw_idx] = dDelta[delta_raw_idx] + dLoss_ddv * ddv_ddelta;
|
|
342
|
+
|
|
343
|
+
// Propagate dh to previous timestep
|
|
344
|
+
dh = a_bar * dh_cur;
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
`;
|
|
348
|
+
//# sourceMappingURL=selective_scan.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"selective_scan.js","sourceRoot":"","sources":["../../src/kernels/selective_scan.ts"],"names":[],"mappings":"AAAA,sCAAsC;AACtC,qEAAqE;AACrE,4EAA4E;AAC5E,EAAE;AACF,2BAA2B;AAC3B,oCAAoC;AACpC,8BAA8B;AAC9B,EAAE;AACF,qEAAqE;AAErE,MAAM,CAAC,MAAM,2BAA2B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAsN5D,CAAC;AAEF,gDAAgD;AAChD,uEAAuE;AAEvE,MAAM,CAAC,MAAM,4BAA4B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAwH7D,CAAC"}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ssd.ts – Structured State Space Duality (SSD) kernels for Mamba-2.
|
|
3
|
+
*
|
|
4
|
+
* Implements a chunked SSD algorithm:
|
|
5
|
+
* A_bar_t = exp(-softplus(A_h) · softplus(dt_t + dt_bias_h)) [scalar per head]
|
|
6
|
+
* h_t = A_bar_t · h_{t-1} + B_t · x_t [MIMO per head]
|
|
7
|
+
* y_t = C_t · h_t
|
|
8
|
+
*
|
|
9
|
+
* The sequence is split into chunks of `chunk_len` time steps.
|
|
10
|
+
* Within each chunk the recurrence is run sequentially; the carry-over
|
|
11
|
+
* state `h` is passed forward between chunks via the state_carry buffer.
|
|
12
|
+
*
|
|
13
|
+
* Dispatch for ssd_chunk_forward: (num_chunks, H, B)
|
|
14
|
+
* Dispatch for ssd_chunk_backward: (num_chunks, H, B)
|
|
15
|
+
*
|
|
16
|
+
* Buffer layout (all f32, row-major):
|
|
17
|
+
* x : [B, L, D_inner] where D_inner = H * d_head
|
|
18
|
+
* B_proj : [B, L, n_groups, N]
|
|
19
|
+
* C_proj : [B, L, n_groups, N]
|
|
20
|
+
* dt : [B, L, H]
|
|
21
|
+
* A_log : [H] log(-A), positive scalar per head
|
|
22
|
+
* dt_bias : [H]
|
|
23
|
+
* D_vec : [H] skip connection per head
|
|
24
|
+
* out : [B, L, D_inner] scan output (written by kernel)
|
|
25
|
+
* state_carry : [num_chunks+1, B, H, N, d_head] inter-chunk states
|
|
26
|
+
*/
|
|
27
|
+
export declare const SSD_FORWARD_WGSL: string;
|
|
28
|
+
export declare const SSD_BACKWARD_WGSL: string;
|
|
29
|
+
//# sourceMappingURL=ssd.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"ssd.d.ts","sourceRoot":"","sources":["../../src/kernels/ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;GAyBG;AAEH,eAAO,MAAM,gBAAgB,EAAE,MA+H9B,CAAC;AAIF,eAAO,MAAM,iBAAiB,EAAE,MAuH/B,CAAC"}
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ssd.ts – Structured State Space Duality (SSD) kernels for Mamba-2.
|
|
3
|
+
*
|
|
4
|
+
* Implements a chunked SSD algorithm:
|
|
5
|
+
* A_bar_t = exp(-softplus(A_h) · softplus(dt_t + dt_bias_h)) [scalar per head]
|
|
6
|
+
* h_t = A_bar_t · h_{t-1} + B_t · x_t [MIMO per head]
|
|
7
|
+
* y_t = C_t · h_t
|
|
8
|
+
*
|
|
9
|
+
* The sequence is split into chunks of `chunk_len` time steps.
|
|
10
|
+
* Within each chunk the recurrence is run sequentially; the carry-over
|
|
11
|
+
* state `h` is passed forward between chunks via the state_carry buffer.
|
|
12
|
+
*
|
|
13
|
+
* Dispatch for ssd_chunk_forward: (num_chunks, H, B)
|
|
14
|
+
* Dispatch for ssd_chunk_backward: (num_chunks, H, B)
|
|
15
|
+
*
|
|
16
|
+
* Buffer layout (all f32, row-major):
|
|
17
|
+
* x : [B, L, D_inner] where D_inner = H * d_head
|
|
18
|
+
* B_proj : [B, L, n_groups, N]
|
|
19
|
+
* C_proj : [B, L, n_groups, N]
|
|
20
|
+
* dt : [B, L, H]
|
|
21
|
+
* A_log : [H] log(-A), positive scalar per head
|
|
22
|
+
* dt_bias : [H]
|
|
23
|
+
* D_vec : [H] skip connection per head
|
|
24
|
+
* out : [B, L, D_inner] scan output (written by kernel)
|
|
25
|
+
* state_carry : [num_chunks+1, B, H, N, d_head] inter-chunk states
|
|
26
|
+
*/
|
|
27
|
+
export const SSD_FORWARD_WGSL = /* wgsl */ `
|
|
28
|
+
struct SsdParams {
|
|
29
|
+
seq_len : u32,
|
|
30
|
+
d_inner : u32,
|
|
31
|
+
n_heads : u32,
|
|
32
|
+
d_head : u32, // d_inner / n_heads
|
|
33
|
+
n_groups : u32,
|
|
34
|
+
d_state : u32, // N
|
|
35
|
+
chunk_len : u32,
|
|
36
|
+
n_chunks : u32,
|
|
37
|
+
batch : u32,
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
@group(0) @binding(0) var<uniform> params : SsdParams;
|
|
41
|
+
@group(0) @binding(1) var<storage, read> x_in : array<f32>; // [B,L,D_inner]
|
|
42
|
+
@group(0) @binding(2) var<storage, read> B_proj : array<f32>; // [B,L,n_groups,N]
|
|
43
|
+
@group(0) @binding(3) var<storage, read> C_proj : array<f32>; // [B,L,n_groups,N]
|
|
44
|
+
@group(0) @binding(4) var<storage, read> dt_in : array<f32>; // [B,L,H]
|
|
45
|
+
@group(0) @binding(5) var<storage, read> A_log : array<f32>; // [H]
|
|
46
|
+
@group(0) @binding(6) var<storage, read> dt_bias : array<f32>; // [H]
|
|
47
|
+
@group(0) @binding(7) var<storage, read> D_vec : array<f32>; // [H]
|
|
48
|
+
@group(0) @binding(8) var<storage, read_write> out_buf : array<f32>; // [B,L,D_inner]
|
|
49
|
+
@group(0) @binding(9) var<storage, read_write> state_carry : array<f32>; // [n_chunks+1,B,H,N,d_head]
|
|
50
|
+
|
|
51
|
+
fn softplus(x: f32) -> f32 {
|
|
52
|
+
return log(1.0 + exp(x));
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// Workgroup: one chunk × one head × one batch item
|
|
56
|
+
@compute @workgroup_size(1, 1, 1)
|
|
57
|
+
fn ssd_chunk_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
58
|
+
let chunk_id = gid.x;
|
|
59
|
+
let head_id = gid.y;
|
|
60
|
+
let batch_id = gid.z;
|
|
61
|
+
|
|
62
|
+
let L = params.seq_len;
|
|
63
|
+
let D = params.d_inner;
|
|
64
|
+
let H = params.n_heads;
|
|
65
|
+
let dh = params.d_head;
|
|
66
|
+
let G = params.n_groups;
|
|
67
|
+
let N = params.d_state;
|
|
68
|
+
let CL = params.chunk_len;
|
|
69
|
+
let NC = params.n_chunks;
|
|
70
|
+
let B = params.batch;
|
|
71
|
+
|
|
72
|
+
let t_start = chunk_id * CL;
|
|
73
|
+
let t_end = min(t_start + CL, L);
|
|
74
|
+
|
|
75
|
+
// Group index: heads are partitioned across groups
|
|
76
|
+
let group_id = head_id * G / H;
|
|
77
|
+
|
|
78
|
+
// A scalar for this head
|
|
79
|
+
let neg_A = softplus(A_log[head_id]); // A_log stores log(-A) positive
|
|
80
|
+
let db = dt_bias[head_id];
|
|
81
|
+
let d_skip = D_vec[head_id];
|
|
82
|
+
|
|
83
|
+
// Load carry-in state: h[N, dh] (stored flat as N*dh floats)
|
|
84
|
+
// state_carry layout: [NC+1, B, H, N*dh]
|
|
85
|
+
let state_stride_chunk = B * H * N * dh;
|
|
86
|
+
let state_base_in = chunk_id * state_stride_chunk
|
|
87
|
+
+ batch_id * H * N * dh
|
|
88
|
+
+ head_id * N * dh;
|
|
89
|
+
|
|
90
|
+
// We maintain h as a local array (N * dh floats).
|
|
91
|
+
// WebGPU WGSL does not support variable-length arrays in function scope,
|
|
92
|
+
// so we use a fixed maximum. Max N*dh = 64*64 = 4096. Here we use dynamic
|
|
93
|
+
// indexing into state_carry which is shared storage.
|
|
94
|
+
|
|
95
|
+
// Write carry-in into temporary positions — use state_carry directly for
|
|
96
|
+
// the running state (overwrite in-place from carry-in slot).
|
|
97
|
+
// Copy carry-in to working slot (chunk_id+1 slot, updated each step).
|
|
98
|
+
let state_base_out = (chunk_id + 1u) * state_stride_chunk
|
|
99
|
+
+ batch_id * H * N * dh
|
|
100
|
+
+ head_id * N * dh;
|
|
101
|
+
|
|
102
|
+
// Initialise working state from carry-in
|
|
103
|
+
for (var s: u32 = 0u; s < N * dh; s = s + 1u) {
|
|
104
|
+
state_carry[state_base_out + s] = state_carry[state_base_in + s];
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// Sequential scan over the chunk
|
|
108
|
+
for (var t: u32 = t_start; t < t_end; t = t + 1u) {
|
|
109
|
+
// dt scalar for this head at time t
|
|
110
|
+
let dt_idx = batch_id * L * H + t * H + head_id;
|
|
111
|
+
let dt_val = softplus(dt_in[dt_idx] + db);
|
|
112
|
+
|
|
113
|
+
// A_bar = exp(-neg_A * dt_val)
|
|
114
|
+
let a_bar = exp(-neg_A * dt_val);
|
|
115
|
+
|
|
116
|
+
// Head slice of x: x[batch, t, head*dh .. (head+1)*dh]
|
|
117
|
+
let x_base = batch_id * L * D + t * D + head_id * dh;
|
|
118
|
+
|
|
119
|
+
// B at this time step: B_proj[batch, t, group_id, *] shape [N]
|
|
120
|
+
let b_base = batch_id * L * G * N + t * G * N + group_id * N;
|
|
121
|
+
|
|
122
|
+
// C at this time step: C_proj[batch, t, group_id, *] shape [N]
|
|
123
|
+
let c_base = batch_id * L * G * N + t * G * N + group_id * N;
|
|
124
|
+
|
|
125
|
+
// y accumulator for this head at time t
|
|
126
|
+
var y_acc: f32 = 0.0;
|
|
127
|
+
|
|
128
|
+
for (var n: u32 = 0u; n < N; n = n + 1u) {
|
|
129
|
+
let b_val = B_proj[b_base + n];
|
|
130
|
+
let c_val = C_proj[c_base + n];
|
|
131
|
+
|
|
132
|
+
for (var i: u32 = 0u; i < dh; i = i + 1u) {
|
|
133
|
+
let s_idx = state_base_out + n * dh + i;
|
|
134
|
+
let x_val = x_in[x_base + i];
|
|
135
|
+
|
|
136
|
+
// h_t = A_bar * h_{t-1} + B * x
|
|
137
|
+
let h_new = a_bar * state_carry[s_idx] + b_val * x_val;
|
|
138
|
+
state_carry[s_idx] = h_new;
|
|
139
|
+
|
|
140
|
+
// y += C * h (summed over n dimension per output channel i)
|
|
141
|
+
y_acc = y_acc + c_val * h_new;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// Write y + skip (D * x, averaged over dh for the skip scalar)
|
|
146
|
+
// out[batch, t, head*dh .. (head+1)*dh]
|
|
147
|
+
for (var i: u32 = 0u; i < dh; i = i + 1u) {
|
|
148
|
+
let out_idx = batch_id * L * D + t * D + head_id * dh + i;
|
|
149
|
+
let x_val = x_in[x_base + i];
|
|
150
|
+
out_buf[out_idx] = y_acc + d_skip * x_val;
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
`;
|
|
155
|
+
// ── Backward ──────────────────────────────────────────────────────────────────
|
|
156
|
+
export const SSD_BACKWARD_WGSL = /* wgsl */ `
|
|
157
|
+
struct SsdParams {
|
|
158
|
+
seq_len : u32,
|
|
159
|
+
d_inner : u32,
|
|
160
|
+
n_heads : u32,
|
|
161
|
+
d_head : u32,
|
|
162
|
+
n_groups : u32,
|
|
163
|
+
d_state : u32,
|
|
164
|
+
chunk_len : u32,
|
|
165
|
+
n_chunks : u32,
|
|
166
|
+
batch : u32,
|
|
167
|
+
};
|
|
168
|
+
|
|
169
|
+
@group(0) @binding(0) var<uniform> params : SsdParams;
|
|
170
|
+
@group(0) @binding(1) var<storage, read> x_in : array<f32>;
|
|
171
|
+
@group(0) @binding(2) var<storage, read> B_proj : array<f32>;
|
|
172
|
+
@group(0) @binding(3) var<storage, read> C_proj : array<f32>;
|
|
173
|
+
@group(0) @binding(4) var<storage, read> dt_in : array<f32>;
|
|
174
|
+
@group(0) @binding(5) var<storage, read> A_log : array<f32>;
|
|
175
|
+
@group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
|
|
176
|
+
@group(0) @binding(7) var<storage, read> state_carry : array<f32>; // forward states
|
|
177
|
+
@group(0) @binding(8) var<storage, read> dy : array<f32>; // upstream grad
|
|
178
|
+
@group(0) @binding(9) var<storage, read_write> dx : array<f32>;
|
|
179
|
+
@group(0) @binding(10) var<storage, read_write> dB : array<f32>;
|
|
180
|
+
@group(0) @binding(11) var<storage, read_write> dC : array<f32>;
|
|
181
|
+
@group(0) @binding(12) var<storage, read_write> ddt : array<f32>;
|
|
182
|
+
@group(0) @binding(13) var<storage, read_write> dA_log : array<f32>;
|
|
183
|
+
@group(0) @binding(14) var<storage, read_write> dD_vec : array<f32>;
|
|
184
|
+
|
|
185
|
+
fn softplus(x: f32) -> f32 {
|
|
186
|
+
return log(1.0 + exp(x));
|
|
187
|
+
}
|
|
188
|
+
fn d_softplus(x: f32) -> f32 {
|
|
189
|
+
return 1.0 / (1.0 + exp(-x));
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
@compute @workgroup_size(1, 1, 1)
|
|
193
|
+
fn ssd_chunk_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
194
|
+
let chunk_id = gid.x;
|
|
195
|
+
let head_id = gid.y;
|
|
196
|
+
let batch_id = gid.z;
|
|
197
|
+
|
|
198
|
+
let L = params.seq_len;
|
|
199
|
+
let D = params.d_inner;
|
|
200
|
+
let H = params.n_heads;
|
|
201
|
+
let dh = params.d_head;
|
|
202
|
+
let G = params.n_groups;
|
|
203
|
+
let N = params.d_state;
|
|
204
|
+
let CL = params.chunk_len;
|
|
205
|
+
let NC = params.n_chunks;
|
|
206
|
+
let B = params.batch;
|
|
207
|
+
|
|
208
|
+
let t_start = chunk_id * CL;
|
|
209
|
+
let t_end = min(t_start + CL, L);
|
|
210
|
+
let group_id = head_id * G / H;
|
|
211
|
+
|
|
212
|
+
let neg_A = softplus(A_log[head_id]);
|
|
213
|
+
let db = dt_bias[head_id];
|
|
214
|
+
|
|
215
|
+
let state_stride = B * H * N * dh;
|
|
216
|
+
let state_base = chunk_id * state_stride
|
|
217
|
+
+ batch_id * H * N * dh
|
|
218
|
+
+ head_id * N * dh;
|
|
219
|
+
|
|
220
|
+
// Backward: iterate time steps in reverse within the chunk
|
|
221
|
+
// dh_next starts at zero (or propagated from future chunks — simplified here)
|
|
222
|
+
for (var t_rev: u32 = 0u; t_rev < t_end - t_start; t_rev = t_rev + 1u) {
|
|
223
|
+
let t = t_end - 1u - t_rev;
|
|
224
|
+
|
|
225
|
+
let dt_idx = batch_id * L * H + t * H + head_id;
|
|
226
|
+
let dt_raw = dt_in[dt_idx] + db;
|
|
227
|
+
let dt_val = softplus(dt_raw);
|
|
228
|
+
let a_bar = exp(-neg_A * dt_val);
|
|
229
|
+
|
|
230
|
+
let x_base = batch_id * L * D + t * D + head_id * dh;
|
|
231
|
+
let b_base = batch_id * L * G * N + t * G * N + group_id * N;
|
|
232
|
+
let c_base = b_base;
|
|
233
|
+
|
|
234
|
+
for (var i: u32 = 0u; i < dh; i = i + 1u) {
|
|
235
|
+
let dy_val = dy[batch_id * L * D + t * D + head_id * dh + i];
|
|
236
|
+
let x_val = x_in[x_base + i];
|
|
237
|
+
|
|
238
|
+
// dD_vec
|
|
239
|
+
dD_vec[head_id] = dD_vec[head_id] + dy_val * x_val;
|
|
240
|
+
// dx from skip
|
|
241
|
+
dx[x_base + i] = dx[x_base + i] + dy_val * /* D */ 1.0;
|
|
242
|
+
|
|
243
|
+
for (var n: u32 = 0u; n < N; n = n + 1u) {
|
|
244
|
+
let s_idx = state_base + n * dh + i;
|
|
245
|
+
let h_val = state_carry[(chunk_id + 1u) * state_stride
|
|
246
|
+
+ batch_id * H * N * dh
|
|
247
|
+
+ head_id * N * dh + n * dh + i];
|
|
248
|
+
let c_val = C_proj[c_base + n];
|
|
249
|
+
let b_val = B_proj[b_base + n];
|
|
250
|
+
|
|
251
|
+
// dC += dy * h
|
|
252
|
+
dC[b_base + n] = dC[b_base + n] + dy_val * h_val;
|
|
253
|
+
|
|
254
|
+
// dh = C * dy
|
|
255
|
+
let dh_val = c_val * dy_val;
|
|
256
|
+
|
|
257
|
+
// dB += dh * x
|
|
258
|
+
dB[b_base + n] = dB[b_base + n] + dh_val * x_val;
|
|
259
|
+
|
|
260
|
+
// dx += dh * B
|
|
261
|
+
dx[x_base + i] = dx[x_base + i] + dh_val * b_val;
|
|
262
|
+
|
|
263
|
+
// ddt += dh * h_prev * (-neg_A) * d_softplus(dt_raw)
|
|
264
|
+
let h_prev = state_carry[s_idx];
|
|
265
|
+
ddt[dt_idx] = ddt[dt_idx]
|
|
266
|
+
+ dh_val * h_prev * (-neg_A) * d_softplus(dt_raw);
|
|
267
|
+
|
|
268
|
+
// dA_log += dh * h_prev * a_bar * (-dt_val) * d_softplus(A_log[head])
|
|
269
|
+
dA_log[head_id] = dA_log[head_id]
|
|
270
|
+
+ dh_val * h_prev * a_bar * (-dt_val) * d_softplus(A_log[head_id]);
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
`;
|
|
276
|
+
//# sourceMappingURL=ssd.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"ssd.js","sourceRoot":"","sources":["../../src/kernels/ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;GAyBG;AAEH,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA+HjD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,iBAAiB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuHlD,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"weight_update.d.ts","sourceRoot":"","sources":["../../src/kernels/weight_update.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,kBAAkB,EAAE,MAgDhC,CAAC;AAIF,eAAO,MAAM,cAAc,EAAE,MAyD5B,CAAC"}
|