catniff 0.6.0 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -47,12 +47,15 @@ export declare class Tensor {
47
47
  static normalizeDims(dims: number[], numDims: number): number[];
48
48
  isContiguous(): boolean;
49
49
  contiguous(): Tensor;
50
+ view(newShape: readonly number[]): Tensor;
50
51
  reshape(newShape: readonly number[]): Tensor;
51
52
  transpose(dim1: number, dim2: number): Tensor;
52
53
  swapaxes: (dim1: number, dim2: number) => Tensor;
53
54
  swapdims: (dim1: number, dim2: number) => Tensor;
54
55
  t(): Tensor;
55
56
  permute(dims: number[]): Tensor;
57
+ indexWithArray(indices: number[]): Tensor;
58
+ index(indices: Tensor | TensorValue): Tensor;
56
59
  slice(ranges: number[][]): Tensor;
57
60
  squeeze(dims?: number[] | number): Tensor;
58
61
  unsqueeze(dim: number): Tensor;
package/dist/core.js CHANGED
@@ -103,7 +103,7 @@ class Tensor {
103
103
  newShape[index] = shapeA[index];
104
104
  }
105
105
  else {
106
- throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`);
106
+ throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`);
107
107
  }
108
108
  }
109
109
  return newShape;
@@ -325,13 +325,37 @@ class Tensor {
325
325
  }
326
326
  return out;
327
327
  }
328
+ view(newShape) {
329
+ // Verify shape size
330
+ const originalSize = this.numel;
331
+ const outputSize = Tensor.shapeToSize(newShape);
332
+ if (originalSize !== outputSize) {
333
+ throw new Error("Can not create view: incompatible sizes");
334
+ }
335
+ // Verify compatibility (only contiguity for now)
336
+ if (!this.isContiguous()) {
337
+ throw new Error("Can not create view: incompatible metadata");
338
+ }
339
+ const outputStrides = Tensor.getStrides(newShape);
340
+ const out = new Tensor(this.value, { shape: newShape, strides: outputStrides, numel: outputSize });
341
+ // Gradient reshaped and flow back to the original tensor
342
+ if (this.requiresGrad) {
343
+ out.requiresGrad = true;
344
+ out.children.push(this);
345
+ out.gradFn = () => {
346
+ Tensor.addGrad(this, out.grad.reshape(this.shape));
347
+ };
348
+ }
349
+ return out;
350
+ }
328
351
  reshape(newShape) {
329
352
  // Verify shape size
330
353
  const originalSize = this.numel;
331
354
  const outputSize = Tensor.shapeToSize(newShape);
332
355
  if (originalSize !== outputSize) {
333
- throw new Error("Cannot reshape: incompatible sizes");
356
+ throw new Error("Can not reshape: incompatible sizes");
334
357
  }
358
+ // Create new tensor with forced compatibility (only contiguity for now)
335
359
  const outputStrides = Tensor.getStrides(newShape);
336
360
  const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides, numel: outputSize });
337
361
  // Gradient reshaped and flow back to the original tensor
@@ -430,6 +454,72 @@ class Tensor {
430
454
  }
431
455
  return out;
432
456
  }
