tensorgrad 0.0.7 → 0.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/nn.d.ts CHANGED
@@ -22,6 +22,14 @@ export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
22
22
  export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
23
23
  /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
24
24
  export declare function mergeHeads(x: Tensor): Tensor;
25
+ /** Slice a flat capture readback of shape `[H, ..., ...]` into one
26
+ * Float32Array per head. The leading axis is treated as the head axis;
27
+ * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
28
+ * each holding the row-major data for that head (size = product of trailing
29
+ * axes). For B>1 graphs, prefix the result by the batch — this helper
30
+ * assumes the leading axis is heads, which matches how `splitHeads` lays
31
+ * out captures at B=1 (the typical capture-readback shape). */
32
+ export declare function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[];
25
33
  /** Per-position cross-entropy along the last (vocab) axis: returns
26
34
  * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
27
35
  * `[...]` of i32; result is `[...]` (one rank less than logits). The user
package/dist/nn.d.ts.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AASrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
1
+ {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AASrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;;;gEAMgE;AAChE,wBAAgB,YAAY,CAAC,IAAI,EAAE,YAAY,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,GAAG,YAAY,EAAE,CAezF;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
package/dist/nn.js CHANGED
@@ -95,6 +95,30 @@ export function mergeHeads(x) {
95
95
  const swapped = swapAxes(x, r - 3, r - 2);
96
96
  return reshape(swapped, [...lead, T, H * d]);
97
97
  }
98
+ /** Slice a flat capture readback of shape `[H, ..., ...]` into one
99
+ * Float32Array per head. The leading axis is treated as the head axis;
100
+ * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
101
+ * each holding the row-major data for that head (size = product of trailing
102
+ * axes). For B>1 graphs, prefix the result by the batch — this helper
103
+ * assumes the leading axis is heads, which matches how `splitHeads` lays
104
+ * out captures at B=1 (the typical capture-readback shape). */
105
+ export function unsplitHeads(flat, shape) {
106
+ if (shape.length < 2) {
107
+ throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`);
108
+ }
109
+ // For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
110
+ // the leading 1 if present so callers can pass captureShapes[name] directly.
111
+ const s = shape[0] === 1 ? shape.slice(1) : shape;
112
+ const H = s[0];
113
+ let stride = 1;
114
+ for (let i = 1; i < s.length; i++)
115
+ stride *= s[i];
116
+ const expected = H * stride;
117
+ if (flat.length !== expected) {
118
+ throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`);
119
+ }
120
+ return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride));
121
+ }
98
122
  // ----------------------------------------------------------------------------
99
123
  // Loss helpers
100
124
  // ----------------------------------------------------------------------------
