catniff 0.5.11 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -3,6 +3,8 @@ export type TensorValue = number | TensorValue[];
3
3
  export interface TensorOptions {
4
4
  shape?: readonly number[];
5
5
  strides?: readonly number[];
6
+ offset?: number;
7
+ numel?: number;
6
8
  grad?: Tensor;
7
9
  requiresGrad?: boolean;
8
10
  gradFn?: Function;
@@ -13,6 +15,8 @@ export declare class Tensor {
13
15
  value: number[] | number;
14
16
  readonly shape: readonly number[];
15
17
  readonly strides: readonly number[];
18
+ offset: number;
19
+ numel: number;
16
20
  grad?: Tensor;
17
21
  requiresGrad: boolean;
18
22
  gradFn: Function;
@@ -40,11 +44,39 @@ export declare class Tensor {
40
44
  elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
41
45
  handleOther(other: Tensor | TensorValue): Tensor;
42
46
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
47
+ static normalizeDims(dims: number[], numDims: number): number[];
43
48
  isContiguous(): boolean;
44
49
  contiguous(): Tensor;
50
+ view(newShape: readonly number[]): Tensor;
45
51
  reshape(newShape: readonly number[]): Tensor;
52
+ transpose(dim1: number, dim2: number): Tensor;
53
+ swapaxes: (dim1: number, dim2: number) => Tensor;
54
+ swapdims: (dim1: number, dim2: number) => Tensor;
55
+ t(): Tensor;
56
+ permute(dims: number[]): Tensor;
57
+ indexWithArray(indices: number[]): Tensor;
58
+ index(indices: Tensor | TensorValue): Tensor;
59
+ slice(ranges: number[][]): Tensor;
46
60
  squeeze(dims?: number[] | number): Tensor;
47
61
  unsqueeze(dim: number): Tensor;
62
+ static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
63
+ identity: number;
64
+ operation: (accumulator: number, value: number) => number;
65
+ needsCounters?: boolean;
66
+ postProcess?: (options: {
67
+ values: number[];
68
+ counters?: number[];
69
+ }) => void;
70
+ needsShareCounts?: boolean;
71
+ gradientFn: (options: {
72
+ outputValue: number[];
73
+ originalValue: number[];
74
+ counters: number[];
75
+ shareCounts: number[];
76
+ realIndex: number;
77
+ outIndex: number;
78
+ }) => number;
79
+ }): Tensor;
48
80
  sum(dims?: number[] | number, keepDims?: boolean): Tensor;
49
81
  prod(dims?: number[] | number, keepDims?: boolean): Tensor;
50
82
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
@@ -54,7 +86,7 @@ export declare class Tensor {
54
86
  any(dims?: number[] | number, keepDims?: boolean): Tensor;
55
87
  var(dims?: number[] | number, keepDims?: boolean): Tensor;
56
88
  std(dims?: number[] | number, keepDims?: boolean): Tensor;
57
- softmax(dims?: number[] | number): Tensor;
89
+ softmax(dim?: number): Tensor;
58
90
  add(other: TensorValue | Tensor): Tensor;
59
91
  sub(other: TensorValue | Tensor): Tensor;
60
92
  subtract: (other: TensorValue | Tensor) => Tensor;
@@ -144,28 +176,23 @@ export declare class Tensor {
144
176
  erf(): Tensor;
145
177
  erfc(): Tensor;
146
178
  erfinv(): Tensor;
147
- transpose(dim1: number, dim2: number): Tensor;
148
- swapaxes: (dim1: number, dim2: number) => Tensor;
149
- swapdims: (dim1: number, dim2: number) => Tensor;
150
- t(): Tensor;
151
- permute(dims: number[]): Tensor;
152
179
  dot(other: TensorValue | Tensor): Tensor;
153
180
  mm(other: TensorValue | Tensor): Tensor;
154
181
  bmm(other: TensorValue | Tensor): Tensor;
155
182
  mv(other: TensorValue | Tensor): Tensor;
156
183
  matmul(other: TensorValue | Tensor): Tensor;
157
184
  dropout(rate: number): Tensor;
158
- static full(shape: number[], num: number, options?: TensorOptions): Tensor;
185
+ static full(shape: readonly number[], num: number, options?: TensorOptions): Tensor;
159
186
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
160
- static ones(shape?: number[], options?: TensorOptions): Tensor;
187
+ static ones(shape?: readonly number[], options?: TensorOptions): Tensor;
161
188
  static onesLike(tensor: Tensor, options?: TensorOptions): Tensor;
162
- static zeros(shape?: number[], options?: TensorOptions): Tensor;
189
+ static zeros(shape?: readonly number[], options?: TensorOptions): Tensor;
163
190
  static zerosLike(tensor: Tensor, options?: TensorOptions): Tensor;
164
- static rand(shape?: number[], options?: TensorOptions): Tensor;
191
+ static rand(shape?: readonly number[], options?: TensorOptions): Tensor;
165
192
  static randLike(tensor: Tensor, options?: TensorOptions): Tensor;
166
- static randn(shape?: number[], options?: TensorOptions): Tensor;
193
+ static randn(shape?: readonly number[], options?: TensorOptions): Tensor;
167
194
  static randnLike(tensor: Tensor, options?: TensorOptions): Tensor;
168
- static randint(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;
195
+ static randint(shape: readonly number[], low: number, high: number, options?: TensorOptions): Tensor;
169
196
  static randintLike(tensor: Tensor, low: number, high: number, options?: TensorOptions): Tensor;
170
197
  static normal(shape: number[], mean: number, stdDev: number, options?: TensorOptions): Tensor;
171
198
  static uniform(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;