tensorgrad 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +121 -0
  3. package/SPEC.md +293 -0
  4. package/dist/adam.d.ts +31 -0
  5. package/dist/adam.d.ts.map +1 -0
  6. package/dist/adam.js +66 -0
  7. package/dist/adam.js.map +1 -0
  8. package/dist/buffers.d.ts +56 -0
  9. package/dist/buffers.d.ts.map +1 -0
  10. package/dist/buffers.js +114 -0
  11. package/dist/buffers.js.map +1 -0
  12. package/dist/codegen.d.ts +23 -0
  13. package/dist/codegen.d.ts.map +1 -0
  14. package/dist/codegen.js +709 -0
  15. package/dist/codegen.js.map +1 -0
  16. package/dist/compile.d.ts +53 -0
  17. package/dist/compile.d.ts.map +1 -0
  18. package/dist/compile.js +76 -0
  19. package/dist/compile.js.map +1 -0
  20. package/dist/grad.d.ts +8 -0
  21. package/dist/grad.d.ts.map +1 -0
  22. package/dist/grad.js +404 -0
  23. package/dist/grad.js.map +1 -0
  24. package/dist/index.d.ts +12 -0
  25. package/dist/index.d.ts.map +1 -0
  26. package/dist/index.js +37 -0
  27. package/dist/index.js.map +1 -0
  28. package/dist/ir.d.ts +204 -0
  29. package/dist/ir.d.ts.map +1 -0
  30. package/dist/ir.js +60 -0
  31. package/dist/ir.js.map +1 -0
  32. package/dist/module.d.ts +21 -0
  33. package/dist/module.d.ts.map +1 -0
  34. package/dist/module.js +113 -0
  35. package/dist/module.js.map +1 -0
  36. package/dist/ops.d.ts +35 -0
  37. package/dist/ops.d.ts.map +1 -0
  38. package/dist/ops.js +270 -0
  39. package/dist/ops.js.map +1 -0
  40. package/dist/runtime.d.ts +26 -0
  41. package/dist/runtime.d.ts.map +1 -0
  42. package/dist/runtime.js +190 -0
  43. package/dist/runtime.js.map +1 -0
  44. package/dist/shape.d.ts +24 -0
  45. package/dist/shape.d.ts.map +1 -0
  46. package/dist/shape.js +259 -0
  47. package/dist/shape.js.map +1 -0
  48. package/dist/trace.d.ts +8 -0
  49. package/dist/trace.d.ts.map +1 -0
  50. package/dist/trace.js +93 -0
  51. package/dist/trace.js.map +1 -0
  52. package/package.json +62 -0
  53. package/src/adam.ts +95 -0
  54. package/src/buffers.ts +173 -0
  55. package/src/codegen.ts +758 -0
  56. package/src/compile.ts +120 -0
  57. package/src/grad.ts +459 -0
  58. package/src/index.ts +40 -0
  59. package/src/ir.ts +197 -0
  60. package/src/module.ts +126 -0
  61. package/src/ops.ts +311 -0
  62. package/src/runtime.ts +232 -0
  63. package/src/shape.ts +263 -0
  64. package/src/trace.ts +101 -0
