deepbox 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +344 -0
- package/dist/CSRMatrix-CwGwQRea.d.cts +219 -0
- package/dist/CSRMatrix-KzNt6QpS.d.ts +219 -0
- package/dist/Tensor-BQLk1ltW.d.cts +147 -0
- package/dist/Tensor-g8mUClel.d.ts +147 -0
- package/dist/chunk-4S73VUBD.js +677 -0
- package/dist/chunk-4S73VUBD.js.map +1 -0
- package/dist/chunk-5R4S63PF.js +2925 -0
- package/dist/chunk-5R4S63PF.js.map +1 -0
- package/dist/chunk-6AE5FKKQ.cjs +9264 -0
- package/dist/chunk-6AE5FKKQ.cjs.map +1 -0
- package/dist/chunk-AD436M45.js +3854 -0
- package/dist/chunk-AD436M45.js.map +1 -0
- package/dist/chunk-ALS7ETWZ.cjs +4263 -0
- package/dist/chunk-ALS7ETWZ.cjs.map +1 -0
- package/dist/chunk-AU7XHGKJ.js +2092 -0
- package/dist/chunk-AU7XHGKJ.js.map +1 -0
- package/dist/chunk-B5TNKUEY.js +1481 -0
- package/dist/chunk-B5TNKUEY.js.map +1 -0
- package/dist/chunk-BCR7G3A6.js +9136 -0
- package/dist/chunk-BCR7G3A6.js.map +1 -0
- package/dist/chunk-C4PKXY74.cjs +1917 -0
- package/dist/chunk-C4PKXY74.cjs.map +1 -0
- package/dist/chunk-DWZY6PIP.cjs +6400 -0
- package/dist/chunk-DWZY6PIP.cjs.map +1 -0
- package/dist/chunk-E3EU5FZO.cjs +2113 -0
- package/dist/chunk-E3EU5FZO.cjs.map +1 -0
- package/dist/chunk-F3JWBINJ.js +1054 -0
- package/dist/chunk-F3JWBINJ.js.map +1 -0
- package/dist/chunk-FJYLIGJX.js +1940 -0
- package/dist/chunk-FJYLIGJX.js.map +1 -0
- package/dist/chunk-JSCDE774.cjs +729 -0
- package/dist/chunk-JSCDE774.cjs.map +1 -0
- package/dist/chunk-LWECRCW2.cjs +2412 -0
- package/dist/chunk-LWECRCW2.cjs.map +1 -0
- package/dist/chunk-MLBMYKCG.js +6379 -0
- package/dist/chunk-MLBMYKCG.js.map +1 -0
- package/dist/chunk-OX6QXFMV.cjs +3874 -0
- package/dist/chunk-OX6QXFMV.cjs.map +1 -0
- package/dist/chunk-PHV2DKRS.cjs +1072 -0
- package/dist/chunk-PHV2DKRS.cjs.map +1 -0
- package/dist/chunk-PL7TAYKI.js +4056 -0
- package/dist/chunk-PL7TAYKI.js.map +1 -0
- package/dist/chunk-PR647I7R.js +1898 -0
- package/dist/chunk-PR647I7R.js.map +1 -0
- package/dist/chunk-QERHVCHC.cjs +2960 -0
- package/dist/chunk-QERHVCHC.cjs.map +1 -0
- package/dist/chunk-XEG44RF6.cjs +1514 -0
- package/dist/chunk-XEG44RF6.cjs.map +1 -0
- package/dist/chunk-XMWVME2W.js +2377 -0
- package/dist/chunk-XMWVME2W.js.map +1 -0
- package/dist/chunk-ZB75FESB.cjs +1979 -0
- package/dist/chunk-ZB75FESB.cjs.map +1 -0
- package/dist/chunk-ZLW62TJG.cjs +4061 -0
- package/dist/chunk-ZLW62TJG.cjs.map +1 -0
- package/dist/chunk-ZXKBDFP3.js +4235 -0
- package/dist/chunk-ZXKBDFP3.js.map +1 -0
- package/dist/core/index.cjs +204 -0
- package/dist/core/index.cjs.map +1 -0
- package/dist/core/index.d.cts +2 -0
- package/dist/core/index.d.ts +2 -0
- package/dist/core/index.js +3 -0
- package/dist/core/index.js.map +1 -0
- package/dist/dataframe/index.cjs +22 -0
- package/dist/dataframe/index.cjs.map +1 -0
- package/dist/dataframe/index.d.cts +3 -0
- package/dist/dataframe/index.d.ts +3 -0
- package/dist/dataframe/index.js +5 -0
- package/dist/dataframe/index.js.map +1 -0
- package/dist/datasets/index.cjs +134 -0
- package/dist/datasets/index.cjs.map +1 -0
- package/dist/datasets/index.d.cts +3 -0
- package/dist/datasets/index.d.ts +3 -0
- package/dist/datasets/index.js +5 -0
- package/dist/datasets/index.js.map +1 -0
- package/dist/index-74AB8Cyh.d.cts +1126 -0
- package/dist/index-9oQx1HgV.d.cts +1180 -0
- package/dist/index-BJY2SI4i.d.ts +483 -0
- package/dist/index-BWGhrDlr.d.ts +733 -0
- package/dist/index-B_DK4FKY.d.cts +242 -0
- package/dist/index-BbA2Gxfl.d.ts +456 -0
- package/dist/index-BgHYAoSS.d.cts +837 -0
- package/dist/index-BndMbqsM.d.ts +1439 -0
- package/dist/index-C1mfVYoo.d.ts +2517 -0
- package/dist/index-CCvlwAmL.d.cts +809 -0
- package/dist/index-CDw5CnOU.d.ts +785 -0
- package/dist/index-Cn3SdB0O.d.ts +1126 -0
- package/dist/index-CrqLlS-a.d.ts +776 -0
- package/dist/index-D61yaSMY.d.cts +483 -0
- package/dist/index-D9Loo1_A.d.cts +2517 -0
- package/dist/index-DIT_OO9C.d.cts +785 -0
- package/dist/index-DIp_RrRt.d.ts +242 -0
- package/dist/index-DbultU6X.d.cts +1427 -0
- package/dist/index-DmEg_LCm.d.cts +776 -0
- package/dist/index-DoPWVxPo.d.cts +1439 -0
- package/dist/index-DuCxd-8d.d.ts +837 -0
- package/dist/index-Dx42TZaY.d.ts +809 -0
- package/dist/index-DyZ4QQf5.d.cts +456 -0
- package/dist/index-GFAVyOWO.d.ts +1427 -0
- package/dist/index-WHQLn0e8.d.cts +733 -0
- package/dist/index-ZtI1Iy4L.d.ts +1180 -0
- package/dist/index-eJgeni9c.d.cts +1911 -0
- package/dist/index-tk4lSYod.d.ts +1911 -0
- package/dist/index.cjs +72 -0
- package/dist/index.cjs.map +1 -0
- package/dist/index.d.cts +17 -0
- package/dist/index.d.ts +17 -0
- package/dist/index.js +15 -0
- package/dist/index.js.map +1 -0
- package/dist/linalg/index.cjs +86 -0
- package/dist/linalg/index.cjs.map +1 -0
- package/dist/linalg/index.d.cts +3 -0
- package/dist/linalg/index.d.ts +3 -0
- package/dist/linalg/index.js +5 -0
- package/dist/linalg/index.js.map +1 -0
- package/dist/metrics/index.cjs +158 -0
- package/dist/metrics/index.cjs.map +1 -0
- package/dist/metrics/index.d.cts +3 -0
- package/dist/metrics/index.d.ts +3 -0
- package/dist/metrics/index.js +5 -0
- package/dist/metrics/index.js.map +1 -0
- package/dist/ml/index.cjs +87 -0
- package/dist/ml/index.cjs.map +1 -0
- package/dist/ml/index.d.cts +3 -0
- package/dist/ml/index.d.ts +3 -0
- package/dist/ml/index.js +6 -0
- package/dist/ml/index.js.map +1 -0
- package/dist/ndarray/index.cjs +501 -0
- package/dist/ndarray/index.cjs.map +1 -0
- package/dist/ndarray/index.d.cts +5 -0
- package/dist/ndarray/index.d.ts +5 -0
- package/dist/ndarray/index.js +4 -0
- package/dist/ndarray/index.js.map +1 -0
- package/dist/nn/index.cjs +142 -0
- package/dist/nn/index.cjs.map +1 -0
- package/dist/nn/index.d.cts +6 -0
- package/dist/nn/index.d.ts +6 -0
- package/dist/nn/index.js +5 -0
- package/dist/nn/index.js.map +1 -0
- package/dist/optim/index.cjs +77 -0
- package/dist/optim/index.cjs.map +1 -0
- package/dist/optim/index.d.cts +4 -0
- package/dist/optim/index.d.ts +4 -0
- package/dist/optim/index.js +4 -0
- package/dist/optim/index.js.map +1 -0
- package/dist/plot/index.cjs +114 -0
- package/dist/plot/index.cjs.map +1 -0
- package/dist/plot/index.d.cts +6 -0
- package/dist/plot/index.d.ts +6 -0
- package/dist/plot/index.js +5 -0
- package/dist/plot/index.js.map +1 -0
- package/dist/preprocess/index.cjs +82 -0
- package/dist/preprocess/index.cjs.map +1 -0
- package/dist/preprocess/index.d.cts +4 -0
- package/dist/preprocess/index.d.ts +4 -0
- package/dist/preprocess/index.js +5 -0
- package/dist/preprocess/index.js.map +1 -0
- package/dist/random/index.cjs +74 -0
- package/dist/random/index.cjs.map +1 -0
- package/dist/random/index.d.cts +3 -0
- package/dist/random/index.d.ts +3 -0
- package/dist/random/index.js +5 -0
- package/dist/random/index.js.map +1 -0
- package/dist/stats/index.cjs +142 -0
- package/dist/stats/index.cjs.map +1 -0
- package/dist/stats/index.d.cts +3 -0
- package/dist/stats/index.d.ts +3 -0
- package/dist/stats/index.js +5 -0
- package/dist/stats/index.js.map +1 -0
- package/dist/tensor-B96jjJLQ.d.cts +205 -0
- package/dist/tensor-B96jjJLQ.d.ts +205 -0
- package/package.json +226 -0
|
@@ -0,0 +1,2960 @@
|
|
|
1
|
+
'use strict';
|
|
2
|
+
|
|
3
|
+
var chunk6AE5FKKQ_cjs = require('./chunk-6AE5FKKQ.cjs');
|
|
4
|
+
var chunkJSCDE774_cjs = require('./chunk-JSCDE774.cjs');
|
|
5
|
+
|
|
6
|
+
// src/nn/index.ts
|
|
7
|
+
var nn_exports = {};
|
|
8
|
+
chunkJSCDE774_cjs.__export(nn_exports, {
|
|
9
|
+
AvgPool2d: () => AvgPool2d,
|
|
10
|
+
BatchNorm1d: () => BatchNorm1d,
|
|
11
|
+
Conv1d: () => Conv1d,
|
|
12
|
+
Conv2d: () => Conv2d,
|
|
13
|
+
Dropout: () => Dropout,
|
|
14
|
+
ELU: () => ELU,
|
|
15
|
+
GELU: () => GELU,
|
|
16
|
+
GRU: () => GRU,
|
|
17
|
+
LSTM: () => LSTM,
|
|
18
|
+
LayerNorm: () => LayerNorm,
|
|
19
|
+
LeakyReLU: () => LeakyReLU,
|
|
20
|
+
Linear: () => Linear,
|
|
21
|
+
LogSoftmax: () => LogSoftmax,
|
|
22
|
+
MaxPool2d: () => MaxPool2d,
|
|
23
|
+
Mish: () => Mish,
|
|
24
|
+
Module: () => Module,
|
|
25
|
+
MultiheadAttention: () => MultiheadAttention,
|
|
26
|
+
RNN: () => RNN,
|
|
27
|
+
ReLU: () => ReLU,
|
|
28
|
+
Sequential: () => Sequential,
|
|
29
|
+
Sigmoid: () => Sigmoid,
|
|
30
|
+
Softmax: () => Softmax,
|
|
31
|
+
Softplus: () => Softplus,
|
|
32
|
+
Swish: () => Swish,
|
|
33
|
+
Tanh: () => Tanh,
|
|
34
|
+
TransformerEncoderLayer: () => TransformerEncoderLayer,
|
|
35
|
+
binaryCrossEntropyLoss: () => binaryCrossEntropyLoss,
|
|
36
|
+
binaryCrossEntropyWithLogitsLoss: () => binaryCrossEntropyWithLogitsLoss,
|
|
37
|
+
crossEntropyLoss: () => crossEntropyLoss,
|
|
38
|
+
huberLoss: () => huberLoss,
|
|
39
|
+
maeLoss: () => maeLoss,
|
|
40
|
+
mseLoss: () => mseLoss,
|
|
41
|
+
rmseLoss: () => rmseLoss
|
|
42
|
+
});
|
|
43
|
+
|
|
44
|
+
// src/nn/module/Module.ts
|
|
45
|
+
function shapesEqual(a, b) {
|
|
46
|
+
if (a.length !== b.length) return false;
|
|
47
|
+
for (let i = 0; i < a.length; i++) {
|
|
48
|
+
if ((a[i] ?? 0) !== (b[i] ?? 0)) return false;
|
|
49
|
+
}
|
|
50
|
+
return true;
|
|
51
|
+
}
|
|
52
|
+
function sizeFromShape(shape, context) {
|
|
53
|
+
let size = 1;
|
|
54
|
+
for (const dim of shape) {
|
|
55
|
+
if (!Number.isInteger(dim) || dim < 0) {
|
|
56
|
+
throw new chunkJSCDE774_cjs.ShapeError(`${context} contains invalid dimension ${String(dim)}`);
|
|
57
|
+
}
|
|
58
|
+
size *= dim;
|
|
59
|
+
}
|
|
60
|
+
return size;
|
|
61
|
+
}
|
|
62
|
+
function cloneTensorData(t) {
|
|
63
|
+
const data = t.data;
|
|
64
|
+
if (Array.isArray(data)) {
|
|
65
|
+
return data.slice();
|
|
66
|
+
}
|
|
67
|
+
if (data instanceof BigInt64Array) {
|
|
68
|
+
return Array.from(data);
|
|
69
|
+
}
|
|
70
|
+
const out = new Array(data.length);
|
|
71
|
+
for (let i = 0; i < data.length; i++) {
|
|
72
|
+
const value = data[i];
|
|
73
|
+
if (value === void 0) {
|
|
74
|
+
throw new chunkJSCDE774_cjs.DeepboxError("Internal error: tensor data access out of bounds");
|
|
75
|
+
}
|
|
76
|
+
out[i] = value;
|
|
77
|
+
}
|
|
78
|
+
return out;
|
|
79
|
+
}
|
|
80
|
+
function validateStateEntryShape(name, kind, entry) {
|
|
81
|
+
const size = sizeFromShape(entry.shape, `${kind} ${name} shape`);
|
|
82
|
+
if (entry.data.length !== size) {
|
|
83
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
84
|
+
`${kind} ${name} data length ${entry.data.length} does not match shape size ${size}`
|
|
85
|
+
);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
function copyStateEntryIntoTensor(name, kind, target, entry) {
|
|
89
|
+
if (!shapesEqual(target.shape, entry.shape)) {
|
|
90
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
91
|
+
`${kind} ${name} shape mismatch: expected [${target.shape.join(", ")}], got [${entry.shape.join(", ")}]`
|
|
92
|
+
);
|
|
93
|
+
}
|
|
94
|
+
if (target.dtype !== entry.dtype) {
|
|
95
|
+
throw new chunkJSCDE774_cjs.DTypeError(
|
|
96
|
+
`${kind} ${name} dtype mismatch: expected ${target.dtype}, got ${entry.dtype}`
|
|
97
|
+
);
|
|
98
|
+
}
|
|
99
|
+
const size = sizeFromShape(entry.shape, `${kind} ${name} shape`);
|
|
100
|
+
const logicalStrides = chunk6AE5FKKQ_cjs.computeStrides(target.shape);
|
|
101
|
+
const data = target.data;
|
|
102
|
+
if (target.dtype === "string") {
|
|
103
|
+
if (!Array.isArray(data)) {
|
|
104
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${kind} ${name} expected string data`);
|
|
105
|
+
}
|
|
106
|
+
for (let i = 0; i < size; i++) {
|
|
107
|
+
const value = entry.data[i];
|
|
108
|
+
if (typeof value !== "string") {
|
|
109
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${kind} ${name} expects string data`);
|
|
110
|
+
}
|
|
111
|
+
const offset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
112
|
+
data[offset] = value;
|
|
113
|
+
}
|
|
114
|
+
return;
|
|
115
|
+
}
|
|
116
|
+
if (data instanceof BigInt64Array) {
|
|
117
|
+
for (let i = 0; i < size; i++) {
|
|
118
|
+
const value = entry.data[i];
|
|
119
|
+
if (typeof value !== "bigint") {
|
|
120
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${kind} ${name} expects bigint data`);
|
|
121
|
+
}
|
|
122
|
+
const offset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
123
|
+
data[offset] = value;
|
|
124
|
+
}
|
|
125
|
+
return;
|
|
126
|
+
}
|
|
127
|
+
if (Array.isArray(data)) {
|
|
128
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${kind} ${name} expected numeric data`);
|
|
129
|
+
}
|
|
130
|
+
for (let i = 0; i < size; i++) {
|
|
131
|
+
const value = entry.data[i];
|
|
132
|
+
if (typeof value !== "number") {
|
|
133
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${kind} ${name} expects numeric data`);
|
|
134
|
+
}
|
|
135
|
+
const offset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(i, logicalStrides, target.strides, target.offset);
|
|
136
|
+
data[offset] = value;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
var Module = class _Module {
|
|
140
|
+
/** Child modules registered to this module - stores nested layers/modules */
|
|
141
|
+
_modules = /* @__PURE__ */ new Map();
|
|
142
|
+
/** Parameters of this module - trainable tensors (weights, biases) wrapped as GradTensor */
|
|
143
|
+
_parameters = /* @__PURE__ */ new Map();
|
|
144
|
+
/** Buffers (non-trainable tensors) of this module - e.g., running stats in BatchNorm */
|
|
145
|
+
_buffers = /* @__PURE__ */ new Map();
|
|
146
|
+
/** Training mode flag - affects behavior of layers like Dropout and BatchNorm */
|
|
147
|
+
_training = true;
|
|
148
|
+
/** Forward pre-hooks registered on this module */
|
|
149
|
+
_forwardPreHooks = /* @__PURE__ */ new Map();
|
|
150
|
+
/** Forward hooks registered on this module */
|
|
151
|
+
_forwardHooks = /* @__PURE__ */ new Map();
|
|
152
|
+
/** Incrementing hook id */
|
|
153
|
+
_nextHookId = 0;
|
|
154
|
+
/**
|
|
155
|
+
* Makes the module callable (allows using `module(x)` instead of `module.forward(x)`).
|
|
156
|
+
*
|
|
157
|
+
* @param inputs - Input tensors (Tensor or GradTensor)
|
|
158
|
+
* @returns Output tensor
|
|
159
|
+
*/
|
|
160
|
+
call(...inputs) {
|
|
161
|
+
let curInputs = inputs;
|
|
162
|
+
for (const hook of this._forwardPreHooks.values()) {
|
|
163
|
+
const result = hook(this, curInputs);
|
|
164
|
+
if (Array.isArray(result)) {
|
|
165
|
+
curInputs = result;
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
let output = this.forward(...curInputs);
|
|
169
|
+
for (const hook of this._forwardHooks.values()) {
|
|
170
|
+
const result = hook(this, curInputs, output);
|
|
171
|
+
if (result !== void 0) {
|
|
172
|
+
output = result;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
return output;
|
|
176
|
+
}
|
|
177
|
+
/**
|
|
178
|
+
* Register a child module.
|
|
179
|
+
*
|
|
180
|
+
* @param name - Name of the module
|
|
181
|
+
* @param module - The module to register
|
|
182
|
+
*/
|
|
183
|
+
registerModule(name, module) {
|
|
184
|
+
this._modules.set(name, module);
|
|
185
|
+
}
|
|
186
|
+
/**
|
|
187
|
+
* Register a parameter (trainable tensor).
|
|
188
|
+
*
|
|
189
|
+
* Parameters must be GradTensor instances with requiresGrad=true for
|
|
190
|
+
* proper gradient computation during backpropagation.
|
|
191
|
+
*
|
|
192
|
+
* @param name - Name of the parameter
|
|
193
|
+
* @param param - The parameter tensor (must be GradTensor)
|
|
194
|
+
*/
|
|
195
|
+
registerParameter(name, param) {
|
|
196
|
+
this._parameters.set(name, param);
|
|
197
|
+
}
|
|
198
|
+
/**
|
|
199
|
+
* Register a buffer (non-trainable tensor).
|
|
200
|
+
*
|
|
201
|
+
* Buffers are typically used for running statistics in batch normalization.
|
|
202
|
+
*
|
|
203
|
+
* @param name - Name of the buffer
|
|
204
|
+
* @param buffer - The buffer tensor
|
|
205
|
+
*/
|
|
206
|
+
registerBuffer(name, buffer) {
|
|
207
|
+
this._buffers.set(name, buffer);
|
|
208
|
+
}
|
|
209
|
+
/**
|
|
210
|
+
* Get all parameters of this module and its children.
|
|
211
|
+
*
|
|
212
|
+
* Returns GradTensor instances that are compatible with optimizers.
|
|
213
|
+
* This enables direct usage with optimizer constructors:
|
|
214
|
+
* ```ts
|
|
215
|
+
* const optimizer = new Adam(model.parameters());
|
|
216
|
+
* ```
|
|
217
|
+
*
|
|
218
|
+
* @param recurse - Whether to include parameters of child modules
|
|
219
|
+
* @returns Iterator of GradTensor parameters
|
|
220
|
+
*/
|
|
221
|
+
*parameters(recurse = true) {
|
|
222
|
+
for (const param of this._parameters.values()) {
|
|
223
|
+
yield param;
|
|
224
|
+
}
|
|
225
|
+
if (recurse) {
|
|
226
|
+
for (const module of this._modules.values()) {
|
|
227
|
+
yield* module.parameters(true);
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
/**
|
|
232
|
+
* Get all named parameters of this module and its children.
|
|
233
|
+
*
|
|
234
|
+
* @param prefix - Prefix for parameter names
|
|
235
|
+
* @param recurse - Whether to include parameters of child modules
|
|
236
|
+
* @returns Iterator of [name, parameter] pairs
|
|
237
|
+
*/
|
|
238
|
+
*namedParameters(prefix = "", recurse = true) {
|
|
239
|
+
for (const [name, param] of this._parameters.entries()) {
|
|
240
|
+
const fullName = prefix ? `${prefix}.${name}` : name;
|
|
241
|
+
yield [fullName, param];
|
|
242
|
+
}
|
|
243
|
+
if (recurse) {
|
|
244
|
+
for (const [moduleName, module] of this._modules.entries()) {
|
|
245
|
+
const fullPrefix = prefix ? `${prefix}.${moduleName}` : moduleName;
|
|
246
|
+
yield* module.namedParameters(fullPrefix, true);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
/**
|
|
251
|
+
* Get all child modules.
|
|
252
|
+
*
|
|
253
|
+
* @param recurse - Whether to include nested child modules
|
|
254
|
+
* @returns Iterator of modules
|
|
255
|
+
*/
|
|
256
|
+
*modules(recurse = true) {
|
|
257
|
+
yield this;
|
|
258
|
+
if (recurse) {
|
|
259
|
+
for (const module of this._modules.values()) {
|
|
260
|
+
yield* module.modules(true);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
/**
|
|
265
|
+
* Get all named child modules.
|
|
266
|
+
*
|
|
267
|
+
* @param prefix - Prefix for module names
|
|
268
|
+
* @param recurse - Whether to include nested child modules
|
|
269
|
+
* @returns Iterator of [name, module] pairs
|
|
270
|
+
*/
|
|
271
|
+
*namedModules(prefix = "", recurse = true) {
|
|
272
|
+
yield [prefix, this];
|
|
273
|
+
if (recurse) {
|
|
274
|
+
for (const [name, module] of this._modules.entries()) {
|
|
275
|
+
const fullName = prefix ? `${prefix}.${name}` : name;
|
|
276
|
+
yield* module.namedModules(fullName, true);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
/**
|
|
281
|
+
* Set the module in training mode.
|
|
282
|
+
*
|
|
283
|
+
* This affects certain layers like Dropout and BatchNorm.
|
|
284
|
+
*
|
|
285
|
+
* @param mode - Training mode (true) or evaluation mode (false)
|
|
286
|
+
* @returns this
|
|
287
|
+
*/
|
|
288
|
+
train(mode = true) {
|
|
289
|
+
this._training = mode;
|
|
290
|
+
for (const module of this._modules.values()) {
|
|
291
|
+
module.train(mode);
|
|
292
|
+
}
|
|
293
|
+
return this;
|
|
294
|
+
}
|
|
295
|
+
/**
|
|
296
|
+
* Set the module in evaluation mode.
|
|
297
|
+
*
|
|
298
|
+
* This is equivalent to calling `train(false)`.
|
|
299
|
+
*
|
|
300
|
+
* @returns this
|
|
301
|
+
*/
|
|
302
|
+
eval() {
|
|
303
|
+
return this.train(false);
|
|
304
|
+
}
|
|
305
|
+
/**
|
|
306
|
+
* Check if the module is in training mode.
|
|
307
|
+
*
|
|
308
|
+
* @returns true if in training mode
|
|
309
|
+
*/
|
|
310
|
+
get training() {
|
|
311
|
+
return this._training;
|
|
312
|
+
}
|
|
313
|
+
/**
|
|
314
|
+
* Zero out the gradients of all parameters.
|
|
315
|
+
*
|
|
316
|
+
* Call this before each training iteration to prevent gradient accumulation
|
|
317
|
+
* from previous iterations.
|
|
318
|
+
*
|
|
319
|
+
* For parameters wrapped in GradTensor, this calls zeroGrad() on each.
|
|
320
|
+
* For regular Tensors, this is a no-op until they are converted to GradTensor.
|
|
321
|
+
*
|
|
322
|
+
* @example
|
|
323
|
+
* ```ts
|
|
324
|
+
* model.zeroGrad();
|
|
325
|
+
* const output = model.forward(input);
|
|
326
|
+
* // ... compute loss and backward
|
|
327
|
+
* optimizer.step();
|
|
328
|
+
* ```
|
|
329
|
+
*/
|
|
330
|
+
zeroGrad() {
|
|
331
|
+
for (const param of this.parameters()) {
|
|
332
|
+
param.zeroGrad();
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
/**
|
|
336
|
+
* Get all buffers of this module and its children.
|
|
337
|
+
*/
|
|
338
|
+
*buffers(recurse = true) {
|
|
339
|
+
for (const buffer of this._buffers.values()) {
|
|
340
|
+
yield buffer;
|
|
341
|
+
}
|
|
342
|
+
if (recurse) {
|
|
343
|
+
for (const module of this._modules.values()) {
|
|
344
|
+
yield* module.buffers(true);
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
/**
|
|
349
|
+
* Get all named buffers of this module and its children.
|
|
350
|
+
*/
|
|
351
|
+
*namedBuffers(prefix = "", recurse = true) {
|
|
352
|
+
for (const [name, buffer] of this._buffers.entries()) {
|
|
353
|
+
const fullName = prefix ? `${prefix}.${name}` : name;
|
|
354
|
+
yield [fullName, buffer];
|
|
355
|
+
}
|
|
356
|
+
if (recurse) {
|
|
357
|
+
for (const [moduleName, module] of this._modules.entries()) {
|
|
358
|
+
const fullPrefix = prefix ? `${prefix}.${moduleName}` : moduleName;
|
|
359
|
+
yield* module.namedBuffers(fullPrefix, true);
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
/**
|
|
364
|
+
* Freeze specific parameters by name (or all if none provided).
|
|
365
|
+
*
|
|
366
|
+
* **⚠️ IMPORTANT**: This method creates new GradTensor instances with updated
|
|
367
|
+
* `requiresGrad` flags. Any external references to the old parameter objects
|
|
368
|
+
* will become stale. If you're using an optimizer that holds parameter references,
|
|
369
|
+
* you should recreate the optimizer after freezing/unfreezing parameters.
|
|
370
|
+
*
|
|
371
|
+
* @param names - Array of parameter names to freeze (e.g., ['fc1.weight']). If undefined, freezes all parameters.
|
|
372
|
+
* @param recurse - Whether to include parameters from child modules (default: true)
|
|
373
|
+
*
|
|
374
|
+
* @example
|
|
375
|
+
* ```ts
|
|
376
|
+
* const model = new MyModel();
|
|
377
|
+
* // Freeze only the first layer's weights
|
|
378
|
+
* model.freezeParameters(['fc1.weight']);
|
|
379
|
+
* // Note: Recreate optimizer after freezing
|
|
380
|
+
* const optimizer = new Adam(model.parameters());
|
|
381
|
+
* ```
|
|
382
|
+
*/
|
|
383
|
+
freezeParameters(names, recurse = true) {
|
|
384
|
+
this.setRequiresGradForNames(names, false, recurse);
|
|
385
|
+
}
|
|
386
|
+
/**
|
|
387
|
+
* Unfreeze specific parameters by name (or all if none provided).
|
|
388
|
+
*
|
|
389
|
+
* **⚠️ IMPORTANT**: This method creates new GradTensor instances with updated
|
|
390
|
+
* `requiresGrad` flags. Any external references to the old parameter objects
|
|
391
|
+
* will become stale. If you're using an optimizer that holds parameter references,
|
|
392
|
+
* you should recreate the optimizer after freezing/unfreezing parameters.
|
|
393
|
+
*
|
|
394
|
+
* @param names - Array of parameter names to unfreeze (e.g., ['fc1.weight']). If undefined, unfreezes all parameters.
|
|
395
|
+
* @param recurse - Whether to include parameters from child modules (default: true)
|
|
396
|
+
*
|
|
397
|
+
* @example
|
|
398
|
+
* ```ts
|
|
399
|
+
* const model = new MyModel();
|
|
400
|
+
* model.freezeParameters(); // Freeze all
|
|
401
|
+
* model.unfreezeParameters(['fc2.weight']); // Unfreeze only fc2 weights
|
|
402
|
+
* // Note: Recreate optimizer after unfreezing
|
|
403
|
+
* const optimizer = new Adam(model.parameters());
|
|
404
|
+
* ```
|
|
405
|
+
*/
|
|
406
|
+
unfreezeParameters(names, recurse = true) {
|
|
407
|
+
this.setRequiresGradForNames(names, true, recurse);
|
|
408
|
+
}
|
|
409
|
+
setRequiresGradForNames(names, requiresGrad, recurse) {
|
|
410
|
+
const providedNames = names !== void 0;
|
|
411
|
+
const targetNames = names ?? Array.from(this.namedParameters("", recurse)).map(([name]) => name);
|
|
412
|
+
for (const name of targetNames) {
|
|
413
|
+
const resolved = this.resolveModuleAndName(name);
|
|
414
|
+
if (!resolved) {
|
|
415
|
+
if (providedNames) {
|
|
416
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`Unknown parameter name: ${name}`, "names", name);
|
|
417
|
+
}
|
|
418
|
+
continue;
|
|
419
|
+
}
|
|
420
|
+
const { module, localName } = resolved;
|
|
421
|
+
const param = module._parameters.get(localName);
|
|
422
|
+
if (!param) {
|
|
423
|
+
if (providedNames) {
|
|
424
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`Unknown parameter name: ${name}`, "names", name);
|
|
425
|
+
}
|
|
426
|
+
continue;
|
|
427
|
+
}
|
|
428
|
+
const nextParam = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(param.tensor, { requiresGrad });
|
|
429
|
+
module._parameters.set(localName, nextParam);
|
|
430
|
+
for (const [key, value] of Object.entries(module)) {
|
|
431
|
+
if (value === param) {
|
|
432
|
+
Reflect.set(module, key, nextParam);
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
resolveModuleAndName(fullName) {
|
|
438
|
+
const parts = fullName.split(".");
|
|
439
|
+
let module = this;
|
|
440
|
+
for (let i = 0; i < parts.length - 1; i++) {
|
|
441
|
+
const part = parts[i] ?? "";
|
|
442
|
+
const child = module._modules.get(part);
|
|
443
|
+
if (!child) return null;
|
|
444
|
+
module = child;
|
|
445
|
+
}
|
|
446
|
+
const localName = parts[parts.length - 1] ?? "";
|
|
447
|
+
return { module, localName };
|
|
448
|
+
}
|
|
449
|
+
static setTensorDeviceMetadata(target, device) {
|
|
450
|
+
if (!Reflect.set(target, "device", device)) {
|
|
451
|
+
throw new chunkJSCDE774_cjs.DeepboxError("Failed to update tensor device metadata");
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
/**
|
|
455
|
+
* Get the state dictionary of the module.
|
|
456
|
+
*/
|
|
457
|
+
stateDict() {
|
|
458
|
+
const parameters = {};
|
|
459
|
+
const buffers = {};
|
|
460
|
+
for (const [name, param] of this.namedParameters()) {
|
|
461
|
+
const t = param.tensor;
|
|
462
|
+
const data = cloneTensorData(t);
|
|
463
|
+
parameters[name] = {
|
|
464
|
+
data,
|
|
465
|
+
shape: [...t.shape],
|
|
466
|
+
dtype: t.dtype
|
|
467
|
+
};
|
|
468
|
+
}
|
|
469
|
+
for (const [name, buffer] of this.namedBuffers()) {
|
|
470
|
+
const data = cloneTensorData(buffer);
|
|
471
|
+
buffers[name] = {
|
|
472
|
+
data,
|
|
473
|
+
shape: [...buffer.shape],
|
|
474
|
+
dtype: buffer.dtype
|
|
475
|
+
};
|
|
476
|
+
}
|
|
477
|
+
return { parameters, buffers };
|
|
478
|
+
}
|
|
479
|
+
/**
|
|
480
|
+
* Load state dictionary into the module.
|
|
481
|
+
*/
|
|
482
|
+
loadStateDict(stateDict) {
|
|
483
|
+
const parameters = stateDict.parameters ?? {};
|
|
484
|
+
const buffers = stateDict.buffers ?? {};
|
|
485
|
+
const namedParams = new Map(this.namedParameters());
|
|
486
|
+
const namedBuffs = new Map(this.namedBuffers());
|
|
487
|
+
for (const name of namedParams.keys()) {
|
|
488
|
+
if (!(name in parameters)) {
|
|
489
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`missing parameter: ${name}`, "stateDict.parameters", name);
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
for (const name of namedBuffs.keys()) {
|
|
493
|
+
if (!(name in buffers)) {
|
|
494
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`missing buffer: ${name}`, "stateDict.buffers", name);
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
for (const name of Object.keys(parameters)) {
|
|
498
|
+
if (!namedParams.has(name)) {
|
|
499
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
500
|
+
`unexpected parameter: ${name}`,
|
|
501
|
+
"stateDict.parameters",
|
|
502
|
+
name
|
|
503
|
+
);
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
for (const name of Object.keys(buffers)) {
|
|
507
|
+
if (!namedBuffs.has(name)) {
|
|
508
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`unexpected buffer: ${name}`, "stateDict.buffers", name);
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
for (const [name, entry] of Object.entries(parameters)) {
|
|
512
|
+
const param = namedParams.get(name);
|
|
513
|
+
if (!param) continue;
|
|
514
|
+
validateStateEntryShape(name, "parameter", entry);
|
|
515
|
+
copyStateEntryIntoTensor(name, "parameter", param.tensor, entry);
|
|
516
|
+
}
|
|
517
|
+
for (const [name, entry] of Object.entries(buffers)) {
|
|
518
|
+
const buffer = namedBuffs.get(name);
|
|
519
|
+
if (!buffer) continue;
|
|
520
|
+
validateStateEntryShape(name, "buffer", entry);
|
|
521
|
+
copyStateEntryIntoTensor(name, "buffer", buffer, entry);
|
|
522
|
+
}
|
|
523
|
+
}
|
|
524
|
+
/**
|
|
525
|
+
* Move module to a specific device.
|
|
526
|
+
*
|
|
527
|
+
* **⚠️ WARNING**: This is a metadata-only operation. It updates the device
|
|
528
|
+
* property on parameters and buffers but does NOT actually transfer data
|
|
529
|
+
* between devices. Actual device data transfer requires device-specific
|
|
530
|
+
* memory management which is not yet implemented.
|
|
531
|
+
*
|
|
532
|
+
* This method is provided for API compatibility and future extensibility.
|
|
533
|
+
* Currently, it only updates the `device` metadata field.
|
|
534
|
+
*
|
|
535
|
+
* @param device - Target device identifier (e.g., 'cpu', 'webgpu', 'wasm')
|
|
536
|
+
* @returns this module for method chaining
|
|
537
|
+
*
|
|
538
|
+
* @example
|
|
539
|
+
* ```ts
|
|
540
|
+
* const model = new Linear(10, 5);
|
|
541
|
+
* model.to('webgpu'); // Updates device metadata only
|
|
542
|
+
* ```
|
|
543
|
+
*/
|
|
544
|
+
to(device) {
|
|
545
|
+
if (!chunkJSCDE774_cjs.isDevice(device)) {
|
|
546
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("device must be one of: cpu, webgpu, wasm", "device", device);
|
|
547
|
+
}
|
|
548
|
+
for (const param of this.parameters()) {
|
|
549
|
+
_Module.setTensorDeviceMetadata(param.tensor, device);
|
|
550
|
+
}
|
|
551
|
+
for (const buffer of this.buffers()) {
|
|
552
|
+
_Module.setTensorDeviceMetadata(buffer, device);
|
|
553
|
+
}
|
|
554
|
+
return this;
|
|
555
|
+
}
|
|
556
|
+
/**
|
|
557
|
+
* Apply a function to all modules recursively.
|
|
558
|
+
*/
|
|
559
|
+
apply(fn) {
|
|
560
|
+
for (const module of this.modules()) {
|
|
561
|
+
fn(module);
|
|
562
|
+
}
|
|
563
|
+
return this;
|
|
564
|
+
}
|
|
565
|
+
/**
|
|
566
|
+
* Register a forward pre-hook.
|
|
567
|
+
*/
|
|
568
|
+
registerForwardPreHook(hook) {
|
|
569
|
+
const hookId = this._nextHookId++;
|
|
570
|
+
this._forwardPreHooks.set(hookId, hook);
|
|
571
|
+
return () => {
|
|
572
|
+
this._forwardPreHooks.delete(hookId);
|
|
573
|
+
};
|
|
574
|
+
}
|
|
575
|
+
/**
|
|
576
|
+
* Register a forward hook.
|
|
577
|
+
*/
|
|
578
|
+
registerForwardHook(hook) {
|
|
579
|
+
const hookId = this._nextHookId++;
|
|
580
|
+
this._forwardHooks.set(hookId, hook);
|
|
581
|
+
return () => {
|
|
582
|
+
this._forwardHooks.delete(hookId);
|
|
583
|
+
};
|
|
584
|
+
}
|
|
585
|
+
/**
|
|
586
|
+
* Get string representation of the module.
|
|
587
|
+
*
|
|
588
|
+
* @returns Hierarchical string representation showing module structure
|
|
589
|
+
*/
|
|
590
|
+
toString() {
|
|
591
|
+
const lines = [`${this.constructor.name}(`];
|
|
592
|
+
for (const [name, module] of this._modules.entries()) {
|
|
593
|
+
const childLines = module.toString().split("\n");
|
|
594
|
+
const moduleStr = childLines.map((line, i) => i === 0 ? line : ` ${line}`).join("\n");
|
|
595
|
+
lines.push(` (${name}): ${moduleStr}`);
|
|
596
|
+
}
|
|
597
|
+
lines.push(")");
|
|
598
|
+
return lines.join("\n");
|
|
599
|
+
}
|
|
600
|
+
};
|
|
601
|
+
|
|
602
|
+
// src/nn/containers/Sequential.ts
|
|
603
|
+
var Sequential = class extends Module {
|
|
604
|
+
/** Array of layers in sequential order */
|
|
605
|
+
layers;
|
|
606
|
+
/**
|
|
607
|
+
* Create a new Sequential container.
|
|
608
|
+
*
|
|
609
|
+
* @param layers - Variable number of Module instances to stack sequentially
|
|
610
|
+
* @throws {InvalidParameterError} If no layers are provided
|
|
611
|
+
* @throws {DeepboxError} If a layer is undefined
|
|
612
|
+
*/
|
|
613
|
+
constructor(...layers) {
|
|
614
|
+
super();
|
|
615
|
+
if (layers.length === 0) {
|
|
616
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
617
|
+
"Sequential requires at least one layer",
|
|
618
|
+
"layers",
|
|
619
|
+
layers.length
|
|
620
|
+
);
|
|
621
|
+
}
|
|
622
|
+
this.layers = layers;
|
|
623
|
+
for (let i = 0; i < layers.length; i++) {
|
|
624
|
+
const layer = layers[i];
|
|
625
|
+
if (!layer) {
|
|
626
|
+
throw new chunkJSCDE774_cjs.DeepboxError(`Layer at index ${i} is undefined`);
|
|
627
|
+
}
|
|
628
|
+
this.registerModule(String(i), layer);
|
|
629
|
+
}
|
|
630
|
+
}
|
|
631
|
+
/**
|
|
632
|
+
* Forward pass: sequentially apply all layers.
|
|
633
|
+
*
|
|
634
|
+
* The output of each layer becomes the input to the next layer.
|
|
635
|
+
*
|
|
636
|
+
* @param input - Input tensor (Tensor or GradTensor)
|
|
637
|
+
* @returns Output tensor after passing through all layers
|
|
638
|
+
* @throws {InvalidParameterError} If the input count is invalid or a layer returns multiple outputs
|
|
639
|
+
* @throws {DeepboxError} If a layer is undefined
|
|
640
|
+
*/
|
|
641
|
+
forward(...inputs) {
|
|
642
|
+
if (inputs.length !== 1) {
|
|
643
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
644
|
+
"Sequential.forward expects a single input tensor",
|
|
645
|
+
"inputs",
|
|
646
|
+
inputs.length
|
|
647
|
+
);
|
|
648
|
+
}
|
|
649
|
+
const input = inputs[0];
|
|
650
|
+
if (!input) {
|
|
651
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
652
|
+
"Sequential.forward expects a single input tensor",
|
|
653
|
+
"input",
|
|
654
|
+
input
|
|
655
|
+
);
|
|
656
|
+
}
|
|
657
|
+
let output = input;
|
|
658
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
659
|
+
const layer = this.layers[i];
|
|
660
|
+
if (!layer) {
|
|
661
|
+
throw new chunkJSCDE774_cjs.DeepboxError(`Layer at index ${i} is undefined`);
|
|
662
|
+
}
|
|
663
|
+
const result = layer.call(output);
|
|
664
|
+
if (Array.isArray(result)) {
|
|
665
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
666
|
+
`Sequential does not support layers that return multiple tensors (layer ${i})`,
|
|
667
|
+
"layer",
|
|
668
|
+
i
|
|
669
|
+
);
|
|
670
|
+
}
|
|
671
|
+
output = result;
|
|
672
|
+
}
|
|
673
|
+
return output;
|
|
674
|
+
}
|
|
675
|
+
/**
|
|
676
|
+
* Get a layer by index.
|
|
677
|
+
*
|
|
678
|
+
* @param index - Zero-based index of the layer
|
|
679
|
+
* @returns The layer at the specified index
|
|
680
|
+
* @throws {IndexError} If index is out of bounds
|
|
681
|
+
* @throws {DeepboxError} If a layer is undefined
|
|
682
|
+
*/
|
|
683
|
+
getLayer(index) {
|
|
684
|
+
if (index < 0 || index >= this.layers.length) {
|
|
685
|
+
throw new chunkJSCDE774_cjs.IndexError(`Layer index ${index} out of bounds [0, ${this.layers.length})`, {
|
|
686
|
+
index,
|
|
687
|
+
validRange: [0, this.layers.length - 1]
|
|
688
|
+
});
|
|
689
|
+
}
|
|
690
|
+
const layer = this.layers[index];
|
|
691
|
+
if (!layer) {
|
|
692
|
+
throw new chunkJSCDE774_cjs.DeepboxError(`Layer at index ${index} is undefined`);
|
|
693
|
+
}
|
|
694
|
+
return layer;
|
|
695
|
+
}
|
|
696
|
+
/**
|
|
697
|
+
* Get the number of layers in the sequential container.
|
|
698
|
+
*/
|
|
699
|
+
get length() {
|
|
700
|
+
return this.layers.length;
|
|
701
|
+
}
|
|
702
|
+
/**
|
|
703
|
+
* Get string representation showing all layers.
|
|
704
|
+
*
|
|
705
|
+
* @returns Multi-line string with each layer on a separate line
|
|
706
|
+
*/
|
|
707
|
+
toString() {
|
|
708
|
+
const lines = ["Sequential("];
|
|
709
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
710
|
+
const layer = this.layers[i];
|
|
711
|
+
if (!layer) continue;
|
|
712
|
+
const childLines = layer.toString().split("\n");
|
|
713
|
+
const layerStr = childLines.map((line, idx) => idx === 0 ? line : ` ${line}`).join("\n");
|
|
714
|
+
lines.push(` (${i}): ${layerStr}`);
|
|
715
|
+
}
|
|
716
|
+
lines.push(")");
|
|
717
|
+
return lines.join("\n");
|
|
718
|
+
}
|
|
719
|
+
/**
|
|
720
|
+
* Iterate over all layers.
|
|
721
|
+
*
|
|
722
|
+
* @returns Iterator of layers
|
|
723
|
+
*/
|
|
724
|
+
*[Symbol.iterator]() {
|
|
725
|
+
for (const layer of this.layers) {
|
|
726
|
+
yield layer;
|
|
727
|
+
}
|
|
728
|
+
}
|
|
729
|
+
};
|
|
730
|
+
|
|
731
|
+
// src/nn/layers/activations.ts
|
|
732
|
+
var ReLU = class extends Module {
|
|
733
|
+
forward(input) {
|
|
734
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.relu();
|
|
735
|
+
return chunk6AE5FKKQ_cjs.relu(input);
|
|
736
|
+
}
|
|
737
|
+
toString() {
|
|
738
|
+
return "ReLU()";
|
|
739
|
+
}
|
|
740
|
+
};
|
|
741
|
+
var Sigmoid = class extends Module {
|
|
742
|
+
forward(input) {
|
|
743
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.sigmoid();
|
|
744
|
+
return chunk6AE5FKKQ_cjs.sigmoid(input);
|
|
745
|
+
}
|
|
746
|
+
toString() {
|
|
747
|
+
return "Sigmoid()";
|
|
748
|
+
}
|
|
749
|
+
};
|
|
750
|
+
var Tanh = class extends Module {
|
|
751
|
+
forward(input) {
|
|
752
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.tanh();
|
|
753
|
+
return chunk6AE5FKKQ_cjs.tanh(input);
|
|
754
|
+
}
|
|
755
|
+
toString() {
|
|
756
|
+
return "Tanh()";
|
|
757
|
+
}
|
|
758
|
+
};
|
|
759
|
+
var LeakyReLU = class extends Module {
|
|
760
|
+
alpha;
|
|
761
|
+
constructor(alpha = 0.01) {
|
|
762
|
+
super();
|
|
763
|
+
this.alpha = alpha;
|
|
764
|
+
}
|
|
765
|
+
forward(input) {
|
|
766
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.leakyRelu(this.alpha);
|
|
767
|
+
return chunk6AE5FKKQ_cjs.leakyRelu(input, this.alpha);
|
|
768
|
+
}
|
|
769
|
+
toString() {
|
|
770
|
+
return `LeakyReLU(alpha=${this.alpha})`;
|
|
771
|
+
}
|
|
772
|
+
};
|
|
773
|
+
var ELU = class extends Module {
|
|
774
|
+
alpha;
|
|
775
|
+
constructor(alpha = 1) {
|
|
776
|
+
super();
|
|
777
|
+
this.alpha = alpha;
|
|
778
|
+
}
|
|
779
|
+
forward(input) {
|
|
780
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.elu(this.alpha);
|
|
781
|
+
return chunk6AE5FKKQ_cjs.elu(input, this.alpha);
|
|
782
|
+
}
|
|
783
|
+
toString() {
|
|
784
|
+
return `ELU(alpha=${this.alpha})`;
|
|
785
|
+
}
|
|
786
|
+
};
|
|
787
|
+
var GELU = class extends Module {
|
|
788
|
+
forward(input) {
|
|
789
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) return input.gelu();
|
|
790
|
+
return chunk6AE5FKKQ_cjs.gelu(input);
|
|
791
|
+
}
|
|
792
|
+
toString() {
|
|
793
|
+
return "GELU()";
|
|
794
|
+
}
|
|
795
|
+
};
|
|
796
|
+
var Softmax = class extends Module {
|
|
797
|
+
axis;
|
|
798
|
+
constructor(axis = -1) {
|
|
799
|
+
super();
|
|
800
|
+
this.axis = axis;
|
|
801
|
+
}
|
|
802
|
+
forward(input) {
|
|
803
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
804
|
+
return chunk6AE5FKKQ_cjs.softmax2(input, chunkJSCDE774_cjs.normalizeAxis(this.axis, input.tensor.ndim));
|
|
805
|
+
}
|
|
806
|
+
return chunk6AE5FKKQ_cjs.softmax(input, this.axis);
|
|
807
|
+
}
|
|
808
|
+
toString() {
|
|
809
|
+
return `Softmax(axis=${this.axis})`;
|
|
810
|
+
}
|
|
811
|
+
};
|
|
812
|
+
var LogSoftmax = class extends Module {
|
|
813
|
+
axis;
|
|
814
|
+
constructor(axis = -1) {
|
|
815
|
+
super();
|
|
816
|
+
this.axis = axis;
|
|
817
|
+
}
|
|
818
|
+
forward(input) {
|
|
819
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
820
|
+
return chunk6AE5FKKQ_cjs.logSoftmax2(input, chunkJSCDE774_cjs.normalizeAxis(this.axis, input.tensor.ndim));
|
|
821
|
+
}
|
|
822
|
+
return chunk6AE5FKKQ_cjs.logSoftmax(input, this.axis);
|
|
823
|
+
}
|
|
824
|
+
toString() {
|
|
825
|
+
return `LogSoftmax(axis=${this.axis})`;
|
|
826
|
+
}
|
|
827
|
+
};
|
|
828
|
+
var Softplus = class extends Module {
|
|
829
|
+
forward(input) {
|
|
830
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
831
|
+
return chunk6AE5FKKQ_cjs.GradTensor.fromTensor(chunk6AE5FKKQ_cjs.softplus(input.tensor), {
|
|
832
|
+
requiresGrad: false
|
|
833
|
+
});
|
|
834
|
+
}
|
|
835
|
+
return chunk6AE5FKKQ_cjs.softplus(input);
|
|
836
|
+
}
|
|
837
|
+
toString() {
|
|
838
|
+
return "Softplus()";
|
|
839
|
+
}
|
|
840
|
+
};
|
|
841
|
+
var Swish = class extends Module {
|
|
842
|
+
forward(input) {
|
|
843
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
844
|
+
return chunk6AE5FKKQ_cjs.GradTensor.fromTensor(chunk6AE5FKKQ_cjs.swish(input.tensor), {
|
|
845
|
+
requiresGrad: false
|
|
846
|
+
});
|
|
847
|
+
}
|
|
848
|
+
return chunk6AE5FKKQ_cjs.swish(input);
|
|
849
|
+
}
|
|
850
|
+
toString() {
|
|
851
|
+
return "Swish()";
|
|
852
|
+
}
|
|
853
|
+
};
|
|
854
|
+
var Mish = class extends Module {
|
|
855
|
+
forward(input) {
|
|
856
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
857
|
+
return chunk6AE5FKKQ_cjs.GradTensor.fromTensor(chunk6AE5FKKQ_cjs.mish(input.tensor), {
|
|
858
|
+
requiresGrad: false
|
|
859
|
+
});
|
|
860
|
+
}
|
|
861
|
+
return chunk6AE5FKKQ_cjs.mish(input);
|
|
862
|
+
}
|
|
863
|
+
toString() {
|
|
864
|
+
return "Mish()";
|
|
865
|
+
}
|
|
866
|
+
};
|
|
867
|
+
|
|
868
|
+
// src/nn/layers/dropout.ts
|
|
869
|
+
var Dropout = class extends Module {
|
|
870
|
+
/** Probability of an element being zeroed (dropout rate) */
|
|
871
|
+
p;
|
|
872
|
+
/**
|
|
873
|
+
* Create a new Dropout layer.
|
|
874
|
+
*
|
|
875
|
+
* @param p - Probability of an element being zeroed (0 <= p < 1)
|
|
876
|
+
* @throws {InvalidParameterError} If p is not in valid range [0, 1)
|
|
877
|
+
*/
|
|
878
|
+
constructor(p = 0.5) {
|
|
879
|
+
super();
|
|
880
|
+
if (!Number.isFinite(p) || p < 0 || p >= 1) {
|
|
881
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`Dropout probability must be in [0, 1), got ${p}`, "p", p);
|
|
882
|
+
}
|
|
883
|
+
this.p = p;
|
|
884
|
+
}
|
|
885
|
+
/**
|
|
886
|
+
* Forward pass: apply dropout during training, identity during evaluation.
|
|
887
|
+
*
|
|
888
|
+
* @param input - Input tensor of any shape (Tensor or GradTensor)
|
|
889
|
+
* @returns Output tensor with same shape as input
|
|
890
|
+
*/
|
|
891
|
+
forward(input) {
|
|
892
|
+
const inputTensor = input instanceof chunk6AE5FKKQ_cjs.GradTensor ? input : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(input);
|
|
893
|
+
if (inputTensor.dtype === "string") {
|
|
894
|
+
throw new chunkJSCDE774_cjs.DTypeError("Dropout does not support string dtype");
|
|
895
|
+
}
|
|
896
|
+
return chunk6AE5FKKQ_cjs.dropout(inputTensor, this.p, this.training);
|
|
897
|
+
}
|
|
898
|
+
/**
|
|
899
|
+
* Get string representation of the layer.
|
|
900
|
+
*
|
|
901
|
+
* @returns String representation with dropout probability
|
|
902
|
+
*/
|
|
903
|
+
toString() {
|
|
904
|
+
return `Dropout(p=${this.p})`;
|
|
905
|
+
}
|
|
906
|
+
/**
|
|
907
|
+
* Get the dropout probability.
|
|
908
|
+
*/
|
|
909
|
+
get dropoutRate() {
|
|
910
|
+
return this.p;
|
|
911
|
+
}
|
|
912
|
+
};
|
|
913
|
+
|
|
914
|
+
// src/nn/layers/linear.ts
|
|
915
|
+
var Linear = class extends Module {
|
|
916
|
+
/** Weight matrix of shape (out_features, in_features) */
|
|
917
|
+
weight;
|
|
918
|
+
weightParam;
|
|
919
|
+
/** Bias vector of shape (out_features,) */
|
|
920
|
+
bias;
|
|
921
|
+
biasParam;
|
|
922
|
+
/** Number of input features */
|
|
923
|
+
inFeatures;
|
|
924
|
+
/** Number of output features */
|
|
925
|
+
outFeatures;
|
|
926
|
+
/** Whether this layer has a bias */
|
|
927
|
+
useBias;
|
|
928
|
+
/**
|
|
929
|
+
* Create a new Linear layer.
|
|
930
|
+
*
|
|
931
|
+
* @param inFeatures - Size of each input sample
|
|
932
|
+
* @param outFeatures - Size of each output sample
|
|
933
|
+
* @param options - Configuration options
|
|
934
|
+
* @param options.bias - If true, add learnable bias (default: true)
|
|
935
|
+
* @param options.dtype - Data type for weights (default: 'float32')
|
|
936
|
+
* @param options.device - Device to place tensors on (default: 'cpu')
|
|
937
|
+
*/
|
|
938
|
+
constructor(inFeatures, outFeatures, options = {}) {
|
|
939
|
+
super();
|
|
940
|
+
if (inFeatures <= 0 || !Number.isInteger(inFeatures)) {
|
|
941
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
942
|
+
"inFeatures must be a positive integer",
|
|
943
|
+
"inFeatures",
|
|
944
|
+
inFeatures
|
|
945
|
+
);
|
|
946
|
+
}
|
|
947
|
+
if (outFeatures <= 0 || !Number.isInteger(outFeatures)) {
|
|
948
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
949
|
+
"outFeatures must be a positive integer",
|
|
950
|
+
"outFeatures",
|
|
951
|
+
outFeatures
|
|
952
|
+
);
|
|
953
|
+
}
|
|
954
|
+
this.inFeatures = inFeatures;
|
|
955
|
+
this.outFeatures = outFeatures;
|
|
956
|
+
this.useBias = options.bias ?? true;
|
|
957
|
+
const stdDev = Math.sqrt(2 / inFeatures);
|
|
958
|
+
const weightTensor = chunk6AE5FKKQ_cjs.randn([outFeatures, inFeatures], {
|
|
959
|
+
dtype: options.dtype ?? "float32",
|
|
960
|
+
device: options.device ?? "cpu"
|
|
961
|
+
});
|
|
962
|
+
const scaledWeight = chunk6AE5FKKQ_cjs.mulScalar(weightTensor, stdDev);
|
|
963
|
+
this.weightParam = chunk6AE5FKKQ_cjs.parameter(scaledWeight);
|
|
964
|
+
this.weight = this.weightParam.tensor;
|
|
965
|
+
this.registerParameter("weight", this.weightParam);
|
|
966
|
+
if (this.useBias) {
|
|
967
|
+
const biasTensor = chunk6AE5FKKQ_cjs.zeros([outFeatures], {
|
|
968
|
+
dtype: options.dtype ?? "float32",
|
|
969
|
+
device: options.device ?? "cpu"
|
|
970
|
+
});
|
|
971
|
+
this.biasParam = chunk6AE5FKKQ_cjs.parameter(biasTensor);
|
|
972
|
+
this.bias = this.biasParam.tensor;
|
|
973
|
+
this.registerParameter("bias", this.biasParam);
|
|
974
|
+
}
|
|
975
|
+
}
|
|
976
|
+
forward(input) {
|
|
977
|
+
const inputTensor = input instanceof chunk6AE5FKKQ_cjs.GradTensor ? input.tensor : input;
|
|
978
|
+
if (inputTensor.dtype === "string") {
|
|
979
|
+
throw new chunkJSCDE774_cjs.DTypeError("Linear layer does not support string dtype");
|
|
980
|
+
}
|
|
981
|
+
if (inputTensor.ndim < 1) {
|
|
982
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Linear layer expects at least 1D input; got ndim=${inputTensor.ndim}`);
|
|
983
|
+
}
|
|
984
|
+
const inputFeatures = inputTensor.shape[inputTensor.shape.length - 1] ?? 0;
|
|
985
|
+
if (inputFeatures !== this.inFeatures) {
|
|
986
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
987
|
+
`Linear layer expects ${this.inFeatures} input features; got ${inputFeatures}`
|
|
988
|
+
);
|
|
989
|
+
}
|
|
990
|
+
const isVectorInput = inputTensor.ndim === 1;
|
|
991
|
+
const batchSize = inputTensor.size / this.inFeatures;
|
|
992
|
+
const outputShape = isVectorInput ? [this.outFeatures] : [...inputTensor.shape.slice(0, -1), this.outFeatures];
|
|
993
|
+
if (input instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
994
|
+
const input2d2 = input.reshape([batchSize, this.inFeatures]);
|
|
995
|
+
const output2d2 = input2d2.matmul(this.weightParam.transpose());
|
|
996
|
+
let output2 = output2d2.reshape(outputShape);
|
|
997
|
+
if (this.useBias && this.biasParam) {
|
|
998
|
+
output2 = output2.add(this.biasParam);
|
|
999
|
+
}
|
|
1000
|
+
return output2;
|
|
1001
|
+
}
|
|
1002
|
+
const input2d = chunk6AE5FKKQ_cjs.reshape(inputTensor, [batchSize, this.inFeatures]);
|
|
1003
|
+
const output2d = chunk6AE5FKKQ_cjs.dot(input2d, chunk6AE5FKKQ_cjs.transpose(this.weight));
|
|
1004
|
+
const output = chunk6AE5FKKQ_cjs.reshape(output2d, outputShape);
|
|
1005
|
+
if (this.useBias && this.bias) {
|
|
1006
|
+
return chunk6AE5FKKQ_cjs.add(output, this.bias);
|
|
1007
|
+
}
|
|
1008
|
+
return output;
|
|
1009
|
+
}
|
|
1010
|
+
/**
|
|
1011
|
+
* Get extra representation string for this layer.
|
|
1012
|
+
*
|
|
1013
|
+
* @returns String representation of layer parameters
|
|
1014
|
+
*/
|
|
1015
|
+
toString() {
|
|
1016
|
+
const biasStr = this.useBias ? "bias=true" : "bias=false";
|
|
1017
|
+
return `Linear(in_features=${this.inFeatures}, out_features=${this.outFeatures}, ${biasStr})`;
|
|
1018
|
+
}
|
|
1019
|
+
/**
|
|
1020
|
+
* Get the weight matrix.
|
|
1021
|
+
*
|
|
1022
|
+
* @returns Weight tensor of shape (out_features, in_features)
|
|
1023
|
+
*/
|
|
1024
|
+
getWeight() {
|
|
1025
|
+
return this.weight;
|
|
1026
|
+
}
|
|
1027
|
+
/**
|
|
1028
|
+
* Get the bias vector.
|
|
1029
|
+
*
|
|
1030
|
+
* @returns Bias tensor of shape (out_features,) or undefined if no bias
|
|
1031
|
+
*/
|
|
1032
|
+
getBias() {
|
|
1033
|
+
return this.bias;
|
|
1034
|
+
}
|
|
1035
|
+
/**
|
|
1036
|
+
* Get the number of input features.
|
|
1037
|
+
*/
|
|
1038
|
+
get inputSize() {
|
|
1039
|
+
return this.inFeatures;
|
|
1040
|
+
}
|
|
1041
|
+
/**
|
|
1042
|
+
* Get the number of output features.
|
|
1043
|
+
*/
|
|
1044
|
+
get outputSize() {
|
|
1045
|
+
return this.outFeatures;
|
|
1046
|
+
}
|
|
1047
|
+
};
|
|
1048
|
+
|
|
1049
|
+
// src/nn/layers/normalization.ts
|
|
1050
|
+
function toContiguousTensor(t) {
|
|
1051
|
+
if (chunk6AE5FKKQ_cjs.isContiguous(t.shape, t.strides)) {
|
|
1052
|
+
return t;
|
|
1053
|
+
}
|
|
1054
|
+
if (t.dtype === "string") {
|
|
1055
|
+
throw new chunkJSCDE774_cjs.DTypeError("Normalization does not support string dtype");
|
|
1056
|
+
}
|
|
1057
|
+
const Ctor = chunkJSCDE774_cjs.dtypeToTypedArrayCtor(t.dtype);
|
|
1058
|
+
const out = new Ctor(t.size);
|
|
1059
|
+
const logicalStrides = chunk6AE5FKKQ_cjs.computeStrides(t.shape);
|
|
1060
|
+
const data = t.data;
|
|
1061
|
+
if (Array.isArray(data)) {
|
|
1062
|
+
throw new chunkJSCDE774_cjs.DTypeError("Normalization does not support string dtype");
|
|
1063
|
+
}
|
|
1064
|
+
if (data instanceof BigInt64Array) {
|
|
1065
|
+
if (!(out instanceof BigInt64Array)) {
|
|
1066
|
+
throw new chunkJSCDE774_cjs.DTypeError("Expected int64 output buffer for int64 tensor");
|
|
1067
|
+
}
|
|
1068
|
+
for (let i = 0; i < t.size; i++) {
|
|
1069
|
+
const offset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(i, logicalStrides, t.strides, t.offset);
|
|
1070
|
+
out[i] = chunkJSCDE774_cjs.getBigIntElement(data, offset);
|
|
1071
|
+
}
|
|
1072
|
+
} else {
|
|
1073
|
+
if (out instanceof BigInt64Array) {
|
|
1074
|
+
throw new chunkJSCDE774_cjs.DTypeError("Unexpected int64 output buffer for numeric tensor");
|
|
1075
|
+
}
|
|
1076
|
+
for (let i = 0; i < t.size; i++) {
|
|
1077
|
+
const offset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(i, logicalStrides, t.strides, t.offset);
|
|
1078
|
+
out[i] = chunkJSCDE774_cjs.getNumericElement(data, offset);
|
|
1079
|
+
}
|
|
1080
|
+
}
|
|
1081
|
+
return chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
1082
|
+
data: out,
|
|
1083
|
+
shape: t.shape,
|
|
1084
|
+
dtype: t.dtype,
|
|
1085
|
+
device: t.device
|
|
1086
|
+
});
|
|
1087
|
+
}
|
|
1088
|
+
var BatchNorm1d = class extends Module {
|
|
1089
|
+
numFeatures;
|
|
1090
|
+
eps;
|
|
1091
|
+
momentum;
|
|
1092
|
+
affine;
|
|
1093
|
+
trackRunningStats;
|
|
1094
|
+
gamma;
|
|
1095
|
+
beta;
|
|
1096
|
+
runningMean;
|
|
1097
|
+
runningVar;
|
|
1098
|
+
constructor(numFeatures, options = {}) {
|
|
1099
|
+
super();
|
|
1100
|
+
if (!Number.isFinite(numFeatures) || numFeatures <= 0 || Math.trunc(numFeatures) !== numFeatures) {
|
|
1101
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1102
|
+
"numFeatures must be a positive integer",
|
|
1103
|
+
"numFeatures",
|
|
1104
|
+
numFeatures
|
|
1105
|
+
);
|
|
1106
|
+
}
|
|
1107
|
+
this.numFeatures = numFeatures;
|
|
1108
|
+
this.eps = options.eps ?? 1e-5;
|
|
1109
|
+
if (!Number.isFinite(this.eps) || this.eps <= 0) {
|
|
1110
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("eps must be a positive number", "eps", this.eps);
|
|
1111
|
+
}
|
|
1112
|
+
this.momentum = options.momentum ?? 0.1;
|
|
1113
|
+
if (!Number.isFinite(this.momentum) || this.momentum < 0 || this.momentum > 1) {
|
|
1114
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1115
|
+
"momentum must be in range [0, 1]",
|
|
1116
|
+
"momentum",
|
|
1117
|
+
this.momentum
|
|
1118
|
+
);
|
|
1119
|
+
}
|
|
1120
|
+
this.affine = options.affine ?? true;
|
|
1121
|
+
this.trackRunningStats = options.trackRunningStats ?? true;
|
|
1122
|
+
if (this.affine) {
|
|
1123
|
+
const gamma = chunk6AE5FKKQ_cjs.ones([numFeatures]);
|
|
1124
|
+
const beta = chunk6AE5FKKQ_cjs.zeros([numFeatures]);
|
|
1125
|
+
this.gamma = chunk6AE5FKKQ_cjs.parameter(gamma);
|
|
1126
|
+
this.beta = chunk6AE5FKKQ_cjs.parameter(beta);
|
|
1127
|
+
this.registerParameter("weight", this.gamma);
|
|
1128
|
+
this.registerParameter("bias", this.beta);
|
|
1129
|
+
}
|
|
1130
|
+
this.runningMean = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(chunk6AE5FKKQ_cjs.zeros([numFeatures]), {
|
|
1131
|
+
requiresGrad: false
|
|
1132
|
+
});
|
|
1133
|
+
this.runningVar = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(chunk6AE5FKKQ_cjs.ones([numFeatures]), {
|
|
1134
|
+
requiresGrad: false
|
|
1135
|
+
});
|
|
1136
|
+
if (this.trackRunningStats) {
|
|
1137
|
+
this.registerBuffer("running_mean", this.runningMean.tensor);
|
|
1138
|
+
this.registerBuffer("running_var", this.runningVar.tensor);
|
|
1139
|
+
}
|
|
1140
|
+
}
|
|
1141
|
+
forward(x) {
|
|
1142
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1143
|
+
const inputDtype = input.dtype;
|
|
1144
|
+
if (inputDtype === "string") {
|
|
1145
|
+
throw new chunkJSCDE774_cjs.DTypeError("BatchNorm1d does not support string dtype");
|
|
1146
|
+
}
|
|
1147
|
+
if (input.ndim !== 2 && input.ndim !== 3) {
|
|
1148
|
+
throw new chunkJSCDE774_cjs.ShapeError(`BatchNorm1d expects 2D or 3D input; got ndim=${input.ndim}`);
|
|
1149
|
+
}
|
|
1150
|
+
const nFeatures = input.shape[1] ?? 0;
|
|
1151
|
+
if (nFeatures !== this.numFeatures) {
|
|
1152
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected ${this.numFeatures} features, got ${nFeatures}`);
|
|
1153
|
+
}
|
|
1154
|
+
const useBatchStats = this.training || !this.trackRunningStats;
|
|
1155
|
+
let mean2;
|
|
1156
|
+
let varVal;
|
|
1157
|
+
let inputReshaped = input;
|
|
1158
|
+
if (input.ndim === 3) {
|
|
1159
|
+
const batch = input.shape[0] ?? 0;
|
|
1160
|
+
const length = input.shape[2] ?? 0;
|
|
1161
|
+
const flat = batch * length;
|
|
1162
|
+
const numericInputDtype = chunkJSCDE774_cjs.ensureNumericDType(inputDtype, "BatchNorm1d");
|
|
1163
|
+
inputReshaped = input.transpose([0, 2, 1]).mul(chunk6AE5FKKQ_cjs.GradTensor.scalar(1, { dtype: numericInputDtype })).reshape([flat, nFeatures]);
|
|
1164
|
+
}
|
|
1165
|
+
if (useBatchStats) {
|
|
1166
|
+
if (inputReshaped.shape[0] === 0) {
|
|
1167
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1168
|
+
"BatchNorm requires at least one element",
|
|
1169
|
+
"input",
|
|
1170
|
+
input.shape
|
|
1171
|
+
);
|
|
1172
|
+
}
|
|
1173
|
+
mean2 = inputReshaped.mean(0);
|
|
1174
|
+
varVal = chunk6AE5FKKQ_cjs.variance2(inputReshaped, 0, 0);
|
|
1175
|
+
if (this.trackRunningStats) {
|
|
1176
|
+
chunk6AE5FKKQ_cjs.noGrad(() => {
|
|
1177
|
+
const n = inputReshaped.shape[0] ?? 0;
|
|
1178
|
+
const unbiasedVar = n > 1 ? chunk6AE5FKKQ_cjs.variance2(inputReshaped, 0, 1) : chunk6AE5FKKQ_cjs.variance2(inputReshaped, 0, 0);
|
|
1179
|
+
const m = this.momentum;
|
|
1180
|
+
const statsDtype = this.runningMean.dtype;
|
|
1181
|
+
if (statsDtype === "string") {
|
|
1182
|
+
throw new chunkJSCDE774_cjs.DTypeError("BatchNorm running statistics must be numeric");
|
|
1183
|
+
}
|
|
1184
|
+
const oneMinusM = chunk6AE5FKKQ_cjs.GradTensor.scalar(1 - m, { dtype: statsDtype });
|
|
1185
|
+
const mScalar = chunk6AE5FKKQ_cjs.GradTensor.scalar(m, { dtype: statsDtype });
|
|
1186
|
+
const newMean = this.runningMean.mul(oneMinusM).add(mean2.mul(mScalar));
|
|
1187
|
+
const newVar = this.runningVar.mul(oneMinusM).add(unbiasedVar.mul(mScalar));
|
|
1188
|
+
this.runningMean = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(newMean.tensor, {
|
|
1189
|
+
requiresGrad: false
|
|
1190
|
+
});
|
|
1191
|
+
this.runningVar = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(newVar.tensor, {
|
|
1192
|
+
requiresGrad: false
|
|
1193
|
+
});
|
|
1194
|
+
this.registerBuffer("running_mean", this.runningMean.tensor);
|
|
1195
|
+
this.registerBuffer("running_var", this.runningVar.tensor);
|
|
1196
|
+
});
|
|
1197
|
+
}
|
|
1198
|
+
} else {
|
|
1199
|
+
mean2 = this.runningMean;
|
|
1200
|
+
varVal = this.runningVar;
|
|
1201
|
+
}
|
|
1202
|
+
let meanBroadcast = mean2;
|
|
1203
|
+
let varBroadcast = varVal;
|
|
1204
|
+
if (input.ndim === 3) {
|
|
1205
|
+
meanBroadcast = mean2.reshape([1, nFeatures, 1]);
|
|
1206
|
+
varBroadcast = varVal.reshape([1, nFeatures, 1]);
|
|
1207
|
+
} else {
|
|
1208
|
+
meanBroadcast = mean2.reshape([1, nFeatures]);
|
|
1209
|
+
varBroadcast = varVal.reshape([1, nFeatures]);
|
|
1210
|
+
}
|
|
1211
|
+
const epsTensor = chunk6AE5FKKQ_cjs.GradTensor.scalar(this.eps, { dtype: inputDtype });
|
|
1212
|
+
const denom = varBroadcast.add(epsTensor).sqrt();
|
|
1213
|
+
let out = input.sub(meanBroadcast).div(denom);
|
|
1214
|
+
if (this.affine && this.gamma && this.beta) {
|
|
1215
|
+
let gammaB = this.gamma;
|
|
1216
|
+
let betaB = this.beta;
|
|
1217
|
+
if (input.ndim === 3) {
|
|
1218
|
+
gammaB = this.gamma.reshape([1, nFeatures, 1]);
|
|
1219
|
+
betaB = this.beta.reshape([1, nFeatures, 1]);
|
|
1220
|
+
} else {
|
|
1221
|
+
gammaB = this.gamma.reshape([1, nFeatures]);
|
|
1222
|
+
betaB = this.beta.reshape([1, nFeatures]);
|
|
1223
|
+
}
|
|
1224
|
+
out = out.mul(gammaB).add(betaB);
|
|
1225
|
+
}
|
|
1226
|
+
return out;
|
|
1227
|
+
}
|
|
1228
|
+
toString() {
|
|
1229
|
+
return `BatchNorm1d(${this.numFeatures}, eps=${this.eps}, momentum=${this.momentum}, affine=${this.affine})`;
|
|
1230
|
+
}
|
|
1231
|
+
};
|
|
1232
|
+
var LayerNorm = class extends Module {
|
|
1233
|
+
normalizedShape;
|
|
1234
|
+
eps;
|
|
1235
|
+
elementwiseAffine;
|
|
1236
|
+
gamma;
|
|
1237
|
+
beta;
|
|
1238
|
+
constructor(normalizedShape, options = {}) {
|
|
1239
|
+
super();
|
|
1240
|
+
this.normalizedShape = typeof normalizedShape === "number" ? [normalizedShape] : Array.from(normalizedShape);
|
|
1241
|
+
if (this.normalizedShape.length === 0) {
|
|
1242
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1243
|
+
"normalizedShape must contain at least one dimension",
|
|
1244
|
+
"normalizedShape",
|
|
1245
|
+
normalizedShape
|
|
1246
|
+
);
|
|
1247
|
+
}
|
|
1248
|
+
for (const dim of this.normalizedShape) {
|
|
1249
|
+
if (!Number.isFinite(dim) || dim <= 0 || Math.trunc(dim) !== dim) {
|
|
1250
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1251
|
+
"All dimensions in normalizedShape must be positive integers",
|
|
1252
|
+
"normalizedShape",
|
|
1253
|
+
normalizedShape
|
|
1254
|
+
);
|
|
1255
|
+
}
|
|
1256
|
+
}
|
|
1257
|
+
this.eps = options.eps ?? 1e-5;
|
|
1258
|
+
if (!Number.isFinite(this.eps) || this.eps <= 0) {
|
|
1259
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("eps must be a positive number", "eps", this.eps);
|
|
1260
|
+
}
|
|
1261
|
+
this.elementwiseAffine = options.elementwiseAffine ?? true;
|
|
1262
|
+
if (this.elementwiseAffine) {
|
|
1263
|
+
this.gamma = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.ones(this.normalizedShape));
|
|
1264
|
+
this.beta = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.zeros(this.normalizedShape));
|
|
1265
|
+
this.registerParameter("weight", this.gamma);
|
|
1266
|
+
this.registerParameter("bias", this.beta);
|
|
1267
|
+
}
|
|
1268
|
+
}
|
|
1269
|
+
forward(x) {
|
|
1270
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1271
|
+
const inputDtype = input.dtype;
|
|
1272
|
+
if (inputDtype === "string") {
|
|
1273
|
+
throw new chunkJSCDE774_cjs.DTypeError("LayerNorm does not support string dtype");
|
|
1274
|
+
}
|
|
1275
|
+
let workingInput = input;
|
|
1276
|
+
if (!chunk6AE5FKKQ_cjs.isContiguous(input.tensor.shape, input.tensor.strides)) {
|
|
1277
|
+
const contiguous = toContiguousTensor(input.tensor);
|
|
1278
|
+
workingInput = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(contiguous, {
|
|
1279
|
+
requiresGrad: input.requiresGrad
|
|
1280
|
+
});
|
|
1281
|
+
}
|
|
1282
|
+
const inputShape = workingInput.shape;
|
|
1283
|
+
const normShape = this.normalizedShape;
|
|
1284
|
+
if (normShape.length > inputShape.length) {
|
|
1285
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Input shape ${inputShape} too small for normalizedShape ${normShape}`);
|
|
1286
|
+
}
|
|
1287
|
+
const suffixStart = inputShape.length - normShape.length;
|
|
1288
|
+
for (let i = 0; i < normShape.length; i++) {
|
|
1289
|
+
if (inputShape[suffixStart + i] !== normShape[i]) {
|
|
1290
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1291
|
+
`Input shape ${inputShape} does not end with normalizedShape ${normShape}`
|
|
1292
|
+
);
|
|
1293
|
+
}
|
|
1294
|
+
}
|
|
1295
|
+
const outerDims = inputShape.slice(0, suffixStart);
|
|
1296
|
+
const normSize = normShape.reduce((a, b) => a * b, 1);
|
|
1297
|
+
const flattenedShape = [...outerDims, normSize];
|
|
1298
|
+
const inputReshaped = workingInput.reshape(flattenedShape);
|
|
1299
|
+
const mean2 = inputReshaped.mean(-1, true);
|
|
1300
|
+
const varVal = chunk6AE5FKKQ_cjs.variance2(inputReshaped, -1, 0);
|
|
1301
|
+
const varReshaped = varVal.reshape(mean2.shape);
|
|
1302
|
+
const epsTensor = chunk6AE5FKKQ_cjs.GradTensor.scalar(this.eps, { dtype: inputDtype });
|
|
1303
|
+
const denom = varReshaped.add(epsTensor).sqrt();
|
|
1304
|
+
const normalizedReshaped = inputReshaped.sub(mean2).div(denom);
|
|
1305
|
+
let out = normalizedReshaped.reshape(inputShape);
|
|
1306
|
+
if (this.elementwiseAffine && this.gamma && this.beta) {
|
|
1307
|
+
out = out.mul(this.gamma).add(this.beta);
|
|
1308
|
+
}
|
|
1309
|
+
return out;
|
|
1310
|
+
}
|
|
1311
|
+
toString() {
|
|
1312
|
+
return `LayerNorm(${this.normalizedShape}, eps=${this.eps}, elementwise_affine=${this.elementwiseAffine})`;
|
|
1313
|
+
}
|
|
1314
|
+
};
|
|
1315
|
+
|
|
1316
|
+
// src/nn/layers/attention.ts
|
|
1317
|
+
var MultiheadAttention = class extends Module {
|
|
1318
|
+
/** Embedding dimension */
|
|
1319
|
+
embedDim;
|
|
1320
|
+
/** Number of attention heads */
|
|
1321
|
+
numHeads;
|
|
1322
|
+
/** Dimension of each head */
|
|
1323
|
+
headDim;
|
|
1324
|
+
/** Scaling factor for dot product attention */
|
|
1325
|
+
scale;
|
|
1326
|
+
/** Whether to add bias to projections */
|
|
1327
|
+
useBias;
|
|
1328
|
+
/** Dropout probability applied to attention weights */
|
|
1329
|
+
dropout;
|
|
1330
|
+
/** Query projection weights (embedDim, embedDim) */
|
|
1331
|
+
wQ;
|
|
1332
|
+
bQ;
|
|
1333
|
+
/** Key projection weights (embedDim, embedDim) */
|
|
1334
|
+
wK;
|
|
1335
|
+
bK;
|
|
1336
|
+
/** Value projection weights (embedDim, embedDim) */
|
|
1337
|
+
wV;
|
|
1338
|
+
bV;
|
|
1339
|
+
/** Output projection weights (embedDim, embedDim) */
|
|
1340
|
+
wO;
|
|
1341
|
+
bO;
|
|
1342
|
+
/**
|
|
1343
|
+
* Create a new MultiheadAttention layer.
|
|
1344
|
+
*
|
|
1345
|
+
* @param embedDim - Total dimension of the model (must be divisible by numHeads)
|
|
1346
|
+
* @param numHeads - Number of parallel attention heads
|
|
1347
|
+
* @param options - Configuration options
|
|
1348
|
+
* @param options.bias - Whether to add bias to projections (default: true)
|
|
1349
|
+
* @param options.dropout - Dropout probability applied to attention weights (default: 0.0)
|
|
1350
|
+
*/
|
|
1351
|
+
constructor(embedDim, numHeads, options = {}) {
|
|
1352
|
+
super();
|
|
1353
|
+
if (!Number.isInteger(embedDim) || embedDim <= 0) {
|
|
1354
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("embedDim must be a positive integer", "embedDim", embedDim);
|
|
1355
|
+
}
|
|
1356
|
+
if (!Number.isInteger(numHeads) || numHeads <= 0) {
|
|
1357
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("numHeads must be a positive integer", "numHeads", numHeads);
|
|
1358
|
+
}
|
|
1359
|
+
if (embedDim % numHeads !== 0) {
|
|
1360
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1361
|
+
`embedDim (${embedDim}) must be divisible by numHeads (${numHeads})`,
|
|
1362
|
+
"embedDim",
|
|
1363
|
+
embedDim
|
|
1364
|
+
);
|
|
1365
|
+
}
|
|
1366
|
+
const dropout2 = options.dropout ?? 0;
|
|
1367
|
+
if (!Number.isFinite(dropout2) || dropout2 < 0 || dropout2 >= 1) {
|
|
1368
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("dropout must be in [0, 1)", "dropout", dropout2);
|
|
1369
|
+
}
|
|
1370
|
+
this.embedDim = embedDim;
|
|
1371
|
+
this.numHeads = numHeads;
|
|
1372
|
+
this.headDim = embedDim / numHeads;
|
|
1373
|
+
this.scale = Math.sqrt(this.headDim);
|
|
1374
|
+
this.useBias = options.bias ?? true;
|
|
1375
|
+
this.dropout = dropout2;
|
|
1376
|
+
const stdDev = Math.sqrt(2 / (embedDim + embedDim));
|
|
1377
|
+
this.wQ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1378
|
+
this.wK = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1379
|
+
this.wV = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1380
|
+
this.wO = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([embedDim, embedDim]), stdDev));
|
|
1381
|
+
this.registerParameter("in_proj_weight_q", this.wQ);
|
|
1382
|
+
this.registerParameter("in_proj_weight_k", this.wK);
|
|
1383
|
+
this.registerParameter("in_proj_weight_v", this.wV);
|
|
1384
|
+
this.registerParameter("out_proj_weight", this.wO);
|
|
1385
|
+
if (this.useBias) {
|
|
1386
|
+
this.bQ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.zeros([embedDim]));
|
|
1387
|
+
this.bK = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.zeros([embedDim]));
|
|
1388
|
+
this.bV = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.zeros([embedDim]));
|
|
1389
|
+
this.bO = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.zeros([embedDim]));
|
|
1390
|
+
this.registerParameter("in_proj_bias_q", this.bQ);
|
|
1391
|
+
this.registerParameter("in_proj_bias_k", this.bK);
|
|
1392
|
+
this.registerParameter("in_proj_bias_v", this.bV);
|
|
1393
|
+
this.registerParameter("out_proj_bias", this.bO);
|
|
1394
|
+
}
|
|
1395
|
+
}
|
|
1396
|
+
/**
|
|
1397
|
+
* Forward pass of multi-head attention.
|
|
1398
|
+
*
|
|
1399
|
+
* @param query - Query tensor of shape (batch, seqLen, embedDim)
|
|
1400
|
+
* @param key - Key tensor of shape (batch, seqLen, embedDim)
|
|
1401
|
+
* @param value - Value tensor of shape (batch, seqLen, embedDim)
|
|
1402
|
+
* @returns Output tensor of same shape as query
|
|
1403
|
+
*/
|
|
1404
|
+
forward(...inputs) {
|
|
1405
|
+
if (inputs.length < 1 || inputs.length > 3) {
|
|
1406
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1407
|
+
"MultiheadAttention.forward expects 1 to 3 input tensors",
|
|
1408
|
+
"inputs",
|
|
1409
|
+
inputs.length
|
|
1410
|
+
);
|
|
1411
|
+
}
|
|
1412
|
+
const queryInput = inputs[0];
|
|
1413
|
+
if (queryInput === void 0) {
|
|
1414
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Query tensor is required", "query", queryInput);
|
|
1415
|
+
}
|
|
1416
|
+
const query = queryInput instanceof chunk6AE5FKKQ_cjs.GradTensor ? queryInput : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(queryInput);
|
|
1417
|
+
const keyInput = inputs[1] ?? queryInput;
|
|
1418
|
+
const key = keyInput instanceof chunk6AE5FKKQ_cjs.GradTensor ? keyInput : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(keyInput);
|
|
1419
|
+
const valueInput = inputs[2] ?? queryInput;
|
|
1420
|
+
const value = valueInput instanceof chunk6AE5FKKQ_cjs.GradTensor ? valueInput : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(valueInput);
|
|
1421
|
+
if (query.dtype === "string") throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1422
|
+
if (query.ndim !== key.ndim || query.ndim !== value.ndim) {
|
|
1423
|
+
throw new chunkJSCDE774_cjs.ShapeError("query, key, and value must have same rank");
|
|
1424
|
+
}
|
|
1425
|
+
if (query.ndim !== 2 && query.ndim !== 3) {
|
|
1426
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Query must be 2D or 3D; got ndim=${query.ndim}`);
|
|
1427
|
+
}
|
|
1428
|
+
if (key.ndim !== 2 && key.ndim !== 3) {
|
|
1429
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Key must be 2D or 3D; got ndim=${key.ndim}`);
|
|
1430
|
+
}
|
|
1431
|
+
if (value.ndim !== 2 && value.ndim !== 3) {
|
|
1432
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Value must be 2D or 3D; got ndim=${value.ndim}`);
|
|
1433
|
+
}
|
|
1434
|
+
let q = query;
|
|
1435
|
+
let k = key;
|
|
1436
|
+
let v = value;
|
|
1437
|
+
if (q.ndim === 2) q = q.reshape([1, q.shape[0] ?? 0, q.shape[1] ?? 0]);
|
|
1438
|
+
if (k.ndim === 2) k = k.reshape([1, k.shape[0] ?? 0, k.shape[1] ?? 0]);
|
|
1439
|
+
if (v.ndim === 2) v = v.reshape([1, v.shape[0] ?? 0, v.shape[1] ?? 0]);
|
|
1440
|
+
const batchSize = q.shape[0] ?? 0;
|
|
1441
|
+
const seqLenQ = q.shape[1] ?? 0;
|
|
1442
|
+
const seqLenK = k.shape[1] ?? 0;
|
|
1443
|
+
const seqLenV = v.shape[1] ?? 0;
|
|
1444
|
+
const embedDim = q.shape[2] ?? 0;
|
|
1445
|
+
if (embedDim !== this.embedDim) {
|
|
1446
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Query embedDim mismatch: expected ${this.embedDim}, got ${embedDim}`);
|
|
1447
|
+
}
|
|
1448
|
+
if (k.shape[2] !== this.embedDim) {
|
|
1449
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Key embedDim mismatch: expected ${this.embedDim}, got ${k.shape[2]}`);
|
|
1450
|
+
}
|
|
1451
|
+
if (v.shape[2] !== this.embedDim) {
|
|
1452
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Value embedDim mismatch: expected ${this.embedDim}, got ${v.shape[2]}`);
|
|
1453
|
+
}
|
|
1454
|
+
if (k.shape[0] !== batchSize || v.shape[0] !== batchSize) {
|
|
1455
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1456
|
+
`batch size mismatch: query=${batchSize}, key=${k.shape[0]}, value=${v.shape[0]}`
|
|
1457
|
+
);
|
|
1458
|
+
}
|
|
1459
|
+
if (seqLenK !== seqLenV) {
|
|
1460
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Key/value sequence length mismatch: key=${seqLenK}, value=${seqLenV}`);
|
|
1461
|
+
}
|
|
1462
|
+
let Q = q.matmul(this.wQ.transpose());
|
|
1463
|
+
if (this.bQ) Q = Q.add(this.bQ);
|
|
1464
|
+
let K = k.matmul(this.wK.transpose());
|
|
1465
|
+
if (this.bK) K = K.add(this.bK);
|
|
1466
|
+
let V = v.matmul(this.wV.transpose());
|
|
1467
|
+
if (this.bV) V = V.add(this.bV);
|
|
1468
|
+
const H = this.numHeads;
|
|
1469
|
+
const D = this.headDim;
|
|
1470
|
+
Q = Q.reshape([batchSize, seqLenQ, H, D]).transpose([0, 2, 1, 3]);
|
|
1471
|
+
K = K.reshape([batchSize, seqLenK, H, D]).transpose([0, 2, 1, 3]);
|
|
1472
|
+
V = V.reshape([batchSize, seqLenV, H, D]).transpose([0, 2, 1, 3]);
|
|
1473
|
+
let scores = Q.matmul(K.transpose([0, 1, 3, 2]));
|
|
1474
|
+
scores = scores.div(chunk6AE5FKKQ_cjs.GradTensor.scalar(this.scale));
|
|
1475
|
+
let attn = chunk6AE5FKKQ_cjs.softmax2(scores, -1);
|
|
1476
|
+
attn = chunk6AE5FKKQ_cjs.dropout(attn, this.dropout, this.training);
|
|
1477
|
+
const context = attn.matmul(V);
|
|
1478
|
+
const contextDtype = chunkJSCDE774_cjs.ensureNumericDType(context.dtype, "MultiheadAttention");
|
|
1479
|
+
const contextReshaped = context.transpose([0, 2, 1, 3]).mul(chunk6AE5FKKQ_cjs.GradTensor.scalar(1, { dtype: contextDtype })).reshape([batchSize, seqLenQ, this.embedDim]);
|
|
1480
|
+
let output = contextReshaped.matmul(this.wO.transpose());
|
|
1481
|
+
if (this.bO) output = output.add(this.bO);
|
|
1482
|
+
if (query.ndim === 2) {
|
|
1483
|
+
output = output.reshape([seqLenQ, this.embedDim]);
|
|
1484
|
+
}
|
|
1485
|
+
return output;
|
|
1486
|
+
}
|
|
1487
|
+
toString() {
|
|
1488
|
+
return `MultiheadAttention(embed_dim=${this.embedDim}, num_heads=${this.numHeads})`;
|
|
1489
|
+
}
|
|
1490
|
+
};
|
|
1491
|
+
var TransformerEncoderLayer = class extends Module {
|
|
1492
|
+
dModel;
|
|
1493
|
+
nHead;
|
|
1494
|
+
dFF;
|
|
1495
|
+
selfAttn;
|
|
1496
|
+
linear1;
|
|
1497
|
+
linear2;
|
|
1498
|
+
norm1;
|
|
1499
|
+
norm2;
|
|
1500
|
+
dropout;
|
|
1501
|
+
// We use functional dropout in forward, or could use Dropout module.
|
|
1502
|
+
// Using Dropout module is cleaner.
|
|
1503
|
+
dropout1;
|
|
1504
|
+
dropout2;
|
|
1505
|
+
dropout3;
|
|
1506
|
+
constructor(dModel, nHead, dFF, options = {}) {
|
|
1507
|
+
super();
|
|
1508
|
+
if (!Number.isInteger(dModel) || dModel <= 0) {
|
|
1509
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("dModel must be a positive integer", "dModel", dModel);
|
|
1510
|
+
}
|
|
1511
|
+
if (!Number.isInteger(nHead) || nHead <= 0) {
|
|
1512
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("nHead must be a positive integer", "nHead", nHead);
|
|
1513
|
+
}
|
|
1514
|
+
if (dModel % nHead !== 0) {
|
|
1515
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1516
|
+
`dModel (${dModel}) must be divisible by nHead (${nHead})`,
|
|
1517
|
+
"dModel",
|
|
1518
|
+
dModel
|
|
1519
|
+
);
|
|
1520
|
+
}
|
|
1521
|
+
if (!Number.isInteger(dFF) || dFF <= 0) {
|
|
1522
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("dFF must be a positive integer", "dFF", dFF);
|
|
1523
|
+
}
|
|
1524
|
+
const dropout2 = options.dropout ?? 0.1;
|
|
1525
|
+
const eps = options.eps ?? 1e-5;
|
|
1526
|
+
this.dModel = dModel;
|
|
1527
|
+
this.nHead = nHead;
|
|
1528
|
+
this.dFF = dFF;
|
|
1529
|
+
this.dropout = dropout2;
|
|
1530
|
+
this.selfAttn = new MultiheadAttention(dModel, nHead, { dropout: dropout2 });
|
|
1531
|
+
this.linear1 = new Linear(dModel, dFF);
|
|
1532
|
+
this.linear2 = new Linear(dFF, dModel);
|
|
1533
|
+
this.norm1 = new LayerNorm(dModel, { eps });
|
|
1534
|
+
this.norm2 = new LayerNorm(dModel, { eps });
|
|
1535
|
+
this.dropout1 = new Dropout(dropout2);
|
|
1536
|
+
this.dropout2 = new Dropout(dropout2);
|
|
1537
|
+
this.dropout3 = new Dropout(dropout2);
|
|
1538
|
+
this.registerModule("self_attn", this.selfAttn);
|
|
1539
|
+
this.registerModule("linear1", this.linear1);
|
|
1540
|
+
this.registerModule("linear2", this.linear2);
|
|
1541
|
+
this.registerModule("norm1", this.norm1);
|
|
1542
|
+
this.registerModule("norm2", this.norm2);
|
|
1543
|
+
this.registerModule("dropout1", this.dropout1);
|
|
1544
|
+
this.registerModule("dropout2", this.dropout2);
|
|
1545
|
+
this.registerModule("dropout3", this.dropout3);
|
|
1546
|
+
}
|
|
1547
|
+
/**
|
|
1548
|
+
* Forward pass of the Transformer encoder layer.
|
|
1549
|
+
*
|
|
1550
|
+
* @param src - Source sequence of shape (batch, seqLen, dModel)
|
|
1551
|
+
* @returns Output of same shape as input
|
|
1552
|
+
*/
|
|
1553
|
+
forward(src) {
|
|
1554
|
+
const input = src instanceof chunk6AE5FKKQ_cjs.GradTensor ? src : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(src);
|
|
1555
|
+
if (input.dtype === "string") {
|
|
1556
|
+
throw new chunkJSCDE774_cjs.DTypeError("TransformerEncoderLayer does not support string dtype");
|
|
1557
|
+
}
|
|
1558
|
+
let src2 = this.selfAttn.forward(input, input, input);
|
|
1559
|
+
src2 = this.dropout1.forward(src2);
|
|
1560
|
+
let out = input.add(src2);
|
|
1561
|
+
out = this.norm1.forward(out);
|
|
1562
|
+
let ffn = this.linear1.forward(out);
|
|
1563
|
+
ffn = ffn.relu();
|
|
1564
|
+
ffn = this.dropout2.forward(ffn);
|
|
1565
|
+
ffn = this.linear2.forward(ffn);
|
|
1566
|
+
ffn = this.dropout3.forward(ffn);
|
|
1567
|
+
out = out.add(ffn);
|
|
1568
|
+
out = this.norm2.forward(out);
|
|
1569
|
+
return out;
|
|
1570
|
+
}
|
|
1571
|
+
toString() {
|
|
1572
|
+
return `TransformerEncoderLayer(d_model=${this.dModel}, nhead=${this.nHead}, dim_feedforward=${this.dFF}, dropout=${this.dropout})`;
|
|
1573
|
+
}
|
|
1574
|
+
};
|
|
1575
|
+
|
|
1576
|
+
// src/nn/layers/conv.ts
|
|
1577
|
+
function normalizePair(name, value, allowZero, description) {
|
|
1578
|
+
const arr = typeof value === "number" ? [value, value] : value;
|
|
1579
|
+
const first = arr[0];
|
|
1580
|
+
const second = arr[1];
|
|
1581
|
+
if (arr.length !== 2 || first === void 0 || second === void 0 || !Number.isInteger(first) || !Number.isInteger(second) || (allowZero ? first < 0 || second < 0 : first <= 0 || second <= 0)) {
|
|
1582
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`${name} must be ${description}`, name, value);
|
|
1583
|
+
}
|
|
1584
|
+
return [first, second];
|
|
1585
|
+
}
|
|
1586
|
+
var Conv1d = class extends Module {
|
|
1587
|
+
inChannels;
|
|
1588
|
+
outChannels;
|
|
1589
|
+
kernelSize;
|
|
1590
|
+
stride;
|
|
1591
|
+
padding;
|
|
1592
|
+
bias;
|
|
1593
|
+
weight_;
|
|
1594
|
+
bias_;
|
|
1595
|
+
constructor(inChannels, outChannels, kernelSize, options = {}) {
|
|
1596
|
+
super();
|
|
1597
|
+
if (inChannels <= 0 || !Number.isInteger(inChannels)) {
|
|
1598
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1599
|
+
"inChannels must be a positive integer",
|
|
1600
|
+
"inChannels",
|
|
1601
|
+
inChannels
|
|
1602
|
+
);
|
|
1603
|
+
}
|
|
1604
|
+
if (outChannels <= 0 || !Number.isInteger(outChannels)) {
|
|
1605
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1606
|
+
"outChannels must be a positive integer",
|
|
1607
|
+
"outChannels",
|
|
1608
|
+
outChannels
|
|
1609
|
+
);
|
|
1610
|
+
}
|
|
1611
|
+
if (kernelSize <= 0 || !Number.isInteger(kernelSize)) {
|
|
1612
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1613
|
+
"kernelSize must be a positive integer",
|
|
1614
|
+
"kernelSize",
|
|
1615
|
+
kernelSize
|
|
1616
|
+
);
|
|
1617
|
+
}
|
|
1618
|
+
const stride = options.stride ?? 1;
|
|
1619
|
+
if (stride <= 0 || !Number.isInteger(stride)) {
|
|
1620
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("stride must be a positive integer", "stride", stride);
|
|
1621
|
+
}
|
|
1622
|
+
const padding = options.padding ?? 0;
|
|
1623
|
+
if (padding < 0 || !Number.isInteger(padding)) {
|
|
1624
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("padding must be a non-negative integer", "padding", padding);
|
|
1625
|
+
}
|
|
1626
|
+
this.inChannels = inChannels;
|
|
1627
|
+
this.outChannels = outChannels;
|
|
1628
|
+
this.kernelSize = kernelSize;
|
|
1629
|
+
this.stride = stride;
|
|
1630
|
+
this.padding = padding;
|
|
1631
|
+
this.bias = options.bias ?? true;
|
|
1632
|
+
this.initializeParameters();
|
|
1633
|
+
}
|
|
1634
|
+
initializeParameters() {
|
|
1635
|
+
const k = 1 / Math.sqrt(this.inChannels * this.kernelSize);
|
|
1636
|
+
const weight = chunk6AE5FKKQ_cjs.randn([this.outChannels, this.inChannels, this.kernelSize]);
|
|
1637
|
+
this.weight_ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(weight, k));
|
|
1638
|
+
this.registerParameter("weight", this.weight_);
|
|
1639
|
+
if (this.bias) {
|
|
1640
|
+
const biasInit = chunk6AE5FKKQ_cjs.randn([this.outChannels]);
|
|
1641
|
+
this.bias_ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(biasInit, k));
|
|
1642
|
+
this.registerParameter("bias", this.bias_);
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
forward(x) {
|
|
1646
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1647
|
+
if (input.dtype === "string") {
|
|
1648
|
+
throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1649
|
+
}
|
|
1650
|
+
if (input.ndim !== 3) {
|
|
1651
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Conv1d expects 3D input (batch, channels, length), got ${input.ndim}D`);
|
|
1652
|
+
}
|
|
1653
|
+
const batch = input.shape[0] ?? 0;
|
|
1654
|
+
const inC = input.shape[1] ?? 0;
|
|
1655
|
+
const inL = input.shape[2] ?? 0;
|
|
1656
|
+
if (inC !== this.inChannels) {
|
|
1657
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected ${this.inChannels} input channels, got ${inC}`);
|
|
1658
|
+
}
|
|
1659
|
+
const weight = this.weight_;
|
|
1660
|
+
if (!weight) throw new chunkJSCDE774_cjs.NotFittedError("Weight not initialized");
|
|
1661
|
+
const input2d = input.reshape([batch, inC, 1, inL]);
|
|
1662
|
+
const kernelSize = [1, this.kernelSize];
|
|
1663
|
+
const stride = [1, this.stride];
|
|
1664
|
+
const padding = [0, this.padding];
|
|
1665
|
+
const cols = chunk6AE5FKKQ_cjs.im2col2(input2d, kernelSize, stride, padding);
|
|
1666
|
+
const weightFlat = weight.reshape([this.outChannels, this.inChannels * this.kernelSize]);
|
|
1667
|
+
const out = cols.matmul(weightFlat.transpose());
|
|
1668
|
+
const outTransposed = out.transpose([0, 2, 1]);
|
|
1669
|
+
if (this.bias && this.bias_) {
|
|
1670
|
+
const biasReshaped = this.bias_.reshape([1, this.outChannels, 1]);
|
|
1671
|
+
return outTransposed.add(biasReshaped);
|
|
1672
|
+
}
|
|
1673
|
+
return outTransposed;
|
|
1674
|
+
}
|
|
1675
|
+
get weight() {
|
|
1676
|
+
if (!this.weight_) {
|
|
1677
|
+
throw new chunkJSCDE774_cjs.NotFittedError("Weight not initialized");
|
|
1678
|
+
}
|
|
1679
|
+
return this.weight_;
|
|
1680
|
+
}
|
|
1681
|
+
};
|
|
1682
|
+
var Conv2d = class extends Module {
|
|
1683
|
+
inChannels;
|
|
1684
|
+
outChannels;
|
|
1685
|
+
kernelSize;
|
|
1686
|
+
stride;
|
|
1687
|
+
padding;
|
|
1688
|
+
useBias;
|
|
1689
|
+
weight_;
|
|
1690
|
+
bias_;
|
|
1691
|
+
constructor(inChannels, outChannels, kernelSize, options = {}) {
|
|
1692
|
+
super();
|
|
1693
|
+
if (inChannels <= 0 || !Number.isInteger(inChannels)) {
|
|
1694
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1695
|
+
"inChannels must be a positive integer",
|
|
1696
|
+
"inChannels",
|
|
1697
|
+
inChannels
|
|
1698
|
+
);
|
|
1699
|
+
}
|
|
1700
|
+
if (outChannels <= 0 || !Number.isInteger(outChannels)) {
|
|
1701
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
1702
|
+
"outChannels must be a positive integer",
|
|
1703
|
+
"outChannels",
|
|
1704
|
+
outChannels
|
|
1705
|
+
);
|
|
1706
|
+
}
|
|
1707
|
+
const kernelArr = normalizePair(
|
|
1708
|
+
"kernelSize",
|
|
1709
|
+
kernelSize,
|
|
1710
|
+
false,
|
|
1711
|
+
"a positive integer or a tuple of two positive integers"
|
|
1712
|
+
);
|
|
1713
|
+
const stride = options.stride ?? 1;
|
|
1714
|
+
const strideArr = normalizePair(
|
|
1715
|
+
"stride",
|
|
1716
|
+
stride,
|
|
1717
|
+
false,
|
|
1718
|
+
"a positive integer or a tuple of two positive integers"
|
|
1719
|
+
);
|
|
1720
|
+
const padding = options.padding ?? 0;
|
|
1721
|
+
const paddingArr = normalizePair(
|
|
1722
|
+
"padding",
|
|
1723
|
+
padding,
|
|
1724
|
+
true,
|
|
1725
|
+
"a non-negative integer or a tuple of two non-negative integers"
|
|
1726
|
+
);
|
|
1727
|
+
this.inChannels = inChannels;
|
|
1728
|
+
this.outChannels = outChannels;
|
|
1729
|
+
this.kernelSize = kernelArr;
|
|
1730
|
+
this.stride = strideArr;
|
|
1731
|
+
this.padding = paddingArr;
|
|
1732
|
+
this.useBias = options.bias ?? true;
|
|
1733
|
+
this.initializeParameters();
|
|
1734
|
+
}
|
|
1735
|
+
initializeParameters() {
|
|
1736
|
+
const kH = this.kernelSize[0] ?? 1;
|
|
1737
|
+
const kW = this.kernelSize[1] ?? 1;
|
|
1738
|
+
const k = 1 / Math.sqrt(this.inChannels * kH * kW);
|
|
1739
|
+
const weight = chunk6AE5FKKQ_cjs.randn([this.outChannels, this.inChannels, kH, kW]);
|
|
1740
|
+
this.weight_ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(weight, k));
|
|
1741
|
+
this.registerParameter("weight", this.weight_);
|
|
1742
|
+
if (this.useBias) {
|
|
1743
|
+
const biasInit = chunk6AE5FKKQ_cjs.randn([this.outChannels]);
|
|
1744
|
+
this.bias_ = chunk6AE5FKKQ_cjs.parameter(chunk6AE5FKKQ_cjs.mulScalar(biasInit, k));
|
|
1745
|
+
this.registerParameter("bias", this.bias_);
|
|
1746
|
+
}
|
|
1747
|
+
}
|
|
1748
|
+
forward(x) {
|
|
1749
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1750
|
+
if (input.dtype === "string") {
|
|
1751
|
+
throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1752
|
+
}
|
|
1753
|
+
if (input.ndim !== 4) {
|
|
1754
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1755
|
+
`Conv2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1756
|
+
);
|
|
1757
|
+
}
|
|
1758
|
+
const batch = input.shape[0] ?? 0;
|
|
1759
|
+
const inC = input.shape[1] ?? 0;
|
|
1760
|
+
const inH = input.shape[2] ?? 0;
|
|
1761
|
+
const inW = input.shape[3] ?? 0;
|
|
1762
|
+
if (inC !== this.inChannels) {
|
|
1763
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected ${this.inChannels} input channels, got ${inC}`);
|
|
1764
|
+
}
|
|
1765
|
+
const weight = this.weight_;
|
|
1766
|
+
if (!weight) throw new chunkJSCDE774_cjs.NotFittedError("Weight not initialized");
|
|
1767
|
+
const [kH, kW] = this.kernelSize;
|
|
1768
|
+
const [sH, sW] = this.stride;
|
|
1769
|
+
const [pH, pW] = this.padding;
|
|
1770
|
+
const cols = chunk6AE5FKKQ_cjs.im2col2(input, [kH, kW], [sH, sW], [pH, pW]);
|
|
1771
|
+
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1772
|
+
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
1773
|
+
const weightFlat = weight.reshape([this.outChannels, this.inChannels * kH * kW]);
|
|
1774
|
+
const out = cols.matmul(weightFlat.transpose());
|
|
1775
|
+
const outTransposed = out.transpose([0, 2, 1]);
|
|
1776
|
+
const outReshaped = outTransposed.reshape([batch, this.outChannels, outH, outW]);
|
|
1777
|
+
if (this.useBias && this.bias_) {
|
|
1778
|
+
const biasReshaped = this.bias_.reshape([1, this.outChannels, 1, 1]);
|
|
1779
|
+
return outReshaped.add(biasReshaped);
|
|
1780
|
+
}
|
|
1781
|
+
return outReshaped;
|
|
1782
|
+
}
|
|
1783
|
+
get weight() {
|
|
1784
|
+
if (!this.weight_) {
|
|
1785
|
+
throw new chunkJSCDE774_cjs.NotFittedError("Weight not initialized");
|
|
1786
|
+
}
|
|
1787
|
+
return this.weight_;
|
|
1788
|
+
}
|
|
1789
|
+
};
|
|
1790
|
+
var MaxPool2d = class extends Module {
|
|
1791
|
+
kernelSizeValue;
|
|
1792
|
+
stride;
|
|
1793
|
+
padding;
|
|
1794
|
+
constructor(kernelSize, options = {}) {
|
|
1795
|
+
super();
|
|
1796
|
+
const kernelArr = normalizePair(
|
|
1797
|
+
"kernelSize",
|
|
1798
|
+
kernelSize,
|
|
1799
|
+
false,
|
|
1800
|
+
"a positive integer or a tuple of two positive integers"
|
|
1801
|
+
);
|
|
1802
|
+
this.kernelSizeValue = kernelArr;
|
|
1803
|
+
const strideArr = normalizePair(
|
|
1804
|
+
"stride",
|
|
1805
|
+
options.stride ?? kernelSize,
|
|
1806
|
+
false,
|
|
1807
|
+
"a positive integer or a tuple of two positive integers"
|
|
1808
|
+
);
|
|
1809
|
+
this.stride = strideArr;
|
|
1810
|
+
const paddingArr = normalizePair(
|
|
1811
|
+
"padding",
|
|
1812
|
+
options.padding ?? 0,
|
|
1813
|
+
true,
|
|
1814
|
+
"a non-negative integer or a tuple of two non-negative integers"
|
|
1815
|
+
);
|
|
1816
|
+
this.padding = paddingArr;
|
|
1817
|
+
}
|
|
1818
|
+
forward(x) {
|
|
1819
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1820
|
+
if (input.dtype === "string") {
|
|
1821
|
+
throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1822
|
+
}
|
|
1823
|
+
if (input.ndim !== 4) {
|
|
1824
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1825
|
+
`MaxPool2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1826
|
+
);
|
|
1827
|
+
}
|
|
1828
|
+
const batch = input.shape[0] ?? 0;
|
|
1829
|
+
const channels = input.shape[1] ?? 0;
|
|
1830
|
+
const inH = input.shape[2] ?? 0;
|
|
1831
|
+
const inW = input.shape[3] ?? 0;
|
|
1832
|
+
const [kH, kW] = this.kernelSizeValue;
|
|
1833
|
+
const [sH, sW] = this.stride;
|
|
1834
|
+
const [pH, pW] = this.padding;
|
|
1835
|
+
const inputReshaped = input.reshape([batch * channels, 1, inH, inW]);
|
|
1836
|
+
const cols = chunk6AE5FKKQ_cjs.im2col2(inputReshaped, [kH, kW], [sH, sW], [pH, pW]);
|
|
1837
|
+
const maxVals = cols.max(2);
|
|
1838
|
+
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1839
|
+
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
1840
|
+
return maxVals.reshape([batch, channels, outH, outW]);
|
|
1841
|
+
}
|
|
1842
|
+
};
|
|
1843
|
+
var AvgPool2d = class extends Module {
|
|
1844
|
+
kernelSizeValue;
|
|
1845
|
+
stride;
|
|
1846
|
+
padding;
|
|
1847
|
+
constructor(kernelSize, options = {}) {
|
|
1848
|
+
super();
|
|
1849
|
+
const kernelArr = normalizePair(
|
|
1850
|
+
"kernelSize",
|
|
1851
|
+
kernelSize,
|
|
1852
|
+
false,
|
|
1853
|
+
"a positive integer or a tuple of two positive integers"
|
|
1854
|
+
);
|
|
1855
|
+
this.kernelSizeValue = kernelArr;
|
|
1856
|
+
const strideArr = normalizePair(
|
|
1857
|
+
"stride",
|
|
1858
|
+
options.stride ?? kernelSize,
|
|
1859
|
+
false,
|
|
1860
|
+
"a positive integer or a tuple of two positive integers"
|
|
1861
|
+
);
|
|
1862
|
+
this.stride = strideArr;
|
|
1863
|
+
const paddingArr = normalizePair(
|
|
1864
|
+
"padding",
|
|
1865
|
+
options.padding ?? 0,
|
|
1866
|
+
true,
|
|
1867
|
+
"a non-negative integer or a tuple of two non-negative integers"
|
|
1868
|
+
);
|
|
1869
|
+
this.padding = paddingArr;
|
|
1870
|
+
}
|
|
1871
|
+
forward(x) {
|
|
1872
|
+
const input = x instanceof chunk6AE5FKKQ_cjs.GradTensor ? x : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(x);
|
|
1873
|
+
if (input.dtype === "string") {
|
|
1874
|
+
throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1875
|
+
}
|
|
1876
|
+
if (input.ndim !== 4) {
|
|
1877
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1878
|
+
`AvgPool2d expects 4D input (batch, channels, height, width), got ${input.ndim}D`
|
|
1879
|
+
);
|
|
1880
|
+
}
|
|
1881
|
+
const batch = input.shape[0] ?? 0;
|
|
1882
|
+
const channels = input.shape[1] ?? 0;
|
|
1883
|
+
const inH = input.shape[2] ?? 0;
|
|
1884
|
+
const inW = input.shape[3] ?? 0;
|
|
1885
|
+
const [kH, kW] = this.kernelSizeValue;
|
|
1886
|
+
const [sH, sW] = this.stride;
|
|
1887
|
+
const [pH, pW] = this.padding;
|
|
1888
|
+
const inputReshaped = input.reshape([batch * channels, 1, inH, inW]);
|
|
1889
|
+
const cols = chunk6AE5FKKQ_cjs.im2col2(inputReshaped, [kH, kW], [sH, sW], [pH, pW]);
|
|
1890
|
+
const meanVals = cols.mean(2);
|
|
1891
|
+
const outH = Math.floor((inH + 2 * pH - kH) / sH) + 1;
|
|
1892
|
+
const outW = Math.floor((inW + 2 * pW - kW) / sW) + 1;
|
|
1893
|
+
return meanVals.reshape([batch, channels, outH, outW]);
|
|
1894
|
+
}
|
|
1895
|
+
};
|
|
1896
|
+
|
|
1897
|
+
// src/nn/layers/recurrent.ts
|
|
1898
|
+
function ensureFloatTensor(t, context) {
|
|
1899
|
+
if (t.dtype === "string") {
|
|
1900
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${context} does not support string dtype`);
|
|
1901
|
+
}
|
|
1902
|
+
if (t.dtype !== "float32" && t.dtype !== "float64") {
|
|
1903
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${context} expects float32 or float64 dtype`);
|
|
1904
|
+
}
|
|
1905
|
+
}
|
|
1906
|
+
function readNumeric(t, offset) {
|
|
1907
|
+
const data = t.data;
|
|
1908
|
+
if (Array.isArray(data)) {
|
|
1909
|
+
throw new chunkJSCDE774_cjs.DTypeError("String tensors are not supported");
|
|
1910
|
+
}
|
|
1911
|
+
return chunkJSCDE774_cjs.getElementAsNumber(data, offset);
|
|
1912
|
+
}
|
|
1913
|
+
function createFloatBuffer(size, dtype) {
|
|
1914
|
+
return dtype === "float64" ? new Float64Array(size) : new Float32Array(size);
|
|
1915
|
+
}
|
|
1916
|
+
function validatePositiveInt(name, value) {
|
|
1917
|
+
if (!Number.isInteger(value) || value <= 0) {
|
|
1918
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`${name} must be a positive integer`, name, value);
|
|
1919
|
+
}
|
|
1920
|
+
}
|
|
1921
|
+
function parseInput(input, batchFirst) {
|
|
1922
|
+
if (input.ndim === 2) {
|
|
1923
|
+
const seqLen = input.shape[0] ?? 0;
|
|
1924
|
+
const inputDim = input.shape[1] ?? 0;
|
|
1925
|
+
return {
|
|
1926
|
+
batch: 1,
|
|
1927
|
+
seqLen,
|
|
1928
|
+
inputDim,
|
|
1929
|
+
isUnbatched: true,
|
|
1930
|
+
batchStride: 0,
|
|
1931
|
+
seqStride: input.strides[0] ?? 0,
|
|
1932
|
+
featStride: input.strides[1] ?? 0
|
|
1933
|
+
};
|
|
1934
|
+
}
|
|
1935
|
+
if (input.ndim !== 3) {
|
|
1936
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Recurrent layers expect 2D or 3D input; got ndim=${input.ndim}`);
|
|
1937
|
+
}
|
|
1938
|
+
if (batchFirst) {
|
|
1939
|
+
return {
|
|
1940
|
+
batch: input.shape[0] ?? 0,
|
|
1941
|
+
seqLen: input.shape[1] ?? 0,
|
|
1942
|
+
inputDim: input.shape[2] ?? 0,
|
|
1943
|
+
isUnbatched: false,
|
|
1944
|
+
batchStride: input.strides[0] ?? 0,
|
|
1945
|
+
seqStride: input.strides[1] ?? 0,
|
|
1946
|
+
featStride: input.strides[2] ?? 0
|
|
1947
|
+
};
|
|
1948
|
+
}
|
|
1949
|
+
return {
|
|
1950
|
+
batch: input.shape[1] ?? 0,
|
|
1951
|
+
seqLen: input.shape[0] ?? 0,
|
|
1952
|
+
inputDim: input.shape[2] ?? 0,
|
|
1953
|
+
isUnbatched: false,
|
|
1954
|
+
batchStride: input.strides[1] ?? 0,
|
|
1955
|
+
seqStride: input.strides[0] ?? 0,
|
|
1956
|
+
featStride: input.strides[2] ?? 0
|
|
1957
|
+
};
|
|
1958
|
+
}
|
|
1959
|
+
function outputIndex(batchFirst, isUnbatched, batch, seqLen, hiddenSize, b, t, j) {
|
|
1960
|
+
if (isUnbatched) {
|
|
1961
|
+
return t * hiddenSize + j;
|
|
1962
|
+
}
|
|
1963
|
+
if (batchFirst) {
|
|
1964
|
+
return b * (seqLen * hiddenSize) + t * hiddenSize + j;
|
|
1965
|
+
}
|
|
1966
|
+
return t * (batch * hiddenSize) + b * hiddenSize + j;
|
|
1967
|
+
}
|
|
1968
|
+
function extractTensor(arg, _name) {
|
|
1969
|
+
if (arg instanceof chunk6AE5FKKQ_cjs.GradTensor) {
|
|
1970
|
+
return arg.tensor;
|
|
1971
|
+
}
|
|
1972
|
+
return arg;
|
|
1973
|
+
}
|
|
1974
|
+
function buildState(state, numLayers, batch, hiddenSize, isUnbatched, name) {
|
|
1975
|
+
const result = new Array(numLayers);
|
|
1976
|
+
for (let l = 0; l < numLayers; l++) {
|
|
1977
|
+
result[l] = new Float64Array(batch * hiddenSize);
|
|
1978
|
+
}
|
|
1979
|
+
if (!state) {
|
|
1980
|
+
return result;
|
|
1981
|
+
}
|
|
1982
|
+
ensureFloatTensor(state, name);
|
|
1983
|
+
if (state.ndim === 2) {
|
|
1984
|
+
if (!isUnbatched) {
|
|
1985
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected ${name} with 3 dimensions for batched input`);
|
|
1986
|
+
}
|
|
1987
|
+
if ((state.shape[0] ?? 0) !== numLayers || (state.shape[1] ?? 0) !== hiddenSize) {
|
|
1988
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
1989
|
+
`Expected ${name} shape [${numLayers}, ${hiddenSize}], got [${state.shape.join(", ")}]`
|
|
1990
|
+
);
|
|
1991
|
+
}
|
|
1992
|
+
const stride02 = state.strides[0] ?? 0;
|
|
1993
|
+
const stride12 = state.strides[1] ?? 0;
|
|
1994
|
+
for (let l = 0; l < numLayers; l++) {
|
|
1995
|
+
const layerState = result[l];
|
|
1996
|
+
if (!layerState) {
|
|
1997
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Internal error: missing ${name} layer state`);
|
|
1998
|
+
}
|
|
1999
|
+
const base = state.offset + l * stride02;
|
|
2000
|
+
for (let j = 0; j < hiddenSize; j++) {
|
|
2001
|
+
layerState[j] = readNumeric(state, base + j * stride12);
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
2004
|
+
return result;
|
|
2005
|
+
}
|
|
2006
|
+
if (state.ndim !== 3) {
|
|
2007
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected ${name} with 2 or 3 dimensions; got ndim=${state.ndim}`);
|
|
2008
|
+
}
|
|
2009
|
+
const expectedBatch = isUnbatched ? 1 : batch;
|
|
2010
|
+
if ((state.shape[0] ?? 0) !== numLayers || (state.shape[1] ?? 0) !== expectedBatch || (state.shape[2] ?? 0) !== hiddenSize) {
|
|
2011
|
+
const expected = isUnbatched ? [numLayers, 1, hiddenSize] : [numLayers, batch, hiddenSize];
|
|
2012
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
2013
|
+
`Expected ${name} shape [${expected.join(", ")}], got [${state.shape.join(", ")}]`
|
|
2014
|
+
);
|
|
2015
|
+
}
|
|
2016
|
+
const stride0 = state.strides[0] ?? 0;
|
|
2017
|
+
const stride1 = state.strides[1] ?? 0;
|
|
2018
|
+
const stride2 = state.strides[2] ?? 0;
|
|
2019
|
+
for (let l = 0; l < numLayers; l++) {
|
|
2020
|
+
const layerState = result[l];
|
|
2021
|
+
if (!layerState) {
|
|
2022
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Internal error: missing ${name} layer state`);
|
|
2023
|
+
}
|
|
2024
|
+
const baseLayer = state.offset + l * stride0;
|
|
2025
|
+
for (let b = 0; b < batch; b++) {
|
|
2026
|
+
const baseBatch = baseLayer + b * stride1;
|
|
2027
|
+
for (let j = 0; j < hiddenSize; j++) {
|
|
2028
|
+
layerState[b * hiddenSize + j] = readNumeric(state, baseBatch + j * stride2);
|
|
2029
|
+
}
|
|
2030
|
+
}
|
|
2031
|
+
}
|
|
2032
|
+
return result;
|
|
2033
|
+
}
|
|
2034
|
+
function packState(state, numLayers, batch, hiddenSize, dtype, device, isUnbatched) {
|
|
2035
|
+
const size = isUnbatched ? numLayers * hiddenSize : numLayers * batch * hiddenSize;
|
|
2036
|
+
const data = createFloatBuffer(size, dtype);
|
|
2037
|
+
if (isUnbatched) {
|
|
2038
|
+
for (let l = 0; l < numLayers; l++) {
|
|
2039
|
+
const layer = state[l];
|
|
2040
|
+
if (!layer) {
|
|
2041
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing packed state layer");
|
|
2042
|
+
}
|
|
2043
|
+
for (let j = 0; j < hiddenSize; j++) {
|
|
2044
|
+
data[l * hiddenSize + j] = layer[j] ?? 0;
|
|
2045
|
+
}
|
|
2046
|
+
}
|
|
2047
|
+
return chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2048
|
+
data,
|
|
2049
|
+
shape: [numLayers, hiddenSize],
|
|
2050
|
+
dtype,
|
|
2051
|
+
device
|
|
2052
|
+
});
|
|
2053
|
+
}
|
|
2054
|
+
for (let l = 0; l < numLayers; l++) {
|
|
2055
|
+
const layer = state[l];
|
|
2056
|
+
if (!layer) {
|
|
2057
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing packed state layer");
|
|
2058
|
+
}
|
|
2059
|
+
const layerOffset = l * batch * hiddenSize;
|
|
2060
|
+
for (let b = 0; b < batch; b++) {
|
|
2061
|
+
const batchOffset = layerOffset + b * hiddenSize;
|
|
2062
|
+
for (let j = 0; j < hiddenSize; j++) {
|
|
2063
|
+
data[batchOffset + j] = layer[b * hiddenSize + j] ?? 0;
|
|
2064
|
+
}
|
|
2065
|
+
}
|
|
2066
|
+
}
|
|
2067
|
+
return chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2068
|
+
data,
|
|
2069
|
+
shape: [numLayers, batch, hiddenSize],
|
|
2070
|
+
dtype,
|
|
2071
|
+
device
|
|
2072
|
+
});
|
|
2073
|
+
}
|
|
2074
|
+
var RNN = class extends Module {
|
|
2075
|
+
inputSize;
|
|
2076
|
+
hiddenSize;
|
|
2077
|
+
numLayers;
|
|
2078
|
+
nonlinearity;
|
|
2079
|
+
bias;
|
|
2080
|
+
batchFirst;
|
|
2081
|
+
weightsIh;
|
|
2082
|
+
weightsHh;
|
|
2083
|
+
biasIh;
|
|
2084
|
+
biasHh;
|
|
2085
|
+
constructor(inputSize, hiddenSize, options = {}) {
|
|
2086
|
+
super();
|
|
2087
|
+
validatePositiveInt("inputSize", inputSize);
|
|
2088
|
+
validatePositiveInt("hiddenSize", hiddenSize);
|
|
2089
|
+
const numLayers = options.numLayers ?? 1;
|
|
2090
|
+
validatePositiveInt("numLayers", numLayers);
|
|
2091
|
+
this.inputSize = inputSize;
|
|
2092
|
+
this.hiddenSize = hiddenSize;
|
|
2093
|
+
this.numLayers = numLayers;
|
|
2094
|
+
this.nonlinearity = options.nonlinearity ?? "tanh";
|
|
2095
|
+
this.bias = options.bias ?? true;
|
|
2096
|
+
this.batchFirst = options.batchFirst ?? true;
|
|
2097
|
+
const stdv = 1 / Math.sqrt(hiddenSize);
|
|
2098
|
+
this.weightsIh = [];
|
|
2099
|
+
this.weightsHh = [];
|
|
2100
|
+
this.biasIh = [];
|
|
2101
|
+
this.biasHh = [];
|
|
2102
|
+
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2103
|
+
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2104
|
+
const wIh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([hiddenSize, inputDim]), stdv);
|
|
2105
|
+
const wHh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([hiddenSize, hiddenSize]), stdv);
|
|
2106
|
+
this.weightsIh.push(wIh);
|
|
2107
|
+
this.weightsHh.push(wHh);
|
|
2108
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wIh));
|
|
2109
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wHh));
|
|
2110
|
+
if (this.bias) {
|
|
2111
|
+
const bIh = chunk6AE5FKKQ_cjs.zeros([hiddenSize]);
|
|
2112
|
+
const bHh = chunk6AE5FKKQ_cjs.zeros([hiddenSize]);
|
|
2113
|
+
this.biasIh.push(bIh);
|
|
2114
|
+
this.biasHh.push(bHh);
|
|
2115
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bIh));
|
|
2116
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bHh));
|
|
2117
|
+
}
|
|
2118
|
+
}
|
|
2119
|
+
}
|
|
2120
|
+
activation(x) {
|
|
2121
|
+
return this.nonlinearity === "tanh" ? Math.tanh(x) : Math.max(0, x);
|
|
2122
|
+
}
|
|
2123
|
+
run(input, hx) {
|
|
2124
|
+
ensureFloatTensor(input, "RNN");
|
|
2125
|
+
const parsed = parseInput(input, this.batchFirst);
|
|
2126
|
+
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2127
|
+
if (inputDim !== this.inputSize) {
|
|
2128
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2129
|
+
}
|
|
2130
|
+
if (seqLen <= 0) {
|
|
2131
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2132
|
+
}
|
|
2133
|
+
if (!isUnbatched && batch <= 0) {
|
|
2134
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2135
|
+
}
|
|
2136
|
+
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2137
|
+
const outSize = (isUnbatched ? seqLen : batch * seqLen) * this.hiddenSize;
|
|
2138
|
+
const out = createFloatBuffer(outSize, input.dtype);
|
|
2139
|
+
const inputVec = new Float64Array(inputDim);
|
|
2140
|
+
for (let t = 0; t < seqLen; t++) {
|
|
2141
|
+
for (let b = 0; b < batch; b++) {
|
|
2142
|
+
const baseOffset = input.offset + b * batchStride + t * seqStride;
|
|
2143
|
+
for (let i = 0; i < inputDim; i++) {
|
|
2144
|
+
inputVec[i] = readNumeric(input, baseOffset + i * featStride);
|
|
2145
|
+
}
|
|
2146
|
+
let layerInput = inputVec;
|
|
2147
|
+
for (let l = 0; l < this.numLayers; l++) {
|
|
2148
|
+
const wIh = this.weightsIh[l];
|
|
2149
|
+
const wHh = this.weightsHh[l];
|
|
2150
|
+
if (!wIh || !wHh) {
|
|
2151
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing RNN weights");
|
|
2152
|
+
}
|
|
2153
|
+
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2154
|
+
const newH = new Float64Array(this.hiddenSize);
|
|
2155
|
+
const hLayer = h[l];
|
|
2156
|
+
if (!hLayer) {
|
|
2157
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing RNN hidden state");
|
|
2158
|
+
}
|
|
2159
|
+
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2160
|
+
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
2161
|
+
const wHhStride0 = wHh.strides[0] ?? 0;
|
|
2162
|
+
const wHhStride1 = wHh.strides[1] ?? 0;
|
|
2163
|
+
const biasIh = this.biasIh[l];
|
|
2164
|
+
const biasHh = this.biasHh[l];
|
|
2165
|
+
const biasIhStride = biasIh ? biasIh.strides[0] ?? 0 : 0;
|
|
2166
|
+
const biasHhStride = biasHh ? biasHh.strides[0] ?? 0 : 0;
|
|
2167
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2168
|
+
let sum2 = 0;
|
|
2169
|
+
const wIhBase = wIh.offset + j * wIhStride0;
|
|
2170
|
+
for (let k = 0; k < curInputSize; k++) {
|
|
2171
|
+
sum2 += (layerInput[k] ?? 0) * readNumeric(wIh, wIhBase + k * wIhStride1);
|
|
2172
|
+
}
|
|
2173
|
+
const wHhBase = wHh.offset + j * wHhStride0;
|
|
2174
|
+
for (let k = 0; k < this.hiddenSize; k++) {
|
|
2175
|
+
sum2 += (hLayer[b * this.hiddenSize + k] ?? 0) * readNumeric(wHh, wHhBase + k * wHhStride1);
|
|
2176
|
+
}
|
|
2177
|
+
if (this.bias && biasIh && biasHh) {
|
|
2178
|
+
sum2 += readNumeric(biasIh, biasIh.offset + j * biasIhStride);
|
|
2179
|
+
sum2 += readNumeric(biasHh, biasHh.offset + j * biasHhStride);
|
|
2180
|
+
}
|
|
2181
|
+
newH[j] = this.activation(sum2);
|
|
2182
|
+
}
|
|
2183
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2184
|
+
hLayer[b * this.hiddenSize + j] = newH[j] ?? 0;
|
|
2185
|
+
}
|
|
2186
|
+
layerInput = newH;
|
|
2187
|
+
}
|
|
2188
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2189
|
+
const idx = outputIndex(
|
|
2190
|
+
this.batchFirst,
|
|
2191
|
+
isUnbatched,
|
|
2192
|
+
batch,
|
|
2193
|
+
seqLen,
|
|
2194
|
+
this.hiddenSize,
|
|
2195
|
+
b,
|
|
2196
|
+
t,
|
|
2197
|
+
j
|
|
2198
|
+
);
|
|
2199
|
+
out[idx] = layerInput[j] ?? 0;
|
|
2200
|
+
}
|
|
2201
|
+
}
|
|
2202
|
+
}
|
|
2203
|
+
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2204
|
+
return {
|
|
2205
|
+
output: chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2206
|
+
data: out,
|
|
2207
|
+
shape: outShape,
|
|
2208
|
+
dtype: input.dtype,
|
|
2209
|
+
device: input.device
|
|
2210
|
+
}),
|
|
2211
|
+
h: packState(
|
|
2212
|
+
h,
|
|
2213
|
+
this.numLayers,
|
|
2214
|
+
batch,
|
|
2215
|
+
this.hiddenSize,
|
|
2216
|
+
input.dtype,
|
|
2217
|
+
input.device,
|
|
2218
|
+
isUnbatched
|
|
2219
|
+
)
|
|
2220
|
+
};
|
|
2221
|
+
}
|
|
2222
|
+
forward(...inputs) {
|
|
2223
|
+
if (inputs.length < 1 || inputs.length > 2) {
|
|
2224
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("RNN.forward expects 1 or 2 inputs", "inputs", inputs.length);
|
|
2225
|
+
}
|
|
2226
|
+
const inputArg = inputs[0];
|
|
2227
|
+
if (inputArg === void 0) {
|
|
2228
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("RNN.forward requires an input tensor", "input", inputArg);
|
|
2229
|
+
}
|
|
2230
|
+
const input = extractTensor(inputArg);
|
|
2231
|
+
const hxArg = inputs.length === 2 ? inputs[1] : void 0;
|
|
2232
|
+
const hx = hxArg === void 0 ? void 0 : extractTensor(hxArg);
|
|
2233
|
+
return this.run(input, hx).output;
|
|
2234
|
+
}
|
|
2235
|
+
/**
|
|
2236
|
+
* Forward pass returning both output and hidden state.
|
|
2237
|
+
* Use this method when you need the hidden state.
|
|
2238
|
+
*/
|
|
2239
|
+
forwardWithState(input, hx) {
|
|
2240
|
+
const inputTensor = extractTensor(input);
|
|
2241
|
+
const hxTensor = hx === void 0 ? void 0 : extractTensor(hx);
|
|
2242
|
+
const { output, h } = this.run(inputTensor, hxTensor);
|
|
2243
|
+
return [output, h];
|
|
2244
|
+
}
|
|
2245
|
+
toString() {
|
|
2246
|
+
return `RNN(${this.inputSize}, ${this.hiddenSize}, num_layers=${this.numLayers})`;
|
|
2247
|
+
}
|
|
2248
|
+
};
|
|
2249
|
+
var LSTM = class extends Module {
|
|
2250
|
+
inputSize;
|
|
2251
|
+
hiddenSize;
|
|
2252
|
+
numLayers;
|
|
2253
|
+
bias;
|
|
2254
|
+
batchFirst;
|
|
2255
|
+
weightsIh;
|
|
2256
|
+
weightsHh;
|
|
2257
|
+
biasIh;
|
|
2258
|
+
biasHh;
|
|
2259
|
+
constructor(inputSize, hiddenSize, options = {}) {
|
|
2260
|
+
super();
|
|
2261
|
+
validatePositiveInt("inputSize", inputSize);
|
|
2262
|
+
validatePositiveInt("hiddenSize", hiddenSize);
|
|
2263
|
+
const numLayers = options.numLayers ?? 1;
|
|
2264
|
+
validatePositiveInt("numLayers", numLayers);
|
|
2265
|
+
this.inputSize = inputSize;
|
|
2266
|
+
this.hiddenSize = hiddenSize;
|
|
2267
|
+
this.numLayers = numLayers;
|
|
2268
|
+
this.bias = options.bias ?? true;
|
|
2269
|
+
this.batchFirst = options.batchFirst ?? true;
|
|
2270
|
+
const stdv = 1 / Math.sqrt(hiddenSize);
|
|
2271
|
+
this.weightsIh = [];
|
|
2272
|
+
this.weightsHh = [];
|
|
2273
|
+
this.biasIh = [];
|
|
2274
|
+
this.biasHh = [];
|
|
2275
|
+
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2276
|
+
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2277
|
+
const wIh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([4 * hiddenSize, inputDim]), stdv);
|
|
2278
|
+
const wHh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([4 * hiddenSize, hiddenSize]), stdv);
|
|
2279
|
+
this.weightsIh.push(wIh);
|
|
2280
|
+
this.weightsHh.push(wHh);
|
|
2281
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wIh));
|
|
2282
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wHh));
|
|
2283
|
+
if (this.bias) {
|
|
2284
|
+
const bIh = chunk6AE5FKKQ_cjs.zeros([4 * hiddenSize]);
|
|
2285
|
+
const bHh = chunk6AE5FKKQ_cjs.zeros([4 * hiddenSize]);
|
|
2286
|
+
this.biasIh.push(bIh);
|
|
2287
|
+
this.biasHh.push(bHh);
|
|
2288
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bIh));
|
|
2289
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bHh));
|
|
2290
|
+
}
|
|
2291
|
+
}
|
|
2292
|
+
}
|
|
2293
|
+
sigmoid(x) {
|
|
2294
|
+
return 1 / (1 + Math.exp(-x));
|
|
2295
|
+
}
|
|
2296
|
+
run(input, hx, cx) {
|
|
2297
|
+
ensureFloatTensor(input, "LSTM");
|
|
2298
|
+
const parsed = parseInput(input, this.batchFirst);
|
|
2299
|
+
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2300
|
+
if (inputDim !== this.inputSize) {
|
|
2301
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2302
|
+
}
|
|
2303
|
+
if (seqLen <= 0) {
|
|
2304
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2305
|
+
}
|
|
2306
|
+
if (!isUnbatched && batch <= 0) {
|
|
2307
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2308
|
+
}
|
|
2309
|
+
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2310
|
+
const c = buildState(cx, this.numLayers, batch, this.hiddenSize, isUnbatched, "cx");
|
|
2311
|
+
const outSize = (isUnbatched ? seqLen : batch * seqLen) * this.hiddenSize;
|
|
2312
|
+
const out = createFloatBuffer(outSize, input.dtype);
|
|
2313
|
+
const inputVec = new Float64Array(inputDim);
|
|
2314
|
+
const gates = new Float64Array(4 * this.hiddenSize);
|
|
2315
|
+
for (let t = 0; t < seqLen; t++) {
|
|
2316
|
+
for (let b = 0; b < batch; b++) {
|
|
2317
|
+
const baseOffset = input.offset + b * batchStride + t * seqStride;
|
|
2318
|
+
for (let i = 0; i < inputDim; i++) {
|
|
2319
|
+
inputVec[i] = readNumeric(input, baseOffset + i * featStride);
|
|
2320
|
+
}
|
|
2321
|
+
let layerInput = inputVec;
|
|
2322
|
+
for (let l = 0; l < this.numLayers; l++) {
|
|
2323
|
+
const wIh = this.weightsIh[l];
|
|
2324
|
+
const wHh = this.weightsHh[l];
|
|
2325
|
+
if (!wIh || !wHh) {
|
|
2326
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing LSTM weights");
|
|
2327
|
+
}
|
|
2328
|
+
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2329
|
+
const hLayer = h[l];
|
|
2330
|
+
const cLayer = c[l];
|
|
2331
|
+
if (!hLayer || !cLayer) {
|
|
2332
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing LSTM state");
|
|
2333
|
+
}
|
|
2334
|
+
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2335
|
+
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
2336
|
+
const wHhStride0 = wHh.strides[0] ?? 0;
|
|
2337
|
+
const wHhStride1 = wHh.strides[1] ?? 0;
|
|
2338
|
+
const biasIh = this.biasIh[l];
|
|
2339
|
+
const biasHh = this.biasHh[l];
|
|
2340
|
+
const biasIhStride = biasIh ? biasIh.strides[0] ?? 0 : 0;
|
|
2341
|
+
const biasHhStride = biasHh ? biasHh.strides[0] ?? 0 : 0;
|
|
2342
|
+
for (let g = 0; g < 4 * this.hiddenSize; g++) {
|
|
2343
|
+
let sum2 = 0;
|
|
2344
|
+
const wIhBase = wIh.offset + g * wIhStride0;
|
|
2345
|
+
for (let k = 0; k < curInputSize; k++) {
|
|
2346
|
+
sum2 += (layerInput[k] ?? 0) * readNumeric(wIh, wIhBase + k * wIhStride1);
|
|
2347
|
+
}
|
|
2348
|
+
const wHhBase = wHh.offset + g * wHhStride0;
|
|
2349
|
+
for (let k = 0; k < this.hiddenSize; k++) {
|
|
2350
|
+
sum2 += (hLayer[b * this.hiddenSize + k] ?? 0) * readNumeric(wHh, wHhBase + k * wHhStride1);
|
|
2351
|
+
}
|
|
2352
|
+
if (this.bias && biasIh && biasHh) {
|
|
2353
|
+
sum2 += readNumeric(biasIh, biasIh.offset + g * biasIhStride);
|
|
2354
|
+
sum2 += readNumeric(biasHh, biasHh.offset + g * biasHhStride);
|
|
2355
|
+
}
|
|
2356
|
+
gates[g] = sum2;
|
|
2357
|
+
}
|
|
2358
|
+
const newH = new Float64Array(this.hiddenSize);
|
|
2359
|
+
const newC = new Float64Array(this.hiddenSize);
|
|
2360
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2361
|
+
const iGate = this.sigmoid(gates[j] ?? 0);
|
|
2362
|
+
const fGate = this.sigmoid(gates[this.hiddenSize + j] ?? 0);
|
|
2363
|
+
const gGate = Math.tanh(gates[2 * this.hiddenSize + j] ?? 0);
|
|
2364
|
+
const oGate = this.sigmoid(gates[3 * this.hiddenSize + j] ?? 0);
|
|
2365
|
+
const prevC = cLayer[b * this.hiddenSize + j] ?? 0;
|
|
2366
|
+
const nextC = fGate * prevC + iGate * gGate;
|
|
2367
|
+
const nextH = oGate * Math.tanh(nextC);
|
|
2368
|
+
newC[j] = nextC;
|
|
2369
|
+
newH[j] = nextH;
|
|
2370
|
+
}
|
|
2371
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2372
|
+
hLayer[b * this.hiddenSize + j] = newH[j] ?? 0;
|
|
2373
|
+
cLayer[b * this.hiddenSize + j] = newC[j] ?? 0;
|
|
2374
|
+
}
|
|
2375
|
+
layerInput = newH;
|
|
2376
|
+
}
|
|
2377
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2378
|
+
const idx = outputIndex(
|
|
2379
|
+
this.batchFirst,
|
|
2380
|
+
isUnbatched,
|
|
2381
|
+
batch,
|
|
2382
|
+
seqLen,
|
|
2383
|
+
this.hiddenSize,
|
|
2384
|
+
b,
|
|
2385
|
+
t,
|
|
2386
|
+
j
|
|
2387
|
+
);
|
|
2388
|
+
out[idx] = layerInput[j] ?? 0;
|
|
2389
|
+
}
|
|
2390
|
+
}
|
|
2391
|
+
}
|
|
2392
|
+
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2393
|
+
return {
|
|
2394
|
+
output: chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2395
|
+
data: out,
|
|
2396
|
+
shape: outShape,
|
|
2397
|
+
dtype: input.dtype,
|
|
2398
|
+
device: input.device
|
|
2399
|
+
}),
|
|
2400
|
+
h: packState(
|
|
2401
|
+
h,
|
|
2402
|
+
this.numLayers,
|
|
2403
|
+
batch,
|
|
2404
|
+
this.hiddenSize,
|
|
2405
|
+
input.dtype,
|
|
2406
|
+
input.device,
|
|
2407
|
+
isUnbatched
|
|
2408
|
+
),
|
|
2409
|
+
c: packState(
|
|
2410
|
+
c,
|
|
2411
|
+
this.numLayers,
|
|
2412
|
+
batch,
|
|
2413
|
+
this.hiddenSize,
|
|
2414
|
+
input.dtype,
|
|
2415
|
+
input.device,
|
|
2416
|
+
isUnbatched
|
|
2417
|
+
)
|
|
2418
|
+
};
|
|
2419
|
+
}
|
|
2420
|
+
forward(...inputs) {
|
|
2421
|
+
if (inputs.length < 1 || inputs.length > 3) {
|
|
2422
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
2423
|
+
"LSTM.forward expects 1 to 3 inputs",
|
|
2424
|
+
"inputs",
|
|
2425
|
+
inputs.length
|
|
2426
|
+
);
|
|
2427
|
+
}
|
|
2428
|
+
const inputArg = inputs[0];
|
|
2429
|
+
if (inputArg === void 0) {
|
|
2430
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("LSTM.forward requires an input tensor", "input", inputArg);
|
|
2431
|
+
}
|
|
2432
|
+
const input = extractTensor(inputArg);
|
|
2433
|
+
const hxArg = inputs.length >= 2 ? inputs[1] : void 0;
|
|
2434
|
+
const cxArg = inputs.length >= 3 ? inputs[2] : void 0;
|
|
2435
|
+
const hx = hxArg === void 0 ? void 0 : extractTensor(hxArg);
|
|
2436
|
+
const cx = cxArg === void 0 ? void 0 : extractTensor(cxArg);
|
|
2437
|
+
return this.run(input, hx, cx).output;
|
|
2438
|
+
}
|
|
2439
|
+
/**
|
|
2440
|
+
* Forward pass returning output, hidden state, and cell state.
|
|
2441
|
+
* Use this method when you need the hidden/cell states.
|
|
2442
|
+
*/
|
|
2443
|
+
forwardWithState(input, hx, cx) {
|
|
2444
|
+
const inputTensor = extractTensor(input);
|
|
2445
|
+
const hxTensor = hx === void 0 ? void 0 : extractTensor(hx);
|
|
2446
|
+
const cxTensor = cx === void 0 ? void 0 : extractTensor(cx);
|
|
2447
|
+
const { output, h, c } = this.run(inputTensor, hxTensor, cxTensor);
|
|
2448
|
+
return [output, [h, c]];
|
|
2449
|
+
}
|
|
2450
|
+
toString() {
|
|
2451
|
+
return `LSTM(${this.inputSize}, ${this.hiddenSize}, num_layers=${this.numLayers})`;
|
|
2452
|
+
}
|
|
2453
|
+
};
|
|
2454
|
+
var GRU = class extends Module {
|
|
2455
|
+
inputSize;
|
|
2456
|
+
hiddenSize;
|
|
2457
|
+
numLayers;
|
|
2458
|
+
bias;
|
|
2459
|
+
batchFirst;
|
|
2460
|
+
weightsIh;
|
|
2461
|
+
weightsHh;
|
|
2462
|
+
biasIh;
|
|
2463
|
+
biasHh;
|
|
2464
|
+
constructor(inputSize, hiddenSize, options = {}) {
|
|
2465
|
+
super();
|
|
2466
|
+
validatePositiveInt("inputSize", inputSize);
|
|
2467
|
+
validatePositiveInt("hiddenSize", hiddenSize);
|
|
2468
|
+
const numLayers = options.numLayers ?? 1;
|
|
2469
|
+
validatePositiveInt("numLayers", numLayers);
|
|
2470
|
+
this.inputSize = inputSize;
|
|
2471
|
+
this.hiddenSize = hiddenSize;
|
|
2472
|
+
this.numLayers = numLayers;
|
|
2473
|
+
this.bias = options.bias ?? true;
|
|
2474
|
+
this.batchFirst = options.batchFirst ?? true;
|
|
2475
|
+
const stdv = 1 / Math.sqrt(hiddenSize);
|
|
2476
|
+
this.weightsIh = [];
|
|
2477
|
+
this.weightsHh = [];
|
|
2478
|
+
this.biasIh = [];
|
|
2479
|
+
this.biasHh = [];
|
|
2480
|
+
for (let layer = 0; layer < this.numLayers; layer++) {
|
|
2481
|
+
const inputDim = layer === 0 ? inputSize : hiddenSize;
|
|
2482
|
+
const wIh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([3 * hiddenSize, inputDim]), stdv);
|
|
2483
|
+
const wHh = chunk6AE5FKKQ_cjs.mulScalar(chunk6AE5FKKQ_cjs.randn([3 * hiddenSize, hiddenSize]), stdv);
|
|
2484
|
+
this.weightsIh.push(wIh);
|
|
2485
|
+
this.weightsHh.push(wHh);
|
|
2486
|
+
this.registerParameter(`weight_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wIh));
|
|
2487
|
+
this.registerParameter(`weight_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(wHh));
|
|
2488
|
+
if (this.bias) {
|
|
2489
|
+
const bIh = chunk6AE5FKKQ_cjs.zeros([3 * hiddenSize]);
|
|
2490
|
+
const bHh = chunk6AE5FKKQ_cjs.zeros([3 * hiddenSize]);
|
|
2491
|
+
this.biasIh.push(bIh);
|
|
2492
|
+
this.biasHh.push(bHh);
|
|
2493
|
+
this.registerParameter(`bias_ih_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bIh));
|
|
2494
|
+
this.registerParameter(`bias_hh_l${layer}`, chunk6AE5FKKQ_cjs.parameter(bHh));
|
|
2495
|
+
}
|
|
2496
|
+
}
|
|
2497
|
+
}
|
|
2498
|
+
sigmoid(x) {
|
|
2499
|
+
return 1 / (1 + Math.exp(-x));
|
|
2500
|
+
}
|
|
2501
|
+
run(input, hx) {
|
|
2502
|
+
ensureFloatTensor(input, "GRU");
|
|
2503
|
+
const parsed = parseInput(input, this.batchFirst);
|
|
2504
|
+
const { batch, seqLen, inputDim, isUnbatched, batchStride, seqStride, featStride } = parsed;
|
|
2505
|
+
if (inputDim !== this.inputSize) {
|
|
2506
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Expected input size ${this.inputSize}, got ${inputDim}`);
|
|
2507
|
+
}
|
|
2508
|
+
if (seqLen <= 0) {
|
|
2509
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Sequence length must be positive", "seqLen", seqLen);
|
|
2510
|
+
}
|
|
2511
|
+
if (!isUnbatched && batch <= 0) {
|
|
2512
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("Batch size must be positive", "batch", batch);
|
|
2513
|
+
}
|
|
2514
|
+
const h = buildState(hx, this.numLayers, batch, this.hiddenSize, isUnbatched, "hx");
|
|
2515
|
+
const outSize = (isUnbatched ? seqLen : batch * seqLen) * this.hiddenSize;
|
|
2516
|
+
const out = createFloatBuffer(outSize, input.dtype);
|
|
2517
|
+
const inputVec = new Float64Array(inputDim);
|
|
2518
|
+
const gatesIh = new Float64Array(3 * this.hiddenSize);
|
|
2519
|
+
const gatesHh = new Float64Array(3 * this.hiddenSize);
|
|
2520
|
+
for (let t = 0; t < seqLen; t++) {
|
|
2521
|
+
for (let b = 0; b < batch; b++) {
|
|
2522
|
+
const baseOffset = input.offset + b * batchStride + t * seqStride;
|
|
2523
|
+
for (let i = 0; i < inputDim; i++) {
|
|
2524
|
+
inputVec[i] = readNumeric(input, baseOffset + i * featStride);
|
|
2525
|
+
}
|
|
2526
|
+
let layerInput = inputVec;
|
|
2527
|
+
for (let l = 0; l < this.numLayers; l++) {
|
|
2528
|
+
const wIh = this.weightsIh[l];
|
|
2529
|
+
const wHh = this.weightsHh[l];
|
|
2530
|
+
if (!wIh || !wHh) {
|
|
2531
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing GRU weights");
|
|
2532
|
+
}
|
|
2533
|
+
const curInputSize = l === 0 ? this.inputSize : this.hiddenSize;
|
|
2534
|
+
const hLayer = h[l];
|
|
2535
|
+
if (!hLayer) {
|
|
2536
|
+
throw new chunkJSCDE774_cjs.ShapeError("Internal error: missing GRU hidden state");
|
|
2537
|
+
}
|
|
2538
|
+
const wIhStride0 = wIh.strides[0] ?? 0;
|
|
2539
|
+
const wIhStride1 = wIh.strides[1] ?? 0;
|
|
2540
|
+
const wHhStride0 = wHh.strides[0] ?? 0;
|
|
2541
|
+
const wHhStride1 = wHh.strides[1] ?? 0;
|
|
2542
|
+
const biasIh = this.biasIh[l];
|
|
2543
|
+
const biasHh = this.biasHh[l];
|
|
2544
|
+
const biasIhStride = biasIh ? biasIh.strides[0] ?? 0 : 0;
|
|
2545
|
+
const biasHhStride = biasHh ? biasHh.strides[0] ?? 0 : 0;
|
|
2546
|
+
for (let g = 0; g < 3 * this.hiddenSize; g++) {
|
|
2547
|
+
let sumIh = 0;
|
|
2548
|
+
let sumHh = 0;
|
|
2549
|
+
const wIhBase = wIh.offset + g * wIhStride0;
|
|
2550
|
+
for (let k = 0; k < curInputSize; k++) {
|
|
2551
|
+
sumIh += (layerInput[k] ?? 0) * readNumeric(wIh, wIhBase + k * wIhStride1);
|
|
2552
|
+
}
|
|
2553
|
+
const wHhBase = wHh.offset + g * wHhStride0;
|
|
2554
|
+
for (let k = 0; k < this.hiddenSize; k++) {
|
|
2555
|
+
sumHh += (hLayer[b * this.hiddenSize + k] ?? 0) * readNumeric(wHh, wHhBase + k * wHhStride1);
|
|
2556
|
+
}
|
|
2557
|
+
if (this.bias && biasIh && biasHh) {
|
|
2558
|
+
sumIh += readNumeric(biasIh, biasIh.offset + g * biasIhStride);
|
|
2559
|
+
sumHh += readNumeric(biasHh, biasHh.offset + g * biasHhStride);
|
|
2560
|
+
}
|
|
2561
|
+
gatesIh[g] = sumIh;
|
|
2562
|
+
gatesHh[g] = sumHh;
|
|
2563
|
+
}
|
|
2564
|
+
const newH = new Float64Array(this.hiddenSize);
|
|
2565
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2566
|
+
const r = this.sigmoid((gatesIh[j] ?? 0) + (gatesHh[j] ?? 0));
|
|
2567
|
+
const z = this.sigmoid(
|
|
2568
|
+
(gatesIh[this.hiddenSize + j] ?? 0) + (gatesHh[this.hiddenSize + j] ?? 0)
|
|
2569
|
+
);
|
|
2570
|
+
const n = Math.tanh(
|
|
2571
|
+
(gatesIh[2 * this.hiddenSize + j] ?? 0) + r * (gatesHh[2 * this.hiddenSize + j] ?? 0)
|
|
2572
|
+
);
|
|
2573
|
+
newH[j] = (1 - z) * n + z * (hLayer[b * this.hiddenSize + j] ?? 0);
|
|
2574
|
+
}
|
|
2575
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2576
|
+
hLayer[b * this.hiddenSize + j] = newH[j] ?? 0;
|
|
2577
|
+
}
|
|
2578
|
+
layerInput = newH;
|
|
2579
|
+
}
|
|
2580
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
2581
|
+
const idx = outputIndex(
|
|
2582
|
+
this.batchFirst,
|
|
2583
|
+
isUnbatched,
|
|
2584
|
+
batch,
|
|
2585
|
+
seqLen,
|
|
2586
|
+
this.hiddenSize,
|
|
2587
|
+
b,
|
|
2588
|
+
t,
|
|
2589
|
+
j
|
|
2590
|
+
);
|
|
2591
|
+
out[idx] = layerInput[j] ?? 0;
|
|
2592
|
+
}
|
|
2593
|
+
}
|
|
2594
|
+
}
|
|
2595
|
+
const outShape = isUnbatched ? [seqLen, this.hiddenSize] : this.batchFirst ? [batch, seqLen, this.hiddenSize] : [seqLen, batch, this.hiddenSize];
|
|
2596
|
+
return {
|
|
2597
|
+
output: chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2598
|
+
data: out,
|
|
2599
|
+
shape: outShape,
|
|
2600
|
+
dtype: input.dtype,
|
|
2601
|
+
device: input.device
|
|
2602
|
+
}),
|
|
2603
|
+
h: packState(
|
|
2604
|
+
h,
|
|
2605
|
+
this.numLayers,
|
|
2606
|
+
batch,
|
|
2607
|
+
this.hiddenSize,
|
|
2608
|
+
input.dtype,
|
|
2609
|
+
input.device,
|
|
2610
|
+
isUnbatched
|
|
2611
|
+
)
|
|
2612
|
+
};
|
|
2613
|
+
}
|
|
2614
|
+
forward(...inputs) {
|
|
2615
|
+
if (inputs.length < 1 || inputs.length > 2) {
|
|
2616
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("GRU.forward expects 1 or 2 inputs", "inputs", inputs.length);
|
|
2617
|
+
}
|
|
2618
|
+
const inputArg = inputs[0];
|
|
2619
|
+
if (inputArg === void 0) {
|
|
2620
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError("GRU.forward requires an input tensor", "input", inputArg);
|
|
2621
|
+
}
|
|
2622
|
+
const input = extractTensor(inputArg);
|
|
2623
|
+
const hxArg = inputs.length === 2 ? inputs[1] : void 0;
|
|
2624
|
+
const hx = hxArg === void 0 ? void 0 : extractTensor(hxArg);
|
|
2625
|
+
return this.run(input, hx).output;
|
|
2626
|
+
}
|
|
2627
|
+
/**
|
|
2628
|
+
* Forward pass returning both output and hidden state.
|
|
2629
|
+
* Use this method when you need the hidden state.
|
|
2630
|
+
*/
|
|
2631
|
+
forwardWithState(input, hx) {
|
|
2632
|
+
const inputTensor = extractTensor(input);
|
|
2633
|
+
const hxTensor = hx === void 0 ? void 0 : extractTensor(hx);
|
|
2634
|
+
const { output, h } = this.run(inputTensor, hxTensor);
|
|
2635
|
+
return [output, h];
|
|
2636
|
+
}
|
|
2637
|
+
toString() {
|
|
2638
|
+
return `GRU(${this.inputSize}, ${this.hiddenSize}, num_layers=${this.numLayers})`;
|
|
2639
|
+
}
|
|
2640
|
+
};
|
|
2641
|
+
|
|
2642
|
+
// src/nn/losses/crossEntropy.ts
|
|
2643
|
+
function toOneHot(indices, numClasses) {
|
|
2644
|
+
const nSamples = indices.size;
|
|
2645
|
+
const outData = new Float32Array(nSamples * numClasses);
|
|
2646
|
+
const data = indices.data;
|
|
2647
|
+
if (Array.isArray(data)) {
|
|
2648
|
+
throw new chunkJSCDE774_cjs.DTypeError("crossEntropyLoss target indices must be numeric");
|
|
2649
|
+
}
|
|
2650
|
+
const stride0 = indices.strides[0] ?? 0;
|
|
2651
|
+
const base = indices.offset;
|
|
2652
|
+
for (let i = 0; i < nSamples; i++) {
|
|
2653
|
+
const offset = base + i * stride0;
|
|
2654
|
+
let idx;
|
|
2655
|
+
if (data instanceof BigInt64Array) {
|
|
2656
|
+
const raw = chunkJSCDE774_cjs.getBigIntElement(data, offset);
|
|
2657
|
+
const asNumber = Number(raw);
|
|
2658
|
+
if (!Number.isSafeInteger(asNumber)) {
|
|
2659
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
2660
|
+
`Class index ${raw.toString()} exceeds safe integer range`,
|
|
2661
|
+
"target",
|
|
2662
|
+
raw.toString()
|
|
2663
|
+
);
|
|
2664
|
+
}
|
|
2665
|
+
idx = asNumber;
|
|
2666
|
+
} else {
|
|
2667
|
+
idx = Number(chunkJSCDE774_cjs.getNumericElement(data, offset));
|
|
2668
|
+
}
|
|
2669
|
+
if (!Number.isFinite(idx) || !Number.isInteger(idx)) {
|
|
2670
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`Class index ${idx} is not a valid integer`, "target", idx);
|
|
2671
|
+
}
|
|
2672
|
+
if (idx < 0 || idx >= numClasses) {
|
|
2673
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
2674
|
+
`Class index ${idx} out of range [0, ${numClasses})`,
|
|
2675
|
+
"target",
|
|
2676
|
+
idx
|
|
2677
|
+
);
|
|
2678
|
+
}
|
|
2679
|
+
outData[i * numClasses + idx] = 1;
|
|
2680
|
+
}
|
|
2681
|
+
return chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2682
|
+
data: outData,
|
|
2683
|
+
shape: [nSamples, numClasses],
|
|
2684
|
+
dtype: "float32",
|
|
2685
|
+
device: indices.device
|
|
2686
|
+
});
|
|
2687
|
+
}
|
|
2688
|
+
function crossEntropyLoss(input, target) {
|
|
2689
|
+
const yPred = input instanceof chunk6AE5FKKQ_cjs.GradTensor ? input : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(input);
|
|
2690
|
+
const targetIsGrad = target instanceof chunk6AE5FKKQ_cjs.GradTensor;
|
|
2691
|
+
const yTrue = target instanceof chunk6AE5FKKQ_cjs.GradTensor ? target : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(target, { requiresGrad: false });
|
|
2692
|
+
if (yPred.ndim !== 2) {
|
|
2693
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Input must be 2-dimensional (batch, classes); got ${yPred.ndim}`);
|
|
2694
|
+
}
|
|
2695
|
+
const nSamples = yPred.shape[0] ?? 0;
|
|
2696
|
+
const nClasses = yPred.shape[1] ?? 0;
|
|
2697
|
+
let targetTensor = yTrue;
|
|
2698
|
+
if (yTrue.ndim === 1) {
|
|
2699
|
+
if (targetIsGrad) {
|
|
2700
|
+
throw new chunkJSCDE774_cjs.ShapeError("Target must be 2-dimensional when provided as GradTensor");
|
|
2701
|
+
}
|
|
2702
|
+
if (yTrue.shape[0] !== nSamples) {
|
|
2703
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
2704
|
+
`Target must have same number of samples as input; got ${yTrue.shape[0]} and ${nSamples}`
|
|
2705
|
+
);
|
|
2706
|
+
}
|
|
2707
|
+
const oneHot = toOneHot(yTrue.tensor, nClasses);
|
|
2708
|
+
targetTensor = chunk6AE5FKKQ_cjs.GradTensor.fromTensor(oneHot, { requiresGrad: false });
|
|
2709
|
+
} else if (yTrue.ndim === 2) {
|
|
2710
|
+
if (yTrue.shape[0] !== nSamples || yTrue.shape[1] !== nClasses) {
|
|
2711
|
+
throw new chunkJSCDE774_cjs.ShapeError(
|
|
2712
|
+
"Target must be 1-dimensional class indices or have the same shape as input"
|
|
2713
|
+
);
|
|
2714
|
+
}
|
|
2715
|
+
} else {
|
|
2716
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Target must be 1D (indices) or 2D (probs); got ${yTrue.ndim}D`);
|
|
2717
|
+
}
|
|
2718
|
+
const logProbs = chunk6AE5FKKQ_cjs.logSoftmax2(yPred, 1);
|
|
2719
|
+
const weighted = logProbs.mul(targetTensor);
|
|
2720
|
+
const sampleLoss = weighted.sum(1);
|
|
2721
|
+
const meanLoss = sampleLoss.mean().neg();
|
|
2722
|
+
if (!(input instanceof chunk6AE5FKKQ_cjs.GradTensor) && !targetIsGrad) {
|
|
2723
|
+
const data = meanLoss.tensor.data;
|
|
2724
|
+
if (Array.isArray(data)) {
|
|
2725
|
+
throw new chunkJSCDE774_cjs.DTypeError("crossEntropyLoss does not support string dtype");
|
|
2726
|
+
}
|
|
2727
|
+
if (data instanceof BigInt64Array) {
|
|
2728
|
+
const raw = chunkJSCDE774_cjs.getBigIntElement(data, meanLoss.tensor.offset);
|
|
2729
|
+
return Number(raw);
|
|
2730
|
+
}
|
|
2731
|
+
return chunkJSCDE774_cjs.getNumericElement(data, meanLoss.tensor.offset);
|
|
2732
|
+
}
|
|
2733
|
+
return meanLoss;
|
|
2734
|
+
}
|
|
2735
|
+
function binaryCrossEntropyWithLogitsLoss(input, target) {
|
|
2736
|
+
const yPred = input instanceof chunk6AE5FKKQ_cjs.GradTensor ? input : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(input);
|
|
2737
|
+
const yTrue = target instanceof chunk6AE5FKKQ_cjs.GradTensor ? target : chunk6AE5FKKQ_cjs.GradTensor.fromTensor(target, { requiresGrad: false });
|
|
2738
|
+
let pred = yPred;
|
|
2739
|
+
let truth = yTrue;
|
|
2740
|
+
if (pred.ndim !== 1 && pred.ndim !== 2) {
|
|
2741
|
+
throw new chunkJSCDE774_cjs.ShapeError("Input must be 1 or 2-dimensional");
|
|
2742
|
+
}
|
|
2743
|
+
if (truth.ndim !== 1 && truth.ndim !== 2) {
|
|
2744
|
+
throw new chunkJSCDE774_cjs.ShapeError("Target must be 1 or 2-dimensional");
|
|
2745
|
+
}
|
|
2746
|
+
if (pred.ndim === 1) {
|
|
2747
|
+
pred = pred.reshape([pred.shape[0] ?? 0, 1]);
|
|
2748
|
+
}
|
|
2749
|
+
if (truth.ndim === 1) {
|
|
2750
|
+
truth = truth.reshape([truth.shape[0] ?? 0, 1]);
|
|
2751
|
+
}
|
|
2752
|
+
if (pred.ndim !== 2 || pred.shape[1] !== 1) {
|
|
2753
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Input must have shape (N,) or (N, 1)`);
|
|
2754
|
+
}
|
|
2755
|
+
if (truth.ndim !== 2 || truth.shape[1] !== 1) {
|
|
2756
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Target must be 1-dimensional or have shape (N, 1)`);
|
|
2757
|
+
}
|
|
2758
|
+
if ((pred.shape[0] ?? 0) !== (truth.shape[0] ?? 0)) {
|
|
2759
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Batch size mismatch`);
|
|
2760
|
+
}
|
|
2761
|
+
const predDtype = pred.dtype;
|
|
2762
|
+
if (predDtype === "string") {
|
|
2763
|
+
throw new chunkJSCDE774_cjs.DTypeError("Binary cross entropy does not support string dtype");
|
|
2764
|
+
}
|
|
2765
|
+
const term1 = pred.relu();
|
|
2766
|
+
const term2 = pred.mul(truth);
|
|
2767
|
+
const negPred = pred.neg();
|
|
2768
|
+
const absPred = pred.relu().add(negPred.relu());
|
|
2769
|
+
const expNegAbs = absPred.neg().exp();
|
|
2770
|
+
const scalarDtype = expNegAbs.dtype;
|
|
2771
|
+
if (scalarDtype === "string") {
|
|
2772
|
+
throw new chunkJSCDE774_cjs.DTypeError("binaryCrossEntropyWithLogitsLoss does not support string dtype");
|
|
2773
|
+
}
|
|
2774
|
+
const one = chunk6AE5FKKQ_cjs.GradTensor.scalar(1, { dtype: scalarDtype });
|
|
2775
|
+
const term3 = one.add(expNegAbs).log();
|
|
2776
|
+
const loss = term1.sub(term2).add(term3).mean();
|
|
2777
|
+
if (!(input instanceof chunk6AE5FKKQ_cjs.GradTensor) && !(target instanceof chunk6AE5FKKQ_cjs.GradTensor)) {
|
|
2778
|
+
const data = loss.tensor.data;
|
|
2779
|
+
if (Array.isArray(data)) {
|
|
2780
|
+
throw new chunkJSCDE774_cjs.DTypeError("binaryCrossEntropyWithLogitsLoss does not support string dtype");
|
|
2781
|
+
}
|
|
2782
|
+
if (data instanceof BigInt64Array) {
|
|
2783
|
+
const raw = chunkJSCDE774_cjs.getBigIntElement(data, loss.tensor.offset);
|
|
2784
|
+
return Number(raw);
|
|
2785
|
+
}
|
|
2786
|
+
return chunkJSCDE774_cjs.getNumericElement(data, loss.tensor.offset);
|
|
2787
|
+
}
|
|
2788
|
+
return loss;
|
|
2789
|
+
}
|
|
2790
|
+
|
|
2791
|
+
// src/nn/losses/index.ts
|
|
2792
|
+
function shapesEqual2(a, b) {
|
|
2793
|
+
if (a.length !== b.length) return false;
|
|
2794
|
+
for (let i = 0; i < a.length; i++) {
|
|
2795
|
+
if ((a[i] ?? 0) !== (b[i] ?? 0)) return false;
|
|
2796
|
+
}
|
|
2797
|
+
return true;
|
|
2798
|
+
}
|
|
2799
|
+
function ensureSameShape(a, b, context) {
|
|
2800
|
+
if (!shapesEqual2(a.shape, b.shape)) {
|
|
2801
|
+
throw new chunkJSCDE774_cjs.ShapeError(`Shape mismatch in ${context}: [${a.shape}] vs [${b.shape}]`);
|
|
2802
|
+
}
|
|
2803
|
+
}
|
|
2804
|
+
function ensureNumeric(t, context) {
|
|
2805
|
+
if (t.dtype === "string") {
|
|
2806
|
+
throw new chunkJSCDE774_cjs.DTypeError(`${context} does not support string dtype`);
|
|
2807
|
+
}
|
|
2808
|
+
}
|
|
2809
|
+
function validateReduction(reduction, context) {
|
|
2810
|
+
if (reduction !== "mean" && reduction !== "sum" && reduction !== "none") {
|
|
2811
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(
|
|
2812
|
+
`${context} reduction must be 'mean', 'sum', or 'none'`,
|
|
2813
|
+
"reduction",
|
|
2814
|
+
reduction
|
|
2815
|
+
);
|
|
2816
|
+
}
|
|
2817
|
+
}
|
|
2818
|
+
function readNumericFlat(data, flat, logicalStrides, strides, offset) {
|
|
2819
|
+
const dataOffset = chunk6AE5FKKQ_cjs.offsetFromFlatIndex(flat, logicalStrides, strides, offset);
|
|
2820
|
+
return chunkJSCDE774_cjs.getElementAsNumber(data, dataOffset);
|
|
2821
|
+
}
|
|
2822
|
+
function mseLoss(predictions, targets, reduction = "mean") {
|
|
2823
|
+
validateReduction(reduction, "mseLoss");
|
|
2824
|
+
ensureNumeric(predictions, "mseLoss");
|
|
2825
|
+
ensureNumeric(targets, "mseLoss");
|
|
2826
|
+
ensureSameShape(predictions, targets, "mseLoss");
|
|
2827
|
+
const diff = chunk6AE5FKKQ_cjs.sub(predictions, targets);
|
|
2828
|
+
const squaredDiff = chunk6AE5FKKQ_cjs.pow(diff, chunk6AE5FKKQ_cjs.tensor(2, { dtype: diff.dtype, device: diff.device }));
|
|
2829
|
+
if (reduction === "none") {
|
|
2830
|
+
return squaredDiff;
|
|
2831
|
+
}
|
|
2832
|
+
if (reduction === "sum") {
|
|
2833
|
+
return chunk6AE5FKKQ_cjs.sum(squaredDiff);
|
|
2834
|
+
}
|
|
2835
|
+
return chunk6AE5FKKQ_cjs.mean(squaredDiff);
|
|
2836
|
+
}
|
|
2837
|
+
function maeLoss(predictions, targets, reduction = "mean") {
|
|
2838
|
+
validateReduction(reduction, "maeLoss");
|
|
2839
|
+
ensureNumeric(predictions, "maeLoss");
|
|
2840
|
+
ensureNumeric(targets, "maeLoss");
|
|
2841
|
+
ensureSameShape(predictions, targets, "maeLoss");
|
|
2842
|
+
const diff = chunk6AE5FKKQ_cjs.sub(predictions, targets);
|
|
2843
|
+
const absDiff = chunk6AE5FKKQ_cjs.abs(diff);
|
|
2844
|
+
if (reduction === "none") {
|
|
2845
|
+
return absDiff;
|
|
2846
|
+
}
|
|
2847
|
+
if (reduction === "sum") {
|
|
2848
|
+
return chunk6AE5FKKQ_cjs.sum(absDiff);
|
|
2849
|
+
}
|
|
2850
|
+
return chunk6AE5FKKQ_cjs.mean(absDiff);
|
|
2851
|
+
}
|
|
2852
|
+
function binaryCrossEntropyLoss(predictions, targets, reduction = "mean") {
|
|
2853
|
+
validateReduction(reduction, "binaryCrossEntropyLoss");
|
|
2854
|
+
ensureNumeric(predictions, "binaryCrossEntropyLoss");
|
|
2855
|
+
ensureNumeric(targets, "binaryCrossEntropyLoss");
|
|
2856
|
+
ensureSameShape(predictions, targets, "binaryCrossEntropyLoss");
|
|
2857
|
+
const epsilon = 1e-7;
|
|
2858
|
+
const predClamped = chunk6AE5FKKQ_cjs.clip(predictions, epsilon, 1 - epsilon);
|
|
2859
|
+
const logPred = chunk6AE5FKKQ_cjs.log(predClamped);
|
|
2860
|
+
const term1 = chunk6AE5FKKQ_cjs.mul(targets, logPred);
|
|
2861
|
+
const one = chunk6AE5FKKQ_cjs.tensor(1, {
|
|
2862
|
+
dtype: predictions.dtype === "float64" ? "float64" : "float32",
|
|
2863
|
+
device: predictions.device
|
|
2864
|
+
});
|
|
2865
|
+
const oneMinusTargets = chunk6AE5FKKQ_cjs.sub(one, targets);
|
|
2866
|
+
const oneMinusPred = chunk6AE5FKKQ_cjs.sub(one, predClamped);
|
|
2867
|
+
const logOneMinusPred = chunk6AE5FKKQ_cjs.log(oneMinusPred);
|
|
2868
|
+
const term2 = chunk6AE5FKKQ_cjs.mul(oneMinusTargets, logOneMinusPred);
|
|
2869
|
+
const loss = chunk6AE5FKKQ_cjs.neg(chunk6AE5FKKQ_cjs.add(term1, term2));
|
|
2870
|
+
if (reduction === "none") {
|
|
2871
|
+
return loss;
|
|
2872
|
+
}
|
|
2873
|
+
if (reduction === "sum") {
|
|
2874
|
+
return chunk6AE5FKKQ_cjs.sum(loss);
|
|
2875
|
+
}
|
|
2876
|
+
return chunk6AE5FKKQ_cjs.mean(loss);
|
|
2877
|
+
}
|
|
2878
|
+
function rmseLoss(predictions, targets) {
|
|
2879
|
+
ensureNumeric(predictions, "rmseLoss");
|
|
2880
|
+
ensureNumeric(targets, "rmseLoss");
|
|
2881
|
+
ensureSameShape(predictions, targets, "rmseLoss");
|
|
2882
|
+
const mse = mseLoss(predictions, targets, "mean");
|
|
2883
|
+
return chunk6AE5FKKQ_cjs.sqrt(mse);
|
|
2884
|
+
}
|
|
2885
|
+
function huberLoss(predictions, targets, delta = 1, reduction = "mean") {
|
|
2886
|
+
validateReduction(reduction, "huberLoss");
|
|
2887
|
+
ensureNumeric(predictions, "huberLoss");
|
|
2888
|
+
ensureNumeric(targets, "huberLoss");
|
|
2889
|
+
ensureSameShape(predictions, targets, "huberLoss");
|
|
2890
|
+
if (!Number.isFinite(delta) || delta <= 0) {
|
|
2891
|
+
throw new chunkJSCDE774_cjs.InvalidParameterError(`delta must be positive; got ${delta}`, "delta", delta);
|
|
2892
|
+
}
|
|
2893
|
+
const diff = chunk6AE5FKKQ_cjs.sub(predictions, targets);
|
|
2894
|
+
const absDiff = chunk6AE5FKKQ_cjs.abs(diff);
|
|
2895
|
+
const absData = absDiff.data;
|
|
2896
|
+
if (Array.isArray(absData)) {
|
|
2897
|
+
throw new chunkJSCDE774_cjs.DTypeError("huberLoss does not support string dtype");
|
|
2898
|
+
}
|
|
2899
|
+
const dtype = predictions.dtype === "float64" ? "float64" : "float32";
|
|
2900
|
+
const lossData = dtype === "float64" ? new Float64Array(diff.size) : new Float32Array(diff.size);
|
|
2901
|
+
const logicalStrides = chunk6AE5FKKQ_cjs.computeStrides(absDiff.shape);
|
|
2902
|
+
for (let i = 0; i < diff.size; i++) {
|
|
2903
|
+
const absVal = readNumericFlat(absData, i, logicalStrides, absDiff.strides, absDiff.offset);
|
|
2904
|
+
if (absVal <= delta) {
|
|
2905
|
+
lossData[i] = 0.5 * absVal * absVal;
|
|
2906
|
+
} else {
|
|
2907
|
+
lossData[i] = delta * (absVal - 0.5 * delta);
|
|
2908
|
+
}
|
|
2909
|
+
}
|
|
2910
|
+
const loss = chunk6AE5FKKQ_cjs.Tensor.fromTypedArray({
|
|
2911
|
+
data: lossData,
|
|
2912
|
+
shape: predictions.shape,
|
|
2913
|
+
dtype,
|
|
2914
|
+
device: predictions.device
|
|
2915
|
+
});
|
|
2916
|
+
if (reduction === "none") {
|
|
2917
|
+
return loss;
|
|
2918
|
+
}
|
|
2919
|
+
if (reduction === "sum") {
|
|
2920
|
+
return chunk6AE5FKKQ_cjs.sum(loss);
|
|
2921
|
+
}
|
|
2922
|
+
return chunk6AE5FKKQ_cjs.mean(loss);
|
|
2923
|
+
}
|
|
2924
|
+
|
|
2925
|
+
exports.AvgPool2d = AvgPool2d;
|
|
2926
|
+
exports.BatchNorm1d = BatchNorm1d;
|
|
2927
|
+
exports.Conv1d = Conv1d;
|
|
2928
|
+
exports.Conv2d = Conv2d;
|
|
2929
|
+
exports.Dropout = Dropout;
|
|
2930
|
+
exports.ELU = ELU;
|
|
2931
|
+
exports.GELU = GELU;
|
|
2932
|
+
exports.GRU = GRU;
|
|
2933
|
+
exports.LSTM = LSTM;
|
|
2934
|
+
exports.LayerNorm = LayerNorm;
|
|
2935
|
+
exports.LeakyReLU = LeakyReLU;
|
|
2936
|
+
exports.Linear = Linear;
|
|
2937
|
+
exports.LogSoftmax = LogSoftmax;
|
|
2938
|
+
exports.MaxPool2d = MaxPool2d;
|
|
2939
|
+
exports.Mish = Mish;
|
|
2940
|
+
exports.Module = Module;
|
|
2941
|
+
exports.MultiheadAttention = MultiheadAttention;
|
|
2942
|
+
exports.RNN = RNN;
|
|
2943
|
+
exports.ReLU = ReLU;
|
|
2944
|
+
exports.Sequential = Sequential;
|
|
2945
|
+
exports.Sigmoid = Sigmoid;
|
|
2946
|
+
exports.Softmax = Softmax;
|
|
2947
|
+
exports.Softplus = Softplus;
|
|
2948
|
+
exports.Swish = Swish;
|
|
2949
|
+
exports.Tanh = Tanh;
|
|
2950
|
+
exports.TransformerEncoderLayer = TransformerEncoderLayer;
|
|
2951
|
+
exports.binaryCrossEntropyLoss = binaryCrossEntropyLoss;
|
|
2952
|
+
exports.binaryCrossEntropyWithLogitsLoss = binaryCrossEntropyWithLogitsLoss;
|
|
2953
|
+
exports.crossEntropyLoss = crossEntropyLoss;
|
|
2954
|
+
exports.huberLoss = huberLoss;
|
|
2955
|
+
exports.maeLoss = maeLoss;
|
|
2956
|
+
exports.mseLoss = mseLoss;
|
|
2957
|
+
exports.nn_exports = nn_exports;
|
|
2958
|
+
exports.rmseLoss = rmseLoss;
|
|
2959
|
+
//# sourceMappingURL=chunk-QERHVCHC.cjs.map
|
|
2960
|
+
//# sourceMappingURL=chunk-QERHVCHC.cjs.map
|