catniff 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +3 -0
- package/dist/core.js +96 -6
- package/dist/nn.d.ts +6 -0
- package/dist/nn.js +10 -0
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -47,12 +47,15 @@ export declare class Tensor {
|
|
|
47
47
|
static normalizeDims(dims: number[], numDims: number): number[];
|
|
48
48
|
isContiguous(): boolean;
|
|
49
49
|
contiguous(): Tensor;
|
|
50
|
+
view(newShape: readonly number[]): Tensor;
|
|
50
51
|
reshape(newShape: readonly number[]): Tensor;
|
|
51
52
|
transpose(dim1: number, dim2: number): Tensor;
|
|
52
53
|
swapaxes: (dim1: number, dim2: number) => Tensor;
|
|
53
54
|
swapdims: (dim1: number, dim2: number) => Tensor;
|
|
54
55
|
t(): Tensor;
|
|
55
56
|
permute(dims: number[]): Tensor;
|
|
57
|
+
indexWithArray(indices: number[]): Tensor;
|
|
58
|
+
index(indices: Tensor | TensorValue): Tensor;
|
|
56
59
|
slice(ranges: number[][]): Tensor;
|
|
57
60
|
squeeze(dims?: number[] | number): Tensor;
|
|
58
61
|
unsqueeze(dim: number): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -103,7 +103,7 @@ class Tensor {
|
|
|
103
103
|
newShape[index] = shapeA[index];
|
|
104
104
|
}
|
|
105
105
|
else {
|
|
106
|
-
throw new Error(`
|
|
106
|
+
throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`);
|
|
107
107
|
}
|
|
108
108
|
}
|
|
109
109
|
return newShape;
|
|
@@ -325,13 +325,37 @@ class Tensor {
|
|
|
325
325
|
}
|
|
326
326
|
return out;
|
|
327
327
|
}
|
|
328
|
+
view(newShape) {
|
|
329
|
+
// Verify shape size
|
|
330
|
+
const originalSize = this.numel;
|
|
331
|
+
const outputSize = Tensor.shapeToSize(newShape);
|
|
332
|
+
if (originalSize !== outputSize) {
|
|
333
|
+
throw new Error("Can not create view: incompatible sizes");
|
|
334
|
+
}
|
|
335
|
+
// Verify compatibility (only contiguity for now)
|
|
336
|
+
if (!this.isContiguous()) {
|
|
337
|
+
throw new Error("Can not create view: incompatible metadata");
|
|
338
|
+
}
|
|
339
|
+
const outputStrides = Tensor.getStrides(newShape);
|
|
340
|
+
const out = new Tensor(this.value, { shape: newShape, strides: outputStrides, numel: outputSize });
|
|
341
|
+
// Gradient reshaped and flow back to the original tensor
|
|
342
|
+
if (this.requiresGrad) {
|
|
343
|
+
out.requiresGrad = true;
|
|
344
|
+
out.children.push(this);
|
|
345
|
+
out.gradFn = () => {
|
|
346
|
+
Tensor.addGrad(this, out.grad.reshape(this.shape));
|
|
347
|
+
};
|
|
348
|
+
}
|
|
349
|
+
return out;
|
|
350
|
+
}
|
|
328
351
|
reshape(newShape) {
|
|
329
352
|
// Verify shape size
|
|
330
353
|
const originalSize = this.numel;
|
|
331
354
|
const outputSize = Tensor.shapeToSize(newShape);
|
|
332
355
|
if (originalSize !== outputSize) {
|
|
333
|
-
throw new Error("
|
|
356
|
+
throw new Error("Can not reshape: incompatible sizes");
|
|
334
357
|
}
|
|
358
|
+
// Create new tensor with forced compatibility (only contiguity for now)
|
|
335
359
|
const outputStrides = Tensor.getStrides(newShape);
|
|
336
360
|
const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides, numel: outputSize });
|
|
337
361
|
// Gradient reshaped and flow back to the original tensor
|
|
@@ -430,6 +454,72 @@ class Tensor {
|
|
|
430
454
|
}
|
|
431
455
|
return out;
|
|
432
456
|
}
|
|
457
|
+
// Utility for indexing with array of indices
|
|
458
|
+
indexWithArray(indices) {
|
|
459
|
+
if (typeof this.value === "number")
|
|
460
|
+
return this;
|
|
461
|
+
indices = Tensor.normalizeDims(indices, this.shape[0]);
|
|
462
|
+
// Init necessary stuff for indexing
|
|
463
|
+
const reducedShape = this.shape.slice(1);
|
|
464
|
+
const reducedStrides = this.strides.slice(1);
|
|
465
|
+
const elementsPerIndex = Tensor.shapeToSize(reducedShape);
|
|
466
|
+
// Init output data
|
|
467
|
+
const outputShape = [indices.length, ...reducedShape];
|
|
468
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
469
|
+
const outputValue = new Array(outputSize);
|
|
470
|
+
for (let i = 0; i < indices.length; i++) {
|
|
471
|
+
const sourceRowIndex = indices[i];
|
|
472
|
+
const targetStart = i * elementsPerIndex;
|
|
473
|
+
for (let j = 0; j < elementsPerIndex; j++) {
|
|
474
|
+
const fullCoords = Tensor.indexToCoords(j, reducedStrides);
|
|
475
|
+
fullCoords.unshift(sourceRowIndex);
|
|
476
|
+
const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides);
|
|
477
|
+
outputValue[targetStart + j] = this.value[this.offset + sourceIndex];
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
const out = new Tensor(outputValue, {
|
|
481
|
+
shape: outputShape,
|
|
482
|
+
numel: outputSize
|
|
483
|
+
});
|
|
484
|
+
// Handle gradient
|
|
485
|
+
if (this.requiresGrad) {
|
|
486
|
+
out.requiresGrad = true;
|
|
487
|
+
out.children.push(this);
|
|
488
|
+
out.gradFn = () => {
|
|
489
|
+
const outGrad = out.grad;
|
|
490
|
+
// Create zero gradient tensor with original shape
|
|
491
|
+
const grad = Tensor.zerosLike(this);
|
|
492
|
+
// Scatter gradients back to original positions
|
|
493
|
+
for (let i = 0; i < indices.length; i++) {
|
|
494
|
+
const originalRowIndex = indices[i];
|
|
495
|
+
const sourceStart = i * elementsPerIndex;
|
|
496
|
+
for (let j = 0; j < elementsPerIndex; j++) {
|
|
497
|
+
const fullCoords = Tensor.indexToCoords(j, reducedStrides);
|
|
498
|
+
fullCoords.unshift(originalRowIndex);
|
|
499
|
+
const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides);
|
|
500
|
+
grad.value[targetIndex] += outGrad.value[sourceStart + j];
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
Tensor.addGrad(this, grad);
|
|
504
|
+
};
|
|
505
|
+
}
|
|
506
|
+
return out;
|
|
507
|
+
}
|
|
508
|
+
// Tensor indexing
|
|
509
|
+
index(indices) {
|
|
510
|
+
if (typeof indices === "number") {
|
|
511
|
+
return this.indexWithArray([indices]).squeeze(0);
|
|
512
|
+
}
|
|
513
|
+
else {
|
|
514
|
+
const tensorIndices = this.handleOther(indices).contiguous();
|
|
515
|
+
const originalShape = tensorIndices.shape;
|
|
516
|
+
const flatIndices = tensorIndices.value;
|
|
517
|
+
const result = this.indexWithArray(flatIndices);
|
|
518
|
+
// Reshape to preserve input shape
|
|
519
|
+
const outputShape = [...originalShape, ...this.shape.slice(1)];
|
|
520
|
+
return result.reshape(outputShape);
|
|
521
|
+
}
|
|
522
|
+
}
|
|
433
523
|
// Tensor slicing
|
|
434
524
|
slice(ranges) {
|
|
435
525
|
// Handle scalars
|
|
@@ -478,7 +568,7 @@ class Tensor {
|
|
|
478
568
|
out.children.push(this);
|
|
479
569
|
out.gradFn = () => {
|
|
480
570
|
// Create zero tensor of original shape
|
|
481
|
-
const
|
|
571
|
+
const grad = Tensor.zerosLike(this);
|
|
482
572
|
// Upstream grad
|
|
483
573
|
const outGrad = out.grad;
|
|
484
574
|
const totalElements = outGrad.numel;
|
|
@@ -497,11 +587,11 @@ class Tensor {
|
|
|
497
587
|
}
|
|
498
588
|
// Get flat indices with offsets
|
|
499
589
|
const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
|
|
500
|
-
const targetIndex = Tensor.coordsToIndex(originalCoords,
|
|
590
|
+
const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset;
|
|
501
591
|
// Accumulate gradient
|
|
502
|
-
|
|
592
|
+
grad.value[targetIndex] += outGrad.value[srcIndex];
|
|
503
593
|
}
|
|
504
|
-
Tensor.addGrad(this,
|
|
594
|
+
Tensor.addGrad(this, grad);
|
|
505
595
|
};
|
|
506
596
|
}
|
|
507
597
|
return out;
|
package/dist/nn.d.ts
CHANGED
|
@@ -57,6 +57,11 @@ declare class LayerNorm {
|
|
|
57
57
|
constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
|
|
58
58
|
forward(input: Tensor): Tensor;
|
|
59
59
|
}
|
|
60
|
+
declare class Embedding {
|
|
61
|
+
weight: Tensor;
|
|
62
|
+
constructor(numEmbeddings: number, embeddingDim: number, device: string);
|
|
63
|
+
forward(input: Tensor | TensorValue): Tensor;
|
|
64
|
+
}
|
|
60
65
|
export interface StateDict {
|
|
61
66
|
[key: string]: any;
|
|
62
67
|
}
|
|
@@ -66,6 +71,7 @@ export declare const nn: {
|
|
|
66
71
|
GRUCell: typeof GRUCell;
|
|
67
72
|
LSTMCell: typeof LSTMCell;
|
|
68
73
|
LayerNorm: typeof LayerNorm;
|
|
74
|
+
Embedding: typeof Embedding;
|
|
69
75
|
state: {
|
|
70
76
|
getParameters(model: any, visited?: WeakSet<object>): Tensor[];
|
|
71
77
|
moveParameters(model: any, device: string): void;
|
package/dist/nn.js
CHANGED
|
@@ -188,6 +188,15 @@ class LayerNorm {
|
|
|
188
188
|
return normalized;
|
|
189
189
|
}
|
|
190
190
|
}
|
|
191
|
+
class Embedding {
|
|
192
|
+
weight;
|
|
193
|
+
constructor(numEmbeddings, embeddingDim, device) {
|
|
194
|
+
this.weight = core_1.Tensor.randn([numEmbeddings, embeddingDim], { device });
|
|
195
|
+
}
|
|
196
|
+
forward(input) {
|
|
197
|
+
return this.weight.index(input);
|
|
198
|
+
}
|
|
199
|
+
}
|
|
191
200
|
const state = {
|
|
192
201
|
getParameters(model, visited = new WeakSet()) {
|
|
193
202
|
if (visited.has(model))
|
|
@@ -256,5 +265,6 @@ exports.nn = {
|
|
|
256
265
|
GRUCell,
|
|
257
266
|
LSTMCell,
|
|
258
267
|
LayerNorm,
|
|
268
|
+
Embedding,
|
|
259
269
|
state
|
|
260
270
|
};
|