tensorgrad 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +121 -0
  3. package/SPEC.md +293 -0
  4. package/dist/adam.d.ts +31 -0
  5. package/dist/adam.d.ts.map +1 -0
  6. package/dist/adam.js +66 -0
  7. package/dist/adam.js.map +1 -0
  8. package/dist/buffers.d.ts +56 -0
  9. package/dist/buffers.d.ts.map +1 -0
  10. package/dist/buffers.js +114 -0
  11. package/dist/buffers.js.map +1 -0
  12. package/dist/codegen.d.ts +23 -0
  13. package/dist/codegen.d.ts.map +1 -0
  14. package/dist/codegen.js +709 -0
  15. package/dist/codegen.js.map +1 -0
  16. package/dist/compile.d.ts +53 -0
  17. package/dist/compile.d.ts.map +1 -0
  18. package/dist/compile.js +76 -0
  19. package/dist/compile.js.map +1 -0
  20. package/dist/grad.d.ts +8 -0
  21. package/dist/grad.d.ts.map +1 -0
  22. package/dist/grad.js +404 -0
  23. package/dist/grad.js.map +1 -0
  24. package/dist/index.d.ts +12 -0
  25. package/dist/index.d.ts.map +1 -0
  26. package/dist/index.js +37 -0
  27. package/dist/index.js.map +1 -0
  28. package/dist/ir.d.ts +204 -0
  29. package/dist/ir.d.ts.map +1 -0
  30. package/dist/ir.js +60 -0
  31. package/dist/ir.js.map +1 -0
  32. package/dist/module.d.ts +21 -0
  33. package/dist/module.d.ts.map +1 -0
  34. package/dist/module.js +113 -0
  35. package/dist/module.js.map +1 -0
  36. package/dist/ops.d.ts +35 -0
  37. package/dist/ops.d.ts.map +1 -0
  38. package/dist/ops.js +270 -0
  39. package/dist/ops.js.map +1 -0
  40. package/dist/runtime.d.ts +26 -0
  41. package/dist/runtime.d.ts.map +1 -0
  42. package/dist/runtime.js +190 -0
  43. package/dist/runtime.js.map +1 -0
  44. package/dist/shape.d.ts +24 -0
  45. package/dist/shape.d.ts.map +1 -0
  46. package/dist/shape.js +259 -0
  47. package/dist/shape.js.map +1 -0
  48. package/dist/trace.d.ts +8 -0
  49. package/dist/trace.d.ts.map +1 -0
  50. package/dist/trace.js +93 -0
  51. package/dist/trace.js.map +1 -0
  52. package/package.json +62 -0
  53. package/src/adam.ts +95 -0
  54. package/src/buffers.ts +173 -0
  55. package/src/codegen.ts +758 -0
  56. package/src/compile.ts +120 -0
  57. package/src/grad.ts +459 -0
  58. package/src/index.ts +40 -0
  59. package/src/ir.ts +197 -0
  60. package/src/module.ts +126 -0
  61. package/src/ops.ts +311 -0
  62. package/src/runtime.ts +232 -0
  63. package/src/shape.ts +263 -0
  64. package/src/trace.ts +101 -0
