tensorgrad 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +121 -0
- package/SPEC.md +293 -0
- package/dist/adam.d.ts +31 -0
- package/dist/adam.d.ts.map +1 -0
- package/dist/adam.js +66 -0
- package/dist/adam.js.map +1 -0
- package/dist/buffers.d.ts +56 -0
- package/dist/buffers.d.ts.map +1 -0
- package/dist/buffers.js +114 -0
- package/dist/buffers.js.map +1 -0
- package/dist/codegen.d.ts +23 -0
- package/dist/codegen.d.ts.map +1 -0
- package/dist/codegen.js +709 -0
- package/dist/codegen.js.map +1 -0
- package/dist/compile.d.ts +53 -0
- package/dist/compile.d.ts.map +1 -0
- package/dist/compile.js +76 -0
- package/dist/compile.js.map +1 -0
- package/dist/grad.d.ts +8 -0
- package/dist/grad.d.ts.map +1 -0
- package/dist/grad.js +404 -0
- package/dist/grad.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +37 -0
- package/dist/index.js.map +1 -0
- package/dist/ir.d.ts +204 -0
- package/dist/ir.d.ts.map +1 -0
- package/dist/ir.js +60 -0
- package/dist/ir.js.map +1 -0
- package/dist/module.d.ts +21 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +113 -0
- package/dist/module.js.map +1 -0
- package/dist/ops.d.ts +35 -0
- package/dist/ops.d.ts.map +1 -0
- package/dist/ops.js +270 -0
- package/dist/ops.js.map +1 -0
- package/dist/runtime.d.ts +26 -0
- package/dist/runtime.d.ts.map +1 -0
- package/dist/runtime.js +190 -0
- package/dist/runtime.js.map +1 -0
- package/dist/shape.d.ts +24 -0
- package/dist/shape.d.ts.map +1 -0
- package/dist/shape.js +259 -0
- package/dist/shape.js.map +1 -0
- package/dist/trace.d.ts +8 -0
- package/dist/trace.d.ts.map +1 -0
- package/dist/trace.js +93 -0
- package/dist/trace.js.map +1 -0
- package/package.json +62 -0
- package/src/adam.ts +95 -0
- package/src/buffers.ts +173 -0
- package/src/codegen.ts +758 -0
- package/src/compile.ts +120 -0
- package/src/grad.ts +459 -0
- package/src/index.ts +40 -0
- package/src/ir.ts +197 -0
- package/src/module.ts +126 -0
- package/src/ops.ts +311 -0
- package/src/runtime.ts +232 -0
- package/src/shape.ts +263 -0
- package/src/trace.ts +101 -0
package/src/ir.ts
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
// Intermediate representation for tensor computations.
|
|
2
|
+
//
|
|
3
|
+
// A `Graph` is a flat array of `OpNode`s in topological (= construction) order.
|
|
4
|
+
// A `Tensor` is an opaque handle: shape + dtype + a pointer back to the OpNode
|
|
5
|
+
// that produced it (or `null` for graph leaves — params and external inputs).
|
|
6
|
+
//
|
|
7
|
+
// This is the data structure everything else operates on:
|
|
8
|
+
// - tracing builds it (src/trace.ts)
|
|
9
|
+
// - autograd walks it in reverse to add backward nodes (src/grad.ts, later)
|
|
10
|
+
// - codegen reads it to emit WGSL kernels and a dispatch plan (src/codegen.ts, later)
|
|
11
|
+
//
|
|
12
|
+
// Design intent: keep this file boring. No tracing logic, no shape inference,
|
|
13
|
+
// no codegen — those live in their own modules and consume `Graph` / `OpNode`.
|
|
14
|
+
|
|
15
|
+
export type Dtype = 'f32' | 'i32' | 'bool'
|
|
16
|
+
export type Shape = readonly number[]
|
|
17
|
+
|
|
18
|
+
// A Tensor is just metadata + a unique id. The actual storage doesn't exist
|
|
19
|
+
// until the graph is compiled and run on a device.
|
|
20
|
+
export interface Tensor {
|
|
21
|
+
readonly id: number
|
|
22
|
+
readonly shape: Shape
|
|
23
|
+
readonly dtype: Dtype
|
|
24
|
+
// null for leaves (params, external inputs); otherwise the index into Graph.ops.
|
|
25
|
+
readonly source: number | null
|
|
26
|
+
// Captured at op-call time so shape errors blame the user's frame, not the
|
|
27
|
+
// library's. Lazy: only formatted on demand.
|
|
28
|
+
readonly site: CallSite | null
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
export interface CallSite {
|
|
32
|
+
readonly opName: string
|
|
33
|
+
// Full Error stack at the point of op invocation. Format on demand.
|
|
34
|
+
readonly stack: string
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Discriminated union over every op the IR knows about. Adding an op means:
|
|
38
|
+
// 1. add a variant here,
|
|
39
|
+
// 2. add a shape rule in src/shape.ts,
|
|
40
|
+
// 3. add a transpose rule in src/grad.ts (later),
|
|
41
|
+
// 4. add a kernel template in src/codegen.ts (later).
|
|
42
|
+
// The kinds intentionally match the surface API in src/ops.ts one-to-one.
|
|
43
|
+
export type OpNode =
|
|
44
|
+
// ---- Leaves ----------------------------------------------------------------
|
|
45
|
+
// A trainable parameter, supplied by the caller as a Float32Array at runtime.
|
|
46
|
+
| { kind: 'param_input'; out: number; name: string }
|
|
47
|
+
// A non-trainable input (tokens, targets, constants). Bound at runtime.
|
|
48
|
+
| { kind: 'tensor_input'; out: number; name: string }
|
|
49
|
+
// Persistent state buffer (e.g. Adam's m/v). Allocated and zero-initialized
|
|
50
|
+
// at compile time; survives across step() calls. Updated via writebacks
|
|
51
|
+
// declared in the compile result.
|
|
52
|
+
| { kind: 'state_input'; out: number; name: string; initValue: number }
|
|
53
|
+
|
|
54
|
+
// ---- Element-wise --------------------------------------------------------
|
|
55
|
+
| { kind: 'add'; out: number; a: number; b: number }
|
|
56
|
+
| { kind: 'sub'; out: number; a: number; b: number }
|
|
57
|
+
| { kind: 'mul'; out: number; a: number; b: number }
|
|
58
|
+
| { kind: 'div'; out: number; a: number; b: number }
|
|
59
|
+
| { kind: 'mul_scalar'; out: number; a: number; scalar: number }
|
|
60
|
+
| { kind: 'add_scalar'; out: number; a: number; scalar: number }
|
|
61
|
+
|
|
62
|
+
// ---- Unary ---------------------------------------------------------------
|
|
63
|
+
| { kind: 'sqrt'; out: number; a: number }
|
|
64
|
+
| { kind: 'rsqrt'; out: number; a: number }
|
|
65
|
+
| { kind: 'log'; out: number; a: number }
|
|
66
|
+
| { kind: 'exp'; out: number; a: number }
|
|
67
|
+
| { kind: 'relu'; out: number; a: number }
|
|
68
|
+
|
|
69
|
+
// ---- Reductions (over last axis only; reshape if you need other axes) ----
|
|
70
|
+
| { kind: 'mean_last'; out: number; a: number } // keepdims=true
|
|
71
|
+
| { kind: 'sum_last'; out: number; a: number } // keepdims=false
|
|
72
|
+
|
|
73
|
+
// ---- Shape ---------------------------------------------------------------
|
|
74
|
+
| { kind: 'reshape'; out: number; a: number; newShape: Shape }
|
|
75
|
+
| { kind: 'transpose'; out: number; a: number; perm: readonly number[] }
|
|
76
|
+
|
|
77
|
+
// ---- Linear algebra -----------------------------------------------------
|
|
78
|
+
// matmul: a [..., M, K] · b [K, N] -> [..., M, N]. b is unbatched.
|
|
79
|
+
// (Batched-on-both-sides matmul, e.g. for attention scores, is a separate kind
|
|
80
|
+
// to keep autograd transpose rules simple.)
|
|
81
|
+
| { kind: 'matmul'; out: number; a: number; b: number }
|
|
82
|
+
// matmul_batched: a [..., M, K] · b [..., K, N] -> [..., M, N]. Used by attention.
|
|
83
|
+
| { kind: 'matmul_batched'; out: number; a: number; b: number }
|
|
84
|
+
|
|
85
|
+
// ---- Indexing / casting --------------------------------------------------
|
|
86
|
+
| { kind: 'one_hot'; out: number; indices: number; depth: number; dtype: Dtype }
|
|
87
|
+
| { kind: 'arange'; out: number; n: number; dtype: Dtype }
|
|
88
|
+
|
|
89
|
+
// ---- ML primitives (fused for cleaner autograd) -------------------------
|
|
90
|
+
| { kind: 'softmax_causal_last'; out: number; a: number }
|
|
91
|
+
| { kind: 'log_softmax_last'; out: number; a: number }
|
|
92
|
+
// Sets cells where (i >= j) on the last two axes; for masking attention scores
|
|
93
|
+
// *before* softmax. Lower-triangle entries pass through; upper-triangle entries
|
|
94
|
+
// become `fillValue` (typically -inf or a large negative number).
|
|
95
|
+
| { kind: 'where_causal'; out: number; a: number; fillValue: number }
|
|
96
|
+
|
|
97
|
+
// ---- Comparisons + selection -------------------------------------------
|
|
98
|
+
// Element-wise comparison; result is bool (lowered to u32 in storage).
|
|
99
|
+
// Supports the same trailing-axis broadcast as element-wise binops.
|
|
100
|
+
| { kind: 'less'; out: number; a: number; b: number }
|
|
101
|
+
| { kind: 'greater'; out: number; a: number; b: number }
|
|
102
|
+
// Element-wise select: out[i] = cond[i] ? a[i] : b[i]. cond must be bool.
|
|
103
|
+
// a, b, cond all broadcast-compatible to out's shape.
|
|
104
|
+
| { kind: 'where'; out: number; cond: number; a: number; b: number }
|
|
105
|
+
|
|
106
|
+
// ---- Optimizer-fused ops (Adam) ----------------------------------------
|
|
107
|
+
// Each is a single kernel doing the full per-element math, baking in the
|
|
108
|
+
// hyperparameter constant. Used by appendAdam() to avoid decomposing the
|
|
109
|
+
// update into ~12 element-wise dispatches per param.
|
|
110
|
+
| { kind: 'adam_update_m'; out: number; m: number; g: number; b1: number }
|
|
111
|
+
| { kind: 'adam_update_v'; out: number; v: number; g: number; b2: number }
|
|
112
|
+
// adam_update_p: p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
113
|
+
// `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
|
|
114
|
+
// already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
|
|
115
|
+
// Only `eps` is baked in.
|
|
116
|
+
| { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number }
|
|
117
|
+
|
|
118
|
+
// ---- Slicing / broadcasting / autograd infrastructure -------------------
|
|
119
|
+
// Slice [start, end) along the last axis. Output shape: input shape with
|
|
120
|
+
// last axis replaced by (end - start). Used for splitting Q/K/V from a
|
|
121
|
+
// single fused QKV matmul.
|
|
122
|
+
| { kind: 'slice_last_range'; out: number; a: number; start: number; end: number }
|
|
123
|
+
// Broadcast `a` to `targetShape`. Standard right-aligned NumPy broadcast.
|
|
124
|
+
// Used by autograd to expand cotangents back over reduced/broadcast axes.
|
|
125
|
+
| { kind: 'broadcast_to'; out: number; a: number; targetShape: Shape }
|
|
126
|
+
// Inverse of broadcast_to: sum-reduce `a` to `targetShape`. Used by autograd
|
|
127
|
+
// to "un-broadcast" a cotangent back to the smaller operand's shape.
|
|
128
|
+
| { kind: 'sum_to_shape'; out: number; a: number; targetShape: Shape }
|
|
129
|
+
// 0-d tensor with a constant value. Used to seed loss cotangent (1.0).
|
|
130
|
+
| { kind: 'const_scalar'; out: number; value: number; dtype: Dtype }
|
|
131
|
+
// ReLU's backward: passes `dy` through where `x > 0`, else 0. Output shape = x's.
|
|
132
|
+
| { kind: 'relu_grad'; out: number; x: number; dy: number }
|
|
133
|
+
|
|
134
|
+
// A Graph collects ops and tensors during tracing, then becomes the input to
|
|
135
|
+
// autograd and codegen. Once tracing is done it should be treated as immutable.
|
|
136
|
+
export interface Graph {
|
|
137
|
+
readonly ops: OpNode[]
|
|
138
|
+
readonly tensors: Tensor[]
|
|
139
|
+
// Names of tensors that should be exposed as outputs of the compiled function.
|
|
140
|
+
// Set by the trace driver; for a loss function, this is `[lossTensor]`.
|
|
141
|
+
readonly outputs: number[]
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
export function makeGraph(): Graph {
|
|
145
|
+
return { ops: [], tensors: [], outputs: [] }
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Internal: register a fresh tensor in the graph and return its id.
|
|
149
|
+
export function addTensor(g: Graph, shape: Shape, dtype: Dtype, source: number | null, site: CallSite | null): Tensor {
|
|
150
|
+
const id = g.tensors.length
|
|
151
|
+
const t: Tensor = { id, shape, dtype, source, site }
|
|
152
|
+
g.tensors.push(t)
|
|
153
|
+
return t
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// Internal: append an op and the tensor it produces. Returns the produced tensor.
|
|
157
|
+
// Generic over the specific op kind so callers don't need `as any` casts.
|
|
158
|
+
// `Extract<OpNode, { kind: K }>` narrows the union to the chosen variant, then
|
|
159
|
+
// `Omit` strips the parts addOp itself supplies (the kind tag and out tensor id).
|
|
160
|
+
export function addOp<K extends OpNode['kind']>(
|
|
161
|
+
g: Graph,
|
|
162
|
+
kind: K,
|
|
163
|
+
shape: Shape,
|
|
164
|
+
dtype: Dtype,
|
|
165
|
+
site: CallSite | null,
|
|
166
|
+
fields: Omit<Extract<OpNode, { kind: K }>, 'kind' | 'out'>,
|
|
167
|
+
): Tensor {
|
|
168
|
+
const opIndex = g.ops.length
|
|
169
|
+
const out = addTensor(g, shape, dtype, opIndex, site)
|
|
170
|
+
const node = { kind, out: out.id, ...fields } as Extract<OpNode, { kind: K }>
|
|
171
|
+
g.ops.push(node)
|
|
172
|
+
return out
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Capture a call site without paying full Error formatting cost up-front.
|
|
176
|
+
// The stack is materialised but parsing/trimming is deferred to error reporting.
|
|
177
|
+
export function captureSite(opName: string): CallSite {
|
|
178
|
+
// Skip our own frame plus the op wrapper's frame; user's frame is what's left.
|
|
179
|
+
const stack = (new Error()).stack ?? ''
|
|
180
|
+
return { opName, stack }
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Format a CallSite for inclusion in a thrown error. Strips Tensorgrad frames
|
|
184
|
+
// and library internals so the user sees their code first.
|
|
185
|
+
export function formatSite(site: CallSite): string {
|
|
186
|
+
const lines = site.stack.split('\n')
|
|
187
|
+
// Stack starts with "Error" line; drop it. Then drop frames from this file
|
|
188
|
+
// and from src/ops.ts so the first surviving frame is user code.
|
|
189
|
+
const userFrames: string[] = []
|
|
190
|
+
for (const line of lines.slice(1)) {
|
|
191
|
+
if (line.includes('/tensorgrad/src/') || line.includes('\\tensorgrad\\src\\')) continue
|
|
192
|
+
userFrames.push(line.trim())
|
|
193
|
+
if (userFrames.length >= 3) break
|
|
194
|
+
}
|
|
195
|
+
if (userFrames.length === 0) return `[${site.opName}] (no user frame found)`
|
|
196
|
+
return `[${site.opName}]\n ${userFrames.join('\n ')}`
|
|
197
|
+
}
|
package/src/module.ts
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
// Module abstraction — a Domeleon-style component layer for parameter trees.
|
|
2
|
+
//
|
|
3
|
+
// User code defines a model as nested classes:
|
|
4
|
+
//
|
|
5
|
+
// class Linear extends Module {
|
|
6
|
+
// W: Tensor; b: Tensor
|
|
7
|
+
// constructor(inDim: number, outDim: number) {
|
|
8
|
+
// super()
|
|
9
|
+
// this.W = this.param([inDim, outDim])
|
|
10
|
+
// this.b = this.param([outDim])
|
|
11
|
+
// }
|
|
12
|
+
// }
|
|
13
|
+
// class Block extends Module {
|
|
14
|
+
// attn = new Attention(D)
|
|
15
|
+
// mlp = new MLP(D, 4 * D)
|
|
16
|
+
// }
|
|
17
|
+
// class Model extends Module {
|
|
18
|
+
// embed = new Linear(VOCAB, D)
|
|
19
|
+
// layers = range(N).map(() => new Block())
|
|
20
|
+
// }
|
|
21
|
+
//
|
|
22
|
+
// The param tree is discovered automatically at compile time by walking
|
|
23
|
+
// enumerable instance properties. Each parameter gets a name auto-derived
|
|
24
|
+
// from its path (`layers.0.attn.W_q`); names are used for upload/download
|
|
25
|
+
// and writeback wiring. Forward functions are pure and stateless — they
|
|
26
|
+
// take the materialized model and inputs, return a Tensor.
|
|
27
|
+
|
|
28
|
+
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
29
|
+
import { paramInput } from './trace.js'
|
|
30
|
+
|
|
31
|
+
// ============================================================================
|
|
32
|
+
// Internals: param sentinel
|
|
33
|
+
// ============================================================================
|
|
34
|
+
//
|
|
35
|
+
// `this.param(shape)` returns a placeholder that's replaced by a real Tensor
|
|
36
|
+
// during `materializeParams`. We type-cheat by declaring the return type as
|
|
37
|
+
// `Tensor` so user code can write `this.W` and have TS happy; the cheat is
|
|
38
|
+
// only valid post-materialization (which is always before forward runs).
|
|
39
|
+
|
|
40
|
+
class ParamSentinel {
|
|
41
|
+
constructor(public readonly shape: Shape, public readonly dtype: Dtype) {}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
// ============================================================================
|
|
45
|
+
// Module base class
|
|
46
|
+
// ============================================================================
|
|
47
|
+
|
|
48
|
+
export abstract class Module {
|
|
49
|
+
/**
|
|
50
|
+
* Declare a learnable parameter at this module. Must be called from inside
|
|
51
|
+
* the constructor (typically as a field assignment). Returns a placeholder
|
|
52
|
+
* that gets replaced with a real Tensor at compile time.
|
|
53
|
+
*
|
|
54
|
+
* The parameter's name is auto-derived from its property path in the model
|
|
55
|
+
* tree (e.g. `layers.0.attn.W_q`).
|
|
56
|
+
*/
|
|
57
|
+
protected param(shape: Shape, dtype: Dtype = 'f32'): Tensor {
|
|
58
|
+
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
59
|
+
return new ParamSentinel(shape, dtype) as unknown as Tensor
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// ============================================================================
|
|
64
|
+
// Tree walking
|
|
65
|
+
// ============================================================================
|
|
66
|
+
|
|
67
|
+
/**
|
|
68
|
+
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
69
|
+
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
70
|
+
* trace context (paramInput appends to the current graph).
|
|
71
|
+
*
|
|
72
|
+
* Returns a flat record of `{ path: tensor }` for every materialized param.
|
|
73
|
+
*/
|
|
74
|
+
export function materializeParams(root: Module): Record<string, Tensor> {
|
|
75
|
+
const out: Record<string, Tensor> = {}
|
|
76
|
+
visit(root, '', (path, val, owner, key) => {
|
|
77
|
+
if (val instanceof ParamSentinel) {
|
|
78
|
+
const t = paramInput(path, val.shape, val.dtype)
|
|
79
|
+
;(owner as any)[key] = t
|
|
80
|
+
out[path] = t
|
|
81
|
+
}
|
|
82
|
+
})
|
|
83
|
+
return out
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ----------------------------------------------------------------------------
|
|
87
|
+
// Visitor
|
|
88
|
+
// ----------------------------------------------------------------------------
|
|
89
|
+
//
|
|
90
|
+
// Walks enumerable own properties recursively, building a path string. Recurses
|
|
91
|
+
// into nested Modules and arrays of Modules (or arrays of arrays, etc.).
|
|
92
|
+
// Calls `visitor` on every leaf — including ParamSentinels (pre-materialize)
|
|
93
|
+
// and real Tensor leaves (post-materialize).
|
|
94
|
+
|
|
95
|
+
type Visitor = (path: string, val: unknown, owner: object, key: string | number) => void
|
|
96
|
+
|
|
97
|
+
function visit(node: unknown, path: string, visitor: Visitor): void {
|
|
98
|
+
if (node === null || node === undefined) return
|
|
99
|
+
if (typeof node !== 'object') return
|
|
100
|
+
|
|
101
|
+
if (node instanceof Module) {
|
|
102
|
+
for (const key of Object.keys(node as object)) {
|
|
103
|
+
const child = (node as any)[key]
|
|
104
|
+
const childPath = path ? `${path}.${key}` : key
|
|
105
|
+
visitChild(child, childPath, node, key, visitor)
|
|
106
|
+
}
|
|
107
|
+
return
|
|
108
|
+
}
|
|
109
|
+
if (Array.isArray(node)) {
|
|
110
|
+
node.forEach((item, i) => {
|
|
111
|
+
const childPath = path ? `${path}.${i}` : String(i)
|
|
112
|
+
visitChild(item, childPath, node as unknown as object, i, visitor)
|
|
113
|
+
})
|
|
114
|
+
return
|
|
115
|
+
}
|
|
116
|
+
// Plain leaf object (sentinel / tensor / something else): visitor decides.
|
|
117
|
+
// No deeper recursion.
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
function visitChild(child: unknown, path: string, owner: object, key: string | number, visitor: Visitor): void {
|
|
121
|
+
if (child instanceof Module || Array.isArray(child)) {
|
|
122
|
+
visit(child, path, visitor)
|
|
123
|
+
} else {
|
|
124
|
+
visitor(path, child, owner, key)
|
|
125
|
+
}
|
|
126
|
+
}
|
package/src/ops.ts
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
// User-facing op surface.
|
|
2
|
+
//
|
|
3
|
+
// Each function here is a thin wrapper:
|
|
4
|
+
// 1. capture the call site (for error attribution)
|
|
5
|
+
// 2. validate input shapes via src/shape.ts (which throws on mismatch)
|
|
6
|
+
// 3. compute the output shape and dtype
|
|
7
|
+
// 4. append the op to the current Graph (held in module state by src/trace.ts)
|
|
8
|
+
// 5. return the produced Tensor handle
|
|
9
|
+
//
|
|
10
|
+
// No actual numeric work happens here. These calls just build the IR.
|
|
11
|
+
|
|
12
|
+
import type { Tensor, Shape, Dtype, OpNode } from './ir.js'
|
|
13
|
+
import { addOp, captureSite } from './ir.js'
|
|
14
|
+
import { currentGraph } from './trace.js'
|
|
15
|
+
import {
|
|
16
|
+
inferElementwiseBinop, inferUnary, inferMeanLast, inferSumLast,
|
|
17
|
+
inferReshape, inferTranspose, inferMatmul, inferMatmulBatched,
|
|
18
|
+
inferOneHot, inferWhereCausal, inferSliceLastRange,
|
|
19
|
+
inferBroadcastTo, inferSumToShape, inferReluGrad, inferWhere,
|
|
20
|
+
ShapeError,
|
|
21
|
+
} from './shape.js'
|
|
22
|
+
|
|
23
|
+
// ----------------------------------------------------------------------------
|
|
24
|
+
// Element-wise binops (add/sub/mul/div). Trailing-suffix broadcast.
|
|
25
|
+
// ----------------------------------------------------------------------------
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Build an element-wise binop op (forward declaration only — appends to the
|
|
29
|
+
* graph). Used by both arithmetic ops (add/sub/mul/div, output dtype = input
|
|
30
|
+
* dtype) and comparisons (less/greater, output dtype = bool).
|
|
31
|
+
*/
|
|
32
|
+
function binopOp(
|
|
33
|
+
name: string,
|
|
34
|
+
kind: OpNode['kind'],
|
|
35
|
+
a: Tensor, b: Tensor,
|
|
36
|
+
outDtype: Dtype = a.dtype,
|
|
37
|
+
): Tensor {
|
|
38
|
+
const site = captureSite(name)
|
|
39
|
+
if (a.dtype !== b.dtype) throw new ShapeError(`${name}: dtype mismatch (${a.dtype} vs ${b.dtype})`, site)
|
|
40
|
+
const outShape = inferElementwiseBinop(name, a.shape, b.shape, site)
|
|
41
|
+
return addOp(currentGraph(), kind, outShape, outDtype, site, { a: a.id, b: b.id })
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
// Element-wise binops. Second arg can be a Tensor or a JS number; the latter
|
|
45
|
+
// dispatches to scalar-fused IR ops internally. `mul(x, 2)` and `mul(x, y)`
|
|
46
|
+
// both work — matches every NumPy-shaped library.
|
|
47
|
+
export function add(a: Tensor, b: Tensor | number): Tensor {
|
|
48
|
+
return typeof b === 'number' ? addScalar(a, b) : binopOp('add', 'add', a, b)
|
|
49
|
+
}
|
|
50
|
+
export function sub(a: Tensor, b: Tensor | number): Tensor {
|
|
51
|
+
return typeof b === 'number' ? addScalar(a, -b) : binopOp('sub', 'sub', a, b)
|
|
52
|
+
}
|
|
53
|
+
export function mul(a: Tensor, b: Tensor | number): Tensor {
|
|
54
|
+
return typeof b === 'number' ? mulScalar(a, b) : binopOp('mul', 'mul', a, b)
|
|
55
|
+
}
|
|
56
|
+
export function div(a: Tensor, b: Tensor | number): Tensor {
|
|
57
|
+
if (typeof b === 'number') {
|
|
58
|
+
if (b === 0) throw new ShapeError(`div: scalar divisor cannot be zero`, captureSite('div'))
|
|
59
|
+
return mulScalar(a, 1 / b)
|
|
60
|
+
}
|
|
61
|
+
return binopOp('div', 'div', a, b)
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// ----------------------------------------------------------------------------
|
|
65
|
+
// Element-wise scalar binops (mul/add by JS number). Used for things like
|
|
66
|
+
// `scores * (1/sqrt(d))` and `logits + 1e-5` where allocating a 0-d tensor
|
|
67
|
+
// for the scalar is wasteful.
|
|
68
|
+
// ----------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
export function mulScalar(a: Tensor, scalar: number): Tensor {
|
|
71
|
+
const site = captureSite('mulScalar')
|
|
72
|
+
return addOp(currentGraph(), 'mul_scalar', a.shape, a.dtype, site, { a: a.id, scalar })
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
export function addScalar(a: Tensor, scalar: number): Tensor {
|
|
76
|
+
const site = captureSite('addScalar')
|
|
77
|
+
return addOp(currentGraph(), 'add_scalar', a.shape, a.dtype, site, { a: a.id, scalar })
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// ----------------------------------------------------------------------------
|
|
81
|
+
// Unary ops.
|
|
82
|
+
// ----------------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
function unary(name: 'sqrt' | 'rsqrt' | 'log' | 'exp' | 'relu', a: Tensor): Tensor {
|
|
85
|
+
const site = captureSite(name)
|
|
86
|
+
if (a.dtype !== 'f32') throw new ShapeError(`${name}: requires f32, got ${a.dtype}`, site)
|
|
87
|
+
return addOp(currentGraph(), name, inferUnary(name, a.shape, site), 'f32', site, { a: a.id })
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
export const sqrt = (a: Tensor): Tensor => unary('sqrt', a)
|
|
91
|
+
export const rsqrt = (a: Tensor): Tensor => unary('rsqrt', a)
|
|
92
|
+
export const log = (a: Tensor): Tensor => unary('log', a)
|
|
93
|
+
export const exp = (a: Tensor): Tensor => unary('exp', a)
|
|
94
|
+
export const relu = (a: Tensor): Tensor => unary('relu', a)
|
|
95
|
+
|
|
96
|
+
// ----------------------------------------------------------------------------
|
|
97
|
+
// Reductions over the last axis. To reduce along other axes, transpose first.
|
|
98
|
+
// (This is intentional — keeps codegen and autograd small.)
|
|
99
|
+
// ----------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
export function meanLast(a: Tensor): Tensor {
|
|
102
|
+
const site = captureSite('meanLast')
|
|
103
|
+
if (a.dtype !== 'f32') throw new ShapeError(`meanLast: requires f32, got ${a.dtype}`, site)
|
|
104
|
+
const outShape = inferMeanLast('meanLast', a.shape, site)
|
|
105
|
+
return addOp(currentGraph(), 'mean_last', outShape, a.dtype, site, { a: a.id })
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
export function sumLast(a: Tensor): Tensor {
|
|
109
|
+
const site = captureSite('sumLast')
|
|
110
|
+
if (a.dtype !== 'f32') throw new ShapeError(`sumLast: requires f32, got ${a.dtype}`, site)
|
|
111
|
+
const outShape = inferSumLast('sumLast', a.shape, site)
|
|
112
|
+
return addOp(currentGraph(), 'sum_last', outShape, a.dtype, site, { a: a.id })
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// ----------------------------------------------------------------------------
|
|
116
|
+
// Shape ops.
|
|
117
|
+
// ----------------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
export function reshape(a: Tensor, newShape: Shape): Tensor {
|
|
120
|
+
const site = captureSite('reshape')
|
|
121
|
+
const outShape = inferReshape('reshape', a.shape, newShape, site)
|
|
122
|
+
return addOp(currentGraph(), 'reshape', outShape, a.dtype, site, { a: a.id, newShape: outShape })
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
export function transpose(a: Tensor, perm: readonly number[]): Tensor {
|
|
126
|
+
const site = captureSite('transpose')
|
|
127
|
+
const outShape = inferTranspose('transpose', a.shape, perm, site)
|
|
128
|
+
return addOp(currentGraph(), 'transpose', outShape, a.dtype, site, { a: a.id, perm })
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// ----------------------------------------------------------------------------
|
|
132
|
+
// Linear algebra.
|
|
133
|
+
// ----------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
export function matmul(a: Tensor, b: Tensor): Tensor {
|
|
136
|
+
const site = captureSite('matmul')
|
|
137
|
+
if (a.dtype !== 'f32' || b.dtype !== 'f32') {
|
|
138
|
+
throw new ShapeError(`matmul: requires f32, got ${a.dtype} and ${b.dtype}`, site)
|
|
139
|
+
}
|
|
140
|
+
const outShape = inferMatmul('matmul', a.shape, b.shape, site)
|
|
141
|
+
return addOp(currentGraph(), 'matmul', outShape, 'f32', site, { a: a.id, b: b.id })
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
export function matmulBatched(a: Tensor, b: Tensor): Tensor {
|
|
145
|
+
const site = captureSite('matmulBatched')
|
|
146
|
+
if (a.dtype !== 'f32' || b.dtype !== 'f32') {
|
|
147
|
+
throw new ShapeError(`matmulBatched: requires f32, got ${a.dtype} and ${b.dtype}`, site)
|
|
148
|
+
}
|
|
149
|
+
const outShape = inferMatmulBatched('matmulBatched', a.shape, b.shape, site)
|
|
150
|
+
return addOp(currentGraph(), 'matmul_batched', outShape, 'f32', site, { a: a.id, b: b.id })
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// ----------------------------------------------------------------------------
|
|
154
|
+
// Indexing / casting.
|
|
155
|
+
// ----------------------------------------------------------------------------
|
|
156
|
+
|
|
157
|
+
export function oneHot(indices: Tensor, depth: number, dtype: Dtype = 'f32'): Tensor {
|
|
158
|
+
const site = captureSite('oneHot')
|
|
159
|
+
if (indices.dtype !== 'i32') {
|
|
160
|
+
throw new ShapeError(`oneHot: indices must be i32, got ${indices.dtype}`, site)
|
|
161
|
+
}
|
|
162
|
+
const outShape = inferOneHot('oneHot', indices.shape, depth, site)
|
|
163
|
+
return addOp(currentGraph(), 'one_hot', outShape, dtype, site, { indices: indices.id, depth, dtype })
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
// arange(n) → [n] of values [0, 1, ..., n-1]. Used for position embeddings.
|
|
167
|
+
export function arange(n: number, dtype: Dtype = 'i32'): Tensor {
|
|
168
|
+
const site = captureSite('arange')
|
|
169
|
+
if (n <= 0 || !Number.isInteger(n)) {
|
|
170
|
+
throw new ShapeError(`arange: n must be a positive integer, got ${n}`, site)
|
|
171
|
+
}
|
|
172
|
+
return addOp(currentGraph(), 'arange', [n], dtype, site, { n, dtype })
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// ----------------------------------------------------------------------------
|
|
176
|
+
// ML primitives. Fused so autograd's transpose rule is straightforward and the
|
|
177
|
+
// kernels can be hand-tuned for our specific shapes.
|
|
178
|
+
// ----------------------------------------------------------------------------
|
|
179
|
+
|
|
180
|
+
// Causal-masked softmax along the last axis. Shape preserved. Last two axes
|
|
181
|
+
// must be square (TxT attention scores).
|
|
182
|
+
export function softmaxCausalLast(a: Tensor): Tensor {
|
|
183
|
+
const site = captureSite('softmaxCausalLast')
|
|
184
|
+
if (a.dtype !== 'f32') throw new ShapeError(`softmaxCausalLast: requires f32, got ${a.dtype}`, site)
|
|
185
|
+
inferWhereCausal('softmaxCausalLast', a.shape, site) // shape check (square last 2 axes)
|
|
186
|
+
return addOp(currentGraph(), 'softmax_causal_last', a.shape, 'f32', site, { a: a.id })
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Numerically-stable log-softmax along the last axis. Shape preserved.
|
|
190
|
+
export function logSoftmaxLast(a: Tensor): Tensor {
|
|
191
|
+
const site = captureSite('logSoftmaxLast')
|
|
192
|
+
if (a.dtype !== 'f32') throw new ShapeError(`logSoftmaxLast: requires f32, got ${a.dtype}`, site)
|
|
193
|
+
return addOp(currentGraph(), 'log_softmax_last', a.shape, 'f32', site, { a: a.id })
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Pre-softmax causal mask. Sets cells where (i < j) on the last two axes to
|
|
197
|
+
// `fillValue` (typically -1e30). Lower-triangle entries pass through.
|
|
198
|
+
// Use this when you want the masked scores explicitly (e.g. for capture);
|
|
199
|
+
// for the common case, prefer softmaxCausalLast which fuses both.
|
|
200
|
+
export function whereCausal(a: Tensor, fillValue: number): Tensor {
|
|
201
|
+
const site = captureSite('whereCausal')
|
|
202
|
+
if (a.dtype !== 'f32') throw new ShapeError(`whereCausal: requires f32, got ${a.dtype}`, site)
|
|
203
|
+
inferWhereCausal('whereCausal', a.shape, site)
|
|
204
|
+
return addOp(currentGraph(), 'where_causal', a.shape, 'f32', site, { a: a.id, fillValue })
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// ----------------------------------------------------------------------------
|
|
208
|
+
// Slicing.
|
|
209
|
+
// ----------------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
// sliceLastRange(a, start, end): slice [start, end) along the last axis.
|
|
212
|
+
// Used for splitting Q/K/V from a fused QKV matmul.
|
|
213
|
+
export function sliceLastRange(a: Tensor, start: number, end: number): Tensor {
|
|
214
|
+
const site = captureSite('sliceLastRange')
|
|
215
|
+
const outShape = inferSliceLastRange('sliceLastRange', a.shape, start, end, site)
|
|
216
|
+
return addOp(currentGraph(), 'slice_last_range', outShape, a.dtype, site, { a: a.id, start, end })
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
// ----------------------------------------------------------------------------
|
|
220
|
+
// Broadcast / un-broadcast. Mostly used by autograd, but exposed in case user
|
|
221
|
+
// code needs them (e.g. explicit broadcasting for clarity).
|
|
222
|
+
// ----------------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
export function broadcastTo(a: Tensor, targetShape: Shape): Tensor {
|
|
225
|
+
const site = captureSite('broadcastTo')
|
|
226
|
+
inferBroadcastTo('broadcastTo', a.shape, targetShape, site)
|
|
227
|
+
return addOp(currentGraph(), 'broadcast_to', targetShape, a.dtype, site, { a: a.id, targetShape })
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
export function sumToShape(a: Tensor, targetShape: Shape): Tensor {
|
|
231
|
+
const site = captureSite('sumToShape')
|
|
232
|
+
inferSumToShape('sumToShape', a.shape, targetShape, site)
|
|
233
|
+
return addOp(currentGraph(), 'sum_to_shape', targetShape, a.dtype, site, { a: a.id, targetShape })
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// ----------------------------------------------------------------------------
|
|
237
|
+
// Constants.
|
|
238
|
+
// ----------------------------------------------------------------------------
|
|
239
|
+
|
|
240
|
+
// 0-d tensor with a constant value. Used by autograd to seed the loss cotangent.
|
|
241
|
+
export function constScalar(value: number, dtype: Dtype = 'f32'): Tensor {
|
|
242
|
+
const site = captureSite('constScalar')
|
|
243
|
+
return addOp(currentGraph(), 'const_scalar', [], dtype, site, { value, dtype })
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// ----------------------------------------------------------------------------
|
|
247
|
+
// Autograd-internal helpers (exposed for users writing custom transpose rules).
|
|
248
|
+
// ----------------------------------------------------------------------------
|
|
249
|
+
|
|
250
|
+
// ----------------------------------------------------------------------------
|
|
251
|
+
// Comparisons and selection.
|
|
252
|
+
// ----------------------------------------------------------------------------
|
|
253
|
+
|
|
254
|
+
// Comparisons reuse the binop helper but return bool.
|
|
255
|
+
export const less = (a: Tensor, b: Tensor): Tensor => binopOp('less', 'less', a, b, 'bool')
|
|
256
|
+
export const greater = (a: Tensor, b: Tensor): Tensor => binopOp('greater', 'greater', a, b, 'bool')
|
|
257
|
+
|
|
258
|
+
// where(cond, a, b): elementwise select. cond is bool; a and b can be any matching dtype.
|
|
259
|
+
export function where(cond: Tensor, a: Tensor, b: Tensor): Tensor {
|
|
260
|
+
const site = captureSite('where')
|
|
261
|
+
if (cond.dtype !== 'bool') throw new ShapeError(`where: cond must be bool, got ${cond.dtype}`, site)
|
|
262
|
+
if (a.dtype !== b.dtype) throw new ShapeError(`where: a/b dtype mismatch (${a.dtype} vs ${b.dtype})`, site)
|
|
263
|
+
const outShape = inferWhere('where', cond.shape, a.shape, b.shape, site)
|
|
264
|
+
return addOp(currentGraph(), 'where', outShape, a.dtype, site, { cond: cond.id, a: a.id, b: b.id })
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
// reluGrad(x, dy) = dy where x > 0, else 0. Same shape as x. This is the
|
|
268
|
+
// transpose rule for relu, exposed as an op so codegen can emit it.
|
|
269
|
+
export function reluGrad(x: Tensor, dy: Tensor): Tensor {
|
|
270
|
+
const site = captureSite('reluGrad')
|
|
271
|
+
if (x.dtype !== 'f32' || dy.dtype !== 'f32') {
|
|
272
|
+
throw new ShapeError(`reluGrad: requires f32, got ${x.dtype} and ${dy.dtype}`, site)
|
|
273
|
+
}
|
|
274
|
+
const outShape = inferReluGrad('reluGrad', x.shape, dy.shape, site)
|
|
275
|
+
return addOp(currentGraph(), 'relu_grad', outShape, 'f32', site, { x: x.id, dy: dy.id })
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// ----------------------------------------------------------------------------
|
|
279
|
+
// Adam-fused ops. Each does its full per-element update in one kernel.
|
|
280
|
+
// ----------------------------------------------------------------------------
|
|
281
|
+
|
|
282
|
+
export function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor {
|
|
283
|
+
const site = captureSite('adamUpdateM')
|
|
284
|
+
if (m.dtype !== 'f32' || g.dtype !== 'f32') throw new ShapeError(`adamUpdateM: requires f32`, site)
|
|
285
|
+
if (m.shape.length !== g.shape.length || m.shape.some((d, i) => d !== g.shape[i])) {
|
|
286
|
+
throw new ShapeError(`adamUpdateM: shape mismatch`, site)
|
|
287
|
+
}
|
|
288
|
+
return addOp(currentGraph(), 'adam_update_m', m.shape, 'f32', site, { m: m.id, g: g.id, b1 })
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
export function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor {
|
|
292
|
+
const site = captureSite('adamUpdateV')
|
|
293
|
+
if (v.dtype !== 'f32' || g.dtype !== 'f32') throw new ShapeError(`adamUpdateV: requires f32`, site)
|
|
294
|
+
if (v.shape.length !== g.shape.length || v.shape.some((d, i) => d !== g.shape[i])) {
|
|
295
|
+
throw new ShapeError(`adamUpdateV: shape mismatch`, site)
|
|
296
|
+
}
|
|
297
|
+
return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 })
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor {
|
|
301
|
+
const site = captureSite('adamUpdateP')
|
|
302
|
+
if (p.dtype !== 'f32') throw new ShapeError(`adamUpdateP: requires f32`, site)
|
|
303
|
+
if (lrt.dtype !== 'f32' || lrt.shape.length !== 0) {
|
|
304
|
+
throw new ShapeError(`adamUpdateP: lrt must be a 0-d f32 scalar`, site)
|
|
305
|
+
}
|
|
306
|
+
if (p.shape.length !== mNew.shape.length || p.shape.some((d, i) => d !== mNew.shape[i])) {
|
|
307
|
+
throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site)
|
|
308
|
+
}
|
|
309
|
+
return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site,
|
|
310
|
+
{ p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps })
|
|
311
|
+
}
|