tensorgrad 0.0.9 → 0.0.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.ts CHANGED
@@ -1,42 +1,47 @@
1
- // Public surface. Bulb code imports from here.
2
- //
3
- // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
- // codegen / compile() (Phase 3+) come later.
5
-
6
- export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
- export { ShapeError } from './shape.js'
8
- export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
- export { capture } from './capture.js'
10
- export {
11
- // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
12
- add, sub, mul, div,
13
- // Element-wise unary
14
- sqrt, rsqrt, log, exp, relu,
15
- // Comparisons + select
16
- less, greater, where,
17
- // Reductions over the last axis (other axes via reshape/transpose first)
18
- meanLast, sumLast, sumAll,
19
- // Shape ops
20
- reshape, transpose, swapAxes,
21
- // Linear algebra
22
- matmul, matmulBatched,
23
- // Indexing / casting
24
- oneHot, arange, embedding,
25
- // ML primitives — fused for the transformer
26
- softmaxCausalLast, logSoftmaxLast, whereCausal,
27
- // Slicing
28
- sliceLastRange,
29
- } from './ops.js'
30
-
31
- // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
32
- // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
33
- // adam.ts can import them) but aren't part of the public API — `add`/`mul`
34
- // overload on JS numbers, `where` subsumes the rest.
35
- export { appendGrad, type GradResult } from './grad.js'
36
- export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
37
- export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
- export { emitKernels, type KernelSpec } from './codegen.js'
39
- export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
40
- export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
41
- export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
42
- export * as nn from './nn.js'
1
+ // Public surface. Bulb code imports from here.
2
+ //
3
+ // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
+ // codegen / compile() (Phase 3+) come later.
5
+
6
+ export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
+ export { ShapeError } from './shape.js'
8
+ export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export { capture } from './capture.js'
10
+ export {
11
+ // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
12
+ add, sub, mul, div,
13
+ // Element-wise unary
14
+ sqrt, rsqrt, log, exp, relu,
15
+ // Comparisons + select
16
+ less, greater, where,
17
+ // Reductions over the last axis (other axes via reshape/transpose first)
18
+ meanLast, sumLast, sumAll,
19
+ // Shape ops
20
+ reshape, transpose, swapAxes,
21
+ // Linear algebra
22
+ matmul, matmulBatched,
23
+ // Indexing / casting
24
+ oneHot, arange, embedding,
25
+ // ML primitives — fused for the transformer
26
+ softmaxCausalLast, logSoftmaxLast, whereCausal,
27
+ // Slicing
28
+ sliceLastRange,
29
+ } from './ops.js'
30
+
31
+ // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
32
+ // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
33
+ // adam.ts can import them) but aren't part of the public API — `add`/`mul`
34
+ // overload on JS numbers, `where` subsumes the rest.
35
+ export { appendGrad, type GradResult } from './grad.js'
36
+ export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
37
+ export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
+ export { emitKernels, type KernelSpec } from './codegen.js'
39
+ export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
40
+ export {
41
+ compile, compileToIR, compileModule, compileForward,
42
+ type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
43
+ type CompiledModule, type CompiledForwardModule,
44
+ type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
45
+ } from './compile.js'
46
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
47
+ export * as nn from './nn.js'
package/src/nn.ts CHANGED
@@ -1,41 +1,44 @@
1
1
  // Standard "batteries-included" Module subclasses for the most common layers.
2
2
  //
3
- // JAX-style: each class declares its params (and their init); the forward is a
4
- // plain function the user calls with `(module, x)`. No subclassing, no method
5
- // dispatch keeps the autograd-traced computation visible at the call site.
6
- //
7
- // Import as a namespace:
3
+ // Each class declares its params and a `.fwd(x)` method that runs the forward
4
+ // computation. Forward methods are pure tensorgrad ops autograd traces
5
+ // through them just like any other call.
8
6
  //
9
7
  // import { nn } from 'tensorgrad'
10
8
  // class Block extends Module {
11
9
  // ln = new nn.LayerNorm(D)
12
10
  // ffn = new nn.Linear(D, 4 * D)
13
11
  // }
14
- // const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
12
+ // const y = p.ffn.fwd(p.ln.fwd(x))
15
13
 
16
14
  import { Module } from './module.js'
17
15
  import type { Tensor } from './ir.js'
