catniff 0.5.4 → 0.5.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -3
- package/dist/core.d.ts +5 -0
- package/dist/core.js +148 -30
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -120,10 +120,11 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
|
|
|
120
120
|
|
|
121
121
|
## Todos
|
|
122
122
|
|
|
123
|
-
*
|
|
124
|
-
* More tensor ops.
|
|
125
|
-
* GPU acceleration.
|
|
123
|
+
* More general tensor ops.
|
|
126
124
|
* More general neural net APIs.
|
|
125
|
+
* GPU acceleration.
|
|
126
|
+
* Comprehensive caching.
|
|
127
|
+
* Bug fixes.
|
|
127
128
|
* More detailed documentation.
|
|
128
129
|
* Code refactoring.
|
|
129
130
|
* Proper tests.
|
package/dist/core.d.ts
CHANGED
|
@@ -40,6 +40,9 @@ export declare class Tensor {
|
|
|
40
40
|
elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
41
41
|
static forceTensor(value: TensorValue | Tensor): Tensor;
|
|
42
42
|
static addGrad(tensor: Tensor, accumGrad: Tensor): void;
|
|
43
|
+
isContiguous(): boolean;
|
|
44
|
+
contiguous(): Tensor;
|
|
45
|
+
reshape(newShape: readonly number[]): Tensor;
|
|
43
46
|
squeeze(dims?: number[] | number): Tensor;
|
|
44
47
|
unsqueeze(dim: number): Tensor;
|
|
45
48
|
sum(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
@@ -125,6 +128,7 @@ export declare class Tensor {
|
|
|
125
128
|
softsign(): Tensor;
|
|
126
129
|
silu(): Tensor;
|
|
127
130
|
mish(): Tensor;
|
|
131
|
+
gelu(approximate?: string): Tensor;
|
|
128
132
|
maximum(other: TensorValue | Tensor): Tensor;
|
|
129
133
|
minimum(other: TensorValue | Tensor): Tensor;
|
|
130
134
|
round(): Tensor;
|
|
@@ -142,6 +146,7 @@ export declare class Tensor {
|
|
|
142
146
|
swapaxes: (dim1: number, dim2: number) => Tensor;
|
|
143
147
|
swapdims: (dim1: number, dim2: number) => Tensor;
|
|
144
148
|
t(): Tensor;
|
|
149
|
+
permute(dims: number[]): Tensor;
|
|
145
150
|
dot(other: TensorValue | Tensor): Tensor;
|
|
146
151
|
mm(other: TensorValue | Tensor): Tensor;
|
|
147
152
|
bmm(other: TensorValue | Tensor): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -201,9 +201,9 @@ class Tensor {
|
|
|
201
201
|
if (out.requiresGrad) {
|
|
202
202
|
out.gradFn = () => {
|
|
203
203
|
// Disable gradient collecting of gradients themselves
|
|
204
|
-
const outGrad = out.grad
|
|
205
|
-
const selfNoGrad = this.
|
|
206
|
-
const otherNoGrad = other.
|
|
204
|
+
const outGrad = out.grad;
|
|
205
|
+
const selfNoGrad = this.detach();
|
|
206
|
+
const otherNoGrad = other.detach();
|
|
207
207
|
if (this.requiresGrad)
|
|
208
208
|
Tensor.addGrad(this, thisGrad(selfNoGrad, otherNoGrad, outGrad));
|
|
209
209
|
if (other.requiresGrad)
|
|
@@ -222,8 +222,8 @@ class Tensor {
|
|
|
222
222
|
if (out.requiresGrad) {
|
|
223
223
|
out.gradFn = () => {
|
|
224
224
|
// Disable gradient collecting of gradients themselves
|
|
225
|
-
const outGrad = out.grad
|
|
226
|
-
const selfNoGrad = this.
|
|
225
|
+
const outGrad = out.grad;
|
|
226
|
+
const selfNoGrad = this.detach();
|
|
227
227
|
if (this.requiresGrad)
|
|
228
228
|
Tensor.addGrad(this, thisGrad(selfNoGrad, outGrad));
|
|
229
229
|
};
|
|
@@ -261,6 +261,64 @@ class Tensor {
|
|
|
261
261
|
tensor.grad = tensor.grad.add(squeezedGrad);
|
|
262
262
|
}
|
|
263
263
|
}
|
|
264
|
+
// Contiguity-related ops
|
|
265
|
+
isContiguous() {
|
|
266
|
+
const expectedStrides = Tensor.getStrides(this.shape);
|
|
267
|
+
if (expectedStrides.length !== this.strides.length) {
|
|
268
|
+
return false;
|
|
269
|
+
}
|
|
270
|
+
for (let i = 0; i < this.strides.length; i++) {
|
|
271
|
+
if (this.strides[i] !== expectedStrides[i]) {
|
|
272
|
+
return false;
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
return true;
|
|
276
|
+
}
|
|
277
|
+
contiguous() {
|
|
278
|
+
// Check if scalar
|
|
279
|
+
if (typeof this.value === "number")
|
|
280
|
+
return this;
|
|
281
|
+
// Check if already contiguous
|
|
282
|
+
if (this.isContiguous())
|
|
283
|
+
return this;
|
|
284
|
+
const outputStrides = Tensor.getStrides(this.shape);
|
|
285
|
+
const outputSize = Tensor.shapeToSize(this.shape);
|
|
286
|
+
const outputValue = new Array(outputSize);
|
|
287
|
+
for (let index = 0; index < outputSize; index++) {
|
|
288
|
+
const outputCoords = Tensor.indexToCoords(index, outputStrides);
|
|
289
|
+
const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
|
|
290
|
+
outputValue[index] = this.value[originalIndex];
|
|
291
|
+
}
|
|
292
|
+
const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides });
|
|
293
|
+
// Gradient flow back to the original tensor
|
|
294
|
+
if (this.requiresGrad) {
|
|
295
|
+
out.requiresGrad = true;
|
|
296
|
+
out.children.push(this);
|
|
297
|
+
out.gradFn = () => {
|
|
298
|
+
Tensor.addGrad(this, out.grad);
|
|
299
|
+
};
|
|
300
|
+
}
|
|
301
|
+
return out;
|
|
302
|
+
}
|
|
303
|
+
reshape(newShape) {
|
|
304
|
+
// Verify shape size
|
|
305
|
+
const originalSize = Tensor.shapeToSize(this.shape);
|
|
306
|
+
const outputSize = Tensor.shapeToSize(newShape);
|
|
307
|
+
if (originalSize !== outputSize) {
|
|
308
|
+
throw new Error("Cannot reshape: incompatible sizes");
|
|
309
|
+
}
|
|
310
|
+
const outputStrides = Tensor.getStrides(newShape);
|
|
311
|
+
const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides });
|
|
312
|
+
// Gradient reshaped and flow back to the original tensor
|
|
313
|
+
if (this.requiresGrad) {
|
|
314
|
+
out.requiresGrad = true;
|
|
315
|
+
out.children.push(this);
|
|
316
|
+
out.gradFn = () => {
|
|
317
|
+
Tensor.addGrad(this, out.grad.reshape(this.shape));
|
|
318
|
+
};
|
|
319
|
+
}
|
|
320
|
+
return out;
|
|
321
|
+
}
|
|
264
322
|
// Tensor squeeze
|
|
265
323
|
squeeze(dims) {
|
|
266
324
|
if (typeof this.value === "number")
|
|
@@ -302,7 +360,7 @@ class Tensor {
|
|
|
302
360
|
out.requiresGrad = true;
|
|
303
361
|
out.children.push(this);
|
|
304
362
|
out.gradFn = () => {
|
|
305
|
-
let restoredGrad = out.grad
|
|
363
|
+
let restoredGrad = out.grad;
|
|
306
364
|
for (let i = dims.length - 1; i >= 0; i--) {
|
|
307
365
|
restoredGrad = restoredGrad.unsqueeze(dims[i]);
|
|
308
366
|
}
|
|
@@ -338,7 +396,7 @@ class Tensor {
|
|
|
338
396
|
out.requiresGrad = true;
|
|
339
397
|
out.children.push(this);
|
|
340
398
|
out.gradFn = () => {
|
|
341
|
-
Tensor.addGrad(this, out.grad.
|
|
399
|
+
Tensor.addGrad(this, out.grad.squeeze(dim));
|
|
342
400
|
};
|
|
343
401
|
}
|
|
344
402
|
return out;
|
|
@@ -397,7 +455,7 @@ class Tensor {
|
|
|
397
455
|
out.children.push(this);
|
|
398
456
|
out.gradFn = () => {
|
|
399
457
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
400
|
-
Tensor.addGrad(this, out.grad.
|
|
458
|
+
Tensor.addGrad(this, out.grad.mul(localGrad));
|
|
401
459
|
};
|
|
402
460
|
}
|
|
403
461
|
return keepDims ? out : out.squeeze(dims);
|
|
@@ -454,7 +512,7 @@ class Tensor {
|
|
|
454
512
|
gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
|
|
455
513
|
}
|
|
456
514
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
457
|
-
Tensor.addGrad(this, out.grad.
|
|
515
|
+
Tensor.addGrad(this, out.grad.mul(localGrad));
|
|
458
516
|
};
|
|
459
517
|
}
|
|
460
518
|
return keepDims ? out : out.squeeze(dims);
|
|
@@ -518,7 +576,7 @@ class Tensor {
|
|
|
518
576
|
gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
|
|
519
577
|
}
|
|
520
578
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
521
|
-
Tensor.addGrad(this, out.grad.
|
|
579
|
+
Tensor.addGrad(this, out.grad.mul(localGrad));
|
|
522
580
|
};
|
|
523
581
|
}
|
|
524
582
|
return keepDims ? out : out.squeeze(dims);
|
|
@@ -588,7 +646,7 @@ class Tensor {
|
|
|
588
646
|
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
|
|
589
647
|
}
|
|
590
648
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
591
|
-
Tensor.addGrad(this, out.grad.
|
|
649
|
+
Tensor.addGrad(this, out.grad.mul(localGrad));
|
|
592
650
|
};
|
|
593
651
|
}
|
|
594
652
|
return keepDims ? out : out.squeeze(dims);
|
|
@@ -658,7 +716,7 @@ class Tensor {
|
|
|
658
716
|
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
|
|
659
717
|
}
|
|
660
718
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
661
|
-
Tensor.addGrad(this, out.grad.
|
|
719
|
+
Tensor.addGrad(this, out.grad.mul(localGrad));
|
|
662
720
|
};
|
|
663
721
|
}
|
|
664
722
|
return keepDims ? out : out.squeeze(dims);
|
|
@@ -727,8 +785,8 @@ class Tensor {
|
|
|
727
785
|
out.requiresGrad = true;
|
|
728
786
|
out.children.push(this);
|
|
729
787
|
out.gradFn = () => {
|
|
730
|
-
const upstreamGrad = out.grad
|
|
731
|
-
const softmaxOutput = out.
|
|
788
|
+
const upstreamGrad = out.grad;
|
|
789
|
+
const softmaxOutput = out.detach();
|
|
732
790
|
// Compute element-wise product: ∂L/∂σᵢ × σᵢ
|
|
733
791
|
const gradTimesOutput = upstreamGrad.mul(softmaxOutput);
|
|
734
792
|
// Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ)
|
|
@@ -1000,6 +1058,34 @@ class Tensor {
|
|
|
1000
1058
|
return outGrad.mul(derivative);
|
|
1001
1059
|
});
|
|
1002
1060
|
}
|
|
1061
|
+
// Tensor element-wise gelu
|
|
1062
|
+
gelu(approximate = "none") {
|
|
1063
|
+
if (approximate === "none") {
|
|
1064
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => {
|
|
1065
|
+
const sqrt2 = Math.sqrt(2);
|
|
1066
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1067
|
+
const xOverSqrt2 = self.div(sqrt2);
|
|
1068
|
+
const erfVal = xOverSqrt2.erf();
|
|
1069
|
+
const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi);
|
|
1070
|
+
const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi));
|
|
1071
|
+
return outGrad.mul(derivative);
|
|
1072
|
+
});
|
|
1073
|
+
}
|
|
1074
|
+
else if (approximate === "tanh") {
|
|
1075
|
+
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => {
|
|
1076
|
+
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
|
|
1077
|
+
const c = 0.044715;
|
|
1078
|
+
const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi);
|
|
1079
|
+
const tanhVal = tanhArg.tanh();
|
|
1080
|
+
const sechSquared = tanhVal.square().neg().add(1);
|
|
1081
|
+
const term1 = tanhVal.add(1).mul(0.5);
|
|
1082
|
+
const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5);
|
|
1083
|
+
const derivative = term1.add(term2);
|
|
1084
|
+
return outGrad.mul(derivative);
|
|
1085
|
+
});
|
|
1086
|
+
}
|
|
1087
|
+
throw new Error("Specified approximation does not exist");
|
|
1088
|
+
}
|
|
1003
1089
|
// Tensor element-wise maximum
|
|
1004
1090
|
maximum(other) {
|
|
1005
1091
|
return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
|
|
@@ -1068,7 +1154,7 @@ class Tensor {
|
|
|
1068
1154
|
if (this.requiresGrad) {
|
|
1069
1155
|
out.children.push(this);
|
|
1070
1156
|
out.gradFn = () => {
|
|
1071
|
-
Tensor.addGrad(this, out.grad.
|
|
1157
|
+
Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
|
|
1072
1158
|
};
|
|
1073
1159
|
}
|
|
1074
1160
|
return out;
|
|
@@ -1083,6 +1169,39 @@ class Tensor {
|
|
|
1083
1169
|
}
|
|
1084
1170
|
return this.transpose(0, 1);
|
|
1085
1171
|
}
|
|
1172
|
+
// Permute
|
|
1173
|
+
permute(dims) {
|
|
1174
|
+
if (dims.length !== this.shape.length) {
|
|
1175
|
+
throw new Error("Permutation must specify all dimensions");
|
|
1176
|
+
}
|
|
1177
|
+
// Compute new shape and strides
|
|
1178
|
+
const newShape = new Array(dims.length);
|
|
1179
|
+
const newStrides = new Array(dims.length);
|
|
1180
|
+
for (let index = 0; index < dims.length; index++) {
|
|
1181
|
+
const dim = dims[index];
|
|
1182
|
+
newShape[index] = this.shape[dim];
|
|
1183
|
+
newStrides[index] = this.strides[dim];
|
|
1184
|
+
}
|
|
1185
|
+
const out = new Tensor(this.value, {
|
|
1186
|
+
shape: newShape,
|
|
1187
|
+
strides: newStrides
|
|
1188
|
+
});
|
|
1189
|
+
if (this.requiresGrad) {
|
|
1190
|
+
out.requiresGrad = true;
|
|
1191
|
+
out.children.push(this);
|
|
1192
|
+
out.gradFn = () => {
|
|
1193
|
+
// Compute inverse permutation
|
|
1194
|
+
const inverseAxes = new Array(dims.length);
|
|
1195
|
+
for (let i = 0; i < dims.length; i++) {
|
|
1196
|
+
inverseAxes[dims[i]] = i;
|
|
1197
|
+
}
|
|
1198
|
+
// Permute gradient back to original order
|
|
1199
|
+
const permutedGrad = out.grad.permute(inverseAxes);
|
|
1200
|
+
Tensor.addGrad(this, permutedGrad);
|
|
1201
|
+
};
|
|
1202
|
+
}
|
|
1203
|
+
return out;
|
|
1204
|
+
}
|
|
1086
1205
|
// 1D tensor dot product
|
|
1087
1206
|
dot(other) {
|
|
1088
1207
|
other = Tensor.forceTensor(other);
|
|
@@ -1110,9 +1229,9 @@ class Tensor {
|
|
|
1110
1229
|
if (out.requiresGrad) {
|
|
1111
1230
|
out.gradFn = () => {
|
|
1112
1231
|
// Disable gradient collecting of gradients themselves
|
|
1113
|
-
const outGrad = out.grad
|
|
1114
|
-
const selfNoGrad = this.
|
|
1115
|
-
const otherNoGrad = other.
|
|
1232
|
+
const outGrad = out.grad;
|
|
1233
|
+
const selfNoGrad = this.detach();
|
|
1234
|
+
const otherNoGrad = other.detach();
|
|
1116
1235
|
if (this.requiresGrad)
|
|
1117
1236
|
Tensor.addGrad(this, outGrad.mul(otherNoGrad));
|
|
1118
1237
|
if (other.requiresGrad)
|
|
@@ -1165,9 +1284,9 @@ class Tensor {
|
|
|
1165
1284
|
if (out.requiresGrad) {
|
|
1166
1285
|
out.gradFn = () => {
|
|
1167
1286
|
// Disable gradient collecting of gradients themselves
|
|
1168
|
-
const outGrad = out.grad
|
|
1169
|
-
const selfNoGrad = this.
|
|
1170
|
-
const otherNoGrad = other.
|
|
1287
|
+
const outGrad = out.grad;
|
|
1288
|
+
const selfNoGrad = this.detach();
|
|
1289
|
+
const otherNoGrad = other.detach();
|
|
1171
1290
|
if (this.requiresGrad)
|
|
1172
1291
|
Tensor.addGrad(this, outGrad.mm(otherNoGrad.t()));
|
|
1173
1292
|
if (other.requiresGrad)
|
|
@@ -1223,9 +1342,9 @@ class Tensor {
|
|
|
1223
1342
|
if (out.requiresGrad) {
|
|
1224
1343
|
out.gradFn = () => {
|
|
1225
1344
|
// Disable gradient collecting of gradients themselves
|
|
1226
|
-
const outGrad = out.grad
|
|
1227
|
-
const selfNoGrad = this.
|
|
1228
|
-
const otherNoGrad = other.
|
|
1345
|
+
const outGrad = out.grad;
|
|
1346
|
+
const selfNoGrad = this.detach();
|
|
1347
|
+
const otherNoGrad = other.detach();
|
|
1229
1348
|
if (this.requiresGrad)
|
|
1230
1349
|
Tensor.addGrad(this, outGrad.bmm(otherNoGrad.transpose(1, 2)));
|
|
1231
1350
|
if (other.requiresGrad)
|
|
@@ -1260,9 +1379,8 @@ class Tensor {
|
|
|
1260
1379
|
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
1261
1380
|
return this.mm(other);
|
|
1262
1381
|
}
|
|
1263
|
-
else if ((
|
|
1264
|
-
(
|
|
1265
|
-
(other.shape.length > 2 && this.shape.length > 2)) {
|
|
1382
|
+
else if ((this.shape.length > 0 && other.shape.length >= 2) ||
|
|
1383
|
+
(this.shape.length >= 2 && other.shape.length > 0)) {
|
|
1266
1384
|
// Append/prepend dims if needed
|
|
1267
1385
|
const self = isThis1D ? this.unsqueeze(0) : this;
|
|
1268
1386
|
other = isOther1D ? other.unsqueeze(1) : other;
|
|
@@ -1323,9 +1441,9 @@ class Tensor {
|
|
|
1323
1441
|
if (out.requiresGrad) {
|
|
1324
1442
|
out.gradFn = () => {
|
|
1325
1443
|
other = other;
|
|
1326
|
-
const outGrad = out.grad
|
|
1327
|
-
const selfNoGrad = self.
|
|
1328
|
-
const otherNoGrad = other.
|
|
1444
|
+
const outGrad = out.grad;
|
|
1445
|
+
const selfNoGrad = self.detach();
|
|
1446
|
+
const otherNoGrad = other.detach();
|
|
1329
1447
|
if (this.requiresGrad)
|
|
1330
1448
|
Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
|
|
1331
1449
|
if (other.requiresGrad)
|