@isidorus/cpu 0.0.0-alpha.1 → 0.0.0-alpha.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/graph.d.ts +25 -1
- package/dist/graph.d.ts.map +1 -1
- package/dist/graph.js +30 -2
- package/dist/graph.js.map +1 -1
- package/dist/index.d.ts +3 -0
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -0
- package/dist/index.js.map +1 -1
- package/dist/model/index.d.ts +5 -0
- package/dist/model/index.d.ts.map +1 -0
- package/dist/model/index.js +3 -0
- package/dist/model/index.js.map +1 -0
- package/dist/model/layer.d.ts +25 -0
- package/dist/model/layer.d.ts.map +1 -0
- package/dist/model/layer.js +2 -0
- package/dist/model/layer.js.map +1 -0
- package/dist/model/layers.d.ts +47 -0
- package/dist/model/layers.d.ts.map +1 -0
- package/dist/model/layers.js +191 -0
- package/dist/model/layers.js.map +1 -0
- package/dist/model/sequential.d.ts +91 -0
- package/dist/model/sequential.d.ts.map +1 -0
- package/dist/model/sequential.js +248 -0
- package/dist/model/sequential.js.map +1 -0
- package/dist/ops/math_ops.js +1 -1
- package/dist/ops/math_ops.js.map +1 -1
- package/dist/ops/nn_ops.js +9 -9
- package/dist/ops/nn_ops.js.map +1 -1
- package/dist/ops/variable_ops.d.ts.map +1 -1
- package/dist/ops/variable_ops.js +7 -9
- package/dist/ops/variable_ops.js.map +1 -1
- package/dist/optimizers/adam.d.ts +26 -0
- package/dist/optimizers/adam.d.ts.map +1 -0
- package/dist/optimizers/adam.js +97 -0
- package/dist/optimizers/adam.js.map +1 -0
- package/dist/optimizers/index.d.ts +5 -0
- package/dist/optimizers/index.d.ts.map +1 -0
- package/dist/optimizers/index.js +4 -0
- package/dist/optimizers/index.js.map +1 -0
- package/dist/optimizers/rmsprop.d.ts +22 -0
- package/dist/optimizers/rmsprop.d.ts.map +1 -0
- package/dist/optimizers/rmsprop.js +65 -0
- package/dist/optimizers/rmsprop.js.map +1 -0
- package/dist/optimizers/sgd.d.ts +53 -0
- package/dist/optimizers/sgd.d.ts.map +1 -0
- package/dist/optimizers/sgd.js +76 -0
- package/dist/optimizers/sgd.js.map +1 -0
- package/dist/tsconfig.tsbuildinfo +1 -1
- package/package.json +1 -1
- package/src/native/graph.cc +136 -1
- package/src/native/graph.h +1 -11
package/dist/graph.d.ts
CHANGED
|
@@ -21,6 +21,9 @@ export type AttrValue = {
|
|
|
21
21
|
} | {
|
|
22
22
|
kind: "list_int";
|
|
23
23
|
value: number[];
|
|
24
|
+
} | {
|
|
25
|
+
kind: "string";
|
|
26
|
+
value: string;
|
|
24
27
|
} | {
|
|
25
28
|
kind: "tensor";
|
|
26
29
|
value: InlineTensor;
|
|
@@ -58,9 +61,30 @@ export declare class Graph {
|
|
|
58
61
|
* @param inputs Output references from prior ops
|
|
59
62
|
* @param attrs Op attributes
|
|
60
63
|
* @param name Optional explicit op name (auto-generated if omitted)
|
|
64
|
+
* @param controlInputs Op names that must complete before this op runs.
|
|
65
|
+
* Used by globalVariablesInitializer to sequence init
|
|
66
|
+
* ops before the NoOp target that callers wait on.
|
|
61
67
|
* @returns Array of output Tensors (one per op output)
|
|
62
68
|
*/
|
|
63
|
-
addOp(type: string, inputs: TFOutput[], attrs?: Record<string, AttrValue>, name?: string): Tensor[];
|
|
69
|
+
addOp(type: string, inputs: TFOutput[], attrs?: Record<string, AttrValue>, name?: string, controlInputs?: string[]): Tensor[];
|
|
70
|
+
/**
|
|
71
|
+
* addGradients — compute symbolic gradients via TF_AddGradients.
|
|
72
|
+
*
|
|
73
|
+
* Injects gradient ops into the graph. Returns one gradient Tensor per
|
|
74
|
+
* entry in `x` — the partial derivative dSum(y)/dx_i.
|
|
75
|
+
*
|
|
76
|
+
* @param y Loss outputs to differentiate (typically a scalar loss tensor)
|
|
77
|
+
* @param x Parameters to differentiate with respect to
|
|
78
|
+
* @param dx Initial upstream gradients (default: ones, i.e. dL/dy = 1)
|
|
79
|
+
*
|
|
80
|
+
* @example
|
|
81
|
+
* const loss = ops.sparseSoftmaxCrossEntropyWithLogits(g, labels, logits);
|
|
82
|
+
* const wVal = ops.readVariable(g, wHandle, DType.FLOAT32);
|
|
83
|
+
* const [dw] = g.addGradients([loss], [wVal]);
|
|
84
|
+
* // dw is now a Tensor representing dLoss/dW
|
|
85
|
+
* // pass it to applyGradientDescent or applyAdam
|
|
86
|
+
*/
|
|
87
|
+
addGradients(y: TFOutput[], x: TFOutput[], dx?: TFOutput[]): Tensor[];
|
|
64
88
|
/** Whether an op with the given name exists in this graph. */
|
|
65
89
|
hasOp(name: string): boolean;
|
|
66
90
|
/** Total number of ops in the graph. */
|
package/dist/graph.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"graph.d.ts","sourceRoot":"","sources":["../src/ts/graph.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,gBAAgB,CAAC;AAC7C,OAAO,EAAE,KAAK,EAA2B,MAAM,gBAAgB,CAAC;AAGhE,MAAM,MAAM,SAAS,GACjB;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,GAC9B;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,GAChC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,OAAO,CAAA;CAAE,GAChC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9B;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,KAAK,EAAE,MAAM,EAAE,CAAA;CAAE,GAClC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,KAAK,EAAE,KAAK,EAAE,CAAA;CAAE,GACrC;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,KAAK,EAAE,MAAM,EAAE,CAAA;CAAE,GACrC;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,KAAK,EAAE,YAAY,CAAA;CAAE,CAAC;AAE5C,MAAM,WAAW,YAAY;IAC3B,KAAK,EAAE,KAAK,CAAC;IACb,KAAK,EAAE,MAAM,EAAE,CAAC;IAChB,IAAI,EAAE,MAAM,CAAC;CACd;AAED,MAAM,MAAM,QAAQ,GAAG;IAAE,MAAM,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,CAAC;AAEzD;;;;;;;;;;;;GAYG;AACH,qBAAa,KAAK;IAChB,gBAAgB;IAChB,QAAQ,CAAC,OAAO,EAAE,GAAG,CAAC;gBAEV,MAAM,EAAE,GAAG;IAIvB
|
|
1
|
+
{"version":3,"file":"graph.d.ts","sourceRoot":"","sources":["../src/ts/graph.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,gBAAgB,CAAC;AAC7C,OAAO,EAAE,KAAK,EAA2B,MAAM,gBAAgB,CAAC;AAGhE,MAAM,MAAM,SAAS,GACjB;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,GAC9B;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,GAChC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,OAAO,CAAA;CAAE,GAChC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9B;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,KAAK,EAAE,MAAM,EAAE,CAAA;CAAE,GAClC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,KAAK,EAAE,KAAK,EAAE,CAAA;CAAE,GACrC;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,KAAK,EAAE,MAAM,EAAE,CAAA;CAAE,GACrC;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,GACjC;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,KAAK,EAAE,YAAY,CAAA;CAAE,CAAC;AAE5C,MAAM,WAAW,YAAY;IAC3B,KAAK,EAAE,KAAK,CAAC;IACb,KAAK,EAAE,MAAM,EAAE,CAAC;IAChB,IAAI,EAAE,MAAM,CAAC;CACd;AAED,MAAM,MAAM,QAAQ,GAAG;IAAE,MAAM,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAA;CAAE,CAAC;AAEzD;;;;;;;;;;;;GAYG;AACH,qBAAa,KAAK;IAChB,gBAAgB;IAChB,QAAQ,CAAC,OAAO,EAAE,GAAG,CAAC;gBAEV,MAAM,EAAE,GAAG;IAIvB;;;;;;;;;;;OAWG;IACH,KAAK,CACH,IAAI,EAAE,MAAM,EACZ,MAAM,EAAE,QAAQ,EAAE,EAClB,KAAK,GAAE,MAAM,CAAC,MAAM,EAAE,SAAS,CAAM,EACrC,IAAI,CAAC,EAAE,MAAM,EACb,aAAa,GAAE,MAAM,EAAO,GAC3B,MAAM,EAAE;IAwCX;;;;;;;;;;;;;;;;OAgBG;IACH,YAAY,CAAC,CAAC,EAAE,QAAQ,EAAE,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,EAAE,QAAQ,EAAE,GAAG,MAAM,EAAE;IAgBrE,8DAA8D;IAC9D,KAAK,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO;IAI5B,wCAAwC;IACxC,IAAI,MAAM,IAAI,MAAM,CAEnB;IAED;;;OAGG;IACH,UAAU,IAAI,MAAM;IAIpB;;;;;;;;;;;;;;OAcG;IACH,cAAc,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI;IAKpC,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,SAAI,GAAG,MAAM,GAAG,IAAI;CAW9C"}
|
package/dist/graph.js
CHANGED
|
@@ -25,9 +25,12 @@ export class Graph {
|
|
|
25
25
|
* @param inputs Output references from prior ops
|
|
26
26
|
* @param attrs Op attributes
|
|
27
27
|
* @param name Optional explicit op name (auto-generated if omitted)
|
|
28
|
+
* @param controlInputs Op names that must complete before this op runs.
|
|
29
|
+
* Used by globalVariablesInitializer to sequence init
|
|
30
|
+
* ops before the NoOp target that callers wait on.
|
|
28
31
|
* @returns Array of output Tensors (one per op output)
|
|
29
32
|
*/
|
|
30
|
-
addOp(type, inputs, attrs = {}, name) {
|
|
33
|
+
addOp(type, inputs, attrs = {}, name, controlInputs = []) {
|
|
31
34
|
const nativeAttrs = {};
|
|
32
35
|
for (const [k, v] of Object.entries(attrs)) {
|
|
33
36
|
if (v.kind === "list_type") {
|
|
@@ -40,7 +43,7 @@ export class Graph {
|
|
|
40
43
|
nativeAttrs[k] = v;
|
|
41
44
|
}
|
|
42
45
|
}
|
|
43
|
-
const result = this._native.addOp(type, inputs, nativeAttrs, name);
|
|
46
|
+
const result = this._native.addOp(type, inputs, nativeAttrs, name, controlInputs);
|
|
44
47
|
const { opName, numOutputs } = result;
|
|
45
48
|
const tensors = [];
|
|
46
49
|
for (let i = 0; i < numOutputs; i++) {
|
|
@@ -50,6 +53,31 @@ export class Graph {
|
|
|
50
53
|
}
|
|
51
54
|
return tensors;
|
|
52
55
|
}
|
|
56
|
+
/**
|
|
57
|
+
* addGradients — compute symbolic gradients via TF_AddGradients.
|
|
58
|
+
*
|
|
59
|
+
* Injects gradient ops into the graph. Returns one gradient Tensor per
|
|
60
|
+
* entry in `x` — the partial derivative dSum(y)/dx_i.
|
|
61
|
+
*
|
|
62
|
+
* @param y Loss outputs to differentiate (typically a scalar loss tensor)
|
|
63
|
+
* @param x Parameters to differentiate with respect to
|
|
64
|
+
* @param dx Initial upstream gradients (default: ones, i.e. dL/dy = 1)
|
|
65
|
+
*
|
|
66
|
+
* @example
|
|
67
|
+
* const loss = ops.sparseSoftmaxCrossEntropyWithLogits(g, labels, logits);
|
|
68
|
+
* const wVal = ops.readVariable(g, wHandle, DType.FLOAT32);
|
|
69
|
+
* const [dw] = g.addGradients([loss], [wVal]);
|
|
70
|
+
* // dw is now a Tensor representing dLoss/dW
|
|
71
|
+
* // pass it to applyGradientDescent or applyAdam
|
|
72
|
+
*/
|
|
73
|
+
addGradients(y, x, dx) {
|
|
74
|
+
const raw = this._native.addGradients(y, x, dx ?? null);
|
|
75
|
+
return raw.map(({ opName, index }) => {
|
|
76
|
+
const tfDtype = this._native.opOutputType(opName, index);
|
|
77
|
+
const tfShape = this._native.opOutputShape(opName, index);
|
|
78
|
+
return makeTensor(opName, index, tfDtype != null ? tfDtype : null, tfShape != null ? ShapeFromTF(tfShape) : null);
|
|
79
|
+
});
|
|
80
|
+
}
|
|
53
81
|
/** Whether an op with the given name exists in this graph. */
|
|
54
82
|
hasOp(name) {
|
|
55
83
|
return this._native.hasOp(name);
|
package/dist/graph.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"graph.js","sourceRoot":"","sources":["../src/ts/graph.ts"],"names":[],"mappings":"AACA,OAAO,EAAS,UAAU,EAAE,WAAW,EAAE,MAAM,gBAAgB,CAAC;
|
|
1
|
+
{"version":3,"file":"graph.js","sourceRoot":"","sources":["../src/ts/graph.ts"],"names":[],"mappings":"AACA,OAAO,EAAS,UAAU,EAAE,WAAW,EAAE,MAAM,gBAAgB,CAAC;AAsBhE;;;;;;;;;;;;GAYG;AACH,MAAM,OAAO,KAAK;IAChB,gBAAgB;IACP,OAAO,CAAM;IAEtB,YAAY,MAAW;QACrB,IAAI,CAAC,OAAO,GAAG,MAAM,CAAC;IACxB,CAAC;IAED;;;;;;;;;;;OAWG;IACH,KAAK,CACH,IAAY,EACZ,MAAkB,EAClB,QAAmC,EAAE,EACrC,IAAa,EACb,gBAA0B,EAAE;QAE5B,MAAM,WAAW,GAAwB,EAAE,CAAC;QAC5C,KAAK,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;YAC3C,IAAI,CAAC,CAAC,IAAI,KAAK,WAAW,EAAE,CAAC;gBAC3B,WAAW,CAAC,CAAC,CAAC,GAAG,EAAE,IAAI,EAAE,WAAW,EAAE,KAAK,EAAE,CAAC,CAAC,KAAK,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC;YACrE,CAAC;iBAAM,IAAI,CAAC,CAAC,IAAI,KAAK,MAAM,EAAE,CAAC;gBAC7B,WAAW,CAAC,CAAC,CAAC,GAAG,EAAE,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC;YAC5D,CAAC;iBAAM,CAAC;gBACN,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;YACrB,CAAC;QACH,CAAC;QAED,MAAM,MAAM,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAC/B,IAAI,EACJ,MAAM,EACN,WAAW,EACX,IAAI,EACJ,aAAa,CACd,CAAC;QACF,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,GAAG,MAG9B,CAAC;QAEF,MAAM,OAAO,GAAa,EAAE,CAAC;QAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,MAAM,EAAE,CAAC,CAAkB,CAAC;YACtE,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,EAAE,CAAC,CAAoB,CAAC;YACzE,OAAO,CAAC,IAAI,CACV,UAAU,CACR,MAAM,EACN,CAAC,EACD,OAAO,IAAI,IAAI,CAAC,CAAC,CAAE,OAAiB,CAAC,CAAC,CAAC,IAAI,EAC3C,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,IAAI,CAC9C,CACF,CAAC;QACJ,CAAC;QACD,OAAO,OAAO,CAAC;IACjB,CAAC;IAED;;;;;;;;;;;;;;;;OAgBG;IACH,YAAY,CAAC,CAAa,EAAE,CAAa,EAAE,EAAe;QACxD,MAAM,GAAG,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,IAAI,IAAI,CAAe,CAAC;QACtE,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,EAAE,EAAE;YACnC,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,MAAM,EAAE,KAAK,CAAkB,CAAC;YAC1E,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,EAAE,KAAK,CAEhD,CAAC;YACT,OAAO,UAAU,CACf,MAAM,EACN,KAAK,EACL,OAAO,IAAI,IAAI,CAAC,CAAC,CAAE,OAAiB,CAAC,CAAC,CAAC,IAAI,EAC3C,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,IAAI,CAC9C,CAAC;QACJ,CAAC,CAAC,CAAC;IACL,CAAC;IAED,8DAA8D;IAC9D,KAAK,CAAC,IAAY;QAChB,OAAO,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAY,CAAC;IAC7C,CAAC;IAED,wCAAwC;IACxC,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,EAAY,CAAC;IACzC,CAAC;IAED;;;OAGG;IACH,UAAU;QACR,OAAO,IAAI,CAAC,OAAO,CAAC,UAAU,EAAY,CAAC;IAC7C,CAAC;IAED;;;;;;;;;;;;;;OAcG;IACH,cAAc,CAAC,MAAc;QAC3B,IAAI,CAAC,OAAO,CAAC,cAAc,CAAC,MAAM,CAAC,CAAC;IACtC,CAAC;IAED,WAAW;IACX,KAAK,CAAC,IAAY,EAAE,KAAK,GAAG,CAAC;QAC3B,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC;YAAE,OAAO,IAAI,CAAC;QAC3C,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,EAAE,KAAK,CAAkB,CAAC;QACxE,MAAM,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,IAAI,EAAE,KAAK,CAAoB,CAAC;QAC3E,OAAO,UAAU,CACf,IAAI,EACJ,KAAK,EACL,OAAO,IAAI,IAAI,CAAC,CAAC,CAAE,OAAiB,CAAC,CAAC,CAAC,IAAI,EAC3C,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,IAAI,CAC9C,CAAC;IACJ,CAAC;CACF"}
|
package/dist/index.d.ts
CHANGED
|
@@ -44,4 +44,7 @@ export declare function session(g: Graph, options?: {
|
|
|
44
44
|
export * as ops from "./ops/index.js";
|
|
45
45
|
export type { PoolOptions, PoolResult, ExecutionStrategy, } from "./inference-pool.js";
|
|
46
46
|
export { InferencePool } from "./inference-pool.js";
|
|
47
|
+
export * as optimizers from "./optimizers/index.js";
|
|
48
|
+
export type { ActivationFn, Layer, LayerParam, LossFn, TrainStepResult, } from "./model/index.js";
|
|
49
|
+
export { Dense, Flatten, Conv2D, Sequential } from "./model/index.js";
|
|
47
50
|
//# sourceMappingURL=index.d.ts.map
|
package/dist/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/ts/index.ts"],"names":[],"mappings":"AAAA;;;;;;GAMG;AAsBH,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,gBAAgB,CAAC;AACpD,OAAO,EACL,KAAK,EACL,aAAa,EACb,SAAS,EACT,UAAU,EACV,QAAQ,GACT,MAAM,gBAAgB,CAAC;AAGxB,YAAY,EAAE,SAAS,EAAE,YAAY,EAAE,QAAQ,EAAE,MAAM,YAAY,CAAC;AACpE,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,YAAY,EAAE,SAAS,EAAE,WAAW,EAAE,MAAM,cAAc,CAAC;AAC3D,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAIvC,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC;;;;;;GAMG;AACH,wBAAgB,KAAK,IAAI,KAAK,CAE7B;AAED;;;;;;;;;;;;;GAaG;AACH,wBAAgB,OAAO,CACrB,CAAC,EAAE,KAAK,EACR,OAAO,CAAC,EAAE;IACR,QAAQ,CAAC,EAAE,aAAa,GAAG,aAAa,CAAC;IACzC,cAAc,CAAC,EAAE,MAAM,CAAC;IACxB,cAAc,CAAC,EAAE,MAAM,CAAC;IACxB,YAAY,CAAC,EAAE,MAAM,CAAC;CACvB,GACA,OAAO,CAET;AAGD,OAAO,KAAK,GAAG,MAAM,gBAAgB,CAAC;AAGtC,YAAY,EACV,WAAW,EACX,UAAU,EACV,iBAAiB,GAClB,MAAM,qBAAqB,CAAC;AAC7B,OAAO,EAAE,aAAa,EAAE,MAAM,qBAAqB,CAAC"}
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/ts/index.ts"],"names":[],"mappings":"AAAA;;;;;;GAMG;AAsBH,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,gBAAgB,CAAC;AACpD,OAAO,EACL,KAAK,EACL,aAAa,EACb,SAAS,EACT,UAAU,EACV,QAAQ,GACT,MAAM,gBAAgB,CAAC;AAGxB,YAAY,EAAE,SAAS,EAAE,YAAY,EAAE,QAAQ,EAAE,MAAM,YAAY,CAAC;AACpE,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,YAAY,EAAE,SAAS,EAAE,WAAW,EAAE,MAAM,cAAc,CAAC;AAC3D,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAIvC,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC;;;;;;GAMG;AACH,wBAAgB,KAAK,IAAI,KAAK,CAE7B;AAED;;;;;;;;;;;;;GAaG;AACH,wBAAgB,OAAO,CACrB,CAAC,EAAE,KAAK,EACR,OAAO,CAAC,EAAE;IACR,QAAQ,CAAC,EAAE,aAAa,GAAG,aAAa,CAAC;IACzC,cAAc,CAAC,EAAE,MAAM,CAAC;IACxB,cAAc,CAAC,EAAE,MAAM,CAAC;IACxB,YAAY,CAAC,EAAE,MAAM,CAAC;CACvB,GACA,OAAO,CAET;AAGD,OAAO,KAAK,GAAG,MAAM,gBAAgB,CAAC;AAGtC,YAAY,EACV,WAAW,EACX,UAAU,EACV,iBAAiB,GAClB,MAAM,qBAAqB,CAAC;AAC7B,OAAO,EAAE,aAAa,EAAE,MAAM,qBAAqB,CAAC;AAGpD,OAAO,KAAK,UAAU,MAAM,uBAAuB,CAAC;AACpD,YAAY,EACV,YAAY,EACZ,KAAK,EACL,UAAU,EACV,MAAM,EACN,eAAe,GAChB,MAAM,kBAAkB,CAAC;AAC1B,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,kBAAkB,CAAC"}
|
package/dist/index.js
CHANGED
|
@@ -55,4 +55,7 @@ export function session(g, options) {
|
|
|
55
55
|
// ── Ops namespace ─────────────────────────────────────────────────────────────
|
|
56
56
|
export * as ops from "./ops/index.js";
|
|
57
57
|
export { InferencePool } from "./inference-pool.js";
|
|
58
|
+
// ── Model layers and optimizers ───────────────────────────────────────────────
|
|
59
|
+
export * as optimizers from "./optimizers/index.js";
|
|
60
|
+
export { Dense, Flatten, Conv2D, Sequential } from "./model/index.js";
|
|
58
61
|
//# sourceMappingURL=index.js.map
|
package/dist/index.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/ts/index.ts"],"names":[],"mappings":"AAAA;;;;;;GAMG;AAEH,OAAO,EAAE,QAAQ,EAAE,MAAM,4BAA4B,CAAC;AACtD,OAAO,YAAY,MAAM,gBAAgB,CAAC;AAC1C,OAAO,EAAE,aAAa,EAAE,MAAM,KAAK,CAAC;AACpC,OAAO,EAAE,OAAO,EAAE,IAAI,EAAE,MAAM,MAAM,CAAC;AAErC,gFAAgF;AAChF,yEAAyE;AACzE,qEAAqE;AACrE,sDAAsD;AACtD,MAAM,QAAQ,EAAE,CAAC;AAEjB,MAAM,UAAU,GAAG,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;AAClD,MAAM,SAAS,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC;AAEtC,MAAM,KAAK,GAAG,YAAY,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAQ,CAAC;AAE/D,OAAO,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAC;AACxC,QAAQ,CAAC,KAAK,CAAC,CAAC;AAIhB,OAAO,EACL,KAAK,EACL,aAAa,EACb,SAAS,EACT,UAAU,EACV,QAAQ,GACT,MAAM,gBAAgB,CAAC;AAIxB,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AAEnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC,iFAAiF;AAEjF,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC;;;;;;GAMG;AACH,MAAM,UAAU,KAAK;IACnB,OAAO,IAAI,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;AACtC,CAAC;AAED;;;;;;;;;;;;;GAaG;AACH,MAAM,UAAU,OAAO,CACrB,CAAQ,EACR,OAKC;IAED,OAAO,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;AAC5D,CAAC;AAED,iFAAiF;AACjF,OAAO,KAAK,GAAG,MAAM,gBAAgB,CAAC;AAQtC,OAAO,EAAE,aAAa,EAAE,MAAM,qBAAqB,CAAC"}
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/ts/index.ts"],"names":[],"mappings":"AAAA;;;;;;GAMG;AAEH,OAAO,EAAE,QAAQ,EAAE,MAAM,4BAA4B,CAAC;AACtD,OAAO,YAAY,MAAM,gBAAgB,CAAC;AAC1C,OAAO,EAAE,aAAa,EAAE,MAAM,KAAK,CAAC;AACpC,OAAO,EAAE,OAAO,EAAE,IAAI,EAAE,MAAM,MAAM,CAAC;AAErC,gFAAgF;AAChF,yEAAyE;AACzE,qEAAqE;AACrE,sDAAsD;AACtD,MAAM,QAAQ,EAAE,CAAC;AAEjB,MAAM,UAAU,GAAG,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;AAClD,MAAM,SAAS,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC;AAEtC,MAAM,KAAK,GAAG,YAAY,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAQ,CAAC;AAE/D,OAAO,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAC;AACxC,QAAQ,CAAC,KAAK,CAAC,CAAC;AAIhB,OAAO,EACL,KAAK,EACL,aAAa,EACb,SAAS,EACT,UAAU,EACV,QAAQ,GACT,MAAM,gBAAgB,CAAC;AAIxB,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AAEnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC,iFAAiF;AAEjF,OAAO,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACnC,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAC;AAEvC;;;;;;GAMG;AACH,MAAM,UAAU,KAAK;IACnB,OAAO,IAAI,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;AACtC,CAAC;AAED;;;;;;;;;;;;;GAaG;AACH,MAAM,UAAU,OAAO,CACrB,CAAQ,EACR,OAKC;IAED,OAAO,IAAI,OAAO,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;AAC5D,CAAC;AAED,iFAAiF;AACjF,OAAO,KAAK,GAAG,MAAM,gBAAgB,CAAC;AAQtC,OAAO,EAAE,aAAa,EAAE,MAAM,qBAAqB,CAAC;AAEpD,iFAAiF;AACjF,OAAO,KAAK,UAAU,MAAM,uBAAuB,CAAC;AAQpD,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,kBAAkB,CAAC"}
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
export type { ActivationFn, Layer, LayerParam } from "./layer.js";
|
|
2
|
+
export { Dense, Flatten, Conv2D } from "./layers.js";
|
|
3
|
+
export type { LossFn, TrainStepResult } from "./sequential.js";
|
|
4
|
+
export { Sequential } from "./sequential.js";
|
|
5
|
+
//# sourceMappingURL=index.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/ts/model/index.ts"],"names":[],"mappings":"AAAA,YAAY,EAAE,YAAY,EAAE,KAAK,EAAE,UAAU,EAAE,MAAM,YAAY,CAAC;AAClE,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AACrD,YAAY,EAAE,MAAM,EAAE,eAAe,EAAE,MAAM,iBAAiB,CAAC;AAC/D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/ts/model/index.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AAErD,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC"}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import type { Tensor } from "@isidorus/core";
|
|
2
|
+
import { DType } from "@isidorus/core";
|
|
3
|
+
import type { Graph } from "../graph.js";
|
|
4
|
+
export type ActivationFn = "relu" | "leaky_relu" | "relu6" | "sigmoid" | "tanh" | "softmax" | "log_softmax" | "elu" | "selu" | "swish" | "gelu" | "linear";
|
|
5
|
+
export interface LayerParam {
|
|
6
|
+
handle: Tensor;
|
|
7
|
+
read: Tensor;
|
|
8
|
+
dtype: DType;
|
|
9
|
+
name: string;
|
|
10
|
+
initOp: string;
|
|
11
|
+
}
|
|
12
|
+
export interface Layer {
|
|
13
|
+
readonly name: string;
|
|
14
|
+
/**
|
|
15
|
+
* Build — add ops to the graph.
|
|
16
|
+
* Called once by Sequential in order.
|
|
17
|
+
* @returns output shape after this layer (null dims = dynamic)
|
|
18
|
+
*/
|
|
19
|
+
build(g: Graph, input: Tensor, inputShape: (number | null)[]): (number | null)[];
|
|
20
|
+
/** Symbolic output tensor. Set during build(). */
|
|
21
|
+
readonly output: Tensor;
|
|
22
|
+
/** Parameters exposed to the optimizer. Set during build(). */
|
|
23
|
+
readonly layerParams: LayerParam[];
|
|
24
|
+
}
|
|
25
|
+
//# sourceMappingURL=layer.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"layer.d.ts","sourceRoot":"","sources":["../../src/ts/model/layer.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,gBAAgB,CAAC;AAC7C,OAAO,EAAE,KAAK,EAAE,MAAM,gBAAgB,CAAC;AACvC,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,aAAa,CAAC;AAkBzC,MAAM,MAAM,YAAY,GACpB,MAAM,GACN,YAAY,GACZ,OAAO,GACP,SAAS,GACT,MAAM,GACN,SAAS,GACT,aAAa,GACb,KAAK,GACL,MAAM,GACN,OAAO,GACP,MAAM,GACN,QAAQ,CAAC;AAEb,MAAM,WAAW,UAAU;IACzB,MAAM,EAAE,MAAM,CAAC;IACf,IAAI,EAAE,MAAM,CAAC;IACb,KAAK,EAAE,KAAK,CAAC;IACb,IAAI,EAAE,MAAM,CAAC;IACb,MAAM,EAAE,MAAM,CAAC;CAChB;AAED,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,IAAI,EAAE,MAAM,CAAC;IAEtB;;;;OAIG;IACH,KAAK,CACH,CAAC,EAAE,KAAK,EACR,KAAK,EAAE,MAAM,EACb,UAAU,EAAE,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,GAC5B,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,CAAC;IAErB,kDAAkD;IAClD,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAC;IAExB,+DAA+D;IAC/D,QAAQ,CAAC,WAAW,EAAE,UAAU,EAAE,CAAC;CACpC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"layer.js","sourceRoot":"","sources":["../../src/ts/model/layer.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import type { Tensor } from "@isidorus/core";
|
|
2
|
+
import type { Graph } from "../graph.js";
|
|
3
|
+
import type { Layer, LayerParam, ActivationFn } from "./layer.js";
|
|
4
|
+
export declare class Dense implements Layer {
|
|
5
|
+
readonly name: string;
|
|
6
|
+
output: Tensor;
|
|
7
|
+
readonly layerParams: LayerParam[];
|
|
8
|
+
private readonly units;
|
|
9
|
+
private readonly activation;
|
|
10
|
+
private readonly useBias;
|
|
11
|
+
constructor(units: number, options?: {
|
|
12
|
+
activation?: ActivationFn;
|
|
13
|
+
useBias?: boolean;
|
|
14
|
+
name?: string;
|
|
15
|
+
});
|
|
16
|
+
build(g: Graph, input: Tensor, inputShape: (number | null)[]): (number | null)[];
|
|
17
|
+
}
|
|
18
|
+
export declare class Flatten implements Layer {
|
|
19
|
+
readonly name: string;
|
|
20
|
+
output: Tensor;
|
|
21
|
+
readonly layerParams: LayerParam[];
|
|
22
|
+
constructor(options?: {
|
|
23
|
+
name?: string;
|
|
24
|
+
});
|
|
25
|
+
build(g: Graph, input: Tensor, inputShape: (number | null)[]): (number | null)[];
|
|
26
|
+
}
|
|
27
|
+
export declare class Conv2D implements Layer {
|
|
28
|
+
readonly name: string;
|
|
29
|
+
output: Tensor;
|
|
30
|
+
readonly layerParams: LayerParam[];
|
|
31
|
+
private readonly filters;
|
|
32
|
+
private readonly kernelSize;
|
|
33
|
+
private readonly strides;
|
|
34
|
+
private readonly padding;
|
|
35
|
+
private readonly activation;
|
|
36
|
+
private readonly useBias;
|
|
37
|
+
constructor(filters: number, options?: {
|
|
38
|
+
kernelSize?: number | [number, number];
|
|
39
|
+
strides?: number | [number, number];
|
|
40
|
+
padding?: "SAME" | "VALID";
|
|
41
|
+
activation?: ActivationFn;
|
|
42
|
+
useBias?: boolean;
|
|
43
|
+
name?: string;
|
|
44
|
+
});
|
|
45
|
+
build(g: Graph, input: Tensor, inputShape: (number | null)[]): (number | null)[];
|
|
46
|
+
}
|
|
47
|
+
//# sourceMappingURL=layers.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"layers.d.ts","sourceRoot":"","sources":["../../src/ts/model/layers.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,gBAAgB,CAAC;AAE7C,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,aAAa,CAAC;AACzC,OAAO,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,YAAY,EAAE,MAAM,YAAY,CAAC;AA8DlE,qBAAa,KAAM,YAAW,KAAK;IACjC,QAAQ,CAAC,IAAI,EAAE,MAAM,CAAC;IACtB,MAAM,EAAG,MAAM,CAAC;IAChB,QAAQ,CAAC,WAAW,EAAE,UAAU,EAAE,CAAM;IAExC,OAAO,CAAC,QAAQ,CAAC,KAAK,CAAS;IAC/B,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAe;IAC1C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;gBAGhC,KAAK,EAAE,MAAM,EACb,OAAO,GAAE;QACP,UAAU,CAAC,EAAE,YAAY,CAAC;QAC1B,OAAO,CAAC,EAAE,OAAO,CAAC;QAClB,IAAI,CAAC,EAAE,MAAM,CAAC;KACV;IAQR,KAAK,CACH,CAAC,EAAE,KAAK,EACR,KAAK,EAAE,MAAM,EACb,UAAU,EAAE,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,GAC5B,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE;CAoErB;AAKD,qBAAa,OAAQ,YAAW,KAAK;IACnC,QAAQ,CAAC,IAAI,EAAE,MAAM,CAAC;IACtB,MAAM,EAAG,MAAM,CAAC;IAChB,QAAQ,CAAC,WAAW,EAAE,UAAU,EAAE,CAAM;gBAE5B,OAAO,GAAE;QAAE,IAAI,CAAC,EAAE,MAAM,CAAA;KAAO;IAI3C,KAAK,CACH,CAAC,EAAE,KAAK,EACR,KAAK,EAAE,MAAM,EACb,UAAU,EAAE,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,GAC5B,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE;CA2BrB;AAKD,qBAAa,MAAO,YAAW,KAAK;IAClC,QAAQ,CAAC,IAAI,EAAE,MAAM,CAAC;IACtB,MAAM,EAAG,MAAM,CAAC;IAChB,QAAQ,CAAC,WAAW,EAAE,UAAU,EAAE,CAAM;IAExC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAmB;IAC9C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAmC;IAC3D,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAmB;IAC3C,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAe;IAC1C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;gBAGhC,OAAO,EAAE,MAAM,EACf,OAAO,GAAE;QACP,UAAU,CAAC,EAAE,MAAM,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QACvC,OAAO,CAAC,EAAE,MAAM,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QACpC,OAAO,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC;QAC3B,UAAU,CAAC,EAAE,YAAY,CAAC;QAC1B,OAAO,CAAC,EAAE,OAAO,CAAC;QAClB,IAAI,CAAC,EAAE,MAAM,CAAC;KACV;IAgBR,KAAK,CACH,CAAC,EAAE,KAAK,EACR,KAAK,EAAE,MAAM,EACb,UAAU,EAAE,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,GAC5B,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE;CAqGrB"}
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import { DType } from "@isidorus/core";
|
|
2
|
+
import { variableWithInit, readVariable, zerosInitializer, glorotUniformInitializer, truncatedNormalInitializer, } from "../ops/variable_ops.js";
|
|
3
|
+
import { constant } from "../ops/array_ops.js";
|
|
4
|
+
import { matmul, biasAdd } from "../ops/math_ops.js";
|
|
5
|
+
import { relu, leakyRelu, relu6, sigmoid, conv2d as conv2dOp, tanh, softmax, logSoftmax, elu, selu, swish, gelu, } from "../ops/nn_ops.js";
|
|
6
|
+
// ---------------------------------------------------------------------------
|
|
7
|
+
// Activation helper
|
|
8
|
+
// ---------------------------------------------------------------------------
|
|
9
|
+
function activate(g, x, fn, name) {
|
|
10
|
+
switch (fn) {
|
|
11
|
+
case "relu":
|
|
12
|
+
return relu(g, x, `${name}/relu`);
|
|
13
|
+
case "leaky_relu":
|
|
14
|
+
return leakyRelu(g, x, 0.2, `${name}/leaky_relu`);
|
|
15
|
+
case "relu6":
|
|
16
|
+
return relu6(g, x, `${name}/relu6`);
|
|
17
|
+
case "sigmoid":
|
|
18
|
+
return sigmoid(g, x, `${name}/sigmoid`);
|
|
19
|
+
case "tanh":
|
|
20
|
+
return tanh(g, x, `${name}/tanh`);
|
|
21
|
+
case "softmax":
|
|
22
|
+
return softmax(g, x, `${name}/softmax`);
|
|
23
|
+
case "log_softmax":
|
|
24
|
+
return logSoftmax(g, x, `${name}/log_softmax`);
|
|
25
|
+
case "elu":
|
|
26
|
+
return elu(g, x, `${name}/elu`);
|
|
27
|
+
case "selu":
|
|
28
|
+
return selu(g, x, `${name}/selu`);
|
|
29
|
+
case "swish":
|
|
30
|
+
return swish(g, x, `${name}/swish`);
|
|
31
|
+
case "gelu":
|
|
32
|
+
return gelu(g, x, `${name}/gelu`);
|
|
33
|
+
case "linear":
|
|
34
|
+
return x;
|
|
35
|
+
default:
|
|
36
|
+
throw new Error(`Unknown activation: ${fn}`);
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
// ---------------------------------------------------------------------------
|
|
40
|
+
// Dense — fully-connected layer: output = activation(input @ W + b)
|
|
41
|
+
// ---------------------------------------------------------------------------
|
|
42
|
+
export class Dense {
|
|
43
|
+
name;
|
|
44
|
+
output;
|
|
45
|
+
layerParams = [];
|
|
46
|
+
units;
|
|
47
|
+
activation;
|
|
48
|
+
useBias;
|
|
49
|
+
constructor(units, options = {}) {
|
|
50
|
+
this.units = units;
|
|
51
|
+
this.activation = options.activation ?? "linear";
|
|
52
|
+
this.useBias = options.useBias ?? true;
|
|
53
|
+
this.name = options.name ?? `dense_${units}`;
|
|
54
|
+
}
|
|
55
|
+
build(g, input, inputShape) {
|
|
56
|
+
const inFeatures = inputShape[inputShape.length - 1];
|
|
57
|
+
if (!inFeatures || inFeatures < 1)
|
|
58
|
+
throw new Error(`Dense "${this.name}": last input dim must be known, got ${JSON.stringify(inputShape)}`);
|
|
59
|
+
// ── Weight W: [inFeatures, units] ────────────────────────────────────
|
|
60
|
+
const wInitVal = glorotUniformInitializer(g, [inFeatures, this.units], DType.FLOAT32, `${this.name}/w_glorot`);
|
|
61
|
+
const { handle: wHandle, initOp: wInitOp } = variableWithInit(g, [inFeatures, this.units], DType.FLOAT32, `${this.name}/w`, wInitVal);
|
|
62
|
+
const wRead = readVariable(g, wHandle, DType.FLOAT32, `${this.name}/w_read`);
|
|
63
|
+
let out = matmul(g, input, wRead, {}, `${this.name}/matmul`);
|
|
64
|
+
this.layerParams.push({
|
|
65
|
+
handle: wHandle,
|
|
66
|
+
read: wRead,
|
|
67
|
+
dtype: DType.FLOAT32,
|
|
68
|
+
name: `${this.name}/w`,
|
|
69
|
+
initOp: wInitOp,
|
|
70
|
+
});
|
|
71
|
+
// ── Bias b: [units] ───────────────────────────────────────────────────
|
|
72
|
+
if (this.useBias) {
|
|
73
|
+
const bInitVal = zerosInitializer(g, [this.units], DType.FLOAT32);
|
|
74
|
+
const { handle: bHandle, initOp: bInitOp } = variableWithInit(g, [this.units], DType.FLOAT32, `${this.name}/b`, bInitVal);
|
|
75
|
+
const bRead = readVariable(g, bHandle, DType.FLOAT32, `${this.name}/b_read`);
|
|
76
|
+
out = biasAdd(g, out, bRead, `${this.name}/bias_add`);
|
|
77
|
+
this.layerParams.push({
|
|
78
|
+
handle: bHandle,
|
|
79
|
+
read: bRead,
|
|
80
|
+
dtype: DType.FLOAT32,
|
|
81
|
+
name: `${this.name}/b`,
|
|
82
|
+
initOp: bInitOp,
|
|
83
|
+
});
|
|
84
|
+
}
|
|
85
|
+
this.output = activate(g, out, this.activation, this.name);
|
|
86
|
+
return [...inputShape.slice(0, -1), this.units];
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
// ---------------------------------------------------------------------------
|
|
90
|
+
// Flatten — reshapes [batch, d1, d2, ...] → [batch, d1*d2*...]
|
|
91
|
+
// ---------------------------------------------------------------------------
|
|
92
|
+
export class Flatten {
|
|
93
|
+
name;
|
|
94
|
+
output;
|
|
95
|
+
layerParams = [];
|
|
96
|
+
constructor(options = {}) {
|
|
97
|
+
this.name = options.name ?? "flatten";
|
|
98
|
+
}
|
|
99
|
+
build(g, input, inputShape) {
|
|
100
|
+
const spatialDims = inputShape.slice(1);
|
|
101
|
+
const hasUnknown = spatialDims.some((d) => d === null);
|
|
102
|
+
const flatSize = hasUnknown
|
|
103
|
+
? null
|
|
104
|
+
: spatialDims.reduce((a, b) => a * b, 1);
|
|
105
|
+
const flatDim = flatSize ?? -1;
|
|
106
|
+
const shapeBuf = Buffer.allocUnsafe(8);
|
|
107
|
+
shapeBuf.writeInt32LE(-1, 0);
|
|
108
|
+
shapeBuf.writeInt32LE(flatDim, 4);
|
|
109
|
+
const shapeConst = constant(g, shapeBuf, [2], DType.INT32, `${this.name}/shape`);
|
|
110
|
+
const [out] = g.addOp("Reshape", [input, shapeConst], {}, `${this.name}/reshape`);
|
|
111
|
+
this.output = out;
|
|
112
|
+
return [null, flatSize];
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
// ---------------------------------------------------------------------------
|
|
116
|
+
// Conv2D — 2D convolution (NHWC): output = activation(conv2d(input, W) + b)
|
|
117
|
+
// ---------------------------------------------------------------------------
|
|
118
|
+
export class Conv2D {
|
|
119
|
+
name;
|
|
120
|
+
output;
|
|
121
|
+
layerParams = [];
|
|
122
|
+
filters;
|
|
123
|
+
kernelSize;
|
|
124
|
+
strides;
|
|
125
|
+
padding;
|
|
126
|
+
activation;
|
|
127
|
+
useBias;
|
|
128
|
+
constructor(filters, options = {}) {
|
|
129
|
+
this.filters = filters;
|
|
130
|
+
this.padding = options.padding ?? "SAME";
|
|
131
|
+
this.activation = options.activation ?? "linear";
|
|
132
|
+
this.useBias = options.useBias ?? true;
|
|
133
|
+
this.name = options.name ?? `conv2d_${filters}f`;
|
|
134
|
+
const ks = options.kernelSize ?? 3;
|
|
135
|
+
this.kernelSize = Array.isArray(ks) ? ks : [ks, ks];
|
|
136
|
+
const st = options.strides ?? 1;
|
|
137
|
+
const [sH, sW] = Array.isArray(st) ? st : [st, st];
|
|
138
|
+
this.strides = [1, sH, sW, 1];
|
|
139
|
+
}
|
|
140
|
+
build(g, input, inputShape) {
|
|
141
|
+
if (inputShape.length !== 4)
|
|
142
|
+
throw new Error(`Conv2D "${this.name}": expects 4D input [batch,H,W,C], got rank ${inputShape.length}`);
|
|
143
|
+
const inChannels = inputShape[3];
|
|
144
|
+
if (!inChannels || inChannels < 1)
|
|
145
|
+
throw new Error(`Conv2D "${this.name}": in_channels must be known, got ${JSON.stringify(inputShape)}`);
|
|
146
|
+
const [kH, kW] = this.kernelSize;
|
|
147
|
+
const wShape = [kH, kW, inChannels, this.filters];
|
|
148
|
+
// He normal init for conv: stddev = sqrt(2 / (kH * kW * inChannels))
|
|
149
|
+
const stddev = Math.sqrt(2 / (kH * kW * inChannels));
|
|
150
|
+
const wInitVal = truncatedNormalInitializer(g, wShape, DType.FLOAT32, { stddev }, `${this.name}/w_init`);
|
|
151
|
+
const { handle: wHandle, initOp: wInitOp } = variableWithInit(g, wShape, DType.FLOAT32, `${this.name}/w`, wInitVal);
|
|
152
|
+
const wRead = readVariable(g, wHandle, DType.FLOAT32, `${this.name}/w_read`);
|
|
153
|
+
let out = conv2dOp(g, input, wRead, this.strides, this.padding, `${this.name}/conv`);
|
|
154
|
+
this.layerParams.push({
|
|
155
|
+
handle: wHandle,
|
|
156
|
+
read: wRead,
|
|
157
|
+
dtype: DType.FLOAT32,
|
|
158
|
+
name: `${this.name}/w`,
|
|
159
|
+
initOp: wInitOp,
|
|
160
|
+
});
|
|
161
|
+
if (this.useBias) {
|
|
162
|
+
const bInitVal = zerosInitializer(g, [this.filters], DType.FLOAT32);
|
|
163
|
+
const { handle: bHandle, initOp: bInitOp } = variableWithInit(g, [this.filters], DType.FLOAT32, `${this.name}/b`, bInitVal);
|
|
164
|
+
const bRead = readVariable(g, bHandle, DType.FLOAT32, `${this.name}/b_read`);
|
|
165
|
+
out = biasAdd(g, out, bRead, `${this.name}/bias_add`);
|
|
166
|
+
this.layerParams.push({
|
|
167
|
+
handle: bHandle,
|
|
168
|
+
read: bRead,
|
|
169
|
+
dtype: DType.FLOAT32,
|
|
170
|
+
name: `${this.name}/b`,
|
|
171
|
+
initOp: bInitOp,
|
|
172
|
+
});
|
|
173
|
+
}
|
|
174
|
+
this.output = activate(g, out, this.activation, this.name);
|
|
175
|
+
// Output spatial dimensions
|
|
176
|
+
const [, H, W] = inputShape;
|
|
177
|
+
const [, sH, sW] = this.strides;
|
|
178
|
+
const outH = H === null
|
|
179
|
+
? null
|
|
180
|
+
: this.padding === "SAME"
|
|
181
|
+
? Math.ceil(H / sH)
|
|
182
|
+
: Math.ceil((H - kH + 1) / sH);
|
|
183
|
+
const outW = W === null
|
|
184
|
+
? null
|
|
185
|
+
: this.padding === "SAME"
|
|
186
|
+
? Math.ceil(W / sW)
|
|
187
|
+
: Math.ceil((W - kW + 1) / sW);
|
|
188
|
+
return [null, outH, outW, this.filters];
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
//# sourceMappingURL=layers.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"layers.js","sourceRoot":"","sources":["../../src/ts/model/layers.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,KAAK,EAAE,MAAM,gBAAgB,CAAC;AAGvC,OAAO,EACL,gBAAgB,EAChB,YAAY,EACZ,gBAAgB,EAChB,wBAAwB,EACxB,0BAA0B,GAC3B,MAAM,wBAAwB,CAAC;AAChC,OAAO,EAAE,QAAQ,EAAE,MAAM,qBAAqB,CAAC;AAC/C,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,oBAAoB,CAAC;AACrD,OAAO,EACL,IAAI,EACJ,SAAS,EACT,KAAK,EACL,OAAO,EACP,MAAM,IAAI,QAAQ,EAClB,IAAI,EACJ,OAAO,EACP,UAAU,EACV,GAAG,EACH,IAAI,EACJ,KAAK,EACL,IAAI,GACL,MAAM,kBAAkB,CAAC;AAE1B,8EAA8E;AAC9E,oBAAoB;AACpB,8EAA8E;AAC9E,SAAS,QAAQ,CAAC,CAAQ,EAAE,CAAS,EAAE,EAAgB,EAAE,IAAY;IACnE,QAAQ,EAAE,EAAE,CAAC;QACX,KAAK,MAAM;YACT,OAAO,IAAI,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,OAAO,CAAC,CAAC;QACpC,KAAK,YAAY;YACf,OAAO,SAAS,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,GAAG,IAAI,aAAa,CAAC,CAAC;QACpD,KAAK,OAAO;YACV,OAAO,KAAK,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,QAAQ,CAAC,CAAC;QACtC,KAAK,SAAS;YACZ,OAAO,OAAO,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,UAAU,CAAC,CAAC;QAC1C,KAAK,MAAM;YACT,OAAO,IAAI,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,OAAO,CAAC,CAAC;QACpC,KAAK,SAAS;YACZ,OAAO,OAAO,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,UAAU,CAAC,CAAC;QAC1C,KAAK,aAAa;YAChB,OAAO,UAAU,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,cAAc,CAAC,CAAC;QACjD,KAAK,KAAK;YACR,OAAO,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,MAAM,CAAC,CAAC;QAClC,KAAK,MAAM;YACT,OAAO,IAAI,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,OAAO,CAAC,CAAC;QACpC,KAAK,OAAO;YACV,OAAO,KAAK,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,QAAQ,CAAC,CAAC;QACtC,KAAK,MAAM;YACT,OAAO,IAAI,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,IAAI,OAAO,CAAC,CAAC;QACpC,KAAK,QAAQ;YACX,OAAO,CAAC,CAAC;QACX;YACE,MAAM,IAAI,KAAK,CAAC,uBAAuB,EAAE,EAAE,CAAC,CAAC;IACjD,CAAC;AACH,CAAC;AAED,8EAA8E;AAC9E,oEAAoE;AACpE,8EAA8E;AAC9E,MAAM,OAAO,KAAK;IACP,IAAI,CAAS;IACtB,MAAM,CAAU;IACP,WAAW,GAAiB,EAAE,CAAC;IAEvB,KAAK,CAAS;IACd,UAAU,CAAe;IACzB,OAAO,CAAU;IAElC,YACE,KAAa,EACb,UAII,EAAE;QAEN,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QACnB,IAAI,CAAC,UAAU,GAAG,OAAO,CAAC,UAAU,IAAI,QAAQ,CAAC;QACjD,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC,OAAO,IAAI,IAAI,CAAC;QACvC,IAAI,CAAC,IAAI,GAAG,OAAO,CAAC,IAAI,IAAI,SAAS,KAAK,EAAE,CAAC;IAC/C,CAAC;IAED,KAAK,CACH,CAAQ,EACR,KAAa,EACb,UAA6B;QAE7B,MAAM,UAAU,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAW,CAAC;QAC/D,IAAI,CAAC,UAAU,IAAI,UAAU,GAAG,CAAC;YAC/B,MAAM,IAAI,KAAK,CACb,UACE,IAAI,CAAC,IACP,wCAAwC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CACrE,CAAC;QAEJ,wEAAwE;QACxE,MAAM,QAAQ,GAAG,wBAAwB,CACvC,CAAC,EACD,CAAC,UAAU,EAAE,IAAI,CAAC,KAAK,CAAC,EACxB,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,WAAW,CACxB,CAAC;QACF,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,gBAAgB,CAC3D,CAAC,EACD,CAAC,UAAU,EAAE,IAAI,CAAC,KAAK,CAAC,EACxB,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,IAAI,EAChB,QAAQ,CACT,CAAC;QACF,MAAM,KAAK,GAAG,YAAY,CACxB,CAAC,EACD,OAAO,EACP,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,SAAS,CACtB,CAAC;QAEF,IAAI,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,EAAE,EAAE,GAAG,IAAI,CAAC,IAAI,SAAS,CAAC,CAAC;QAC7D,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC;YACpB,MAAM,EAAE,OAAO;YACf,IAAI,EAAE,KAAK;YACX,KAAK,EAAE,KAAK,CAAC,OAAO;YACpB,IAAI,EAAE,GAAG,IAAI,CAAC,IAAI,IAAI;YACtB,MAAM,EAAE,OAAO;SAChB,CAAC,CAAC;QAEH,yEAAyE;QACzE,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YACjB,MAAM,QAAQ,GAAG,gBAAgB,CAAC,CAAC,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,KAAK,CAAC,OAAO,CAAC,CAAC;YAClE,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,gBAAgB,CAC3D,CAAC,EACD,CAAC,IAAI,CAAC,KAAK,CAAC,EACZ,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,IAAI,EAChB,QAAQ,CACT,CAAC;YACF,MAAM,KAAK,GAAG,YAAY,CACxB,CAAC,EACD,OAAO,EACP,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,SAAS,CACtB,CAAC;YACF,GAAG,GAAG,OAAO,CAAC,CAAC,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,IAAI,WAAW,CAAC,CAAC;YACtD,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC;gBACpB,MAAM,EAAE,OAAO;gBACf,IAAI,EAAE,KAAK;gBACX,KAAK,EAAE,KAAK,CAAC,OAAO;gBACpB,IAAI,EAAE,GAAG,IAAI,CAAC,IAAI,IAAI;gBACtB,MAAM,EAAE,OAAO;aAChB,CAAC,CAAC;QACL,CAAC;QAED,IAAI,CAAC,MAAM,GAAG,QAAQ,CAAC,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;QAC3D,OAAO,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IAClD,CAAC;CACF;AAED,8EAA8E;AAC9E,+DAA+D;AAC/D,8EAA8E;AAC9E,MAAM,OAAO,OAAO;IACT,IAAI,CAAS;IACtB,MAAM,CAAU;IACP,WAAW,GAAiB,EAAE,CAAC;IAExC,YAAY,UAA6B,EAAE;QACzC,IAAI,CAAC,IAAI,GAAG,OAAO,CAAC,IAAI,IAAI,SAAS,CAAC;IACxC,CAAC;IAED,KAAK,CACH,CAAQ,EACR,KAAa,EACb,UAA6B;QAE7B,MAAM,WAAW,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QACxC,MAAM,UAAU,GAAG,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,KAAK,IAAI,CAAC,CAAC;QACvD,MAAM,QAAQ,GAAG,UAAU;YACzB,CAAC,CAAC,IAAI;YACN,CAAC,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAE,GAAG,CAAE,EAAE,CAAC,CAAC,CAAC;QAC7C,MAAM,OAAO,GAAG,QAAQ,IAAI,CAAC,CAAC,CAAC;QAE/B,MAAM,QAAQ,GAAG,MAAM,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;QACvC,QAAQ,CAAC,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC7B,QAAQ,CAAC,YAAY,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;QAClC,MAAM,UAAU,GAAG,QAAQ,CACzB,CAAC,EACD,QAAQ,EACR,CAAC,CAAC,CAAC,EACH,KAAK,CAAC,KAAK,EACX,GAAG,IAAI,CAAC,IAAI,QAAQ,CACrB,CAAC;QACF,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,KAAK,CACnB,SAAS,EACT,CAAC,KAAK,EAAE,UAAU,CAAC,EACnB,EAAE,EACF,GAAG,IAAI,CAAC,IAAI,UAAU,CACvB,CAAC;QACF,IAAI,CAAC,MAAM,GAAG,GAAG,CAAC;QAClB,OAAO,CAAC,IAAI,EAAE,QAAQ,CAAC,CAAC;IAC1B,CAAC;CACF;AAED,8EAA8E;AAC9E,4EAA4E;AAC5E,8EAA8E;AAC9E,MAAM,OAAO,MAAM;IACR,IAAI,CAAS;IACtB,MAAM,CAAU;IACP,WAAW,GAAiB,EAAE,CAAC;IAEvB,OAAO,CAAS;IAChB,UAAU,CAAmB;IAC7B,OAAO,CAAmC;IAC1C,OAAO,CAAmB;IAC1B,UAAU,CAAe;IACzB,OAAO,CAAU;IAElC,YACE,OAAe,EACf,UAOI,EAAE;QAEN,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC,OAAO,IAAI,MAAM,CAAC;QACzC,IAAI,CAAC,UAAU,GAAG,OAAO,CAAC,UAAU,IAAI,QAAQ,CAAC;QACjD,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC,OAAO,IAAI,IAAI,CAAC;QACvC,IAAI,CAAC,IAAI,GAAG,OAAO,CAAC,IAAI,IAAI,UAAU,OAAO,GAAG,CAAC;QAEjD,MAAM,EAAE,GAAG,OAAO,CAAC,UAAU,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,UAAU,GAAG,KAAK,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAuB,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC;QAE1E,MAAM,EAAE,GAAG,OAAO,CAAC,OAAO,IAAI,CAAC,CAAC;QAChC,MAAM,CAAC,EAAE,EAAE,EAAE,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAuB,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC;QACzE,IAAI,CAAC,OAAO,GAAG,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;IAChC,CAAC;IAED,KAAK,CACH,CAAQ,EACR,KAAa,EACb,UAA6B;QAE7B,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC;YACzB,MAAM,IAAI,KAAK,CACb,WAAW,IAAI,CAAC,IAAI,+CAA+C,UAAU,CAAC,MAAM,EAAE,CACvF,CAAC;QAEJ,MAAM,UAAU,GAAG,UAAU,CAAC,CAAC,CAAW,CAAC;QAC3C,IAAI,CAAC,UAAU,IAAI,UAAU,GAAG,CAAC;YAC/B,MAAM,IAAI,KAAK,CACb,WAAW,IAAI,CAAC,IAAI,qCAAqC,IAAI,CAAC,SAAS,CACrE,UAAU,CACX,EAAE,CACJ,CAAC;QAEJ,MAAM,CAAC,EAAE,EAAE,EAAE,CAAC,GAAG,IAAI,CAAC,UAAU,CAAC;QACjC,MAAM,MAAM,GAAG,CAAC,EAAE,EAAE,EAAE,EAAE,UAAU,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;QAElD,qEAAqE;QACrE,MAAM,MAAM,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,EAAE,GAAG,UAAU,CAAC,CAAC,CAAC;QACrD,MAAM,QAAQ,GAAG,0BAA0B,CACzC,CAAC,EACD,MAAM,EACN,KAAK,CAAC,OAAO,EACb,EAAE,MAAM,EAAE,EACV,GAAG,IAAI,CAAC,IAAI,SAAS,CACtB,CAAC;QACF,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,gBAAgB,CAC3D,CAAC,EACD,MAAM,EACN,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,IAAI,EAChB,QAAQ,CACT,CAAC;QACF,MAAM,KAAK,GAAG,YAAY,CACxB,CAAC,EACD,OAAO,EACP,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,SAAS,CACtB,CAAC;QAEF,IAAI,GAAG,GAAG,QAAQ,CAChB,CAAC,EACD,KAAK,EACL,KAAK,EACL,IAAI,CAAC,OAAO,EACZ,IAAI,CAAC,OAAO,EACZ,GAAG,IAAI,CAAC,IAAI,OAAO,CACpB,CAAC;QACF,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC;YACpB,MAAM,EAAE,OAAO;YACf,IAAI,EAAE,KAAK;YACX,KAAK,EAAE,KAAK,CAAC,OAAO;YACpB,IAAI,EAAE,GAAG,IAAI,CAAC,IAAI,IAAI;YACtB,MAAM,EAAE,OAAO;SAChB,CAAC,CAAC;QAEH,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YACjB,MAAM,QAAQ,GAAG,gBAAgB,CAAC,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC,EAAE,KAAK,CAAC,OAAO,CAAC,CAAC;YACpE,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,GAAG,gBAAgB,CAC3D,CAAC,EACD,CAAC,IAAI,CAAC,OAAO,CAAC,EACd,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,IAAI,EAChB,QAAQ,CACT,CAAC;YACF,MAAM,KAAK,GAAG,YAAY,CACxB,CAAC,EACD,OAAO,EACP,KAAK,CAAC,OAAO,EACb,GAAG,IAAI,CAAC,IAAI,SAAS,CACtB,CAAC;YACF,GAAG,GAAG,OAAO,CAAC,CAAC,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,IAAI,WAAW,CAAC,CAAC;YACtD,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC;gBACpB,MAAM,EAAE,OAAO;gBACf,IAAI,EAAE,KAAK;gBACX,KAAK,EAAE,KAAK,CAAC,OAAO;gBACpB,IAAI,EAAE,GAAG,IAAI,CAAC,IAAI,IAAI;gBACtB,MAAM,EAAE,OAAO;aAChB,CAAC,CAAC;QACL,CAAC;QAED,IAAI,CAAC,MAAM,GAAG,QAAQ,CAAC,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;QAE3D,4BAA4B;QAC5B,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,GAAG,UAAU,CAAC;QAC5B,MAAM,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC;QAChC,MAAM,IAAI,GACR,CAAC,KAAK,IAAI;YACR,CAAC,CAAC,IAAI;YACN,CAAC,CAAC,IAAI,CAAC,OAAO,KAAK,MAAM;gBACzB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,EAAE,CAAC;gBACnB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,CAAC;QACnC,MAAM,IAAI,GACR,CAAC,KAAK,IAAI;YACR,CAAC,CAAC,IAAI;YACN,CAAC,CAAC,IAAI,CAAC,OAAO,KAAK,MAAM;gBACzB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,EAAE,CAAC;gBACnB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,CAAC;QAEnC,OAAO,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;IAC1C,CAAC;CACF"}
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import type { Tensor } from "@isidorus/core";
|
|
2
|
+
import { DType } from "@isidorus/core";
|
|
3
|
+
import type { Graph } from "../graph.js";
|
|
4
|
+
import type { Session } from "../session.js";
|
|
5
|
+
import type { Layer } from "./layer.js";
|
|
6
|
+
import type { ParamSpec, FeedEntry } from "../optimizers/sgd.js";
|
|
7
|
+
export type LossFn = "sparse_categorical_crossentropy" | "binary_crossentropy" | "mse";
|
|
8
|
+
export interface Optimizer {
|
|
9
|
+
init(sess: Session): Promise<void>;
|
|
10
|
+
applyGradients(sess: Session, feeds: FeedEntry[]): Promise<void>;
|
|
11
|
+
}
|
|
12
|
+
export interface TrainStepResult {
|
|
13
|
+
loss: number;
|
|
14
|
+
}
|
|
15
|
+
export declare class Sequential {
|
|
16
|
+
private readonly g;
|
|
17
|
+
private readonly layers;
|
|
18
|
+
private _xPlaceholder;
|
|
19
|
+
private _yPlaceholder;
|
|
20
|
+
private _outputTensor;
|
|
21
|
+
private _lossTensor;
|
|
22
|
+
private _allParams;
|
|
23
|
+
private _allInitOp;
|
|
24
|
+
private _labelDtype;
|
|
25
|
+
private compiled;
|
|
26
|
+
constructor(g: Graph, layers: Layer[]);
|
|
27
|
+
/**
|
|
28
|
+
* compile — wire the full computation graph.
|
|
29
|
+
*
|
|
30
|
+
* After this returns, model.params contains ParamSpec entries with
|
|
31
|
+
* real gradient tensors from g.addGradients(). Use those to construct
|
|
32
|
+
* your optimizer before calling init().
|
|
33
|
+
*/
|
|
34
|
+
compile(opts: {
|
|
35
|
+
loss: LossFn;
|
|
36
|
+
inputShape: number[];
|
|
37
|
+
labelDtype?: DType;
|
|
38
|
+
}): void;
|
|
39
|
+
/**
|
|
40
|
+
* init — run all variable initialisations and the optimizer's state init.
|
|
41
|
+
*
|
|
42
|
+
* @param sess The session to run on
|
|
43
|
+
* @param opt The optimizer (must be constructed from model.params)
|
|
44
|
+
*/
|
|
45
|
+
init(sess: Session, opt: Optimizer): Promise<void>;
|
|
46
|
+
/**
|
|
47
|
+
* trainStep — one forward pass, gradient computation, and weight update.
|
|
48
|
+
*
|
|
49
|
+
* Two sequential TF_SessionRun calls:
|
|
50
|
+
* 1. Fetch loss (forward pass runs as part of this)
|
|
51
|
+
* 2. Run optimizer update ops (backward pass + weight update)
|
|
52
|
+
*
|
|
53
|
+
* Keeping them separate avoids the optimizer needing to expose its
|
|
54
|
+
* internal step op name. The overhead is one extra C++ call per step,
|
|
55
|
+
* which is negligible compared to the matmul/conv cost.
|
|
56
|
+
*
|
|
57
|
+
* @param sess Session to run on
|
|
58
|
+
* @param opt Optimizer (same instance used in init)
|
|
59
|
+
* @param xBuf Float32 input bytes [batchSize, ...inputShape]
|
|
60
|
+
* @param yBuf Label bytes (INT32 class indices or FLOAT32 values)
|
|
61
|
+
* @param xShape [batchSize, ...inputShape]
|
|
62
|
+
* @param yShape [batchSize] for classification, [batchSize, units] for mse
|
|
63
|
+
* @param labelDtype DType for labels — defaults to what was set in compile()
|
|
64
|
+
*/
|
|
65
|
+
trainStep(sess: Session, opt: Optimizer, xBuf: Buffer, yBuf: Buffer, xShape: number[], yShape: number[], labelDtype?: DType): Promise<TrainStepResult>;
|
|
66
|
+
/**
|
|
67
|
+
* predict — forward pass only, no gradient computation or update.
|
|
68
|
+
*/
|
|
69
|
+
predict(sess: Session, xBuf: Buffer, xShape: number[]): Promise<{
|
|
70
|
+
data: Buffer;
|
|
71
|
+
shape: number[];
|
|
72
|
+
dtype: DType;
|
|
73
|
+
}>;
|
|
74
|
+
/** Input placeholder tensor — available after compile(). */
|
|
75
|
+
get xPlaceholder(): Tensor;
|
|
76
|
+
/** Label placeholder tensor — available after compile(). */
|
|
77
|
+
get yPlaceholder(): Tensor;
|
|
78
|
+
/** Final layer output tensor — available after compile(). */
|
|
79
|
+
get output(): Tensor;
|
|
80
|
+
/** Scalar mean loss tensor — available after compile(). */
|
|
81
|
+
get loss(): Tensor;
|
|
82
|
+
/**
|
|
83
|
+
* All parameter specs with real gradient tensors.
|
|
84
|
+
* Available after compile(). Use these to construct your optimizer.
|
|
85
|
+
*/
|
|
86
|
+
get params(): ParamSpec[];
|
|
87
|
+
/** Label dtype resolved during compile(). */
|
|
88
|
+
get labelDtype(): DType;
|
|
89
|
+
private assertCompiled;
|
|
90
|
+
}
|
|
91
|
+
//# sourceMappingURL=sequential.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"sequential.d.ts","sourceRoot":"","sources":["../../src/ts/model/sequential.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,gBAAgB,CAAC;AAC7C,OAAO,EAAE,KAAK,EAAE,MAAM,gBAAgB,CAAC;AACvC,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,aAAa,CAAC;AACzC,OAAO,KAAK,EAAE,OAAO,EAAE,MAAM,eAAe,CAAC;AAC7C,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,YAAY,CAAC;AACxC,OAAO,KAAK,EAAE,SAAS,EAAE,SAAS,EAAE,MAAM,sBAAsB,CAAC;AASjE,MAAM,MAAM,MAAM,GACd,iCAAiC,GACjC,qBAAqB,GACrB,KAAK,CAAC;AAKV,MAAM,WAAW,SAAS;IACxB,IAAI,CAAC,IAAI,EAAE,OAAO,GAAG,OAAO,CAAC,IAAI,CAAC,CAAC;IACnC,cAAc,CAAC,IAAI,EAAE,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,GAAG,OAAO,CAAC,IAAI,CAAC,CAAC;CAClE;AAED,MAAM,WAAW,eAAe;IAC9B,IAAI,EAAE,MAAM,CAAC;CACd;AAqCD,qBAAa,UAAU;IACrB,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAQ;IAC1B,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IAGjC,OAAO,CAAC,aAAa,CAAU;IAC/B,OAAO,CAAC,aAAa,CAAU;IAC/B,OAAO,CAAC,aAAa,CAAU;IAC/B,OAAO,CAAC,WAAW,CAAU;IAC7B,OAAO,CAAC,UAAU,CAAe;IACjC,OAAO,CAAC,UAAU,CAAU;IAC5B,OAAO,CAAC,WAAW,CAAS;IAC5B,OAAO,CAAC,QAAQ,CAAS;gBAEb,CAAC,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE;IAKrC;;;;;;OAMG;IACH,OAAO,CAAC,IAAI,EAAE;QACZ,IAAI,EAAE,MAAM,CAAC;QACb,UAAU,EAAE,MAAM,EAAE,CAAC;QACrB,UAAU,CAAC,EAAE,KAAK,CAAC;KACpB,GAAG,IAAI;IAmJR;;;;;OAKG;IACG,IAAI,CAAC,IAAI,EAAE,OAAO,EAAE,GAAG,EAAE,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC;IAMxD;;;;;;;;;;;;;;;;;;OAkBG;IACG,SAAS,CACb,IAAI,EAAE,OAAO,EACb,GAAG,EAAE,SAAS,EACd,IAAI,EAAE,MAAM,EACZ,IAAI,EAAE,MAAM,EACZ,MAAM,EAAE,MAAM,EAAE,EAChB,MAAM,EAAE,MAAM,EAAE,EAChB,UAAU,CAAC,EAAE,KAAK,GACjB,OAAO,CAAC,eAAe,CAAC;IAuB3B;;OAEG;IACG,OAAO,CACX,IAAI,EAAE,OAAO,EACb,IAAI,EAAE,MAAM,EACZ,MAAM,EAAE,MAAM,EAAE,GACf,OAAO,CAAC;QAAE,IAAI,EAAE,MAAM,CAAC;QAAC,KAAK,EAAE,MAAM,EAAE,CAAC;QAAC,KAAK,EAAE,KAAK,CAAA;KAAE,CAAC;IAgB3D,4DAA4D;IAC5D,IAAI,YAAY,IAAI,MAAM,CAGzB;IAED,4DAA4D;IAC5D,IAAI,YAAY,IAAI,MAAM,CAGzB;IAED,6DAA6D;IAC7D,IAAI,MAAM,IAAI,MAAM,CAGnB;IAED,2DAA2D;IAC3D,IAAI,IAAI,IAAI,MAAM,CAGjB;IAED;;;OAGG;IACH,IAAI,MAAM,IAAI,SAAS,EAAE,CAGxB;IAED,6CAA6C;IAC7C,IAAI,UAAU,IAAI,KAAK,CAGtB;IAED,OAAO,CAAC,cAAc;CAIvB"}
|