18
16
  import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
19
17
  import { ShapeError } from './shape.js'
20
18
  import { captureSite } from './ir.js'
19
+ import type { Captures } from './runtime.js'
21
20
 
22
21
  // ----------------------------------------------------------------------------
23
22
  // Linear: y = x @ W (+ b)
24
23
  // ----------------------------------------------------------------------------
25
24
 
25
+ export interface LinearOptions {
26
+ /** Include a bias term (default true). */
27
+ bias?: boolean
28
+ }
29
+
26
30
  export class Linear extends Module {
27
31
  W: Tensor
28
32
  b: Tensor | null
29
- constructor(public readonly inDim: number, public readonly outDim: number, withBias = true) {
33
+ constructor(public readonly inDim: number, public readonly outDim: number, opts: LinearOptions = {}) {
30
34
  super()
31
35
  this.W = this.param([inDim, outDim]) // randn, scale 0.02
32
- this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null
36
+ this.b = opts.bias === false ? null : this.param([outDim], { init: 'zeros' })
37
+ }
38
+ fwd(x: Tensor): Tensor {
39
+ const out = matmul(x, this.W)
40
+ return this.b ? add(out, this.b) : out
33
41
  }
34
- }
35
-
36
- export function linearFwd(p: Linear, x: Tensor): Tensor {
37
- const out = matmul(x, p.W)
38
- return p.b ? add(out, p.b) : out
39
42
  }
40
43
 
41
44
  // ----------------------------------------------------------------------------
@@ -50,14 +53,13 @@ export class LayerNorm extends Module {
50
53
  this.g = this.param([d], { init: 'ones' })
51
54
  this.b = this.param([d], { init: 'zeros' })
52
55
  }
53
- }
54
-
55
- export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
56
- const m = meanLast(x)
57
- const c = sub(x, m)
58
- const v = meanLast(mul(c, c))
59
- const stdev = sqrt(add(v, p.eps))
60
- return add(mul(div(c, stdev), p.g), p.b)
56
+ fwd(x: Tensor): Tensor {
57
+ const m = meanLast(x)
58
+ const c = sub(x, m)
59
+ const v = meanLast(mul(c, c))
60
+ const stdev = sqrt(add(v, this.eps))
61
+ return add(mul(div(c, stdev), this.g), this.b)
62
+ }
61
63
  }
62
64
 
63
65
  // ----------------------------------------------------------------------------
@@ -97,26 +99,26 @@ export function mergeHeads(x: Tensor): Tensor {
97
99
  return reshape(swapped, [...lead, T, H * d])
98
100
  }
99
101
 
100
- /** Slice a flat capture readback of shape `[H, ..., ...]` into one
101
- * Float32Array per head. The leading axis is treated as the head axis;
102
- * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
103
- * each holding the row-major data for that head (size = product of trailing
104
- * axes). For B>1 graphs, prefix the result by the batch — this helper
105
- * assumes the leading axis is heads, which matches how `splitHeads` lays
106
- * out captures at B=1 (the typical capture-readback shape). */
107
- export function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[] {
102
+ /** Slice a captured tensor named `name` into one Float32Array per head, using
103
+ * the static shape registered at compile time. The leading axis is treated as
104
+ * heads (matching `splitHeads` layout at B=1); a leading singleton batch is
105
+ * stripped if present so callers can pass capture names directly. Throws if
106
+ * the capture isn't registered or wasn't read back this call. */
107
+ export function unsplitHeads(captures: Captures, name: string): Float32Array[] {
108
+ const flat = captures.get(name)
109
+ const shape = captures.shapeOf(name)
108
110
  if (shape.length < 2) {
109
- throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
111
+ throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(', ')}]`)
110
112
  }
111
113
  // For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
112
- // the leading 1 if present so callers can pass captureShapes[name] directly.
114
+ // the leading 1 if present so the next axis is heads.
113
115
  const s = shape[0] === 1 ? shape.slice(1) : shape
114
116
  const H = s[0]!
115
117
  let stride = 1
116
118
  for (let i = 1; i < s.length; i++) stride *= s[i]!
117
119
  const expected = H * stride
118
120
  if (flat.length !== expected) {
119
- throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`)
121
+ throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`)
120
122
  }
121
123
  return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
122
124
  }