catniff 0.5.11 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/dist/core.d.ts +36 -12
  2. package/dist/core.js +367 -504
  3. package/package.json +1 -1
package/dist/core.d.ts CHANGED
@@ -3,6 +3,8 @@ export type TensorValue = number | TensorValue[];
3
3
  export interface TensorOptions {
4
4
  shape?: readonly number[];
5
5
  strides?: readonly number[];
6
+ offset?: number;
7
+ numel?: number;
6
8
  grad?: Tensor;
7
9
  requiresGrad?: boolean;
8
10
  gradFn?: Function;
@@ -13,6 +15,8 @@ export declare class Tensor {
13
15
  value: number[] | number;
14
16
  readonly shape: readonly number[];
15
17
  readonly strides: readonly number[];
18
+ offset: number;
19
+ numel: number;
16
20
  grad?: Tensor;
17
21
  requiresGrad: boolean;
18
22
  gradFn: Function;
@@ -40,11 +44,36 @@ export declare class Tensor {
40
44
  elementWiseSelfDAG(op: (a: number) => number, thisGrad?: (self: Tensor, outGrad: Tensor) => Tensor): Tensor;
41
45
  handleOther(other: Tensor | TensorValue): Tensor;
42
46
  static addGrad(tensor: Tensor, accumGrad: Tensor): void;
47
+ static normalizeDims(dims: number[], numDims: number): number[];
43
48
  isContiguous(): boolean;
44
49
  contiguous(): Tensor;
45
50
  reshape(newShape: readonly number[]): Tensor;
51
+ transpose(dim1: number, dim2: number): Tensor;
52
+ swapaxes: (dim1: number, dim2: number) => Tensor;
53
+ swapdims: (dim1: number, dim2: number) => Tensor;
54
+ t(): Tensor;
55
+ permute(dims: number[]): Tensor;
56
+ slice(ranges: number[][]): Tensor;
46
57
  squeeze(dims?: number[] | number): Tensor;
47
58
  unsqueeze(dim: number): Tensor;
59
+ static reduce(tensor: Tensor, dims: number[] | number | undefined, keepDims: boolean, config: {
60
+ identity: number;
61
+ operation: (accumulator: number, value: number) => number;
62
+ needsCounters?: boolean;
63
+ postProcess?: (options: {
64
+ values: number[];
65
+ counters?: number[];
66
+ }) => void;
67
+ needsShareCounts?: boolean;
68
+ gradientFn: (options: {
69
+ outputValue: number[];
70
+ originalValue: number[];
71
+ counters: number[];
72
+ shareCounts: number[];
73
+ realIndex: number;
74
+ outIndex: number;
75
+ }) => number;
76
+ }): Tensor;
48
77
  sum(dims?: number[] | number, keepDims?: boolean): Tensor;
49
78
  prod(dims?: number[] | number, keepDims?: boolean): Tensor;
50
79
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
@@ -54,7 +83,7 @@ export declare class Tensor {
54
83
  any(dims?: number[] | number, keepDims?: boolean): Tensor;
55
84
  var(dims?: number[] | number, keepDims?: boolean): Tensor;
56
85
  std(dims?: number[] | number, keepDims?: boolean): Tensor;
57
- softmax(dims?: number[] | number): Tensor;
86
+ softmax(dim?: number): Tensor;
58
87
  add(other: TensorValue | Tensor): Tensor;
59
88
  sub(other: TensorValue | Tensor): Tensor;
60
89
  subtract: (other: TensorValue | Tensor) => Tensor;
@@ -144,28 +173,23 @@ export declare class Tensor {
144
173
  erf(): Tensor;
145
174
  erfc(): Tensor;
146
175
  erfinv(): Tensor;
147
- transpose(dim1: number, dim2: number): Tensor;
148
- swapaxes: (dim1: number, dim2: number) => Tensor;
149
- swapdims: (dim1: number, dim2: number) => Tensor;
150
- t(): Tensor;
151
- permute(dims: number[]): Tensor;
152
176
  dot(other: TensorValue | Tensor): Tensor;
153
177
  mm(other: TensorValue | Tensor): Tensor;
154
178
  bmm(other: TensorValue | Tensor): Tensor;
155
179
  mv(other: TensorValue | Tensor): Tensor;
156
180
  matmul(other: TensorValue | Tensor): Tensor;
157
181
  dropout(rate: number): Tensor;
158
- static full(shape: number[], num: number, options?: TensorOptions): Tensor;
182
+ static full(shape: readonly number[], num: number, options?: TensorOptions): Tensor;
159
183
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
160
- static ones(shape?: number[], options?: TensorOptions): Tensor;
184
+ static ones(shape?: readonly number[], options?: TensorOptions): Tensor;
161
185
  static onesLike(tensor: Tensor, options?: TensorOptions): Tensor;
162
- static zeros(shape?: number[], options?: TensorOptions): Tensor;
186
+ static zeros(shape?: readonly number[], options?: TensorOptions): Tensor;
163
187
  static zerosLike(tensor: Tensor, options?: TensorOptions): Tensor;
164
- static rand(shape?: number[], options?: TensorOptions): Tensor;
188
+ static rand(shape?: readonly number[], options?: TensorOptions): Tensor;
165
189
  static randLike(tensor: Tensor, options?: TensorOptions): Tensor;
166
- static randn(shape?: number[], options?: TensorOptions): Tensor;
190
+ static randn(shape?: readonly number[], options?: TensorOptions): Tensor;
167
191
  static randnLike(tensor: Tensor, options?: TensorOptions): Tensor;
168
- static randint(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;
192
+ static randint(shape: readonly number[], low: number, high: number, options?: TensorOptions): Tensor;
169
193
  static randintLike(tensor: Tensor, low: number, high: number, options?: TensorOptions): Tensor;
170
194
  static normal(shape: number[], mean: number, stdDev: number, options?: TensorOptions): Tensor;
171
195
  static uniform(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;
package/dist/core.js CHANGED
@@ -6,6 +6,8 @@ class Tensor {
6
6
  value;
7
7
  shape;
8
8
  strides;
9
+ offset;
10
+ numel;
9
11
  grad;
10
12
  requiresGrad;
11
13
  gradFn;
@@ -13,14 +15,19 @@ class Tensor {
13
15
  device;
14
16
  static training = false;
15
17
  constructor(value, options = {}) {
18
+ // Storage
16
19
  this.value = Tensor.flatten(value);
20
+ // Tensor metadata
17
21
  this.shape = options.shape || Tensor.getShape(value);
18
22
  this.strides = options.strides || Tensor.getStrides(this.shape);
23
+ this.offset = options.offset || 0;
24
+ this.numel = options.numel || Tensor.shapeToSize(this.shape);
25
+ this.device = options.device || "cpu";
26
+ // Autograd data
19
27
  this.grad = options.grad;
20
28
  this.requiresGrad = options.requiresGrad ?? false;
21
29
  this.gradFn = options.gradFn || (() => { });
22
30
  this.children = options.children || [];
23
- this.device = options.device || "cpu";
24
31
  // Move to device in-place
25
32
  this.to_(this.device);
26
33
  }
@@ -164,22 +171,28 @@ class Tensor {
164
171
  // Convert the coordinates to 1D index of flattened B with respect to B's shape
165
172
  const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
166
173
  // Calculate with op
167
- outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
174
+ outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
168
175
  }
169
176
  return new Tensor(outputValue, {
170
177
  shape: outputShape,
171
- strides: outputStrides
178
+ strides: outputStrides,
179
+ numel: outputSize
172
180
  });
173
181
  }
174
182
  // Utility for self-inflicting element-wise ops
175
183
  static elementWiseSelf(tA, op) {
176
184
  if (typeof tA.value === "number")
177
185
  return new Tensor(op(tA.value));
178
- const newValue = new Array(tA.value.length);
179
- for (let index = 0; index < tA.value.length; index++) {
180
- newValue[index] = op(tA.value[index]);
186
+ const outputShape = tA.shape;
187
+ const outputStrides = Tensor.getStrides(outputShape);
188
+ const outputSize = tA.numel;
189
+ const outputValue = new Array(outputSize);
190
+ for (let index = 0; index < outputSize; index++) {
191
+ const outputCoords = Tensor.indexToCoords(index, outputStrides);
192
+ const originalIndex = tA.offset + Tensor.coordsToIndex(outputCoords, tA.strides);
193
+ outputValue[index] = op(tA.value[originalIndex]);
181
194
  }
182
- return new Tensor(newValue, { shape: tA.shape, strides: tA.strides });
195
+ return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: tA.numel });
183
196
  }
184
197
  // Utility to do element-wise operation and build a dag node with another tensor
185
198
  elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
@@ -260,6 +273,19 @@ class Tensor {
260
273
  tensor.grad = tensor.grad.add(squeezedGrad);
261
274
  }
262
275
  }
276
+ static normalizeDims(dims, numDims) {
277
+ for (let index = 0; index < dims.length; index++) {
278
+ // Handle negative indices
279
+ if (dims[index] < 0) {
280
+ dims[index] += numDims;
281
+ }
282
+ // If dimension out of bound, throw error
283
+ if (dims[index] >= numDims || dims[index] < 0) {
284
+ throw new Error("Dimensions do not exist");
285
+ }
286
+ }
287
+ return dims;
288
+ }
263
289
  // Contiguity-related ops
264
290
  isContiguous() {
265
291
  const expectedStrides = Tensor.getStrides(this.shape);
@@ -281,14 +307,14 @@ class Tensor {
281
307
  if (this.isContiguous())
282
308
  return this;
283
309
  const outputStrides = Tensor.getStrides(this.shape);
284
- const outputSize = Tensor.shapeToSize(this.shape);
310
+ const outputSize = this.numel;
285
311
  const outputValue = new Array(outputSize);
286
312
  for (let index = 0; index < outputSize; index++) {
287
313
  const outputCoords = Tensor.indexToCoords(index, outputStrides);
288
314
  const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
289
- outputValue[index] = this.value[originalIndex];
315
+ outputValue[index] = this.value[this.offset + originalIndex];
290
316
  }
291
- const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides });
317
+ const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
292
318
  // Gradient flow back to the original tensor
293
319
  if (this.requiresGrad) {
294
320
  out.requiresGrad = true;
@@ -301,13 +327,13 @@ class Tensor {
301
327
  }
302
328
  reshape(newShape) {
303
329
  // Verify shape size
304
- const originalSize = Tensor.shapeToSize(this.shape);
330
+ const originalSize = this.numel;
305
331
  const outputSize = Tensor.shapeToSize(newShape);
306
332
  if (originalSize !== outputSize) {
307
333
  throw new Error("Cannot reshape: incompatible sizes");
308
334
  }
309
335
  const outputStrides = Tensor.getStrides(newShape);
310
- const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides });
336
+ const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides, numel: outputSize });
311
337
  // Gradient reshaped and flow back to the original tensor
312
338
  if (this.requiresGrad) {
313
339
  out.requiresGrad = true;
@@ -318,6 +344,168 @@ class Tensor {
318
344
  }
319
345
  return out;
320
346
  }
347
+ // Transpose
348
+ transpose(dim1, dim2) {
349
+ // Handle negative indices
350
+ if (dim1 < 0) {
351
+ dim1 += this.shape.length;
352
+ }
353
+ if (dim2 < 0) {
354
+ dim2 += this.shape.length;
355
+ }
356
+ // If dimension out of bound, throw error
357
+ if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
358
+ throw new Error("Dimensions do not exist to transpose");
359
+ }
360
+ // If same dimension, return view
361
+ if (dim1 === dim2)
362
+ return this;
363
+ // Create new shape and strides by swapping
364
+ const newShape = [...this.shape];
365
+ const newStrides = [...this.strides];
366
+ [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
367
+ [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
368
+ // Create new tensor with same data but swapped shape/strides
369
+ const out = new Tensor(this.value, {
370
+ shape: newShape,
371
+ strides: newStrides,
372
+ offset: this.offset,
373
+ numel: this.numel,
374
+ device: this.device
375
+ });
376
+ out.requiresGrad = this.requiresGrad;
377
+ // Handle gradient if needed
378
+ if (this.requiresGrad) {
379
+ out.children.push(this);
380
+ out.gradFn = () => {
381
+ Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
382
+ };
383
+ }
384
+ return out;
385
+ }
386
+ swapaxes = this.transpose;
387
+ swapdims = this.transpose;
388
+ // Transpose 2D
389
+ t() {
390
+ // Verify matrix shape
391
+ if (this.shape.length !== 2) {
392
+ throw new Error("Input is not a matrix");
393
+ }
394
+ return this.transpose(0, 1);
395
+ }
396
+ // Permute
397
+ permute(dims) {
398
+ dims = Tensor.normalizeDims(dims, this.shape.length);
399
+ if (dims.length !== this.shape.length) {
400
+ throw new Error("Permutation must specify all dimensions");
401
+ }
402
+ // Compute new shape and strides
403
+ const newShape = new Array(dims.length);
404
+ const newStrides = new Array(dims.length);
405
+ for (let index = 0; index < dims.length; index++) {
406
+ const dim = dims[index];
407
+ newShape[index] = this.shape[dim];
408
+ newStrides[index] = this.strides[dim];
409
+ }
410
+ const out = new Tensor(this.value, {
411
+ shape: newShape,
412
+ strides: newStrides,
413
+ offset: this.offset,
414
+ numel: this.numel,
415
+ device: this.device
416
+ });
417
+ if (this.requiresGrad) {
418
+ out.requiresGrad = true;
419
+ out.children.push(this);
420
+ out.gradFn = () => {
421
+ // Compute inverse permutation
422
+ const inverseAxes = new Array(dims.length);
423
+ for (let i = 0; i < dims.length; i++) {
424
+ inverseAxes[dims[i]] = i;
425
+ }
426
+ // Permute gradient back to original order
427
+ const permutedGrad = out.grad.permute(inverseAxes);
428
+ Tensor.addGrad(this, permutedGrad);
429
+ };
430
+ }
431
+ return out;
432
+ }
433
+ // Tensor slicing
434
+ slice(ranges) {
435
+ // Handle scalars
436
+ if (typeof this.value === "number")
437
+ return this;
438
+ const newShape = [];
439
+ const newStrides = [];
440
+ let newOffset = this.offset || 0;
441
+ // Pad ranges to match tensor dimensions
442
+ const paddedRanges = [...ranges];
443
+ while (paddedRanges.length < this.shape.length) {
444
+ paddedRanges.push([]);
445
+ }
446
+ for (let i = 0; i < this.shape.length; i++) {
447
+ const range = paddedRanges[i] || [];
448
+ const dimSize = this.shape[i];
449
+ const stride = this.strides[i];
450
+ // Default values
451
+ let start = range[0] ?? 0;
452
+ let end = range[1] ?? dimSize;
453
+ let step = range[2] ?? 1;
454
+ // Handle negative indices
455
+ if (start < 0)
456
+ start += dimSize;
457
+ if (end < 0)
458
+ end += dimSize;
459
+ // Clamp to valid range
460
+ start = Math.max(0, Math.min(start, dimSize));
461
+ end = Math.max(0, Math.min(end, dimSize));
462
+ // Calculate new dimension size
463
+ const newDimSize = step > 0
464
+ ? Math.max(0, Math.ceil((end - start) / step))
465
+ : Math.max(0, Math.ceil((start - end) / Math.abs(step)));
466
+ newShape.push(newDimSize);
467
+ newStrides.push(stride * step);
468
+ newOffset += start * stride;
469
+ }
470
+ const out = new Tensor(this.value, {
471
+ shape: newShape,
472
+ strides: newStrides,
473
+ offset: newOffset,
474
+ device: this.device
475
+ });
476
+ if (this.requiresGrad) {
477
+ out.requiresGrad = true;
478
+ out.children.push(this);
479
+ out.gradFn = () => {
480
+ // Create zero tensor of original shape
481
+ const zeroGrad = Tensor.zerosLike(this);
482
+ // Upstream grad
483
+ const outGrad = out.grad;
484
+ const totalElements = outGrad.numel;
485
+ for (let i = 0; i < totalElements; i++) {
486
+ // Convert flat index to coordinates in sliced tensor
487
+ const slicedCoords = Tensor.indexToCoords(i, outGrad.strides);
488
+ // Map back to original coordinates
489
+ const originalCoords = new Array(slicedCoords.length);
490
+ for (let dim = 0; dim < slicedCoords.length; dim++) {
491
+ const coord = slicedCoords[dim];
492
+ const range = paddedRanges[dim] || [];
493
+ const start = range[0] ?? 0;
494
+ const step = range[2] ?? 1;
495
+ const normalizedStart = start < 0 ? start + this.shape[dim] : start;
496
+ originalCoords[dim] = normalizedStart + coord * step;
497
+ }
498
+ // Get flat indices with offsets
499
+ const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
500
+ const targetIndex = Tensor.coordsToIndex(originalCoords, zeroGrad.strides) + zeroGrad.offset;
501
+ // Accumulate gradient
502
+ zeroGrad.value[targetIndex] += outGrad.value[srcIndex];
503
+ }
504
+ Tensor.addGrad(this, zeroGrad);
505
+ };
506
+ }
507
+ return out;
508
+ }
321
509
  // Tensor squeeze
322
510
  squeeze(dims) {
323
511
  if (typeof this.value === "number")
@@ -334,6 +522,7 @@ class Tensor {
334
522
  }
335
523
  }
336
524
  }
525
+ dims = Tensor.normalizeDims(dims, this.shape.length);
337
526
  // Remove size-1 dims only
338
527
  const outShape = [], outStrides = [];
339
528
  for (let index = 0; index < this.shape.length; index++) {
@@ -348,10 +537,11 @@ class Tensor {
348
537
  outStrides.push(stride);
349
538
  }
350
539
  }
351
- const outValue = outShape.length === 0 ? this.value[0] : this.value;
540
+ const outValue = outShape.length === 0 ? this.value[this.offset] : this.value;
352
541
  const out = new Tensor(outValue, {
353
542
  shape: outShape,
354
543
  strides: outStrides,
544
+ offset: this.offset,
355
545
  device: this.device
356
546
  });
357
547
  // Set up gradient if needed
@@ -370,6 +560,10 @@ class Tensor {
370
560
  }
371
561
  // Tensor unsqueeze - adds dimension of size 1 at specified position
372
562
  unsqueeze(dim) {
563
+ // Handle negative indices
564
+ if (dim < 0) {
565
+ dim += this.shape.length;
566
+ }
373
567
  let thisValue = this.value;
374
568
  if (typeof thisValue === "number") {
375
569
  thisValue = [thisValue];
@@ -389,7 +583,12 @@ class Tensor {
389
583
  newDimStride = this.strides[dim] * this.shape[dim];
390
584
  }
391
585
  newStrides.splice(dim, 0, newDimStride);
392
- const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, device: this.device });
586
+ const out = new Tensor(thisValue, {
587
+ shape: newShape,
588
+ strides: newStrides,
589
+ offset: this.offset,
590
+ device: this.device
591
+ });
393
592
  // Set up gradient if needed
394
593
  if (this.requiresGrad) {
395
594
  out.requiresGrad = true;
@@ -400,325 +599,138 @@ class Tensor {
400
599
  }
401
600
  return out;
402
601
  }
403
- // Tensor sum reduction
404
- sum(dims, keepDims = false) {
405
- if (typeof this.value === "number")
406
- return this;
602
+ // Generic reduction operation handler
603
+ static reduce(tensor, dims, keepDims, config) {
604
+ if (typeof tensor.value === "number")
605
+ return tensor;
407
606
  if (typeof dims === "undefined") {
408
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
607
+ dims = Array.from({ length: tensor.shape.length }, (_, index) => index);
409
608
  }
410
609
  if (Array.isArray(dims)) {
411
- // Sort in descending order
610
+ dims = Tensor.normalizeDims(dims, tensor.shape.length);
412
611
  const sortedDims = dims.sort((a, b) => b - a);
413
- let reducedThis = this;
612
+ let reducedThis = tensor;
414
613
  for (let i = 0; i < sortedDims.length; i++) {
415
- reducedThis = reducedThis.sum(sortedDims[i], true);
614
+ reducedThis = Tensor.reduce(reducedThis, sortedDims[i], true, config);
416
615
  }
417
616
  return keepDims ? reducedThis : reducedThis.squeeze(dims);
418
617
  }
419
- // Dims that are reduced now have size-1
420
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
618
+ const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
421
619
  const outputStrides = Tensor.getStrides(outputShape);
422
620
  const outputSize = Tensor.shapeToSize(outputShape);
423
- const outputValue = new Array(outputSize).fill(0);
424
- const originalSize = Tensor.shapeToSize(this.shape);
425
- // Gradient data
426
- let gradShape, gradStrides, gradValue = [];
427
- // Allocate gradient data only when needed
428
- if (this.requiresGrad) {
429
- gradShape = this.shape;
430
- gradStrides = this.strides;
431
- gradValue = new Array(originalSize).fill(0);
432
- }
433
- // Calculate new value after sum
434
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
435
- // Force 0 on reduced axes to collapse into size-1 dims
436
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
437
- outCoords[dims] = 0;
438
- // Convert output coordinates to flat index
439
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
440
- // Add into sum
441
- outputValue[outFlatIndex] += this.value[realFlatIndex];
442
- // Mark for gradient if needed
443
- if (this.requiresGrad) {
444
- gradValue[realFlatIndex] = 1;
621
+ const outputValue = new Array(outputSize).fill(config.identity);
622
+ const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
623
+ const originalSize = tensor.numel;
624
+ const originalValue = tensor.value;
625
+ const linearStrides = Tensor.getStrides(tensor.shape);
626
+ // Forward pass
627
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
628
+ // Convert linear index to coordinates using contiguous strides
629
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
630
+ // Convert coordinates to actual strided index
631
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
632
+ // Convert coords to reduced index
633
+ coords[dims] = 0;
634
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
635
+ // Apply op
636
+ outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
637
+ // Count el if needed
638
+ if (config.needsCounters) {
639
+ outputCounters[outFlatIndex]++;
445
640
  }
446
641
  }
447
- const out = new Tensor(outputValue, {
448
- shape: outputShape,
449
- strides: outputStrides
450
- });
451
- // Set up gradient if needed
452
- if (this.requiresGrad) {
642
+ // Post-process if needed (e.g., divide by count for mean)
643
+ if (config.postProcess) {
644
+ config.postProcess({ values: outputValue, counters: outputCounters });
645
+ }
646
+ const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
647
+ // Gradient setup
648
+ if (tensor.requiresGrad) {
453
649
  out.requiresGrad = true;
454
- out.children.push(this);
650
+ out.children.push(tensor);
455
651
  out.gradFn = () => {
456
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
457
- Tensor.addGrad(this, out.grad.mul(localGrad));
652
+ let shareCounts = [];
653
+ if (config.needsShareCounts) {
654
+ shareCounts = new Array(outputSize).fill(0);
655
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
656
+ // Convert linear index to coordinates using contiguous strides
657
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
658
+ // Convert coordinates to actual strided index
659
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
660
+ // Convert coords to reduced index
661
+ coords[dims] = 0;
662
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
663
+ // We collect how many elements share the same max value first
664
+ shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
665
+ }
666
+ }
667
+ const gradValue = new Array(originalSize);
668
+ for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
669
+ // Convert linear index to coordinates using contiguous strides
670
+ const coords = Tensor.indexToCoords(flatIndex, linearStrides);
671
+ // Convert coordinates to actual strided index
672
+ const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
673
+ // Convert coords to reduced index
674
+ coords[dims] = 0;
675
+ const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
676
+ gradValue[flatIndex] = config.gradientFn({
677
+ outputValue,
678
+ originalValue: tensor.value,
679
+ counters: outputCounters,
680
+ shareCounts,
681
+ realIndex: realFlatIndex,
682
+ outIndex: outFlatIndex
683
+ });
684
+ }
685
+ const localGrad = new Tensor(gradValue, { shape: tensor.shape, numel: tensor.numel });
686
+ Tensor.addGrad(tensor, out.grad.mul(localGrad));
458
687
  };
459
688
  }
460
689
  return keepDims ? out : out.squeeze(dims);
461
690
  }
462
- // Tensor product reduction
691
+ // Simplified reduction operations
692
+ sum(dims, keepDims = false) {
693
+ return Tensor.reduce(this, dims, keepDims, {
694
+ identity: 0,
695
+ operation: (a, b) => a + b,
696
+ gradientFn: ({}) => 1
697
+ });
698
+ }
463
699
  prod(dims, keepDims = false) {
464
- if (typeof this.value === "number")
465
- return this;
466
- if (typeof dims === "undefined") {
467
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
468
- }
469
- if (Array.isArray(dims)) {
470
- // Sort in descending order
471
- const sortedDims = dims.sort((a, b) => b - a);
472
- let reducedThis = this;
473
- for (let i = 0; i < sortedDims.length; i++) {
474
- reducedThis = reducedThis.prod(sortedDims[i], true);
475
- }
476
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
477
- }
478
- // Dims that are reduced now have size-1
479
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
480
- const outputStrides = Tensor.getStrides(outputShape);
481
- const outputSize = Tensor.shapeToSize(outputShape);
482
- const outputValue = new Array(outputSize).fill(1);
483
- const originalSize = Tensor.shapeToSize(this.shape);
484
- // Calculate new value after multiplying
485
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
486
- // Force 0 on reduced axes to collapse into size-1 dims
487
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
488
- outCoords[dims] = 0;
489
- // Convert output coordinates to flat index
490
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
491
- // Multiply into product
492
- outputValue[outFlatIndex] *= this.value[realFlatIndex];
493
- }
494
- const out = new Tensor(outputValue, {
495
- shape: outputShape,
496
- strides: outputStrides
700
+ return Tensor.reduce(this, dims, keepDims, {
701
+ identity: 1,
702
+ operation: (a, b) => a * b,
703
+ gradientFn: ({ outputValue, originalValue, realIndex, outIndex }) => outputValue[outIndex] / originalValue[realIndex]
497
704
  });
498
- // Set up gradient if needed
499
- if (this.requiresGrad) {
500
- out.requiresGrad = true;
501
- out.children.push(this);
502
- out.gradFn = () => {
503
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
504
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
505
- // Force 0 on reduced axes to collapse into size-1 dims
506
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
507
- outCoords[dims] = 0;
508
- // Convert output coordinates to flat index
509
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
510
- // Grad is the product of other elements of the same axis, which is product of all els divided by the current value
511
- gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
512
- }
513
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
514
- Tensor.addGrad(this, out.grad.mul(localGrad));
515
- };
516
- }
517
- return keepDims ? out : out.squeeze(dims);
518
705
  }
519
- // Tensor mean reduction
520
706
  mean(dims, keepDims = false) {
521
- if (typeof this.value === "number")
522
- return this;
523
- if (typeof dims === "undefined") {
524
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
525
- }
526
- if (Array.isArray(dims)) {
527
- // Sort in descending order
528
- const sortedDims = dims.sort((a, b) => b - a);
529
- let reducedThis = this;
530
- for (let i = 0; i < sortedDims.length; i++) {
531
- reducedThis = reducedThis.mean(sortedDims[i], true);
532
- }
533
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
534
- }
535
- // Dims that are reduced now have size-1
536
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
537
- const outputStrides = Tensor.getStrides(outputShape);
538
- const outputSize = Tensor.shapeToSize(outputShape);
539
- const outputValue = new Array(outputSize).fill(0);
540
- const outputFeeders = new Array(outputSize).fill(0);
541
- const originalSize = Tensor.shapeToSize(this.shape);
542
- // Calculate sums and how many elements contribute to specific positions
543
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
544
- // Force 0 on reduced axes to collapse into size-1 dims
545
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
546
- outCoords[dims] = 0;
547
- // Convert output coordinates to flat index
548
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
549
- // Calculate sum and contributors to the sum
550
- outputValue[outFlatIndex] += this.value[realFlatIndex];
551
- outputFeeders[outFlatIndex]++;
552
- }
553
- // Calculate mean by dividing sum by the number of contributors to the position
554
- for (let index = 0; index < outputSize; index++) {
555
- outputValue[index] /= outputFeeders[index];
556
- }
557
- const out = new Tensor(outputValue, {
558
- shape: outputShape,
559
- strides: outputStrides
560
- });
561
- // Set up gradient if needed
562
- if (this.requiresGrad) {
563
- out.requiresGrad = true;
564
- out.children.push(this);
565
- out.gradFn = () => {
566
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
567
- // Calculate grad by assigning 1 divided by the number of contributors to the position
568
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
569
- // Force 0 on reduced axes to collapse into size-1 dims
570
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
571
- outCoords[dims] = 0;
572
- // Convert output coordinates to flat index
573
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
574
- // Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
575
- gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
707
+ return Tensor.reduce(this, dims, keepDims, {
708
+ identity: 0,
709
+ operation: (a, b) => a + b,
710
+ needsCounters: true,
711
+ postProcess: ({ values, counters }) => {
712
+ for (let i = 0; i < values.length; i++) {
713
+ values[i] /= counters[i];
576
714
  }
577
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
578
- Tensor.addGrad(this, out.grad.mul(localGrad));
579
- };
580
- }
581
- return keepDims ? out : out.squeeze(dims);
715
+ },
716
+ gradientFn: ({ counters, outIndex }) => 1 / counters[outIndex]
717
+ });
582
718
  }
583
- // Tensor maximum reduction
584
719
  max(dims, keepDims = false) {
585
- if (typeof this.value === "number")
586
- return this;
587
- if (typeof dims === "undefined") {
588
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
589
- }
590
- if (Array.isArray(dims)) {
591
- // Sort in descending order
592
- const sortedDims = dims.sort((a, b) => b - a);
593
- let reducedThis = this;
594
- for (let i = 0; i < sortedDims.length; i++) {
595
- reducedThis = reducedThis.max(sortedDims[i], true);
596
- }
597
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
598
- }
599
- // Dims that are reduced now have size-1
600
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
601
- const outputStrides = Tensor.getStrides(outputShape);
602
- const outputSize = Tensor.shapeToSize(outputShape);
603
- const outputValue = new Array(outputSize).fill(-Infinity);
604
- const originalSize = Tensor.shapeToSize(this.shape);
605
- // Calculate maximum values of axes
606
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
607
- // Force 0 on reduced axes to collapse into size-1 dims
608
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
609
- outCoords[dims] = 0;
610
- // Convert output coordinates to flat index
611
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
612
- // Get max over time
613
- if (this.value[realFlatIndex] > outputValue[outFlatIndex]) {
614
- outputValue[outFlatIndex] = this.value[realFlatIndex];
615
- }
616
- }
617
- const out = new Tensor(outputValue, {
618
- shape: outputShape,
619
- strides: outputStrides
720
+ return Tensor.reduce(this, dims, keepDims, {
721
+ identity: -Infinity,
722
+ operation: (a, b) => Math.max(a, b),
723
+ needsShareCounts: true,
724
+ gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
620
725
  });
621
- // Set up gradient if needed
622
- if (this.requiresGrad) {
623
- out.requiresGrad = true;
624
- out.children.push(this);
625
- out.gradFn = () => {
626
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
627
- const shareCounts = new Array(outputSize).fill(0);
628
- const originalValue = this.value;
629
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
630
- // Force 0 on reduced axes to collapse into size-1 dims
631
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
632
- outCoords[dims] = 0;
633
- // Convert output coordinates to flat index
634
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
635
- // We collect how many elements share the same max value first
636
- shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
637
- }
638
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
639
- // Force 0 on reduced axes to collapse into size-1 dims
640
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
641
- outCoords[dims] = 0;
642
- // Convert output coordinates to flat index
643
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
644
- // Here we share the grad between the elements that share the same max value
645
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
646
- }
647
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
648
- Tensor.addGrad(this, out.grad.mul(localGrad));
649
- };
650
- }
651
- return keepDims ? out : out.squeeze(dims);
652
726
  }
653
- // Tensor minimum reduction
654
727
  min(dims, keepDims = false) {
655
- if (typeof this.value === "number")
656
- return this;
657
- if (typeof dims === "undefined") {
658
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
659
- }
660
- if (Array.isArray(dims)) {
661
- // Sort in descending order
662
- const sortedDims = dims.sort((a, b) => b - a);
663
- let reducedThis = this;
664
- for (let i = 0; i < sortedDims.length; i++) {
665
- reducedThis = reducedThis.min(sortedDims[i], true);
666
- }
667
- return keepDims ? reducedThis : reducedThis.squeeze(dims);
668
- }
669
- // Dims that are reduced now have size-1
670
- const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
671
- const outputStrides = Tensor.getStrides(outputShape);
672
- const outputSize = Tensor.shapeToSize(outputShape);
673
- const outputValue = new Array(outputSize).fill(Infinity);
674
- const originalSize = Tensor.shapeToSize(this.shape);
675
- // Calculate minimum values of axes
676
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
677
- // Force 0 on reduced axes to collapse into size-1 dims
678
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
679
- outCoords[dims] = 0;
680
- // Convert output coordinates to flat index
681
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
682
- // Get min over time
683
- if (this.value[realFlatIndex] < outputValue[outFlatIndex]) {
684
- outputValue[outFlatIndex] = this.value[realFlatIndex];
685
- }
686
- }
687
- const out = new Tensor(outputValue, {
688
- shape: outputShape,
689
- strides: outputStrides
728
+ return Tensor.reduce(this, dims, keepDims, {
729
+ identity: Infinity,
730
+ operation: (a, b) => Math.min(a, b),
731
+ needsShareCounts: true,
732
+ gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
690
733
  });
691
- // Set up gradient if needed
692
- if (this.requiresGrad) {
693
- out.requiresGrad = true;
694
- out.children.push(this);
695
- out.gradFn = () => {
696
- const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
697
- const shareCounts = new Array(outputSize).fill(0);
698
- const originalValue = this.value;
699
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
700
- // Force 0 on reduced axes to collapse into size-1 dims
701
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
702
- outCoords[dims] = 0;
703
- // Convert output coordinates to flat index
704
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
705
- // We collect how many elements share the same min value first
706
- shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
707
- }
708
- for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
709
- // Force 0 on reduced axes to collapse into size-1 dims
710
- const outCoords = Tensor.indexToCoords(realFlatIndex, this.strides);
711
- outCoords[dims] = 0;
712
- // Convert output coordinates to flat index
713
- const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
714
- // Here we share the grad between the elements that share the same min value
715
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
716
- }
717
- const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
718
- Tensor.addGrad(this, out.grad.mul(localGrad));
719
- };
720
- }
721
- return keepDims ? out : out.squeeze(dims);
722
734
  }
723
735
  // Tensor all condition reduction
724
736
  all(dims, keepDims = false) {
@@ -738,75 +750,18 @@ class Tensor {
738
750
  std(dims, keepDims = false) {
739
751
  return this.var(dims, keepDims).sqrt();
740
752
  }
741
- // Tensor product reduction
742
- softmax(dims) {
753
+ // Tensor softmax
754
+ softmax(dim = -1) {
743
755
  if (typeof this.value === "number")
744
756
  return this;
745
- if (typeof dims === "undefined") {
746
- dims = Array.from({ length: this.shape.length }, (_, index) => index);
747
- }
748
- if (Array.isArray(dims)) {
749
- // Sort in descending order
750
- const sortedDims = dims.sort((a, b) => b - a);
751
- let reducedThis = this;
752
- for (let i = 0; i < sortedDims.length; i++) {
753
- reducedThis = reducedThis.softmax(sortedDims[i]);
754
- }
755
- return reducedThis;
756
- }
757
- // Dims that are reduced now have size-1
758
- const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
759
- const expSumStrides = Tensor.getStrides(expSumShape);
760
- const expSumSize = Tensor.shapeToSize(expSumShape);
761
- const expSumValue = new Array(expSumSize).fill(0);
762
- const outputShape = this.shape;
763
- const outputStrides = this.strides;
764
- const outputSize = Tensor.shapeToSize(outputShape);
765
- const outputValue = new Array(outputSize);
766
- // Calculate sums of e^xi over axes
767
- for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
768
- const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
769
- // Force 0 on reduced axes to collapse into size-1 dims
770
- const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
771
- // Convert exp sum coordinates to flat index
772
- const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
773
- // Add e^x to the sum cache
774
- expSumValue[expSumFlatIndex] += Math.exp(this.value[realFlatIndex]);
775
- }
776
- // Calculate e^xi / sum over axes
777
- for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
778
- const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
779
- // Force 0 on reduced axes to collapse into size-1 dims
780
- const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
781
- // Convert exp sum coordinates to flat index
782
- const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
783
- // Calculate e^xi / sum
784
- outputValue[realFlatIndex] = Math.exp(this.value[realFlatIndex]) / expSumValue[expSumFlatIndex];
785
- }
786
- const out = new Tensor(outputValue, {
787
- shape: outputShape,
788
- strides: outputStrides
789
- });
790
- // Set up gradient if needed
791
- if (this.requiresGrad) {
792
- out.requiresGrad = true;
793
- out.children.push(this);
794
- out.gradFn = () => {
795
- const upstreamGrad = out.grad;
796
- const softmaxOutput = out.detach();
797
- // Compute element-wise product: ∂L/∂σᵢ × σᵢ
798
- const gradTimesOutput = upstreamGrad.mul(softmaxOutput);
799
- // Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ)
800
- const sumGradOutput = gradTimesOutput.sum(dims, true); // keepDims=true for broadcasting
801
- // Apply softmax gradient formula:
802
- // ∂L/∂zⱼ = (∂L/∂σⱼ × σⱼ) - (σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ))
803
- const term1 = upstreamGrad.mul(softmaxOutput); // ∂L/∂σⱼ × σⱼ
804
- const term2 = softmaxOutput.mul(sumGradOutput); // σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ)
805
- const localGrad = term1.sub(term2);
806
- Tensor.addGrad(this, localGrad);
807
- };
808
- }
809
- return out;
757
+ // Handle negative indexing
758
+ if (dim < 0)
759
+ dim = this.shape.length + dim;
760
+ const maxVals = this.max(dim, true);
761
+ const shifted = this.sub(maxVals);
762
+ const expVals = shifted.exp();
763
+ const sumExp = expVals.sum(dim, true);
764
+ return expVals.div(sumExp);
810
765
  }
811
766
  // Tensor element-wise addition
812
767
  add(other) {
@@ -1139,76 +1094,6 @@ class Tensor {
1139
1094
  erfinv() {
1140
1095
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfinv)(a), (self, outGrad) => outGrad.mul(self.erfinv().square().exp().mul(Math.sqrt(Math.PI) / 2)));
1141
1096
  }
1142
- // Transpose
1143
- transpose(dim1, dim2) {
1144
- // If dimension out of bound, throw error
1145
- if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
1146
- throw new Error("Dimensions do not exist to tranpose");
1147
- }
1148
- // If same dimension, return copy
1149
- if (dim1 === dim2) {
1150
- return new Tensor(this.value, { shape: this.shape, strides: this.strides });
1151
- }
1152
- // Create new shape and strides by swapping
1153
- const newShape = [...this.shape];
1154
- const newStrides = [...this.strides];
1155
- [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
1156
- [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
1157
- // Create new tensor with same data but swapped shape/strides
1158
- const out = new Tensor(this.value, { shape: newShape, strides: newStrides, device: this.device });
1159
- out.requiresGrad = this.requiresGrad;
1160
- // Handle gradient if needed
1161
- if (this.requiresGrad) {
1162
- out.children.push(this);
1163
- out.gradFn = () => {
1164
- Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
1165
- };
1166
- }
1167
- return out;
1168
- }
1169
- swapaxes = this.transpose;
1170
- swapdims = this.transpose;
1171
- // Transpose 2D
1172
- t() {
1173
- // Verify matrix shape
1174
- if (this.shape.length !== 2) {
1175
- throw new Error("Input is not a matrix");
1176
- }
1177
- return this.transpose(0, 1);
1178
- }
1179
- // Permute
1180
- permute(dims) {
1181
- if (dims.length !== this.shape.length) {
1182
- throw new Error("Permutation must specify all dimensions");
1183
- }
1184
- // Compute new shape and strides
1185
- const newShape = new Array(dims.length);
1186
- const newStrides = new Array(dims.length);
1187
- for (let index = 0; index < dims.length; index++) {
1188
- const dim = dims[index];
1189
- newShape[index] = this.shape[dim];
1190
- newStrides[index] = this.strides[dim];
1191
- }
1192
- const out = new Tensor(this.value, {
1193
- shape: newShape,
1194
- strides: newStrides
1195
- });
1196
- if (this.requiresGrad) {
1197
- out.requiresGrad = true;
1198
- out.children.push(this);
1199
- out.gradFn = () => {
1200
- // Compute inverse permutation
1201
- const inverseAxes = new Array(dims.length);
1202
- for (let i = 0; i < dims.length; i++) {
1203
- inverseAxes[dims[i]] = i;
1204
- }
1205
- // Permute gradient back to original order
1206
- const permutedGrad = out.grad.permute(inverseAxes);
1207
- Tensor.addGrad(this, permutedGrad);
1208
- };
1209
- }
1210
- return out;
1211
- }
1212
1097
  // 1D tensor dot product
1213
1098
  dot(other) {
1214
1099
  other = this.handleOther(other);
@@ -1216,36 +1101,7 @@ class Tensor {
1216
1101
  if (this.shape.length !== 1 || other.shape.length !== 1) {
1217
1102
  throw new Error("Inputs are not 1D tensors");
1218
1103
  }
1219
- // Simple vector dot product
1220
- const vectLen = this.shape[0];
1221
- const vectA = this.value;
1222
- const vectB = other.value;
1223
- let sum = 0;
1224
- for (let index = 0; index < vectLen; index++) {
1225
- sum += vectA[index] * vectB[index];
1226
- }
1227
- const out = new Tensor(sum);
1228
- if (this.requiresGrad) {
1229
- out.requiresGrad = true;
1230
- out.children.push(this);
1231
- }
1232
- if (other.requiresGrad) {
1233
- out.requiresGrad = true;
1234
- out.children.push(other);
1235
- }
1236
- if (out.requiresGrad) {
1237
- out.gradFn = () => {
1238
- // Disable gradient collecting of gradients themselves
1239
- const outGrad = out.grad;
1240
- const selfNoGrad = this.detach();
1241
- const otherNoGrad = other.detach();
1242
- if (this.requiresGrad)
1243
- Tensor.addGrad(this, outGrad.mul(otherNoGrad));
1244
- if (other.requiresGrad)
1245
- Tensor.addGrad(other, outGrad.mul(selfNoGrad));
1246
- };
1247
- }
1248
- return out;
1104
+ return this.mul(other).sum();
1249
1105
  }
1250
1106
  // Matrix multiplication
1251
1107
  mm(other) {
@@ -1274,12 +1130,12 @@ class Tensor {
1274
1130
  for (let k = 0; k < matACols; k++) {
1275
1131
  // Tensor values are 1D arrays so we have to get real index using strides
1276
1132
  matC[i * matCStrides[0] + j * matCStrides[1]] +=
1277
- matA[i * matAStrides[0] + k * matAStrides[1]] *
1278
- matB[k * matBStrides[0] + j * matBStrides[1]];
1133
+ matA[i * matAStrides[0] + k * matAStrides[1] + this.offset] *
1134
+ matB[k * matBStrides[0] + j * matBStrides[1] + other.offset];
1279
1135
  }
1280
1136
  }
1281
1137
  }
1282
- const out = new Tensor(matC, { shape: matCShape, strides: matCStrides });
1138
+ const out = new Tensor(matC, { shape: matCShape, strides: matCStrides, numel: matCSize });
1283
1139
  if (this.requiresGrad) {
1284
1140
  out.requiresGrad = true;
1285
1141
  out.children.push(this);
@@ -1331,13 +1187,13 @@ class Tensor {
1331
1187
  for (let k = 0; k < batchACols; k++) {
1332
1188
  // Tensor values are 1D arrays so we have to get real index using strides
1333
1189
  batchC[q * batchCStrides[0] + i * batchCStrides[1] + j * batchCStrides[2]] +=
1334
- batchA[q * batchAStrides[0] + i * batchAStrides[1] + k * batchAStrides[2]] *
1335
- batchB[q * batchBStrides[0] + k * batchBStrides[1] + j * batchBStrides[2]];
1190
+ batchA[q * batchAStrides[0] + i * batchAStrides[1] + k * batchAStrides[2] + this.offset] *
1191
+ batchB[q * batchBStrides[0] + k * batchBStrides[1] + j * batchBStrides[2] + other.offset];
1336
1192
  }
1337
1193
  }
1338
1194
  }
1339
1195
  }
1340
- const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides });
1196
+ const out = new Tensor(batchC, { shape: batchCShape, strides: batchCStrides, numel: batchCSize });
1341
1197
  if (this.requiresGrad) {
1342
1198
  out.requiresGrad = true;
1343
1199
  out.children.push(this);
@@ -1410,7 +1266,7 @@ class Tensor {
1410
1266
  const otherOffsetShape = otherShape.slice(0, -2);
1411
1267
  const selfOffsetStrides = selfStrides.slice(0, -2);
1412
1268
  const otherOffsetStrides = otherStrides.slice(0, -2);
1413
- // The output's offset data
1269
+ // Base offset data
1414
1270
  const offsetShape = Tensor.broadcastShapes(selfOffsetShape, otherOffsetShape);
1415
1271
  const offsetSize = Tensor.shapeToSize(offsetShape);
1416
1272
  const offsetStrides = Tensor.getStrides(offsetShape);
@@ -1419,10 +1275,11 @@ class Tensor {
1419
1275
  const outputStrides = Tensor.getStrides(outputShape);
1420
1276
  const outputSize = Tensor.shapeToSize(outputShape);
1421
1277
  const outputValue = new Array(outputSize).fill(0);
1278
+ const outputOffsetStrides = outputStrides.slice(0, -2);
1422
1279
  // Loop through outer dims and do matmul on two outer-most dims
1423
1280
  for (let index = 0; index < offsetSize; index++) {
1424
1281
  const coords = Tensor.indexToCoords(index, offsetStrides);
1425
- const offset = Tensor.coordsToIndex(coords, outputStrides.slice(0, -2));
1282
+ const offset = Tensor.coordsToIndex(coords, outputOffsetStrides);
1426
1283
  const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
1427
1284
  const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
1428
1285
  for (let i = 0; i < batchARows; i++) {
@@ -1431,12 +1288,12 @@ class Tensor {
1431
1288
  const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
1432
1289
  const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
1433
1290
  const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
1434
- outputValue[outputIdx] += batchA[selfIdx] * batchB[otherIdx];
1291
+ outputValue[outputIdx] += batchA[selfIdx + this.offset] * batchB[otherIdx + other.offset];
1435
1292
  }
1436
1293
  }
1437
1294
  }
1438
1295
  }
1439
- const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides });
1296
+ const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
1440
1297
  if (this.requiresGrad) {
1441
1298
  out.requiresGrad = true;
1442
1299
  out.children.push(this);
@@ -1452,9 +1309,9 @@ class Tensor {
1452
1309
  const selfNoGrad = self.detach();
1453
1310
  const otherNoGrad = other.detach();
1454
1311
  if (this.requiresGrad)
1455
- Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
1312
+ Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(-2, -1)));
1456
1313
  if (other.requiresGrad)
1457
- Tensor.addGrad(other, selfNoGrad.transpose(lastDim - 1, lastDim).matmul(outGrad));
1314
+ Tensor.addGrad(other, selfNoGrad.transpose(-2, -1).matmul(outGrad));
1458
1315
  };
