catniff 0.7.4 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +14 -0
- package/dist/core.d.ts +16 -11
- package/dist/core.js +362 -115
- package/dist/dtype.d.ts +5 -0
- package/dist/dtype.js +25 -0
- package/dist/nn.d.ts +9 -8
- package/dist/nn.js +50 -50
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
package/dist/dtype.d.ts
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
export type dtype = "float64" | "float32" | "float16" | "int32" | "int16" | "int8" | "uint32" | "uint16" | "uint8";
|
|
2
|
+
export declare const dtypeHiearchy: Record<dtype, number>;
|
|
3
|
+
export type MemoryBuffer = Float64Array | Float32Array | Float16Array | Int32Array | Int16Array | Int8Array | Uint32Array | Uint16Array | Uint8Array;
|
|
4
|
+
export type TypedArrayConstructor = Float64ArrayConstructor | Float32ArrayConstructor | Float16ArrayConstructor | Int32ArrayConstructor | Int16ArrayConstructor | Int8ArrayConstructor | Uint32ArrayConstructor | Uint16ArrayConstructor | Uint8ArrayConstructor;
|
|
5
|
+
export declare const TypedArray: Record<dtype, TypedArrayConstructor>;
|
package/dist/dtype.js
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TypedArray = exports.dtypeHiearchy = void 0;
|
|
4
|
+
exports.dtypeHiearchy = {
|
|
5
|
+
"float64": 8,
|
|
6
|
+
"float32": 7,
|
|
7
|
+
"float16": 6,
|
|
8
|
+
"int32": 5,
|
|
9
|
+
"int16": 4,
|
|
10
|
+
"int8": 3,
|
|
11
|
+
"uint32": 2,
|
|
12
|
+
"uint16": 1,
|
|
13
|
+
"uint8": 0
|
|
14
|
+
};
|
|
15
|
+
exports.TypedArray = {
|
|
16
|
+
"float64": Float64Array,
|
|
17
|
+
"float32": Float32Array,
|
|
18
|
+
"float16": Float16Array,
|
|
19
|
+
"int32": Int32Array,
|
|
20
|
+
"int16": Int16Array,
|
|
21
|
+
"int8": Int8Array,
|
|
22
|
+
"uint32": Uint32Array,
|
|
23
|
+
"uint16": Uint16Array,
|
|
24
|
+
"uint8": Uint8Array
|
|
25
|
+
};
|
package/dist/nn.d.ts
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import { Tensor, TensorValue } from "./core";
|
|
2
|
+
import { dtype } from "./dtype";
|
|
2
3
|
export declare class Linear {
|
|
3
4
|
weight: Tensor;
|
|
4
5
|
bias?: Tensor;
|
|
5
|
-
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string);
|
|
6
|
+
constructor(inFeatures: number, outFeatures: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
6
7
|
forward(input: Tensor | TensorValue): Tensor;
|
|
7
8
|
}
|
|
8
9
|
export declare class RNNCell {
|
|
@@ -10,7 +11,7 @@ export declare class RNNCell {
|
|
|
10
11
|
weightHH: Tensor;
|
|
11
12
|
biasIH?: Tensor;
|
|
12
13
|
biasHH?: Tensor;
|
|
13
|
-
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
14
|
+
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
14
15
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
15
16
|
}
|
|
16
17
|
export declare class GRUCell {
|
|
@@ -26,7 +27,7 @@ export declare class GRUCell {
|
|
|
26
27
|
biasHR?: Tensor;
|
|
27
28
|
biasHZ?: Tensor;
|
|
28
29
|
biasHN?: Tensor;
|
|
29
|
-
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
30
|
+
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
30
31
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue): Tensor;
|
|
31
32
|
}
|
|
32
33
|
export declare class LSTMCell {
|
|
@@ -46,7 +47,7 @@ export declare class LSTMCell {
|
|
|
46
47
|
biasHF?: Tensor;
|
|
47
48
|
biasHG?: Tensor;
|
|
48
49
|
biasHO?: Tensor;
|
|
49
|
-
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string);
|
|
50
|
+
constructor(inputSize: number, hiddenSize: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
50
51
|
forward(input: Tensor | TensorValue, hidden: Tensor | TensorValue, cell: Tensor | TensorValue): [Tensor, Tensor];
|
|
51
52
|
}
|
|
52
53
|
export declare class LayerNorm {
|
|
@@ -54,19 +55,19 @@ export declare class LayerNorm {
|
|
|
54
55
|
bias?: Tensor;
|
|
55
56
|
eps: number;
|
|
56
57
|
normalizedShape: number[];
|
|
57
|
-
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
|
+
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string, dtype?: dtype);
|
|
58
59
|
forward(input: Tensor): Tensor;
|
|
59
60
|
}
|
|
60
61
|
export declare class RMSNorm {
|
|
61
62
|
weight?: Tensor;
|
|
62
63
|
eps: number;
|
|
63
64
|
normalizedShape: number[];
|
|
64
|
-
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, device?: string);
|
|
65
|
+
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, device?: string, dtype?: dtype);
|
|
65
66
|
forward(input: Tensor): Tensor;
|
|
66
67
|
}
|
|
67
68
|
export declare class Embedding {
|
|
68
69
|
weight: Tensor;
|
|
69
|
-
constructor(numEmbeddings: number, embeddingDim: number, device
|
|
70
|
+
constructor(numEmbeddings: number, embeddingDim: number, device?: string, dtype?: dtype);
|
|
70
71
|
forward(input: Tensor | TensorValue): Tensor;
|
|
71
72
|
}
|
|
72
73
|
export declare class MultiheadAttention {
|
|
@@ -78,7 +79,7 @@ export declare class MultiheadAttention {
|
|
|
78
79
|
numHeads: number;
|
|
79
80
|
headDim: number;
|
|
80
81
|
dropout: number;
|
|
81
|
-
constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string);
|
|
82
|
+
constructor(embedDim: number, numHeads: number, dropout?: number, bias?: boolean, device?: string, dtype?: dtype);
|
|
82
83
|
forward(query: Tensor, key: Tensor, value: Tensor, needWeights?: boolean, attnMask?: Tensor, averageAttnWeights?: boolean): [Tensor, Tensor | undefined];
|
|
83
84
|
}
|
|
84
85
|
export interface StateDict {
|
package/dist/nn.js
CHANGED
|
@@ -12,11 +12,11 @@ function linearTransform(input, weight, bias) {
|
|
|
12
12
|
class Linear {
|
|
13
13
|
weight;
|
|
14
14
|
bias;
|
|
15
|
-
constructor(inFeatures, outFeatures, bias = true, device) {
|
|
15
|
+
constructor(inFeatures, outFeatures, bias = true, device, dtype) {
|
|
16
16
|
const bound = 1 / Math.sqrt(inFeatures);
|
|
17
|
-
this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device });
|
|
17
|
+
this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device, dtype });
|
|
18
18
|
if (bias) {
|
|
19
|
-
this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device });
|
|
19
|
+
this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device, dtype });
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
forward(input) {
|
|
@@ -40,13 +40,13 @@ class RNNCell {
|
|
|
40
40
|
weightHH;
|
|
41
41
|
biasIH;
|
|
42
42
|
biasHH;
|
|
43
|
-
constructor(inputSize, hiddenSize, bias = true, device) {
|
|
43
|
+
constructor(inputSize, hiddenSize, bias = true, device, dtype) {
|
|
44
44
|
const bound = 1 / Math.sqrt(hiddenSize);
|
|
45
|
-
this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
46
|
-
this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
45
|
+
this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
46
|
+
this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
47
47
|
if (bias) {
|
|
48
|
-
this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
49
|
-
this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
48
|
+
this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
49
|
+
this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
50
50
|
}
|
|
51
51
|
}
|
|
52
52
|
forward(input, hidden) {
|
|
@@ -69,21 +69,21 @@ class GRUCell {
|
|
|
69
69
|
biasHR;
|
|
70
70
|
biasHZ;
|
|
71
71
|
biasHN;
|
|
72
|
-
constructor(inputSize, hiddenSize, bias = true, device) {
|
|
72
|
+
constructor(inputSize, hiddenSize, bias = true, device, dtype) {
|
|
73
73
|
const bound = 1 / Math.sqrt(hiddenSize);
|
|
74
|
-
this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
75
|
-
this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
76
|
-
this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
77
|
-
this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
78
|
-
this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
79
|
-
this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
74
|
+
this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
75
|
+
this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
76
|
+
this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
77
|
+
this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
78
|
+
this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
79
|
+
this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
80
80
|
if (bias) {
|
|
81
|
-
this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
82
|
-
this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
83
|
-
this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
84
|
-
this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
85
|
-
this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
86
|
-
this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
81
|
+
this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
82
|
+
this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
83
|
+
this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
84
|
+
this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
85
|
+
this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
86
|
+
this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
87
87
|
}
|
|
88
88
|
}
|
|
89
89
|
forward(input, hidden) {
|
|
@@ -113,25 +113,25 @@ class LSTMCell {
|
|
|
113
113
|
biasHF;
|
|
114
114
|
biasHG;
|
|
115
115
|
biasHO;
|
|
116
|
-
constructor(inputSize, hiddenSize, bias = true, device) {
|
|
116
|
+
constructor(inputSize, hiddenSize, bias = true, device, dtype) {
|
|
117
117
|
const bound = 1 / Math.sqrt(hiddenSize);
|
|
118
|
-
this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
119
|
-
this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
120
|
-
this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
121
|
-
this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device });
|
|
122
|
-
this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
123
|
-
this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
124
|
-
this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
125
|
-
this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
118
|
+
this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
119
|
+
this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
120
|
+
this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
121
|
+
this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
122
|
+
this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
123
|
+
this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
124
|
+
this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
125
|
+
this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
126
126
|
if (bias) {
|
|
127
|
-
this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
128
|
-
this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
129
|
-
this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
130
|
-
this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
131
|
-
this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
132
|
-
this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
133
|
-
this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
134
|
-
this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device });
|
|
127
|
+
this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
128
|
+
this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
129
|
+
this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
130
|
+
this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
131
|
+
this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
132
|
+
this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
133
|
+
this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
134
|
+
this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype });
|
|
135
135
|
}
|
|
136
136
|
}
|
|
137
137
|
forward(input, hidden, cell) {
|
|
@@ -153,16 +153,16 @@ class LayerNorm {
|
|
|
153
153
|
bias;
|
|
154
154
|
eps;
|
|
155
155
|
normalizedShape;
|
|
156
|
-
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device) {
|
|
156
|
+
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device, dtype) {
|
|
157
157
|
this.eps = eps;
|
|
158
158
|
this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
|
|
159
159
|
if (this.normalizedShape.length === 0) {
|
|
160
160
|
throw new Error("Normalized shape cannot be empty");
|
|
161
161
|
}
|
|
162
162
|
if (elementwiseAffine) {
|
|
163
|
-
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
|
|
163
|
+
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device, dtype });
|
|
164
164
|
if (bias) {
|
|
165
|
-
this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device });
|
|
165
|
+
this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device, dtype });
|
|
166
166
|
}
|
|
167
167
|
}
|
|
168
168
|
}
|
|
@@ -197,14 +197,14 @@ class RMSNorm {
|
|
|
197
197
|
weight;
|
|
198
198
|
eps;
|
|
199
199
|
normalizedShape;
|
|
200
|
-
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, device) {
|
|
200
|
+
constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, device, dtype) {
|
|
201
201
|
this.eps = eps;
|
|
202
202
|
this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape];
|
|
203
203
|
if (this.normalizedShape.length === 0) {
|
|
204
204
|
throw new Error("Normalized shape cannot be empty");
|
|
205
205
|
}
|
|
206
206
|
if (elementwiseAffine) {
|
|
207
|
-
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device });
|
|
207
|
+
this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device, dtype });
|
|
208
208
|
}
|
|
209
209
|
}
|
|
210
210
|
forward(input) {
|
|
@@ -232,8 +232,8 @@ class RMSNorm {
|
|
|
232
232
|
exports.RMSNorm = RMSNorm;
|
|
233
233
|
class Embedding {
|
|
234
234
|
weight;
|
|
235
|
-
constructor(numEmbeddings, embeddingDim, device) {
|
|
236
|
-
this.weight = core_1.Tensor.randn([numEmbeddings, embeddingDim], { requiresGrad: true, device });
|
|
235
|
+
constructor(numEmbeddings, embeddingDim, device, dtype) {
|
|
236
|
+
this.weight = core_1.Tensor.randn([numEmbeddings, embeddingDim], { requiresGrad: true, device, dtype });
|
|
237
237
|
}
|
|
238
238
|
forward(input) {
|
|
239
239
|
return this.weight.index(input);
|
|
@@ -249,11 +249,11 @@ class MultiheadAttention {
|
|
|
249
249
|
numHeads;
|
|
250
250
|
headDim;
|
|
251
251
|
dropout;
|
|
252
|
-
constructor(embedDim, numHeads, dropout = 0, bias = true, device) {
|
|
253
|
-
this.qProjection = new Linear(embedDim, embedDim, bias, device);
|
|
254
|
-
this.kProjection = new Linear(embedDim, embedDim, bias, device);
|
|
255
|
-
this.vProjection = new Linear(embedDim, embedDim, bias, device);
|
|
256
|
-
this.oProjection = new Linear(embedDim, embedDim, bias, device);
|
|
252
|
+
constructor(embedDim, numHeads, dropout = 0, bias = true, device, dtype) {
|
|
253
|
+
this.qProjection = new Linear(embedDim, embedDim, bias, device, dtype);
|
|
254
|
+
this.kProjection = new Linear(embedDim, embedDim, bias, device, dtype);
|
|
255
|
+
this.vProjection = new Linear(embedDim, embedDim, bias, device, dtype);
|
|
256
|
+
this.oProjection = new Linear(embedDim, embedDim, bias, device, dtype);
|
|
257
257
|
this.embedDim = embedDim;
|
|
258
258
|
this.numHeads = numHeads;
|
|
259
259
|
this.headDim = Math.floor(embedDim / numHeads);
|
package/index.d.ts
CHANGED
package/index.js
CHANGED