catniff 0.2.14 → 0.2.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/core.js +79 -4
  2. package/package.json +1 -1
package/dist/core.js CHANGED
@@ -1144,19 +1144,94 @@ class Tensor {
1144
1144
  // General matrix multiplication with different shapes
1145
1145
  matmul(other) {
1146
1146
  other = Tensor.forceTensor(other);
1147
- if (this.shape.length === 1 && other.shape.length === 1) {
1147
+ const isThis1D = this.shape.length === 1;
1148
+ const isOther1D = other.shape.length === 1;
1149
+ if (isThis1D && isOther1D) {
1148
1150
  return this.dot(other);
1149
1151
  }
1150
- else if (this.shape.length === 1 && other.shape.length === 2) {
1152
+ else if (isThis1D && other.shape.length === 2) {
1151
1153
  return this.unsqueeze(0).mm(other).squeeze(0);
1152
1154
  }
1153
- else if (this.shape.length === 2 && other.shape.length === 1) {
1155
+ else if (this.shape.length === 2 && isOther1D) {
1154
1156
  return this.mv(other);
1155
1157
  }
1156
1158
  else if (this.shape.length === 2 && other.shape.length === 2) {
1157
1159
  return this.mm(other);
1158
1160
  }
1159
- // Too lazy for batched matmul
1161
+ else if ((isThis1D && other.shape.length > 2) ||
1162
+ (isOther1D && this.shape.length > 2) ||
1163
+ (other.shape.length > 2 && this.shape.length > 2)) {
1164
+ // Append/prepend dims if needed
1165
+ const self = isThis1D ? this.unsqueeze(0) : this;
1166
+ other = isOther1D ? other.unsqueeze(1) : other;
1167
+ // Padding
1168
+ const [selfStrides, otherStrides, selfShape, otherShape] = Tensor.padShape(self.strides, other.strides, self.shape, other.shape);
1169
+ const lastDim = selfShape.length - 1;
1170
+ // Prepare data for broadcasting
1171
+ const batchA = self.value;
1172
+ const batchB = other.value;
1173
+ const batchARows = selfShape[lastDim - 1];
1174
+ const batchACols = selfShape[lastDim];
1175
+ const batchBRows = otherShape[lastDim - 1];
1176
+ const batchBCols = otherShape[lastDim];
1177
+ // Verify if can do matmul
1178
+ if (batchACols !== batchBRows)
1179
+ throw new Error("Invalid matrices shape for multiplication");
1180
+ // Prepare shape, strides, size info, but more importantly the offset-related data to loop through the outer, non-matrix dims
1181
+ // Self and other's offset data
1182
+ const selfOffsetShape = selfShape.slice(0, -2);
1183
+ const otherOffsetShape = otherShape.slice(0, -2);
1184
+ const selfOffsetStrides = selfStrides.slice(0, -2);
1185
+ const otherOffsetStrides = otherStrides.slice(0, -2);
1186
+ // The output's offset data
1187
+ const offsetShape = Tensor.broadcastShapes(selfOffsetShape, otherOffsetShape);
1188
+ const offsetSize = Tensor.shapeToSize(offsetShape);
1189
+ const offsetStrides = Tensor.getStrides(offsetShape);
1190
+ // Output shape, strides, size, value
1191
+ const outputShape = [...offsetShape, batchARows, batchBCols];
1192
+ const outputStrides = Tensor.getStrides(outputShape);
1193
+ const outputSize = Tensor.shapeToSize(outputShape);
1194
+ const outputValue = new Array(outputSize).fill(0);
1195
+ // Loop through outer dims and do matmul on two outer-most dims
1196
+ for (let index = 0; index < offsetSize; index++) {
1197
+ const coords = Tensor.indexToCoords(index, offsetStrides);
1198
+ const offset = Tensor.coordsToIndex(coords, outputStrides.slice(0, -2));
1199
+ const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
1200
+ const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
1201
+ for (let i = 0; i < batchARows; i++) {
1202
+ for (let j = 0; j < batchBCols; j++) {
1203
+ for (let k = 0; k < batchACols; k++) {
1204
+ const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
1205
+ const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
1206
+ const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
1207
+ outputValue[outputIdx] += batchA[selfIdx] * batchB[otherIdx];
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+ const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides });
1213
+ if (this.requiresGrad) {
1214
+ out.requiresGrad = true;
1215
+ out.children.push(this);
1216
+ }
1217
+ if (other.requiresGrad) {
1218
+ out.requiresGrad = true;
1219
+ out.children.push(other);
1220
+ }
1221
+ if (out.requiresGrad) {
1222
+ out.gradFn = () => {
1223
+ other = other;
1224
+ const outGrad = out.grad.withGrad(false);
1225
+ const selfNoGrad = self.withGrad(false);
1226
+ const otherNoGrad = other.withGrad(false);
1227
+ if (this.requiresGrad)
1228
+ Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
1229
+ if (other.requiresGrad)
1230
+ Tensor.addGrad(other, selfNoGrad.transpose(lastDim - 1, lastDim).matmul(outGrad));
1231
+ };
1232
+ }
1233
+ return out;
1234
+ }
1160
1235
  throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
1161
1236
  }
1162
1237
  // Utility to create a new tensor filled with a number
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.2.14",
3
+ "version": "0.2.15",
4
4
  "description": "A small Torch-like deep learning framework for Javascript with tensor and autograd support",
5
5
  "main": "index.js",
6
6
  "scripts": {