deepbox 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +39 -37
- package/dist/{CSRMatrix-KzNt6QpS.d.ts → CSRMatrix-B7XtUAZO.d.cts} +3 -3
- package/dist/{CSRMatrix-CwGwQRea.d.cts → CSRMatrix-CtD23fRM.d.ts} +3 -3
- package/dist/{Tensor-BQLk1ltW.d.cts → Tensor-BORFp_zt.d.ts} +27 -7
- package/dist/{Tensor-g8mUClel.d.ts → Tensor-fxBg-TFZ.d.cts} +27 -7
- package/dist/{chunk-FJYLIGJX.js → chunk-3AX37GPK.js} +33 -7
- package/dist/chunk-3AX37GPK.js.map +1 -0
- package/dist/{chunk-PR647I7R.js → chunk-3YFEYDHN.js} +21 -4
- package/dist/chunk-3YFEYDHN.js.map +1 -0
- package/dist/{chunk-XMWVME2W.js → chunk-6SX26MAJ.js} +4 -4
- package/dist/{chunk-XMWVME2W.js.map → chunk-6SX26MAJ.js.map} +1 -1
- package/dist/{chunk-C4PKXY74.cjs → chunk-6X7XFNDO.cjs} +94 -77
- package/dist/chunk-6X7XFNDO.cjs.map +1 -0
- package/dist/{chunk-6AE5FKKQ.cjs → chunk-724CXHFH.cjs} +1211 -919
- package/dist/chunk-724CXHFH.cjs.map +1 -0
- package/dist/{chunk-AU7XHGKJ.js → chunk-AJTKVBY5.js} +4 -4
- package/dist/{chunk-AU7XHGKJ.js.map → chunk-AJTKVBY5.js.map} +1 -1
- package/dist/{chunk-ZB75FESB.cjs → chunk-AV6WGSYX.cjs} +130 -104
- package/dist/chunk-AV6WGSYX.cjs.map +1 -0
- package/dist/{chunk-ZLW62TJG.cjs → chunk-BWOSU234.cjs} +142 -141
- package/dist/chunk-BWOSU234.cjs.map +1 -0
- package/dist/{chunk-4S73VUBD.js → chunk-CZOMBUI7.js} +3 -3
- package/dist/chunk-CZOMBUI7.js.map +1 -0
- package/dist/{chunk-QERHVCHC.cjs → chunk-EUZHJDZ6.cjs} +419 -364
- package/dist/chunk-EUZHJDZ6.cjs.map +1 -0
- package/dist/{chunk-AD436M45.js → chunk-G2G55ATL.js} +120 -58
- package/dist/chunk-G2G55ATL.js.map +1 -0
- package/dist/{chunk-5R4S63PF.js → chunk-G3WNLNYS.js} +119 -64
- package/dist/chunk-G3WNLNYS.js.map +1 -0
- package/dist/{chunk-XEG44RF6.cjs → chunk-G7KXZHG6.cjs} +105 -95
- package/dist/chunk-G7KXZHG6.cjs.map +1 -0
- package/dist/{chunk-MLBMYKCG.js → chunk-H3JR7SV2.js} +255 -113
- package/dist/chunk-H3JR7SV2.js.map +1 -0
- package/dist/{chunk-PHV2DKRS.cjs → chunk-HDKMIG6E.cjs} +107 -107
- package/dist/{chunk-PHV2DKRS.cjs.map → chunk-HDKMIG6E.cjs.map} +1 -1
- package/dist/{chunk-ALS7ETWZ.cjs → chunk-HI2EZHCJ.cjs} +111 -102
- package/dist/chunk-HI2EZHCJ.cjs.map +1 -0
- package/dist/{chunk-OX6QXFMV.cjs → chunk-IT4BZUYE.cjs} +490 -428
- package/dist/chunk-IT4BZUYE.cjs.map +1 -0
- package/dist/{chunk-E3EU5FZO.cjs → chunk-JTZPRV6E.cjs} +123 -123
- package/dist/{chunk-E3EU5FZO.cjs.map → chunk-JTZPRV6E.cjs.map} +1 -1
- package/dist/{chunk-PL7TAYKI.js → chunk-K2L5C5YH.js} +8 -7
- package/dist/chunk-K2L5C5YH.js.map +1 -0
- package/dist/{chunk-BCR7G3A6.js → chunk-KCF6P34A.js} +356 -64
- package/dist/chunk-KCF6P34A.js.map +1 -0
- package/dist/{chunk-ZXKBDFP3.js → chunk-LZHVHD62.js} +15 -6
- package/dist/chunk-LZHVHD62.js.map +1 -0
- package/dist/{chunk-LWECRCW2.cjs → chunk-MTJF52AJ.cjs} +141 -141
- package/dist/{chunk-LWECRCW2.cjs.map → chunk-MTJF52AJ.cjs.map} +1 -1
- package/dist/{chunk-B5TNKUEY.js → chunk-NDDTUFKK.js} +16 -6
- package/dist/chunk-NDDTUFKK.js.map +1 -0
- package/dist/{chunk-DWZY6PIP.cjs → chunk-NOQI6OFL.cjs} +615 -473
- package/dist/chunk-NOQI6OFL.cjs.map +1 -0
- package/dist/{chunk-F3JWBINJ.js → chunk-OEXDJFHA.js} +4 -4
- package/dist/{chunk-F3JWBINJ.js.map → chunk-OEXDJFHA.js.map} +1 -1
- package/dist/{chunk-JSCDE774.cjs → chunk-Z6BGACIH.cjs} +3 -3
- package/dist/chunk-Z6BGACIH.cjs.map +1 -0
- package/dist/core/index.cjs +50 -50
- package/dist/core/index.d.cts +2 -2
- package/dist/core/index.d.ts +2 -2
- package/dist/core/index.js +1 -1
- package/dist/dataframe/index.cjs +6 -6
- package/dist/dataframe/index.d.cts +3 -3
- package/dist/dataframe/index.d.ts +3 -3
- package/dist/dataframe/index.js +3 -3
- package/dist/datasets/index.cjs +34 -34
- package/dist/datasets/index.d.cts +3 -3
- package/dist/datasets/index.d.ts +3 -3
- package/dist/datasets/index.js +3 -3
- package/dist/{index-C1mfVYoo.d.ts → index-B18dHc8q.d.ts} +81 -46
- package/dist/{index-GFAVyOWO.d.ts → index-BHHX0qTY.d.cts} +14 -12
- package/dist/{index-tk4lSYod.d.ts → index-BI6QOUvV.d.ts} +106 -80
- package/dist/{index-DIp_RrRt.d.ts → index-BKvK21lf.d.ts} +13 -35
- package/dist/{index-BJY2SI4i.d.ts → index-BL8jLf3K.d.cts} +12 -11
- package/dist/{index-Cn3SdB0O.d.ts → index-BNbX167d.d.cts} +16 -10
- package/dist/{index-BWGhrDlr.d.ts → index-BT2ofL7Z.d.cts} +35 -35
- package/dist/{index-BbA2Gxfl.d.ts → index-BqcfIcL4.d.ts} +15 -15
- package/dist/{index-ZtI1Iy4L.d.ts → index-BrgrECM2.d.ts} +41 -38
- package/dist/{index-CDw5CnOU.d.ts → index-BtYKI9yJ.d.ts} +10 -8
- package/dist/{index-DIT_OO9C.d.cts → index-C7nLsAOC.d.cts} +10 -8
- package/dist/{index-D9Loo1_A.d.cts → index-CNj2Mxwf.d.cts} +81 -46
- package/dist/{index-DmEg_LCm.d.cts → index-CYlxeNW1.d.cts} +5 -3
- package/dist/{index-D61yaSMY.d.cts → index-CiTd61a5.d.ts} +12 -11
- package/dist/{index-BndMbqsM.d.ts → index-Cjnn0KeN.d.cts} +35 -21
- package/dist/{index-9oQx1HgV.d.cts → index-CkGGAn69.d.cts} +41 -38
- package/dist/{index-74AB8Cyh.d.cts → index-D4URSgqA.d.ts} +16 -10
- package/dist/{index-DoPWVxPo.d.cts → index-D4pn5zLT.d.ts} +35 -21
- package/dist/{index-DuCxd-8d.d.ts → index-D9ztTlDr.d.ts} +60 -42
- package/dist/{index-BgHYAoSS.d.cts → index-DF28ZPB5.d.cts} +60 -42
- package/dist/{index-eJgeni9c.d.cts → index-DLdiQzf0.d.cts} +106 -80
- package/dist/{index-WHQLn0e8.d.cts → index-DN4omPQw.d.ts} +35 -35
- package/dist/{index-CrqLlS-a.d.ts → index-DUnFq1WV.d.ts} +5 -3
- package/dist/{index-DbultU6X.d.cts → index-DgaYshkF.d.ts} +14 -12
- package/dist/{index-B_DK4FKY.d.cts → index-GUHYEhxs.d.cts} +13 -35
- package/dist/{index-CCvlwAmL.d.cts → index-TP--4irE.d.cts} +16 -14
- package/dist/{index-Dx42TZaY.d.ts → index-x0z_sanT.d.ts} +16 -14
- package/dist/{index-DyZ4QQf5.d.cts → index-xWH7ujWa.d.cts} +15 -15
- package/dist/index.cjs +26 -26
- package/dist/index.d.cts +17 -17
- package/dist/index.d.ts +17 -17
- package/dist/index.js +13 -13
- package/dist/linalg/index.cjs +22 -22
- package/dist/linalg/index.d.cts +3 -3
- package/dist/linalg/index.d.ts +3 -3
- package/dist/linalg/index.js +3 -3
- package/dist/metrics/index.cjs +40 -40
- package/dist/metrics/index.d.cts +3 -3
- package/dist/metrics/index.d.ts +3 -3
- package/dist/metrics/index.js +3 -3
- package/dist/ml/index.cjs +23 -23
- package/dist/ml/index.d.cts +3 -3
- package/dist/ml/index.d.ts +3 -3
- package/dist/ml/index.js +4 -4
- package/dist/ndarray/index.cjs +125 -125
- package/dist/ndarray/index.d.cts +5 -5
- package/dist/ndarray/index.d.ts +5 -5
- package/dist/ndarray/index.js +2 -2
- package/dist/nn/index.cjs +36 -36
- package/dist/nn/index.d.cts +6 -6
- package/dist/nn/index.d.ts +6 -6
- package/dist/nn/index.js +3 -3
- package/dist/optim/index.cjs +19 -19
- package/dist/optim/index.d.cts +4 -4
- package/dist/optim/index.d.ts +4 -4
- package/dist/optim/index.js +2 -2
- package/dist/plot/index.cjs +29 -29
- package/dist/plot/index.d.cts +6 -6
- package/dist/plot/index.d.ts +6 -6
- package/dist/plot/index.js +3 -3
- package/dist/preprocess/index.cjs +21 -21
- package/dist/preprocess/index.d.cts +4 -4
- package/dist/preprocess/index.d.ts +4 -4
- package/dist/preprocess/index.js +3 -3
- package/dist/random/index.cjs +19 -19
- package/dist/random/index.d.cts +3 -3
- package/dist/random/index.d.ts +3 -3
- package/dist/random/index.js +3 -3
- package/dist/stats/index.cjs +36 -36
- package/dist/stats/index.d.cts +3 -3
- package/dist/stats/index.d.ts +3 -3
- package/dist/stats/index.js +3 -3
- package/dist/{tensor-B96jjJLQ.d.cts → tensor-IlVTF0bz.d.cts} +16 -3
- package/dist/{tensor-B96jjJLQ.d.ts → tensor-IlVTF0bz.d.ts} +16 -3
- package/package.json +3 -2
- package/dist/chunk-4S73VUBD.js.map +0 -1
- package/dist/chunk-5R4S63PF.js.map +0 -1
- package/dist/chunk-6AE5FKKQ.cjs.map +0 -1
- package/dist/chunk-AD436M45.js.map +0 -1
- package/dist/chunk-ALS7ETWZ.cjs.map +0 -1
- package/dist/chunk-B5TNKUEY.js.map +0 -1
- package/dist/chunk-BCR7G3A6.js.map +0 -1
- package/dist/chunk-C4PKXY74.cjs.map +0 -1
- package/dist/chunk-DWZY6PIP.cjs.map +0 -1
- package/dist/chunk-FJYLIGJX.js.map +0 -1
- package/dist/chunk-JSCDE774.cjs.map +0 -1
- package/dist/chunk-MLBMYKCG.js.map +0 -1
- package/dist/chunk-OX6QXFMV.cjs.map +0 -1
- package/dist/chunk-PL7TAYKI.js.map +0 -1
- package/dist/chunk-PR647I7R.js.map +0 -1
- package/dist/chunk-QERHVCHC.cjs.map +0 -1
- package/dist/chunk-XEG44RF6.cjs.map +0 -1
- package/dist/chunk-ZB75FESB.cjs.map +0 -1
- package/dist/chunk-ZLW62TJG.cjs.map +0 -1
- package/dist/chunk-ZXKBDFP3.js.map +0 -1
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
'use strict';
|
|
2
2
|
|
|
3
|
-
var
|
|
4
|
-
var
|
|
3
|
+
var chunk724CXHFH_cjs = require('./chunk-724CXHFH.cjs');
|
|
4
|
+
var chunkZ6BGACIH_cjs = require('./chunk-Z6BGACIH.cjs');
|
|
5
5
|
|
|
6
6
|
// src/nn/index.ts
|
|
7
7
|
var nn_exports = {};
|
|
8
|
-
|
|
8
|
+
chunkZ6BGACIH_cjs.__export(nn_exports, {
|
|
9
9
|
AvgPool2d: () => AvgPool2d,
|
|
10
10
|
BatchNorm1d: () => BatchNorm1d,
|
|
11
11
|
Conv1d: () => Conv1d,
|
|
@@ -53,7 +53,7 @@ function sizeFromShape(shape, context) {
|
|
|
53
53
|
let size = 1;
|
|
54
54
|
for (const dim of shape) {
|
|
55
55
|
if (!Number.isInteger(dim) || dim < 0) {
|
|
56
|
-
throw new
|
|
56
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`${context} contains invalid dimension ${String(dim)}`);
|
|
57
57
|
}
|
|
58
58
|
size *= dim;
|
|
59
59
|
}
|
|
@@ -71,7 +71,7 @@ function cloneTensorData(t) {
|
|
|
71
71
|
for (let i = 0; i < data.length; i++) {
|
|
72
72
|
const value = data[i];
|
|
73
73
|
if (value === void 0) {
|
|
74
|
-
throw new
|
|
74
|
+
throw new chunkZ6BGACIH_cjs.DeepboxError("Internal error: tensor data access out of bounds");
|
|
75
75
|
}
|
|
76
76
|
out[i] = value;
|
|
77
77
|
}
|
|
@@ -80,35 +80,35 @@ function cloneTensorData(t) {
|
|
|
80
80
|
function validateStateEntryShape(name, kind, entry) {
|
|
81
81
|
const size = sizeFromShape(entry.shape, `${kind} ${name} shape`);
|
|
82
82
|
if (entry.data.length !== size) {
|
|
83
|
-
throw new
|
|
83
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
84
84
|
`${kind} ${name} data length ${entry.data.length} does not match shape size ${size}`
|
|
85
85
|
);
|
|
86
86
|
}
|
|
87
87
|
}
|
|
88
88
|
function copyStateEntryIntoTensor(name, kind, target, entry) {
|
|
89
89
|
if (!shapesEqual(target.shape, entry.shape)) {
|
|
90
|
-
throw new
|
|
90
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
91
91
|
`${kind} ${name} shape mismatch: expected [${target.shape.join(", ")}], got [${entry.shape.join(", ")}]`
|
|
92
92
|
);
|
|
93
93
|
}
|
|
94
94
|
if (target.dtype !== entry.dtype) {
|
|
95
|
-
throw new
|
|
95
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(
|
|
96
96
|
`${kind} ${name} dtype mismatch: expected ${target.dtype}, got ${entry.dtype}`
|
|
97
97
|
);
|
|
98
98
|
}
|
|
99
99
|
const size = sizeFromShape(entry.shape, `${kind} ${name} shape`);
|
|
100
|
-
const logicalStrides =
|
|
100
|
+
const logicalStrides = chunk724CXHFH_cjs.computeStrides(target.shape);
|
|
101
101
|
const data = target.data;
|
|
102
102
|
if (target.dtype === "string") {
|
|
103
103
|
if (!Array.isArray(data)) {
|
|
104
|
-
throw new
|
|
104
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${kind} ${name} expected string data`);
|
|
105
105
|
}
|
|
106
106
|
for (let i = 0; i < size; i++) {
|
|
107
107
|
const value = entry.data[i];
|
|
108
108
|
if (typeof value !== "string") {
|
|
109
|
-
throw new
|
|
109
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${kind} ${name} expects string data`);
|
|
110
110
|
}
|
|
111
|
-
const offset =
|
|
111
|
+
const offset = chunk724CXHFH_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
112
112
|
data[offset] = value;
|
|
113
113
|
}
|
|
114
114
|
return;
|
|
@@ -117,22 +117,22 @@ function copyStateEntryIntoTensor(name, kind, target, entry) {
|
|
|
117
117
|
for (let i = 0; i < size; i++) {
|
|
118
118
|
const value = entry.data[i];
|
|
119
119
|
if (typeof value !== "bigint") {
|
|
120
|
-
throw new
|
|
120
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${kind} ${name} expects bigint data`);
|
|
121
121
|
}
|
|
122
|
-
const offset =
|
|
122
|
+
const offset = chunk724CXHFH_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
123
123
|
data[offset] = value;
|
|
124
124
|
}
|
|
125
125
|
return;
|
|
126
126
|
}
|
|
127
127
|
if (Array.isArray(data)) {
|
|
128
|
-
throw new
|
|
128
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${kind} ${name} expected numeric data`);
|
|
129
129
|
}
|
|
130
130
|
for (let i = 0; i < size; i++) {
|
|
131
131
|
const value = entry.data[i];
|
|
132
132
|
if (typeof value !== "number") {
|
|
133
|
-
throw new
|
|
133
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${kind} ${name} expects numeric data`);
|
|
134
134
|
}
|
|
135
|
-
const offset =
|
|
135
|
+
const offset = chunk724CXHFH_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
136
136
|
data[offset] = value;
|
|
137
137
|
}
|
|
138
138
|
}
|
|
@@ -413,7 +413,7 @@ var Module = class _Module {
|
|
|
413
413
|
const resolved = this.resolveModuleAndName(name);
|
|
414
414
|
if (!resolved) {
|
|
415
415
|
if (providedNames) {
|
|
416
|
-
throw new
|
|
416
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`Unknown parameter name: ${name}`, "names", name);
|
|
417
417
|
}
|
|
418
418
|
continue;
|
|
419
419
|
}
|
|
@@ -421,11 +421,11 @@ var Module = class _Module {
|
|
|
421
421
|
const param = module._parameters.get(localName);
|
|
422
422
|
if (!param) {
|
|
423
423
|
if (providedNames) {
|
|
424
|
-
throw new
|
|
424
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`Unknown parameter name: ${name}`, "names", name);
|
|
425
425
|
}
|
|
426
426
|
continue;
|
|
427
427
|
}
|
|
428
|
-
const nextParam =
|
|
428
|
+
const nextParam = chunk724CXHFH_cjs.GradTensor.fromTensor(param.tensor, { requiresGrad });
|
|
429
429
|
module._parameters.set(localName, nextParam);
|
|
430
430
|
for (const [key, value] of Object.entries(module)) {
|
|
431
431
|
if (value === param) {
|
|
@@ -448,7 +448,7 @@ var Module = class _Module {
|
|
|
448
448
|
}
|
|
449
449
|
static setTensorDeviceMetadata(target, device) {
|
|
450
450
|
if (!Reflect.set(target, "device", device)) {
|
|
451
|
-
throw new
|
|
451
|
+
throw new chunkZ6BGACIH_cjs.DeepboxError("Failed to update tensor device metadata");
|
|
452
452
|
}
|
|
453
453
|
}
|
|
454
454
|
/**
|
|
@@ -486,17 +486,17 @@ var Module = class _Module {
|
|
|
486
486
|
const namedBuffs = new Map(this.namedBuffers());
|
|
487
487
|
for (const name of namedParams.keys()) {
|
|
488
488
|
if (!(name in parameters)) {
|
|
489
|
-
throw new
|
|
489
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`missing parameter: ${name}`, "stateDict.parameters", name);
|
|
490
490
|
}
|
|
491
491
|
}
|
|
492
492
|
for (const name of namedBuffs.keys()) {
|
|
493
493
|
if (!(name in buffers)) {
|
|
494
|
-
throw new
|
|
494
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`missing buffer: ${name}`, "stateDict.buffers", name);
|
|
495
495
|
}
|
|
496
496
|
}
|
|
497
497
|
for (const name of Object.keys(parameters)) {
|
|
498
498
|
if (!namedParams.has(name)) {
|
|
499
|
-
throw new
|
|
499
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
500
500
|
`unexpected parameter: ${name}`,
|
|
501
501
|
"stateDict.parameters",
|
|
502
502
|
name
|
|
@@ -505,7 +505,7 @@ var Module = class _Module {
|
|
|
505
505
|
}
|
|
506
506
|
for (const name of Object.keys(buffers)) {
|
|
507
507
|
if (!namedBuffs.has(name)) {
|
|
508
|
-
throw new
|
|
508
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`unexpected buffer: ${name}`, "stateDict.buffers", name);
|
|
509
509
|
}
|
|
510
510
|
}
|
|
511
511
|
for (const [name, entry] of Object.entries(parameters)) {
|
|
@@ -542,8 +542,8 @@ var Module = class _Module {
|
|
|
542
542
|
* ```
|
|
543
543
|
*/
|
|
544
544
|
to(device) {
|
|
545
|
-
if (!
|
|
546
|
-
throw new
|
|
545
|
+
if (!chunkZ6BGACIH_cjs.isDevice(device)) {
|
|
546
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("device must be one of: cpu, webgpu, wasm", "device", device);
|
|
547
547
|
}
|
|
548
548
|
for (const param of this.parameters()) {
|
|
549
549
|
_Module.setTensorDeviceMetadata(param.tensor, device);
|
|
@@ -613,7 +613,7 @@ var Sequential = class extends Module {
|
|
|
613
613
|
constructor(...layers) {
|
|
614
614
|
super();
|
|
615
615
|
if (layers.length === 0) {
|
|
616
|
-
throw new
|
|
616
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
617
617
|
"Sequential requires at least one layer",
|
|
618
618
|
"layers",
|
|
619
619
|
layers.length
|
|
@@ -623,7 +623,7 @@ var Sequential = class extends Module {
|
|
|
623
623
|
for (let i = 0; i < layers.length; i++) {
|
|
624
624
|
const layer = layers[i];
|
|
625
625
|
if (!layer) {
|
|
626
|
-
throw new
|
|
626
|
+
throw new chunkZ6BGACIH_cjs.DeepboxError(`Layer at index ${i} is undefined`);
|
|
627
627
|
}
|
|
628
628
|
this.registerModule(String(i), layer);
|
|
629
629
|
}
|
|
@@ -640,7 +640,7 @@ var Sequential = class extends Module {
|
|
|
640
640
|
*/
|
|
641
641
|
forward(...inputs) {
|
|
642
642
|
if (inputs.length !== 1) {
|
|
643
|
-
throw new
|
|
643
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
644
644
|
"Sequential.forward expects a single input tensor",
|
|
645
645
|
"inputs",
|
|
646
646
|
inputs.length
|
|
@@ -648,7 +648,7 @@ var Sequential = class extends Module {
|
|
|
648
648
|
}
|
|
649
649
|
const input = inputs[0];
|
|
650
650
|
if (!input) {
|
|
651
|
-
throw new
|
|
651
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
652
652
|
"Sequential.forward expects a single input tensor",
|
|
653
653
|
"input",
|
|
654
654
|
input
|
|
@@ -658,11 +658,11 @@ var Sequential = class extends Module {
|
|
|
658
658
|
for (let i = 0; i < this.layers.length; i++) {
|
|
659
659
|
const layer = this.layers[i];
|
|
660
660
|
if (!layer) {
|
|
661
|
-
throw new
|
|
661
|
+
throw new chunkZ6BGACIH_cjs.DeepboxError(`Layer at index ${i} is undefined`);
|
|
662
662
|
}
|
|
663
663
|
const result = layer.call(output);
|
|
664
664
|
if (Array.isArray(result)) {
|
|
665
|
-
throw new
|
|
665
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
666
666
|
`Sequential does not support layers that return multiple tensors (layer ${i})`,
|
|
667
667
|
"layer",
|
|
668
668
|
i
|
|
@@ -682,14 +682,14 @@ var Sequential = class extends Module {
|
|
|
682
682
|
*/
|
|
683
683
|
getLayer(index) {
|
|
684
684
|
if (index < 0 || index >= this.layers.length) {
|
|
685
|
-
throw new
|
|
685
|
+
throw new chunkZ6BGACIH_cjs.IndexError(`Layer index ${index} out of bounds [0, ${this.layers.length})`, {
|
|
686
686
|
index,
|
|
687
687
|
validRange: [0, this.layers.length - 1]
|
|
688
688
|
});
|
|
689
689
|
}
|
|
690
690
|
const layer = this.layers[index];
|
|
691
691
|
if (!layer) {
|
|
692
|
-
throw new
|
|
692
|
+
throw new chunkZ6BGACIH_cjs.DeepboxError(`Layer at index ${index} is undefined`);
|
|
693
693
|
}
|
|
694
694
|
return layer;
|
|
695
695
|
}
|
|
@@ -731,8 +731,8 @@ var Sequential = class extends Module {
|
|
|
731
731
|
// src/nn/layers/activations.ts
|
|
732
732
|
var ReLU = class extends Module {
|
|
733
733
|
forward(input) {
|
|
734
|
-
if (
|
|
735
|
-
return
|
|
734
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.relu();
|
|
735
|
+
return chunk724CXHFH_cjs.relu(input);
|
|
736
736
|
}
|
|
737
737
|
toString() {
|
|
738
738
|
return "ReLU()";
|
|
@@ -740,8 +740,8 @@ var ReLU = class extends Module {
|
|
|
740
740
|
};
|
|
741
741
|
var Sigmoid = class extends Module {
|
|
742
742
|
forward(input) {
|
|
743
|
-
if (
|
|
744
|
-
return
|
|
743
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.sigmoid();
|
|
744
|
+
return chunk724CXHFH_cjs.sigmoid(input);
|
|
745
745
|
}
|
|
746
746
|
toString() {
|
|
747
747
|
return "Sigmoid()";
|
|
@@ -749,8 +749,8 @@ var Sigmoid = class extends Module {
|
|
|
749
749
|
};
|
|
750
750
|
var Tanh = class extends Module {
|
|
751
751
|
forward(input) {
|
|
752
|
-
if (
|
|
753
|
-
return
|
|
752
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.tanh();
|
|
753
|
+
return chunk724CXHFH_cjs.tanh(input);
|
|
754
754
|
}
|
|
755
755
|
toString() {
|
|
756
756
|
return "Tanh()";
|
|
@@ -763,8 +763,8 @@ var LeakyReLU = class extends Module {
|
|
|
763
763
|
this.alpha = alpha;
|
|
764
764
|
}
|
|
765
765
|
forward(input) {
|
|
766
|
-
if (
|
|
767
|
-
return
|
|
766
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.leakyRelu(this.alpha);
|
|
767
|
+
return chunk724CXHFH_cjs.leakyRelu(input, this.alpha);
|
|
768
768
|
}
|
|
769
769
|
toString() {
|
|
770
770
|
return `LeakyReLU(alpha=${this.alpha})`;
|
|
@@ -777,8 +777,8 @@ var ELU = class extends Module {
|
|
|
777
777
|
this.alpha = alpha;
|
|
778
778
|
}
|
|
779
779
|
forward(input) {
|
|
780
|
-
if (
|
|
781
|
-
return
|
|
780
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.elu(this.alpha);
|
|
781
|
+
return chunk724CXHFH_cjs.elu(input, this.alpha);
|
|
782
782
|
}
|
|
783
783
|
toString() {
|
|
784
784
|
return `ELU(alpha=${this.alpha})`;
|
|
@@ -786,8 +786,8 @@ var ELU = class extends Module {
|
|
|
786
786
|
};
|
|
787
787
|
var GELU = class extends Module {
|
|
788
788
|
forward(input) {
|
|
789
|
-
if (
|
|
790
|
-
return
|
|
789
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) return input.gelu();
|
|
790
|
+
return chunk724CXHFH_cjs.gelu(input);
|
|
791
791
|
}
|
|
792
792
|
toString() {
|
|
793
793
|
return "GELU()";
|
|
@@ -800,10 +800,10 @@ var Softmax = class extends Module {
|
|
|
800
800
|
this.axis = axis;
|
|
801
801
|
}
|
|
802
802
|
forward(input) {
|
|
803
|
-
if (
|
|
804
|
-
return
|
|
803
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
804
|
+
return chunk724CXHFH_cjs.softmax2(input, chunkZ6BGACIH_cjs.normalizeAxis(this.axis, input.tensor.ndim));
|
|
805
805
|
}
|
|
806
|
-
return
|
|
806
|
+
return chunk724CXHFH_cjs.softmax(input, this.axis);
|
|
807
807
|
}
|
|
808
808
|
toString() {
|
|
809
809
|
return `Softmax(axis=${this.axis})`;
|
|
@@ -816,10 +816,10 @@ var LogSoftmax = class extends Module {
|
|
|
816
816
|
this.axis = axis;
|
|
817
817
|
}
|
|
818
818
|
forward(input) {
|
|
819
|
-
if (
|
|
820
|
-
return
|
|
819
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
820
|
+
return chunk724CXHFH_cjs.logSoftmax2(input, chunkZ6BGACIH_cjs.normalizeAxis(this.axis, input.tensor.ndim));
|
|
821
821
|
}
|
|
822
|
-
return
|
|
822
|
+
return chunk724CXHFH_cjs.logSoftmax(input, this.axis);
|
|
823
823
|
}
|
|
824
824
|
toString() {
|
|
825
825
|
return `LogSoftmax(axis=${this.axis})`;
|
|
@@ -827,12 +827,13 @@ var LogSoftmax = class extends Module {
|
|
|
827
827
|
};
|
|
828
828
|
var Softplus = class extends Module {
|
|
829
829
|
forward(input) {
|
|
830
|
-
if (
|
|
831
|
-
|
|
832
|
-
|
|
830
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
831
|
+
const one = chunk724CXHFH_cjs.GradTensor.scalar(1, {
|
|
832
|
+
dtype: input.dtype === "float64" ? "float64" : "float32"
|
|
833
833
|
});
|
|
834
|
+
return one.add(input.exp()).log();
|
|
834
835
|
}
|
|
835
|
-
return
|
|
836
|
+
return chunk724CXHFH_cjs.softplus(input);
|
|
836
837
|
}
|
|
837
838
|
toString() {
|
|
838
839
|
return "Softplus()";
|
|
@@ -840,12 +841,10 @@ var Softplus = class extends Module {
|
|
|
840
841
|
};
|
|
841
842
|
var Swish = class extends Module {
|
|
842
843
|
forward(input) {
|
|
843
|
-
if (
|
|
844
|
-
return
|
|
845
|
-
requiresGrad: false
|
|
846
|
-
});
|
|
844
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
845
|
+
return input.mul(input.sigmoid());
|
|
847
846
|
}
|
|
848
|
-
return
|
|
847
|
+
return chunk724CXHFH_cjs.swish(input);
|
|
849
848
|
}
|
|
850
849
|
toString() {
|
|
851
850
|
return "Swish()";
|
|
@@ -853,12 +852,14 @@ var Swish = class extends Module {
|
|
|
853
852
|
};
|
|
854
853
|
var Mish = class extends Module {
|
|
855
854
|
forward(input) {
|
|
856
|
-
if (
|
|
857
|
-
|
|
858
|
-
|
|
855
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
856
|
+
const one = chunk724CXHFH_cjs.GradTensor.scalar(1, {
|
|
857
|
+
dtype: input.dtype === "float64" ? "float64" : "float32"
|
|
859
858
|
});
|
|
859
|
+
const sp = one.add(input.exp()).log();
|
|
860
|
+
return input.mul(sp.tanh());
|
|
860
861
|
}
|
|
861
|
-
return
|
|
862
|
+
return chunk724CXHFH_cjs.mish(input);
|
|
862
863
|
}
|
|
863
864
|
toString() {
|
|
864
865
|
return "Mish()";
|
|
@@ -878,7 +879,7 @@ var Dropout = class extends Module {
|
|
|
878
879
|
constructor(p = 0.5) {
|
|
879
880
|
super();
|
|
880
881
|
if (!Number.isFinite(p) || p < 0 || p >= 1) {
|
|
881
|
-
throw new
|
|
882
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`Dropout probability must be in [0, 1), got ${p}`, "p", p);
|
|
882
883
|
}
|
|
883
884
|
this.p = p;
|
|
884
885
|
}
|
|
@@ -889,11 +890,11 @@ var Dropout = class extends Module {
|
|
|
889
890
|
* @returns Output tensor with same shape as input
|
|
890
891
|
*/
|
|
891
892
|
forward(input) {
|
|
892
|
-
const inputTensor =
|
|
893
|
+
const inputTensor = chunk724CXHFH_cjs.GradTensor.isGradTensor(input) ? input : chunk724CXHFH_cjs.GradTensor.fromTensor(input);
|
|
893
894
|
if (inputTensor.dtype === "string") {
|
|
894
|
-
throw new
|
|
895
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Dropout does not support string dtype");
|
|
895
896
|
}
|
|
896
|
-
return
|
|
897
|
+
return chunk724CXHFH_cjs.dropout(inputTensor, this.p, this.training);
|
|
897
898
|
}
|
|
898
899
|
/**
|
|
899
900
|
* Get string representation of the layer.
|
|
@@ -938,14 +939,14 @@ var Linear = class extends Module {
|
|
|
938
939
|
constructor(inFeatures, outFeatures, options = {}) {
|
|
939
940
|
super();
|
|
940
941
|
if (inFeatures <= 0 || !Number.isInteger(inFeatures)) {
|
|
941
|
-
throw new
|
|
942
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
942
943
|
"inFeatures must be a positive integer",
|
|
943
944
|
"inFeatures",
|
|
944
945
|
inFeatures
|
|
945
946
|
);
|
|
946
947
|
}
|
|
947
948
|
if (outFeatures <= 0 || !Number.isInteger(outFeatures)) {
|
|
948
|
-
throw new
|
|
949
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
949
950
|
"outFeatures must be a positive integer",
|
|
950
951
|
"outFeatures",
|
|
951
952
|
outFeatures
|
|
@@ -955,42 +956,52 @@ var Linear = class extends Module {
|
|
|
955
956
|
this.outFeatures = outFeatures;
|
|
956
957
|
this.useBias = options.bias ?? true;
|
|
957
958
|
const stdDev = Math.sqrt(2 / inFeatures);
|
|
958
|
-
const weightTensor =
|
|
959
|
+
const weightTensor = chunk724CXHFH_cjs.randn([outFeatures, inFeatures], {
|
|
959
960
|
dtype: options.dtype ?? "float32",
|
|
960
961
|
device: options.device ?? "cpu"
|
|
961
962
|
});
|
|
962
|
-
const scaledWeight =
|
|
963
|
-
this.weightParam =
|
|
963
|
+
const scaledWeight = chunk724CXHFH_cjs.mulScalar(weightTensor, stdDev);
|
|
964
|
+
this.weightParam = chunk724CXHFH_cjs.parameter(scaledWeight);
|
|
964
965
|
this.weight = this.weightParam.tensor;
|
|
965
966
|
this.registerParameter("weight", this.weightParam);
|
|
966
967
|
if (this.useBias) {
|
|
967
|
-
const biasTensor =
|
|
968
|
+
const biasTensor = chunk724CXHFH_cjs.zeros([outFeatures], {
|
|
968
969
|
dtype: options.dtype ?? "float32",
|
|
969
970
|
device: options.device ?? "cpu"
|
|
970
971
|
});
|
|
971
|
-
this.biasParam =
|
|
972
|
+
this.biasParam = chunk724CXHFH_cjs.parameter(biasTensor);
|
|
972
973
|
this.bias = this.biasParam.tensor;
|
|
973
974
|
this.registerParameter("bias", this.biasParam);
|
|
974
975
|
}
|
|
975
976
|
}
|
|
976
977
|
forward(input) {
|
|
977
|
-
|
|
978
|
+
let inputTensor = chunk724CXHFH_cjs.GradTensor.isGradTensor(input) ? input.tensor : input;
|
|
978
979
|
if (inputTensor.dtype === "string") {
|
|
979
|
-
throw new
|
|
980
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Linear layer does not support string dtype");
|
|
981
|
+
}
|
|
982
|
+
if (inputTensor.dtype !== this.weight.dtype && inputTensor.dtype !== "int64") {
|
|
983
|
+
const castData = new Float32Array(
|
|
984
|
+
inputTensor.data
|
|
985
|
+
);
|
|
986
|
+
const castTensor = chunk724CXHFH_cjs.reshape(chunk724CXHFH_cjs.tensor(castData), inputTensor.shape);
|
|
987
|
+
inputTensor = castTensor;
|
|
988
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
989
|
+
input = chunk724CXHFH_cjs.parameter(castTensor);
|
|
990
|
+
}
|
|
980
991
|
}
|
|
981
992
|
if (inputTensor.ndim < 1) {
|
|
982
|
-
throw new
|
|
993
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Linear layer expects at least 1D input; got ndim=${inputTensor.ndim}`);
|
|
983
994
|
}
|
|
984
995
|
const inputFeatures = inputTensor.shape[inputTensor.shape.length - 1] ?? 0;
|
|
985
996
|
if (inputFeatures !== this.inFeatures) {
|
|
986
|
-
throw new
|
|
997
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
987
998
|
`Linear layer expects ${this.inFeatures} input features; got ${inputFeatures}`
|
|
988
999
|
);
|
|
989
1000
|
}
|
|
990
1001
|
const isVectorInput = inputTensor.ndim === 1;
|
|
991
1002
|
const batchSize = inputTensor.size / this.inFeatures;
|
|
992
1003
|
const outputShape = isVectorInput ? [this.outFeatures] : [...inputTensor.shape.slice(0, -1), this.outFeatures];
|
|
993
|
-
if (
|
|
1004
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(input)) {
|
|
994
1005
|
const input2d2 = input.reshape([batchSize, this.inFeatures]);
|
|
995
1006
|
const output2d2 = input2d2.matmul(this.weightParam.transpose());
|
|
996
1007
|
let output2 = output2d2.reshape(outputShape);
|
|
@@ -999,11 +1010,11 @@ var Linear = class extends Module {
|
|
|
999
1010
|
}
|
|
1000
1011
|
return output2;
|
|
1001
1012
|
}
|
|
1002
|
-
const input2d =
|
|
1003
|
-
const output2d =
|
|
1004
|
-
const output =
|
|
1013
|
+
const input2d = chunk724CXHFH_cjs.reshape(inputTensor, [batchSize, this.inFeatures]);
|
|
1014
|
+
const output2d = chunk724CXHFH_cjs.dot(input2d, chunk724CXHFH_cjs.transpose(this.weight));
|
|
1015
|
+
const output = chunk724CXHFH_cjs.reshape(output2d, outputShape);
|
|
1005
1016
|
if (this.useBias && this.bias) {
|
|
1006
|
-
return
|
|
1017
|
+
return chunk724CXHFH_cjs.add(output, this.bias);
|
|
1007
1018
|
}
|
|
1008
1019
|
return output;
|
|
1009
1020
|
}
|
|
@@ -1048,37 +1059,37 @@ var Linear = class extends Module {
|
|
|
1048
1059
|
|
|
1049
1060
|
// src/nn/layers/normalization.ts
|
|
1050
1061
|
function toContiguousTensor(t) {
|
|
1051
|
-
if (
|
|
1062
|
+
if (chunk724CXHFH_cjs.isContiguous(t.shape, t.strides)) {
|
|
1052
1063
|
return t;
|
|
1053
1064
|
}
|
|
1054
1065
|
if (t.dtype === "string") {
|
|
1055
|
-
throw new
|
|
1066
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Normalization does not support string dtype");
|
|
1056
1067
|
}
|
|
1057
|
-
const Ctor =
|
|
1068
|
+
const Ctor = chunkZ6BGACIH_cjs.dtypeToTypedArrayCtor(t.dtype);
|
|
1058
1069
|
const out = new Ctor(t.size);
|
|
1059
|
-
const logicalStrides =
|
|
1070
|
+
const logicalStrides = chunk724CXHFH_cjs.computeStrides(t.shape);
|
|
1060
1071
|
const data = t.data;
|
|
1061
1072
|
if (Array.isArray(data)) {
|
|
1062
|
-
throw new
|
|
1073
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Normalization does not support string dtype");
|
|
1063
1074
|
}
|
|
1064
1075
|
if (data instanceof BigInt64Array) {
|
|
1065
1076
|
if (!(out instanceof BigInt64Array)) {
|
|
1066
|
-
throw new
|
|
1077
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Expected int64 output buffer for int64 tensor");
|
|
1067
1078
|
}
|
|
1068
1079
|
for (let i = 0; i < t.size; i++) {
|
|
1069
|
-
const offset =
|
|
1070
|
-
out[i] =
|
|
1080
|
+
const offset = chunk724CXHFH_cjs.offsetFromFlatIndex(i, logicalStrides, t.strides, t.offset);
|
|
1081
|
+
out[i] = chunkZ6BGACIH_cjs.getBigIntElement(data, offset);
|
|
1071
1082
|
}
|
|
1072
1083
|
} else {
|
|
1073
1084
|
if (out instanceof BigInt64Array) {
|
|
1074
|
-
throw new
|
|
1085
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Unexpected int64 output buffer for numeric tensor");
|
|
1075
1086
|
}
|
|
1076
1087
|
for (let i = 0; i < t.size; i++) {
|
|
1077
|
-
const offset =
|
|
1078
|
-
out[i] =
|
|
1088
|
+
const offset = chunk724CXHFH_cjs.offsetFromFlatIndex(i, logicalStrides, t.strides, t.offset);
|
|
1089
|
+
out[i] = chunkZ6BGACIH_cjs.getNumericElement(data, offset);
|
|
1079
1090
|
}
|
|
1080
1091
|
}
|
|
1081
|
-
return
|
|
1092
|
+
return chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
1082
1093
|
data: out,
|
|
1083
1094
|
shape: t.shape,
|
|
1084
1095
|
dtype: t.dtype,
|
|
@@ -1098,7 +1109,7 @@ var BatchNorm1d = class extends Module {
|
|
|
1098
1109
|
constructor(numFeatures, options = {}) {
|
|
1099
1110
|
super();
|
|
1100
1111
|
if (!Number.isFinite(numFeatures) || numFeatures <= 0 || Math.trunc(numFeatures) !== numFeatures) {
|
|
1101
|
-
throw new
|
|
1112
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1102
1113
|
"numFeatures must be a positive integer",
|
|
1103
1114
|
"numFeatures",
|
|
1104
1115
|
numFeatures
|
|
@@ -1107,11 +1118,11 @@ var BatchNorm1d = class extends Module {
|
|
|
1107
1118
|
this.numFeatures = numFeatures;
|
|
1108
1119
|
this.eps = options.eps ?? 1e-5;
|
|
1109
1120
|
if (!Number.isFinite(this.eps) || this.eps <= 0) {
|
|
1110
|
-
throw new
|
|
1121
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("eps must be a positive number", "eps", this.eps);
|
|
1111
1122
|
}
|
|
1112
1123
|
this.momentum = options.momentum ?? 0.1;
|
|
1113
1124
|
if (!Number.isFinite(this.momentum) || this.momentum < 0 || this.momentum > 1) {
|
|
1114
|
-
throw new
|
|
1125
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1115
1126
|
"momentum must be in range [0, 1]",
|
|
1116
1127
|
"momentum",
|
|
1117
1128
|
this.momentum
|
|
@@ -1120,17 +1131,17 @@ var BatchNorm1d = class extends Module {
|
|
|
1120
1131
|
this.affine = options.affine ?? true;
|
|
1121
1132
|
this.trackRunningStats = options.trackRunningStats ?? true;
|
|
1122
1133
|
if (this.affine) {
|
|
1123
|
-
const gamma =
|
|
1124
|
-
const beta =
|
|
1125
|
-
this.gamma =
|
|
1126
|
-
this.beta =
|
|
1134
|
+
const gamma = chunk724CXHFH_cjs.ones([numFeatures]);
|
|
1135
|
+
const beta = chunk724CXHFH_cjs.zeros([numFeatures]);
|
|
1136
|
+
this.gamma = chunk724CXHFH_cjs.parameter(gamma);
|
|
1137
|
+
this.beta = chunk724CXHFH_cjs.parameter(beta);
|
|
1127
1138
|
this.registerParameter("weight", this.gamma);
|
|
1128
1139
|
this.registerParameter("bias", this.beta);
|
|
1129
1140
|
}
|
|
1130
|
-
this.runningMean =
|
|
1141
|
+
this.runningMean = chunk724CXHFH_cjs.GradTensor.fromTensor(chunk724CXHFH_cjs.zeros([numFeatures]), {
|
|
1131
1142
|
requiresGrad: false
|
|
1132
1143
|
});
|
|
1133
|
-
this.runningVar =
|
|
1144
|
+
this.runningVar = chunk724CXHFH_cjs.GradTensor.fromTensor(chunk724CXHFH_cjs.ones([numFeatures]), {
|
|
1134
1145
|
requiresGrad: false
|
|
1135
1146
|
});
|
|
1136
1147
|
if (this.trackRunningStats) {
|
|
@@ -1139,17 +1150,17 @@ var BatchNorm1d = class extends Module {
|
|
|
1139
1150
|
}
|
|
1140
1151
|
}
|
|
1141
1152
|
forward(x) {
|
|
1142
|
-
const input =
|
|
1153
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1143
1154
|
const inputDtype = input.dtype;
|
|
1144
1155
|
if (inputDtype === "string") {
|
|
1145
|
-
throw new
|
|
1156
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("BatchNorm1d does not support string dtype");
|
|
1146
1157
|
}
|
|
1147
1158
|
if (input.ndim !== 2 && input.ndim !== 3) {
|
|
1148
|
-
throw new
|
|
1159
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`BatchNorm1d expects 2D or 3D input; got ndim=${input.ndim}`);
|
|
1149
1160
|
}
|
|
1150
1161
|
const nFeatures = input.shape[1] ?? 0;
|
|
1151
1162
|
if (nFeatures !== this.numFeatures) {
|
|
1152
|
-
throw new
|
|
1163
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected ${this.numFeatures} features, got ${nFeatures}`);
|
|
1153
1164
|
}
|
|
1154
1165
|
const useBatchStats = this.training || !this.trackRunningStats;
|
|
1155
1166
|
let mean2;
|
|
@@ -1159,36 +1170,36 @@ var BatchNorm1d = class extends Module {
|
|
|
1159
1170
|
const batch = input.shape[0] ?? 0;
|
|
1160
1171
|
const length = input.shape[2] ?? 0;
|
|
1161
1172
|
const flat = batch * length;
|
|
1162
|
-
const numericInputDtype =
|
|
1163
|
-
inputReshaped = input.transpose([0, 2, 1]).mul(
|
|
1173
|
+
const numericInputDtype = chunkZ6BGACIH_cjs.ensureNumericDType(inputDtype, "BatchNorm1d");
|
|
1174
|
+
inputReshaped = input.transpose([0, 2, 1]).mul(chunk724CXHFH_cjs.GradTensor.scalar(1, { dtype: numericInputDtype })).reshape([flat, nFeatures]);
|
|
1164
1175
|
}
|
|
1165
1176
|
if (useBatchStats) {
|
|
1166
1177
|
if (inputReshaped.shape[0] === 0) {
|
|
1167
|
-
throw new
|
|
1178
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1168
1179
|
"BatchNorm requires at least one element",
|
|
1169
1180
|
"input",
|
|
1170
1181
|
input.shape
|
|
1171
1182
|
);
|
|
1172
1183
|
}
|
|
1173
1184
|
mean2 = inputReshaped.mean(0);
|
|
1174
|
-
varVal =
|
|
1185
|
+
varVal = chunk724CXHFH_cjs.variance2(inputReshaped, 0, 0);
|
|
1175
1186
|
if (this.trackRunningStats) {
|
|
1176
|
-
|
|
1187
|
+
chunk724CXHFH_cjs.noGrad(() => {
|
|
1177
1188
|
const n = inputReshaped.shape[0] ?? 0;
|
|
1178
|
-
const unbiasedVar = n > 1 ?
|
|
1189
|
+
const unbiasedVar = n > 1 ? chunk724CXHFH_cjs.variance2(inputReshaped, 0, 1) : chunk724CXHFH_cjs.variance2(inputReshaped, 0, 0);
|
|
1179
1190
|
const m = this.momentum;
|
|
1180
1191
|
const statsDtype = this.runningMean.dtype;
|
|
1181
1192
|
if (statsDtype === "string") {
|
|
1182
|
-
throw new
|
|
1193
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("BatchNorm running statistics must be numeric");
|
|
1183
1194
|
}
|
|
1184
|
-
const oneMinusM =
|
|
1185
|
-
const mScalar =
|
|
1195
|
+
const oneMinusM = chunk724CXHFH_cjs.GradTensor.scalar(1 - m, { dtype: statsDtype });
|
|
1196
|
+
const mScalar = chunk724CXHFH_cjs.GradTensor.scalar(m, { dtype: statsDtype });
|
|
1186
1197
|
const newMean = this.runningMean.mul(oneMinusM).add(mean2.mul(mScalar));
|
|
1187
1198
|
const newVar = this.runningVar.mul(oneMinusM).add(unbiasedVar.mul(mScalar));
|
|
1188
|
-
this.runningMean =
|
|
1199
|
+
this.runningMean = chunk724CXHFH_cjs.GradTensor.fromTensor(newMean.tensor, {
|
|
1189
1200
|
requiresGrad: false
|
|
1190
1201
|
});
|
|
1191
|
-
this.runningVar =
|
|
1202
|
+
this.runningVar = chunk724CXHFH_cjs.GradTensor.fromTensor(newVar.tensor, {
|
|
1192
1203
|
requiresGrad: false
|
|
1193
1204
|
});
|
|
1194
1205
|
this.registerBuffer("running_mean", this.runningMean.tensor);
|
|
@@ -1208,7 +1219,7 @@ var BatchNorm1d = class extends Module {
|
|
|
1208
1219
|
meanBroadcast = mean2.reshape([1, nFeatures]);
|
|
1209
1220
|
varBroadcast = varVal.reshape([1, nFeatures]);
|
|
1210
1221
|
}
|
|
1211
|
-
const epsTensor =
|
|
1222
|
+
const epsTensor = chunk724CXHFH_cjs.GradTensor.scalar(this.eps, { dtype: inputDtype });
|
|
1212
1223
|
const denom = varBroadcast.add(epsTensor).sqrt();
|
|
1213
1224
|
let out = input.sub(meanBroadcast).div(denom);
|
|
1214
1225
|
if (this.affine && this.gamma && this.beta) {
|
|
@@ -1239,7 +1250,7 @@ var LayerNorm = class extends Module {
|
|
|
1239
1250
|
super();
|
|
1240
1251
|
this.normalizedShape = typeof normalizedShape === "number" ? [normalizedShape] : Array.from(normalizedShape);
|
|
1241
1252
|
if (this.normalizedShape.length === 0) {
|
|
1242
|
-
throw new
|
|
1253
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1243
1254
|
"normalizedShape must contain at least one dimension",
|
|
1244
1255
|
"normalizedShape",
|
|
1245
1256
|
normalizedShape
|
|
@@ -1247,7 +1258,7 @@ var LayerNorm = class extends Module {
|
|
|
1247
1258
|
}
|
|
1248
1259
|
for (const dim of this.normalizedShape) {
|
|
1249
1260
|
if (!Number.isFinite(dim) || dim <= 0 || Math.trunc(dim) !== dim) {
|
|
1250
|
-
throw new
|
|
1261
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1251
1262
|
"All dimensions in normalizedShape must be positive integers",
|
|
1252
1263
|
"normalizedShape",
|
|
1253
1264
|
normalizedShape
|
|
@@ -1256,38 +1267,38 @@ var LayerNorm = class extends Module {
|
|
|
1256
1267
|
}
|
|
1257
1268
|
this.eps = options.eps ?? 1e-5;
|
|
1258
1269
|
if (!Number.isFinite(this.eps) || this.eps <= 0) {
|
|
1259
|
-
throw new
|
|
1270
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("eps must be a positive number", "eps", this.eps);
|
|
1260
1271
|
}
|
|
1261
1272
|
this.elementwiseAffine = options.elementwiseAffine ?? true;
|
|
1262
1273
|
if (this.elementwiseAffine) {
|
|
1263
|
-
this.gamma =
|
|
1264
|
-
this.beta =
|
|
1274
|
+
this.gamma = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.ones(this.normalizedShape));
|
|
1275
|
+
this.beta = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.zeros(this.normalizedShape));
|
|
1265
1276
|
this.registerParameter("weight", this.gamma);
|
|
1266
1277
|
this.registerParameter("bias", this.beta);
|
|
1267
1278
|
}
|
|
1268
1279
|
}
|
|
1269
1280
|
forward(x) {
|
|
1270
|
-
const input =
|
|
1281
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1271
1282
|
const inputDtype = input.dtype;
|
|
1272
1283
|
if (inputDtype === "string") {
|
|
1273
|
-
throw new
|
|
1284
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("LayerNorm does not support string dtype");
|
|
1274
1285
|
}
|
|
1275
1286
|
let workingInput = input;
|
|
1276
|
-
if (!
|
|
1287
|
+
if (!chunk724CXHFH_cjs.isContiguous(input.tensor.shape, input.tensor.strides)) {
|
|
1277
1288
|
const contiguous = toContiguousTensor(input.tensor);
|
|
1278
|
-
workingInput =
|
|
1289
|
+
workingInput = chunk724CXHFH_cjs.GradTensor.fromTensor(contiguous, {
|
|
1279
1290
|
requiresGrad: input.requiresGrad
|
|
1280
1291
|
});
|
|
1281
1292
|
}
|
|
1282
1293
|
const inputShape = workingInput.shape;
|
|
1283
1294
|
const normShape = this.normalizedShape;
|
|
1284
1295
|
if (normShape.length > inputShape.length) {
|
|
1285
|
-
throw new
|
|
1296
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Input shape ${inputShape} too small for normalizedShape ${normShape}`);
|
|
1286
1297
|
}
|
|
1287
1298
|
const suffixStart = inputShape.length - normShape.length;
|
|
1288
1299
|
for (let i = 0; i < normShape.length; i++) {
|
|
1289
1300
|
if (inputShape[suffixStart + i] !== normShape[i]) {
|
|
1290
|
-
throw new
|
|
1301
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1291
1302
|
`Input shape ${inputShape} does not end with normalizedShape ${normShape}`
|
|
1292
1303
|
);
|
|
1293
1304
|
}
|
|
@@ -1297,9 +1308,9 @@ var LayerNorm = class extends Module {
|
|
|
1297
1308
|
const flattenedShape = [...outerDims, normSize];
|
|
1298
1309
|
const inputReshaped = workingInput.reshape(flattenedShape);
|
|
1299
1310
|
const mean2 = inputReshaped.mean(-1, true);
|
|
1300
|
-
const varVal =
|
|
1311
|
+
const varVal = chunk724CXHFH_cjs.variance2(inputReshaped, -1, 0);
|
|
1301
1312
|
const varReshaped = varVal.reshape(mean2.shape);
|
|
1302
|
-
const epsTensor =
|
|
1313
|
+
const epsTensor = chunk724CXHFH_cjs.GradTensor.scalar(this.eps, { dtype: inputDtype });
|
|
1303
1314
|
const denom = varReshaped.add(epsTensor).sqrt();
|
|
1304
1315
|
const normalizedReshaped = inputReshaped.sub(mean2).div(denom);
|
|
1305
1316
|
let out = normalizedReshaped.reshape(inputShape);
|
|
@@ -1351,13 +1362,13 @@ var MultiheadAttention = class extends Module {
|
|
|
1351
1362
|
constructor(embedDim, numHeads, options = {}) {
|
|
1352
1363
|
super();
|
|
1353
1364
|
if (!Number.isInteger(embedDim) || embedDim <= 0) {
|
|
1354
|
-
throw new
|
|
1365
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("embedDim must be a positive integer", "embedDim", embedDim);
|
|
1355
1366
|
}
|
|
1356
1367
|
if (!Number.isInteger(numHeads) || numHeads <= 0) {
|
|
1357
|
-
throw new
|
|
1368
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("numHeads must be a positive integer", "numHeads", numHeads);
|
|
1358
1369
|
}
|
|
1359
1370
|
if (embedDim % numHeads !== 0) {
|
|
1360
|
-
throw new
|
|
1371
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1361
1372
|
`embedDim (${embedDim}) must be divisible by numHeads (${numHeads})`,
|
|
1362
1373
|
"embedDim",
|
|
1363
1374
|
embedDim
|
|
@@ -1365,7 +1376,7 @@ var MultiheadAttention = class extends Module {
|
|
|
1365
1376
|
}
|
|
1366
1377
|
const dropout2 = options.dropout ?? 0;
|
|
1367
1378
|
if (!Number.isFinite(dropout2) || dropout2 < 0 || dropout2 >= 1) {
|
|
1368
|
-
throw new
|
|
1379
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("dropout must be in [0, 1)", "dropout", dropout2);
|
|
1369
1380
|
}
|
|
1370
1381
|
this.embedDim = embedDim;
|
|
1371
1382
|
this.numHeads = numHeads;
|
|
@@ -1374,19 +1385,19 @@ var MultiheadAttention = class extends Module {
|
|
|
1374
1385
|
this.useBias = options.bias ?? true;
|
|
1375
1386
|
this.dropout = dropout2;
|
|
1376
1387
|
const stdDev = Math.sqrt(2 / (embedDim + embedDim));
|
|
1377
|
-
this.wQ =
|
|
1378
|
-
this.wK =
|
|
1379
|
-
this.wV =
|
|
1380
|
-
this.wO =
|
|
1388
|
+
this.wQ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1389
|
+
this.wK = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1390
|
+
this.wV = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1391
|
+
this.wO = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1381
1392
|
this.registerParameter("in_proj_weight_q", this.wQ);
|
|
1382
1393
|
this.registerParameter("in_proj_weight_k", this.wK);
|
|
1383
1394
|
this.registerParameter("in_proj_weight_v", this.wV);
|
|
1384
1395
|
this.registerParameter("out_proj_weight", this.wO);
|
|
1385
1396
|
if (this.useBias) {
|
|
1386
|
-
this.bQ =
|
|
1387
|
-
this.bK =
|
|
1388
|
-
this.bV =
|
|
1389
|
-
this.bO =
|
|
1397
|
+
this.bQ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.zeros([embedDim]));
|
|
1398
|
+
this.bK = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.zeros([embedDim]));
|
|
1399
|
+
this.bV = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.zeros([embedDim]));
|
|
1400
|
+
this.bO = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.zeros([embedDim]));
|
|
1390
1401
|
this.registerParameter("in_proj_bias_q", this.bQ);
|
|
1391
1402
|
this.registerParameter("in_proj_bias_k", this.bK);
|
|
1392
1403
|
this.registerParameter("in_proj_bias_v", this.bV);
|
|
@@ -1403,7 +1414,7 @@ var MultiheadAttention = class extends Module {
|
|
|
1403
1414
|
*/
|
|
1404
1415
|
forward(...inputs) {
|
|
1405
1416
|
if (inputs.length < 1 || inputs.length > 3) {
|
|
1406
|
-
throw new
|
|
1417
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1407
1418
|
"MultiheadAttention.forward expects 1 to 3 input tensors",
|
|
1408
1419
|
"inputs",
|
|
1409
1420
|
inputs.length
|
|
@@ -1411,25 +1422,25 @@ var MultiheadAttention = class extends Module {
|
|
|
1411
1422
|
}
|
|
1412
1423
|
const queryInput = inputs[0];
|
|
1413
1424
|
if (queryInput === void 0) {
|
|
1414
|
-
throw new
|
|
1425
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Query tensor is required", "query", queryInput);
|
|
1415
1426
|
}
|
|
1416
|
-
const query =
|
|
1427
|
+
const query = chunk724CXHFH_cjs.GradTensor.isGradTensor(queryInput) ? queryInput : chunk724CXHFH_cjs.GradTensor.fromTensor(queryInput);
|
|
1417
1428
|
const keyInput = inputs[1] ?? queryInput;
|
|
1418
|
-
const key =
|
|
1429
|
+
const key = chunk724CXHFH_cjs.GradTensor.isGradTensor(keyInput) ? keyInput : chunk724CXHFH_cjs.GradTensor.fromTensor(keyInput);
|
|
1419
1430
|
const valueInput = inputs[2] ?? queryInput;
|
|
1420
|
-
const value =
|
|
1421
|
-
if (query.dtype === "string") throw new
|
|
1431
|
+
const value = chunk724CXHFH_cjs.GradTensor.isGradTensor(valueInput) ? valueInput : chunk724CXHFH_cjs.GradTensor.fromTensor(valueInput);
|
|
1432
|
+
if (query.dtype === "string") throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1422
1433
|
if (query.ndim !== key.ndim || query.ndim !== value.ndim) {
|
|
1423
|
-
throw new
|
|
1434
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("query, key, and value must have same rank");
|
|
1424
1435
|
}
|
|
1425
1436
|
if (query.ndim !== 2 && query.ndim !== 3) {
|
|
1426
|
-
throw new
|
|
1437
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Query must be 2D or 3D; got ndim=${query.ndim}`);
|
|
1427
1438
|
}
|
|
1428
1439
|
if (key.ndim !== 2 && key.ndim !== 3) {
|
|
1429
|
-
throw new
|
|
1440
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Key must be 2D or 3D; got ndim=${key.ndim}`);
|
|
1430
1441
|
}
|
|
1431
1442
|
if (value.ndim !== 2 && value.ndim !== 3) {
|
|
1432
|
-
throw new
|
|
1443
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Value must be 2D or 3D; got ndim=${value.ndim}`);
|
|
1433
1444
|
}
|
|
1434
1445
|
let q = query;
|
|
1435
1446
|
let k = key;
|
|
@@ -1443,21 +1454,21 @@ var MultiheadAttention = class extends Module {
|
|
|
1443
1454
|
const seqLenV = v.shape[1] ?? 0;
|
|
1444
1455
|
const embedDim = q.shape[2] ?? 0;
|
|
1445
1456
|
if (embedDim !== this.embedDim) {
|
|
1446
|
-
throw new
|
|
1457
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Query embedDim mismatch: expected ${this.embedDim}, got ${embedDim}`);
|
|
1447
1458
|
}
|
|
1448
1459
|
if (k.shape[2] !== this.embedDim) {
|
|
1449
|
-
throw new
|
|
1460
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Key embedDim mismatch: expected ${this.embedDim}, got ${k.shape[2]}`);
|
|
1450
1461
|
}
|
|
1451
1462
|
if (v.shape[2] !== this.embedDim) {
|
|
1452
|
-
throw new
|
|
1463
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Value embedDim mismatch: expected ${this.embedDim}, got ${v.shape[2]}`);
|
|
1453
1464
|
}
|
|
1454
1465
|
if (k.shape[0] !== batchSize || v.shape[0] !== batchSize) {
|
|
1455
|
-
throw new
|
|
1466
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1456
1467
|
`batch size mismatch: query=${batchSize}, key=${k.shape[0]}, value=${v.shape[0]}`
|
|
1457
1468
|
);
|
|
1458
1469
|
}
|
|
1459
1470
|
if (seqLenK !== seqLenV) {
|
|
1460
|
-
throw new
|
|
1471
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Key/value sequence length mismatch: key=${seqLenK}, value=${seqLenV}`);
|
|
1461
1472
|
}
|
|
1462
1473
|
let Q = q.matmul(this.wQ.transpose());
|
|
1463
1474
|
if (this.bQ) Q = Q.add(this.bQ);
|
|
@@ -1471,12 +1482,12 @@ var MultiheadAttention = class extends Module {
|
|
|
1471
1482
|
K = K.reshape([batchSize, seqLenK, H, D]).transpose([0, 2, 1, 3]);
|
|
1472
1483
|
V = V.reshape([batchSize, seqLenV, H, D]).transpose([0, 2, 1, 3]);
|
|
1473
1484
|
let scores = Q.matmul(K.transpose([0, 1, 3, 2]));
|
|
1474
|
-
scores = scores.div(
|
|
1475
|
-
let attn =
|
|
1476
|
-
attn =
|
|
1485
|
+
scores = scores.div(chunk724CXHFH_cjs.GradTensor.scalar(this.scale));
|
|
1486
|
+
let attn = chunk724CXHFH_cjs.softmax2(scores, -1);
|
|
1487
|
+
attn = chunk724CXHFH_cjs.dropout(attn, this.dropout, this.training);
|
|
1477
1488
|
const context = attn.matmul(V);
|
|
1478
|
-
const contextDtype =
|
|
1479
|
-
const contextReshaped = context.transpose([0, 2, 1, 3]).mul(
|
|
1489
|
+
const contextDtype = chunkZ6BGACIH_cjs.ensureNumericDType(context.dtype, "MultiheadAttention");
|
|
1490
|
+
const contextReshaped = context.transpose([0, 2, 1, 3]).mul(chunk724CXHFH_cjs.GradTensor.scalar(1, { dtype: contextDtype })).reshape([batchSize, seqLenQ, this.embedDim]);
|
|
1480
1491
|
let output = contextReshaped.matmul(this.wO.transpose());
|
|
1481
1492
|
if (this.bO) output = output.add(this.bO);
|
|
1482
1493
|
if (query.ndim === 2) {
|
|
@@ -1503,33 +1514,58 @@ var TransformerEncoderLayer = class extends Module {
|
|
|
1503
1514
|
dropout1;
|
|
1504
1515
|
dropout2;
|
|
1505
1516
|
dropout3;
|
|
1506
|
-
constructor(
|
|
1517
|
+
constructor(dModelOrOpts, nHead, dFFOrOptions, options = {}) {
|
|
1507
1518
|
super();
|
|
1519
|
+
let resolvedDModel;
|
|
1520
|
+
let resolvedNHead;
|
|
1521
|
+
let resolvedDFF;
|
|
1522
|
+
let resolvedDropout;
|
|
1523
|
+
let resolvedEps;
|
|
1524
|
+
if (typeof dModelOrOpts === "object") {
|
|
1525
|
+
resolvedDModel = dModelOrOpts.dModel;
|
|
1526
|
+
resolvedNHead = dModelOrOpts.nHead;
|
|
1527
|
+
resolvedDFF = dModelOrOpts.dFF ?? dModelOrOpts.dimFeedforward ?? 2048;
|
|
1528
|
+
resolvedDropout = dModelOrOpts.dropout;
|
|
1529
|
+
resolvedEps = dModelOrOpts.eps;
|
|
1530
|
+
} else if (typeof dFFOrOptions === "object") {
|
|
1531
|
+
resolvedDModel = dModelOrOpts;
|
|
1532
|
+
resolvedNHead = nHead ?? 1;
|
|
1533
|
+
resolvedDFF = dFFOrOptions.dFF ?? dFFOrOptions.dimFeedforward ?? 2048;
|
|
1534
|
+
resolvedDropout = dFFOrOptions.dropout;
|
|
1535
|
+
resolvedEps = dFFOrOptions.eps;
|
|
1536
|
+
} else {
|
|
1537
|
+
resolvedDModel = dModelOrOpts;
|
|
1538
|
+
resolvedNHead = nHead ?? 1;
|
|
1539
|
+
resolvedDFF = dFFOrOptions ?? 2048;
|
|
1540
|
+
resolvedDropout = options.dropout;
|
|
1541
|
+
resolvedEps = options.eps;
|
|
1542
|
+
}
|
|
1543
|
+
const dModel = resolvedDModel;
|
|
1508
1544
|
if (!Number.isInteger(dModel) || dModel <= 0) {
|
|
1509
|
-
throw new
|
|
1545
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("dModel must be a positive integer", "dModel", dModel);
|
|
1510
1546
|
}
|
|
1511
|
-
if (!Number.isInteger(
|
|
1512
|
-
throw new
|
|
1547
|
+
if (!Number.isInteger(resolvedNHead) || resolvedNHead <= 0) {
|
|
1548
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("nHead must be a positive integer", "nHead", resolvedNHead);
|
|
1513
1549
|
}
|
|
1514
|
-
if (dModel %
|
|
1515
|
-
throw new
|
|
1516
|
-
`dModel (${dModel}) must be divisible by nHead (${
|
|
1550
|
+
if (dModel % resolvedNHead !== 0) {
|
|
1551
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1552
|
+
`dModel (${dModel}) must be divisible by nHead (${resolvedNHead})`,
|
|
1517
1553
|
"dModel",
|
|
1518
1554
|
dModel
|
|
1519
1555
|
);
|
|
1520
1556
|
}
|
|
1521
|
-
if (!Number.isInteger(
|
|
1522
|
-
throw new
|
|
1557
|
+
if (!Number.isInteger(resolvedDFF) || resolvedDFF <= 0) {
|
|
1558
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("dFF must be a positive integer", "dFF", resolvedDFF);
|
|
1523
1559
|
}
|
|
1524
|
-
const dropout2 =
|
|
1525
|
-
const eps =
|
|
1560
|
+
const dropout2 = resolvedDropout ?? 0.1;
|
|
1561
|
+
const eps = resolvedEps ?? 1e-5;
|
|
1526
1562
|
this.dModel = dModel;
|
|
1527
|
-
this.nHead =
|
|
1528
|
-
this.dFF =
|
|
1563
|
+
this.nHead = resolvedNHead;
|
|
1564
|
+
this.dFF = resolvedDFF;
|
|
1529
1565
|
this.dropout = dropout2;
|
|
1530
|
-
this.selfAttn = new MultiheadAttention(dModel,
|
|
1531
|
-
this.linear1 = new Linear(dModel,
|
|
1532
|
-
this.linear2 = new Linear(
|
|
1566
|
+
this.selfAttn = new MultiheadAttention(dModel, resolvedNHead, { dropout: dropout2 });
|
|
1567
|
+
this.linear1 = new Linear(dModel, resolvedDFF);
|
|
1568
|
+
this.linear2 = new Linear(resolvedDFF, dModel);
|
|
1533
1569
|
this.norm1 = new LayerNorm(dModel, { eps });
|
|
1534
1570
|
this.norm2 = new LayerNorm(dModel, { eps });
|
|
1535
1571
|
this.dropout1 = new Dropout(dropout2);
|
|
@@ -1551,9 +1587,9 @@ var TransformerEncoderLayer = class extends Module {
|
|
|
1551
1587
|
* @returns Output of same shape as input
|
|
1552
1588
|
*/
|
|
1553
1589
|
forward(src) {
|
|
1554
|
-
const input =
|
|
1590
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(src) ? src : chunk724CXHFH_cjs.GradTensor.fromTensor(src);
|
|
1555
1591
|
if (input.dtype === "string") {
|
|
1556
|
-
throw new
|
|
1592
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("TransformerEncoderLayer does not support string dtype");
|
|
1557
1593
|
}
|
|
1558
1594
|
let src2 = this.selfAttn.forward(input, input, input);
|
|
1559
1595
|
src2 = this.dropout1.forward(src2);
|
|
@@ -1579,7 +1615,7 @@ function normalizePair(name, value, allowZero, description) {
|
|
|
1579
1615
|
const first = arr[0];
|
|
1580
1616
|
const second = arr[1];
|
|
1581
1617
|
if (arr.length !== 2 || first === void 0 || second === void 0 || !Number.isInteger(first) || !Number.isInteger(second) || (allowZero ? first < 0 || second < 0 : first <= 0 || second <= 0)) {
|
|
1582
|
-
throw new
|
|
1618
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`${name} must be ${description}`, name, value);
|
|
1583
1619
|
}
|
|
1584
1620
|
return [first, second];
|
|
1585
1621
|
}
|
|
@@ -1595,21 +1631,21 @@ var Conv1d = class extends Module {
|
|
|
1595
1631
|
constructor(inChannels, outChannels, kernelSize, options = {}) {
|
|
1596
1632
|
super();
|
|
1597
1633
|
if (inChannels <= 0 || !Number.isInteger(inChannels)) {
|
|
1598
|
-
throw new
|
|
1634
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1599
1635
|
"inChannels must be a positive integer",
|
|
1600
1636
|
"inChannels",
|
|
1601
1637
|
inChannels
|
|
1602
1638
|
);
|
|
1603
1639
|
}
|
|
1604
1640
|
if (outChannels <= 0 || !Number.isInteger(outChannels)) {
|
|
1605
|
-
throw new
|
|
1641
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1606
1642
|
"outChannels must be a positive integer",
|
|
1607
1643
|
"outChannels",
|
|
1608
1644
|
outChannels
|
|
1609
1645
|
);
|
|
1610
1646
|
}
|
|
1611
1647
|
if (kernelSize <= 0 || !Number.isInteger(kernelSize)) {
|
|
1612
|
-
throw new
|
|
1648
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1613
1649
|
"kernelSize must be a positive integer",
|
|
1614
1650
|
"kernelSize",
|
|
1615
1651
|
kernelSize
|
|
@@ -1617,11 +1653,11 @@ var Conv1d = class extends Module {
|
|
|
1617
1653
|
}
|
|
1618
1654
|
const stride = options.stride ?? 1;
|
|
1619
1655
|
if (stride <= 0 || !Number.isInteger(stride)) {
|
|
1620
|
-
throw new
|
|
1656
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("stride must be a positive integer", "stride", stride);
|
|
1621
1657
|
}
|
|
1622
1658
|
const padding = options.padding ?? 0;
|
|
1623
1659
|
if (padding < 0 || !Number.isInteger(padding)) {
|
|
1624
|
-
throw new
|
|
1660
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("padding must be a non-negative integer", "padding", padding);
|
|
1625
1661
|
}
|
|
1626
1662
|
this.inChannels = inChannels;
|
|
1627
1663
|
this.outChannels = outChannels;
|
|
@@ -1633,36 +1669,36 @@ var Conv1d = class extends Module {
|
|
|
1633
1669
|
}
|
|
1634
1670
|
initializeParameters() {
|
|
1635
1671
|
const k = 1 / Math.sqrt(this.inChannels * this.kernelSize);
|
|
1636
|
-
const weight =
|
|
1637
|
-
this.weight_ =
|
|
1672
|
+
const weight = chunk724CXHFH_cjs.randn([this.outChannels, this.inChannels, this.kernelSize]);
|
|
1673
|
+
this.weight_ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(weight, k));
|
|
1638
1674
|
this.registerParameter("weight", this.weight_);
|
|
1639
1675
|
if (this.bias) {
|
|
1640
|
-
const biasInit =
|
|
1641
|
-
this.bias_ =
|
|
1676
|
+
const biasInit = chunk724CXHFH_cjs.randn([this.outChannels]);
|
|
1677
|
+
this.bias_ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(biasInit, k));
|
|
1642
1678
|
this.registerParameter("bias", this.bias_);
|
|
1643
1679
|
}
|
|
1644
1680
|
}
|
|
1645
1681
|
forward(x) {
|
|
1646
|
-
const input =
|
|
1682
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1647
1683
|
if (input.dtype === "string") {
|
|
1648
|
-
throw new
|
|
1684
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1649
1685
|
}
|
|
1650
1686
|
if (input.ndim !== 3) {
|
|
1651
|
-
throw new
|
|
1687
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Conv1d expects 3D input (batch, channels, length), got ${input.ndim}D`);
|
|
1652
1688
|
}
|
|
1653
1689
|
const batch = input.shape[0] ?? 0;
|
|
1654
1690
|
const inC = input.shape[1] ?? 0;
|
|
1655
1691
|
const inL = input.shape[2] ?? 0;
|
|
1656
1692
|
if (inC !== this.inChannels) {
|
|
1657
|
-
throw new
|
|
1693
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected ${this.inChannels} input channels, got ${inC}`);
|
|
1658
1694
|
}
|
|
1659
1695
|
const weight = this.weight_;
|
|
1660
|
-
if (!weight) throw new
|
|
1696
|
+
if (!weight) throw new chunkZ6BGACIH_cjs.NotFittedError("Weight not initialized");
|
|
1661
1697
|
const input2d = input.reshape([batch, inC, 1, inL]);
|
|
1662
1698
|
const kernelSize = [1, this.kernelSize];
|
|
1663
1699
|
const stride = [1, this.stride];
|
|
1664
1700
|
const padding = [0, this.padding];
|
|
1665
|
-
const cols =
|
|
1701
|
+
const cols = chunk724CXHFH_cjs.im2col2(input2d, kernelSize, stride, padding);
|
|
1666
1702
|
const weightFlat = weight.reshape([this.outChannels, this.inChannels * this.kernelSize]);
|
|
1667
1703
|
const out = cols.matmul(weightFlat.transpose());
|
|
1668
1704
|
const outTransposed = out.transpose([0, 2, 1]);
|
|
@@ -1674,7 +1710,7 @@ var Conv1d = class extends Module {
|
|
|
1674
1710
|
}
|
|
1675
1711
|
get weight() {
|
|
1676
1712
|
if (!this.weight_) {
|
|
1677
|
-
throw new
|
|
1713
|
+
throw new chunkZ6BGACIH_cjs.NotFittedError("Weight not initialized");
|
|
1678
1714
|
}
|
|
1679
1715
|
return this.weight_;
|
|
1680
1716
|
}
|
|
@@ -1691,14 +1727,14 @@ var Conv2d = class extends Module {
|
|
|
1691
1727
|
constructor(inChannels, outChannels, kernelSize, options = {}) {
|
|
1692
1728
|
super();
|
|
1693
1729
|
if (inChannels <= 0 || !Number.isInteger(inChannels)) {
|
|
1694
|
-
throw new
|
|
1730
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1695
1731
|
"inChannels must be a positive integer",
|
|
1696
1732
|
"inChannels",
|
|
1697
1733
|
inChannels
|
|
1698
1734
|
);
|
|
1699
1735
|
}
|
|
1700
1736
|
if (outChannels <= 0 || !Number.isInteger(outChannels)) {
|
|
1701
|
-
throw new
|
|
1737
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
1702
1738
|
"outChannels must be a positive integer",
|
|
1703
1739
|
"outChannels",
|
|
1704
1740
|
outChannels
|
|
@@ -1736,22 +1772,22 @@ var Conv2d = class extends Module {
|
|
|
1736
1772
|
const kH = this.kernelSize[0] ?? 1;
|
|
1737
1773
|
const kW = this.kernelSize[1] ?? 1;
|
|
1738
1774
|
const k = 1 / Math.sqrt(this.inChannels * kH * kW);
|
|
1739
|
-
const weight =
|
|
1740
|
-
this.weight_ =
|
|
1775
|
+
const weight = chunk724CXHFH_cjs.randn([this.outChannels, this.inChannels, kH, kW]);
|
|
1776
|
+
this.weight_ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(weight, k));
|
|
1741
1777
|
this.registerParameter("weight", this.weight_);
|
|
1742
1778
|
if (this.useBias) {
|
|
1743
|
-
const biasInit =
|
|
1744
|
-
this.bias_ =
|
|
1779
|
+
const biasInit = chunk724CXHFH_cjs.randn([this.outChannels]);
|
|
1780
|
+
this.bias_ = chunk724CXHFH_cjs.parameter(chunk724CXHFH_cjs.mulScalar(biasInit, k));
|
|
1745
1781
|
this.registerParameter("bias", this.bias_);
|
|
1746
1782
|
}
|
|
1747
1783
|
}
|
|
1748
1784
|
forward(x) {
|
|
1749
|
-
const input =
|
|
1785
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1750
1786
|
if (input.dtype === "string") {
|
|
1751
|
-
throw new
|
|
1787
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1752
1788
|
}
|
|
1753
1789
|
if (input.ndim !== 4) {
|
|
1754
|
-
throw new
|
|
1790
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1755
1791
|
`Conv2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1756
1792
|
);
|
|
1757
1793
|
}
|
|
@@ -1760,14 +1796,14 @@ var Conv2d = class extends Module {
|
|
|
1760
1796
|
const inH = input.shape[2] ?? 0;
|
|
1761
1797
|
const inW = input.shape[3] ?? 0;
|
|
1762
1798
|
if (inC !== this.inChannels) {
|
|
1763
|
-
throw new
|
|
1799
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected ${this.inChannels} input channels, got ${inC}`);
|
|
1764
1800
|
}
|
|
1765
1801
|
const weight = this.weight_;
|
|
1766
|
-
if (!weight) throw new
|
|
1802
|
+
if (!weight) throw new chunkZ6BGACIH_cjs.NotFittedError("Weight not initialized");
|
|
1767
1803
|
const [kH, kW] = this.kernelSize;
|
|
1768
1804
|
const [sH, sW] = this.stride;
|
|
1769
1805
|
const [pH, pW] = this.padding;
|
|
1770
|
-
const cols =
|
|
1806
|
+
const cols = chunk724CXHFH_cjs.im2col2(input, [kH, kW], [sH, sW], [pH, pW]);
|
|
1771
1807
|
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1772
1808
|
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
1773
1809
|
const weightFlat = weight.reshape([this.outChannels, this.inChannels * kH * kW]);
|
|
@@ -1782,7 +1818,7 @@ var Conv2d = class extends Module {
|
|
|
1782
1818
|
}
|
|
1783
1819
|
get weight() {
|
|
1784
1820
|
if (!this.weight_) {
|
|
1785
|
-
throw new
|
|
1821
|
+
throw new chunkZ6BGACIH_cjs.NotFittedError("Weight not initialized");
|
|
1786
1822
|
}
|
|
1787
1823
|
return this.weight_;
|
|
1788
1824
|
}
|
|
@@ -1816,12 +1852,12 @@ var MaxPool2d = class extends Module {
|
|
|
1816
1852
|
this.padding = paddingArr;
|
|
1817
1853
|
}
|
|
1818
1854
|
forward(x) {
|
|
1819
|
-
const input =
|
|
1855
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1820
1856
|
if (input.dtype === "string") {
|
|
1821
|
-
throw new
|
|
1857
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1822
1858
|
}
|
|
1823
1859
|
if (input.ndim !== 4) {
|
|
1824
|
-
throw new
|
|
1860
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1825
1861
|
`MaxPool2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1826
1862
|
);
|
|
1827
1863
|
}
|
|
@@ -1833,7 +1869,7 @@ var MaxPool2d = class extends Module {
|
|
|
1833
1869
|
const [sH, sW] = this.stride;
|
|
1834
1870
|
const [pH, pW] = this.padding;
|
|
1835
1871
|
const inputReshaped = input.reshape([batch * channels, 1, inH, inW]);
|
|
1836
|
-
const cols =
|
|
1872
|
+
const cols = chunk724CXHFH_cjs.im2col2(inputReshaped, [kH, kW], [sH, sW], [pH, pW]);
|
|
1837
1873
|
const maxVals = cols.max(2);
|
|
1838
1874
|
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1839
1875
|
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
@@ -1869,12 +1905,12 @@ var AvgPool2d = class extends Module {
|
|
|
1869
1905
|
this.padding = paddingArr;
|
|
1870
1906
|
}
|
|
1871
1907
|
forward(x) {
|
|
1872
|
-
const input =
|
|
1908
|
+
const input = chunk724CXHFH_cjs.GradTensor.isGradTensor(x) ? x : chunk724CXHFH_cjs.GradTensor.fromTensor(x);
|
|
1873
1909
|
if (input.dtype === "string") {
|
|
1874
|
-
throw new
|
|
1910
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1875
1911
|
}
|
|
1876
1912
|
if (input.ndim !== 4) {
|
|
1877
|
-
throw new
|
|
1913
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1878
1914
|
`AvgPool2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1879
1915
|
);
|
|
1880
1916
|
}
|
|
@@ -1886,7 +1922,7 @@ var AvgPool2d = class extends Module {
|
|
|
1886
1922
|
const [sH, sW] = this.stride;
|
|
1887
1923
|
const [pH, pW] = this.padding;
|
|
1888
1924
|
const inputReshaped = input.reshape([batch * channels, 1, inH, inW]);
|
|
1889
|
-
const cols =
|
|
1925
|
+
const cols = chunk724CXHFH_cjs.im2col2(inputReshaped, [kH, kW], [sH, sW], [pH, pW]);
|
|
1890
1926
|
const meanVals = cols.mean(2);
|
|
1891
1927
|
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1892
1928
|
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
@@ -1897,25 +1933,25 @@ var AvgPool2d = class extends Module {
|
|
|
1897
1933
|
// src/nn/layers/recurrent.ts
|
|
1898
1934
|
function ensureFloatTensor(t, context) {
|
|
1899
1935
|
if (t.dtype === "string") {
|
|
1900
|
-
throw new
|
|
1936
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${context} does not support string dtype`);
|
|
1901
1937
|
}
|
|
1902
1938
|
if (t.dtype !== "float32" && t.dtype !== "float64") {
|
|
1903
|
-
throw new
|
|
1939
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${context} expects float32 or float64 dtype`);
|
|
1904
1940
|
}
|
|
1905
1941
|
}
|
|
1906
1942
|
function readNumeric(t, offset) {
|
|
1907
1943
|
const data = t.data;
|
|
1908
1944
|
if (Array.isArray(data)) {
|
|
1909
|
-
throw new
|
|
1945
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("String tensors are not supported");
|
|
1910
1946
|
}
|
|
1911
|
-
return
|
|
1947
|
+
return chunkZ6BGACIH_cjs.getElementAsNumber(data, offset);
|
|
1912
1948
|
}
|
|
1913
1949
|
function createFloatBuffer(size, dtype) {
|
|
1914
1950
|
return dtype === "float64" ? new Float64Array(size) : new Float32Array(size);
|
|
1915
1951
|
}
|
|
1916
1952
|
function validatePositiveInt(name, value) {
|
|
1917
1953
|
if (!Number.isInteger(value) || value <= 0) {
|
|
1918
|
-
throw new
|
|
1954
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`${name} must be a positive integer`, name, value);
|
|
1919
1955
|
}
|
|
1920
1956
|
}
|
|
1921
1957
|
function parseInput(input, batchFirst) {
|
|
@@ -1933,7 +1969,7 @@ function parseInput(input, batchFirst) {
|
|
|
1933
1969
|
};
|
|
1934
1970
|
}
|
|
1935
1971
|
if (input.ndim !== 3) {
|
|
1936
|
-
throw new
|
|
1972
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Recurrent layers expect 2D or 3D input; got ndim=${input.ndim}`);
|
|
1937
1973
|
}
|
|
1938
1974
|
if (batchFirst) {
|
|
1939
1975
|
return {
|
|
@@ -1966,7 +2002,7 @@ function outputIndex(batchFirst, isUnbatched, batch, seqLen, hiddenSize, b, t, j
|
|
|
1966
2002
|
return t * (batch * hiddenSize) + b * hiddenSize + j;
|
|
1967
2003
|
}
|
|
1968
2004
|
function extractTensor(arg, _name) {
|
|
1969
|
-
if (
|
|
2005
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(arg)) {
|
|
1970
2006
|
return arg.tensor;
|
|
1971
2007
|
}
|
|
1972
2008
|
return arg;
|
|
@@ -1982,10 +2018,10 @@ function buildState(state, numLayers, batch, hiddenSize, isUnbatched, name) {
|
|
|
1982
2018
|
ensureFloatTensor(state, name);
|
|
1983
2019
|
if (state.ndim === 2) {
|
|
1984
2020
|
if (!isUnbatched) {
|
|
1985
|
-
throw new
|
|
2021
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected ${name} with 3 dimensions for batched input`);
|
|
1986
2022
|
}
|
|
1987
2023
|
if ((state.shape[0] ?? 0) !== numLayers || (state.shape[1] ?? 0) !== hiddenSize) {
|
|
1988
|
-
throw new
|
|
2024
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
1989
2025
|
`Expected ${name} shape [${numLayers}, ${hiddenSize}], got [${state.shape.join(", ")}]`
|
|
1990
2026
|
);
|
|
1991
2027
|
}
|
|
@@ -1994,7 +2030,7 @@ function buildState(state, numLayers, batch, hiddenSize, isUnbatched, name) {
|
|
|
1994
2030
|
for (let l = 0; l < numLayers; l++) {
|
|
1995
2031
|
const layerState = result[l];
|
|
1996
2032
|
if (!layerState) {
|
|
1997
|
-
throw new
|
|
2033
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Internal error: missing ${name} layer state`);
|
|
1998
2034
|
}
|
|
1999
2035
|
const base = state.offset + l * stride02;
|
|
2000
2036
|
for (let j = 0; j < hiddenSize; j++) {
|
|
@@ -2004,12 +2040,12 @@ function buildState(state, numLayers, batch, hiddenSize, isUnbatched, name) {
|
|
|
2004
2040
|
return result;
|
|
2005
2041
|
}
|
|
2006
2042
|
if (state.ndim !== 3) {
|
|
2007
|
-
throw new
|
|
2043
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected ${name} with 2 or 3 dimensions; got ndim=${state.ndim}`);
|
|
2008
2044
|
}
|
|
2009
2045
|
const expectedBatch = isUnbatched ? 1 : batch;
|
|
2010
2046
|
if ((state.shape[0] ?? 0) !== numLayers || (state.shape[1] ?? 0) !== expectedBatch || (state.shape[2] ?? 0) !== hiddenSize) {
|
|
2011
2047
|
const expected = isUnbatched ? [numLayers, 1, hiddenSize] : [numLayers, batch, hiddenSize];
|
|
2012
|
-
throw new
|
|
2048
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
2013
2049
|
`Expected ${name} shape [${expected.join(", ")}], got [${state.shape.join(", ")}]`
|
|
2014
2050
|
);
|
|
2015
2051
|
}
|
|
@@ -2019,7 +2055,7 @@ function buildState(state, numLayers, batch, hiddenSize, isUnbatched, name) {
|
|
|
2019
2055
|
for (let l = 0; l < numLayers; l++) {
|
|
2020
2056
|
const layerState = result[l];
|
|
2021
2057
|
if (!layerState) {
|
|
2022
|
-
throw new
|
|
2058
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Internal error: missing ${name} layer state`);
|
|
2023
2059
|
}
|
|
2024
2060
|
const baseLayer = state.offset + l * stride0;
|
|
2025
2061
|
for (let b = 0; b < batch; b++) {
|
|
@@ -2038,13 +2074,13 @@ function packState(state, numLayers, batch, hiddenSize, dtype, device, isUnbatch
|
|
|
2038
2074
|
for (let l = 0; l < numLayers; l++) {
|
|
2039
2075
|
const layer = state[l];
|
|
2040
2076
|
if (!layer) {
|
|
2041
|
-
throw new
|
|
2077
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing packed state layer");
|
|
2042
2078
|
}
|
|
2043
2079
|
for (let j = 0; j < hiddenSize; j++) {
|
|
2044
2080
|
data[l * hiddenSize + j] = layer[j] ?? 0;
|
|
2045
2081
|
}
|
|
2046
2082
|
}
|
|
2047
|
-
return
|
|
2083
|
+
return chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2048
2084
|
data,
|
|
2049
2085
|
shape: [numLayers, hiddenSize],
|
|
2050
2086
|
dtype,
|
|
@@ -2054,7 +2090,7 @@ function packState(state, numLayers, batch, hiddenSize, dtype, device, isUnbatch
|
|
|
2054
2090
|
for (let l = 0; l < numLayers; l++) {
|
|
2055
2091
|
const layer = state[l];
|
|
2056
2092
|
if (!layer) {
|
|
2057
|
-
throw new
|
|
2093
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing packed state layer");
|
|
2058
2094
|
}
|
|
2059
2095
|
const layerOffset = l * batch * hiddenSize;
|
|
2060
2096
|
for (let b = 0; b < batch; b++) {
|
|
@@ -2064,7 +2100,7 @@ function packState(state, numLayers, batch, hiddenSize, dtype, device, isUnbatch
|
|
|
2064
2100
|
}
|
|
2065
2101
|
}
|
|
2066
2102
|
}
|
|
2067
|
-
return
|
|
2103
|
+
return chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2068
2104
|
data,
|
|
2069
2105
|
shape: [numLayers, batch, hiddenSize],
|
|
2070
2106
|
dtype,
|
|
@@ -2101,19 +2137,19 @@ var RNN = class extends Module {
|
|
|
2101
2137
|
this.biasHh = [];
|
|
2102
2138
|
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2103
2139
|
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2104
|
-
const wIh =
|
|
2105
|
-
const wHh =
|
|
2140
|
+
const wIh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([hiddenSize, inputDim]), stdv);
|
|
2141
|
+
const wHh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([hiddenSize, hiddenSize]), stdv);
|
|
2106
2142
|
this.weightsIh.push(wIh);
|
|
2107
2143
|
this.weightsHh.push(wHh);
|
|
2108
|
-
this.registerParameter(`weight_ih_l${layer}`,
|
|
2109
|
-
this.registerParameter(`weight_hh_l${layer}`,
|
|
2144
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk724CXHFH_cjs.parameter(wIh));
|
|
2145
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk724CXHFH_cjs.parameter(wHh));
|
|
2110
2146
|
if (this.bias) {
|
|
2111
|
-
const bIh =
|
|
2112
|
-
const bHh =
|
|
2147
|
+
const bIh = chunk724CXHFH_cjs.zeros([hiddenSize]);
|
|
2148
|
+
const bHh = chunk724CXHFH_cjs.zeros([hiddenSize]);
|
|
2113
2149
|
this.biasIh.push(bIh);
|
|
2114
2150
|
this.biasHh.push(bHh);
|
|
2115
|
-
this.registerParameter(`bias_ih_l${layer}`,
|
|
2116
|
-
this.registerParameter(`bias_hh_l${layer}`,
|
|
2151
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk724CXHFH_cjs.parameter(bIh));
|
|
2152
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk724CXHFH_cjs.parameter(bHh));
|
|
2117
2153
|
}
|
|
2118
2154
|
}
|
|
2119
2155
|
}
|
|
@@ -2125,13 +2161,13 @@ var RNN = class extends Module {
|
|
|
2125
2161
|
const parsed = parseInput(input, this.batchFirst);
|
|
2126
2162
|
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2127
2163
|
if (inputDim !== this.inputSize) {
|
|
2128
|
-
throw new
|
|
2164
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2129
2165
|
}
|
|
2130
2166
|
if (seqLen <= 0) {
|
|
2131
|
-
throw new
|
|
2167
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2132
2168
|
}
|
|
2133
2169
|
if (!isUnbatched && batch <= 0) {
|
|
2134
|
-
throw new
|
|
2170
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2135
2171
|
}
|
|
2136
2172
|
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2137
2173
|
const outSize = (isUnbatched ? seqLen : batch * seqLen) * this.hiddenSize;
|
|
@@ -2148,13 +2184,13 @@ var RNN = class extends Module {
|
|
|
2148
2184
|
const wIh = this.weightsIh[l];
|
|
2149
2185
|
const wHh = this.weightsHh[l];
|
|
2150
2186
|
if (!wIh || !wHh) {
|
|
2151
|
-
throw new
|
|
2187
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing RNN weights");
|
|
2152
2188
|
}
|
|
2153
2189
|
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2154
2190
|
const newH = new Float64Array(this.hiddenSize);
|
|
2155
2191
|
const hLayer = h[l];
|
|
2156
2192
|
if (!hLayer) {
|
|
2157
|
-
throw new
|
|
2193
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing RNN hidden state");
|
|
2158
2194
|
}
|
|
2159
2195
|
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2160
2196
|
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
@@ -2202,7 +2238,7 @@ var RNN = class extends Module {
|
|
|
2202
2238
|
}
|
|
2203
2239
|
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2204
2240
|
return {
|
|
2205
|
-
output:
|
|
2241
|
+
output: chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2206
2242
|
data: out,
|
|
2207
2243
|
shape: outShape,
|
|
2208
2244
|
dtype: input.dtype,
|
|
@@ -2221,11 +2257,11 @@ var RNN = class extends Module {
|
|
|
2221
2257
|
}
|
|
2222
2258
|
forward(...inputs) {
|
|
2223
2259
|
if (inputs.length < 1 || inputs.length > 2) {
|
|
2224
|
-
throw new
|
|
2260
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("RNN.forward expects 1 or 2 inputs", "inputs", inputs.length);
|
|
2225
2261
|
}
|
|
2226
2262
|
const inputArg = inputs[0];
|
|
2227
2263
|
if (inputArg === void 0) {
|
|
2228
|
-
throw new
|
|
2264
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("RNN.forward requires an input tensor", "input", inputArg);
|
|
2229
2265
|
}
|
|
2230
2266
|
const input = extractTensor(inputArg);
|
|
2231
2267
|
const hxArg = inputs.length === 2 ? inputs[1] : void 0;
|
|
@@ -2274,19 +2310,19 @@ var LSTM = class extends Module {
|
|
|
2274
2310
|
this.biasHh = [];
|
|
2275
2311
|
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2276
2312
|
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2277
|
-
const wIh =
|
|
2278
|
-
const wHh =
|
|
2313
|
+
const wIh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([4 * hiddenSize, inputDim]), stdv);
|
|
2314
|
+
const wHh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([4 * hiddenSize, hiddenSize]), stdv);
|
|
2279
2315
|
this.weightsIh.push(wIh);
|
|
2280
2316
|
this.weightsHh.push(wHh);
|
|
2281
|
-
this.registerParameter(`weight_ih_l${layer}`,
|
|
2282
|
-
this.registerParameter(`weight_hh_l${layer}`,
|
|
2317
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk724CXHFH_cjs.parameter(wIh));
|
|
2318
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk724CXHFH_cjs.parameter(wHh));
|
|
2283
2319
|
if (this.bias) {
|
|
2284
|
-
const bIh =
|
|
2285
|
-
const bHh =
|
|
2320
|
+
const bIh = chunk724CXHFH_cjs.zeros([4 * hiddenSize]);
|
|
2321
|
+
const bHh = chunk724CXHFH_cjs.zeros([4 * hiddenSize]);
|
|
2286
2322
|
this.biasIh.push(bIh);
|
|
2287
2323
|
this.biasHh.push(bHh);
|
|
2288
|
-
this.registerParameter(`bias_ih_l${layer}`,
|
|
2289
|
-
this.registerParameter(`bias_hh_l${layer}`,
|
|
2324
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk724CXHFH_cjs.parameter(bIh));
|
|
2325
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk724CXHFH_cjs.parameter(bHh));
|
|
2290
2326
|
}
|
|
2291
2327
|
}
|
|
2292
2328
|
}
|
|
@@ -2298,13 +2334,13 @@ var LSTM = class extends Module {
|
|
|
2298
2334
|
const parsed = parseInput(input, this.batchFirst);
|
|
2299
2335
|
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2300
2336
|
if (inputDim !== this.inputSize) {
|
|
2301
|
-
throw new
|
|
2337
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2302
2338
|
}
|
|
2303
2339
|
if (seqLen <= 0) {
|
|
2304
|
-
throw new
|
|
2340
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2305
2341
|
}
|
|
2306
2342
|
if (!isUnbatched && batch <= 0) {
|
|
2307
|
-
throw new
|
|
2343
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2308
2344
|
}
|
|
2309
2345
|
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2310
2346
|
const c = buildState(cx, this.numLayers, batch, this.hiddenSize, isUnbatched, "cx");
|
|
@@ -2323,13 +2359,13 @@ var LSTM = class extends Module {
|
|
|
2323
2359
|
const wIh = this.weightsIh[l];
|
|
2324
2360
|
const wHh = this.weightsHh[l];
|
|
2325
2361
|
if (!wIh || !wHh) {
|
|
2326
|
-
throw new
|
|
2362
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing LSTM weights");
|
|
2327
2363
|
}
|
|
2328
2364
|
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2329
2365
|
const hLayer = h[l];
|
|
2330
2366
|
const cLayer = c[l];
|
|
2331
2367
|
if (!hLayer || !cLayer) {
|
|
2332
|
-
throw new
|
|
2368
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing LSTM state");
|
|
2333
2369
|
}
|
|
2334
2370
|
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2335
2371
|
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
@@ -2391,7 +2427,7 @@ var LSTM = class extends Module {
|
|
|
2391
2427
|
}
|
|
2392
2428
|
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2393
2429
|
return {
|
|
2394
|
-
output:
|
|
2430
|
+
output: chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2395
2431
|
data: out,
|
|
2396
2432
|
shape: outShape,
|
|
2397
2433
|
dtype: input.dtype,
|
|
@@ -2419,7 +2455,7 @@ var LSTM = class extends Module {
|
|
|
2419
2455
|
}
|
|
2420
2456
|
forward(...inputs) {
|
|
2421
2457
|
if (inputs.length < 1 || inputs.length > 3) {
|
|
2422
|
-
throw new
|
|
2458
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
2423
2459
|
"LSTM.forward expects 1 to 3 inputs",
|
|
2424
2460
|
"inputs",
|
|
2425
2461
|
inputs.length
|
|
@@ -2427,7 +2463,7 @@ var LSTM = class extends Module {
|
|
|
2427
2463
|
}
|
|
2428
2464
|
const inputArg = inputs[0];
|
|
2429
2465
|
if (inputArg === void 0) {
|
|
2430
|
-
throw new
|
|
2466
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("LSTM.forward requires an input tensor", "input", inputArg);
|
|
2431
2467
|
}
|
|
2432
2468
|
const input = extractTensor(inputArg);
|
|
2433
2469
|
const hxArg = inputs.length >= 2 ? inputs[1] : void 0;
|
|
@@ -2479,19 +2515,19 @@ var GRU = class extends Module {
|
|
|
2479
2515
|
this.biasHh = [];
|
|
2480
2516
|
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2481
2517
|
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2482
|
-
const wIh =
|
|
2483
|
-
const wHh =
|
|
2518
|
+
const wIh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([3 * hiddenSize, inputDim]), stdv);
|
|
2519
|
+
const wHh = chunk724CXHFH_cjs.mulScalar(chunk724CXHFH_cjs.randn([3 * hiddenSize, hiddenSize]), stdv);
|
|
2484
2520
|
this.weightsIh.push(wIh);
|
|
2485
2521
|
this.weightsHh.push(wHh);
|
|
2486
|
-
this.registerParameter(`weight_ih_l${layer}`,
|
|
2487
|
-
this.registerParameter(`weight_hh_l${layer}`,
|
|
2522
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk724CXHFH_cjs.parameter(wIh));
|
|
2523
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk724CXHFH_cjs.parameter(wHh));
|
|
2488
2524
|
if (this.bias) {
|
|
2489
|
-
const bIh =
|
|
2490
|
-
const bHh =
|
|
2525
|
+
const bIh = chunk724CXHFH_cjs.zeros([3 * hiddenSize]);
|
|
2526
|
+
const bHh = chunk724CXHFH_cjs.zeros([3 * hiddenSize]);
|
|
2491
2527
|
this.biasIh.push(bIh);
|
|
2492
2528
|
this.biasHh.push(bHh);
|
|
2493
|
-
this.registerParameter(`bias_ih_l${layer}`,
|
|
2494
|
-
this.registerParameter(`bias_hh_l${layer}`,
|
|
2529
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk724CXHFH_cjs.parameter(bIh));
|
|
2530
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk724CXHFH_cjs.parameter(bHh));
|
|
2495
2531
|
}
|
|
2496
2532
|
}
|
|
2497
2533
|
}
|
|
@@ -2503,13 +2539,13 @@ var GRU = class extends Module {
|
|
|
2503
2539
|
const parsed = parseInput(input, this.batchFirst);
|
|
2504
2540
|
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2505
2541
|
if (inputDim !== this.inputSize) {
|
|
2506
|
-
throw new
|
|
2542
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2507
2543
|
}
|
|
2508
2544
|
if (seqLen <= 0) {
|
|
2509
|
-
throw new
|
|
2545
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2510
2546
|
}
|
|
2511
2547
|
if (!isUnbatched && batch <= 0) {
|
|
2512
|
-
throw new
|
|
2548
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2513
2549
|
}
|
|
2514
2550
|
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2515
2551
|
const outSize = (isUnbatched ? seqLen : batch * seqLen) * this.hiddenSize;
|
|
@@ -2528,12 +2564,12 @@ var GRU = class extends Module {
|
|
|
2528
2564
|
const wIh = this.weightsIh[l];
|
|
2529
2565
|
const wHh = this.weightsHh[l];
|
|
2530
2566
|
if (!wIh || !wHh) {
|
|
2531
|
-
throw new
|
|
2567
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing GRU weights");
|
|
2532
2568
|
}
|
|
2533
2569
|
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2534
2570
|
const hLayer = h[l];
|
|
2535
2571
|
if (!hLayer) {
|
|
2536
|
-
throw new
|
|
2572
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Internal error: missing GRU hidden state");
|
|
2537
2573
|
}
|
|
2538
2574
|
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2539
2575
|
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
@@ -2594,7 +2630,7 @@ var GRU = class extends Module {
|
|
|
2594
2630
|
}
|
|
2595
2631
|
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2596
2632
|
return {
|
|
2597
|
-
output:
|
|
2633
|
+
output: chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2598
2634
|
data: out,
|
|
2599
2635
|
shape: outShape,
|
|
2600
2636
|
dtype: input.dtype,
|
|
@@ -2613,11 +2649,11 @@ var GRU = class extends Module {
|
|
|
2613
2649
|
}
|
|
2614
2650
|
forward(...inputs) {
|
|
2615
2651
|
if (inputs.length < 1 || inputs.length > 2) {
|
|
2616
|
-
throw new
|
|
2652
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("GRU.forward expects 1 or 2 inputs", "inputs", inputs.length);
|
|
2617
2653
|
}
|
|
2618
2654
|
const inputArg = inputs[0];
|
|
2619
2655
|
if (inputArg === void 0) {
|
|
2620
|
-
throw new
|
|
2656
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError("GRU.forward requires an input tensor", "input", inputArg);
|
|
2621
2657
|
}
|
|
2622
2658
|
const input = extractTensor(inputArg);
|
|
2623
2659
|
const hxArg = inputs.length === 2 ? inputs[1] : void 0;
|
|
@@ -2645,7 +2681,7 @@ function toOneHot(indices, numClasses) {
|
|
|
2645
2681
|
const outData = new Float32Array(nSamples * numClasses);
|
|
2646
2682
|
const data = indices.data;
|
|
2647
2683
|
if (Array.isArray(data)) {
|
|
2648
|
-
throw new
|
|
2684
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("crossEntropyLoss target indices must be numeric");
|
|
2649
2685
|
}
|
|
2650
2686
|
const stride0 = indices.strides[0] ?? 0;
|
|
2651
2687
|
const base = indices.offset;
|
|
@@ -2653,10 +2689,10 @@ function toOneHot(indices, numClasses) {
|
|
|
2653
2689
|
const offset = base + i * stride0;
|
|
2654
2690
|
let idx;
|
|
2655
2691
|
if (data instanceof BigInt64Array) {
|
|
2656
|
-
const raw =
|
|
2692
|
+
const raw = chunkZ6BGACIH_cjs.getBigIntElement(data, offset);
|
|
2657
2693
|
const asNumber = Number(raw);
|
|
2658
2694
|
if (!Number.isSafeInteger(asNumber)) {
|
|
2659
|
-
throw new
|
|
2695
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
2660
2696
|
`Class index ${raw.toString()} exceeds safe integer range`,
|
|
2661
2697
|
"target",
|
|
2662
2698
|
raw.toString()
|
|
@@ -2664,13 +2700,13 @@ function toOneHot(indices, numClasses) {
|
|
|
2664
2700
|
}
|
|
2665
2701
|
idx = asNumber;
|
|
2666
2702
|
} else {
|
|
2667
|
-
idx = Number(
|
|
2703
|
+
idx = Number(chunkZ6BGACIH_cjs.getNumericElement(data, offset));
|
|
2668
2704
|
}
|
|
2669
2705
|
if (!Number.isFinite(idx) || !Number.isInteger(idx)) {
|
|
2670
|
-
throw new
|
|
2706
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`Class index ${idx} is not a valid integer`, "target", idx);
|
|
2671
2707
|
}
|
|
2672
2708
|
if (idx < 0 || idx >= numClasses) {
|
|
2673
|
-
throw new
|
|
2709
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
2674
2710
|
`Class index ${idx} out of range [0, ${numClasses})`,
|
|
2675
2711
|
"target",
|
|
2676
2712
|
idx
|
|
@@ -2678,7 +2714,7 @@ function toOneHot(indices, numClasses) {
|
|
|
2678
2714
|
}
|
|
2679
2715
|
outData[i * numClasses + idx] = 1;
|
|
2680
2716
|
}
|
|
2681
|
-
return
|
|
2717
|
+
return chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2682
2718
|
data: outData,
|
|
2683
2719
|
shape: [nSamples, numClasses],
|
|
2684
2720
|
dtype: "float32",
|
|
@@ -2686,62 +2722,59 @@ function toOneHot(indices, numClasses) {
|
|
|
2686
2722
|
});
|
|
2687
2723
|
}
|
|
2688
2724
|
function crossEntropyLoss(input, target) {
|
|
2689
|
-
const yPred =
|
|
2690
|
-
const targetIsGrad =
|
|
2691
|
-
const yTrue =
|
|
2725
|
+
const yPred = chunk724CXHFH_cjs.GradTensor.isGradTensor(input) ? input : chunk724CXHFH_cjs.GradTensor.fromTensor(input);
|
|
2726
|
+
const targetIsGrad = chunk724CXHFH_cjs.GradTensor.isGradTensor(target);
|
|
2727
|
+
const yTrue = chunk724CXHFH_cjs.GradTensor.isGradTensor(target) ? target : chunk724CXHFH_cjs.GradTensor.fromTensor(target, { requiresGrad: false });
|
|
2692
2728
|
if (yPred.ndim !== 2) {
|
|
2693
|
-
throw new
|
|
2729
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Input must be 2-dimensional (batch, classes); got ${yPred.ndim}`);
|
|
2694
2730
|
}
|
|
2695
2731
|
const nSamples = yPred.shape[0] ?? 0;
|
|
2696
2732
|
const nClasses = yPred.shape[1] ?? 0;
|
|
2697
2733
|
let targetTensor = yTrue;
|
|
2698
2734
|
if (yTrue.ndim === 1) {
|
|
2699
|
-
if (targetIsGrad) {
|
|
2700
|
-
throw new chunkJSCDE774_cjs.ShapeError("Target must be 2-dimensional when provided as GradTensor");
|
|
2701
|
-
}
|
|
2702
2735
|
if (yTrue.shape[0] !== nSamples) {
|
|
2703
|
-
throw new
|
|
2736
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
2704
2737
|
`Target must have same number of samples as input; got ${yTrue.shape[0]} and ${nSamples}`
|
|
2705
2738
|
);
|
|
2706
2739
|
}
|
|
2707
2740
|
const oneHot = toOneHot(yTrue.tensor, nClasses);
|
|
2708
|
-
targetTensor =
|
|
2741
|
+
targetTensor = chunk724CXHFH_cjs.GradTensor.fromTensor(oneHot, { requiresGrad: false });
|
|
2709
2742
|
} else if (yTrue.ndim === 2) {
|
|
2710
2743
|
if (yTrue.shape[0] !== nSamples || yTrue.shape[1] !== nClasses) {
|
|
2711
|
-
throw new
|
|
2744
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(
|
|
2712
2745
|
"Target must be 1-dimensional class indices or have the same shape as input"
|
|
2713
2746
|
);
|
|
2714
2747
|
}
|
|
2715
2748
|
} else {
|
|
2716
|
-
throw new
|
|
2749
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Target must be 1D (indices) or 2D (probs); got ${yTrue.ndim}D`);
|
|
2717
2750
|
}
|
|
2718
|
-
const logProbs =
|
|
2751
|
+
const logProbs = chunk724CXHFH_cjs.logSoftmax2(yPred, 1);
|
|
2719
2752
|
const weighted = logProbs.mul(targetTensor);
|
|
2720
2753
|
const sampleLoss = weighted.sum(1);
|
|
2721
2754
|
const meanLoss = sampleLoss.mean().neg();
|
|
2722
|
-
if (!(input
|
|
2755
|
+
if (!chunk724CXHFH_cjs.GradTensor.isGradTensor(input) && !targetIsGrad) {
|
|
2723
2756
|
const data = meanLoss.tensor.data;
|
|
2724
2757
|
if (Array.isArray(data)) {
|
|
2725
|
-
throw new
|
|
2758
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("crossEntropyLoss does not support string dtype");
|
|
2726
2759
|
}
|
|
2727
2760
|
if (data instanceof BigInt64Array) {
|
|
2728
|
-
const raw =
|
|
2761
|
+
const raw = chunkZ6BGACIH_cjs.getBigIntElement(data, meanLoss.tensor.offset);
|
|
2729
2762
|
return Number(raw);
|
|
2730
2763
|
}
|
|
2731
|
-
return
|
|
2764
|
+
return chunkZ6BGACIH_cjs.getNumericElement(data, meanLoss.tensor.offset);
|
|
2732
2765
|
}
|
|
2733
2766
|
return meanLoss;
|
|
2734
2767
|
}
|
|
2735
2768
|
function binaryCrossEntropyWithLogitsLoss(input, target) {
|
|
2736
|
-
const yPred =
|
|
2737
|
-
const yTrue =
|
|
2769
|
+
const yPred = chunk724CXHFH_cjs.GradTensor.isGradTensor(input) ? input : chunk724CXHFH_cjs.GradTensor.fromTensor(input);
|
|
2770
|
+
const yTrue = chunk724CXHFH_cjs.GradTensor.isGradTensor(target) ? target : chunk724CXHFH_cjs.GradTensor.fromTensor(target, { requiresGrad: false });
|
|
2738
2771
|
let pred = yPred;
|
|
2739
2772
|
let truth = yTrue;
|
|
2740
2773
|
if (pred.ndim !== 1 && pred.ndim !== 2) {
|
|
2741
|
-
throw new
|
|
2774
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Input must be 1 or 2-dimensional");
|
|
2742
2775
|
}
|
|
2743
2776
|
if (truth.ndim !== 1 && truth.ndim !== 2) {
|
|
2744
|
-
throw new
|
|
2777
|
+
throw new chunkZ6BGACIH_cjs.ShapeError("Target must be 1 or 2-dimensional");
|
|
2745
2778
|
}
|
|
2746
2779
|
if (pred.ndim === 1) {
|
|
2747
2780
|
pred = pred.reshape([pred.shape[0] ?? 0, 1]);
|
|
@@ -2750,17 +2783,17 @@ function binaryCrossEntropyWithLogitsLoss(input, target) {
|
|
|
2750
2783
|
truth = truth.reshape([truth.shape[0] ?? 0, 1]);
|
|
2751
2784
|
}
|
|
2752
2785
|
if (pred.ndim !== 2 || pred.shape[1] !== 1) {
|
|
2753
|
-
throw new
|
|
2786
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Input must have shape (N,) or (N, 1)`);
|
|
2754
2787
|
}
|
|
2755
2788
|
if (truth.ndim !== 2 || truth.shape[1] !== 1) {
|
|
2756
|
-
throw new
|
|
2789
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Target must be 1-dimensional or have shape (N, 1)`);
|
|
2757
2790
|
}
|
|
2758
2791
|
if ((pred.shape[0] ?? 0) !== (truth.shape[0] ?? 0)) {
|
|
2759
|
-
throw new
|
|
2792
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Batch size mismatch`);
|
|
2760
2793
|
}
|
|
2761
2794
|
const predDtype = pred.dtype;
|
|
2762
2795
|
if (predDtype === "string") {
|
|
2763
|
-
throw new
|
|
2796
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("Binary cross entropy does not support string dtype");
|
|
2764
2797
|
}
|
|
2765
2798
|
const term1 = pred.relu();
|
|
2766
2799
|
const term2 = pred.mul(truth);
|
|
@@ -2769,21 +2802,21 @@ function binaryCrossEntropyWithLogitsLoss(input, target) {
|
|
|
2769
2802
|
const expNegAbs = absPred.neg().exp();
|
|
2770
2803
|
const scalarDtype = expNegAbs.dtype;
|
|
2771
2804
|
if (scalarDtype === "string") {
|
|
2772
|
-
throw new
|
|
2805
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("binaryCrossEntropyWithLogitsLoss does not support string dtype");
|
|
2773
2806
|
}
|
|
2774
|
-
const one =
|
|
2807
|
+
const one = chunk724CXHFH_cjs.GradTensor.scalar(1, { dtype: scalarDtype });
|
|
2775
2808
|
const term3 = one.add(expNegAbs).log();
|
|
2776
2809
|
const loss = term1.sub(term2).add(term3).mean();
|
|
2777
|
-
if (!(input
|
|
2810
|
+
if (!chunk724CXHFH_cjs.GradTensor.isGradTensor(input) && !chunk724CXHFH_cjs.GradTensor.isGradTensor(target)) {
|
|
2778
2811
|
const data = loss.tensor.data;
|
|
2779
2812
|
if (Array.isArray(data)) {
|
|
2780
|
-
throw new
|
|
2813
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("binaryCrossEntropyWithLogitsLoss does not support string dtype");
|
|
2781
2814
|
}
|
|
2782
2815
|
if (data instanceof BigInt64Array) {
|
|
2783
|
-
const raw =
|
|
2816
|
+
const raw = chunkZ6BGACIH_cjs.getBigIntElement(data, loss.tensor.offset);
|
|
2784
2817
|
return Number(raw);
|
|
2785
2818
|
}
|
|
2786
|
-
return
|
|
2819
|
+
return chunkZ6BGACIH_cjs.getNumericElement(data, loss.tensor.offset);
|
|
2787
2820
|
}
|
|
2788
2821
|
return loss;
|
|
2789
2822
|
}
|
|
@@ -2798,17 +2831,25 @@ function shapesEqual2(a, b) {
|
|
|
2798
2831
|
}
|
|
2799
2832
|
function ensureSameShape(a, b, context) {
|
|
2800
2833
|
if (!shapesEqual2(a.shape, b.shape)) {
|
|
2801
|
-
throw new
|
|
2834
|
+
throw new chunkZ6BGACIH_cjs.ShapeError(`Shape mismatch in ${context}: [${a.shape}] vs [${b.shape}]`);
|
|
2835
|
+
}
|
|
2836
|
+
}
|
|
2837
|
+
function alignShapes(a, b) {
|
|
2838
|
+
if (shapesEqual2(a.shape, b.shape)) return [a, b];
|
|
2839
|
+
if (a.size === b.size) {
|
|
2840
|
+
if (a.ndim > b.ndim) return [chunk724CXHFH_cjs.reshape(a, b.shape), b];
|
|
2841
|
+
if (b.ndim > a.ndim) return [a, chunk724CXHFH_cjs.reshape(b, a.shape)];
|
|
2802
2842
|
}
|
|
2843
|
+
return [a, b];
|
|
2803
2844
|
}
|
|
2804
2845
|
function ensureNumeric(t, context) {
|
|
2805
2846
|
if (t.dtype === "string") {
|
|
2806
|
-
throw new
|
|
2847
|
+
throw new chunkZ6BGACIH_cjs.DTypeError(`${context} does not support string dtype`);
|
|
2807
2848
|
}
|
|
2808
2849
|
}
|
|
2809
2850
|
function validateReduction(reduction, context) {
|
|
2810
2851
|
if (reduction !== "mean" && reduction !== "sum" && reduction !== "none") {
|
|
2811
|
-
throw new
|
|
2852
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(
|
|
2812
2853
|
`${context} reduction must be 'mean', 'sum', or 'none'`,
|
|
2813
2854
|
"reduction",
|
|
2814
2855
|
reduction
|
|
@@ -2816,38 +2857,51 @@ function validateReduction(reduction, context) {
|
|
|
2816
2857
|
}
|
|
2817
2858
|
}
|
|
2818
2859
|
function readNumericFlat(data, flat, logicalStrides, strides, offset) {
|
|
2819
|
-
const dataOffset =
|
|
2820
|
-
return
|
|
2860
|
+
const dataOffset = chunk724CXHFH_cjs.offsetFromFlatIndex(flat, logicalStrides, strides, offset);
|
|
2861
|
+
return chunkZ6BGACIH_cjs.getElementAsNumber(data, dataOffset);
|
|
2821
2862
|
}
|
|
2822
2863
|
function mseLoss(predictions, targets, reduction = "mean") {
|
|
2823
2864
|
validateReduction(reduction, "mseLoss");
|
|
2824
|
-
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2828
|
-
|
|
2865
|
+
if (chunk724CXHFH_cjs.GradTensor.isGradTensor(predictions)) {
|
|
2866
|
+
const pred = predictions;
|
|
2867
|
+
const tgt = chunk724CXHFH_cjs.GradTensor.isGradTensor(targets) ? targets : chunk724CXHFH_cjs.GradTensor.fromTensor(targets, { requiresGrad: false });
|
|
2868
|
+
const diff2 = pred.sub(tgt);
|
|
2869
|
+
const squared = diff2.mul(diff2);
|
|
2870
|
+
if (reduction === "none") return squared;
|
|
2871
|
+
if (reduction === "sum") return squared.sum();
|
|
2872
|
+
return squared.mean();
|
|
2873
|
+
}
|
|
2874
|
+
let preds = predictions;
|
|
2875
|
+
let tgts = chunk724CXHFH_cjs.GradTensor.isGradTensor(targets) ? targets.tensor : targets;
|
|
2876
|
+
ensureNumeric(preds, "mseLoss");
|
|
2877
|
+
ensureNumeric(tgts, "mseLoss");
|
|
2878
|
+
[preds, tgts] = alignShapes(preds, tgts);
|
|
2879
|
+
ensureSameShape(preds, tgts, "mseLoss");
|
|
2880
|
+
const diff = chunk724CXHFH_cjs.sub(preds, tgts);
|
|
2881
|
+
const squaredDiff = chunk724CXHFH_cjs.pow(diff, chunk724CXHFH_cjs.tensor(2, { dtype: diff.dtype, device: diff.device }));
|
|
2829
2882
|
if (reduction === "none") {
|
|
2830
2883
|
return squaredDiff;
|
|
2831
2884
|
}
|
|
2832
2885
|
if (reduction === "sum") {
|
|
2833
|
-
return
|
|
2886
|
+
return chunk724CXHFH_cjs.sum(squaredDiff);
|
|
2834
2887
|
}
|
|
2835
|
-
return
|
|
2888
|
+
return chunk724CXHFH_cjs.mean(squaredDiff);
|
|
2836
2889
|
}
|
|
2837
2890
|
function maeLoss(predictions, targets, reduction = "mean") {
|
|
2838
2891
|
validateReduction(reduction, "maeLoss");
|
|
2839
2892
|
ensureNumeric(predictions, "maeLoss");
|
|
2840
2893
|
ensureNumeric(targets, "maeLoss");
|
|
2894
|
+
[predictions, targets] = alignShapes(predictions, targets);
|
|
2841
2895
|
ensureSameShape(predictions, targets, "maeLoss");
|
|
2842
|
-
const diff =
|
|
2843
|
-
const absDiff =
|
|
2896
|
+
const diff = chunk724CXHFH_cjs.sub(predictions, targets);
|
|
2897
|
+
const absDiff = chunk724CXHFH_cjs.abs(diff);
|
|
2844
2898
|
if (reduction === "none") {
|
|
2845
2899
|
return absDiff;
|
|
2846
2900
|
}
|
|
2847
2901
|
if (reduction === "sum") {
|
|
2848
|
-
return
|
|
2902
|
+
return chunk724CXHFH_cjs.sum(absDiff);
|
|
2849
2903
|
}
|
|
2850
|
-
return
|
|
2904
|
+
return chunk724CXHFH_cjs.mean(absDiff);
|
|
2851
2905
|
}
|
|
2852
2906
|
function binaryCrossEntropyLoss(predictions, targets, reduction = "mean") {
|
|
2853
2907
|
validateReduction(reduction, "binaryCrossEntropyLoss");
|
|
@@ -2855,50 +2909,51 @@ function binaryCrossEntropyLoss(predictions, targets, reduction = "mean") {
|
|
|
2855
2909
|
ensureNumeric(targets, "binaryCrossEntropyLoss");
|
|
2856
2910
|
ensureSameShape(predictions, targets, "binaryCrossEntropyLoss");
|
|
2857
2911
|
const epsilon = 1e-7;
|
|
2858
|
-
const predClamped =
|
|
2859
|
-
const logPred =
|
|
2860
|
-
const term1 =
|
|
2861
|
-
const one =
|
|
2912
|
+
const predClamped = chunk724CXHFH_cjs.clip(predictions, epsilon, 1 - epsilon);
|
|
2913
|
+
const logPred = chunk724CXHFH_cjs.log(predClamped);
|
|
2914
|
+
const term1 = chunk724CXHFH_cjs.mul(targets, logPred);
|
|
2915
|
+
const one = chunk724CXHFH_cjs.tensor(1, {
|
|
2862
2916
|
dtype: predictions.dtype === "float64" ? "float64" : "float32",
|
|
2863
2917
|
device: predictions.device
|
|
2864
2918
|
});
|
|
2865
|
-
const oneMinusTargets =
|
|
2866
|
-
const oneMinusPred =
|
|
2867
|
-
const logOneMinusPred =
|
|
2868
|
-
const term2 =
|
|
2869
|
-
const loss =
|
|
2919
|
+
const oneMinusTargets = chunk724CXHFH_cjs.sub(one, targets);
|
|
2920
|
+
const oneMinusPred = chunk724CXHFH_cjs.sub(one, predClamped);
|
|
2921
|
+
const logOneMinusPred = chunk724CXHFH_cjs.log(oneMinusPred);
|
|
2922
|
+
const term2 = chunk724CXHFH_cjs.mul(oneMinusTargets, logOneMinusPred);
|
|
2923
|
+
const loss = chunk724CXHFH_cjs.neg(chunk724CXHFH_cjs.add(term1, term2));
|
|
2870
2924
|
if (reduction === "none") {
|
|
2871
2925
|
return loss;
|
|
2872
2926
|
}
|
|
2873
2927
|
if (reduction === "sum") {
|
|
2874
|
-
return
|
|
2928
|
+
return chunk724CXHFH_cjs.sum(loss);
|
|
2875
2929
|
}
|
|
2876
|
-
return
|
|
2930
|
+
return chunk724CXHFH_cjs.mean(loss);
|
|
2877
2931
|
}
|
|
2878
2932
|
function rmseLoss(predictions, targets) {
|
|
2879
2933
|
ensureNumeric(predictions, "rmseLoss");
|
|
2880
2934
|
ensureNumeric(targets, "rmseLoss");
|
|
2881
2935
|
ensureSameShape(predictions, targets, "rmseLoss");
|
|
2882
2936
|
const mse = mseLoss(predictions, targets, "mean");
|
|
2883
|
-
return
|
|
2937
|
+
return chunk724CXHFH_cjs.sqrt(mse);
|
|
2884
2938
|
}
|
|
2885
2939
|
function huberLoss(predictions, targets, delta = 1, reduction = "mean") {
|
|
2886
2940
|
validateReduction(reduction, "huberLoss");
|
|
2887
2941
|
ensureNumeric(predictions, "huberLoss");
|
|
2888
2942
|
ensureNumeric(targets, "huberLoss");
|
|
2943
|
+
[predictions, targets] = alignShapes(predictions, targets);
|
|
2889
2944
|
ensureSameShape(predictions, targets, "huberLoss");
|
|
2890
2945
|
if (!Number.isFinite(delta) || delta <= 0) {
|
|
2891
|
-
throw new
|
|
2946
|
+
throw new chunkZ6BGACIH_cjs.InvalidParameterError(`delta must be positive; got ${delta}`, "delta", delta);
|
|
2892
2947
|
}
|
|
2893
|
-
const diff =
|
|
2894
|
-
const absDiff =
|
|
2948
|
+
const diff = chunk724CXHFH_cjs.sub(predictions, targets);
|
|
2949
|
+
const absDiff = chunk724CXHFH_cjs.abs(diff);
|
|
2895
2950
|
const absData = absDiff.data;
|
|
2896
2951
|
if (Array.isArray(absData)) {
|
|
2897
|
-
throw new
|
|
2952
|
+
throw new chunkZ6BGACIH_cjs.DTypeError("huberLoss does not support string dtype");
|
|
2898
2953
|
}
|
|
2899
2954
|
const dtype = predictions.dtype === "float64" ? "float64" : "float32";
|
|
2900
2955
|
const lossData = dtype === "float64" ? new Float64Array(diff.size) : new Float32Array(diff.size);
|
|
2901
|
-
const logicalStrides =
|
|
2956
|
+
const logicalStrides = chunk724CXHFH_cjs.computeStrides(absDiff.shape);
|
|
2902
2957
|
for (let i = 0; i < diff.size; i++) {
|
|
2903
2958
|
const absVal = readNumericFlat(absData, i, logicalStrides, absDiff.strides, absDiff.offset);
|
|
2904
2959
|
if (absVal <= delta) {
|
|
@@ -2907,7 +2962,7 @@ function huberLoss(predictions, targets, delta = 1, reduction = "mean") {
|
|
|
2907
2962
|
lossData[i] = delta * (absVal - 0.5 * delta);
|
|
2908
2963
|
}
|
|
2909
2964
|
}
|
|
2910
|
-
const loss =
|
|
2965
|
+
const loss = chunk724CXHFH_cjs.Tensor.fromTypedArray({
|
|
2911
2966
|
data: lossData,
|
|
2912
2967
|
shape: predictions.shape,
|
|
2913
2968
|
dtype,
|
|
@@ -2917,9 +2972,9 @@ function huberLoss(predictions, targets, delta = 1, reduction = "mean") {
|
|
|
2917
2972
|
return loss;
|
|
2918
2973
|
}
|
|
2919
2974
|
if (reduction === "sum") {
|
|
2920
|
-
return
|
|
2975
|
+
return chunk724CXHFH_cjs.sum(loss);
|
|
2921
2976
|
}
|
|
2922
|
-
return
|
|
2977
|
+
return chunk724CXHFH_cjs.mean(loss);
|
|
2923
2978
|
}
|
|
2924
2979
|
|
|
2925
2980
|
exports.AvgPool2d = AvgPool2d;
|
|
@@ -2956,5 +3011,5 @@ exports.maeLoss = maeLoss;
|
|
|
2956
3011
|
exports.mseLoss = mseLoss;
|
|
2957
3012
|
exports.nn_exports = nn_exports;
|
|
2958
3013
|
exports.rmseLoss = rmseLoss;
|
|
2959
|
-
//# sourceMappingURL=chunk-
|
|
2960
|
-
//# sourceMappingURL=chunk-
|
|
3014
|
+
//# sourceMappingURL=chunk-EUZHJDZ6.cjs.map
|
|
3015
|
+
//# sourceMappingURL=chunk-EUZHJDZ6.cjs.map
|