tensorgrad 0.0.7 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/nn.d.ts +8 -0
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +24 -0
- package/dist/nn.js.map +1 -1
- package/package.json +1 -1
- package/src/nn.ts +24 -0
package/dist/nn.d.ts
CHANGED
|
@@ -22,6 +22,14 @@ export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
|
|
|
22
22
|
export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
|
|
23
23
|
/** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
|
|
24
24
|
export declare function mergeHeads(x: Tensor): Tensor;
|
|
25
|
+
/** Slice a flat capture readback of shape `[H, ..., ...]` into one
|
|
26
|
+
* Float32Array per head. The leading axis is treated as the head axis;
|
|
27
|
+
* pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
|
|
28
|
+
* each holding the row-major data for that head (size = product of trailing
|
|
29
|
+
* axes). For B>1 graphs, prefix the result by the batch — this helper
|
|
30
|
+
* assumes the leading axis is heads, which matches how `splitHeads` lays
|
|
31
|
+
* out captures at B=1 (the typical capture-readback shape). */
|
|
32
|
+
export declare function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[];
|
|
25
33
|
/** Per-position cross-entropy along the last (vocab) axis: returns
|
|
26
34
|
* `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
|
|
27
35
|
* `[...]` of i32; result is `[...]` (one rank less than logits). The user
|
package/dist/nn.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AASrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
|
|
1
|
+
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AASrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;;;gEAMgE;AAChE,wBAAgB,YAAY,CAAC,IAAI,EAAE,YAAY,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,GAAG,YAAY,EAAE,CAezF;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
|
package/dist/nn.js
CHANGED
|
@@ -95,6 +95,30 @@ export function mergeHeads(x) {
|
|
|
95
95
|
const swapped = swapAxes(x, r - 3, r - 2);
|
|
96
96
|
return reshape(swapped, [...lead, T, H * d]);
|
|
97
97
|
}
|
|
98
|
+
/** Slice a flat capture readback of shape `[H, ..., ...]` into one
|
|
99
|
+
* Float32Array per head. The leading axis is treated as the head axis;
|
|
100
|
+
* pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
|
|
101
|
+
* each holding the row-major data for that head (size = product of trailing
|
|
102
|
+
* axes). For B>1 graphs, prefix the result by the batch — this helper
|
|
103
|
+
* assumes the leading axis is heads, which matches how `splitHeads` lays
|
|
104
|
+
* out captures at B=1 (the typical capture-readback shape). */
|
|
105
|
+
export function unsplitHeads(flat, shape) {
|
|
106
|
+
if (shape.length < 2) {
|
|
107
|
+
throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`);
|
|
108
|
+
}
|
|
109
|
+
// For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
|
|
110
|
+
// the leading 1 if present so callers can pass captureShapes[name] directly.
|
|
111
|
+
const s = shape[0] === 1 ? shape.slice(1) : shape;
|
|
112
|
+
const H = s[0];
|
|
113
|
+
let stride = 1;
|
|
114
|
+
for (let i = 1; i < s.length; i++)
|
|
115
|
+
stride *= s[i];
|
|
116
|
+
const expected = H * stride;
|
|
117
|
+
if (flat.length !== expected) {
|
|
118
|
+
throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`);
|
|
119
|
+
}
|
|
120
|
+
return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride));
|
|
121
|
+
}
|
|
98
122
|
// ----------------------------------------------------------------------------
|
|
99
123
|
// Loss helpers
|
|
100
124
|
// ----------------------------------------------------------------------------
|
package/dist/nn.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,UAAU,CAAA;AACzH,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAErC,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC;AAED,+EAA+E;AAC/E,wEAAwE;AACxE,gEAAgE;AAChE,+EAA+E;AAE/E;;4DAE4D;AAC5D,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,MAAc;IAClD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,IAAI,CAAC,GAAG,MAAM,KAAK,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,UAAU,CAAC,wBAAwB,CAAC,4BAA4B,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,MAAM,QAAQ,GAAG,OAAO,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;IAC7D,2DAA2D;IAC3D,OAAO,QAAQ,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAA;AACzD,CAAC;AAED,+DAA+D;AAC/D,MAAM,UAAU,UAAU,CAAC,CAAS;IAClC,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,sEAAsE;IACtE,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACzC,OAAO,OAAO,CAAC,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;AAC9C,CAAC;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;+EAI+E;AAC/E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,OAAe;IAC9D,MAAM,IAAI,GAAG,WAAW,CAAC,kBAAkB,CAAC,CAAA;IAC5C,IAAI,OAAO,CAAC,KAAK,KAAK,KAAK,EAAE,CAAC;QAC5B,MAAM,IAAI,UAAU,CAAC,8CAA8C,OAAO,CAAC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpD,MAAM,EAAE,GAAG,cAAc,CAAC,MAAM,CAAC,CAAA,CAAmC,WAAW;IAC/E,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,EAAE,EAAE,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAA,CAAI,QAAQ;IAC5E,OAAO,GAAG,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAA;AAC1B,CAAC"}
|
|
1
|
+
{"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,UAAU,CAAA;AACzH,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAErC,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC;AAED,+EAA+E;AAC/E,wEAAwE;AACxE,gEAAgE;AAChE,+EAA+E;AAE/E;;4DAE4D;AAC5D,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,MAAc;IAClD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,IAAI,CAAC,GAAG,MAAM,KAAK,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,UAAU,CAAC,wBAAwB,CAAC,4BAA4B,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,MAAM,QAAQ,GAAG,OAAO,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;IAC7D,2DAA2D;IAC3D,OAAO,QAAQ,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAA;AACzD,CAAC;AAED,+DAA+D;AAC/D,MAAM,UAAU,UAAU,CAAC,CAAS;IAClC,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,sEAAsE;IACtE,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACzC,OAAO,OAAO,CAAC,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;AAC9C,CAAC;AAED;;;;;;gEAMgE;AAChE,MAAM,UAAU,YAAY,CAAC,IAAkB,EAAE,KAAwB;IACvE,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,KAAK,CAAC,6CAA6C,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;IACnF,CAAC;IACD,2EAA2E;IAC3E,6EAA6E;IAC7E,MAAM,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAA;IACjD,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAE,CAAA;IACf,IAAI,MAAM,GAAG,CAAC,CAAA;IACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE;QAAE,MAAM,IAAI,CAAC,CAAC,CAAC,CAAE,CAAA;IAClD,MAAM,QAAQ,GAAG,CAAC,GAAG,MAAM,CAAA;IAC3B,IAAI,IAAI,CAAC,MAAM,KAAK,QAAQ,EAAE,CAAC;QAC7B,MAAM,IAAI,KAAK,CAAC,6BAA6B,IAAI,CAAC,MAAM,gCAAgC,QAAQ,EAAE,CAAC,CAAA;IACrG,CAAC;IACD,OAAO,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;AACtF,CAAC;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;+EAI+E;AAC/E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,OAAe;IAC9D,MAAM,IAAI,GAAG,WAAW,CAAC,kBAAkB,CAAC,CAAA;IAC5C,IAAI,OAAO,CAAC,KAAK,KAAK,KAAK,EAAE,CAAC;QAC5B,MAAM,IAAI,UAAU,CAAC,8CAA8C,OAAO,CAAC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpD,MAAM,EAAE,GAAG,cAAc,CAAC,MAAM,CAAC,CAAA,CAAmC,WAAW;IAC/E,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,EAAE,EAAE,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAA,CAAI,QAAQ;IAC5E,OAAO,GAAG,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAA;AAC1B,CAAC"}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "tensorgrad",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.8",
|
|
4
4
|
"description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"author": "Ben Albahari",
|
package/src/nn.ts
CHANGED
|
@@ -97,6 +97,30 @@ export function mergeHeads(x: Tensor): Tensor {
|
|
|
97
97
|
return reshape(swapped, [...lead, T, H * d])
|
|
98
98
|
}
|
|
99
99
|
|
|
100
|
+
/** Slice a flat capture readback of shape `[H, ..., ...]` into one
|
|
101
|
+
* Float32Array per head. The leading axis is treated as the head axis;
|
|
102
|
+
* pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
|
|
103
|
+
* each holding the row-major data for that head (size = product of trailing
|
|
104
|
+
* axes). For B>1 graphs, prefix the result by the batch — this helper
|
|
105
|
+
* assumes the leading axis is heads, which matches how `splitHeads` lays
|
|
106
|
+
* out captures at B=1 (the typical capture-readback shape). */
|
|
107
|
+
export function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[] {
|
|
108
|
+
if (shape.length < 2) {
|
|
109
|
+
throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
|
|
110
|
+
}
|
|
111
|
+
// For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
|
|
112
|
+
// the leading 1 if present so callers can pass captureShapes[name] directly.
|
|
113
|
+
const s = shape[0] === 1 ? shape.slice(1) : shape
|
|
114
|
+
const H = s[0]!
|
|
115
|
+
let stride = 1
|
|
116
|
+
for (let i = 1; i < s.length; i++) stride *= s[i]!
|
|
117
|
+
const expected = H * stride
|
|
118
|
+
if (flat.length !== expected) {
|
|
119
|
+
throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`)
|
|
120
|
+
}
|
|
121
|
+
return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
|
|
122
|
+
}
|
|
123
|
+
|
|
100
124
|
// ----------------------------------------------------------------------------
|
|
101
125
|
// Loss helpers
|
|
102
126
|
// ----------------------------------------------------------------------------
|