catniff 0.8.17 → 0.8.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -44,7 +44,7 @@ export declare class Tensor {
44
44
  static coordsToIndex(coords: number[], strides: number[]): number;
45
45
  static shapeToSize(shape: number[]): number;
46
46
  static getResultDtype(type1: dtype, type2: dtype): dtype;
47
- handleOther(other: Tensor | TensorValue): Tensor;
47
+ handleOther(other: Tensor | TensorValue, forceSameDevice?: boolean): Tensor;
48
48
  static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
49
49
  static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
50
50
  elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
@@ -67,6 +67,7 @@ export declare class Tensor {
67
67
  chunk(chunks: number, dim?: number): Tensor[];
68
68
  expand(newShape: number[]): Tensor;
69
69
  cat(other: Tensor | TensorValue, dim?: number): Tensor;
70
+ stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
70
71
  squeeze(dims?: number[] | number): Tensor;
71
72
  unsqueeze(dim: number): Tensor;
72
73
  sort(dim?: number, descending?: boolean): Tensor;
@@ -88,6 +89,12 @@ export declare class Tensor {
88
89
  outIndex: number;
89
90
  }) => number;
90
91
  }): Tensor;
92
+ static reduceArg(tensor: Tensor, dim: number, keepDim: boolean, config: {
93
+ identity: number;
94
+ isBetter: (accumulator: number, value: number) => boolean;
95
+ }): Tensor;
96
+ argmax(dim: number, keepDim?: boolean): Tensor;
97
+ argmin(dim: number, keepDim?: boolean): Tensor;
91
98
  sum(dims?: number[] | number, keepDims?: boolean): Tensor;
92
99
  prod(dims?: number[] | number, keepDims?: boolean): Tensor;
93
100
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
package/dist/core.js CHANGED
@@ -165,10 +165,10 @@ class Tensor {
165
165
  }
166
166
  return type2;
167
167
  }
168
- // Utility to handle other tensor if an op needs a second operand
169
- handleOther(other) {
168
+ // Utility to handle other tensor if an op needs other operands
169
+ handleOther(other, forceSameDevice = true) {
170
170
  if (other instanceof Tensor) {
171
- if (this.device !== other.device) {
171
+ if (forceSameDevice && this.device !== other.device) {
172
172
  throw new Error("Can not operate on tensors that are not on the same device");
173
173
  }
174
174
  return other;
@@ -602,7 +602,7 @@ class Tensor {
602
602
  }
603
603
  // Tensor indexing
604
604
  index(indices) {
605
- const tensorIndices = this.handleOther(indices).clone();
605
+ const tensorIndices = this.handleOther(indices, false).clone();
606
606
  if (tensorIndices.shape.length === 0) {
607
607
  return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
608
608
  }
@@ -843,6 +843,15 @@ class Tensor {
843
843
  }
844
844
  return out;
845
845
  }
846
+ // Tensor stack
847
+ stack(others, dim = 0) {
848
+ let out = this.unsqueeze(dim);
849
+ for (let index = 0; index < others.length; index++) {
850
+ const other = this.handleOther(others[index]).unsqueeze(dim);
851
+ out = out.cat(other, dim);
852
+ }
853
+ return out;
854
+ }
846
855
  // Tensor squeeze
847
856
  squeeze(dims) {
848
857
  if (this.shape.length === 0)
@@ -1146,6 +1155,67 @@ class Tensor {
1146
1155
  }
1147
1156
  return keepDims ? out : out.squeeze(dims);
1148
1157
  }
1158
+ // Generic arg reduction operation handler
1159
+ static reduceArg(tensor, dim, keepDim, config) {
1160
+ if (tensor.shape.length === 0)
1161
+ return tensor;
1162
+ // Handle negative indexing
1163
+ if (dim < 0) {
1164
+ dim += tensor.shape.length;
1165
+ }
1166
+ // If dimension out of bound, throw error
1167
+ if (dim >= tensor.shape.length || dim < 0) {
1168
+ throw new Error("Dimension does not exist to apply arg reduction");
1169
+ }
1170
+ const dimSize = tensor.shape[dim];
1171
+ const outputShape = tensor.shape.map((d, i) => dim === i ? 1 : d);
1172
+ const outputStrides = Tensor.getStrides(outputShape);
1173
+ const outputSize = tensor.numel / dimSize;
1174
+ const bestValues = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
1175
+ const bestIndices = new Int32Array(outputSize).fill(0);
1176
+ const linearStrides = Tensor.getStrides(tensor.shape);
1177
+ // Forward pass
1178
+ for (let flatIndex = 0; flatIndex < tensor.numel; flatIndex++) {
1179
+ // Convert linear index to coordinates using contiguous strides
1180
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
1181
+ // Coordinate in current dim
1182
+ const dimCoord = coords[dim];
1183
+ // Convert coordinates to actual strided index
1184
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
1185
+ // Convert coords to reduced index
1186
+ coords[dim] = 0;
1187
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
1188
+ // Check if current value is better to swap
1189
+ const val = tensor.value[realFlatIndex];
1190
+ if (config.isBetter(val, bestValues[outFlatIndex])) {
1191
+ bestValues[outFlatIndex] = val;
1192
+ bestIndices[outFlatIndex] = dimCoord;
1193
+ }
1194
+ }
1195
+ const out = new Tensor(bestIndices, {
1196
+ shape: outputShape,
1197
+ strides: outputStrides,
1198
+ offset: 0,
1199
+ numel: outputSize,
1200
+ device: tensor.device,
1201
+ dtype: "int32"
1202
+ });
1203
+ return keepDim ? out : out.squeeze(dim);
1204
+ }
1205
+ // Tensor argmax
1206
+ argmax(dim, keepDim = false) {
1207
+ return Tensor.reduceArg(this, dim, keepDim, {
1208
+ identity: -Infinity,
1209
+ isBetter: (a, b) => a > b
1210
+ });
1211
+ }
1212
+ // Tensor argmin
1213
+ argmin(dim, keepDim = false) {
1214
+ return Tensor.reduceArg(this, dim, keepDim, {
1215
+ identity: Infinity,
1216
+ isBetter: (a, b) => a < b
1217
+ });
1218
+ }
1149
1219
  // Simplified reduction operations
1150
1220
  sum(dims, keepDims = false) {
1151
1221
  return Tensor.reduce(this, dims, keepDims, {
@@ -2543,11 +2613,15 @@ class Tensor {
2543
2613
  // Returns the nicely Pytorch-like formatted string form
2544
2614
  toString() {
2545
2615
  const val = this.val();
2546
- // Format a single number (integers get trailing dot)
2616
+ // Format a single number
2547
2617
  const formatNum = (n) => {
2548
- if (Number.isInteger(n) && Math.abs(n) < 1e8) {
2618
+ // For ints with int dtype
2619
+ if (this.dtype.includes("int"))
2620
+ return n.toFixed(0);
2621
+ // For ints with float dtype
2622
+ if (Number.isInteger(n) && Math.abs(n) < 1e8)
2549
2623
  return n.toFixed(0) + ".";
2550
- }
2624
+ // For floats
2551
2625
  return n.toString();
2552
2626
  };
2553
2627
  // Handle scalar
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.17",
3
+ "version": "0.8.19",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {