catniff 0.8.18 → 0.8.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -89,6 +89,12 @@ export declare class Tensor {
89
89
  outIndex: number;
90
90
  }) => number;
91
91
  }): Tensor;
92
+ static reduceArg(tensor: Tensor, dim: number, keepDim: boolean, config: {
93
+ identity: number;
94
+ isBetter: (accumulator: number, value: number) => boolean;
95
+ }): Tensor;
96
+ argmax(dim: number, keepDim?: boolean): Tensor;
97
+ argmin(dim: number, keepDim?: boolean): Tensor;
92
98
  sum(dims?: number[] | number, keepDims?: boolean): Tensor;
93
99
  prod(dims?: number[] | number, keepDims?: boolean): Tensor;
94
100
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
package/dist/core.js CHANGED
@@ -1155,6 +1155,67 @@ class Tensor {
1155
1155
  }
1156
1156
  return keepDims ? out : out.squeeze(dims);
1157
1157
  }
1158
+ // Generic arg reduction operation handler
1159
+ static reduceArg(tensor, dim, keepDim, config) {
1160
+ if (tensor.shape.length === 0)
1161
+ return tensor;
1162
+ // Handle negative indexing
1163
+ if (dim < 0) {
1164
+ dim += tensor.shape.length;
1165
+ }
1166
+ // If dimension out of bound, throw error
1167
+ if (dim >= tensor.shape.length || dim < 0) {
1168
+ throw new Error("Dimension does not exist to apply arg reduction");
1169
+ }
1170
+ const dimSize = tensor.shape[dim];
1171
+ const outputShape = tensor.shape.map((d, i) => dim === i ? 1 : d);
1172
+ const outputStrides = Tensor.getStrides(outputShape);
1173
+ const outputSize = tensor.numel / dimSize;
1174
+ const bestValues = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
1175
+ const bestIndices = new Int32Array(outputSize).fill(0);
1176
+ const linearStrides = Tensor.getStrides(tensor.shape);
1177
+ // Forward pass
1178
+ for (let flatIndex = 0; flatIndex < tensor.numel; flatIndex++) {
1179
+ // Convert linear index to coordinates using contiguous strides
1180
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
1181
+ // Coordinate in current dim
1182
+ const dimCoord = coords[dim];
1183
+ // Convert coordinates to actual strided index
1184
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
1185
+ // Convert coords to reduced index
1186
+ coords[dim] = 0;
1187
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
1188
+ // Check if current value is better to swap
1189
+ const val = tensor.value[realFlatIndex];
1190
+ if (config.isBetter(val, bestValues[outFlatIndex])) {
1191
+ bestValues[outFlatIndex] = val;
1192
+ bestIndices[outFlatIndex] = dimCoord;
1193
+ }
1194
+ }
1195
+ const out = new Tensor(bestIndices, {
1196
+ shape: outputShape,
1197
+ strides: outputStrides,
1198
+ offset: 0,
1199
+ numel: outputSize,
1200
+ device: tensor.device,
1201
+ dtype: "int32"
1202
+ });
1203
+ return keepDim ? out : out.squeeze(dim);
1204
+ }
1205
+ // Tensor argmax
1206
+ argmax(dim, keepDim = false) {
1207
+ return Tensor.reduceArg(this, dim, keepDim, {
1208
+ identity: -Infinity,
1209
+ isBetter: (a, b) => a > b
1210
+ });
1211
+ }
1212
+ // Tensor argmin
1213
+ argmin(dim, keepDim = false) {
1214
+ return Tensor.reduceArg(this, dim, keepDim, {
1215
+ identity: Infinity,
1216
+ isBetter: (a, b) => a < b
1217
+ });
1218
+ }
1158
1219
  // Simplified reduction operations
1159
1220
  sum(dims, keepDims = false) {
1160
1221
  return Tensor.reduce(this, dims, keepDims, {
@@ -2552,11 +2613,15 @@ class Tensor {
2552
2613
  // Returns the nicely Pytorch-like formatted string form
2553
2614
  toString() {
2554
2615
  const val = this.val();
2555
- // Format a single number (integers get trailing dot)
2616
+ // Format a single number
2556
2617
  const formatNum = (n) => {
2557
- if (Number.isInteger(n) && Math.abs(n) < 1e8) {
2618
+ // For ints with int dtype
2619
+ if (this.dtype.includes("int"))
2620
+ return n.toFixed(0);
2621
+ // For ints with float dtype
2622
+ if (Number.isInteger(n) && Math.abs(n) < 1e8)
2558
2623
  return n.toFixed(0) + ".";
2559
- }
2624
+ // For floats
2560
2625
  return n.toString();
2561
2626
  };
2562
2627
  // Handle scalar
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.8.18",
3
+ "version": "0.8.19",
4
4
  "description": "Torch-like deep learning framework for Javascript",
5
5
  "main": "./dist/index.js",
6
6
  "scripts": {