tensorgrad 0.0.5 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/adam.d.ts +6 -1
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +21 -19
- package/dist/adam.js.map +1 -1
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +20 -30
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +5 -16
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +1 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -3
- package/dist/index.js.map +1 -1
- package/dist/module.d.ts +8 -0
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +19 -3
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +12 -0
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +57 -1
- package/dist/nn.js.map +1 -1
- package/dist/ops.d.ts +13 -0
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +40 -1
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +35 -29
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +28 -25
- package/dist/runtime.js.map +1 -1
- package/package.json +1 -1
- package/src/adam.ts +25 -16
- package/src/compile.ts +28 -27
- package/src/grad.ts +5 -16
- package/src/index.ts +3 -3
- package/src/module.ts +24 -2
- package/src/nn.ts +60 -1
- package/src/ops.ts +42 -1
- package/src/runtime.ts +58 -47
package/src/compile.ts
CHANGED
|
@@ -93,7 +93,7 @@ export async function compileModule<M extends Module>(
|
|
|
93
93
|
): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
94
94
|
const inputDecls = opts.inputs ?? []
|
|
95
95
|
const model = modelFactory()
|
|
96
|
-
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
96
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
|
|
97
97
|
const graph = trace(() => {
|
|
98
98
|
materialized = materializeParams(model)
|
|
99
99
|
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
@@ -104,7 +104,7 @@ export async function compileModule<M extends Module>(
|
|
|
104
104
|
|
|
105
105
|
let adamResult: ReturnType<typeof appendAdam> | undefined
|
|
106
106
|
if (opts.adam) {
|
|
107
|
-
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
|
|
107
|
+
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
|
|
108
108
|
}
|
|
109
109
|
|
|
110
110
|
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
@@ -144,16 +144,8 @@ export async function compileModule<M extends Module>(
|
|
|
144
144
|
}
|
|
145
145
|
}
|
|
146
146
|
|
|
147
|
-
const { initFns } = materialized
|
|
148
147
|
const uploadInitialParams = () => {
|
|
149
|
-
const out
|
|
150
|
-
for (const [name, bufId] of plan.paramsByName) {
|
|
151
|
-
const shape = plan.buffers[bufId]!.shape
|
|
152
|
-
const size = shape.reduce((a, b) => a * b, 1)
|
|
153
|
-
const initFn = initFns[name]
|
|
154
|
-
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
155
|
-
out[name] = initFn(size, shape)
|
|
156
|
-
}
|
|
148
|
+
const out = buildInitialParamUploads(plan, materialized.initFns)
|
|
157
149
|
runtime.uploadParams(out)
|
|
158
150
|
}
|
|
159
151
|
|
|
@@ -161,6 +153,28 @@ export async function compileModule<M extends Module>(
|
|
|
161
153
|
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
162
154
|
}
|
|
163
155
|
|
|
156
|
+
// Build a Record<paramName, Float32Array> by running each param's init
|
|
157
|
+
// function against its shape. Shared by compileModule and compileForward.
|
|
158
|
+
// `sharedParams`, when supplied, skips any name it covers (those are owned
|
|
159
|
+
// by the sibling compile and already initialized there).
|
|
160
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
161
|
+
function buildInitialParamUploads(
|
|
162
|
+
plan: BufferPlan,
|
|
163
|
+
initFns: Record<string, InitFn>,
|
|
164
|
+
sharedParams?: Map<string, GPUBuffer>,
|
|
165
|
+
): Record<string, Float32Array> {
|
|
166
|
+
const out: Record<string, Float32Array> = {}
|
|
167
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
168
|
+
if (sharedParams?.has(name)) continue
|
|
169
|
+
const shape = plan.buffers[bufId]!.shape
|
|
170
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
171
|
+
const initFn = initFns[name]
|
|
172
|
+
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
173
|
+
out[name] = initFn(size, shape)
|
|
174
|
+
}
|
|
175
|
+
return out
|
|
176
|
+
}
|
|
177
|
+
|
|
164
178
|
// ============================================================================
|
|
165
179
|
// Forward-only compile
|
|
166
180
|
// ============================================================================
|
|
@@ -187,7 +201,7 @@ export async function compileForward<M extends Module>(
|
|
|
187
201
|
): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
188
202
|
const inputDecls = opts.inputs ?? []
|
|
189
203
|
const model = modelFactory()
|
|
190
|
-
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
204
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
|
|
191
205
|
const graph = trace(() => {
|
|
192
206
|
materialized = materializeParams(model)
|
|
193
207
|
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
@@ -201,22 +215,9 @@ export async function compileForward<M extends Module>(
|
|
|
201
215
|
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
202
216
|
|
|
203
217
|
const sharedParams = opts.sharedParams
|
|
204
|
-
const { initFns } = materialized
|
|
205
218
|
const uploadInitialParams = () => {
|
|
206
|
-
const out
|
|
207
|
-
|
|
208
|
-
for (const [name, bufId] of plan.paramsByName) {
|
|
209
|
-
// Skip params covered by sharedParams — those are owned by the providing
|
|
210
|
-
// compile and already initialized there.
|
|
211
|
-
if (sharedParams?.has(name)) continue
|
|
212
|
-
const shape = plan.buffers[bufId]!.shape
|
|
213
|
-
const size = shape.reduce((a, b) => a * b, 1)
|
|
214
|
-
const initFn = initFns[name]
|
|
215
|
-
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
216
|
-
out[name] = initFn(size, shape)
|
|
217
|
-
needsUpload = true
|
|
218
|
-
}
|
|
219
|
-
if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
219
|
+
const out = buildInitialParamUploads(plan, materialized.initFns, sharedParams)
|
|
220
|
+
if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
220
221
|
}
|
|
221
222
|
|
|
222
223
|
// CompiledIR.loss is the field name; for forward-only, it carries the user's
|
package/src/grad.ts
CHANGED
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import type { Graph, OpNode, Tensor, Shape } from './ir.js'
|
|
19
19
|
import {
|
|
20
20
|
add, sub, mul, div, mulScalar,
|
|
21
|
-
matmul, matmulBatched, transpose, reshape,
|
|
21
|
+
matmul, matmulBatched, transpose, swapAxes, reshape,
|
|
22
22
|
exp,
|
|
23
23
|
broadcastTo, sumToShape,
|
|
24
24
|
constScalar, reluGrad,
|
|
@@ -280,14 +280,10 @@ function runTransposeRule(
|
|
|
280
280
|
// leading batch dims to get [K, N].
|
|
281
281
|
const a = tensorOf(op.a), b = tensorOf(op.b)
|
|
282
282
|
// dA = dC @ B^T
|
|
283
|
-
|
|
284
|
-
accumulate(cotangents, op.a, matmul(outCotan, bT))
|
|
283
|
+
accumulate(cotangents, op.a, matmul(outCotan, swapAxes(b, -1, -2)))
|
|
285
284
|
// dB: per-batch A^T @ dC, then sum over batch dims.
|
|
286
285
|
// A is [..., M, K]; transpose last two axes.
|
|
287
|
-
const
|
|
288
|
-
;[aTPerm[a.shape.length - 1], aTPerm[a.shape.length - 2]] =
|
|
289
|
-
[aTPerm[a.shape.length - 2]!, aTPerm[a.shape.length - 1]!]
|
|
290
|
-
const aT = transpose(a, aTPerm) // [..., K, M]
|
|
286
|
+
const aT = swapAxes(a, -1, -2) // [..., K, M]
|
|
291
287
|
// matmul_batched needs same rank on both sides. dC has rank `a.rank`;
|
|
292
288
|
// aT has rank `a.rank`; use matmul_batched if rank > 2, else matmul.
|
|
293
289
|
let perBatchDb: Tensor
|
|
@@ -305,15 +301,8 @@ function runTransposeRule(
|
|
|
305
301
|
// dA = dC @ B^T (per-batch, all batch dims preserved)
|
|
306
302
|
// dB = A^T @ dC (per-batch)
|
|
307
303
|
const a = tensorOf(op.a), b = tensorOf(op.b)
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
;[p[rank - 1], p[rank - 2]] = [p[rank - 2]!, p[rank - 1]!]
|
|
311
|
-
return p
|
|
312
|
-
}
|
|
313
|
-
const bT = transpose(b, lastTwoSwap(b.shape.length))
|
|
314
|
-
const aT = transpose(a, lastTwoSwap(a.shape.length))
|
|
315
|
-
accumulate(cotangents, op.a, matmulBatched(outCotan, bT))
|
|
316
|
-
accumulate(cotangents, op.b, matmulBatched(aT, outCotan))
|
|
304
|
+
accumulate(cotangents, op.a, matmulBatched(outCotan, swapAxes(b, -1, -2)))
|
|
305
|
+
accumulate(cotangents, op.b, matmulBatched(swapAxes(a, -1, -2), outCotan))
|
|
317
306
|
return
|
|
318
307
|
}
|
|
319
308
|
|
package/src/index.ts
CHANGED
|
@@ -15,13 +15,13 @@ export {
|
|
|
15
15
|
// Comparisons + select
|
|
16
16
|
less, greater, where,
|
|
17
17
|
// Reductions over the last axis (other axes via reshape/transpose first)
|
|
18
|
-
meanLast, sumLast,
|
|
18
|
+
meanLast, sumLast, sumAll,
|
|
19
19
|
// Shape ops
|
|
20
|
-
reshape, transpose,
|
|
20
|
+
reshape, transpose, swapAxes,
|
|
21
21
|
// Linear algebra
|
|
22
22
|
matmul, matmulBatched,
|
|
23
23
|
// Indexing / casting
|
|
24
|
-
oneHot, arange,
|
|
24
|
+
oneHot, arange, embedding,
|
|
25
25
|
// ML primitives — fused for the transformer
|
|
26
26
|
softmaxCausalLast, logSoftmaxLast, whereCausal,
|
|
27
27
|
// Slicing
|
package/src/module.ts
CHANGED
|
@@ -52,6 +52,11 @@ export interface ParamOptions {
|
|
|
52
52
|
init?: InitSpec
|
|
53
53
|
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
54
|
scale?: number
|
|
55
|
+
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
56
|
+
* decay to this param. Default: `true` for `'randn'` init (weight matrices,
|
|
57
|
+
* embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
|
|
58
|
+
* to force or skip. Replaces `adam.decayFilter` for the common case. */
|
|
59
|
+
decay?: boolean
|
|
55
60
|
}
|
|
56
61
|
|
|
57
62
|
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
@@ -76,6 +81,17 @@ function resolveInit(opts: ParamOptions | undefined): InitFn {
|
|
|
76
81
|
throw new Error(`Unknown init: ${String(init)}`)
|
|
77
82
|
}
|
|
78
83
|
|
|
84
|
+
/** Resolve the decay default for a param. Decay weight matrices and
|
|
85
|
+
* embedding tables (randn-initialized); skip biases (zeros) and LN gains
|
|
86
|
+
* (ones). Custom init functions default to "decay" — most user-supplied
|
|
87
|
+
* inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
|
|
88
|
+
function resolveDecay(opts: ParamOptions | undefined): boolean {
|
|
89
|
+
if (opts?.decay !== undefined) return opts.decay
|
|
90
|
+
const init = opts?.init ?? 'randn'
|
|
91
|
+
if (init === 'zeros' || init === 'ones') return false
|
|
92
|
+
return true // 'randn' or function
|
|
93
|
+
}
|
|
94
|
+
|
|
79
95
|
// ============================================================================
|
|
80
96
|
// Internals: param sentinel
|
|
81
97
|
// ============================================================================
|
|
@@ -90,6 +106,7 @@ class ParamSentinel {
|
|
|
90
106
|
public readonly shape: Shape,
|
|
91
107
|
public readonly dtype: Dtype,
|
|
92
108
|
public readonly initFn: InitFn,
|
|
109
|
+
public readonly decay: boolean,
|
|
93
110
|
) {}
|
|
94
111
|
}
|
|
95
112
|
|
|
@@ -110,7 +127,7 @@ export abstract class Module {
|
|
|
110
127
|
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
111
128
|
const dtype = opts?.dtype ?? 'f32'
|
|
112
129
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
113
|
-
return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
|
|
130
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
|
|
114
131
|
}
|
|
115
132
|
}
|
|
116
133
|
|
|
@@ -123,6 +140,9 @@ export interface MaterializedParams {
|
|
|
123
140
|
tensors: Record<string, Tensor>
|
|
124
141
|
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
125
142
|
initFns: Record<string, InitFn>
|
|
143
|
+
/** Whether this param should receive AdamW weight decay. Resolved at
|
|
144
|
+
* `param()` time from `ParamOptions.decay` (with init-based default). */
|
|
145
|
+
decayFlags: Record<string, boolean>
|
|
126
146
|
}
|
|
127
147
|
|
|
128
148
|
/**
|
|
@@ -136,15 +156,17 @@ export interface MaterializedParams {
|
|
|
136
156
|
export function materializeParams(root: Module): MaterializedParams {
|
|
137
157
|
const tensors: Record<string, Tensor> = {}
|
|
138
158
|
const initFns: Record<string, InitFn> = {}
|
|
159
|
+
const decayFlags: Record<string, boolean> = {}
|
|
139
160
|
visit(root, '', (path, val, owner, key) => {
|
|
140
161
|
if (val instanceof ParamSentinel) {
|
|
141
162
|
const t = paramInput(path, val.shape, val.dtype)
|
|
142
163
|
;(owner as any)[key] = t
|
|
143
164
|
tensors[path] = t
|
|
144
165
|
initFns[path] = val.initFn
|
|
166
|
+
decayFlags[path] = val.decay
|
|
145
167
|
}
|
|
146
168
|
})
|
|
147
|
-
return { tensors, initFns }
|
|
169
|
+
return { tensors, initFns, decayFlags }
|
|
148
170
|
}
|
|
149
171
|
|
|
150
172
|
// ----------------------------------------------------------------------------
|
package/src/nn.ts
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
import { Module } from './module.js'
|
|
17
17
|
import type { Tensor } from './ir.js'
|
|
18
|
-
import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
|
|
18
|
+
import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
|
|
19
|
+
import { ShapeError } from './shape.js'
|
|
20
|
+
import { captureSite } from './ir.js'
|
|
19
21
|
|
|
20
22
|
// ----------------------------------------------------------------------------
|
|
21
23
|
// Linear: y = x @ W (+ b)
|
|
@@ -57,3 +59,60 @@ export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
|
|
|
57
59
|
const stdev = sqrt(add(v, p.eps))
|
|
58
60
|
return add(mul(div(c, stdev), p.g), p.b)
|
|
59
61
|
}
|
|
62
|
+
|
|
63
|
+
// ----------------------------------------------------------------------------
|
|
64
|
+
// Multi-head attention shape helpers — split the last (model) axis into
|
|
65
|
+
// [nHeads, headDim] and bring heads ahead of the sequence axis.
|
|
66
|
+
// ----------------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
/** [..., T, D] → [..., H, T, D/H]. Folds the standard
|
|
69
|
+
* `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
|
|
70
|
+
* call. Last dim of `x` must divide evenly by `nHeads`. */
|
|
71
|
+
export function splitHeads(x: Tensor, nHeads: number): Tensor {
|
|
72
|
+
const site = captureSite('splitHeads')
|
|
73
|
+
const r = x.shape.length
|
|
74
|
+
if (r < 2) throw new ShapeError(`splitHeads: requires rank >= 2, got ${r}`, site)
|
|
75
|
+
const T = x.shape[r - 2]!
|
|
76
|
+
const D = x.shape[r - 1]!
|
|
77
|
+
if (D % nHeads !== 0) {
|
|
78
|
+
throw new ShapeError(`splitHeads: last dim ${D} not divisible by nHeads ${nHeads}`, site)
|
|
79
|
+
}
|
|
80
|
+
const lead = x.shape.slice(0, r - 2)
|
|
81
|
+
const reshaped = reshape(x, [...lead, T, nHeads, D / nHeads])
|
|
82
|
+
// Swap T (axis lead.length) with H (axis lead.length + 1).
|
|
83
|
+
return swapAxes(reshaped, lead.length, lead.length + 1)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
|
|
87
|
+
export function mergeHeads(x: Tensor): Tensor {
|
|
88
|
+
const site = captureSite('mergeHeads')
|
|
89
|
+
const r = x.shape.length
|
|
90
|
+
if (r < 3) throw new ShapeError(`mergeHeads: requires rank >= 3, got ${r}`, site)
|
|
91
|
+
const H = x.shape[r - 3]!
|
|
92
|
+
const T = x.shape[r - 2]!
|
|
93
|
+
const d = x.shape[r - 1]!
|
|
94
|
+
const lead = x.shape.slice(0, r - 3)
|
|
95
|
+
// Swap H (axis r-3) and T (axis r-2): [..., H, T, d] → [..., T, H, d]
|
|
96
|
+
const swapped = swapAxes(x, r - 3, r - 2)
|
|
97
|
+
return reshape(swapped, [...lead, T, H * d])
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// ----------------------------------------------------------------------------
|
|
101
|
+
// Loss helpers
|
|
102
|
+
// ----------------------------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
/** Per-position cross-entropy along the last (vocab) axis: returns
|
|
105
|
+
* `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
|
|
106
|
+
* `[...]` of i32; result is `[...]` (one rank less than logits). The user
|
|
107
|
+
* applies their own masking + reduction downstream — useful when only some
|
|
108
|
+
* positions contribute (e.g. result-digit masking) or for label smoothing. */
|
|
109
|
+
export function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor {
|
|
110
|
+
const site = captureSite('crossEntropyLast')
|
|
111
|
+
if (targets.dtype !== 'i32') {
|
|
112
|
+
throw new ShapeError(`crossEntropyLast: targets must be i32, got ${targets.dtype}`, site)
|
|
113
|
+
}
|
|
114
|
+
const vocab = logits.shape[logits.shape.length - 1]!
|
|
115
|
+
const lp = logSoftmaxLast(logits) // [..., V]
|
|
116
|
+
const targetLp = sumLast(mul(lp, oneHot(targets, vocab, 'f32'))) // [...]
|
|
117
|
+
return mul(targetLp, -1)
|
|
118
|
+
}
|
package/src/ops.ts
CHANGED
|
@@ -17,7 +17,7 @@ import {
|
|
|
17
17
|
inferReshape, inferTranspose, inferMatmul, inferMatmulBatched,
|
|
18
18
|
inferOneHot, inferWhereCausal, inferSliceLastRange,
|
|
19
19
|
inferBroadcastTo, inferSumToShape, inferReluGrad, inferWhere,
|
|
20
|
-
ShapeError,
|
|
20
|
+
ShapeError, showShape,
|
|
21
21
|
} from './shape.js'
|
|
22
22
|
|
|
23
23
|
// ----------------------------------------------------------------------------
|
|
@@ -112,6 +112,11 @@ export function sumLast(a: Tensor): Tensor {
|
|
|
112
112
|
return addOp(currentGraph(), 'sum_last', outShape, a.dtype, site, { a: a.id })
|
|
113
113
|
}
|
|
114
114
|
|
|
115
|
+
/** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
|
|
116
|
+
export function sumAll(a: Tensor): Tensor {
|
|
117
|
+
return sumLast(reshape(a, [-1]))
|
|
118
|
+
}
|
|
119
|
+
|
|
115
120
|
// ----------------------------------------------------------------------------
|
|
116
121
|
// Shape ops.
|
|
117
122
|
// ----------------------------------------------------------------------------
|
|
@@ -128,6 +133,26 @@ export function transpose(a: Tensor, perm: readonly number[]): Tensor {
|
|
|
128
133
|
return addOp(currentGraph(), 'transpose', outShape, a.dtype, site, { a: a.id, perm })
|
|
129
134
|
}
|
|
130
135
|
|
|
136
|
+
/** Swap two axes of a tensor. Negative indices count from the end (so
|
|
137
|
+
* `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
|
|
138
|
+
* All other axes keep their position. Implemented as `transpose` with the
|
|
139
|
+
* permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
|
|
140
|
+
export function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor {
|
|
141
|
+
const r = a.shape.length
|
|
142
|
+
const norm = (axis: number): number => axis < 0 ? r + axis : axis
|
|
143
|
+
const i1 = norm(axis1)
|
|
144
|
+
const i2 = norm(axis2)
|
|
145
|
+
const site = captureSite('swapAxes')
|
|
146
|
+
if (i1 < 0 || i1 >= r || i2 < 0 || i2 >= r) {
|
|
147
|
+
throw new ShapeError(`swapAxes: axis out of range — got (${axis1}, ${axis2}) for rank-${r} tensor`, site)
|
|
148
|
+
}
|
|
149
|
+
if (i1 === i2) return a
|
|
150
|
+
const perm = Array.from({ length: r }, (_, k) => k)
|
|
151
|
+
perm[i1] = i2
|
|
152
|
+
perm[i2] = i1
|
|
153
|
+
return transpose(a, perm)
|
|
154
|
+
}
|
|
155
|
+
|
|
131
156
|
// ----------------------------------------------------------------------------
|
|
132
157
|
// Linear algebra.
|
|
133
158
|
// ----------------------------------------------------------------------------
|
|
@@ -163,6 +188,22 @@ export function oneHot(indices: Tensor, depth: number, dtype: Dtype = 'f32'): Te
|
|
|
163
188
|
return addOp(currentGraph(), 'one_hot', outShape, dtype, site, { indices: indices.id, depth, dtype })
|
|
164
189
|
}
|
|
165
190
|
|
|
191
|
+
/** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
|
|
192
|
+
* to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
|
|
193
|
+
* scatter-with-atomic-add backward — the matmul transpose rule handles it.
|
|
194
|
+
* `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
|
|
195
|
+
* is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
|
|
196
|
+
export function embedding(table: Tensor, indices: Tensor): Tensor {
|
|
197
|
+
const site = captureSite('embedding')
|
|
198
|
+
if (table.shape.length !== 2) {
|
|
199
|
+
throw new ShapeError(`embedding: table must be 2-d [vocab, dim], got ${showShape(table.shape)}`, site)
|
|
200
|
+
}
|
|
201
|
+
if (indices.dtype !== 'i32') {
|
|
202
|
+
throw new ShapeError(`embedding: indices must be i32, got ${indices.dtype}`, site)
|
|
203
|
+
}
|
|
204
|
+
return matmul(oneHot(indices, table.shape[0]!, 'f32'), table)
|
|
205
|
+
}
|
|
206
|
+
|
|
166
207
|
// arange(n) → [n] of values [0, 1, ..., n-1]. Used for position embeddings.
|
|
167
208
|
export function arange(n: number, dtype: Dtype = 'i32'): Tensor {
|
|
168
209
|
const site = captureSite('arange')
|
package/src/runtime.ts
CHANGED
|
@@ -43,17 +43,38 @@ export interface RunWithCaptures {
|
|
|
43
43
|
captures: Record<string, Float32Array>
|
|
44
44
|
}
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
*
|
|
46
|
+
/** Common surface for both training and forward-only compiled runtimes. */
|
|
47
|
+
export interface CompiledBase {
|
|
48
|
+
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
49
|
+
* `sharedParams` to share without copies. */
|
|
50
50
|
params: Map<string, GPUBuffer>
|
|
51
|
+
/** Shape of each tensor registered via `capture(name, t)`. Static after
|
|
52
|
+
* compile — reshape readbacks without recomputing strides. */
|
|
53
|
+
captureShapes: Record<string, number[]>
|
|
54
|
+
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
55
|
+
* returned tensor for forward-only compiles). */
|
|
56
|
+
outputShape: number[]
|
|
51
57
|
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
52
58
|
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
53
59
|
* `{ partial: true }` to skip the missing-key check. */
|
|
54
60
|
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
55
61
|
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
56
62
|
downloadParams(): Promise<Record<string, Float32Array>>
|
|
63
|
+
/** Free GPU resources. */
|
|
64
|
+
destroy(): void
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/** Run a dispatch and read back the full output tensor (and any registered
|
|
68
|
+
* captures if requested). Forward-only compiles use this as their primary
|
|
69
|
+
* surface; training compiles also expose it but `step()` is more convenient
|
|
70
|
+
* there because the output is a scalar loss. */
|
|
71
|
+
export interface RunFn {
|
|
72
|
+
(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
73
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
74
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
export interface CompiledRuntime extends CompiledBase {
|
|
57
78
|
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
58
79
|
downloadParamGrads(): Promise<Record<string, Float32Array>>
|
|
59
80
|
/**
|
|
@@ -68,32 +89,19 @@ export interface CompiledRuntime {
|
|
|
68
89
|
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
69
90
|
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
|
|
70
91
|
step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
|
|
71
|
-
/**
|
|
72
|
-
*
|
|
73
|
-
*
|
|
74
|
-
run
|
|
75
|
-
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
76
|
-
run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
92
|
+
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
93
|
+
* training graphs the output is a scalar loss, so step() is usually more
|
|
94
|
+
* convenient. Provided for parity with `compileForward`. */
|
|
95
|
+
run: RunFn
|
|
77
96
|
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
78
97
|
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
79
98
|
resetOptimizerState(): void
|
|
80
|
-
/** Free GPU resources. */
|
|
81
|
-
destroy(): void
|
|
82
99
|
}
|
|
83
100
|
|
|
84
101
|
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
85
102
|
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
86
|
-
export interface CompiledForward {
|
|
87
|
-
|
|
88
|
-
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
89
|
-
downloadParams(): Promise<Record<string, Float32Array>>
|
|
90
|
-
/** Forward-only dispatch. Returns the graph's output tensor as a Float32Array
|
|
91
|
-
* (the user's returned tensor from the forward function, in row-major order).
|
|
92
|
-
* With `{ withCaptures: true }`, returns `{ output, captures }`. */
|
|
93
|
-
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
94
|
-
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
95
|
-
run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
96
|
-
destroy(): void
|
|
103
|
+
export interface CompiledForward extends CompiledBase {
|
|
104
|
+
run: RunFn
|
|
97
105
|
}
|
|
98
106
|
|
|
99
107
|
export interface RuntimeOpts {
|
|
@@ -147,14 +155,7 @@ export async function createRuntime(
|
|
|
147
155
|
label: spec.name ?? `t${spec.id}-${spec.kind}`,
|
|
148
156
|
})
|
|
149
157
|
buffers.set(spec.id, buf)
|
|
150
|
-
if (spec.kind === 'state')
|
|
151
|
-
// Fill with initValue (typically 0). Float and int both 4 bytes per element.
|
|
152
|
-
const elements = spec.byteSize / 4
|
|
153
|
-
const init = spec.dtype === 'f32'
|
|
154
|
-
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
155
|
-
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
156
|
-
queue.writeBuffer(buf, 0, init as unknown as BufferSource)
|
|
157
|
-
}
|
|
158
|
+
if (spec.kind === 'state') fillStateBuffer(spec, buf)
|
|
158
159
|
}
|
|
159
160
|
// Track which params are externally owned — those are skipped on destroy().
|
|
160
161
|
const ownedBufferIds = new Set<number>()
|
|
@@ -404,14 +405,20 @@ export async function createRuntime(
|
|
|
404
405
|
return out
|
|
405
406
|
}
|
|
406
407
|
|
|
408
|
+
// Fill a state buffer with its declared initValue (typically 0). Float and
|
|
409
|
+
// int both serialize to 4 bytes per element. Used at allocation time and on
|
|
410
|
+
// resetOptimizerState() — same logic, two callers.
|
|
411
|
+
function fillStateBuffer(spec: { byteSize: number; dtype: 'f32' | 'i32' | 'bool'; initValue?: number }, target: GPUBuffer): void {
|
|
412
|
+
const elements = spec.byteSize / 4
|
|
413
|
+
const init = spec.dtype === 'f32'
|
|
414
|
+
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
415
|
+
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
416
|
+
queue.writeBuffer(target, 0, init as unknown as BufferSource)
|
|
417
|
+
}
|
|
418
|
+
|
|
407
419
|
function resetOptimizerState() {
|
|
408
420
|
for (const spec of plan.buffers) {
|
|
409
|
-
if (spec.kind
|
|
410
|
-
const elements = spec.byteSize / 4
|
|
411
|
-
const init = spec.dtype === 'f32'
|
|
412
|
-
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
413
|
-
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
414
|
-
queue.writeBuffer(buffers.get(spec.id)!, 0, init as unknown as BufferSource)
|
|
421
|
+
if (spec.kind === 'state') fillStateBuffer(spec, buffers.get(spec.id)!)
|
|
415
422
|
}
|
|
416
423
|
}
|
|
417
424
|
|
|
@@ -421,6 +428,13 @@ export async function createRuntime(
|
|
|
421
428
|
for (const [name, bufId] of plan.paramsByName) {
|
|
422
429
|
params.set(name, buffers.get(bufId)!)
|
|
423
430
|
}
|
|
431
|
+
// Static-after-compile shape metadata so users don't have to recompute
|
|
432
|
+
// strides to interpret a flat capture readback.
|
|
433
|
+
const captureShapes: Record<string, number[]> = {}
|
|
434
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
435
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
436
|
+
}
|
|
437
|
+
const outputShape = [...plan.buffers[lossBufferId]!.shape]
|
|
424
438
|
|
|
425
439
|
const destroy = () => {
|
|
426
440
|
for (const [id, b] of buffers) {
|
|
@@ -432,6 +446,8 @@ export async function createRuntime(
|
|
|
432
446
|
|
|
433
447
|
return {
|
|
434
448
|
params,
|
|
449
|
+
captureShapes,
|
|
450
|
+
outputShape,
|
|
435
451
|
uploadParams,
|
|
436
452
|
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
437
453
|
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
@@ -442,22 +458,17 @@ export async function createRuntime(
|
|
|
442
458
|
}
|
|
443
459
|
}
|
|
444
460
|
|
|
445
|
-
/** Same machinery as `createRuntime`, narrower public
|
|
446
|
-
*
|
|
461
|
+
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
462
|
+
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
463
|
+
* loss readback). The full runtime object is built once and projected by
|
|
464
|
+
* `compileForward` to the public shape. */
|
|
447
465
|
export async function createForwardRuntime(
|
|
448
466
|
plan: BufferPlan,
|
|
449
467
|
kernels: KernelSpec[],
|
|
450
468
|
outputBufferId: number,
|
|
451
469
|
opts: RuntimeOpts = {},
|
|
452
470
|
): Promise<CompiledForward> {
|
|
453
|
-
|
|
454
|
-
return {
|
|
455
|
-
params: full.params,
|
|
456
|
-
uploadParams: full.uploadParams,
|
|
457
|
-
downloadParams: full.downloadParams,
|
|
458
|
-
run: full.run,
|
|
459
|
-
destroy: full.destroy,
|
|
460
|
-
}
|
|
471
|
+
return await createRuntime(plan, kernels, outputBufferId, opts)
|
|
461
472
|
}
|
|
462
473
|
|
|
463
474
|
async function acquireDevice(): Promise<GPUDevice> {
|