catniff 0.8.0 → 0.8.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +4 -3
- package/dist/core.js +122 -16
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -69,19 +69,20 @@ export declare class Tensor {
|
|
|
69
69
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
70
70
|
squeeze(dims?: number[] | number): Tensor;
|
|
71
71
|
unsqueeze(dim: number): Tensor;
|
|
72
|
+
sort(dim?: number, descending?: boolean): Tensor;
|
|
73
|
+
topk(k: number, dim?: number, largest?: boolean): Tensor;
|
|
72
74
|
static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
|
|
73
75
|
identity: number;
|
|
74
76
|
operation: (accumulator: number, value: number) => number;
|
|
75
|
-
needsCounters?: boolean;
|
|
76
77
|
postProcess?: (options: {
|
|
77
78
|
values: MemoryBuffer;
|
|
78
|
-
|
|
79
|
+
dimSize: number;
|
|
79
80
|
}) => void;
|
|
80
81
|
needsShareCounts?: boolean;
|
|
81
82
|
gradientFn: (options: {
|
|
82
83
|
outputValue: MemoryBuffer;
|
|
83
84
|
originalValue: MemoryBuffer;
|
|
84
|
-
|
|
85
|
+
dimSize: number;
|
|
85
86
|
shareCounts: MemoryBuffer;
|
|
86
87
|
realIndex: number;
|
|
87
88
|
outIndex: number;
|
package/dist/core.js
CHANGED
|
@@ -208,13 +208,16 @@ class Tensor {
|
|
|
208
208
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
209
209
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
210
210
|
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
|
|
211
|
+
// Check fast path conditions of two tensors
|
|
212
|
+
const aFastPath = tA.isContiguous() && tA.numel === outputSize;
|
|
213
|
+
const bFastPath = tB.isContiguous() && tB.numel === outputSize;
|
|
211
214
|
for (let i = 0; i < outputSize; i++) {
|
|
212
215
|
// Get coordinates from 1D index
|
|
213
|
-
const coordsOutput = Tensor.indexToCoords(i, outputStrides);
|
|
216
|
+
const coordsOutput = aFastPath && bFastPath ? [] : Tensor.indexToCoords(i, outputStrides);
|
|
214
217
|
// Convert the coordinates to 1D index of flattened A with respect to A's shape
|
|
215
|
-
const indexA = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
|
|
218
|
+
const indexA = aFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
|
|
216
219
|
// Convert the coordinates to 1D index of flattened B with respect to B's shape
|
|
217
|
-
const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
|
|
220
|
+
const indexB = bFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
|
|
218
221
|
// Calculate with op
|
|
219
222
|
outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
|
|
220
223
|
}
|
|
@@ -633,7 +636,7 @@ class Tensor {
|
|
|
633
636
|
let start = range[0] ?? 0;
|
|
634
637
|
let end = range[1] ?? dimSize;
|
|
635
638
|
let step = range[2] ?? 1;
|
|
636
|
-
// Handle negative
|
|
639
|
+
// Handle negative indicesoutGrad
|
|
637
640
|
if (start < 0)
|
|
638
641
|
start += dimSize;
|
|
639
642
|
if (end < 0)
|
|
@@ -934,6 +937,114 @@ class Tensor {
|
|
|
934
937
|
}
|
|
935
938
|
return out;
|
|
936
939
|
}
|
|
940
|
+
// Tensor sort
|
|
941
|
+
sort(dim = -1, descending = false) {
|
|
942
|
+
if (dim < 0) {
|
|
943
|
+
dim += this.shape.length;
|
|
944
|
+
}
|
|
945
|
+
// If dimension out of bound, throw error
|
|
946
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
947
|
+
throw new Error("Dimension do not exist to sort");
|
|
948
|
+
}
|
|
949
|
+
// Copy if not contiguous
|
|
950
|
+
const outputSize = this.numel;
|
|
951
|
+
const outputShape = this.shape;
|
|
952
|
+
let outputValue, outputStrides;
|
|
953
|
+
if (this.isContiguous()) {
|
|
954
|
+
outputValue = [...this.value];
|
|
955
|
+
outputStrides = this.strides;
|
|
956
|
+
}
|
|
957
|
+
else {
|
|
958
|
+
outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
959
|
+
outputStrides = Tensor.getStrides(outputShape);
|
|
960
|
+
for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
|
|
961
|
+
const coords = Tensor.indexToCoords(flatIndex, outputStrides);
|
|
962
|
+
const originalIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
963
|
+
outputValue[flatIndex] = this.value[originalIndex + this.offset];
|
|
964
|
+
}
|
|
965
|
+
}
|
|
966
|
+
// Calculate dimensions for gather-scatter
|
|
967
|
+
const dimSize = outputShape[dim];
|
|
968
|
+
const outerSize = outputShape.slice(0, dim).reduce((a, b) => a * b, 1);
|
|
969
|
+
const innerSize = outputShape.slice(dim + 1).reduce((a, b) => a * b, 1);
|
|
970
|
+
// Store permutation indices for gradient
|
|
971
|
+
const permutation = new Array(outputSize);
|
|
972
|
+
// Sort each group independently
|
|
973
|
+
for (let outer = 0; outer < outerSize; outer++) {
|
|
974
|
+
for (let inner = 0; inner < innerSize; inner++) {
|
|
975
|
+
const group = [];
|
|
976
|
+
for (let i = 0; i < dimSize; i++) {
|
|
977
|
+
const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
|
|
978
|
+
group.push({
|
|
979
|
+
value: outputValue[flatIdx],
|
|
980
|
+
dimIdx: i
|
|
981
|
+
});
|
|
982
|
+
}
|
|
983
|
+
// Sort this group by value
|
|
984
|
+
group.sort((a, b) => descending ? b.value - a.value : a.value - b.value);
|
|
985
|
+
// Scatter: write back sorted values and record permutation
|
|
986
|
+
for (let i = 0; i < dimSize; i++) {
|
|
987
|
+
const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
|
|
988
|
+
outputValue[flatIdx] = group[i].value;
|
|
989
|
+
// Record where this element came from (for gradient)
|
|
990
|
+
const originalFlatIdx = outer * (dimSize * innerSize) + group[i].dimIdx * innerSize + inner;
|
|
991
|
+
permutation[flatIdx] = originalFlatIdx;
|
|
992
|
+
}
|
|
993
|
+
}
|
|
994
|
+
}
|
|
995
|
+
const out = new Tensor(outputValue, {
|
|
996
|
+
shape: outputShape,
|
|
997
|
+
strides: outputStrides,
|
|
998
|
+
offset: 0,
|
|
999
|
+
numel: outputSize,
|
|
1000
|
+
device: this.device,
|
|
1001
|
+
dtype: this.dtype
|
|
1002
|
+
});
|
|
1003
|
+
// Gradient setup
|
|
1004
|
+
if (this.requiresGrad) {
|
|
1005
|
+
out.requiresGrad = true;
|
|
1006
|
+
out.children.push(this);
|
|
1007
|
+
out.gradFn = () => {
|
|
1008
|
+
const outGrad = out.grad;
|
|
1009
|
+
// Scatter output gradients back to original positions
|
|
1010
|
+
const inputGradValue = new dtype_1.TypedArray[this.dtype](outputSize);
|
|
1011
|
+
for (let sortedIdx = 0; sortedIdx < outputSize; sortedIdx++) {
|
|
1012
|
+
const originalIdx = permutation[sortedIdx];
|
|
1013
|
+
inputGradValue[originalIdx] = outGrad.value[sortedIdx];
|
|
1014
|
+
}
|
|
1015
|
+
const inputGrad = new Tensor(inputGradValue, {
|
|
1016
|
+
shape: outputShape,
|
|
1017
|
+
strides: outputStrides,
|
|
1018
|
+
offset: 0,
|
|
1019
|
+
numel: outputSize,
|
|
1020
|
+
device: this.device,
|
|
1021
|
+
dtype: this.dtype
|
|
1022
|
+
});
|
|
1023
|
+
Tensor.addGrad(this, inputGrad);
|
|
1024
|
+
};
|
|
1025
|
+
}
|
|
1026
|
+
return out;
|
|
1027
|
+
}
|
|
1028
|
+
// Top-k sampling
|
|
1029
|
+
topk(k, dim = -1, largest = true) {
|
|
1030
|
+
if (dim < 0) {
|
|
1031
|
+
dim += this.shape.length;
|
|
1032
|
+
}
|
|
1033
|
+
// If dimension out of bound, throw error
|
|
1034
|
+
if (dim >= this.shape.length || dim < 0) {
|
|
1035
|
+
throw new Error("Dimension do not exist to get topk");
|
|
1036
|
+
}
|
|
1037
|
+
const dimRanges = new Array(this.shape.length);
|
|
1038
|
+
for (let index = 0; index < dimRanges.length; index++) {
|
|
1039
|
+
if (index === dim) {
|
|
1040
|
+
dimRanges[index] = [0, k];
|
|
1041
|
+
}
|
|
1042
|
+
else {
|
|
1043
|
+
dimRanges[index] = [];
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
return this.sort(dim, largest).slice(dimRanges);
|
|
1047
|
+
}
|
|
937
1048
|
// Generic reduction operation handler
|
|
938
1049
|
static reduce(tensor, dims, keepDims, config) {
|
|
939
1050
|
if (tensor.shape.length === 0)
|
|
@@ -953,11 +1064,11 @@ class Tensor {
|
|
|
953
1064
|
}
|
|
954
1065
|
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
955
1066
|
}
|
|
1067
|
+
const dimSize = tensor.shape[dims];
|
|
956
1068
|
const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
957
1069
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
958
|
-
const outputSize =
|
|
1070
|
+
const outputSize = tensor.numel / dimSize;
|
|
959
1071
|
const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
|
|
960
|
-
const outputCounters = config.needsCounters ? new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0) : new dtype_1.TypedArray[tensor.dtype]();
|
|
961
1072
|
const originalSize = tensor.numel;
|
|
962
1073
|
const originalValue = tensor.value;
|
|
963
1074
|
const linearStrides = Tensor.getStrides(tensor.shape);
|
|
@@ -972,14 +1083,10 @@ class Tensor {
|
|
|
972
1083
|
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
|
|
973
1084
|
// Apply op
|
|
974
1085
|
outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
|
|
975
|
-
// Count el if needed
|
|
976
|
-
if (config.needsCounters) {
|
|
977
|
-
outputCounters[outFlatIndex]++;
|
|
978
|
-
}
|
|
979
1086
|
}
|
|
980
1087
|
// Post-process if needed (e.g., divide by count for mean)
|
|
981
1088
|
if (config.postProcess) {
|
|
982
|
-
config.postProcess({ values: outputValue,
|
|
1089
|
+
config.postProcess({ values: outputValue, dimSize });
|
|
983
1090
|
}
|
|
984
1091
|
const out = new Tensor(outputValue, {
|
|
985
1092
|
shape: outputShape,
|
|
@@ -1021,7 +1128,7 @@ class Tensor {
|
|
|
1021
1128
|
gradValue[flatIndex] = config.gradientFn({
|
|
1022
1129
|
outputValue,
|
|
1023
1130
|
originalValue: tensor.value,
|
|
1024
|
-
|
|
1131
|
+
dimSize,
|
|
1025
1132
|
shareCounts,
|
|
1026
1133
|
realIndex: realFlatIndex,
|
|
1027
1134
|
outIndex: outFlatIndex
|
|
@@ -1058,13 +1165,12 @@ class Tensor {
|
|
|
1058
1165
|
return Tensor.reduce(this, dims, keepDims, {
|
|
1059
1166
|
identity: 0,
|
|
1060
1167
|
operation: (a, b) => a + b,
|
|
1061
|
-
|
|
1062
|
-
postProcess: ({ values, counters }) => {
|
|
1168
|
+
postProcess: ({ values, dimSize }) => {
|
|
1063
1169
|
for (let i = 0; i < values.length; i++) {
|
|
1064
|
-
values[i] /=
|
|
1170
|
+
values[i] /= dimSize;
|
|
1065
1171
|
}
|
|
1066
1172
|
},
|
|
1067
|
-
gradientFn: ({
|
|
1173
|
+
gradientFn: ({ dimSize }) => 1 / dimSize
|
|
1068
1174
|
});
|
|
1069
1175
|
}
|
|
1070
1176
|
max(dims, keepDims = false) {
|