@jax-js/jax 0.1.10 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -1
- package/dist/{backend-Ctqs8la1.js → backend-DZvR7mZV.js} +730 -21
- package/dist/{backend-DMauYnfl.cjs → backend-DlYlOYqN.cjs} +735 -20
- package/dist/index.cjs +140 -3
- package/dist/index.d.cts +66 -3
- package/dist/index.d.ts +66 -3
- package/dist/index.js +140 -4
- package/dist/{webgl-CvQ1QBX1.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-Dg8FpYrH.js} +6 -1
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-uU9nnttc.cjs} +6 -1
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
364
364
|
Primitive$1["Mod"] = "mod";
|
|
365
365
|
Primitive$1["Min"] = "min";
|
|
366
366
|
Primitive$1["Max"] = "max";
|
|
367
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
368
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
367
369
|
Primitive$1["Neg"] = "neg";
|
|
368
370
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
369
371
|
Primitive$1["Floor"] = "floor";
|
|
@@ -437,6 +439,12 @@ function min$1(x, y) {
|
|
|
437
439
|
function max$1(x, y) {
|
|
438
440
|
return bind1(Primitive.Max, [x, y]);
|
|
439
441
|
}
|
|
442
|
+
function bitCombine(x, y, op) {
|
|
443
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
444
|
+
}
|
|
445
|
+
function bitShift(x, y, op) {
|
|
446
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
447
|
+
}
|
|
440
448
|
function neg(x) {
|
|
441
449
|
return bind1(Primitive.Neg, [x]);
|
|
442
450
|
}
|
|
@@ -1655,6 +1663,16 @@ const abstractEvalRules = {
|
|
|
1655
1663
|
[Primitive.Mod]: binopAbstractEval,
|
|
1656
1664
|
[Primitive.Min]: binopAbstractEval,
|
|
1657
1665
|
[Primitive.Max]: binopAbstractEval,
|
|
1666
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1667
|
+
const aval = promoteAvals(x, y);
|
|
1668
|
+
if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1669
|
+
return [aval];
|
|
1670
|
+
},
|
|
1671
|
+
[Primitive.BitShift]([x, y]) {
|
|
1672
|
+
const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
|
|
1673
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1674
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1675
|
+
},
|
|
1658
1676
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1659
1677
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1660
1678
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -2190,6 +2208,8 @@ const jitRules = {
|
|
|
2190
2208
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2191
2209
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2192
2210
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
2211
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
|
|
2212
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
|
|
2193
2213
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
2194
2214
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
2195
2215
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -2382,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2382
2402
|
case Primitive.Idiv:
|
|
2383
2403
|
case Primitive.Mod:
|
|
2384
2404
|
case Primitive.Min:
|
|
2385
|
-
case Primitive.Max:
|
|
2405
|
+
case Primitive.Max:
|
|
2406
|
+
case Primitive.BitCombine:
|
|
2407
|
+
case Primitive.BitShift: {
|
|
2386
2408
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2387
2409
|
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2388
2410
|
head = usages[0];
|
|
@@ -3021,6 +3043,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3021
3043
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
3022
3044
|
}
|
|
3023
3045
|
/**
|
|
3046
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3047
|
+
*
|
|
3048
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3049
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3050
|
+
* _should not_ mutate the buffer's contents.
|
|
3051
|
+
*
|
|
3052
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3053
|
+
* will always be aligned to 4 bytes.
|
|
3054
|
+
*/
|
|
3055
|
+
async gpuBuffer() {
|
|
3056
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3057
|
+
this.#realize();
|
|
3058
|
+
const pending = this.#pending;
|
|
3059
|
+
if (pending) {
|
|
3060
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3061
|
+
for (const p of pending) p.submit();
|
|
3062
|
+
}
|
|
3063
|
+
const backend = this.#backend;
|
|
3064
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3065
|
+
this.dispose();
|
|
3066
|
+
return buffer;
|
|
3067
|
+
}
|
|
3068
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3069
|
+
gpuBufferSync() {
|
|
3070
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3071
|
+
this.#realize();
|
|
3072
|
+
for (const p of this.#pending) {
|
|
3073
|
+
p.prepareSync();
|
|
3074
|
+
p.submit();
|
|
3075
|
+
}
|
|
3076
|
+
const backend = this.#backend;
|
|
3077
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3078
|
+
this.dispose();
|
|
3079
|
+
return buffer;
|
|
3080
|
+
}
|
|
3081
|
+
/**
|
|
3024
3082
|
* Convert this array into a JavaScript object.
|
|
3025
3083
|
*
|
|
3026
3084
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3067,6 +3125,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3067
3125
|
[Primitive.Max]([x, y]) {
|
|
3068
3126
|
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
3069
3127
|
},
|
|
3128
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3129
|
+
const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
|
|
3130
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3131
|
+
},
|
|
3132
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3133
|
+
const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
|
|
3134
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3135
|
+
},
|
|
3070
3136
|
[Primitive.Neg]([x]) {
|
|
3071
3137
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
3072
3138
|
},
|
|
@@ -3759,6 +3825,8 @@ const vmapRules = {
|
|
|
3759
3825
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3760
3826
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3761
3827
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3828
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3829
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3762
3830
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3763
3831
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3764
3832
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4082,6 +4150,8 @@ const jvpRules = {
|
|
|
4082
4150
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4083
4151
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4084
4152
|
},
|
|
4153
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4154
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4085
4155
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4086
4156
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4087
4157
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -5273,7 +5343,8 @@ __export(numpy_linalg_exports, {
|
|
|
5273
5343
|
solve: () => solve,
|
|
5274
5344
|
tensordot: () => tensordot,
|
|
5275
5345
|
trace: () => trace,
|
|
5276
|
-
vecdot: () => vecdot
|
|
5346
|
+
vecdot: () => vecdot,
|
|
5347
|
+
vectorNorm: () => vectorNorm
|
|
5277
5348
|
});
|
|
5278
5349
|
function checkSquare(name, a) {
|
|
5279
5350
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5452,6 +5523,23 @@ function solve(a, b) {
|
|
|
5452
5523
|
if (bIs1d) x = squeeze(x, -1);
|
|
5453
5524
|
return x;
|
|
5454
5525
|
}
|
|
5526
|
+
/**
|
|
5527
|
+
* Compute the vector norm of an array.
|
|
5528
|
+
*
|
|
5529
|
+
* @param x - Input array.
|
|
5530
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5531
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5532
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5533
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5534
|
+
*/
|
|
5535
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5536
|
+
x = fudgeArray(x);
|
|
5537
|
+
const ax = axis ?? null;
|
|
5538
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5539
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5540
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5541
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5542
|
+
}
|
|
5455
5543
|
|
|
5456
5544
|
//#endregion
|
|
5457
5545
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5571,6 +5659,13 @@ __export(numpy_exports, {
|
|
|
5571
5659
|
atan2: () => atan2,
|
|
5572
5660
|
atanh: () => arctanh,
|
|
5573
5661
|
average: () => average,
|
|
5662
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5663
|
+
bitwiseInvert: () => invert,
|
|
5664
|
+
bitwiseLeftShift: () => leftShift,
|
|
5665
|
+
bitwiseNot: () => invert,
|
|
5666
|
+
bitwiseOr: () => bitwiseOr,
|
|
5667
|
+
bitwiseRightShift: () => rightShift,
|
|
5668
|
+
bitwiseXor: () => bitwiseXor,
|
|
5574
5669
|
bool: () => bool,
|
|
5575
5670
|
broadcastArrays: () => broadcastArrays,
|
|
5576
5671
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5632,12 +5727,14 @@ __export(numpy_exports, {
|
|
|
5632
5727
|
inf: () => inf,
|
|
5633
5728
|
inner: () => inner,
|
|
5634
5729
|
int32: () => int32,
|
|
5730
|
+
invert: () => invert,
|
|
5635
5731
|
isfinite: () => isfinite,
|
|
5636
5732
|
isinf: () => isinf,
|
|
5637
5733
|
isnan: () => isnan,
|
|
5638
5734
|
isneginf: () => isneginf,
|
|
5639
5735
|
isposinf: () => isposinf,
|
|
5640
5736
|
ldexp: () => ldexp,
|
|
5737
|
+
leftShift: () => leftShift,
|
|
5641
5738
|
less: () => less,
|
|
5642
5739
|
lessEqual: () => lessEqual,
|
|
5643
5740
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5686,6 +5783,7 @@ __export(numpy_exports, {
|
|
|
5686
5783
|
remainder: () => remainder,
|
|
5687
5784
|
repeat: () => repeat,
|
|
5688
5785
|
reshape: () => reshape,
|
|
5786
|
+
rightShift: () => rightShift,
|
|
5689
5787
|
rint: () => rint,
|
|
5690
5788
|
round: () => round,
|
|
5691
5789
|
shape: () => shape,
|
|
@@ -5800,6 +5898,44 @@ function logicalXor(x, y) {
|
|
|
5800
5898
|
function logicalNot(x) {
|
|
5801
5899
|
return notEqual(astype(x, require_backend.DType.Bool), true);
|
|
5802
5900
|
}
|
|
5901
|
+
/** Compute element-wise bitwise AND. */
|
|
5902
|
+
function bitwiseAnd(x, y) {
|
|
5903
|
+
return bitCombine(x, y, "and");
|
|
5904
|
+
}
|
|
5905
|
+
/** Compute element-wise bitwise OR. */
|
|
5906
|
+
function bitwiseOr(x, y) {
|
|
5907
|
+
return bitCombine(x, y, "or");
|
|
5908
|
+
}
|
|
5909
|
+
/** Compute element-wise bitwise XOR. */
|
|
5910
|
+
function bitwiseXor(x, y) {
|
|
5911
|
+
return bitCombine(x, y, "xor");
|
|
5912
|
+
}
|
|
5913
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5914
|
+
function invert(x) {
|
|
5915
|
+
const arr = fudgeArray(x);
|
|
5916
|
+
let allOnes;
|
|
5917
|
+
switch (arr.dtype) {
|
|
5918
|
+
case require_backend.DType.Bool:
|
|
5919
|
+
allOnes = true;
|
|
5920
|
+
break;
|
|
5921
|
+
case require_backend.DType.Uint32:
|
|
5922
|
+
allOnes = 4294967295;
|
|
5923
|
+
break;
|
|
5924
|
+
case require_backend.DType.Int32:
|
|
5925
|
+
allOnes = -1;
|
|
5926
|
+
break;
|
|
5927
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5928
|
+
}
|
|
5929
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5930
|
+
}
|
|
5931
|
+
/** Compute element-wise left bit shift. */
|
|
5932
|
+
function leftShift(x, y) {
|
|
5933
|
+
return bitShift(x, y, "shl");
|
|
5934
|
+
}
|
|
5935
|
+
/** Compute element-wise right bit shift. */
|
|
5936
|
+
function rightShift(x, y) {
|
|
5937
|
+
return bitShift(x, y, "shr");
|
|
5938
|
+
}
|
|
5803
5939
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5804
5940
|
const where = where$1;
|
|
5805
5941
|
/**
|
|
@@ -8336,6 +8472,7 @@ exports.blockUntilReady = blockUntilReady;
|
|
|
8336
8472
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
8337
8473
|
exports.devicePut = devicePut;
|
|
8338
8474
|
exports.devices = require_backend.devices;
|
|
8475
|
+
exports.getWebGPUDevice = require_backend.getWebGPUDevice;
|
|
8339
8476
|
exports.grad = grad;
|
|
8340
8477
|
exports.hessian = hessian;
|
|
8341
8478
|
exports.init = require_backend.init;
|
package/dist/index.d.cts
CHANGED
|
@@ -232,6 +232,8 @@ declare class AluExp implements FpHashable {
|
|
|
232
232
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
233
233
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
234
234
|
static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
|
|
235
|
+
static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
|
|
236
|
+
static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
|
|
235
237
|
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
236
238
|
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
237
239
|
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
@@ -323,6 +325,11 @@ declare enum AluOp {
|
|
|
323
325
|
Reciprocal = "Reciprocal",
|
|
324
326
|
Cast = "Cast",
|
|
325
327
|
Bitcast = "Bitcast",
|
|
328
|
+
BitCombine = "BitCombine",
|
|
329
|
+
// arg = 'or' | 'and' | 'xor'
|
|
330
|
+
BitInvert = "BitInvert",
|
|
331
|
+
BitShift = "BitShift",
|
|
332
|
+
// arg = 'shl' | 'shr'
|
|
326
333
|
Cmplt = "Cmplt",
|
|
327
334
|
Cmpne = "Cmpne",
|
|
328
335
|
Where = "Where",
|
|
@@ -546,6 +553,11 @@ declare class Executable<T = any> {
|
|
|
546
553
|
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
547
554
|
data: T);
|
|
548
555
|
}
|
|
556
|
+
/**
|
|
557
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
558
|
+
* backend runs on. This is useful for sharing buffers.
|
|
559
|
+
*/
|
|
560
|
+
declare function getWebGPUDevice(): GPUDevice;
|
|
549
561
|
declare namespace tree_d_exports {
|
|
550
562
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
551
563
|
}
|
|
@@ -719,6 +731,8 @@ declare enum Primitive {
|
|
|
719
731
|
// uses sign of numerator, C-style, matches JS but not Python
|
|
720
732
|
Min = "min",
|
|
721
733
|
Max = "max",
|
|
734
|
+
BitCombine = "bit_combine",
|
|
735
|
+
BitShift = "bit_shift",
|
|
722
736
|
Neg = "neg",
|
|
723
737
|
Reciprocal = "reciprocal",
|
|
724
738
|
Floor = "floor",
|
|
@@ -767,6 +781,12 @@ declare enum Primitive {
|
|
|
767
781
|
Jit = "jit",
|
|
768
782
|
}
|
|
769
783
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
784
|
+
[Primitive.BitCombine]: {
|
|
785
|
+
op: "and" | "or" | "xor";
|
|
786
|
+
};
|
|
787
|
+
[Primitive.BitShift]: {
|
|
788
|
+
op: "shl" | "shr";
|
|
789
|
+
};
|
|
770
790
|
[Primitive.Cast]: {
|
|
771
791
|
dtype: DType;
|
|
772
792
|
};
|
|
@@ -1194,6 +1214,19 @@ declare class Array extends Tracer {
|
|
|
1194
1214
|
* recommended for performance reasons, as it will block rendering.
|
|
1195
1215
|
*/
|
|
1196
1216
|
dataSync(): DataArray;
|
|
1217
|
+
/**
|
|
1218
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
1219
|
+
*
|
|
1220
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
1221
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
1222
|
+
* _should not_ mutate the buffer's contents.
|
|
1223
|
+
*
|
|
1224
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
1225
|
+
* will always be aligned to 4 bytes.
|
|
1226
|
+
*/
|
|
1227
|
+
gpuBuffer(): Promise<GPUBuffer>;
|
|
1228
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
1229
|
+
gpuBufferSync(): GPUBuffer;
|
|
1197
1230
|
/**
|
|
1198
1231
|
* Convert this array into a JavaScript object.
|
|
1199
1232
|
*
|
|
@@ -1571,7 +1604,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1571
1604
|
*/
|
|
1572
1605
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1573
1606
|
declare namespace numpy_linalg_d_exports {
|
|
1574
|
-
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1607
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
|
|
1575
1608
|
}
|
|
1576
1609
|
/**
|
|
1577
1610
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1626,6 +1659,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
|
|
|
1626
1659
|
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
1627
1660
|
*/
|
|
1628
1661
|
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
1662
|
+
/**
|
|
1663
|
+
* Compute the vector norm of an array.
|
|
1664
|
+
*
|
|
1665
|
+
* @param x - Input array.
|
|
1666
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
1667
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
1668
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
1669
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
1670
|
+
*/
|
|
1671
|
+
declare function vectorNorm(x: ArrayLike, {
|
|
1672
|
+
ord,
|
|
1673
|
+
axis,
|
|
1674
|
+
keepdims
|
|
1675
|
+
}?: {
|
|
1676
|
+
ord?: number;
|
|
1677
|
+
axis?: number | number[] | null;
|
|
1678
|
+
keepdims?: boolean;
|
|
1679
|
+
}): Array;
|
|
1629
1680
|
//#endregion
|
|
1630
1681
|
//#region src/library/numpy/dtype-info.d.ts
|
|
1631
1682
|
/** @inline */
|
|
@@ -1679,7 +1730,7 @@ type IInfo = Readonly<{
|
|
|
1679
1730
|
/** Machine limits for integer types. */
|
|
1680
1731
|
declare function iinfo(dtype: DType): IInfo;
|
|
1681
1732
|
declare namespace numpy_d_exports {
|
|
1682
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1733
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1683
1734
|
}
|
|
1684
1735
|
declare const float32 = DType.Float32;
|
|
1685
1736
|
declare const int32 = DType.Int32;
|
|
@@ -1747,6 +1798,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
|
1747
1798
|
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1748
1799
|
/** Compute element-wise logical NOT. */
|
|
1749
1800
|
declare function logicalNot(x: ArrayLike): Array;
|
|
1801
|
+
/** Compute element-wise bitwise AND. */
|
|
1802
|
+
declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1803
|
+
/** Compute element-wise bitwise OR. */
|
|
1804
|
+
declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1805
|
+
/** Compute element-wise bitwise XOR. */
|
|
1806
|
+
declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1807
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
1808
|
+
declare function invert(x: ArrayLike): Array;
|
|
1809
|
+
/** Compute element-wise left bit shift. */
|
|
1810
|
+
declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1811
|
+
/** Compute element-wise right bit shift. */
|
|
1812
|
+
declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1750
1813
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1751
1814
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1752
1815
|
/**
|
|
@@ -2958,4 +3021,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2958
3021
|
*/
|
|
2959
3022
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2960
3023
|
//#endregion
|
|
2961
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
3024
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
package/dist/index.d.ts
CHANGED
|
@@ -229,6 +229,8 @@ declare class AluExp implements FpHashable {
|
|
|
229
229
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
230
230
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
231
231
|
static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
|
|
232
|
+
static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
|
|
233
|
+
static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
|
|
232
234
|
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
233
235
|
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
234
236
|
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
@@ -320,6 +322,11 @@ declare enum AluOp {
|
|
|
320
322
|
Reciprocal = "Reciprocal",
|
|
321
323
|
Cast = "Cast",
|
|
322
324
|
Bitcast = "Bitcast",
|
|
325
|
+
BitCombine = "BitCombine",
|
|
326
|
+
// arg = 'or' | 'and' | 'xor'
|
|
327
|
+
BitInvert = "BitInvert",
|
|
328
|
+
BitShift = "BitShift",
|
|
329
|
+
// arg = 'shl' | 'shr'
|
|
323
330
|
Cmplt = "Cmplt",
|
|
324
331
|
Cmpne = "Cmpne",
|
|
325
332
|
Where = "Where",
|
|
@@ -543,6 +550,11 @@ declare class Executable<T = any> {
|
|
|
543
550
|
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
544
551
|
data: T);
|
|
545
552
|
}
|
|
553
|
+
/**
|
|
554
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
555
|
+
* backend runs on. This is useful for sharing buffers.
|
|
556
|
+
*/
|
|
557
|
+
declare function getWebGPUDevice(): GPUDevice;
|
|
546
558
|
declare namespace tree_d_exports {
|
|
547
559
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
548
560
|
}
|
|
@@ -716,6 +728,8 @@ declare enum Primitive {
|
|
|
716
728
|
// uses sign of numerator, C-style, matches JS but not Python
|
|
717
729
|
Min = "min",
|
|
718
730
|
Max = "max",
|
|
731
|
+
BitCombine = "bit_combine",
|
|
732
|
+
BitShift = "bit_shift",
|
|
719
733
|
Neg = "neg",
|
|
720
734
|
Reciprocal = "reciprocal",
|
|
721
735
|
Floor = "floor",
|
|
@@ -764,6 +778,12 @@ declare enum Primitive {
|
|
|
764
778
|
Jit = "jit",
|
|
765
779
|
}
|
|
766
780
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
781
|
+
[Primitive.BitCombine]: {
|
|
782
|
+
op: "and" | "or" | "xor";
|
|
783
|
+
};
|
|
784
|
+
[Primitive.BitShift]: {
|
|
785
|
+
op: "shl" | "shr";
|
|
786
|
+
};
|
|
767
787
|
[Primitive.Cast]: {
|
|
768
788
|
dtype: DType;
|
|
769
789
|
};
|
|
@@ -1191,6 +1211,19 @@ declare class Array extends Tracer {
|
|
|
1191
1211
|
* recommended for performance reasons, as it will block rendering.
|
|
1192
1212
|
*/
|
|
1193
1213
|
dataSync(): DataArray;
|
|
1214
|
+
/**
|
|
1215
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
1216
|
+
*
|
|
1217
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
1218
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
1219
|
+
* _should not_ mutate the buffer's contents.
|
|
1220
|
+
*
|
|
1221
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
1222
|
+
* will always be aligned to 4 bytes.
|
|
1223
|
+
*/
|
|
1224
|
+
gpuBuffer(): Promise<GPUBuffer>;
|
|
1225
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
1226
|
+
gpuBufferSync(): GPUBuffer;
|
|
1194
1227
|
/**
|
|
1195
1228
|
* Convert this array into a JavaScript object.
|
|
1196
1229
|
*
|
|
@@ -1568,7 +1601,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1568
1601
|
*/
|
|
1569
1602
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1570
1603
|
declare namespace numpy_linalg_d_exports {
|
|
1571
|
-
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1604
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
|
|
1572
1605
|
}
|
|
1573
1606
|
/**
|
|
1574
1607
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1623,6 +1656,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
|
|
|
1623
1656
|
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
1624
1657
|
*/
|
|
1625
1658
|
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
1659
|
+
/**
|
|
1660
|
+
* Compute the vector norm of an array.
|
|
1661
|
+
*
|
|
1662
|
+
* @param x - Input array.
|
|
1663
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
1664
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
1665
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
1666
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
1667
|
+
*/
|
|
1668
|
+
declare function vectorNorm(x: ArrayLike, {
|
|
1669
|
+
ord,
|
|
1670
|
+
axis,
|
|
1671
|
+
keepdims
|
|
1672
|
+
}?: {
|
|
1673
|
+
ord?: number;
|
|
1674
|
+
axis?: number | number[] | null;
|
|
1675
|
+
keepdims?: boolean;
|
|
1676
|
+
}): Array;
|
|
1626
1677
|
//#endregion
|
|
1627
1678
|
//#region src/library/numpy/dtype-info.d.ts
|
|
1628
1679
|
/** @inline */
|
|
@@ -1676,7 +1727,7 @@ type IInfo = Readonly<{
|
|
|
1676
1727
|
/** Machine limits for integer types. */
|
|
1677
1728
|
declare function iinfo(dtype: DType): IInfo;
|
|
1678
1729
|
declare namespace numpy_d_exports {
|
|
1679
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1730
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1680
1731
|
}
|
|
1681
1732
|
declare const float32 = DType.Float32;
|
|
1682
1733
|
declare const int32 = DType.Int32;
|
|
@@ -1744,6 +1795,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
|
1744
1795
|
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1745
1796
|
/** Compute element-wise logical NOT. */
|
|
1746
1797
|
declare function logicalNot(x: ArrayLike): Array;
|
|
1798
|
+
/** Compute element-wise bitwise AND. */
|
|
1799
|
+
declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1800
|
+
/** Compute element-wise bitwise OR. */
|
|
1801
|
+
declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1802
|
+
/** Compute element-wise bitwise XOR. */
|
|
1803
|
+
declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1804
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
1805
|
+
declare function invert(x: ArrayLike): Array;
|
|
1806
|
+
/** Compute element-wise left bit shift. */
|
|
1807
|
+
declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1808
|
+
/** Compute element-wise right bit shift. */
|
|
1809
|
+
declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1747
1810
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1748
1811
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1749
1812
|
/**
|
|
@@ -2955,4 +3018,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2955
3018
|
*/
|
|
2956
3019
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2957
3020
|
//#endregion
|
|
2958
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
3021
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|