catniff 0.8.17 → 0.8.19
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +8 -1
- package/dist/core.js +81 -7
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -44,7 +44,7 @@ export declare class Tensor {
|
|
|
44
44
|
static coordsToIndex(coords: number[], strides: number[]): number;
|
|
45
45
|
static shapeToSize(shape: number[]): number;
|
|
46
46
|
static getResultDtype(type1: dtype, type2: dtype): dtype;
|
|
47
|
-
handleOther(other: Tensor | TensorValue): Tensor;
|
|
47
|
+
handleOther(other: Tensor | TensorValue, forceSameDevice?: boolean): Tensor;
|
|
48
48
|
static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
|
|
49
49
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
50
50
|
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
@@ -67,6 +67,7 @@ export declare class Tensor {
|
|
|
67
67
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
68
68
|
expand(newShape: number[]): Tensor;
|
|
69
69
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
70
|
+
stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
|
|
70
71
|
squeeze(dims?: number[] | number): Tensor;
|
|
71
72
|
unsqueeze(dim: number): Tensor;
|
|
72
73
|
sort(dim?: number, descending?: boolean): Tensor;
|
|
@@ -88,6 +89,12 @@ export declare class Tensor {
|
|
|
88
89
|
outIndex: number;
|
|
89
90
|
}) => number;
|
|
90
91
|
}): Tensor;
|
|
92
|
+
static reduceArg(tensor: Tensor, dim: number, keepDim: boolean, config: {
|
|
93
|
+
identity: number;
|
|
94
|
+
isBetter: (accumulator: number, value: number) => boolean;
|
|
95
|
+
}): Tensor;
|
|
96
|
+
argmax(dim: number, keepDim?: boolean): Tensor;
|
|
97
|
+
argmin(dim: number, keepDim?: boolean): Tensor;
|
|
91
98
|
sum(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
92
99
|
prod(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
93
100
|
mean(dims?: number[] | number, keepDims?: boolean): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -165,10 +165,10 @@ class Tensor {
|
|
|
165
165
|
}
|
|
166
166
|
return type2;
|
|
167
167
|
}
|
|
168
|
-
// Utility to handle other tensor if an op needs
|
|
169
|
-
handleOther(other) {
|
|
168
|
+
// Utility to handle other tensor if an op needs other operands
|
|
169
|
+
handleOther(other, forceSameDevice = true) {
|
|
170
170
|
if (other instanceof Tensor) {
|
|
171
|
-
if (this.device !== other.device) {
|
|
171
|
+
if (forceSameDevice && this.device !== other.device) {
|
|
172
172
|
throw new Error("Can not operate on tensors that are not on the same device");
|
|
173
173
|
}
|
|
174
174
|
return other;
|
|
@@ -602,7 +602,7 @@ class Tensor {
|
|
|
602
602
|
}
|
|
603
603
|
// Tensor indexing
|
|
604
604
|
index(indices) {
|
|
605
|
-
const tensorIndices = this.handleOther(indices).clone();
|
|
605
|
+
const tensorIndices = this.handleOther(indices, false).clone();
|
|
606
606
|
if (tensorIndices.shape.length === 0) {
|
|
607
607
|
return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
|
|
608
608
|
}
|
|
@@ -843,6 +843,15 @@ class Tensor {
|
|
|
843
843
|
}
|
|
844
844
|
return out;
|
|
845
845
|
}
|
|
846
|
+
// Tensor stack
|
|
847
|
+
stack(others, dim = 0) {
|
|
848
|
+
let out = this.unsqueeze(dim);
|
|
849
|
+
for (let index = 0; index < others.length; index++) {
|
|
850
|
+
const other = this.handleOther(others[index]).unsqueeze(dim);
|
|
851
|
+
out = out.cat(other, dim);
|
|
852
|
+
}
|
|
853
|
+
return out;
|
|
854
|
+
}
|
|
846
855
|
// Tensor squeeze
|
|
847
856
|
squeeze(dims) {
|
|
848
857
|
if (this.shape.length === 0)
|
|
@@ -1146,6 +1155,67 @@ class Tensor {
|
|
|
1146
1155
|
}
|
|
1147
1156
|
return keepDims ? out : out.squeeze(dims);
|
|
1148
1157
|
}
|
|
1158
|
+
// Generic arg reduction operation handler
|
|
1159
|
+
static reduceArg(tensor, dim, keepDim, config) {
|
|
1160
|
+
if (tensor.shape.length === 0)
|
|
1161
|
+
return tensor;
|
|
1162
|
+
// Handle negative indexing
|
|
1163
|
+
if (dim < 0) {
|
|
1164
|
+
dim += tensor.shape.length;
|
|
1165
|
+
}
|
|
1166
|
+
// If dimension out of bound, throw error
|
|
1167
|
+
if (dim >= tensor.shape.length || dim < 0) {
|
|
1168
|
+
throw new Error("Dimension does not exist to apply arg reduction");
|
|
1169
|
+
}
|
|
1170
|
+
const dimSize = tensor.shape[dim];
|
|
1171
|
+
const outputShape = tensor.shape.map((d, i) => dim === i ? 1 : d);
|
|
1172
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
1173
|
+
const outputSize = tensor.numel / dimSize;
|
|
1174
|
+
const bestValues = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
|
|
1175
|
+
const bestIndices = new Int32Array(outputSize).fill(0);
|
|
1176
|
+
const linearStrides = Tensor.getStrides(tensor.shape);
|
|
1177
|
+
// Forward pass
|
|
1178
|
+
for (let flatIndex = 0; flatIndex < tensor.numel; flatIndex++) {
|
|
1179
|
+
// Convert linear index to coordinates using contiguous strides
|
|
1180
|
+
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
|
|
1181
|
+
// Coordinate in current dim
|
|
1182
|
+
const dimCoord = coords[dim];
|
|
1183
|
+
// Convert coordinates to actual strided index
|
|
1184
|
+
const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
|
|
1185
|
+
// Convert coords to reduced index
|
|
1186
|
+
coords[dim] = 0;
|
|
1187
|
+
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
|
|
1188
|
+
// Check if current value is better to swap
|
|
1189
|
+
const val = tensor.value[realFlatIndex];
|
|
1190
|
+
if (config.isBetter(val, bestValues[outFlatIndex])) {
|
|
1191
|
+
bestValues[outFlatIndex] = val;
|
|
1192
|
+
bestIndices[outFlatIndex] = dimCoord;
|
|
1193
|
+
}
|
|
1194
|
+
}
|
|
1195
|
+
const out = new Tensor(bestIndices, {
|
|
1196
|
+
shape: outputShape,
|
|
1197
|
+
strides: outputStrides,
|
|
1198
|
+
offset: 0,
|
|
1199
|
+
numel: outputSize,
|
|
1200
|
+
device: tensor.device,
|
|
1201
|
+
dtype: "int32"
|
|
1202
|
+
});
|
|
1203
|
+
return keepDim ? out : out.squeeze(dim);
|
|
1204
|
+
}
|
|
1205
|
+
// Tensor argmax
|
|
1206
|
+
argmax(dim, keepDim = false) {
|
|
1207
|
+
return Tensor.reduceArg(this, dim, keepDim, {
|
|
1208
|
+
identity: -Infinity,
|
|
1209
|
+
isBetter: (a, b) => a > b
|
|
1210
|
+
});
|
|
1211
|
+
}
|
|
1212
|
+
// Tensor argmin
|
|
1213
|
+
argmin(dim, keepDim = false) {
|
|
1214
|
+
return Tensor.reduceArg(this, dim, keepDim, {
|
|
1215
|
+
identity: Infinity,
|
|
1216
|
+
isBetter: (a, b) => a < b
|
|
1217
|
+
});
|
|
1218
|
+
}
|
|
1149
1219
|
// Simplified reduction operations
|
|
1150
1220
|
sum(dims, keepDims = false) {
|
|
1151
1221
|
return Tensor.reduce(this, dims, keepDims, {
|
|
@@ -2543,11 +2613,15 @@ class Tensor {
|
|
|
2543
2613
|
// Returns the nicely Pytorch-like formatted string form
|
|
2544
2614
|
toString() {
|
|
2545
2615
|
const val = this.val();
|
|
2546
|
-
// Format a single number
|
|
2616
|
+
// Format a single number
|
|
2547
2617
|
const formatNum = (n) => {
|
|
2548
|
-
|
|
2618
|
+
// For ints with int dtype
|
|
2619
|
+
if (this.dtype.includes("int"))
|
|
2620
|
+
return n.toFixed(0);
|
|
2621
|
+
// For ints with float dtype
|
|
2622
|
+
if (Number.isInteger(n) && Math.abs(n) < 1e8)
|
|
2549
2623
|
return n.toFixed(0) + ".";
|
|
2550
|
-
|
|
2624
|
+
// For floats
|
|
2551
2625
|
return n.toString();
|
|
2552
2626
|
};
|
|
2553
2627
|
// Handle scalar
|