tensorgrad 0.0.9 → 0.0.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/compile.d.ts +77 -28
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +132 -81
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/nn.d.ts +14 -11
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +28 -33
- package/dist/nn.js.map +1 -1
- package/dist/runtime.d.ts +35 -27
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +45 -10
- package/dist/runtime.js.map +1 -1
- package/package.json +61 -61
- package/src/compile.ts +358 -227
- package/src/index.ts +47 -42
- package/src/nn.ts +34 -32
- package/src/runtime.ts +523 -497
package/src/index.ts
CHANGED
|
@@ -1,42 +1,47 @@
|
|
|
1
|
-
// Public surface. Bulb code imports from here.
|
|
2
|
-
//
|
|
3
|
-
// Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
|
|
4
|
-
// codegen / compile() (Phase 3+) come later.
|
|
5
|
-
|
|
6
|
-
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
|
-
export { ShapeError } from './shape.js'
|
|
8
|
-
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
-
export { capture } from './capture.js'
|
|
10
|
-
export {
|
|
11
|
-
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
12
|
-
add, sub, mul, div,
|
|
13
|
-
// Element-wise unary
|
|
14
|
-
sqrt, rsqrt, log, exp, relu,
|
|
15
|
-
// Comparisons + select
|
|
16
|
-
less, greater, where,
|
|
17
|
-
// Reductions over the last axis (other axes via reshape/transpose first)
|
|
18
|
-
meanLast, sumLast, sumAll,
|
|
19
|
-
// Shape ops
|
|
20
|
-
reshape, transpose, swapAxes,
|
|
21
|
-
// Linear algebra
|
|
22
|
-
matmul, matmulBatched,
|
|
23
|
-
// Indexing / casting
|
|
24
|
-
oneHot, arange, embedding,
|
|
25
|
-
// ML primitives — fused for the transformer
|
|
26
|
-
softmaxCausalLast, logSoftmaxLast, whereCausal,
|
|
27
|
-
// Slicing
|
|
28
|
-
sliceLastRange,
|
|
29
|
-
} from './ops.js'
|
|
30
|
-
|
|
31
|
-
// Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
|
|
32
|
-
// are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
|
|
33
|
-
// adam.ts can import them) but aren't part of the public API — `add`/`mul`
|
|
34
|
-
// overload on JS numbers, `where` subsumes the rest.
|
|
35
|
-
export { appendGrad, type GradResult } from './grad.js'
|
|
36
|
-
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
37
|
-
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
|
-
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
-
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type
|
|
40
|
-
export {
|
|
41
|
-
|
|
42
|
-
|
|
1
|
+
// Public surface. Bulb code imports from here.
|
|
2
|
+
//
|
|
3
|
+
// Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
|
|
4
|
+
// codegen / compile() (Phase 3+) come later.
|
|
5
|
+
|
|
6
|
+
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
|
+
export { ShapeError } from './shape.js'
|
|
8
|
+
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
+
export { capture } from './capture.js'
|
|
10
|
+
export {
|
|
11
|
+
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
12
|
+
add, sub, mul, div,
|
|
13
|
+
// Element-wise unary
|
|
14
|
+
sqrt, rsqrt, log, exp, relu,
|
|
15
|
+
// Comparisons + select
|
|
16
|
+
less, greater, where,
|
|
17
|
+
// Reductions over the last axis (other axes via reshape/transpose first)
|
|
18
|
+
meanLast, sumLast, sumAll,
|
|
19
|
+
// Shape ops
|
|
20
|
+
reshape, transpose, swapAxes,
|
|
21
|
+
// Linear algebra
|
|
22
|
+
matmul, matmulBatched,
|
|
23
|
+
// Indexing / casting
|
|
24
|
+
oneHot, arange, embedding,
|
|
25
|
+
// ML primitives — fused for the transformer
|
|
26
|
+
softmaxCausalLast, logSoftmaxLast, whereCausal,
|
|
27
|
+
// Slicing
|
|
28
|
+
sliceLastRange,
|
|
29
|
+
} from './ops.js'
|
|
30
|
+
|
|
31
|
+
// Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
|
|
32
|
+
// are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
|
|
33
|
+
// adam.ts can import them) but aren't part of the public API — `add`/`mul`
|
|
34
|
+
// overload on JS numbers, `where` subsumes the rest.
|
|
35
|
+
export { appendGrad, type GradResult } from './grad.js'
|
|
36
|
+
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
37
|
+
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
|
+
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
+
export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
|
|
40
|
+
export {
|
|
41
|
+
compile, compileToIR, compileModule, compileForward,
|
|
42
|
+
type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
|
|
43
|
+
type CompiledModule, type CompiledForwardModule,
|
|
44
|
+
type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
|
|
45
|
+
} from './compile.js'
|
|
46
|
+
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
47
|
+
export * as nn from './nn.js'
|
package/src/nn.ts
CHANGED
|
@@ -1,41 +1,44 @@
|
|
|
1
1
|
// Standard "batteries-included" Module subclasses for the most common layers.
|
|
2
2
|
//
|
|
3
|
-
//
|
|
4
|
-
//
|
|
5
|
-
//
|
|
6
|
-
//
|
|
7
|
-
// Import as a namespace:
|
|
3
|
+
// Each class declares its params and a `.fwd(x)` method that runs the forward
|
|
4
|
+
// computation. Forward methods are pure tensorgrad ops — autograd traces
|
|
5
|
+
// through them just like any other call.
|
|
8
6
|
//
|
|
9
7
|
// import { nn } from 'tensorgrad'
|
|
10
8
|
// class Block extends Module {
|
|
11
9
|
// ln = new nn.LayerNorm(D)
|
|
12
10
|
// ffn = new nn.Linear(D, 4 * D)
|
|
13
11
|
// }
|
|
14
|
-
// const y =
|
|
12
|
+
// const y = p.ffn.fwd(p.ln.fwd(x))
|
|
15
13
|
|
|
16
14
|
import { Module } from './module.js'
|
|
17
15
|
import type { Tensor } from './ir.js'
|
|
18
16
|
import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
|
|
19
17
|
import { ShapeError } from './shape.js'
|
|
20
18
|
import { captureSite } from './ir.js'
|
|
19
|
+
import type { Captures } from './runtime.js'
|
|
21
20
|
|
|
22
21
|
// ----------------------------------------------------------------------------
|
|
23
22
|
// Linear: y = x @ W (+ b)
|
|
24
23
|
// ----------------------------------------------------------------------------
|
|
25
24
|
|
|
25
|
+
export interface LinearOptions {
|
|
26
|
+
/** Include a bias term (default true). */
|
|
27
|
+
bias?: boolean
|
|
28
|
+
}
|
|
29
|
+
|
|
26
30
|
export class Linear extends Module {
|
|
27
31
|
W: Tensor
|
|
28
32
|
b: Tensor | null
|
|
29
|
-
constructor(public readonly inDim: number, public readonly outDim: number,
|
|
33
|
+
constructor(public readonly inDim: number, public readonly outDim: number, opts: LinearOptions = {}) {
|
|
30
34
|
super()
|
|
31
35
|
this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
32
|
-
this.b =
|
|
36
|
+
this.b = opts.bias === false ? null : this.param([outDim], { init: 'zeros' })
|
|
37
|
+
}
|
|
38
|
+
fwd(x: Tensor): Tensor {
|
|
39
|
+
const out = matmul(x, this.W)
|
|
40
|
+
return this.b ? add(out, this.b) : out
|
|
33
41
|
}
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
export function linearFwd(p: Linear, x: Tensor): Tensor {
|
|
37
|
-
const out = matmul(x, p.W)
|
|
38
|
-
return p.b ? add(out, p.b) : out
|
|
39
42
|
}
|
|
40
43
|
|
|
41
44
|
// ----------------------------------------------------------------------------
|
|
@@ -50,14 +53,13 @@ export class LayerNorm extends Module {
|
|
|
50
53
|
this.g = this.param([d], { init: 'ones' })
|
|
51
54
|
this.b = this.param([d], { init: 'zeros' })
|
|
52
55
|
}
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
return add(mul(div(c, stdev), p.g), p.b)
|
|
56
|
+
fwd(x: Tensor): Tensor {
|
|
57
|
+
const m = meanLast(x)
|
|
58
|
+
const c = sub(x, m)
|
|
59
|
+
const v = meanLast(mul(c, c))
|
|
60
|
+
const stdev = sqrt(add(v, this.eps))
|
|
61
|
+
return add(mul(div(c, stdev), this.g), this.b)
|
|
62
|
+
}
|
|
61
63
|
}
|
|
62
64
|
|
|
63
65
|
// ----------------------------------------------------------------------------
|
|
@@ -97,26 +99,26 @@ export function mergeHeads(x: Tensor): Tensor {
|
|
|
97
99
|
return reshape(swapped, [...lead, T, H * d])
|
|
98
100
|
}
|
|
99
101
|
|
|
100
|
-
/** Slice a
|
|
101
|
-
*
|
|
102
|
-
*
|
|
103
|
-
*
|
|
104
|
-
*
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
103
|
+
* the static shape registered at compile time. The leading axis is treated as
|
|
104
|
+
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
105
|
+
* stripped if present so callers can pass capture names directly. Throws if
|
|
106
|
+
* the capture isn't registered or wasn't read back this call. */
|
|
107
|
+
export function unsplitHeads(captures: Captures, name: string): Float32Array[] {
|
|
108
|
+
const flat = captures.get(name)
|
|
109
|
+
const shape = captures.shapeOf(name)
|
|
108
110
|
if (shape.length < 2) {
|
|
109
|
-
throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
|
|
111
|
+
throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(', ')}]`)
|
|
110
112
|
}
|
|
111
113
|
// For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
|
|
112
|
-
// the leading 1 if present so
|
|
114
|
+
// the leading 1 if present so the next axis is heads.
|
|
113
115
|
const s = shape[0] === 1 ? shape.slice(1) : shape
|
|
114
116
|
const H = s[0]!
|
|
115
117
|
let stride = 1
|
|
116
118
|
for (let i = 1; i < s.length; i++) stride *= s[i]!
|
|
117
119
|
const expected = H * stride
|
|
118
120
|
if (flat.length !== expected) {
|
|
119
|
-
throw new Error(`unsplitHeads:
|
|
121
|
+
throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`)
|
|
120
122
|
}
|
|
121
123
|
return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
|
|
122
124
|
}
|