package/dist/nn.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,UAAU,CAAA;AACzH,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAErC,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC;AAED,+EAA+E;AAC/E,wEAAwE;AACxE,gEAAgE;AAChE,+EAA+E;AAE/E;;4DAE4D;AAC5D,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,MAAc;IAClD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,IAAI,CAAC,GAAG,MAAM,KAAK,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,UAAU,CAAC,wBAAwB,CAAC,4BAA4B,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,MAAM,QAAQ,GAAG,OAAO,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;IAC7D,2DAA2D;IAC3D,OAAO,QAAQ,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAA;AACzD,CAAC;AAED,+DAA+D;AAC/D,MAAM,UAAU,UAAU,CAAC,CAAS;IAClC,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,sEAAsE;IACtE,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACzC,OAAO,OAAO,CAAC,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;AAC9C,CAAC;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;+EAI+E;AAC/E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,OAAe;IAC9D,MAAM,IAAI,GAAG,WAAW,CAAC,kBAAkB,CAAC,CAAA;IAC5C,IAAI,OAAO,CAAC,KAAK,KAAK,KAAK,EAAE,CAAC;QAC5B,MAAM,IAAI,UAAU,CAAC,8CAA8C,OAAO,CAAC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpD,MAAM,EAAE,GAAG,cAAc,CAAC,MAAM,CAAC,CAAA,CAAmC,WAAW;IAC/E,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,EAAE,EAAE,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAA,CAAI,QAAQ;IAC5E,OAAO,GAAG,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAA;AAC1B,CAAC"}
1
+ {"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,UAAU,CAAA;AACzH,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAErC,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC;AAED,+EAA+E;AAC/E,wEAAwE;AACxE,gEAAgE;AAChE,+EAA+E;AAE/E;;4DAE4D;AAC5D,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,MAAc;IAClD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,IAAI,CAAC,GAAG,MAAM,KAAK,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,UAAU,CAAC,wBAAwB,CAAC,4BAA4B,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,MAAM,QAAQ,GAAG,OAAO,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;IAC7D,2DAA2D;IAC3D,OAAO,QAAQ,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAA;AACzD,CAAC;AAED,+DAA+D;AAC/D,MAAM,UAAU,UAAU,CAAC,CAAS;IAClC,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,sEAAsE;IACtE,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACzC,OAAO,OAAO,CAAC,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;AAC9C,CAAC;AAED;;;;;;gEAMgE;AAChE,MAAM,UAAU,YAAY,CAAC,IAAkB,EAAE,KAAwB;IACvE,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,KAAK,CAAC,6CAA6C,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;IACnF,CAAC;IACD,2EAA2E;IAC3E,6EAA6E;IAC7E,MAAM,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAA;IACjD,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAE,CAAA;IACf,IAAI,MAAM,GAAG,CAAC,CAAA;IACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE;QAAE,MAAM,IAAI,CAAC,CAAC,CAAC,CAAE,CAAA;IAClD,MAAM,QAAQ,GAAG,CAAC,GAAG,MAAM,CAAA;IAC3B,IAAI,IAAI,CAAC,MAAM,KAAK,QAAQ,EAAE,CAAC;QAC7B,MAAM,IAAI,KAAK,CAAC,6BAA6B,IAAI,CAAC,MAAM,gCAAgC,QAAQ,EAAE,CAAC,CAAA;IACrG,CAAC;IACD,OAAO,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;AACtF,CAAC;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;+EAI+E;AAC/E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,OAAe;IAC9D,MAAM,IAAI,GAAG,WAAW,CAAC,kBAAkB,CAAC,CAAA;IAC5C,IAAI,OAAO,CAAC,KAAK,KAAK,KAAK,EAAE,CAAC;QAC5B,MAAM,IAAI,UAAU,CAAC,8CAA8C,OAAO,CAAC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpD,MAAM,EAAE,GAAG,cAAc,CAAC,MAAM,CAAC,CAAA,CAAmC,WAAW;IAC/E,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,EAAE,EAAE,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAA,CAAI,QAAQ;IAC5E,OAAO,GAAG,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAA;AAC1B,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "tensorgrad",
3
- "version": "0.0.7",
3
+ "version": "0.0.8",
4
4
  "description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
5
5
  "license": "MIT",
6
6
  "author": "Ben Albahari",
package/src/nn.ts CHANGED
@@ -97,6 +97,30 @@ export function mergeHeads(x: Tensor): Tensor {
97
97
  return reshape(swapped, [...lead, T, H * d])
98
98
  }
99
99
 
100
+ /** Slice a flat capture readback of shape `[H, ..., ...]` into one
101
+ * Float32Array per head. The leading axis is treated as the head axis;
102
+ * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
103
+ * each holding the row-major data for that head (size = product of trailing
104
+ * axes). For B>1 graphs, prefix the result by the batch — this helper
105
+ * assumes the leading axis is heads, which matches how `splitHeads` lays
106
+ * out captures at B=1 (the typical capture-readback shape). */
107
+ export function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[] {
108
+ if (shape.length < 2) {
109
+ throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
110
+ }
111
+ // For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
112
+ // the leading 1 if present so callers can pass captureShapes[name] directly.
113
+ const s = shape[0] === 1 ? shape.slice(1) : shape
114
+ const H = s[0]!
115
+ let stride = 1
116
+ for (let i = 1; i < s.length; i++) stride *= s[i]!
117
+ const expected = H * stride
118
+ if (flat.length !== expected) {
119
+ throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`)
120
+ }
121
+ return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
122
+ }
123
+
100
124
  // ----------------------------------------------------------------------------
101
125
  // Loss helpers
102
126
  // ----------------------------------------------------------------------------