1459
1316
  }
1460
1317
  return out;
@@ -1476,15 +1333,15 @@ class Tensor {
1476
1333
  return new Tensor(num, options);
1477
1334
  const outputSize = Tensor.shapeToSize(shape);
1478
1335
  const outputValue = new Array(outputSize).fill(num);
1479
- return new Tensor(outputValue, { shape, ...options });
1336
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1480
1337
  }
1481
1338
  // Utility to create a new tensor with shape of another tensor, filled with a number
1482
1339
  static fullLike(tensor, num, options = {}) {
1483
1340
  if (typeof tensor.value === "number")
1484
1341
  return new Tensor(num, options);
1485
- return new Tensor(new Array(tensor.value.length).fill(num), {
1342
+ return new Tensor(new Array(tensor.numel).fill(num), {
1486
1343
  shape: tensor.shape,
1487
- strides: tensor.strides,
1344
+ numel: tensor.numel,
1488
1345
  device: tensor.device,
1489
1346
  ...options
1490
1347
  });
@@ -1495,15 +1352,15 @@ class Tensor {
1495
1352
  return new Tensor(1, options);
1496
1353
  const outputSize = Tensor.shapeToSize(shape);
1497
1354
  const outputValue = new Array(outputSize).fill(1);
1498
- return new Tensor(outputValue, { shape, ...options });
1355
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1499
1356
  }
1500
1357
  // Utility to create a new tensor with shape of another tensor, filled with 1
1501
1358
  static onesLike(tensor, options = {}) {
1502
1359
  if (typeof tensor.value === "number")
1503
1360
  return new Tensor(1, options);
1504
- return new Tensor(new Array(tensor.value.length).fill(1), {
1361
+ return new Tensor(new Array(tensor.numel).fill(1), {
1505
1362
  shape: tensor.shape,
1506
- strides: tensor.strides,
1363
+ numel: tensor.numel,
1507
1364
  device: tensor.device,
1508
1365
  ...options
1509
1366
  });
@@ -1514,15 +1371,15 @@ class Tensor {
1514
1371
  return new Tensor(0, options);
1515
1372
  const outputSize = Tensor.shapeToSize(shape);
1516
1373
  const outputValue = new Array(outputSize).fill(0);
1517
- return new Tensor(outputValue, { shape, ...options });
1374
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1518
1375
  }
1519
1376
  // Utility to create a new tensor with shape of another tensor, filled with 0
1520
1377
  static zerosLike(tensor, options = {}) {
1521
1378
  if (typeof tensor.value === "number")
1522
1379
  return new Tensor(0, options);
1523
- return new Tensor(new Array(tensor.value.length).fill(0), {
1380
+ return new Tensor(new Array(tensor.numel).fill(0), {
1524
1381
  shape: tensor.shape,
1525
- strides: tensor.strides,
1382
+ numel: tensor.numel,
1526
1383
  device: tensor.device,
1527
1384
  ...options
1528
1385
  });
@@ -1536,19 +1393,19 @@ class Tensor {
1536
1393
  for (let index = 0; index < outputValue.length; index++) {
1537
1394
  outputValue[index] = (0, utils_1.randUniform)();
1538
1395
  }
1539
- return new Tensor(outputValue, { shape, ...options });
1396
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1540
1397
  }
1541
1398
  // Utility to create a new tensor with shape of another tensor, filled with a random number with uniform distribution from 0 to 1
1542
1399
  static randLike(tensor, options = {}) {
1543
1400
  if (typeof tensor.value === "number")
1544
1401
  return new Tensor((0, utils_1.randUniform)(), options);
1545
- const outputValue = new Array(tensor.value.length);
1402
+ const outputValue = new Array(tensor.numel);
1546
1403
  for (let index = 0; index < outputValue.length; index++) {
1547
1404
  outputValue[index] = (0, utils_1.randUniform)();
1548
1405
  }
1549
1406
  return new Tensor(outputValue, {
1550
1407
  shape: tensor.shape,
1551
- strides: tensor.strides,
1408
+ numel: tensor.numel,
1552
1409
  device: tensor.device,
1553
1410
  ...options
1554
1411
  });
@@ -1562,19 +1419,19 @@ class Tensor {
1562
1419
  for (let index = 0; index < outputValue.length; index++) {
1563
1420
  outputValue[index] = (0, utils_1.randNormal)();
1564
1421
  }
1565
- return new Tensor(outputValue, { shape, ...options });
1422
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1566
1423
  }
1567
1424
  // Utility to create a new tensor with shape of another tensor, filled with a random number with normal distribution of mean=0 and stddev=1
1568
1425
  static randnLike(tensor, options = {}) {
1569
1426
  if (typeof tensor.value === "number")
1570
1427
  return new Tensor((0, utils_1.randNormal)(), options);
1571
- const outputValue = new Array(tensor.value.length);
1428
+ const outputValue = new Array(tensor.numel);
1572
1429
  for (let index = 0; index < outputValue.length; index++) {
1573
1430
  outputValue[index] = (0, utils_1.randNormal)();
1574
1431
  }
1575
1432
  return new Tensor(outputValue, {
1576
1433
  shape: tensor.shape,
1577
- strides: tensor.strides,
1434
+ numel: tensor.numel,
1578
1435
  device: tensor.device,
1579
1436
  ...options
1580
1437
  });
@@ -1588,19 +1445,19 @@ class Tensor {
1588
1445
  for (let index = 0; index < outputValue.length; index++) {
1589
1446
  outputValue[index] = (0, utils_1.randInt)(low, high);
1590
1447
  }
1591
- return new Tensor(outputValue, { shape, ...options });
1448
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1592
1449
  }
1593
1450
  // Utility to create a new tensor with shape of another tensor, filled with a random integer between low and high
1594
1451
  static randintLike(tensor, low, high, options = {}) {
1595
1452
  if (typeof tensor.value === "number")
1596
1453
  return new Tensor((0, utils_1.randInt)(low, high), options);
1597
- const outputValue = new Array(tensor.value.length);
1454
+ const outputValue = new Array(tensor.numel);
1598
1455
  for (let index = 0; index < outputValue.length; index++) {
1599
1456
  outputValue[index] = (0, utils_1.randInt)(low, high);
1600
1457
  }
1601
1458
  return new Tensor(outputValue, {
1602
1459
  shape: tensor.shape,
1603
- strides: tensor.strides,
1460
+ numel: tensor.numel,
1604
1461
  device: tensor.device,
1605
1462
  ...options
1606
1463
  });
@@ -1614,7 +1471,7 @@ class Tensor {
1614
1471
  for (let index = 0; index < outputValue.length; index++) {
1615
1472
  outputValue[index] = (0, utils_1.randNormal)(mean, stdDev);
1616
1473
  }
1617
- return new Tensor(outputValue, { shape, ...options });
1474
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1618
1475
  }
1619
1476
  // Utility to create a new tensor filled with a random number with uniform distribution from low to high
1620
1477
  static uniform(shape, low, high, options = {}) {
@@ -1625,7 +1482,7 @@ class Tensor {
1625
1482
  for (let index = 0; index < outputValue.length; index++) {
1626
1483
  outputValue[index] = (0, utils_1.randUniform)(low, high);
1627
1484
  }
1628
- return new Tensor(outputValue, { shape, ...options });
1485
+ return new Tensor(outputValue, { shape, numel: outputSize, ...options });
1629
1486
  }
1630
1487
  // Reverse-mode autodiff call
1631
1488
  backward(options = {}) {
@@ -1674,13 +1531,15 @@ class Tensor {
1674
1531
  }
1675
1532
  return result;
1676
1533
  }
1677
- return buildNested(this.value, this.shape, this.strides);
1534
+ return buildNested(this.value, this.shape, this.strides, this.offset);
1678
1535
  }
1679
1536
  // Returns a view of the tensor with gradient turned on/off and detaches from autograd
1680
1537
  withGrad(requiresGrad) {
1681
1538
  return new Tensor(this.value, {
1682
1539
  shape: this.shape,
1683
1540
  strides: this.strides,
1541
+ offset: this.offset,
1542
+ numel: this.numel,
1684
1543
  device: this.device,
1685
1544
  requiresGrad
1686
1545
  });
@@ -1690,6 +1549,8 @@ class Tensor {
1690
1549
  return new Tensor(this.value, {
1691
1550
  shape: this.shape,
1692
1551
  strides: this.strides,
1552
+ offset: this.offset,
1553
+ numel: this.numel,
1693
1554
  device: this.device,
1694
1555
  requiresGrad: false
1695
1556
  });
@@ -1699,6 +1560,8 @@ class Tensor {
1699
1560
  return new Tensor(typeof this.value === "number" ? this.value : [...this.value], {
1700
1561
  shape: this.shape,
1701
1562
  strides: this.strides,
1563
+ offset: this.offset,
1564
+ numel: this.numel,
1702
1565
  requiresGrad: this.requiresGrad
1703
1566
  });
1704
1567
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.11",
3
+ "version": "0.6.0",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {