catniff 0.5.3 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/core.d.ts CHANGED
@@ -18,6 +18,7 @@ export declare class Tensor {
18
18
  gradFn: Function;
19
19
  children: Tensor[];
20
20
  device: string;
21
+ static training: boolean;
21
22
  constructor(value: TensorValue, options?: TensorOptions);
22
23
  static flatten(tensor: TensorValue): number[] | number;
23
24
  static getShape(tensor: TensorValue): readonly number[];
@@ -146,6 +147,7 @@ export declare class Tensor {
146
147
  bmm(other: TensorValue | Tensor): Tensor;
147
148
  mv(other: TensorValue | Tensor): Tensor;
148
149
  matmul(other: TensorValue | Tensor): Tensor;
150
+ dropout(rate: number): Tensor;
149
151
  static full(shape: number[], num: number, options?: TensorOptions): Tensor;
150
152
  static fullLike(tensor: Tensor, num: number, options?: TensorOptions): Tensor;
151
153
  static ones(shape?: number[], options?: TensorOptions): Tensor;
@@ -160,7 +162,9 @@ export declare class Tensor {
160
162
  static randintLike(tensor: Tensor, low: number, high: number, options?: TensorOptions): Tensor;
161
163
  static normal(shape: number[], mean: number, stdDev: number, options?: TensorOptions): Tensor;
162
164
  static uniform(shape: number[], low: number, high: number, options?: TensorOptions): Tensor;
163
- backward(): void;
165
+ backward(options?: {
166
+ zeroGrad?: boolean;
167
+ }): void;
164
168
  val(): TensorValue;
165
169
  withGrad(requiresGrad: boolean): Tensor;
166
170
  detach(): Tensor;
package/dist/core.js CHANGED
@@ -11,6 +11,7 @@ class Tensor {
11
11
  gradFn;
12
12
  children;
13
13
  device;
14
+ static training = false;
14
15
  constructor(value, options = {}) {
15
16
  this.value = Tensor.flatten(value);
16
17
  this.shape = options.shape || Tensor.getShape(value);
@@ -1335,6 +1336,15 @@ class Tensor {
1335
1336
  }
1336
1337
  throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
1337
1338
  }
1339
+ // Dropout
1340
+ dropout(rate) {
1341
+ if (!Tensor.training || rate === 0)
1342
+ return this;
1343
+ const keepRate = 1 - rate;
1344
+ const uniform = Tensor.randLike(this);
1345
+ const mask = uniform.lt(keepRate);
1346
+ return this.mul(mask).div(keepRate);
1347
+ }
1338
1348
  // Utility to create a new tensor filled with a number
1339
1349
  static full(shape, num, options = {}) {
1340
1350
  if (shape.length === 0)
@@ -1347,7 +1357,12 @@ class Tensor {
1347
1357
  static fullLike(tensor, num, options = {}) {
1348
1358
  if (typeof tensor.value === "number")
1349
1359
  return new Tensor(num, options);
1350
- return new Tensor(new Array(tensor.value.length).fill(num), { shape: tensor.shape, strides: tensor.strides, ...options });
1360
+ return new Tensor(new Array(tensor.value.length).fill(num), {
1361
+ shape: tensor.shape,
1362
+ strides: tensor.strides,
1363
+ device: tensor.device,
1364
+ ...options
1365
+ });
1351
1366
  }
1352
1367
  // Utility to create a new tensor filled with 1
1353
1368
  static ones(shape, options = {}) {
@@ -1361,7 +1376,12 @@ class Tensor {
1361
1376
  static onesLike(tensor, options = {}) {
1362
1377
  if (typeof tensor.value === "number")
1363
1378
  return new Tensor(1, options);
1364
- return new Tensor(new Array(tensor.value.length).fill(1), { shape: tensor.shape, strides: tensor.strides, ...options });
1379
+ return new Tensor(new Array(tensor.value.length).fill(1), {
1380
+ shape: tensor.shape,
1381
+ strides: tensor.strides,
1382
+ device: tensor.device,
1383
+ ...options
1384
+ });
1365
1385
  }
1366
1386
  // Utility to create a new tensor filled with 0
1367
1387
  static zeros(shape, options = {}) {
@@ -1375,7 +1395,12 @@ class Tensor {
1375
1395
  static zerosLike(tensor, options = {}) {
1376
1396
  if (typeof tensor.value === "number")
1377
1397
  return new Tensor(0, options);
1378
- return new Tensor(new Array(tensor.value.length).fill(0), { shape: tensor.shape, strides: tensor.strides, ...options });
1398
+ return new Tensor(new Array(tensor.value.length).fill(0), {
1399
+ shape: tensor.shape,
1400
+ strides: tensor.strides,
1401
+ device: tensor.device,
1402
+ ...options
1403
+ });
1379
1404
  }
1380
1405
  // Utility to create a new tensor filled with a random number with uniform distribution from 0 to 1
1381
1406
  static rand(shape, options = {}) {
@@ -1397,7 +1422,10 @@ class Tensor {
1397
1422
  outputValue[index] = (0, utils_1.randUniform)();
1398
1423
  }
1399
1424
  return new Tensor(outputValue, {
1400
- shape: tensor.shape, strides: tensor.strides, ...options
1425
+ shape: tensor.shape,
1426
+ strides: tensor.strides,
1427
+ device: tensor.device,
1428
+ ...options
1401
1429
  });
1402
1430
  }
1403
1431
  // Utility to create a new tensor filled with a random number with normal distribution of mean=0 and stddev=1
@@ -1420,7 +1448,10 @@ class Tensor {
1420
1448
  outputValue[index] = (0, utils_1.randNormal)();
1421
1449
  }
1422
1450
  return new Tensor(outputValue, {
1423
- shape: tensor.shape, strides: tensor.strides, ...options
1451
+ shape: tensor.shape,
1452
+ strides: tensor.strides,
1453
+ device: tensor.device,
1454
+ ...options
1424
1455
  });
1425
1456
  }
1426
1457
  // Utility to create a new tensor filled with a random integer between low and high
@@ -1443,7 +1474,10 @@ class Tensor {
1443
1474
  outputValue[index] = (0, utils_1.randInt)(low, high);
1444
1475
  }
1445
1476
  return new Tensor(outputValue, {
1446
- shape: tensor.shape, strides: tensor.strides, ...options
1477
+ shape: tensor.shape,
1478
+ strides: tensor.strides,
1479
+ device: tensor.device,
1480
+ ...options
1447
1481
  });
1448
1482
  }
1449
1483
  // Utility to create a new tensor filled with a random number with normal distribution of custom mean and stddev
@@ -1469,14 +1503,20 @@ class Tensor {
1469
1503
  return new Tensor(outputValue, { shape, ...options });
1470
1504
  }
1471
1505
  // Reverse-mode autodiff call
1472
- backward() {
1506
+ backward(options = {}) {
1507
+ // Init
1508
+ const zeroGrad = options.zeroGrad ?? true;
1473
1509
  // Build topological order
1474
1510
  const topo = [];
1475
1511
  const visited = new Set();
1476
1512
  function build(node) {
1513
+ // Only collects unvisited node and node that requires gradient
1477
1514
  if (!visited.has(node) && node.requiresGrad) {
1478
1515
  visited.add(node);
1479
- node.grad = Tensor.zerosLike(node); // Reset grad with 0
1516
+ // Reset grad to zeros if specified
1517
+ if (zeroGrad) {
1518
+ node.grad = Tensor.zerosLike(node);
1519
+ }
1480
1520
  for (let child of node.children)
1481
1521
  build(child);
1482
1522
  topo.push(node);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.5.3",
3
+ "version": "0.5.4",
4
4
  "description": "A small Torch-like deep learning framework for Javascript",
5
5
  "main": "index.js",
6
6
  "scripts": {