package/src/runtime.ts ADDED
@@ -0,0 +1,232 @@
1
+ // WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
2
+ // allocates real GPU buffers and pipelines, and provides a `step()` method
3
+ // that uploads inputs, dispatches all kernels, and reads back outputs.
4
+ //
5
+ // Browser-only: this module needs `navigator.gpu` at runtime.
6
+
7
+ import type { BufferPlan } from './buffers.js'
8
+ import type { KernelSpec } from './codegen.js'
9
+
10
+ // TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
11
+ // Provided by the browser per WebGPU spec; declare just what we use.
12
+ declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
13
+
14
+ export interface CompiledRuntime {
15
+ /** Upload one or more parameter Float32Arrays to their GPU buffers. */
16
+ uploadParams(params: Record<string, Float32Array>): void
17
+ /** Read all parameters back as Float32Arrays — used for UI panels. */
18
+ downloadParams(): Promise<Record<string, Float32Array>>
19
+ /** Read all parameter gradients back. Mostly for verification / debugging. */
20
+ downloadParamGrads(): Promise<Record<string, Float32Array>>
21
+ /**
22
+ * One full forward+backward step.
23
+ * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
24
+ * 2. Dispatches every kernel in order.
25
+ * 3. Reads back the loss scalar.
26
+ * Returns the loss as a JS number.
27
+ */
28
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
29
+ /** Free GPU resources. */
30
+ destroy(): void
31
+ }
32
+
33
+ export interface RuntimeOpts {
34
+ /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
35
+ device?: GPUDevice
36
+ }
37
+
38
+ // Inlined numeric values (per WebGPU spec) so this module is importable in Node
39
+ // for codegen-only usage. The browser provides GPUBufferUsage as a global, but
40
+ // referencing it at module scope would crash before any browser code runs.
41
+ const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
42
+ const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
43
+
44
+ export async function createRuntime(
45
+ plan: BufferPlan,
46
+ kernels: KernelSpec[],
47
+ lossBufferId: number,
48
+ opts: RuntimeOpts = {},
49
+ ): Promise<CompiledRuntime> {
50
+ const device = opts.device ?? await acquireDevice()
51
+ const queue = device.queue
52
+
53
+ // ---- Allocate one GPUBuffer per BufferSpec --------------------------------
54
+ // State buffers also get filled with their initValue at allocation time.
55
+ const buffers = new Map<number, GPUBuffer>()
56
+ for (const spec of plan.buffers) {
57
+ const buf = device.createBuffer({
58
+ size: spec.byteSize,
59
+ usage: STORAGE_RW,
60
+ label: spec.name ?? `t${spec.id}-${spec.kind}`,
61
+ })
62
+ buffers.set(spec.id, buf)
63
+ if (spec.kind === 'state') {
64
+ // Fill with initValue (typically 0). Float and int both 4 bytes per element.
65
+ const elements = spec.byteSize / 4
66
+ const init = spec.dtype === 'f32'
67
+ ? new Float32Array(elements).fill(spec.initValue ?? 0)
68
+ : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
69
+ queue.writeBuffer(buf, 0, init as unknown as BufferSource)
70
+ }
71
+ }
72
+
73
+ // ---- Compile pipelines per kernel; cache by WGSL source -------------------
74
+ // Push an error scope around each shader+pipeline creation so we can surface
75
+ // the actual compile error rather than the cryptic "previous error" that
76
+ // comes from using an invalid pipeline at dispatch time.
77
+ const moduleCache = new Map<string, GPUShaderModule>()
78
+ const pipelines: (GPUComputePipeline | null)[] = []
79
+ type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
80
+ const probes: ErrorProbe[] = []
81
+ for (const k of kernels) {
82
+ if (!k.wgsl) { pipelines.push(null); continue }
83
+ let module = moduleCache.get(k.wgsl)
84
+ if (!module) {
85
+ module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
86
+ moduleCache.set(k.wgsl, module)
87
+ }
88
+ device.pushErrorScope('validation')
89
+ const pipeline = device.createComputePipeline({
90
+ layout: 'auto',
91
+ compute: { module, entryPoint: 'main' },
92
+ label: k.opKind,
93
+ })
94
+ pipelines.push(pipeline)
95
+ probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
96
+ }
97
+ const probeResults = await Promise.all(probes)
98
+ const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
99
+ if (failures.length > 0) {
100
+ const reports: string[] = []
101
+ for (const { k, module, err } of failures) {
102
+ const info = await module.getCompilationInfo()
103
+ const messages = info.messages
104
+ .map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
105
+ .join('\n')
106
+ reports.push(
107
+ `[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
108
+ (messages || ' (no compilation messages)') +
109
+ `\n--- WGSL ---\n${k.wgsl}\n-----------`,
110
+ )
111
+ }
112
+ // eslint-disable-next-line no-console
113
+ console.error(reports.join('\n\n'))
114
+ throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
115
+ }
116
+
117
+ // ---- Pre-build bind groups (static — buffer ids don't change per step) ---
118
+ const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
119
+ const pipeline = pipelines[i]
120
+ if (!pipeline) return null
121
+ return device.createBindGroup({
122
+ layout: pipeline.getBindGroupLayout(0),
123
+ entries: k.bindings.map((bufId, idx) => ({
124
+ binding: idx,
125
+ resource: { buffer: buffers.get(bufId)! },
126
+ })),
127
+ })
128
+ })
129
+
130
+ // ---- Loss readback staging buffer -----------------------------------------
131
+ const lossSpec = plan.buffers[lossBufferId]!
132
+ const lossReadback = device.createBuffer({ size: lossSpec.byteSize, usage: READBACK })
133
+
134
+ // ---- step() ---------------------------------------------------------------
135
+ async function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number> {
136
+ for (const [name, bufId] of plan.inputsByName) {
137
+ const data = inputs[name]
138
+ if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
139
+ const expectedBytes = plan.buffers[bufId]!.byteSize
140
+ if (data.byteLength !== expectedBytes) {
141
+ throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
142
+ }
143
+ // Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
144
+ // but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
145
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
146
+ }
147
+
148
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
149
+ for (let i = 0; i < kernels.length; i++) {
150
+ const k = kernels[i]!
151
+ if (!k.wgsl || k.threads === 0) continue
152
+ const pipeline = pipelines[i]!
153
+ const bindGroup = bindGroups[i]!
154
+ const pass = encoder.beginComputePass({ label: k.opKind })
155
+ pass.setPipeline(pipeline)
156
+ pass.setBindGroup(0, bindGroup)
157
+ // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
158
+ // when a kernel needs more than that on the X axis. Kernels compute their
159
+ // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
160
+ // stride we set here. For dispatches that fit in one row, gid.y is 0.
161
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
162
+ const MAX_X = 65535
163
+ const wgX = Math.min(wgCount, MAX_X)
164
+ const wgY = Math.ceil(wgCount / MAX_X)
165
+ pass.dispatchWorkgroups(wgX, wgY, 1)
166
+ pass.end()
167
+ }
168
+ // After all dispatches: writebacks (Adam state, updated params).
169
+ // copyBufferToBuffer is queued onto the same encoder so it's ordered after
170
+ // all kernel dispatches.
171
+ for (const wb of plan.writebacks) {
172
+ encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
173
+ }
174
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, lossReadback, 0, lossSpec.byteSize)
175
+ queue.submit([encoder.finish()])
176
+
177
+ await lossReadback.mapAsync(GPUMapMode.READ)
178
+ const view = new Float32Array(lossReadback.getMappedRange().slice(0))
179
+ lossReadback.unmap()
180
+ return view[0]!
181
+ }
182
+
183
+ // ---- uploadParams ---------------------------------------------------------
184
+ function uploadParams(params: Record<string, Float32Array>) {
185
+ for (const [name, bufId] of plan.paramsByName) {
186
+ const data = params[name]
187
+ if (!data) continue
188
+ queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
189
+ }
190
+ }
191
+
192
+ // ---- download helpers -----------------------------------------------------
193
+ async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
194
+ const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
195
+ const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
196
+ for (const [name, bufId] of map) {
197
+ const spec = plan.buffers[bufId]!
198
+ const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
199
+ encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
200
+ stagings.push({ name, buf: staging, bytes: spec.byteSize })
201
+ }
202
+ queue.submit([encoder.finish()])
203
+ const out: Record<string, Float32Array> = {}
204
+ for (const s of stagings) {
205
+ await s.buf.mapAsync(GPUMapMode.READ)
206
+ out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
207
+ s.buf.unmap()
208
+ s.buf.destroy()
209
+ }
210
+ return out
211
+ }
212
+
213
+ return {
214
+ uploadParams,
215
+ downloadParams: () => downloadFromMap(plan.paramsByName),
216
+ downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
217
+ step,
218
+ destroy: () => {
219
+ for (const b of buffers.values()) b.destroy()
220
+ lossReadback.destroy()
221
+ },
222
+ }
223
+ }
224
+
225
+ async function acquireDevice(): Promise<GPUDevice> {
226
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
227
+ throw new Error('tensorgrad: WebGPU not available in this environment')
228
+ }
229
+ const adapter = await navigator.gpu.requestAdapter()
230
+ if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
231
+ return await adapter.requestDevice()
232
+ }
package/src/shape.ts ADDED
@@ -0,0 +1,263 @@
1
+ // Shape inference and validation for each op kind.
2
+ //
3
+ // Every op in src/ops.ts validates its inputs and computes its output shape
4
+ // through helpers here. Errors throw with the captured call-site so the
5
+ // stack trace points at the user's line, not into the library.
6
+ //
7
+ // Broadcasting rules (deliberately limited):
8
+ // * For element-wise binops (add/sub/mul/div), we support trailing-axis
9
+ // broadcasting: the smaller operand's shape must be a suffix of the
10
+ // larger's, with axes of size 1 broadcasting to any size. Examples
11
+ // ALLOWED: [B, T, D] op [D] → [B, T, D]
12
+ // [B, T, D] op [1, D] → [B, T, D]
13
+ // [B, T, D] op [B, T, D] → [B, T, D]
14
+ // Examples REJECTED: [B, T, D] op [B] (suffix mismatch)
15
+ // [B, T, D] op [T, D] when T != B (legal numpy, banned here)
16
+ // The restriction makes codegen and autograd much simpler and covers every
17
+ // broadcast pattern in our transformer (biases, layernorm gain/bias, masks).
18
+
19
+ import type { Shape, CallSite } from './ir.js'
20
+ import { formatSite } from './ir.js'
21
+
22
+ // ============================================================================
23
+ // Errors
24
+ // ============================================================================
25
+
26
+ export class ShapeError extends Error {
27
+ constructor(message: string, site: CallSite | null) {
28
+ const formatted = site ? `${message}\n at ${formatSite(site)}` : message
29
+ super(formatted)
30
+ this.name = 'ShapeError'
31
+ }
32
+ }
33
+
34
+ function fail(message: string, site: CallSite | null): never {
35
+ throw new ShapeError(message, site)
36
+ }
37
+
38
+ // ============================================================================
39
+ // Shape utilities
40
+ // ============================================================================
41
+
42
+ export function shapesEqual(a: Shape, b: Shape): boolean {
43
+ if (a.length !== b.length) return false
44
+ for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false
45
+ return true
46
+ }
47
+
48
+ export function shapeSize(shape: Shape): number {
49
+ let n = 1
50
+ for (const d of shape) n *= d
51
+ return n
52
+ }
53
+
54
+ export function showShape(shape: Shape): string {
55
+ return `[${shape.join(', ')}]`
56
+ }
57
+
58
+ // Standard right-aligned NumPy-style broadcasting. Pad the shorter shape with
59
+ // leading 1s, then per-axis: equal dims unify, size-1 dims broadcast on either
60
+ // side, otherwise incompatible. Returns the resulting shape or null.
61
+ export function broadcastTrailing(a: Shape, b: Shape): Shape | null {
62
+ const rank = Math.max(a.length, b.length)
63
+ const out: number[] = new Array(rank)
64
+ for (let i = 0; i < rank; i++) {
65
+ const ai = i - (rank - a.length)
66
+ const bi = i - (rank - b.length)
67
+ const av = ai < 0 ? 1 : a[ai]!
68
+ const bv = bi < 0 ? 1 : b[bi]!
69
+ if (av === bv) out[i] = av
70
+ else if (av === 1) out[i] = bv
71
+ else if (bv === 1) out[i] = av
72
+ else return null
73
+ }
74
+ return out
75
+ }
76
+
77
+ // ============================================================================
78
+ // Per-op shape rules
79
+ // ============================================================================
80
+ //
81
+ // Each rule takes the input shapes and returns the output shape, or throws.
82
+ // All rules accept a `site` for error attribution.
83
+
84
+ export function inferElementwiseBinop(
85
+ opName: string, aShape: Shape, bShape: Shape, site: CallSite | null,
86
+ ): Shape {
87
+ const result = broadcastTrailing(aShape, bShape)
88
+ if (!result) {
89
+ fail(
90
+ `${opName}: incompatible shapes ${showShape(aShape)} and ${showShape(bShape)}. ` +
91
+ `Trailing-suffix broadcasting only — the smaller shape must be a suffix of the larger, ` +
92
+ `with size-1 axes broadcasting to any size.`,
93
+ site,
94
+ )
95
+ }
96
+ return result
97
+ }
98
+
99
+ export function inferUnary(_opName: string, aShape: Shape, _site: CallSite | null): Shape {
100
+ return aShape
101
+ }
102
+
103
+ export function inferMeanLast(opName: string, aShape: Shape, site: CallSite | null): Shape {
104
+ if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site)
105
+ // keepdims=true: replace last axis with 1.
106
+ return [...aShape.slice(0, -1), 1]
107
+ }
108
+
109
+ export function inferSumLast(opName: string, aShape: Shape, site: CallSite | null): Shape {
110
+ if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site)
111
+ // keepdims=false: drop the last axis.
112
+ return aShape.slice(0, -1)
113
+ }
114
+
115
+ export function inferReshape(opName: string, aShape: Shape, newShape: Shape, site: CallSite | null): Shape {
116
+ // Validate -1 placeholder (at most one allowed) and total size match.
117
+ let inferIdx = -1
118
+ let knownSize = 1
119
+ for (let i = 0; i < newShape.length; i++) {
120
+ const d = newShape[i]!
121
+ if (d === -1) {
122
+ if (inferIdx !== -1) fail(`${opName}: at most one -1 dim allowed in newShape ${showShape(newShape)}`, site)
123
+ inferIdx = i
124
+ } else if (d <= 0) {
125
+ fail(`${opName}: invalid dim ${d} in newShape ${showShape(newShape)}`, site)
126
+ } else {
127
+ knownSize *= d
128
+ }
129
+ }
130
+ const totalIn = shapeSize(aShape)
131
+ const out = [...newShape]
132
+ if (inferIdx !== -1) {
133
+ if (totalIn % knownSize !== 0) {
134
+ fail(`${opName}: cannot reshape ${showShape(aShape)} (size ${totalIn}) to ${showShape(newShape)} — known dims multiply to ${knownSize}`, site)
135
+ }
136
+ out[inferIdx] = totalIn / knownSize
137
+ } else if (knownSize !== totalIn) {
138
+ fail(`${opName}: size mismatch — input ${showShape(aShape)} has ${totalIn} elements but newShape ${showShape(newShape)} has ${knownSize}`, site)
139
+ }
140
+ return out
141
+ }
142
+
143
+ export function inferTranspose(opName: string, aShape: Shape, perm: readonly number[], site: CallSite | null): Shape {
144
+ if (perm.length !== aShape.length) {
145
+ fail(`${opName}: perm length ${perm.length} must equal input rank ${aShape.length}`, site)
146
+ }
147
+ const seen = new Set<number>()
148
+ for (const p of perm) {
149
+ if (p < 0 || p >= aShape.length) fail(`${opName}: perm index ${p} out of range for rank ${aShape.length}`, site)
150
+ if (seen.has(p)) fail(`${opName}: perm has duplicate index ${p}`, site)
151
+ seen.add(p)
152
+ }
153
+ return perm.map(p => aShape[p]!)
154
+ }
155
+
156
+ // matmul: a [..., M, K] · b [K, N] → [..., M, N]. b is unbatched.
157
+ export function inferMatmul(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
158
+ if (aShape.length < 2) fail(`${opName}: lhs must have rank >= 2, got ${showShape(aShape)}`, site)
159
+ if (bShape.length !== 2) fail(`${opName}: rhs must have rank 2, got ${showShape(bShape)} — use matmulBatched for batched rhs`, site)
160
+ const M = aShape[aShape.length - 2]!
161
+ const Ka = aShape[aShape.length - 1]!
162
+ const Kb = bShape[0]!
163
+ const N = bShape[1]!
164
+ if (Ka !== Kb) fail(`${opName}: inner dims don't match — ${showShape(aShape)} · ${showShape(bShape)} (last axis of lhs = ${Ka}, first axis of rhs = ${Kb})`, site)
165
+ return [...aShape.slice(0, -2), M, N]
166
+ }
167
+
168
+ // matmul_batched: a [..., M, K] · b [..., K, N] → [..., M, N]. Both have leading batch dims.
169
+ export function inferMatmulBatched(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
170
+ if (aShape.length < 2 || bShape.length < 2) {
171
+ fail(`${opName}: both inputs must have rank >= 2, got ${showShape(aShape)} and ${showShape(bShape)}`, site)
172
+ }
173
+ if (aShape.length !== bShape.length) {
174
+ fail(`${opName}: ranks must match (got ${aShape.length} vs ${bShape.length}). Reshape if you need different batch dims.`, site)
175
+ }
176
+ const aBatch = aShape.slice(0, -2)
177
+ const bBatch = bShape.slice(0, -2)
178
+ for (let i = 0; i < aBatch.length; i++) {
179
+ if (aBatch[i] !== bBatch[i]) {
180
+ fail(`${opName}: batch dims must match — ${showShape(aShape)} vs ${showShape(bShape)}`, site)
181
+ }
182
+ }
183
+ const M = aShape[aShape.length - 2]!
184
+ const Ka = aShape[aShape.length - 1]!
185
+ const Kb = bShape[bShape.length - 2]!
186
+ const N = bShape[bShape.length - 1]!
187
+ if (Ka !== Kb) fail(`${opName}: inner dims don't match — last axis of lhs = ${Ka}, second-to-last of rhs = ${Kb}`, site)
188
+ return [...aBatch, M, N]
189
+ }
190
+
191
+ export function inferOneHot(opName: string, indicesShape: Shape, depth: number, site: CallSite | null): Shape {
192
+ if (depth <= 0) fail(`${opName}: depth must be positive, got ${depth}`, site)
193
+ return [...indicesShape, depth]
194
+ }
195
+
196
+ // where_causal preserves shape but requires the last two axes to be square.
197
+ export function inferWhereCausal(opName: string, aShape: Shape, site: CallSite | null): Shape {
198
+ if (aShape.length < 2) fail(`${opName}: requires rank >= 2, got ${showShape(aShape)}`, site)
199
+ const m = aShape[aShape.length - 2]!
200
+ const n = aShape[aShape.length - 1]!
201
+ if (m !== n) fail(`${opName}: last two axes must be equal (square mask), got ${showShape(aShape)}`, site)
202
+ return aShape
203
+ }
204
+
205
+ export function inferSliceLastRange(opName: string, aShape: Shape, start: number, end: number, site: CallSite | null): Shape {
206
+ if (aShape.length === 0) fail(`${opName}: cannot slice 0-d tensor`, site)
207
+ const last = aShape[aShape.length - 1]!
208
+ if (start < 0 || end > last || start >= end) {
209
+ fail(`${opName}: invalid range [${start}, ${end}) for last axis of size ${last}`, site)
210
+ }
211
+ return [...aShape.slice(0, -1), end - start]
212
+ }
213
+
214
+ // broadcast_to: validate that `aShape` can broadcast to `targetShape` under
215
+ // right-aligned NumPy rules. Returns targetShape on success.
216
+ export function inferBroadcastTo(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape {
217
+ if (aShape.length > targetShape.length) {
218
+ fail(`${opName}: source rank ${aShape.length} > target rank ${targetShape.length}`, site)
219
+ }
220
+ const offset = targetShape.length - aShape.length
221
+ for (let i = 0; i < aShape.length; i++) {
222
+ const av = aShape[i]!
223
+ const tv = targetShape[offset + i]!
224
+ if (av !== tv && av !== 1) {
225
+ fail(`${opName}: cannot broadcast ${showShape(aShape)} to ${showShape(targetShape)} — axis ${i} (size ${av}) doesn't match target axis ${offset + i} (size ${tv}) and isn't 1`, site)
226
+ }
227
+ }
228
+ return targetShape
229
+ }
230
+
231
+ // sum_to_shape: validate that `targetShape` is a valid right-aligned reduction
232
+ // of `aShape` (i.e., aShape can have been produced by broadcasting targetShape).
233
+ export function inferSumToShape(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape {
234
+ if (targetShape.length > aShape.length) {
235
+ fail(`${opName}: target rank ${targetShape.length} > source rank ${aShape.length}`, site)
236
+ }
237
+ const offset = aShape.length - targetShape.length
238
+ for (let i = 0; i < targetShape.length; i++) {
239
+ const av = aShape[offset + i]!
240
+ const tv = targetShape[i]!
241
+ if (av !== tv && tv !== 1) {
242
+ fail(`${opName}: cannot sum-reduce ${showShape(aShape)} to ${showShape(targetShape)} — target axis ${i} (size ${tv}) must be 1 or match source`, site)
243
+ }
244
+ }
245
+ return targetShape
246
+ }
247
+
248
+ // Three-way broadcast for `where(cond, a, b)`. All three shapes must broadcast
249
+ // to a common shape under standard NumPy rules.
250
+ export function inferWhere(opName: string, condShape: Shape, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
251
+ const ab = broadcastTrailing(aShape, bShape)
252
+ if (!ab) fail(`${opName}: a/b incompatible: ${showShape(aShape)} vs ${showShape(bShape)}`, site)
253
+ const result = broadcastTrailing(condShape, ab)
254
+ if (!result) fail(`${opName}: cond ${showShape(condShape)} incompatible with broadcast(a, b) ${showShape(ab)}`, site)
255
+ return result
256
+ }
257
+
258
+ export function inferReluGrad(opName: string, xShape: Shape, dyShape: Shape, site: CallSite | null): Shape {
259
+ if (!shapesEqual(xShape, dyShape)) {
260
+ fail(`${opName}: x and dy must have matching shapes, got ${showShape(xShape)} and ${showShape(dyShape)}`, site)
261
+ }
262
+ return xShape
263
+ }
package/src/trace.ts ADDED
@@ -0,0 +1,101 @@
1
+ // Trace driver. Holds the "current graph" in module-local state so user code
2
+ // can call ops without threading a graph parameter through every function.
3
+ //
4
+ // Usage:
5
+ //
6
+ // const graph = trace(() => {
7
+ // const x = tensorInput('x', [B, T], 'i32')
8
+ // const w = paramInput('w', [V, D], 'f32')
9
+ // // ... user computation building tensors ...
10
+ // return finalLossTensor
11
+ // })
12
+ //
13
+ // `trace` is single-threaded and re-entrant only via nested calls (which share
14
+ // the outer graph — but we don't currently have a use for nesting). Calling an
15
+ // op outside a `trace(...)` block is an error.
16
+
17
+ import type { Graph, Tensor, Shape, Dtype } from './ir.js'
18
+ import { makeGraph, addOp, captureSite } from './ir.js'
19
+
20
+ // Module-local: the graph being built right now, or null if no trace is active.
21
+ let _current: Graph | null = null
22
+
23
+ export function currentGraph(): Graph {
24
+ if (!_current) {
25
+ throw new Error(
26
+ 'tensorgrad: ops can only be called inside trace(). ' +
27
+ 'Did you forget to wrap your forward pass?',
28
+ )
29
+ }
30
+ return _current
31
+ }
32
+
33
+ // Run `fn` with a fresh graph as the current one; capture and return the graph.
34
+ // `fn` must return the tensor (or array of tensors) to mark as graph outputs.
35
+ export function trace(fn: () => Tensor | Tensor[]): Graph {
36
+ if (_current) {
37
+ throw new Error('tensorgrad: nested trace() is not supported')
38
+ }
39
+ const g = makeGraph()
40
+ _current = g
41
+ try {
42
+ const result = fn()
43
+ const outputs = Array.isArray(result) ? result : [result]
44
+ for (const t of outputs) {
45
+ ;(g.outputs as number[]).push(t.id)
46
+ }
47
+ } finally {
48
+ _current = null
49
+ }
50
+ return g
51
+ }
52
+
53
+ // Re-enter an existing graph to append more ops. Used by autograd to add
54
+ // backward ops to a graph that's already been traced. `fn` runs with the
55
+ // supplied graph as the current one; any ops it calls append to that graph.
56
+ // Returns whatever `fn` returns.
57
+ export function traceInto<T>(g: Graph, fn: () => T): T {
58
+ if (_current) {
59
+ throw new Error('tensorgrad: traceInto() called while another trace is active')
60
+ }
61
+ _current = g
62
+ try {
63
+ return fn()
64
+ } finally {
65
+ _current = null
66
+ }
67
+ }
68
+
69
+ // ---- Leaf tensor builders --------------------------------------------------
70
+ // Inputs are added to the graph as `param_input` or `tensor_input` op nodes.
71
+ // Their .source on the Tensor points at that node so codegen knows where to
72
+ // bind external data.
73
+
74
+ export function paramInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
75
+ const g = currentGraph()
76
+ if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
77
+ throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
78
+ }
79
+ const site = captureSite('paramInput')
80
+ return addOp(g, 'param_input', shape, dtype, site, { name } as any)
81
+ }
82
+
83
+ export function tensorInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
84
+ const g = currentGraph()
85
+ if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
86
+ throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
87
+ }
88
+ const site = captureSite('tensorInput')
89
+ return addOp(g, 'tensor_input', shape, dtype, site, { name } as any)
90
+ }
91
+
92
+ // Persistent state buffer. Allocated at compile time, zero-(or initValue-)initialized,
93
+ // and updated across step() calls via writebacks declared by the optimizer helper.
94
+ export function stateInput(name: string, shape: Shape, dtype: Dtype = 'f32', initValue = 0): Tensor {
95
+ const g = currentGraph()
96
+ if (g.ops.some(op => op.kind === 'state_input' && op.name === name)) {
97
+ throw new Error(`tensorgrad: state name '${name}' already used in this trace`)
98
+ }
99
+ const site = captureSite('stateInput')
100
+ return addOp(g, 'state_input', shape, dtype, site, { name, initValue } as any)
101
+ }