tensorgrad 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +121 -0
- package/SPEC.md +293 -0
- package/dist/adam.d.ts +31 -0
- package/dist/adam.d.ts.map +1 -0
- package/dist/adam.js +66 -0
- package/dist/adam.js.map +1 -0
- package/dist/buffers.d.ts +56 -0
- package/dist/buffers.d.ts.map +1 -0
- package/dist/buffers.js +114 -0
- package/dist/buffers.js.map +1 -0
- package/dist/codegen.d.ts +23 -0
- package/dist/codegen.d.ts.map +1 -0
- package/dist/codegen.js +709 -0
- package/dist/codegen.js.map +1 -0
- package/dist/compile.d.ts +53 -0
- package/dist/compile.d.ts.map +1 -0
- package/dist/compile.js +76 -0
- package/dist/compile.js.map +1 -0
- package/dist/grad.d.ts +8 -0
- package/dist/grad.d.ts.map +1 -0
- package/dist/grad.js +404 -0
- package/dist/grad.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +37 -0
- package/dist/index.js.map +1 -0
- package/dist/ir.d.ts +204 -0
- package/dist/ir.d.ts.map +1 -0
- package/dist/ir.js +60 -0
- package/dist/ir.js.map +1 -0
- package/dist/module.d.ts +21 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +113 -0
- package/dist/module.js.map +1 -0
- package/dist/ops.d.ts +35 -0
- package/dist/ops.d.ts.map +1 -0
- package/dist/ops.js +270 -0
- package/dist/ops.js.map +1 -0
- package/dist/runtime.d.ts +26 -0
- package/dist/runtime.d.ts.map +1 -0
- package/dist/runtime.js +190 -0
- package/dist/runtime.js.map +1 -0
- package/dist/shape.d.ts +24 -0
- package/dist/shape.d.ts.map +1 -0
- package/dist/shape.js +259 -0
- package/dist/shape.js.map +1 -0
- package/dist/trace.d.ts +8 -0
- package/dist/trace.d.ts.map +1 -0
- package/dist/trace.js +93 -0
- package/dist/trace.js.map +1 -0
- package/package.json +62 -0
- package/src/adam.ts +95 -0
- package/src/buffers.ts +173 -0
- package/src/codegen.ts +758 -0
- package/src/compile.ts +120 -0
- package/src/grad.ts +459 -0
- package/src/index.ts +40 -0
- package/src/ir.ts +197 -0
- package/src/module.ts +126 -0
- package/src/ops.ts +311 -0
- package/src/runtime.ts +232 -0
- package/src/shape.ts +263 -0
- package/src/trace.ts +101 -0
package/src/runtime.ts
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
// WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
|
|
2
|
+
// allocates real GPU buffers and pipelines, and provides a `step()` method
|
|
3
|
+
// that uploads inputs, dispatches all kernels, and reads back outputs.
|
|
4
|
+
//
|
|
5
|
+
// Browser-only: this module needs `navigator.gpu` at runtime.
|
|
6
|
+
|
|
7
|
+
import type { BufferPlan } from './buffers.js'
|
|
8
|
+
import type { KernelSpec } from './codegen.js'
|
|
9
|
+
|
|
10
|
+
// TS lib.dom defines WebGPU types but not the GPUMapMode runtime constant.
|
|
11
|
+
// Provided by the browser per WebGPU spec; declare just what we use.
|
|
12
|
+
declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
|
|
13
|
+
|
|
14
|
+
export interface CompiledRuntime {
|
|
15
|
+
/** Upload one or more parameter Float32Arrays to their GPU buffers. */
|
|
16
|
+
uploadParams(params: Record<string, Float32Array>): void
|
|
17
|
+
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
18
|
+
downloadParams(): Promise<Record<string, Float32Array>>
|
|
19
|
+
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
20
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>
|
|
21
|
+
/**
|
|
22
|
+
* One full forward+backward step.
|
|
23
|
+
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
24
|
+
* 2. Dispatches every kernel in order.
|
|
25
|
+
* 3. Reads back the loss scalar.
|
|
26
|
+
* Returns the loss as a JS number.
|
|
27
|
+
*/
|
|
28
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
29
|
+
/** Free GPU resources. */
|
|
30
|
+
destroy(): void
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export interface RuntimeOpts {
|
|
34
|
+
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
35
|
+
device?: GPUDevice
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// Inlined numeric values (per WebGPU spec) so this module is importable in Node
|
|
39
|
+
// for codegen-only usage. The browser provides GPUBufferUsage as a global, but
|
|
40
|
+
// referencing it at module scope would crash before any browser code runs.
|
|
41
|
+
const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4 /*COPY_SRC*/
|
|
42
|
+
const READBACK = 0x1 /*MAP_READ*/ | 0x8 /*COPY_DST*/
|
|
43
|
+
|
|
44
|
+
export async function createRuntime(
|
|
45
|
+
plan: BufferPlan,
|
|
46
|
+
kernels: KernelSpec[],
|
|
47
|
+
lossBufferId: number,
|
|
48
|
+
opts: RuntimeOpts = {},
|
|
49
|
+
): Promise<CompiledRuntime> {
|
|
50
|
+
const device = opts.device ?? await acquireDevice()
|
|
51
|
+
const queue = device.queue
|
|
52
|
+
|
|
53
|
+
// ---- Allocate one GPUBuffer per BufferSpec --------------------------------
|
|
54
|
+
// State buffers also get filled with their initValue at allocation time.
|
|
55
|
+
const buffers = new Map<number, GPUBuffer>()
|
|
56
|
+
for (const spec of plan.buffers) {
|
|
57
|
+
const buf = device.createBuffer({
|
|
58
|
+
size: spec.byteSize,
|
|
59
|
+
usage: STORAGE_RW,
|
|
60
|
+
label: spec.name ?? `t${spec.id}-${spec.kind}`,
|
|
61
|
+
})
|
|
62
|
+
buffers.set(spec.id, buf)
|
|
63
|
+
if (spec.kind === 'state') {
|
|
64
|
+
// Fill with initValue (typically 0). Float and int both 4 bytes per element.
|
|
65
|
+
const elements = spec.byteSize / 4
|
|
66
|
+
const init = spec.dtype === 'f32'
|
|
67
|
+
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
68
|
+
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
69
|
+
queue.writeBuffer(buf, 0, init as unknown as BufferSource)
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// ---- Compile pipelines per kernel; cache by WGSL source -------------------
|
|
74
|
+
// Push an error scope around each shader+pipeline creation so we can surface
|
|
75
|
+
// the actual compile error rather than the cryptic "previous error" that
|
|
76
|
+
// comes from using an invalid pipeline at dispatch time.
|
|
77
|
+
const moduleCache = new Map<string, GPUShaderModule>()
|
|
78
|
+
const pipelines: (GPUComputePipeline | null)[] = []
|
|
79
|
+
type ErrorProbe = Promise<{ k: KernelSpec; module: GPUShaderModule; err: GPUError } | null>
|
|
80
|
+
const probes: ErrorProbe[] = []
|
|
81
|
+
for (const k of kernels) {
|
|
82
|
+
if (!k.wgsl) { pipelines.push(null); continue }
|
|
83
|
+
let module = moduleCache.get(k.wgsl)
|
|
84
|
+
if (!module) {
|
|
85
|
+
module = device.createShaderModule({ code: k.wgsl, label: k.opKind })
|
|
86
|
+
moduleCache.set(k.wgsl, module)
|
|
87
|
+
}
|
|
88
|
+
device.pushErrorScope('validation')
|
|
89
|
+
const pipeline = device.createComputePipeline({
|
|
90
|
+
layout: 'auto',
|
|
91
|
+
compute: { module, entryPoint: 'main' },
|
|
92
|
+
label: k.opKind,
|
|
93
|
+
})
|
|
94
|
+
pipelines.push(pipeline)
|
|
95
|
+
probes.push(device.popErrorScope().then(err => err ? { k, module: module!, err } : null))
|
|
96
|
+
}
|
|
97
|
+
const probeResults = await Promise.all(probes)
|
|
98
|
+
const failures = probeResults.filter((p): p is { k: KernelSpec; module: GPUShaderModule; err: GPUError } => p != null)
|
|
99
|
+
if (failures.length > 0) {
|
|
100
|
+
const reports: string[] = []
|
|
101
|
+
for (const { k, module, err } of failures) {
|
|
102
|
+
const info = await module.getCompilationInfo()
|
|
103
|
+
const messages = info.messages
|
|
104
|
+
.map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
|
|
105
|
+
.join('\n')
|
|
106
|
+
reports.push(
|
|
107
|
+
`[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
|
|
108
|
+
(messages || ' (no compilation messages)') +
|
|
109
|
+
`\n--- WGSL ---\n${k.wgsl}\n-----------`,
|
|
110
|
+
)
|
|
111
|
+
}
|
|
112
|
+
// eslint-disable-next-line no-console
|
|
113
|
+
console.error(reports.join('\n\n'))
|
|
114
|
+
throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`)
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// ---- Pre-build bind groups (static — buffer ids don't change per step) ---
|
|
118
|
+
const bindGroups: (GPUBindGroup | null)[] = kernels.map((k, i) => {
|
|
119
|
+
const pipeline = pipelines[i]
|
|
120
|
+
if (!pipeline) return null
|
|
121
|
+
return device.createBindGroup({
|
|
122
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
123
|
+
entries: k.bindings.map((bufId, idx) => ({
|
|
124
|
+
binding: idx,
|
|
125
|
+
resource: { buffer: buffers.get(bufId)! },
|
|
126
|
+
})),
|
|
127
|
+
})
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
// ---- Loss readback staging buffer -----------------------------------------
|
|
131
|
+
const lossSpec = plan.buffers[lossBufferId]!
|
|
132
|
+
const lossReadback = device.createBuffer({ size: lossSpec.byteSize, usage: READBACK })
|
|
133
|
+
|
|
134
|
+
// ---- step() ---------------------------------------------------------------
|
|
135
|
+
async function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number> {
|
|
136
|
+
for (const [name, bufId] of plan.inputsByName) {
|
|
137
|
+
const data = inputs[name]
|
|
138
|
+
if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
|
|
139
|
+
const expectedBytes = plan.buffers[bufId]!.byteSize
|
|
140
|
+
if (data.byteLength !== expectedBytes) {
|
|
141
|
+
throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`)
|
|
142
|
+
}
|
|
143
|
+
// Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
|
|
144
|
+
// but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
|
|
145
|
+
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
|
|
149
|
+
for (let i = 0; i < kernels.length; i++) {
|
|
150
|
+
const k = kernels[i]!
|
|
151
|
+
if (!k.wgsl || k.threads === 0) continue
|
|
152
|
+
const pipeline = pipelines[i]!
|
|
153
|
+
const bindGroup = bindGroups[i]!
|
|
154
|
+
const pass = encoder.beginComputePass({ label: k.opKind })
|
|
155
|
+
pass.setPipeline(pipeline)
|
|
156
|
+
pass.setBindGroup(0, bindGroup)
|
|
157
|
+
// WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
|
|
158
|
+
// when a kernel needs more than that on the X axis. Kernels compute their
|
|
159
|
+
// global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
|
|
160
|
+
// stride we set here. For dispatches that fit in one row, gid.y is 0.
|
|
161
|
+
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
|
|
162
|
+
const MAX_X = 65535
|
|
163
|
+
const wgX = Math.min(wgCount, MAX_X)
|
|
164
|
+
const wgY = Math.ceil(wgCount / MAX_X)
|
|
165
|
+
pass.dispatchWorkgroups(wgX, wgY, 1)
|
|
166
|
+
pass.end()
|
|
167
|
+
}
|
|
168
|
+
// After all dispatches: writebacks (Adam state, updated params).
|
|
169
|
+
// copyBufferToBuffer is queued onto the same encoder so it's ordered after
|
|
170
|
+
// all kernel dispatches.
|
|
171
|
+
for (const wb of plan.writebacks) {
|
|
172
|
+
encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
|
|
173
|
+
}
|
|
174
|
+
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, lossReadback, 0, lossSpec.byteSize)
|
|
175
|
+
queue.submit([encoder.finish()])
|
|
176
|
+
|
|
177
|
+
await lossReadback.mapAsync(GPUMapMode.READ)
|
|
178
|
+
const view = new Float32Array(lossReadback.getMappedRange().slice(0))
|
|
179
|
+
lossReadback.unmap()
|
|
180
|
+
return view[0]!
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// ---- uploadParams ---------------------------------------------------------
|
|
184
|
+
function uploadParams(params: Record<string, Float32Array>) {
|
|
185
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
186
|
+
const data = params[name]
|
|
187
|
+
if (!data) continue
|
|
188
|
+
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// ---- download helpers -----------------------------------------------------
|
|
193
|
+
async function downloadFromMap(map: Map<string, number>): Promise<Record<string, Float32Array>> {
|
|
194
|
+
const stagings: { name: string; buf: GPUBuffer; bytes: number }[] = []
|
|
195
|
+
const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' })
|
|
196
|
+
for (const [name, bufId] of map) {
|
|
197
|
+
const spec = plan.buffers[bufId]!
|
|
198
|
+
const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK })
|
|
199
|
+
encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, staging, 0, spec.byteSize)
|
|
200
|
+
stagings.push({ name, buf: staging, bytes: spec.byteSize })
|
|
201
|
+
}
|
|
202
|
+
queue.submit([encoder.finish()])
|
|
203
|
+
const out: Record<string, Float32Array> = {}
|
|
204
|
+
for (const s of stagings) {
|
|
205
|
+
await s.buf.mapAsync(GPUMapMode.READ)
|
|
206
|
+
out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0))
|
|
207
|
+
s.buf.unmap()
|
|
208
|
+
s.buf.destroy()
|
|
209
|
+
}
|
|
210
|
+
return out
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
return {
|
|
214
|
+
uploadParams,
|
|
215
|
+
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
216
|
+
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
217
|
+
step,
|
|
218
|
+
destroy: () => {
|
|
219
|
+
for (const b of buffers.values()) b.destroy()
|
|
220
|
+
lossReadback.destroy()
|
|
221
|
+
},
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
async function acquireDevice(): Promise<GPUDevice> {
|
|
226
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
227
|
+
throw new Error('tensorgrad: WebGPU not available in this environment')
|
|
228
|
+
}
|
|
229
|
+
const adapter = await navigator.gpu.requestAdapter()
|
|
230
|
+
if (!adapter) throw new Error('tensorgrad: no WebGPU adapter')
|
|
231
|
+
return await adapter.requestDevice()
|
|
232
|
+
}
|
package/src/shape.ts
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
// Shape inference and validation for each op kind.
|
|
2
|
+
//
|
|
3
|
+
// Every op in src/ops.ts validates its inputs and computes its output shape
|
|
4
|
+
// through helpers here. Errors throw with the captured call-site so the
|
|
5
|
+
// stack trace points at the user's line, not into the library.
|
|
6
|
+
//
|
|
7
|
+
// Broadcasting rules (deliberately limited):
|
|
8
|
+
// * For element-wise binops (add/sub/mul/div), we support trailing-axis
|
|
9
|
+
// broadcasting: the smaller operand's shape must be a suffix of the
|
|
10
|
+
// larger's, with axes of size 1 broadcasting to any size. Examples
|
|
11
|
+
// ALLOWED: [B, T, D] op [D] → [B, T, D]
|
|
12
|
+
// [B, T, D] op [1, D] → [B, T, D]
|
|
13
|
+
// [B, T, D] op [B, T, D] → [B, T, D]
|
|
14
|
+
// Examples REJECTED: [B, T, D] op [B] (suffix mismatch)
|
|
15
|
+
// [B, T, D] op [T, D] when T != B (legal numpy, banned here)
|
|
16
|
+
// The restriction makes codegen and autograd much simpler and covers every
|
|
17
|
+
// broadcast pattern in our transformer (biases, layernorm gain/bias, masks).
|
|
18
|
+
|
|
19
|
+
import type { Shape, CallSite } from './ir.js'
|
|
20
|
+
import { formatSite } from './ir.js'
|
|
21
|
+
|
|
22
|
+
// ============================================================================
|
|
23
|
+
// Errors
|
|
24
|
+
// ============================================================================
|
|
25
|
+
|
|
26
|
+
export class ShapeError extends Error {
|
|
27
|
+
constructor(message: string, site: CallSite | null) {
|
|
28
|
+
const formatted = site ? `${message}\n at ${formatSite(site)}` : message
|
|
29
|
+
super(formatted)
|
|
30
|
+
this.name = 'ShapeError'
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
function fail(message: string, site: CallSite | null): never {
|
|
35
|
+
throw new ShapeError(message, site)
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// ============================================================================
|
|
39
|
+
// Shape utilities
|
|
40
|
+
// ============================================================================
|
|
41
|
+
|
|
42
|
+
export function shapesEqual(a: Shape, b: Shape): boolean {
|
|
43
|
+
if (a.length !== b.length) return false
|
|
44
|
+
for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false
|
|
45
|
+
return true
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export function shapeSize(shape: Shape): number {
|
|
49
|
+
let n = 1
|
|
50
|
+
for (const d of shape) n *= d
|
|
51
|
+
return n
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export function showShape(shape: Shape): string {
|
|
55
|
+
return `[${shape.join(', ')}]`
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Standard right-aligned NumPy-style broadcasting. Pad the shorter shape with
|
|
59
|
+
// leading 1s, then per-axis: equal dims unify, size-1 dims broadcast on either
|
|
60
|
+
// side, otherwise incompatible. Returns the resulting shape or null.
|
|
61
|
+
export function broadcastTrailing(a: Shape, b: Shape): Shape | null {
|
|
62
|
+
const rank = Math.max(a.length, b.length)
|
|
63
|
+
const out: number[] = new Array(rank)
|
|
64
|
+
for (let i = 0; i < rank; i++) {
|
|
65
|
+
const ai = i - (rank - a.length)
|
|
66
|
+
const bi = i - (rank - b.length)
|
|
67
|
+
const av = ai < 0 ? 1 : a[ai]!
|
|
68
|
+
const bv = bi < 0 ? 1 : b[bi]!
|
|
69
|
+
if (av === bv) out[i] = av
|
|
70
|
+
else if (av === 1) out[i] = bv
|
|
71
|
+
else if (bv === 1) out[i] = av
|
|
72
|
+
else return null
|
|
73
|
+
}
|
|
74
|
+
return out
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// ============================================================================
|
|
78
|
+
// Per-op shape rules
|
|
79
|
+
// ============================================================================
|
|
80
|
+
//
|
|
81
|
+
// Each rule takes the input shapes and returns the output shape, or throws.
|
|
82
|
+
// All rules accept a `site` for error attribution.
|
|
83
|
+
|
|
84
|
+
export function inferElementwiseBinop(
|
|
85
|
+
opName: string, aShape: Shape, bShape: Shape, site: CallSite | null,
|
|
86
|
+
): Shape {
|
|
87
|
+
const result = broadcastTrailing(aShape, bShape)
|
|
88
|
+
if (!result) {
|
|
89
|
+
fail(
|
|
90
|
+
`${opName}: incompatible shapes ${showShape(aShape)} and ${showShape(bShape)}. ` +
|
|
91
|
+
`Trailing-suffix broadcasting only — the smaller shape must be a suffix of the larger, ` +
|
|
92
|
+
`with size-1 axes broadcasting to any size.`,
|
|
93
|
+
site,
|
|
94
|
+
)
|
|
95
|
+
}
|
|
96
|
+
return result
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
export function inferUnary(_opName: string, aShape: Shape, _site: CallSite | null): Shape {
|
|
100
|
+
return aShape
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
export function inferMeanLast(opName: string, aShape: Shape, site: CallSite | null): Shape {
|
|
104
|
+
if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site)
|
|
105
|
+
// keepdims=true: replace last axis with 1.
|
|
106
|
+
return [...aShape.slice(0, -1), 1]
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
export function inferSumLast(opName: string, aShape: Shape, site: CallSite | null): Shape {
|
|
110
|
+
if (aShape.length === 0) fail(`${opName}: cannot reduce a 0-d tensor`, site)
|
|
111
|
+
// keepdims=false: drop the last axis.
|
|
112
|
+
return aShape.slice(0, -1)
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
export function inferReshape(opName: string, aShape: Shape, newShape: Shape, site: CallSite | null): Shape {
|
|
116
|
+
// Validate -1 placeholder (at most one allowed) and total size match.
|
|
117
|
+
let inferIdx = -1
|
|
118
|
+
let knownSize = 1
|
|
119
|
+
for (let i = 0; i < newShape.length; i++) {
|
|
120
|
+
const d = newShape[i]!
|
|
121
|
+
if (d === -1) {
|
|
122
|
+
if (inferIdx !== -1) fail(`${opName}: at most one -1 dim allowed in newShape ${showShape(newShape)}`, site)
|
|
123
|
+
inferIdx = i
|
|
124
|
+
} else if (d <= 0) {
|
|
125
|
+
fail(`${opName}: invalid dim ${d} in newShape ${showShape(newShape)}`, site)
|
|
126
|
+
} else {
|
|
127
|
+
knownSize *= d
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
const totalIn = shapeSize(aShape)
|
|
131
|
+
const out = [...newShape]
|
|
132
|
+
if (inferIdx !== -1) {
|
|
133
|
+
if (totalIn % knownSize !== 0) {
|
|
134
|
+
fail(`${opName}: cannot reshape ${showShape(aShape)} (size ${totalIn}) to ${showShape(newShape)} — known dims multiply to ${knownSize}`, site)
|
|
135
|
+
}
|
|
136
|
+
out[inferIdx] = totalIn / knownSize
|
|
137
|
+
} else if (knownSize !== totalIn) {
|
|
138
|
+
fail(`${opName}: size mismatch — input ${showShape(aShape)} has ${totalIn} elements but newShape ${showShape(newShape)} has ${knownSize}`, site)
|
|
139
|
+
}
|
|
140
|
+
return out
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
export function inferTranspose(opName: string, aShape: Shape, perm: readonly number[], site: CallSite | null): Shape {
|
|
144
|
+
if (perm.length !== aShape.length) {
|
|
145
|
+
fail(`${opName}: perm length ${perm.length} must equal input rank ${aShape.length}`, site)
|
|
146
|
+
}
|
|
147
|
+
const seen = new Set<number>()
|
|
148
|
+
for (const p of perm) {
|
|
149
|
+
if (p < 0 || p >= aShape.length) fail(`${opName}: perm index ${p} out of range for rank ${aShape.length}`, site)
|
|
150
|
+
if (seen.has(p)) fail(`${opName}: perm has duplicate index ${p}`, site)
|
|
151
|
+
seen.add(p)
|
|
152
|
+
}
|
|
153
|
+
return perm.map(p => aShape[p]!)
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// matmul: a [..., M, K] · b [K, N] → [..., M, N]. b is unbatched.
|
|
157
|
+
export function inferMatmul(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
|
|
158
|
+
if (aShape.length < 2) fail(`${opName}: lhs must have rank >= 2, got ${showShape(aShape)}`, site)
|
|
159
|
+
if (bShape.length !== 2) fail(`${opName}: rhs must have rank 2, got ${showShape(bShape)} — use matmulBatched for batched rhs`, site)
|
|
160
|
+
const M = aShape[aShape.length - 2]!
|
|
161
|
+
const Ka = aShape[aShape.length - 1]!
|
|
162
|
+
const Kb = bShape[0]!
|
|
163
|
+
const N = bShape[1]!
|
|
164
|
+
if (Ka !== Kb) fail(`${opName}: inner dims don't match — ${showShape(aShape)} · ${showShape(bShape)} (last axis of lhs = ${Ka}, first axis of rhs = ${Kb})`, site)
|
|
165
|
+
return [...aShape.slice(0, -2), M, N]
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
// matmul_batched: a [..., M, K] · b [..., K, N] → [..., M, N]. Both have leading batch dims.
|
|
169
|
+
export function inferMatmulBatched(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
|
|
170
|
+
if (aShape.length < 2 || bShape.length < 2) {
|
|
171
|
+
fail(`${opName}: both inputs must have rank >= 2, got ${showShape(aShape)} and ${showShape(bShape)}`, site)
|
|
172
|
+
}
|
|
173
|
+
if (aShape.length !== bShape.length) {
|
|
174
|
+
fail(`${opName}: ranks must match (got ${aShape.length} vs ${bShape.length}). Reshape if you need different batch dims.`, site)
|
|
175
|
+
}
|
|
176
|
+
const aBatch = aShape.slice(0, -2)
|
|
177
|
+
const bBatch = bShape.slice(0, -2)
|
|
178
|
+
for (let i = 0; i < aBatch.length; i++) {
|
|
179
|
+
if (aBatch[i] !== bBatch[i]) {
|
|
180
|
+
fail(`${opName}: batch dims must match — ${showShape(aShape)} vs ${showShape(bShape)}`, site)
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
const M = aShape[aShape.length - 2]!
|
|
184
|
+
const Ka = aShape[aShape.length - 1]!
|
|
185
|
+
const Kb = bShape[bShape.length - 2]!
|
|
186
|
+
const N = bShape[bShape.length - 1]!
|
|
187
|
+
if (Ka !== Kb) fail(`${opName}: inner dims don't match — last axis of lhs = ${Ka}, second-to-last of rhs = ${Kb}`, site)
|
|
188
|
+
return [...aBatch, M, N]
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
export function inferOneHot(opName: string, indicesShape: Shape, depth: number, site: CallSite | null): Shape {
|
|
192
|
+
if (depth <= 0) fail(`${opName}: depth must be positive, got ${depth}`, site)
|
|
193
|
+
return [...indicesShape, depth]
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// where_causal preserves shape but requires the last two axes to be square.
|
|
197
|
+
export function inferWhereCausal(opName: string, aShape: Shape, site: CallSite | null): Shape {
|
|
198
|
+
if (aShape.length < 2) fail(`${opName}: requires rank >= 2, got ${showShape(aShape)}`, site)
|
|
199
|
+
const m = aShape[aShape.length - 2]!
|
|
200
|
+
const n = aShape[aShape.length - 1]!
|
|
201
|
+
if (m !== n) fail(`${opName}: last two axes must be equal (square mask), got ${showShape(aShape)}`, site)
|
|
202
|
+
return aShape
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
export function inferSliceLastRange(opName: string, aShape: Shape, start: number, end: number, site: CallSite | null): Shape {
|
|
206
|
+
if (aShape.length === 0) fail(`${opName}: cannot slice 0-d tensor`, site)
|
|
207
|
+
const last = aShape[aShape.length - 1]!
|
|
208
|
+
if (start < 0 || end > last || start >= end) {
|
|
209
|
+
fail(`${opName}: invalid range [${start}, ${end}) for last axis of size ${last}`, site)
|
|
210
|
+
}
|
|
211
|
+
return [...aShape.slice(0, -1), end - start]
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
// broadcast_to: validate that `aShape` can broadcast to `targetShape` under
|
|
215
|
+
// right-aligned NumPy rules. Returns targetShape on success.
|
|
216
|
+
export function inferBroadcastTo(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape {
|
|
217
|
+
if (aShape.length > targetShape.length) {
|
|
218
|
+
fail(`${opName}: source rank ${aShape.length} > target rank ${targetShape.length}`, site)
|
|
219
|
+
}
|
|
220
|
+
const offset = targetShape.length - aShape.length
|
|
221
|
+
for (let i = 0; i < aShape.length; i++) {
|
|
222
|
+
const av = aShape[i]!
|
|
223
|
+
const tv = targetShape[offset + i]!
|
|
224
|
+
if (av !== tv && av !== 1) {
|
|
225
|
+
fail(`${opName}: cannot broadcast ${showShape(aShape)} to ${showShape(targetShape)} — axis ${i} (size ${av}) doesn't match target axis ${offset + i} (size ${tv}) and isn't 1`, site)
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
return targetShape
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// sum_to_shape: validate that `targetShape` is a valid right-aligned reduction
|
|
232
|
+
// of `aShape` (i.e., aShape can have been produced by broadcasting targetShape).
|
|
233
|
+
export function inferSumToShape(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape {
|
|
234
|
+
if (targetShape.length > aShape.length) {
|
|
235
|
+
fail(`${opName}: target rank ${targetShape.length} > source rank ${aShape.length}`, site)
|
|
236
|
+
}
|
|
237
|
+
const offset = aShape.length - targetShape.length
|
|
238
|
+
for (let i = 0; i < targetShape.length; i++) {
|
|
239
|
+
const av = aShape[offset + i]!
|
|
240
|
+
const tv = targetShape[i]!
|
|
241
|
+
if (av !== tv && tv !== 1) {
|
|
242
|
+
fail(`${opName}: cannot sum-reduce ${showShape(aShape)} to ${showShape(targetShape)} — target axis ${i} (size ${tv}) must be 1 or match source`, site)
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
return targetShape
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
// Three-way broadcast for `where(cond, a, b)`. All three shapes must broadcast
|
|
249
|
+
// to a common shape under standard NumPy rules.
|
|
250
|
+
export function inferWhere(opName: string, condShape: Shape, aShape: Shape, bShape: Shape, site: CallSite | null): Shape {
|
|
251
|
+
const ab = broadcastTrailing(aShape, bShape)
|
|
252
|
+
if (!ab) fail(`${opName}: a/b incompatible: ${showShape(aShape)} vs ${showShape(bShape)}`, site)
|
|
253
|
+
const result = broadcastTrailing(condShape, ab)
|
|
254
|
+
if (!result) fail(`${opName}: cond ${showShape(condShape)} incompatible with broadcast(a, b) ${showShape(ab)}`, site)
|
|
255
|
+
return result
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
export function inferReluGrad(opName: string, xShape: Shape, dyShape: Shape, site: CallSite | null): Shape {
|
|
259
|
+
if (!shapesEqual(xShape, dyShape)) {
|
|
260
|
+
fail(`${opName}: x and dy must have matching shapes, got ${showShape(xShape)} and ${showShape(dyShape)}`, site)
|
|
261
|
+
}
|
|
262
|
+
return xShape
|
|
263
|
+
}
|
package/src/trace.ts
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
// Trace driver. Holds the "current graph" in module-local state so user code
|
|
2
|
+
// can call ops without threading a graph parameter through every function.
|
|
3
|
+
//
|
|
4
|
+
// Usage:
|
|
5
|
+
//
|
|
6
|
+
// const graph = trace(() => {
|
|
7
|
+
// const x = tensorInput('x', [B, T], 'i32')
|
|
8
|
+
// const w = paramInput('w', [V, D], 'f32')
|
|
9
|
+
// // ... user computation building tensors ...
|
|
10
|
+
// return finalLossTensor
|
|
11
|
+
// })
|
|
12
|
+
//
|
|
13
|
+
// `trace` is single-threaded and re-entrant only via nested calls (which share
|
|
14
|
+
// the outer graph — but we don't currently have a use for nesting). Calling an
|
|
15
|
+
// op outside a `trace(...)` block is an error.
|
|
16
|
+
|
|
17
|
+
import type { Graph, Tensor, Shape, Dtype } from './ir.js'
|
|
18
|
+
import { makeGraph, addOp, captureSite } from './ir.js'
|
|
19
|
+
|
|
20
|
+
// Module-local: the graph being built right now, or null if no trace is active.
|
|
21
|
+
let _current: Graph | null = null
|
|
22
|
+
|
|
23
|
+
export function currentGraph(): Graph {
|
|
24
|
+
if (!_current) {
|
|
25
|
+
throw new Error(
|
|
26
|
+
'tensorgrad: ops can only be called inside trace(). ' +
|
|
27
|
+
'Did you forget to wrap your forward pass?',
|
|
28
|
+
)
|
|
29
|
+
}
|
|
30
|
+
return _current
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
// Run `fn` with a fresh graph as the current one; capture and return the graph.
|
|
34
|
+
// `fn` must return the tensor (or array of tensors) to mark as graph outputs.
|
|
35
|
+
export function trace(fn: () => Tensor | Tensor[]): Graph {
|
|
36
|
+
if (_current) {
|
|
37
|
+
throw new Error('tensorgrad: nested trace() is not supported')
|
|
38
|
+
}
|
|
39
|
+
const g = makeGraph()
|
|
40
|
+
_current = g
|
|
41
|
+
try {
|
|
42
|
+
const result = fn()
|
|
43
|
+
const outputs = Array.isArray(result) ? result : [result]
|
|
44
|
+
for (const t of outputs) {
|
|
45
|
+
;(g.outputs as number[]).push(t.id)
|
|
46
|
+
}
|
|
47
|
+
} finally {
|
|
48
|
+
_current = null
|
|
49
|
+
}
|
|
50
|
+
return g
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Re-enter an existing graph to append more ops. Used by autograd to add
|
|
54
|
+
// backward ops to a graph that's already been traced. `fn` runs with the
|
|
55
|
+
// supplied graph as the current one; any ops it calls append to that graph.
|
|
56
|
+
// Returns whatever `fn` returns.
|
|
57
|
+
export function traceInto<T>(g: Graph, fn: () => T): T {
|
|
58
|
+
if (_current) {
|
|
59
|
+
throw new Error('tensorgrad: traceInto() called while another trace is active')
|
|
60
|
+
}
|
|
61
|
+
_current = g
|
|
62
|
+
try {
|
|
63
|
+
return fn()
|
|
64
|
+
} finally {
|
|
65
|
+
_current = null
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// ---- Leaf tensor builders --------------------------------------------------
|
|
70
|
+
// Inputs are added to the graph as `param_input` or `tensor_input` op nodes.
|
|
71
|
+
// Their .source on the Tensor points at that node so codegen knows where to
|
|
72
|
+
// bind external data.
|
|
73
|
+
|
|
74
|
+
export function paramInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
|
|
75
|
+
const g = currentGraph()
|
|
76
|
+
if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
|
|
77
|
+
throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
|
|
78
|
+
}
|
|
79
|
+
const site = captureSite('paramInput')
|
|
80
|
+
return addOp(g, 'param_input', shape, dtype, site, { name } as any)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
export function tensorInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
|
|
84
|
+
const g = currentGraph()
|
|
85
|
+
if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
|
|
86
|
+
throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
|
|
87
|
+
}
|
|
88
|
+
const site = captureSite('tensorInput')
|
|
89
|
+
return addOp(g, 'tensor_input', shape, dtype, site, { name } as any)
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Persistent state buffer. Allocated at compile time, zero-(or initValue-)initialized,
|
|
93
|
+
// and updated across step() calls via writebacks declared by the optimizer helper.
|
|
94
|
+
export function stateInput(name: string, shape: Shape, dtype: Dtype = 'f32', initValue = 0): Tensor {
|
|
95
|
+
const g = currentGraph()
|
|
96
|
+
if (g.ops.some(op => op.kind === 'state_input' && op.name === name)) {
|
|
97
|
+
throw new Error(`tensorgrad: state name '${name}' already used in this trace`)
|
|
98
|
+
}
|
|
99
|
+
const site = captureSite('stateInput')
|
|
100
|
+
return addOp(g, 'state_input', shape, dtype, site, { name, initValue } as any)
|
|
101
|
+
}
|