catniff 0.2.14 → 0.2.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.js +79 -4
- package/package.json +1 -1
package/dist/core.js
CHANGED
|
@@ -1144,19 +1144,94 @@ class Tensor {
|
|
|
1144
1144
|
// General matrix multiplication with different shapes
|
|
1145
1145
|
matmul(other) {
|
|
1146
1146
|
other = Tensor.forceTensor(other);
|
|
1147
|
-
|
|
1147
|
+
const isThis1D = this.shape.length === 1;
|
|
1148
|
+
const isOther1D = other.shape.length === 1;
|
|
1149
|
+
if (isThis1D && isOther1D) {
|
|
1148
1150
|
return this.dot(other);
|
|
1149
1151
|
}
|
|
1150
|
-
else if (
|
|
1152
|
+
else if (isThis1D && other.shape.length === 2) {
|
|
1151
1153
|
return this.unsqueeze(0).mm(other).squeeze(0);
|
|
1152
1154
|
}
|
|
1153
|
-
else if (this.shape.length === 2 &&
|
|
1155
|
+
else if (this.shape.length === 2 && isOther1D) {
|
|
1154
1156
|
return this.mv(other);
|
|
1155
1157
|
}
|
|
1156
1158
|
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
1157
1159
|
return this.mm(other);
|
|
1158
1160
|
}
|
|
1159
|
-
|
|
1161
|
+
else if ((isThis1D && other.shape.length > 2) ||
|
|
1162
|
+
(isOther1D && this.shape.length > 2) ||
|
|
1163
|
+
(other.shape.length > 2 && this.shape.length > 2)) {
|
|
1164
|
+
// Append/prepend dims if needed
|
|
1165
|
+
const self = isThis1D ? this.unsqueeze(0) : this;
|
|
1166
|
+
other = isOther1D ? other.unsqueeze(1) : other;
|
|
1167
|
+
// Padding
|
|
1168
|
+
const [selfStrides, otherStrides, selfShape, otherShape] = Tensor.padShape(self.strides, other.strides, self.shape, other.shape);
|
|
1169
|
+
const lastDim = selfShape.length - 1;
|
|
1170
|
+
// Prepare data for broadcasting
|
|
1171
|
+
const batchA = self.value;
|
|
1172
|
+
const batchB = other.value;
|
|
1173
|
+
const batchARows = selfShape[lastDim - 1];
|
|
1174
|
+
const batchACols = selfShape[lastDim];
|
|
1175
|
+
const batchBRows = otherShape[lastDim - 1];
|
|
1176
|
+
const batchBCols = otherShape[lastDim];
|
|
1177
|
+
// Verify if can do matmul
|
|
1178
|
+
if (batchACols !== batchBRows)
|
|
1179
|
+
throw new Error("Invalid matrices shape for multiplication");
|
|
1180
|
+
// Prepare shape, strides, size info, but more importantly the offset-related data to loop through the outer, non-matrix dims
|
|
1181
|
+
// Self and other's offset data
|
|
1182
|
+
const selfOffsetShape = selfShape.slice(0, -2);
|
|
1183
|
+
const otherOffsetShape = otherShape.slice(0, -2);
|
|
1184
|
+
const selfOffsetStrides = selfStrides.slice(0, -2);
|
|
1185
|
+
const otherOffsetStrides = otherStrides.slice(0, -2);
|
|
1186
|
+
// The output's offset data
|
|
1187
|
+
const offsetShape = Tensor.broadcastShapes(selfOffsetShape, otherOffsetShape);
|
|
1188
|
+
const offsetSize = Tensor.shapeToSize(offsetShape);
|
|
1189
|
+
const offsetStrides = Tensor.getStrides(offsetShape);
|
|
1190
|
+
// Output shape, strides, size, value
|
|
1191
|
+
const outputShape = [...offsetShape, batchARows, batchBCols];
|
|
1192
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
1193
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
1194
|
+
const outputValue = new Array(outputSize).fill(0);
|
|
1195
|
+
// Loop through outer dims and do matmul on two outer-most dims
|
|
1196
|
+
for (let index = 0; index < offsetSize; index++) {
|
|
1197
|
+
const coords = Tensor.indexToCoords(index, offsetStrides);
|
|
1198
|
+
const offset = Tensor.coordsToIndex(coords, outputStrides.slice(0, -2));
|
|
1199
|
+
const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
|
|
1200
|
+
const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
|
|
1201
|
+
for (let i = 0; i < batchARows; i++) {
|
|
1202
|
+
for (let j = 0; j < batchBCols; j++) {
|
|
1203
|
+
for (let k = 0; k < batchACols; k++) {
|
|
1204
|
+
const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
|
|
1205
|
+
const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
|
|
1206
|
+
const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
|
|
1207
|
+
outputValue[outputIdx] += batchA[selfIdx] * batchB[otherIdx];
|
|
1208
|
+
}
|
|
1209
|
+
}
|
|
1210
|
+
}
|
|
1211
|
+
}
|
|
1212
|
+
const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides });
|
|
1213
|
+
if (this.requiresGrad) {
|
|
1214
|
+
out.requiresGrad = true;
|
|
1215
|
+
out.children.push(this);
|
|
1216
|
+
}
|
|
1217
|
+
if (other.requiresGrad) {
|
|
1218
|
+
out.requiresGrad = true;
|
|
1219
|
+
out.children.push(other);
|
|
1220
|
+
}
|
|
1221
|
+
if (out.requiresGrad) {
|
|
1222
|
+
out.gradFn = () => {
|
|
1223
|
+
other = other;
|
|
1224
|
+
const outGrad = out.grad.withGrad(false);
|
|
1225
|
+
const selfNoGrad = self.withGrad(false);
|
|
1226
|
+
const otherNoGrad = other.withGrad(false);
|
|
1227
|
+
if (this.requiresGrad)
|
|
1228
|
+
Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
|
|
1229
|
+
if (other.requiresGrad)
|
|
1230
|
+
Tensor.addGrad(other, selfNoGrad.transpose(lastDim - 1, lastDim).matmul(outGrad));
|
|
1231
|
+
};
|
|
1232
|
+
}
|
|
1233
|
+
return out;
|
|
1234
|
+
}
|
|
1160
1235
|
throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
|
|
1161
1236
|
}
|
|
1162
1237
|
// Utility to create a new tensor filled with a number
|