457
+ // Utility for indexing with array of indices
458
+ indexWithArray(indices) {
459
+ if (typeof this.value === "number")
460
+ return this;
461
+ indices = Tensor.normalizeDims(indices, this.shape[0]);
462
+ // Init necessary stuff for indexing
463
+ const reducedShape = this.shape.slice(1);
464
+ const reducedStrides = this.strides.slice(1);
465
+ const elementsPerIndex = Tensor.shapeToSize(reducedShape);
466
+ // Init output data
467
+ const outputShape = [indices.length, ...reducedShape];
468
+ const outputSize = Tensor.shapeToSize(outputShape);
469
+ const outputValue = new Array(outputSize);
470
+ for (let i = 0; i < indices.length; i++) {
471
+ const sourceRowIndex = indices[i];
472
+ const targetStart = i * elementsPerIndex;
473
+ for (let j = 0; j < elementsPerIndex; j++) {
474
+ const fullCoords = Tensor.indexToCoords(j, reducedStrides);
475
+ fullCoords.unshift(sourceRowIndex);
476
+ const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides);
477
+ outputValue[targetStart + j] = this.value[this.offset + sourceIndex];
478
+ }
479
+ }
480
+ const out = new Tensor(outputValue, {
481
+ shape: outputShape,
482
+ numel: outputSize
483
+ });
484
+ // Handle gradient
485
+ if (this.requiresGrad) {
486
+ out.requiresGrad = true;
487
+ out.children.push(this);
488
+ out.gradFn = () => {
489
+ const outGrad = out.grad;
490
+ // Create zero gradient tensor with original shape
491
+ const grad = Tensor.zerosLike(this);
492
+ // Scatter gradients back to original positions
493
+ for (let i = 0; i < indices.length; i++) {
494
+ const originalRowIndex = indices[i];
495
+ const sourceStart = i * elementsPerIndex;
496
+ for (let j = 0; j < elementsPerIndex; j++) {
497
+ const fullCoords = Tensor.indexToCoords(j, reducedStrides);
498
+ fullCoords.unshift(originalRowIndex);
499
+ const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides);
500
+ grad.value[targetIndex] += outGrad.value[sourceStart + j];
501
+ }
502
+ }
503
+ Tensor.addGrad(this, grad);
504
+ };
505
+ }
506
+ return out;
507
+ }
508
+ // Tensor indexing
509
+ index(indices) {
510
+ const tensorIndices = this.handleOther(indices).contiguous();
511
+ if (typeof tensorIndices.value === "number") {
512
+ return this.indexWithArray([tensorIndices.value]).squeeze(0);
513
+ }
514
+ else {
515
+ const originalShape = tensorIndices.shape;
516
+ const flatIndices = tensorIndices.value;
517
+ const result = this.indexWithArray(flatIndices);
518
+ // Reshape to preserve input shape
519
+ const outputShape = [...originalShape, ...this.shape.slice(1)];
520
+ return result.reshape(outputShape);
521
+ }
522
+ }
433
523
  // Tensor slicing
434
524
  slice(ranges) {
435
525
  // Handle scalars
@@ -478,7 +568,7 @@ class Tensor {
478
568
  out.children.push(this);
479
569
  out.gradFn = () => {
480
570
  // Create zero tensor of original shape
481
- const zeroGrad = Tensor.zerosLike(this);
571
+ const grad = Tensor.zerosLike(this);
482
572
  // Upstream grad
483
573
  const outGrad = out.grad;
484
574
  const totalElements = outGrad.numel;
@@ -497,11 +587,11 @@ class Tensor {
497
587
  }
498
588
  // Get flat indices with offsets
499
589
  const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
500
- const targetIndex = Tensor.coordsToIndex(originalCoords, zeroGrad.strides) + zeroGrad.offset;
590
+ const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset;
501
591
  // Accumulate gradient
502
- zeroGrad.value[targetIndex] += outGrad.value[srcIndex];
592
+ grad.value[targetIndex] += outGrad.value[srcIndex];
503
593
  }
504
- Tensor.addGrad(this, zeroGrad);
594
+ Tensor.addGrad(this, grad);
505
595
  };
506
596
  }
507
597
  return out;
package/dist/nn.d.ts CHANGED
@@ -57,6 +57,11 @@ declare class LayerNorm {
57
57
  constructor(normalizedShape: number | number[], eps?: number, elementwiseAffine?: boolean, bias?: boolean, device?: string);
58
58
  forward(input: Tensor): Tensor;
59
59
  }
60
+ declare class Embedding {
61
+ weight: Tensor;
62
+ constructor(numEmbeddings: number, embeddingDim: number, device: string);
63
+ forward(input: Tensor | TensorValue): Tensor;
64
+ }
60
65
  export interface StateDict {
61
66
  [key: string]: any;
62
67
  }
@@ -66,6 +71,7 @@ export declare const nn: {
66
71
  GRUCell: typeof GRUCell;
67
72
  LSTMCell: typeof LSTMCell;
68
73
  LayerNorm: typeof LayerNorm;
74
+ Embedding: typeof Embedding;
69
75
  state: {
70
76
  getParameters(model: any, visited?: WeakSet<object>): Tensor[];
71
77
  moveParameters(model: any, device: string): void;
package/dist/nn.js CHANGED
@@ -188,6 +188,15 @@ class LayerNorm {
188
188
  return normalized;
189
189
  }
190
190
  }
191
+ class Embedding {
192
+ weight;
193
+ constructor(numEmbeddings, embeddingDim, device) {
194
+ this.weight = core_1.Tensor.randn([numEmbeddings, embeddingDim], { device });
195
+ }
196
+ forward(input) {
197
+ return this.weight.index(input);
198
+ }
199
+ }
191
200
  const state = {
192
201
  getParameters(model, visited = new WeakSet()) {
193
202
  if (visited.has(model))
@@ -256,5 +265,6 @@ exports.nn = {
256
265
  GRUCell,
257
266
  LSTMCell,
258
267
  LayerNorm,
268
+ Embedding,
259
269
  state
260
270
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.6.0",
3
+ "version": "0.6.2",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {