catniff 0.8.16 → 0.8.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +2 -1
- package/dist/core.js +13 -4
- package/dist/lrscheduler.d.ts +12 -0
- package/dist/lrscheduler.js +27 -2
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -44,7 +44,7 @@ export declare class Tensor {
|
|
|
44
44
|
static coordsToIndex(coords: number[], strides: number[]): number;
|
|
45
45
|
static shapeToSize(shape: number[]): number;
|
|
46
46
|
static getResultDtype(type1: dtype, type2: dtype): dtype;
|
|
47
|
-
handleOther(other: Tensor | TensorValue): Tensor;
|
|
47
|
+
handleOther(other: Tensor | TensorValue, forceSameDevice?: boolean): Tensor;
|
|
48
48
|
static elementWiseAB(tA: Tensor, tB: Tensor, op: (tA: number, tB: number) => number): Tensor;
|
|
49
49
|
static elementWiseSelf(tA: Tensor, op: (tA: number) => number): Tensor;
|
|
50
50
|
elementWiseABDAG(other: TensorValue | Tensor, op: (a: number, b: number) => number, thisGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor, otherGrad?: (self: Tensor, other: Tensor, outGrad: Tensor) => Tensor): Tensor;
|
|
@@ -67,6 +67,7 @@ export declare class Tensor {
|
|
|
67
67
|
chunk(chunks: number, dim?: number): Tensor[];
|
|
68
68
|
expand(newShape: number[]): Tensor;
|
|
69
69
|
cat(other: Tensor | TensorValue, dim?: number): Tensor;
|
|
70
|
+
stack(others: (Tensor | TensorValue)[], dim?: number): Tensor;
|
|
70
71
|
squeeze(dims?: number[] | number): Tensor;
|
|
71
72
|
unsqueeze(dim: number): Tensor;
|
|
72
73
|
sort(dim?: number, descending?: boolean): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -165,10 +165,10 @@ class Tensor {
|
|
|
165
165
|
}
|
|
166
166
|
return type2;
|
|
167
167
|
}
|
|
168
|
-
// Utility to handle other tensor if an op needs
|
|
169
|
-
handleOther(other) {
|
|
168
|
+
// Utility to handle other tensor if an op needs other operands
|
|
169
|
+
handleOther(other, forceSameDevice = true) {
|
|
170
170
|
if (other instanceof Tensor) {
|
|
171
|
-
if (this.device !== other.device) {
|
|
171
|
+
if (forceSameDevice && this.device !== other.device) {
|
|
172
172
|
throw new Error("Can not operate on tensors that are not on the same device");
|
|
173
173
|
}
|
|
174
174
|
return other;
|
|
@@ -602,7 +602,7 @@ class Tensor {
|
|
|
602
602
|
}
|
|
603
603
|
// Tensor indexing
|
|
604
604
|
index(indices) {
|
|
605
|
-
const tensorIndices = this.handleOther(indices).clone();
|
|
605
|
+
const tensorIndices = this.handleOther(indices, false).clone();
|
|
606
606
|
if (tensorIndices.shape.length === 0) {
|
|
607
607
|
return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
|
|
608
608
|
}
|
|
@@ -843,6 +843,15 @@ class Tensor {
|
|
|
843
843
|
}
|
|
844
844
|
return out;
|
|
845
845
|
}
|
|
846
|
+
// Tensor stack
|
|
847
|
+
stack(others, dim = 0) {
|
|
848
|
+
let out = this.unsqueeze(dim);
|
|
849
|
+
for (let index = 0; index < others.length; index++) {
|
|
850
|
+
const other = this.handleOther(others[index]).unsqueeze(dim);
|
|
851
|
+
out = out.cat(other, dim);
|
|
852
|
+
}
|
|
853
|
+
return out;
|
|
854
|
+
}
|
|
846
855
|
// Tensor squeeze
|
|
847
856
|
squeeze(dims) {
|
|
848
857
|
if (this.shape.length === 0)
|
package/dist/lrscheduler.d.ts
CHANGED
|
@@ -30,8 +30,20 @@ export declare class CosineAnnealingLR {
|
|
|
30
30
|
constructor(optimizer: OptimizerWithLR, TMax: number, etaMin?: number, lastEpoch?: number);
|
|
31
31
|
step(): void;
|
|
32
32
|
}
|
|
33
|
+
export interface Scheduler {
|
|
34
|
+
step: Function;
|
|
35
|
+
}
|
|
36
|
+
export declare class SequentialLR {
|
|
37
|
+
optimizer: OptimizerWithLR;
|
|
38
|
+
schedulers: Scheduler[];
|
|
39
|
+
milestones: number[];
|
|
40
|
+
lastEpoch: number;
|
|
41
|
+
constructor(optimizer: OptimizerWithLR, schedulers: Scheduler[], milestones: number[], lastEpoch?: number);
|
|
42
|
+
step(): void;
|
|
43
|
+
}
|
|
33
44
|
export declare const LRScheduler: {
|
|
34
45
|
StepLR: typeof StepLR;
|
|
35
46
|
LinearLR: typeof LinearLR;
|
|
36
47
|
CosineAnnealingLR: typeof CosineAnnealingLR;
|
|
48
|
+
SequentialLR: typeof SequentialLR;
|
|
37
49
|
};
|
package/dist/lrscheduler.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.LRScheduler = exports.CosineAnnealingLR = exports.LinearLR = exports.StepLR = void 0;
|
|
3
|
+
exports.LRScheduler = exports.SequentialLR = exports.CosineAnnealingLR = exports.LinearLR = exports.StepLR = void 0;
|
|
4
4
|
class StepLR {
|
|
5
5
|
optimizer;
|
|
6
6
|
stepSize;
|
|
@@ -84,8 +84,33 @@ class CosineAnnealingLR {
|
|
|
84
84
|
}
|
|
85
85
|
}
|
|
86
86
|
exports.CosineAnnealingLR = CosineAnnealingLR;
|
|
87
|
+
class SequentialLR {
|
|
88
|
+
optimizer;
|
|
89
|
+
schedulers;
|
|
90
|
+
milestones;
|
|
91
|
+
lastEpoch;
|
|
92
|
+
constructor(optimizer, schedulers, milestones, lastEpoch = -1) {
|
|
93
|
+
this.optimizer = optimizer;
|
|
94
|
+
this.schedulers = schedulers;
|
|
95
|
+
this.milestones = milestones;
|
|
96
|
+
this.lastEpoch = lastEpoch;
|
|
97
|
+
}
|
|
98
|
+
step() {
|
|
99
|
+
this.lastEpoch++;
|
|
100
|
+
let schedulerIndex = this.schedulers.length - 1; // default to last
|
|
101
|
+
for (let index = 0; index < this.milestones.length; index++) {
|
|
102
|
+
if (this.lastEpoch < this.milestones[index]) {
|
|
103
|
+
schedulerIndex = index;
|
|
104
|
+
break;
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
this.schedulers[schedulerIndex].step();
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
exports.SequentialLR = SequentialLR;
|
|
87
111
|
exports.LRScheduler = {
|
|
88
112
|
StepLR,
|
|
89
113
|
LinearLR,
|
|
90
|
-
CosineAnnealingLR
|
|
114
|
+
CosineAnnealingLR,
|
|
115
|
+
SequentialLR
|
|
91
116
|
};
|