package/src/compile.ts ADDED
@@ -0,0 +1,120 @@
1
+ // Top-level compile(): trace → autograd → buffer plan → codegen → runtime.
2
+ //
3
+ // Two entry points:
4
+ // * `compile(traceFn)` — low-level. User declares params via
5
+ // paramInput() inside the trace.
6
+ // * `compileModule(model, …)` — high-level. User defines the model as a
7
+ // Module tree; the library auto-discovers
8
+ // params, traces the forward, appends grad
9
+ // and Adam, and returns a runtime.
10
+
11
+ import type { Tensor, Shape, Dtype } from './ir.js'
12
+ import { trace, tensorInput } from './trace.js'
13
+ import { appendGrad, type GradResult } from './grad.js'
14
+ import { appendAdam, type AdamConfig } from './adam.js'
15
+ import { planBuffers, type BufferPlan } from './buffers.js'
16
+ import { emitKernels, type KernelSpec } from './codegen.js'
17
+ import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
18
+ import { Module, materializeParams } from './module.js'
19
+
20
+ /** Declares one input tensor of the model's forward function. Order matches
21
+ * the function's parameter list (after `model`). The `name` is used at
22
+ * runtime to upload data via `step({ [name]: data })`. */
23
+ export interface InputDecl {
24
+ name: string
25
+ shape: Shape
26
+ dtype?: Dtype
27
+ }
28
+
29
+ export interface CompiledIR {
30
+ graph: GradResult['graph']
31
+ paramGrads: GradResult['paramGrads']
32
+ loss: Tensor
33
+ plan: BufferPlan
34
+ kernels: KernelSpec[]
35
+ }
36
+
37
+ /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
38
+ export function compileToIR(traceFn: () => Tensor): CompiledIR {
39
+ const graph = trace(traceFn)
40
+ const { paramGrads, loss } = appendGrad(graph)
41
+ const plan = planBuffers(graph, paramGrads)
42
+ const kernels = emitKernels(graph, plan)
43
+ return { graph, paramGrads, loss, plan, kernels }
44
+ }
45
+
46
+ /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
47
+ export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
48
+ const ir = compileToIR(traceFn)
49
+ const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
50
+ const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
51
+ return Object.assign(runtime, { ir })
52
+ }
53
+
54
+ // ============================================================================
55
+ // Module-aware compile
56
+ // ============================================================================
57
+
58
+ export interface CompileModuleOptions extends RuntimeOpts {
59
+ /** Per-step data inputs to the forward function. Order matches the forward
60
+ * function's parameters (after the model). e.g. for
61
+ * `(model, tokens, targets, mask) => loss`, inputs is
62
+ * `[{name:'tokens',...}, {name:'targets',...}, {name:'mask',...}]`. */
63
+ inputs?: InputDecl[]
64
+ /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
65
+ adam?: AdamConfig
66
+ }
67
+
68
+ /**
69
+ * Compile a Module-based model. The forward function takes the materialized
70
+ * model and returns the loss tensor (typically by also calling tensorInput
71
+ * for tokens/targets/masks inside).
72
+ *
73
+ * Walks the module tree to materialize params with auto-derived names, then
74
+ * runs trace → grad → adam → buffer plan → codegen → runtime.
75
+ *
76
+ * If `opts.adam` is set, the runtime's `step()` automatically tracks an
77
+ * internal step count and injects the bias-corrected `lrt` scalar each call;
78
+ * users don't need to provide it themselves.
79
+ */
80
+ export async function compileModule<M extends Module>(
81
+ model: M,
82
+ forward: (m: M, ...inputs: Tensor[]) => Tensor,
83
+ opts: CompileModuleOptions = {},
84
+ ): Promise<CompiledRuntime & { ir: CompiledIR }> {
85
+ const inputDecls = opts.inputs ?? []
86
+ let paramTensors: Record<string, Tensor> = {}
87
+ const graph = trace(() => {
88
+ paramTensors = materializeParams(model)
89
+ const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
90
+ return forward(model, ...inputTensors)
91
+ })
92
+
93
+ const { paramGrads, loss } = appendGrad(graph)
94
+
95
+ let adamResult: ReturnType<typeof appendAdam> | undefined
96
+ if (opts.adam) {
97
+ adamResult = appendAdam(graph, paramGrads, paramTensors, opts.adam)
98
+ }
99
+
100
+ const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
101
+ const kernels = emitKernels(graph, plan)
102
+ const lossBufferId = plan.tensorToBuffer.get(loss.id)!
103
+ const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
104
+
105
+ // If Adam is enabled, wrap step() to track the step count and supply lrt.
106
+ if (adamResult) {
107
+ const { lrtInputName, config } = adamResult
108
+ let t = 0
109
+ const lrtBuf = new Float32Array(1)
110
+ const innerStep = runtime.step.bind(runtime)
111
+ runtime.step = async (inputs) => {
112
+ t++
113
+ lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
114
+ return innerStep({ ...inputs, [lrtInputName]: lrtBuf })
115
+ }
116
+ }
117
+
118
+ const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
119
+ return Object.assign(runtime, { ir })
120
+ }
package/src/grad.ts ADDED
@@ -0,0 +1,459 @@
1
+ // Reverse-mode autograd over a traced Graph.
2
+ //
3
+ // Given a graph that ends in a scalar loss tensor, this module walks the ops
4
+ // in reverse and appends backward ops to the same graph, computing dL/dT for
5
+ // every Tensor T that descends from a `param_input`. The final cotangents on
6
+ // the param_input tensors are the parameter gradients.
7
+ //
8
+ // Cotangent accumulation: a tensor with multiple consumers ends up with
9
+ // contributions from each. We add them as we encounter them, so by the time
10
+ // reverse iteration reaches a tensor's producer op, its cotangent is complete.
11
+ //
12
+ // Why this works as "more graph nodes": the transpose rule for an op like
13
+ // mul(a, b)→c is `da += dc * b; db += dc * a`. The right-hand sides are
14
+ // expressible in terms of existing forward ops (mul) plus accumulation (add).
15
+ // We just call those op functions, which append nodes to the current graph
16
+ // because we run inside an active trace context.
17
+
18
+ import type { Graph, OpNode, Tensor, Shape } from './ir.js'
19
+ import {
20
+ add, sub, mul, div, mulScalar,
21
+ matmul, matmulBatched, transpose, reshape,
22
+ exp,
23
+ broadcastTo, sumToShape,
24
+ constScalar, reluGrad,
25
+ sumLast, where,
26
+ } from './ops.js'
27
+ import { traceInto } from './trace.js'
28
+
29
+ // ============================================================================
30
+ // Public API
31
+ // ============================================================================
32
+
33
+ export interface GradResult {
34
+ // The graph, augmented with backward ops.
35
+ readonly graph: Graph
36
+ // Cotangents (gradients) for each param_input, keyed by param name.
37
+ readonly paramGrads: Record<string, Tensor>
38
+ // The loss output (unchanged from input).
39
+ readonly loss: Tensor
40
+ }
41
+
42
+ // `appendGrad(graph)` augments `graph` (which must have already been built by
43
+ // `trace(...)` and must have a single scalar output = the loss) with backward
44
+ // ops. Returns gradients for every param_input.
45
+ //
46
+ // Internally re-enters the graph as the active trace context, so backward ops
47
+ // emitted by transpose rules append to it. The caller doesn't need to manage
48
+ // trace state.
49
+ export function appendGrad(graph: Graph): GradResult {
50
+ if (graph.outputs.length !== 1) {
51
+ throw new Error(`autograd: expected graph with exactly 1 output (the loss); got ${graph.outputs.length}`)
52
+ }
53
+ const lossId = graph.outputs[0]!
54
+ const lossTensor = graph.tensors[lossId]!
55
+ if (lossTensor.shape.length !== 0) {
56
+ throw new Error(
57
+ `autograd: loss must be a rank-0 scalar; got shape [${lossTensor.shape.join(', ')}]. ` +
58
+ `Reduce with sumLast / mulScalar to a scalar before calling appendGrad.`,
59
+ )
60
+ }
61
+
62
+ // Snapshot the forward portion of the graph before we start emitting backward
63
+ // ops, so the reverse walk only iterates over forward ops.
64
+ const forwardOpCount = graph.ops.length
65
+ const forwardOps = graph.ops.slice(0, forwardOpCount)
66
+
67
+ // cotangents: tensorId -> the Tensor representing dL/dTensor in the graph.
68
+ const cotangents = new Map<number, Tensor>()
69
+
70
+ return traceInto(graph, () => {
71
+ // Seed: dL/dLoss = 1.0
72
+ cotangents.set(lossId, constScalar(1.0, 'f32'))
73
+
74
+ // Reverse walk.
75
+ for (let i = forwardOpCount - 1; i >= 0; i--) {
76
+ const op = forwardOps[i]!
77
+ const outCotan = cotangents.get(op.out)
78
+ if (!outCotan) continue
79
+ runTransposeRule(op, outCotan, graph, cotangents)
80
+ }
81
+
82
+ // Collect param gradients by name. Skip non-param leaves.
83
+ const paramGrads: Record<string, Tensor> = {}
84
+ for (const op of forwardOps) {
85
+ if (op.kind !== 'param_input') continue
86
+ // (state_input and tensor_input don't produce gradients we hand back.)
87
+ const cotan = cotangents.get(op.out)
88
+ if (!cotan) {
89
+ // No path from this param to the loss — emit explicit zeros so the
90
+ // caller gets a tensor with the right shape.
91
+ const t = graph.tensors[op.out]!
92
+ paramGrads[op.name] = broadcastTo(constScalar(0.0, t.dtype), t.shape)
93
+ } else {
94
+ paramGrads[op.name] = cotan
95
+ }
96
+ }
97
+
98
+ return { graph, paramGrads, loss: lossTensor }
99
+ })
100
+ }
101
+
102
+ // ============================================================================
103
+ // Cotangent accumulation
104
+ // ============================================================================
105
+
106
+ // Add `contribution` to the cotangent of tensor `inputId`. If a cotangent
107
+ // already exists, sum them (multiple consumers); otherwise initialize.
108
+ function accumulate(cotangents: Map<number, Tensor>, inputId: number, contribution: Tensor): void {
109
+ const existing = cotangents.get(inputId)
110
+ if (existing) {
111
+ cotangents.set(inputId, add(existing, contribution))
112
+ } else {
113
+ cotangents.set(inputId, contribution)
114
+ }
115
+ }
116
+
117
+ // Reduce a cotangent to match the input's shape, undoing any broadcast that
118
+ // occurred during forward. If `fromShape == toShape`, no-op.
119
+ function unbroadcast(cotan: Tensor, toShape: Shape): Tensor {
120
+ if (shapesEqual(cotan.shape, toShape)) return cotan
121
+ return sumToShape(cotan, toShape)
122
+ }
123
+
124
+ function shapesEqual(a: Shape, b: Shape): boolean {
125
+ if (a.length !== b.length) return false
126
+ for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false
127
+ return true
128
+ }
129
+
130
+ // ============================================================================
131
+ // Transpose rules
132
+ // ============================================================================
133
+ //
134
+ // One per OpNode kind. Each rule:
135
+ // * receives the forward op + its output cotangent
136
+ // * builds the backward expression(s) in graph terms (calling ops.ts functions)
137
+ // * accumulates cotangent contributions onto each input tensor
138
+
139
+ function runTransposeRule(
140
+ op: OpNode,
141
+ outCotan: Tensor,
142
+ graph: Graph,
143
+ cotangents: Map<number, Tensor>,
144
+ ): void {
145
+ const tensorOf = (id: number) => graph.tensors[id]!
146
+
147
+ switch (op.kind) {
148
+ // ---- Leaves: no inputs to accumulate into. -----------------------------
149
+ case 'param_input':
150
+ case 'tensor_input':
151
+ case 'state_input':
152
+ case 'arange':
153
+ case 'const_scalar':
154
+ return
155
+
156
+ // ---- Element-wise binops (with broadcast) ------------------------------
157
+ // c = a op b; reduce cotan back to each operand's shape.
158
+ case 'add': {
159
+ const a = tensorOf(op.a), b = tensorOf(op.b)
160
+ accumulate(cotangents, op.a, unbroadcast(outCotan, a.shape))
161
+ accumulate(cotangents, op.b, unbroadcast(outCotan, b.shape))
162
+ return
163
+ }
164
+ case 'sub': {
165
+ const a = tensorOf(op.a), b = tensorOf(op.b)
166
+ accumulate(cotangents, op.a, unbroadcast(outCotan, a.shape))
167
+ accumulate(cotangents, op.b, unbroadcast(mulScalar(outCotan, -1), b.shape))
168
+ return
169
+ }
170
+ case 'mul': {
171
+ const a = tensorOf(op.a), b = tensorOf(op.b)
172
+ // dC/dA = b ; dC/dB = a. Both are forward tensors still alive in the graph.
173
+ // We must NOT consume the forward tensors — they're referenced by id.
174
+ // The mul() helper allocates fresh tensors, so referencing a/b multiple
175
+ // times in different mul() calls is fine: we just emit fresh ops.
176
+ accumulate(cotangents, op.a, unbroadcast(mul(outCotan, b), a.shape))
177
+ accumulate(cotangents, op.b, unbroadcast(mul(outCotan, a), b.shape))
178
+ return
179
+ }
180
+ case 'div': {
181
+ // c = a/b. dc/da = 1/b. dc/db = -a/b^2.
182
+ const a = tensorOf(op.a), b = tensorOf(op.b)
183
+ accumulate(cotangents, op.a, unbroadcast(div(outCotan, b), a.shape))
184
+ // -outCotan * a / (b*b)
185
+ const numer = mul(outCotan, a)
186
+ const bSq = mul(b, b)
187
+ accumulate(cotangents, op.b, unbroadcast(mulScalar(div(numer, bSq), -1), b.shape))
188
+ return
189
+ }
190
+
191
+ // ---- Element-wise scalar binops (scalar is a JS number, not a tensor) -
192
+ case 'mul_scalar': {
193
+ // c = a * s. dc/da = s.
194
+ accumulate(cotangents, op.a, mulScalar(outCotan, op.scalar))
195
+ return
196
+ }
197
+ case 'add_scalar': {
198
+ // c = a + s. dc/da = 1.
199
+ accumulate(cotangents, op.a, outCotan)
200
+ return
201
+ }
202
+
203
+ // ---- Unary -------------------------------------------------------------
204
+ case 'sqrt': {
205
+ // c = sqrt(a). dc/da = 1/(2*sqrt(a)) = 1/(2*c).
206
+ const c = tensorOf(op.out)
207
+ accumulate(cotangents, op.a, mulScalar(div(outCotan, c), 0.5))
208
+ return
209
+ }
210
+ case 'rsqrt': {
211
+ // c = a^(-0.5). dc/da = -0.5 * a^(-1.5) = -0.5 * c^3.
212
+ const c = tensorOf(op.out)
213
+ const c3 = mul(mul(c, c), c)
214
+ accumulate(cotangents, op.a, mulScalar(mul(outCotan, c3), -0.5))
215
+ return
216
+ }
217
+ case 'log': {
218
+ // c = log(a). dc/da = 1/a.
219
+ const a = tensorOf(op.a)
220
+ accumulate(cotangents, op.a, div(outCotan, a))
221
+ return
222
+ }
223
+ case 'exp': {
224
+ // c = exp(a). dc/da = exp(a) = c.
225
+ const c = tensorOf(op.out)
226
+ accumulate(cotangents, op.a, mul(outCotan, c))
227
+ return
228
+ }
229
+ case 'relu': {
230
+ // c = relu(a). dc/da = (a > 0 ? 1 : 0). Use the fused relu_grad op.
231
+ const a = tensorOf(op.a)
232
+ accumulate(cotangents, op.a, reluGrad(a, outCotan))
233
+ return
234
+ }
235
+
236
+ // ---- Reductions over last axis ---------------------------------------
237
+ case 'mean_last': {
238
+ // c[..., 1] = mean over last axis of a[..., D]. da[..., d] = dc[..., 0] / D.
239
+ // outCotan has shape [..., 1]; broadcast to a's shape and divide by D.
240
+ const a = tensorOf(op.a)
241
+ const D = a.shape[a.shape.length - 1]!
242
+ const expanded = broadcastTo(outCotan, a.shape)
243
+ accumulate(cotangents, op.a, mulScalar(expanded, 1 / D))
244
+ return
245
+ }
246
+ case 'sum_last': {
247
+ // c[...] = sum over last axis (keepdims=false). da[..., d] = dc[...].
248
+ // outCotan has rank one less than a; broadcast to a's shape (which inserts
249
+ // back the last axis with a's last-axis size).
250
+ const a = tensorOf(op.a)
251
+ // First reshape outCotan to add a trailing 1, then broadcast to a's shape.
252
+ const withKeep = reshape(outCotan, [...outCotan.shape, 1])
253
+ accumulate(cotangents, op.a, broadcastTo(withKeep, a.shape))
254
+ return
255
+ }
256
+
257
+ // ---- Shape ------------------------------------------------------------
258
+ case 'reshape': {
259
+ // c = reshape(a, ...). Backward: reshape outCotan back to a's shape.
260
+ const a = tensorOf(op.a)
261
+ accumulate(cotangents, op.a, reshape(outCotan, a.shape))
262
+ return
263
+ }
264
+ case 'transpose': {
265
+ // c = transpose(a, perm). Backward: transpose outCotan with inverse perm.
266
+ const inv = invertPerm(op.perm)
267
+ accumulate(cotangents, op.a, transpose(outCotan, inv))
268
+ return
269
+ }
270
+
271
+ // ---- Linear algebra ---------------------------------------------------
272
+ case 'matmul': {
273
+ // c = a @ b, where a: [..., M, K], b: [K, N], c: [..., M, N].
274
+ // dA = dC @ B^T (matmul, since b is unbatched)
275
+ // dB = sum_over_batch( A^T @ dC )
276
+ //
277
+ // Implementation note: dA uses the same `matmul` (a [...,M,N] · b [N,K])
278
+ // because b is rank-2. dB needs A^T which has shape [..., K, M], then
279
+ // matmul with dC ([..., M, N]) gives [..., K, N], which we sum over
280
+ // leading batch dims to get [K, N].
281
+ const a = tensorOf(op.a), b = tensorOf(op.b)
282
+ // dA = dC @ B^T
283
+ const bT = transpose(b, [1, 0])
284
+ accumulate(cotangents, op.a, matmul(outCotan, bT))
285
+ // dB: per-batch A^T @ dC, then sum over batch dims.
286
+ // A is [..., M, K]; transpose last two axes.
287
+ const aTPerm = identityPerm(a.shape.length)
288
+ ;[aTPerm[a.shape.length - 1], aTPerm[a.shape.length - 2]] =
289
+ [aTPerm[a.shape.length - 2]!, aTPerm[a.shape.length - 1]!]
290
+ const aT = transpose(a, aTPerm) // [..., K, M]
291
+ // matmul_batched needs same rank on both sides. dC has rank `a.rank`;
292
+ // aT has rank `a.rank`; use matmul_batched if rank > 2, else matmul.
293
+ let perBatchDb: Tensor
294
+ if (a.shape.length > 2) {
295
+ perBatchDb = matmulBatched(aT, outCotan) // [..., K, N]
296
+ } else {
297
+ perBatchDb = matmul(aT, outCotan) // [K, N]
298
+ }
299
+ // Sum over leading batch dims to collapse to b's shape [K, N].
300
+ accumulate(cotangents, op.b, sumToShape(perBatchDb, b.shape))
301
+ return
302
+ }
303
+ case 'matmul_batched': {
304
+ // c = a @ b, both [..., M, K] · [..., K, N] -> [..., M, N].
305
+ // dA = dC @ B^T (per-batch, all batch dims preserved)
306
+ // dB = A^T @ dC (per-batch)
307
+ const a = tensorOf(op.a), b = tensorOf(op.b)
308
+ const lastTwoSwap = (rank: number) => {
309
+ const p = identityPerm(rank)
310
+ ;[p[rank - 1], p[rank - 2]] = [p[rank - 2]!, p[rank - 1]!]
311
+ return p
312
+ }
313
+ const bT = transpose(b, lastTwoSwap(b.shape.length))
314
+ const aT = transpose(a, lastTwoSwap(a.shape.length))
315
+ accumulate(cotangents, op.a, matmulBatched(outCotan, bT))
316
+ accumulate(cotangents, op.b, matmulBatched(aT, outCotan))
317
+ return
318
+ }
319
+
320
+ // ---- Indexing / casting (no gradient through integer indices) --------
321
+ case 'one_hot':
322
+ // The output is float, but the input (indices) is integer-valued — no
323
+ // continuous gradient flows through it. Stop here.
324
+ return
325
+
326
+ // ---- Slicing ---------------------------------------------------------
327
+ case 'slice_last_range': {
328
+ // c = a[..., start:end]. Backward: pad outCotan with zeros to a's shape.
329
+ // We construct this as: zeros at left, outCotan in middle, zeros at right,
330
+ // concatenated along the last axis. We don't have concat or generic pad
331
+ // ops; the simplest expression here is a sparse expansion via broadcasting
332
+ // and addition of zero tensors. For Phase 2 we punt: slice's autograd is
333
+ // implemented by emitting a single fused op that scatters the cotangent.
334
+ // For now: signal that slice's backward needs a dedicated op kind.
335
+ const a = tensorOf(op.a)
336
+ // Build a zeros tensor of a's shape, then add via... no, we can't do
337
+ // additive scatter without an index_put. Easiest path: add a dedicated
338
+ // backward op kind. For this pass, throw until we extend the IR.
339
+ throw new Error(
340
+ `autograd: slice_last_range backward not implemented yet ` +
341
+ `(would need a scatter-style op or a Concat op). ` +
342
+ `Workaround for now: avoid taking gradients through slices by using ` +
343
+ `separate matmuls for Q/K/V instead of a fused W_qkv. ` +
344
+ `Tensor: ${a.shape} -> ${tensorOf(op.out).shape}`,
345
+ )
346
+ }
347
+
348
+ // ---- Broadcast / un-broadcast (autograd infrastructure) ---------------
349
+ case 'broadcast_to': {
350
+ // c = broadcast(a, target). da = sum_to_shape(dc, a.shape).
351
+ const a = tensorOf(op.a)
352
+ accumulate(cotangents, op.a, sumToShape(outCotan, a.shape))
353
+ return
354
+ }
355
+ case 'sum_to_shape': {
356
+ // c = sum_to_shape(a, target). da = broadcast_to(dc, a.shape).
357
+ const a = tensorOf(op.a)
358
+ accumulate(cotangents, op.a, broadcastTo(outCotan, a.shape))
359
+ return
360
+ }
361
+
362
+ // ---- ML primitives ---------------------------------------------------
363
+ case 'log_softmax_last': {
364
+ // c = log_softmax(a, axis=-1). softmax(a) = exp(c).
365
+ // dL/dA = dL/dC - softmax(a) * sum_last_keepdims(dL/dC)
366
+ const c = tensorOf(op.out)
367
+ const sm = exp(c) // softmax(a)
368
+ // sum_last with keepdims via reshape: sum_last drops the dim, then
369
+ // reshape to add a trailing 1 back, then broadcast multiplies.
370
+ const sumDc = sumLast(outCotan) // shape: [..., ] (rank-1 less)
371
+ const sumDcKeep = reshape(sumDc, [...sumDc.shape, 1])
372
+ const term = mul(sm, broadcastTo(sumDcKeep, c.shape))
373
+ accumulate(cotangents, op.a, sub(outCotan, term))
374
+ return
375
+ }
376
+ case 'softmax_causal_last': {
377
+ // c = softmax_causal(a, axis=-1). The causal mask zeros the upper triangle
378
+ // of c; for the backward, the same mask zeros out dx_upper because both
379
+ // paths through softmax depend on c-values that are 0 there.
380
+ // dL/dA = (dL/dC - sum_last_keep(dL/dC * c)) * c
381
+ const c = tensorOf(op.out)
382
+ const dcXc = mul(outCotan, c)
383
+ const s = sumLast(dcXc)
384
+ const sKeep = reshape(s, [...s.shape, 1])
385
+ const inner = sub(outCotan, broadcastTo(sKeep, c.shape))
386
+ accumulate(cotangents, op.a, mul(inner, c))
387
+ return
388
+ }
389
+ // ---- Comparisons + select ---------------------------------------------
390
+ case 'less':
391
+ case 'greater':
392
+ // No gradient flows through bool comparisons. Stop here.
393
+ return
394
+
395
+ case 'where': {
396
+ // c = where(cond, a, b).
397
+ // dC flows to a where cond is true, to b where cond is false.
398
+ // Need broadcast-aware unreduction back to a's and b's original shapes.
399
+ const cond = tensorOf(op.cond)
400
+ const a = tensorOf(op.a)
401
+ const b = tensorOf(op.b)
402
+ // Build zero tensors via broadcasting a 0-d const scalar.
403
+ const zeroA = broadcastTo(constScalar(0, a.dtype), outCotan.shape)
404
+ const zeroB = broadcastTo(constScalar(0, b.dtype), outCotan.shape)
405
+ accumulate(cotangents, op.a, unbroadcast(where(cond, outCotan, zeroA), a.shape))
406
+ accumulate(cotangents, op.b, unbroadcast(where(cond, zeroB, outCotan), b.shape))
407
+ return
408
+ }
409
+
410
+ case 'where_causal': {
411
+ // c = where(causal_mask, a, fillValue). Upper triangle becomes constant
412
+ // (no gradient); lower triangle passes a through. So da_lower = dc_lower,
413
+ // da_upper = 0. We can't easily express this with current ops; punt.
414
+ throw new Error(
415
+ `autograd: where_causal backward not yet implemented. ` +
416
+ `Use softmax_causal_last (which fuses the mask + softmax) instead.`,
417
+ )
418
+ }
419
+
420
+ // ---- Adam ops are post-autograd; no backward through them. ----------
421
+ case 'adam_update_m':
422
+ case 'adam_update_v':
423
+ case 'adam_update_p':
424
+ throw new Error(`autograd: cannot differentiate through ${op.kind}`)
425
+
426
+ // ---- relu_grad has no further backward (autograd-internal) ----------
427
+ case 'relu_grad': {
428
+ // We don't double-differentiate. If someone tries, this will blow up —
429
+ // intentional. Phase 2 doesn't need 2nd-order gradients.
430
+ throw new Error(
431
+ `autograd: cannot take second-order gradient through relu_grad. ` +
432
+ `Phase 2 does not support higher-order autodiff.`,
433
+ )
434
+ }
435
+
436
+ default: {
437
+ // Exhaustiveness check at type level.
438
+ const _exhaustive: never = op
439
+ void _exhaustive
440
+ throw new Error(`autograd: unhandled op kind ${(op as OpNode).kind}`)
441
+ }
442
+ }
443
+ }
444
+
445
+ // ============================================================================
446
+ // Helpers
447
+ // ============================================================================
448
+
449
+ function identityPerm(rank: number): number[] {
450
+ const p: number[] = new Array(rank)
451
+ for (let i = 0; i < rank; i++) p[i] = i
452
+ return p
453
+ }
454
+
455
+ function invertPerm(perm: readonly number[]): number[] {
456
+ const inv: number[] = new Array(perm.length)
457
+ for (let i = 0; i < perm.length; i++) inv[perm[i]!] = i
458
+ return inv
459
+ }
package/src/index.ts ADDED
@@ -0,0 +1,40 @@
1
+ // Public surface. Bulb code imports from here.
2
+ //
3
+ // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
+ // codegen / compile() (Phase 3+) come later.
5
+
6
+ export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
+ export { ShapeError } from './shape.js'
8
+ export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export {
10
+ // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
11
+ add, sub, mul, div,
12
+ // Element-wise unary
13
+ sqrt, rsqrt, log, exp, relu,
14
+ // Comparisons + select
15
+ less, greater, where,
16
+ // Reductions over the last axis (other axes via reshape/transpose first)
17
+ meanLast, sumLast,
18
+ // Shape ops
19
+ reshape, transpose,
20
+ // Linear algebra
21
+ matmul, matmulBatched,
22
+ // Indexing / casting
23
+ oneHot, arange,
24
+ // ML primitives — fused for the transformer
25
+ softmaxCausalLast, logSoftmaxLast, whereCausal,
26
+ // Slicing
27
+ sliceLastRange,
28
+ } from './ops.js'
29
+
30
+ // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
31
+ // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
32
+ // adam.ts can import them) but aren't part of the public API — `add`/`mul`
33
+ // overload on JS numbers, `where` subsumes the rest.
34
+ export { appendGrad, type GradResult } from './grad.js'
35
+ export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
36
+ export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
37
+ export { emitKernels, type KernelSpec } from './codegen.js'
38
+ export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
39
+ export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
40
+ export { Module, materializeParams } from './module.js'