@jax-js/jax 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js ADDED
@@ -0,0 +1,3708 @@
1
+ import {
2
+ AluExp,
3
+ AluVar,
4
+ DEBUG,
5
+ DType,
6
+ FpHash,
7
+ Kernel,
8
+ Reduction,
9
+ ShapeTracker,
10
+ __callDispose,
11
+ __export,
12
+ __using,
13
+ accessorAluExp,
14
+ accessorGlobal,
15
+ deepEqual,
16
+ devices,
17
+ getBackend,
18
+ init,
19
+ invertPermutation,
20
+ isFloatDtype,
21
+ isPermutation,
22
+ partitionList,
23
+ prod,
24
+ range,
25
+ recursiveFlatten,
26
+ rep,
27
+ runWithCache,
28
+ setDevice,
29
+ toposort,
30
+ unravelAlu,
31
+ unzip2,
32
+ zip
33
+ } from "./chunk-B2GFURUN.js";
34
+
35
+ // src/pprint.ts
36
+ var PPrint = class _PPrint {
37
+ constructor(indents, lines) {
38
+ this.indents = indents;
39
+ this.lines = lines;
40
+ }
41
+ /** Add a fixed amount of indentation to each line. */
42
+ indent(spaces) {
43
+ return new _PPrint(
44
+ this.indents.map((i) => i + spaces),
45
+ this.lines
46
+ );
47
+ }
48
+ /** Concatenate two or more pretty-printed expressions. */
49
+ concat(...items) {
50
+ return new _PPrint(
51
+ (this.indents ?? []).concat(...items.map((i) => i.indents)),
52
+ (this.lines ?? []).concat(...items.map((i) => i.lines))
53
+ );
54
+ }
55
+ /** Stack one block to the right of another one, sharing 1 common line. */
56
+ stack(other) {
57
+ if (!other.lines.length) return this;
58
+ if (!this.lines.length) return other;
59
+ const indent = this.indents[this.indents.length - 1];
60
+ const s = this.lines[this.lines.length - 1];
61
+ const indentedBlock = other.indent(indent + s.length);
62
+ return new _PPrint(
63
+ this.indents.concat(indentedBlock.indents.slice(1)),
64
+ this.lines.slice(0, -1).concat(
65
+ s + " ".repeat(other.indents[0]) + other.lines[0],
66
+ ...indentedBlock.lines.slice(1)
67
+ )
68
+ );
69
+ }
70
+ /** Combine this block of lines into a formatted string. */
71
+ toString() {
72
+ return this.lines.map((line, i) => " ".repeat(this.indents[i]) + line).join("\n");
73
+ }
74
+ static pp(s) {
75
+ const lines = s.toString().split("\n");
76
+ return new _PPrint(Array(lines.length).fill(0), lines);
77
+ }
78
+ };
79
+
80
+ // src/tree.ts
81
+ var tree_exports = {};
82
+ __export(tree_exports, {
83
+ JsTreeDef: () => JsTreeDef,
84
+ NodeType: () => NodeType,
85
+ flatten: () => flatten,
86
+ leaves: () => leaves,
87
+ structure: () => structure,
88
+ unflatten: () => unflatten
89
+ });
90
+ var NodeType = /* @__PURE__ */ ((NodeType2) => {
91
+ NodeType2["Array"] = "Array";
92
+ NodeType2["Object"] = "Object";
93
+ NodeType2["Leaf"] = "Leaf";
94
+ return NodeType2;
95
+ })(NodeType || {});
96
+ var JsTreeDef = class _JsTreeDef {
97
+ constructor(nodeType, nodeMetadata, childTreedefs) {
98
+ this.nodeType = nodeType;
99
+ this.nodeMetadata = nodeMetadata;
100
+ this.childTreedefs = childTreedefs;
101
+ }
102
+ static leaf = new _JsTreeDef("Leaf" /* Leaf */, null, []);
103
+ /** Returns a string representation of this tree definition. */
104
+ toString(root = true) {
105
+ if (root) {
106
+ return "JsTreeDef(" + this.toString(false) + ")";
107
+ }
108
+ switch (this.nodeType) {
109
+ case "Leaf" /* Leaf */:
110
+ return "*";
111
+ case "Array" /* Array */:
112
+ return `[${this.childTreedefs.map((x) => x.toString(false)).join(", ")}]`;
113
+ case "Object" /* Object */: {
114
+ const parts = [];
115
+ for (let i = 0; i < this.childTreedefs.length; i++) {
116
+ parts.push(
117
+ `${quoteObjectKey(this.nodeMetadata[i])}: ${this.childTreedefs[i].toString(false)}`
118
+ );
119
+ }
120
+ return `{${parts.join(", ")}}`;
121
+ }
122
+ }
123
+ }
124
+ /** Compare this tree definition with another. */
125
+ equals(other) {
126
+ return this.nodeType === other.nodeType && deepEqual(this.nodeMetadata, other.nodeMetadata) && this.childTreedefs.length === other.childTreedefs.length && this.childTreedefs.every((x, i) => x.equals(other.childTreedefs[i]));
127
+ }
128
+ };
129
+ function quoteObjectKey(key) {
130
+ if (/^[a-zA-Z_$][a-zA-Z0-9_$]*$/.test(key)) {
131
+ return key;
132
+ }
133
+ return JSON.stringify(key);
134
+ }
135
+ function flatten(tree) {
136
+ const leaves2 = [];
137
+ const treedef = _flatten(tree, leaves2);
138
+ return [leaves2, treedef];
139
+ }
140
+ function _flatten(tree, leaves2) {
141
+ if (Array.isArray(tree)) {
142
+ const childTrees = tree.map((c) => _flatten(c, leaves2));
143
+ return new JsTreeDef("Array" /* Array */, null, childTrees);
144
+ } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
145
+ const [keys, values] = unzip2(Object.entries(tree));
146
+ const childTrees = values.map((c) => _flatten(c, leaves2));
147
+ return new JsTreeDef("Object" /* Object */, keys, childTrees);
148
+ } else {
149
+ leaves2.push(tree);
150
+ return JsTreeDef.leaf;
151
+ }
152
+ }
153
+ function leaves(tree) {
154
+ return flatten(tree)[0];
155
+ }
156
+ function structure(tree) {
157
+ return flatten(tree)[1];
158
+ }
159
+ function unflatten(treedef, leaves2) {
160
+ return _unflatten(treedef, leaves2[Symbol.iterator]());
161
+ }
162
+ function _unflatten(treedef, leaves2) {
163
+ switch (treedef.nodeType) {
164
+ case "Leaf" /* Leaf */: {
165
+ const { value, done } = leaves2.next();
166
+ if (done) {
167
+ throw new TypeError("Ran out of leaves while unflattening JsTree");
168
+ }
169
+ return value;
170
+ }
171
+ case "Array" /* Array */:
172
+ return treedef.childTreedefs.map((c) => _unflatten(c, leaves2));
173
+ case "Object" /* Object */: {
174
+ const obj = {};
175
+ for (let i = 0; i < treedef.childTreedefs.length; i++) {
176
+ obj[treedef.nodeMetadata[i]] = _unflatten(
177
+ treedef.childTreedefs[i],
178
+ leaves2
179
+ );
180
+ }
181
+ return obj;
182
+ }
183
+ }
184
+ }
185
+
186
+ // src/frontend/core.ts
187
+ function add(x, y) {
188
+ return bind1("add" /* Add */, [x, y]);
189
+ }
190
+ function mul(x, y) {
191
+ return bind1("mul" /* Mul */, [x, y]);
192
+ }
193
+ function idiv(x, y) {
194
+ return bind1("idiv" /* Idiv */, [x, y]);
195
+ }
196
+ function neg(x) {
197
+ return bind1("neg" /* Neg */, [x]);
198
+ }
199
+ function reciprocal(x) {
200
+ return bind1("reciprocal" /* Reciprocal */, [x]);
201
+ }
202
+ function sin(x) {
203
+ return bind1("sin" /* Sin */, [x]);
204
+ }
205
+ function cos(x) {
206
+ return bind1("cos" /* Cos */, [x]);
207
+ }
208
+ function exp(x) {
209
+ return bind1("exp" /* Exp */, [x]);
210
+ }
211
+ function log(x) {
212
+ return bind1("log" /* Log */, [x]);
213
+ }
214
+ function min(x, y) {
215
+ return bind1("min" /* Min */, [x, y]);
216
+ }
217
+ function max(x, y) {
218
+ return bind1("max" /* Max */, [x, y]);
219
+ }
220
+ function compare(x, y, op) {
221
+ return bind1("compare" /* Compare */, [x, y], { op });
222
+ }
223
+ function greater(x, y) {
224
+ return compare(x, y, "greater" /* Greater */);
225
+ }
226
+ function less(x, y) {
227
+ return compare(x, y, "less" /* Less */);
228
+ }
229
+ function equal(x, y) {
230
+ return compare(x, y, "equal" /* Equal */);
231
+ }
232
+ function notEqual(x, y) {
233
+ return compare(x, y, "not_equal" /* NotEqual */);
234
+ }
235
+ function greaterEqual(x, y) {
236
+ return compare(x, y, "greater_equal" /* GreaterEqual */);
237
+ }
238
+ function lessEqual(x, y) {
239
+ return compare(x, y, "less_equal" /* LessEqual */);
240
+ }
241
+ function where(cond, x, y) {
242
+ return bind1("where" /* Where */, [cond, x, y]);
243
+ }
244
+ function transpose(x, perm) {
245
+ perm = perm ?? range(ndim(x)).reverse();
246
+ return bind1("transpose" /* Transpose */, [x], { perm });
247
+ }
248
+ function broadcast(x, shape2, axis) {
249
+ return bind1("broadcast" /* Broadcast */, [x], { shape: shape2, axis });
250
+ }
251
+ function reshape(x, shape2) {
252
+ if (typeof shape2 === "number") shape2 = [shape2];
253
+ const originalShape = getShape(x);
254
+ const autoIdx = shape2.indexOf(-1);
255
+ if (autoIdx !== -1) {
256
+ const remaining = prod(originalShape) / -prod(shape2);
257
+ if (!Number.isInteger(remaining) || remaining < 0) {
258
+ throw new TypeError(
259
+ `Invalid reshape: ${JSON.stringify(originalShape)} -> ${JSON.stringify(shape2)}`
260
+ );
261
+ }
262
+ shape2 = shape2.toSpliced(autoIdx, 1, remaining);
263
+ }
264
+ if (prod(originalShape) !== prod(shape2)) {
265
+ throw new TypeError(
266
+ `Invalid reshape: ${JSON.stringify(originalShape)} -> ${JSON.stringify(shape2)}`
267
+ );
268
+ }
269
+ return bind1("reshape" /* Reshape */, [x], { shape: shape2 });
270
+ }
271
+ function flip(x, axis) {
272
+ return bind1("flip" /* Flip */, [x], { axis });
273
+ }
274
+ function reduceSum(x, axis) {
275
+ if (axis === void 0) {
276
+ if (x instanceof Tracer) {
277
+ axis = range(x.shape.length);
278
+ } else {
279
+ axis = [];
280
+ }
281
+ }
282
+ if (typeof axis === "number") {
283
+ axis = [axis];
284
+ }
285
+ return bind1("reduce_sum" /* ReduceSum */, [x], { axis });
286
+ }
287
+ function bind1(prim, args, params = {}) {
288
+ const [results] = bind(prim, args, params);
289
+ return results;
290
+ }
291
+ var traceStack = [];
292
+ var dynamicTrace = null;
293
+ function newMain(traceType, globalData = null) {
294
+ const level = traceStack.length;
295
+ const main = { level, traceType, globalData };
296
+ traceStack.push(main);
297
+ return Object.assign(main, {
298
+ [Symbol.dispose]() {
299
+ traceStack.pop();
300
+ }
301
+ });
302
+ }
303
+ function newDynamic(main) {
304
+ const prevDynamicTrace = dynamicTrace;
305
+ dynamicTrace = main;
306
+ return {
307
+ [Symbol.dispose]() {
308
+ dynamicTrace = prevDynamicTrace;
309
+ }
310
+ };
311
+ }
312
+ var Trace = class {
313
+ constructor(main) {
314
+ this.main = main;
315
+ }
316
+ };
317
+ var Tracer = class {
318
+ _trace;
319
+ constructor(trace) {
320
+ this._trace = trace;
321
+ }
322
+ get shape() {
323
+ return this.aval.shape;
324
+ }
325
+ get dtype() {
326
+ return this.aval.dtype;
327
+ }
328
+ get ndim() {
329
+ return this.shape.length;
330
+ }
331
+ fullLower() {
332
+ return this;
333
+ }
334
+ // These types aren't technically correct since they don't account for the
335
+ // fact that tracers can be lifted to different levels. But they simplify the
336
+ // API visible to users.
337
+ neg() {
338
+ return neg(this);
339
+ }
340
+ add(other) {
341
+ return add(this, other);
342
+ }
343
+ mul(other) {
344
+ return mul(this, other);
345
+ }
346
+ greater(other) {
347
+ return greater(this, other);
348
+ }
349
+ less(other) {
350
+ return less(this, other);
351
+ }
352
+ equal(other) {
353
+ return equal(this, other);
354
+ }
355
+ notEqual(other) {
356
+ return notEqual(this, other);
357
+ }
358
+ greaterEqual(other) {
359
+ return greaterEqual(this, other);
360
+ }
361
+ lessEqual(other) {
362
+ return lessEqual(this, other);
363
+ }
364
+ sum(axis) {
365
+ return reduceSum(this, axis);
366
+ }
367
+ transpose(perm) {
368
+ return transpose(this, perm);
369
+ }
370
+ reshape(shape2) {
371
+ return reshape(this, shape2);
372
+ }
373
+ // Below this line are composite operations built from primitives.
374
+ /** Subtract an array from this one. */
375
+ sub(other) {
376
+ return this.add(neg(other));
377
+ }
378
+ /** Divide an array by this one. */
379
+ div(other) {
380
+ if (this.dtype === "int32" /* Int32 */) {
381
+ return idiv(this, other);
382
+ }
383
+ return this.mul(reciprocal(other));
384
+ }
385
+ /** Return specified diagonals. See `numpy.diagonal` for full docs. */
386
+ diagonal(offset = 0, axis1 = 0, axis2 = 1) {
387
+ if (!Number.isInteger(offset))
388
+ throw new TypeError(`offset must be an integer, got ${offset}`);
389
+ if (axis1 === axis2)
390
+ throw new TypeError("axis1 and axis2 must not be equal");
391
+ throw new Error("diagonal not implemented");
392
+ }
393
+ /** Flatten the array without changing its data. */
394
+ flatten() {
395
+ return this.reshape(-1);
396
+ }
397
+ /** Flatten the array without changing its data. */
398
+ ravel() {
399
+ return this.reshape(-1);
400
+ }
401
+ };
402
+ function ndim(x) {
403
+ if (x instanceof Tracer) {
404
+ return x.shape.length;
405
+ } else {
406
+ return 0;
407
+ }
408
+ }
409
+ function getShape(x) {
410
+ return x instanceof Tracer ? x.shape : [];
411
+ }
412
+ var ShapedArray = class _ShapedArray {
413
+ constructor(shape2, dtype) {
414
+ this.shape = shape2;
415
+ this.dtype = dtype;
416
+ }
417
+ static fromAval(aval) {
418
+ return new _ShapedArray(aval.shape, aval.dtype);
419
+ }
420
+ get ndim() {
421
+ return this.shape.length;
422
+ }
423
+ strShort() {
424
+ return `${this.dtype}[${this.shape.join(",")}]`;
425
+ }
426
+ equals(other) {
427
+ return this === other || this.constructor === other.constructor && this.ndim === other.ndim && this.shape.every((d, i) => d === other.shape[i]);
428
+ }
429
+ };
430
+ function getAval(x) {
431
+ if (x instanceof Tracer) {
432
+ return x.aval;
433
+ } else if (typeof x === "boolean" || typeof x === "number") {
434
+ return new ShapedArray(
435
+ [],
436
+ typeof x === "boolean" ? "bool" /* Bool */ : "float32" /* Float32 */
437
+ );
438
+ } else {
439
+ throw new TypeError(`Unknown value: ${x}`);
440
+ }
441
+ }
442
+ function bind(prim, args, params = {}) {
443
+ const topTrace = findTopTrace(args);
444
+ const tracers = args.map((arg) => fullRaise(topTrace, arg));
445
+ const outs = topTrace.processPrimitive(prim, tracers, params);
446
+ if (DEBUG >= 5) {
447
+ console.info(
448
+ `processing rule for ${prim} on ${tracers.map((x) => x.toString())} and got ${outs.map((x) => x.toString())}`
449
+ );
450
+ }
451
+ return outs.map((out) => out.fullLower());
452
+ }
453
+ function findTopTrace(xs) {
454
+ let topMain = traceStack[0];
455
+ for (const x of xs) {
456
+ if (x instanceof Tracer && x._trace.main.level > topMain.level) {
457
+ topMain = x._trace.main;
458
+ }
459
+ }
460
+ if (dynamicTrace && dynamicTrace.level > topMain.level) {
461
+ topMain = dynamicTrace;
462
+ }
463
+ return new topMain.traceType(topMain);
464
+ }
465
+ function fullRaise(trace, val) {
466
+ if (!(val instanceof Tracer)) {
467
+ return trace.pure(val);
468
+ }
469
+ const level = trace.main.level;
470
+ if (Object.is(val._trace.main, trace.main)) {
471
+ return val;
472
+ } else if (val._trace.main.level < level) {
473
+ return trace.lift(val);
474
+ } else if (val._trace.main.level > level) {
475
+ throw new Error(
476
+ `Can't lift Tracer level ${val._trace.main.level} to level ${level}`
477
+ );
478
+ } else {
479
+ throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
480
+ }
481
+ }
482
+ var TreeMismatchError = class extends TypeError {
483
+ constructor(where3, left, right) {
484
+ super(`Mismatched tree structures in ${where3}: ${left} != ${right}`);
485
+ }
486
+ };
487
+ function flattenFun(f, inTree) {
488
+ const store = { value: void 0 };
489
+ const flatFun = (...argsFlat) => {
490
+ const pytreeArgs = unflatten(inTree, argsFlat);
491
+ const out = f(...pytreeArgs);
492
+ const [outFlat, outTree] = flatten(out);
493
+ store.value = outTree;
494
+ return outFlat;
495
+ };
496
+ return [flatFun, store];
497
+ }
498
+ var UseAfterFreeError = class extends ReferenceError {
499
+ constructor(tracer) {
500
+ super(
501
+ `Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`
502
+ );
503
+ }
504
+ };
505
+
506
+ // src/frontend/jit.ts
507
+ var JitProgram = class {
508
+ constructor(backend, steps, inputs, outputs) {
509
+ this.backend = backend;
510
+ this.steps = steps;
511
+ this.inputs = inputs;
512
+ this.outputs = outputs;
513
+ }
514
+ /** Execute the JitProgram with the given inputs. */
515
+ execute(inputs) {
516
+ const scope = /* @__PURE__ */ new Map();
517
+ if (inputs.length !== this.inputs.length) {
518
+ throw new TypeError(
519
+ `Expected ${this.inputs.length} inputs, got ${inputs.length}`
520
+ );
521
+ }
522
+ for (const [i, id] of this.inputs.entries()) {
523
+ scope.set(id, inputs[i]);
524
+ }
525
+ const pending = [];
526
+ for (const step of this.steps) {
527
+ switch (step.type) {
528
+ case "execute": {
529
+ const inputs2 = step.inputs.map((id) => scope.get(id));
530
+ const outputs = step.outputs.map((id) => scope.get(id));
531
+ if (inputs2.some((s) => s === void 0) || outputs.some((s) => s === void 0)) {
532
+ throw new Error(`internal: JitProgram scope undefined`);
533
+ }
534
+ pending.push(
535
+ new PendingExecute(this.backend, step.kernel, inputs2, outputs)
536
+ );
537
+ break;
538
+ }
539
+ case "const":
540
+ scope.set(step.output, step.slot);
541
+ break;
542
+ case "malloc": {
543
+ const slot = this.backend.malloc(4 * step.size);
544
+ scope.set(step.output, slot);
545
+ break;
546
+ }
547
+ case "free": {
548
+ const slot = scope.get(step.input);
549
+ this.backend.decRef(slot);
550
+ scope.delete(step.input);
551
+ break;
552
+ }
553
+ }
554
+ }
555
+ return {
556
+ outputs: this.outputs.map((id) => scope.get(id)),
557
+ pending
558
+ };
559
+ }
560
+ };
561
+ var JitProgramBuilder = class {
562
+ backend;
563
+ #nextId;
564
+ steps;
565
+ constructor(backend, nargs) {
566
+ this.backend = backend;
567
+ this.#nextId = nargs;
568
+ this.steps = [];
569
+ }
570
+ pushConst(slot) {
571
+ const id = this.#nextId++;
572
+ this.steps.push({
573
+ type: "const",
574
+ slot,
575
+ output: id
576
+ });
577
+ return id;
578
+ }
579
+ pushLit(lit) {
580
+ const kernel = new Kernel(
581
+ 0,
582
+ prod(lit.aval.shape),
583
+ AluExp.const(lit.dtype, lit.value)
584
+ );
585
+ return this.pushKernel(kernel, []);
586
+ }
587
+ pushKernel(kernel, inputs) {
588
+ const id = this.#nextId++;
589
+ this.steps.push({
590
+ type: "malloc",
591
+ size: kernel.size,
592
+ output: id
593
+ });
594
+ this.steps.push({
595
+ type: "execute",
596
+ kernel,
597
+ inputs,
598
+ outputs: [id]
599
+ });
600
+ return id;
601
+ }
602
+ insertFreeSteps(outputIds) {
603
+ const ids = this.steps.filter((s) => s.type === "malloc").map((s) => s.output);
604
+ for (const id of ids) {
605
+ if (outputIds.includes(id)) continue;
606
+ const lastUsage = this.steps.findLastIndex(
607
+ (s) => s.type === "execute" && (s.outputs.includes(id) || s.inputs.includes(id)) || s.type === "malloc" && s.output === id
608
+ );
609
+ this.steps.splice(lastUsage + 1, 0, {
610
+ type: "free",
611
+ input: id
612
+ });
613
+ }
614
+ }
615
+ pushFree(id) {
616
+ this.steps.push({
617
+ type: "free",
618
+ input: id
619
+ });
620
+ }
621
+ };
622
+ var jitCompileCache = /* @__PURE__ */ new Map();
623
+ function jitCompile(backend, jaxpr, consts) {
624
+ if (jaxpr.inBinders.length < consts.length) {
625
+ throw new TypeError(
626
+ `Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`
627
+ );
628
+ }
629
+ for (let i = 0; i < consts.length; i++) {
630
+ if (consts[i].device !== backend.type) {
631
+ throw new TypeError(
632
+ `Const ${i} has device ${consts[i].device}, but expected ${backend.type}`
633
+ );
634
+ }
635
+ }
636
+ const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
637
+ const cached = jitCompileCache.get(cacheKey);
638
+ if (cached) return cached;
639
+ if (DEBUG >= 1) {
640
+ console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
641
+ }
642
+ jaxpr = jaxpr.flatten().simplify();
643
+ const nargs = jaxpr.inBinders.length - consts.length;
644
+ const builder = new JitProgramBuilder(backend, nargs);
645
+ const blackNodes = splitGraphDataflow(backend, jaxpr);
646
+ const ctx = /* @__PURE__ */ new Map();
647
+ for (let i = 0; i < consts.length; i++) {
648
+ const v = jaxpr.inBinders[i];
649
+ const slot = consts[i]._realizeSource();
650
+ ctx.set(v, { type: "imm", arg: builder.pushConst(slot) });
651
+ }
652
+ for (let i = 0; i < nargs; i++) {
653
+ const v = jaxpr.inBinders[consts.length + i];
654
+ ctx.set(v, { type: "imm", arg: i });
655
+ }
656
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
657
+ const eqn = jaxpr.eqns[i];
658
+ const inputExps = [];
659
+ const inputAvals = [];
660
+ const inputArgs = [];
661
+ for (const input of eqn.inputs) {
662
+ if (input instanceof Var) {
663
+ const jitValue = ctx.get(input);
664
+ if (jitValue.type === "exp") {
665
+ const gidMap = /* @__PURE__ */ new Map();
666
+ for (const [gid, jitId] of jitValue.args.entries()) {
667
+ let newGid = inputArgs.indexOf(jitId);
668
+ if (newGid === -1) {
669
+ newGid = inputArgs.length;
670
+ inputArgs.push(jitId);
671
+ }
672
+ gidMap.set(gid, newGid);
673
+ }
674
+ inputExps.push(jitValue.exp.reindexGids(gidMap));
675
+ } else if (jitValue.type === "imm") {
676
+ let gid = inputArgs.indexOf(jitValue.arg);
677
+ if (gid === -1) {
678
+ gid = inputArgs.length;
679
+ inputArgs.push(jitValue.arg);
680
+ }
681
+ const st = ShapeTracker.fromShape(input.aval.shape);
682
+ const indices = unravelAlu(st.shape, AluVar.gidx);
683
+ inputExps.push(AluExp.globalView(input.aval.dtype, gid, st, indices));
684
+ } else {
685
+ jitValue;
686
+ }
687
+ inputAvals.push(input.aval);
688
+ } else if (input instanceof Lit) {
689
+ inputExps.push(AluExp.const(input.dtype, input.value));
690
+ inputAvals.push(input.aval);
691
+ } else {
692
+ throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
693
+ }
694
+ }
695
+ const nargs2 = inputArgs.length;
696
+ const rule = jitRules[eqn.primitive];
697
+ if (!rule)
698
+ throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
699
+ const kernel = rule(nargs2, inputExps, inputAvals, eqn.params);
700
+ const outVar = eqn.outBinders[0];
701
+ if (kernel.reduction || blackNodes.has(outVar)) {
702
+ const outId = builder.pushKernel(kernel, inputArgs);
703
+ ctx.set(outVar, { type: "imm", arg: outId });
704
+ } else {
705
+ ctx.set(outVar, { type: "exp", exp: kernel.exp, args: inputArgs });
706
+ }
707
+ }
708
+ const outputIds = [];
709
+ for (const out of jaxpr.outs) {
710
+ if (out instanceof Var) {
711
+ const jitValue = ctx.get(out);
712
+ if (jitValue.type !== "imm")
713
+ throw new Error("internal: Expected imm, since outs are black nodes");
714
+ outputIds.push(jitValue.arg);
715
+ } else if (out instanceof Lit) {
716
+ outputIds.push(builder.pushLit(out));
717
+ } else {
718
+ out;
719
+ }
720
+ }
721
+ builder.insertFreeSteps(outputIds);
722
+ const jp = new JitProgram(backend, builder.steps, range(0, nargs), outputIds);
723
+ jitCompileCache.set(cacheKey, jp);
724
+ return jp;
725
+ }
726
+ function broadcastedJit(fn) {
727
+ return (nargs, exps, avals, params) => {
728
+ const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
729
+ exps = exps.map(
730
+ (exp4) => exp4.rewrite((exp5) => {
731
+ if (exp5.op === "GlobalView" /* GlobalView */) {
732
+ let [gid, st] = exp5.arg;
733
+ if (!deepEqual(st.shape, newShape)) {
734
+ st = st.broadcast(
735
+ newShape,
736
+ range(newShape.length - st.shape.length)
737
+ );
738
+ const indices = unravelAlu(st.shape, AluVar.gidx);
739
+ return AluExp.globalView(exp5.dtype, gid, st, indices);
740
+ }
741
+ }
742
+ })
743
+ );
744
+ const exp3 = fn(exps, params);
745
+ return new Kernel(nargs, prod(newShape), exp3);
746
+ };
747
+ }
748
+ function reshapeJit(fn) {
749
+ return (nargs, [a], [as], params) => {
750
+ a = a.rewrite((exp3) => {
751
+ if (exp3.op === "GlobalView" /* GlobalView */) {
752
+ const [gid, st] = exp3.arg;
753
+ const newSt = fn(st, params);
754
+ const indices = unravelAlu(newSt.shape, AluVar.gidx);
755
+ return AluExp.globalView(exp3.dtype, gid, newSt, indices);
756
+ }
757
+ });
758
+ return new Kernel(nargs, prod(as.shape), a);
759
+ };
760
+ }
761
+ var jitRules = {
762
+ ["add" /* Add */]: broadcastedJit(([a, b]) => AluExp.add(a, b)),
763
+ ["mul" /* Mul */]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
764
+ ["idiv" /* Idiv */]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
765
+ ["neg" /* Neg */]: broadcastedJit(
766
+ ([a]) => AluExp.sub(AluExp.const(a.dtype, 0), a)
767
+ ),
768
+ ["reciprocal" /* Reciprocal */]: broadcastedJit(([a]) => AluExp.reciprocal(a)),
769
+ ["sin" /* Sin */]: broadcastedJit(([a]) => AluExp.sin(a)),
770
+ ["cos" /* Cos */]: broadcastedJit(([a]) => AluExp.cos(a)),
771
+ ["exp" /* Exp */]: broadcastedJit(([a]) => AluExp.exp(a)),
772
+ ["log" /* Log */]: broadcastedJit(([a]) => AluExp.log(a)),
773
+ ["min" /* Min */]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
774
+ ["max" /* Max */]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
775
+ ["reduce_sum" /* ReduceSum */](nargs, [a], [as], { axis }) {
776
+ const keptAxes = [];
777
+ const shiftedAxes = [];
778
+ const newShape = [];
779
+ for (let i = 0; i < as.shape.length; i++) {
780
+ if (axis.includes(i)) shiftedAxes.push(i);
781
+ else {
782
+ keptAxes.push(i);
783
+ newShape.push(as.shape[i]);
784
+ }
785
+ }
786
+ const size2 = prod(newShape);
787
+ const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
788
+ newShape.push(reductionSize);
789
+ a = a.rewrite((exp3) => {
790
+ if (exp3.op === "GlobalView" /* GlobalView */) {
791
+ const [gid, st] = exp3.arg;
792
+ const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
793
+ const indices = unravelAlu(newShape.slice(0, -1), AluVar.gidx);
794
+ indices.push(AluVar.ridx);
795
+ return AluExp.globalView(exp3.dtype, gid, newSt, indices);
796
+ }
797
+ });
798
+ const reduction = new Reduction(a.dtype, "Add" /* Add */, reductionSize);
799
+ return new Kernel(nargs, size2, a, reduction);
800
+ },
801
+ ["compare" /* Compare */]: broadcastedJit(([a, b], { op }) => {
802
+ return aluCompare(a, b, op);
803
+ }),
804
+ ["where" /* Where */]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
805
+ ["transpose" /* Transpose */]: reshapeJit(
806
+ (st, { perm }) => {
807
+ return st.permute(perm);
808
+ }
809
+ ),
810
+ ["broadcast" /* Broadcast */]: reshapeJit(
811
+ (st, { shape: shape2, axis }) => {
812
+ return st.broadcast(shape2, axis);
813
+ }
814
+ ),
815
+ ["reshape" /* Reshape */]: reshapeJit(
816
+ (st, { shape: shape2 }) => {
817
+ return st.reshape(shape2);
818
+ }
819
+ ),
820
+ ["flip" /* Flip */]: reshapeJit(
821
+ (st, { axis }) => {
822
+ const arg = rep(st.shape.length, false);
823
+ for (const ax of axis) arg[ax] = true;
824
+ return st.flip(arg);
825
+ }
826
+ )
827
+ };
828
+ function splitGraphDataflow(backend, jaxpr) {
829
+ const varToEqn = /* @__PURE__ */ new Map();
830
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
831
+ const eqn = jaxpr.eqns[i];
832
+ for (const v of eqn.outBinders) {
833
+ if (v instanceof Var) varToEqn.set(v, i);
834
+ }
835
+ }
836
+ const blackNodes = /* @__PURE__ */ new Set();
837
+ const p1NextBlack = /* @__PURE__ */ new Map();
838
+ for (const v of jaxpr.outs) {
839
+ if (v instanceof Var) {
840
+ blackNodes.add(v);
841
+ p1NextBlack.set(v, v);
842
+ }
843
+ }
844
+ for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
845
+ const eqn = jaxpr.eqns[i];
846
+ if (eqn.primitive === "reduce_sum" /* ReduceSum */ || eqn.outBinders.some((v) => blackNodes.has(v))) {
847
+ for (const v of eqn.outBinders) {
848
+ blackNodes.add(v);
849
+ p1NextBlack.set(v, v);
850
+ }
851
+ continue;
852
+ }
853
+ const reach = /* @__PURE__ */ new Set();
854
+ for (let j = i + 1; j < jaxpr.eqns.length; j++) {
855
+ for (const v of jaxpr.eqns[j].inputs) {
856
+ if (v instanceof Var && eqn.outBinders.includes(v)) {
857
+ for (const o of jaxpr.eqns[j].outBinders) {
858
+ const u = p1NextBlack.get(o);
859
+ if (u) reach.add(u);
860
+ }
861
+ }
862
+ }
863
+ }
864
+ if (reach.size === 1) {
865
+ const b = reach.values().next().value;
866
+ for (const v of eqn.outBinders) p1NextBlack.set(v, b);
867
+ } else if (reach.size > 1) {
868
+ for (const v of eqn.outBinders) {
869
+ blackNodes.add(v);
870
+ p1NextBlack.set(v, v);
871
+ }
872
+ }
873
+ }
874
+ const p2Deps = /* @__PURE__ */ new Map();
875
+ for (const v of jaxpr.inBinders) {
876
+ p2Deps.set(v, /* @__PURE__ */ new Set([v]));
877
+ }
878
+ let p2idx = 0;
879
+ while (p2idx < jaxpr.eqns.length) {
880
+ const eqn = jaxpr.eqns[p2idx++];
881
+ const deps = [];
882
+ if (eqn.outBinders.some((v) => blackNodes.has(v))) {
883
+ continue;
884
+ }
885
+ for (const input of eqn.inputs) {
886
+ if (input instanceof Var) {
887
+ if (blackNodes.has(input)) deps.push(/* @__PURE__ */ new Set([input]));
888
+ else deps.push(p2Deps.get(input));
889
+ } else {
890
+ deps.push(/* @__PURE__ */ new Set());
891
+ }
892
+ }
893
+ const depCounter = /* @__PURE__ */ new Map();
894
+ for (const depSet of deps) {
895
+ for (const dep of depSet) {
896
+ depCounter.set(dep, (depCounter.get(dep) ?? 0) + 1);
897
+ }
898
+ }
899
+ if (depCounter.size > backend.maxArgs) {
900
+ let maxUniqueDeps = 0;
901
+ let assocInput = -1;
902
+ for (let i = 0; i < eqn.inputs.length; i++) {
903
+ const input = eqn.inputs[i];
904
+ if (input instanceof Var && varToEqn.has(input)) {
905
+ let uniqueDeps = 0;
906
+ for (const dep of deps[i]) {
907
+ if (depCounter.get(dep) === 1) uniqueDeps++;
908
+ }
909
+ if (uniqueDeps > maxUniqueDeps) {
910
+ maxUniqueDeps = uniqueDeps;
911
+ assocInput = i;
912
+ }
913
+ }
914
+ }
915
+ if (assocInput === -1) {
916
+ throw new Error(
917
+ `internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`
918
+ );
919
+ }
920
+ const assocVar = eqn.inputs[assocInput];
921
+ p2idx = varToEqn.get(assocVar);
922
+ for (const out of jaxpr.eqns[p2idx].outBinders) {
923
+ blackNodes.add(out);
924
+ }
925
+ } else {
926
+ const s = new Set(depCounter.keys());
927
+ for (const out of eqn.outBinders) p2Deps.set(out, s);
928
+ }
929
+ }
930
+ return blackNodes;
931
+ }
932
+
933
+ // src/frontend/array.ts
934
+ var JsArray = globalThis.Array;
935
+ var inlineArrayLimit = 1e3;
936
+ var fudgeArray = pureArray;
937
+ var PendingExecute = class {
938
+ // since this could be held by multiple arrays, cancel when it hits 0
939
+ constructor(backend, kernel, inputs, outputs) {
940
+ this.backend = backend;
941
+ this.kernel = kernel;
942
+ this.inputs = inputs;
943
+ this.outputs = outputs;
944
+ for (const slot of inputs) this.backend.incRef(slot);
945
+ }
946
+ prepared = null;
947
+ submitted = false;
948
+ #promise = null;
949
+ // for prepare
950
+ #rc = 1;
951
+ // Change the reference count of the PendingExecute object.
952
+ // Used when copying the object to a new Array, or disposing an array.
953
+ updateRc(delta) {
954
+ if (this.#rc <= 0) throw new Error("internal: PendingExecute used rc<=0");
955
+ this.#rc += delta;
956
+ if (this.#rc <= 0 && !this.submitted) {
957
+ for (const slot of this.inputs) this.backend.decRef(slot);
958
+ }
959
+ }
960
+ async prepare() {
961
+ if (this.prepared) return;
962
+ if (this.#promise) {
963
+ await this.#promise;
964
+ return;
965
+ }
966
+ this.#promise = (async () => {
967
+ this.prepared = await this.backend.prepare(this.kernel);
968
+ })();
969
+ await this.#promise;
970
+ }
971
+ prepareSync() {
972
+ if (this.prepared) return;
973
+ this.prepared = this.backend.prepareSync(this.kernel);
974
+ }
975
+ submit() {
976
+ if (this.submitted) return;
977
+ if (this.#rc <= 0) throw new Error("internal: PendingExecute used rc<=0");
978
+ if (!this.prepared) throw new Error("internal: Not prepared yet");
979
+ this.submitted = true;
980
+ this.backend.dispatch(this.prepared, this.inputs, this.outputs);
981
+ for (const slot of this.inputs) this.backend.decRef(slot);
982
+ }
983
+ };
984
+ var Array3 = class _Array extends Tracer {
985
+ static #nextId = 1001;
986
+ // For unique hashing where needed.
987
+ id;
988
+ #dtype;
989
+ #source;
990
+ #st;
991
+ #backend;
992
+ #rc;
993
+ // reference count for this specific Array object
994
+ #pendingSet;
995
+ // only if source is `Slot`
996
+ constructor(source, st, dtype, backend, pending = null) {
997
+ super(baseArrayTrace);
998
+ this.id = _Array.#nextId++;
999
+ this.#dtype = dtype;
1000
+ this.#source = source;
1001
+ this.#st = st;
1002
+ this.#backend = backend;
1003
+ this.#rc = 1;
1004
+ this.#pendingSet = new Set(pending);
1005
+ if (!(source instanceof AluExp)) {
1006
+ backend.incRef(source);
1007
+ }
1008
+ }
1009
+ get aval() {
1010
+ return new ShapedArray(this.#st.shape, this.#dtype);
1011
+ }
1012
+ /** Return a simple string representation of the array's dimensions. */
1013
+ toString() {
1014
+ return `Array:${this.#dtype}[${this.shape.join(",")}]`;
1015
+ }
1016
+ get device() {
1017
+ return this.#backend.type;
1018
+ }
1019
+ #check() {
1020
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1021
+ }
1022
+ get ref() {
1023
+ this.#check();
1024
+ this.#rc++;
1025
+ return this;
1026
+ }
1027
+ dispose() {
1028
+ this.#check();
1029
+ if (this.#rc-- === 0) {
1030
+ for (const exe of this.#pending) exe.updateRc(-1);
1031
+ if (typeof this.#source === "number") {
1032
+ this.#backend.decRef(this.#source);
1033
+ }
1034
+ }
1035
+ }
1036
+ /** Get the pending executes as a list, trimming if already submitted. */
1037
+ get #pending() {
1038
+ if (!this.#pendingSet) return [];
1039
+ for (const p of this.#pendingSet) {
1040
+ if (p.submitted) this.#pendingSet.delete(p);
1041
+ }
1042
+ if (this.#pendingSet.size === 0) {
1043
+ this.#pendingSet = null;
1044
+ return [];
1045
+ } else {
1046
+ return [...this.#pendingSet];
1047
+ }
1048
+ }
1049
+ /**
1050
+ * Convert this array into a primitive value.
1051
+ *
1052
+ * This only works for scalars (0-dimensional arrays). It lets you get values
1053
+ * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1054
+ * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1055
+ *
1056
+ * This method is also called for `==` equality.
1057
+ */
1058
+ [Symbol.toPrimitive]() {
1059
+ if (this.ndim === 0) {
1060
+ return this.dataSync()[0];
1061
+ } else {
1062
+ throw new Error(
1063
+ `Cannot convert non-scalar array to primitive: ${this.toString()}`
1064
+ );
1065
+ }
1066
+ }
1067
+ #reshape(st) {
1068
+ this.#check();
1069
+ const pending = this.#pending;
1070
+ for (const exe of pending) exe.updateRc(1);
1071
+ const ar = new _Array(this.#source, st, this.#dtype, this.#backend, pending);
1072
+ this.dispose();
1073
+ return ar;
1074
+ }
1075
+ /** Move axes to the rightmost dimension of the shape. */
1076
+ #moveAxesDown(axis) {
1077
+ this.#check();
1078
+ if (axis.length === 0) return this.reshape(this.shape.concat(1));
1079
+ const newShape = [];
1080
+ const keptAxes = [];
1081
+ const shiftedAxes = [];
1082
+ for (let i = 0; i < this.#st.shape.length; i++) {
1083
+ if (axis.includes(i)) {
1084
+ shiftedAxes.push(i);
1085
+ } else {
1086
+ keptAxes.push(i);
1087
+ newShape.push(this.#st.shape[i]);
1088
+ }
1089
+ }
1090
+ newShape.push(-1);
1091
+ return this.#transpose(keptAxes.concat(shiftedAxes)).reshape(newShape);
1092
+ }
1093
+ #transpose(perm) {
1094
+ this.#check();
1095
+ if (!isPermutation(perm, this.ndim))
1096
+ throw new Error(`Invalid perm for transpose: ${JSON.stringify(perm)}`);
1097
+ return this.#reshape(this.#st.permute(perm));
1098
+ }
1099
+ #unary(op) {
1100
+ this.#check();
1101
+ if (this.#source instanceof AluExp) {
1102
+ const exp4 = new AluExp(op, this.#dtype, [this.#source]);
1103
+ return new _Array(exp4, this.#st, this.#dtype, this.#backend);
1104
+ }
1105
+ const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1106
+ const exp3 = new AluExp(op, this.#dtype, [
1107
+ AluExp.globalView(this.#dtype, 0, this.#st, indices)
1108
+ ]);
1109
+ const kernel = new Kernel(1, this.#st.size, exp3);
1110
+ const output = this.#backend.malloc(kernel.size * 4);
1111
+ const pending = [...this.#pending];
1112
+ for (const exe of pending) exe.updateRc(1);
1113
+ pending.push(
1114
+ new PendingExecute(this.#backend, kernel, [this.#source], [output])
1115
+ );
1116
+ this.dispose();
1117
+ return new _Array(
1118
+ output,
1119
+ ShapeTracker.fromShape(this.shape),
1120
+ this.#dtype,
1121
+ this.#backend,
1122
+ pending
1123
+ );
1124
+ }
1125
+ #binary(op, other) {
1126
+ const custom = (src) => new AluExp(op, this.#dtype, src);
1127
+ return _Array.#naryCustom(op, custom, [this, other]);
1128
+ }
1129
+ static #naryCustom(name, custom, arrays, dtypeOverride, dtypeOutput) {
1130
+ const n = arrays.length;
1131
+ const backend = arrays[0].#backend;
1132
+ if (n === 0) throw new TypeError(`No inputs for ${name}`);
1133
+ for (const ar of arrays) ar.#check();
1134
+ let dtype;
1135
+ for (let i = 0; i < n; i++) {
1136
+ if (dtypeOverride?.[i]) {
1137
+ if (arrays[i].#dtype !== dtypeOverride[i]) {
1138
+ throw new TypeError(
1139
+ `Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`
1140
+ );
1141
+ }
1142
+ } else {
1143
+ if (!dtype) dtype = arrays[i].#dtype;
1144
+ else if (arrays[i].#dtype !== dtype) {
1145
+ throw new TypeError(
1146
+ `Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`
1147
+ );
1148
+ }
1149
+ }
1150
+ if (arrays[i].#backend !== backend) {
1151
+ throw new TypeError(
1152
+ `Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`
1153
+ );
1154
+ }
1155
+ }
1156
+ dtypeOutput ??= dtype;
1157
+ if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1158
+ const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
1159
+ arrays = arrays.map((ar) => {
1160
+ if (deepEqual(ar.shape, newShape)) return ar;
1161
+ return ar.#reshape(
1162
+ ar.#st.broadcast(newShape, range(newShape.length - ar.ndim))
1163
+ );
1164
+ });
1165
+ if (arrays.every((ar) => ar.#source instanceof AluExp)) {
1166
+ if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1167
+ const exp5 = custom(arrays.map((ar) => ar.#source));
1168
+ return new _Array(exp5, arrays[0].#st, exp5.dtype, backend);
1169
+ }
1170
+ const exp4 = custom(
1171
+ arrays.map((ar) => {
1172
+ const src2 = ar.#source;
1173
+ if (ar.#st.contiguous) return src2;
1174
+ return accessorAluExp(src2, ar.#st, unravelAlu(newShape, AluVar.idx));
1175
+ })
1176
+ );
1177
+ const st = ShapeTracker.fromShape(newShape);
1178
+ return new _Array(exp4, st, exp4.dtype, backend);
1179
+ }
1180
+ const inputs = [];
1181
+ const src = [];
1182
+ for (const ar of arrays) {
1183
+ const indices = unravelAlu(newShape, AluVar.gidx);
1184
+ if (ar.#source instanceof AluExp) {
1185
+ src.push(accessorAluExp(ar.#source, ar.#st, indices));
1186
+ } else {
1187
+ let gid = inputs.indexOf(ar.#source);
1188
+ if (gid === -1) {
1189
+ gid = inputs.length;
1190
+ inputs.push(ar.#source);
1191
+ }
1192
+ src.push(AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1193
+ }
1194
+ }
1195
+ const exp3 = custom(src);
1196
+ const kernel = new Kernel(inputs.length, arrays[0].#st.size, exp3);
1197
+ const output = backend.malloc(kernel.size * 4);
1198
+ const pending = [...arrays.flatMap((ar) => ar.#pending)];
1199
+ for (const exe of pending) exe.updateRc(1);
1200
+ pending.push(new PendingExecute(backend, kernel, inputs, [output]));
1201
+ for (const ar of arrays) ar.dispose();
1202
+ return new _Array(
1203
+ output,
1204
+ ShapeTracker.fromShape(newShape),
1205
+ dtypeOutput,
1206
+ backend,
1207
+ pending
1208
+ );
1209
+ }
1210
+ /** Reduce the last dimension of the array by an operation. */
1211
+ #reduce(op) {
1212
+ this.#check();
1213
+ if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1214
+ const shape2 = this.shape;
1215
+ const reduction = new Reduction(this.#dtype, op, shape2[shape2.length - 1]);
1216
+ const newShape = shape2.slice(0, -1);
1217
+ const newSize = prod(newShape);
1218
+ const indices = [...unravelAlu(newShape, AluVar.gidx), AluVar.ridx];
1219
+ const [index, valid] = this.#st.toAluExp(indices);
1220
+ let exp3;
1221
+ const inputs = [];
1222
+ if (this.#source instanceof AluExp) {
1223
+ exp3 = AluExp.where(
1224
+ valid,
1225
+ this.#source.substitute({ idx: index }),
1226
+ AluExp.f32(0)
1227
+ );
1228
+ } else {
1229
+ inputs.push(this.#source);
1230
+ exp3 = AluExp.where(
1231
+ valid,
1232
+ AluExp.globalIndex("float32" /* Float32 */, 0, index),
1233
+ AluExp.f32(0)
1234
+ );
1235
+ }
1236
+ const kernel = new Kernel(inputs.length, newSize, exp3, reduction);
1237
+ const output = this.#backend.malloc(kernel.size * 4);
1238
+ const pending = [...this.#pending];
1239
+ for (const exe of pending) exe.updateRc(1);
1240
+ pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1241
+ this.dispose();
1242
+ return new _Array(
1243
+ output,
1244
+ ShapeTracker.fromShape(newShape),
1245
+ this.#dtype,
1246
+ this.#backend,
1247
+ pending
1248
+ );
1249
+ }
1250
+ /**
1251
+ * Normalizes this array into one backed by a `Slot`.
1252
+ *
1253
+ * This mutates the array in-place, turning it into an equivalent array whose
1254
+ * source is actual, contiguous data on device.
1255
+ *
1256
+ * Calling this twice is a no-op.
1257
+ */
1258
+ #realize() {
1259
+ this.#check();
1260
+ const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1261
+ if (this.#source instanceof AluExp) {
1262
+ const exp3 = accessorAluExp(this.#source, this.#st, indices);
1263
+ const kernel = new Kernel(0, this.#st.size, exp3);
1264
+ const output = this.#backend.malloc(kernel.size * 4);
1265
+ const pendingItem = new PendingExecute(
1266
+ this.#backend,
1267
+ kernel,
1268
+ [],
1269
+ [output]
1270
+ );
1271
+ this.#source = output;
1272
+ this.#st = ShapeTracker.fromShape(this.shape);
1273
+ this.#pendingSet = /* @__PURE__ */ new Set([pendingItem]);
1274
+ } else {
1275
+ if (this.#st.contiguous) return;
1276
+ const exp3 = accessorGlobal(this.#dtype, 0, this.#st, indices);
1277
+ const kernel = new Kernel(1, this.#st.size, exp3);
1278
+ const output = this.#backend.malloc(kernel.size * 4);
1279
+ const pendingItem = new PendingExecute(
1280
+ this.#backend,
1281
+ kernel,
1282
+ [this.#source],
1283
+ [output]
1284
+ );
1285
+ this.#source = output;
1286
+ this.#st = ShapeTracker.fromShape(this.shape);
1287
+ this.#pendingSet ??= /* @__PURE__ */ new Set();
1288
+ this.#pendingSet.add(pendingItem);
1289
+ }
1290
+ }
1291
+ #dataInline() {
1292
+ this.#check();
1293
+ const exp3 = this.#source;
1294
+ const ar = new _Array(exp3, this.#st, this.dtype, getBackend("cpu"));
1295
+ this.dispose();
1296
+ return ar.dataSync();
1297
+ }
1298
+ /** Realize the array and return it as data. */
1299
+ async data() {
1300
+ if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") {
1301
+ return this.#dataInline();
1302
+ }
1303
+ this.#realize();
1304
+ const pending = this.#pending;
1305
+ if (pending) {
1306
+ await Promise.all(pending.map((p) => p.prepare()));
1307
+ for (const p of pending) p.submit();
1308
+ }
1309
+ const buf = await this.#backend.read(this.#source);
1310
+ this.dispose();
1311
+ return this.dtype === "float32" /* Float32 */ ? new Float32Array(buf) : new Int32Array(buf);
1312
+ }
1313
+ /** Wait for this array to be placed on the backend, if needed. */
1314
+ async wait() {
1315
+ this.#check();
1316
+ if (this.#source instanceof AluExp) return;
1317
+ const pending = this.#pending;
1318
+ if (pending) {
1319
+ await Promise.all(pending.map((p) => p.prepare()));
1320
+ for (const p of pending) p.submit();
1321
+ }
1322
+ await this.#backend.read(this.#source, 0, 0);
1323
+ this.dispose();
1324
+ }
1325
+ /**
1326
+ * Realize the array and return it as data. This is a sync variant and not
1327
+ * recommended for performance reasons, as it will block rendering.
1328
+ */
1329
+ dataSync() {
1330
+ if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") {
1331
+ return this.#dataInline();
1332
+ }
1333
+ this.#realize();
1334
+ for (const p of this.#pending) {
1335
+ p.prepareSync();
1336
+ p.submit();
1337
+ }
1338
+ const buf = this.#backend.readSync(this.#source);
1339
+ this.dispose();
1340
+ return this.dtype === "float32" /* Float32 */ ? new Float32Array(buf) : new Int32Array(buf);
1341
+ }
1342
+ /** Convert this array into a JavaScript object (blocking). */
1343
+ js() {
1344
+ return dataToJs(this.dtype, this.dataSync(), this.shape);
1345
+ }
1346
+ /** Convert this array into a JavaScript object, asynchronously. */
1347
+ async jsAsync() {
1348
+ return dataToJs(this.dtype, await this.data(), this.shape);
1349
+ }
1350
+ /** @private Internal plumbing method for Array / Tracer ops. */
1351
+ static _implRules() {
1352
+ return {
1353
+ ["add" /* Add */]([x, y]) {
1354
+ return [x.#binary("Add" /* Add */, y)];
1355
+ },
1356
+ ["mul" /* Mul */]([x, y]) {
1357
+ return [x.#binary("Mul" /* Mul */, y)];
1358
+ },
1359
+ ["idiv" /* Idiv */]([x, y]) {
1360
+ return [x.#binary("Idiv" /* Idiv */, y)];
1361
+ },
1362
+ ["neg" /* Neg */]([x]) {
1363
+ return [zerosLike(x).#binary("Sub" /* Sub */, x)];
1364
+ },
1365
+ ["reciprocal" /* Reciprocal */]([x]) {
1366
+ return [x.#unary("Reciprocal" /* Reciprocal */)];
1367
+ },
1368
+ ["sin" /* Sin */]([x]) {
1369
+ return [x.#unary("Sin" /* Sin */)];
1370
+ },
1371
+ ["cos" /* Cos */]([x]) {
1372
+ return [x.#unary("Cos" /* Cos */)];
1373
+ },
1374
+ ["exp" /* Exp */]([x]) {
1375
+ return [x.#unary("Exp" /* Exp */)];
1376
+ },
1377
+ ["log" /* Log */]([x]) {
1378
+ return [x.#unary("Log" /* Log */)];
1379
+ },
1380
+ ["min" /* Min */]([x, y]) {
1381
+ return [x.#binary("Min" /* Min */, y)];
1382
+ },
1383
+ ["max" /* Max */]([x, y]) {
1384
+ return [x.#binary("Max" /* Max */, y)];
1385
+ },
1386
+ ["reduce_sum" /* ReduceSum */]([x], { axis }) {
1387
+ if (axis.length === 0) return [x];
1388
+ return [x.#moveAxesDown(axis).#reduce("Add" /* Add */)];
1389
+ },
1390
+ ["compare" /* Compare */]([x, y], { op }) {
1391
+ const custom = ([x2, y2]) => aluCompare(x2, y2, op);
1392
+ return [_Array.#naryCustom("compare", custom, [x, y], [], "bool" /* Bool */)];
1393
+ },
1394
+ ["where" /* Where */]([cond, x, y]) {
1395
+ const custom = ([cond2, x2, y2]) => AluExp.where(cond2, x2, y2);
1396
+ return [_Array.#naryCustom("where", custom, [cond, x, y], ["bool" /* Bool */])];
1397
+ },
1398
+ ["transpose" /* Transpose */]([x], { perm }) {
1399
+ return [x.#transpose(perm)];
1400
+ },
1401
+ ["broadcast" /* Broadcast */]([x], { shape: shape2, axis }) {
1402
+ return [x.#reshape(x.#st.broadcast(shape2, axis))];
1403
+ },
1404
+ ["reshape" /* Reshape */]([x], { shape: shape2 }) {
1405
+ return [x.#reshape(x.#st.reshape(shape2))];
1406
+ },
1407
+ ["flip" /* Flip */]([x], { axis }) {
1408
+ const arg = rep(x.ndim, false);
1409
+ for (const ax of axis) arg[ax] = true;
1410
+ return [x.#reshape(x.#st.flip(arg))];
1411
+ },
1412
+ ["jit_call" /* JitCall */](args, { jaxpr, numConsts }) {
1413
+ if (jaxpr.inBinders.length !== args.length) {
1414
+ throw new Error(
1415
+ `jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`
1416
+ );
1417
+ }
1418
+ const backend = getBackend();
1419
+ const consts = args.slice(0, numConsts);
1420
+ const tracers = args.slice(numConsts);
1421
+ const jp = jitCompile(backend, jaxpr, consts);
1422
+ const { outputs, pending } = jp.execute(
1423
+ tracers.map((x) => x._realizeSource())
1424
+ );
1425
+ const prevPending = args.flatMap((x) => x.#pending);
1426
+ for (const exe of prevPending) exe.updateRc(1);
1427
+ pending.splice(0, 0, ...prevPending);
1428
+ args.forEach((x) => x.dispose());
1429
+ return outputs.map((source, i) => {
1430
+ return new _Array(
1431
+ source,
1432
+ ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
1433
+ jaxpr.outs[i].aval.dtype,
1434
+ backend,
1435
+ pending
1436
+ );
1437
+ });
1438
+ }
1439
+ };
1440
+ }
1441
+ // Internal methods, not public API. Do not use.
1442
+ _realizeSource() {
1443
+ this.#realize();
1444
+ return this.#source;
1445
+ }
1446
+ };
1447
+ function scalar(value, { dtype, device } = {}) {
1448
+ if (typeof value === "number") {
1449
+ dtype ??= "float32" /* Float32 */;
1450
+ if (!["float32" /* Float32 */, "int32" /* Int32 */].includes(dtype))
1451
+ throw new TypeError(`Mismatched dtype for scalar ${value}`);
1452
+ } else if (typeof value === "boolean") {
1453
+ dtype ??= "bool" /* Bool */;
1454
+ if (!["float32" /* Float32 */, "int32" /* Int32 */, "bool" /* Bool */].includes(dtype))
1455
+ throw new TypeError(`Mismatched dtype for scalar ${value}`);
1456
+ } else {
1457
+ throw new TypeError(`Invalid type for scalar ${value}`);
1458
+ }
1459
+ return new Array3(
1460
+ AluExp.const(dtype, value),
1461
+ ShapeTracker.fromShape([]),
1462
+ dtype,
1463
+ getBackend(device)
1464
+ );
1465
+ }
1466
+ function array(values, { shape: shape2, dtype, device } = {}) {
1467
+ if (values instanceof Array3) {
1468
+ if (shape2 && !deepEqual(values.shape, shape2)) {
1469
+ values = values.reshape(shape2);
1470
+ }
1471
+ if (dtype && values.dtype !== dtype) {
1472
+ throw new Error("array astype not implemented yet");
1473
+ }
1474
+ return values;
1475
+ } else if (values instanceof Float32Array || values instanceof Int32Array) {
1476
+ return arrayFromData(values, shape2 ?? [values.length], {
1477
+ dtype,
1478
+ device
1479
+ });
1480
+ } else {
1481
+ if (!shape2) {
1482
+ shape2 = [];
1483
+ let cur = values;
1484
+ while (JsArray.isArray(cur)) {
1485
+ shape2.push(cur.length);
1486
+ cur = cur[0];
1487
+ }
1488
+ }
1489
+ const size2 = prod(shape2);
1490
+ const flat = recursiveFlatten(values);
1491
+ if (flat.length !== size2) {
1492
+ throw new Error(
1493
+ `Jagged shape: ${JSON.stringify(shape2)} vs ${flat.length}`
1494
+ );
1495
+ }
1496
+ if (size2 === 0) return zeros(shape2, { dtype, device });
1497
+ if (typeof flat[0] === "boolean") {
1498
+ dtype = dtype ?? "bool" /* Bool */;
1499
+ const data = new Int32Array(flat.map((x) => x ? 1 : 0));
1500
+ return arrayFromData(data, shape2, { dtype, device });
1501
+ } else {
1502
+ dtype = dtype ?? "float32" /* Float32 */;
1503
+ const data = new Float32Array(flat);
1504
+ return arrayFromData(data, shape2, { dtype, device });
1505
+ }
1506
+ }
1507
+ }
1508
+ function arrayFromData(data, shape2, { dtype, device } = {}) {
1509
+ if (data.length < inlineArrayLimit) {
1510
+ let allEqual = true;
1511
+ for (let i = 1; i < data.length; i++) {
1512
+ if (data[i] !== data[0]) {
1513
+ allEqual = false;
1514
+ break;
1515
+ }
1516
+ }
1517
+ if (allEqual) {
1518
+ return full(shape2, data[0], { dtype, device });
1519
+ }
1520
+ }
1521
+ const backend = getBackend(device);
1522
+ if (data instanceof Float32Array) {
1523
+ if (dtype && dtype !== "float32" /* Float32 */) {
1524
+ throw new Error("Float32Array must have float32 type");
1525
+ }
1526
+ const slot = backend.malloc(data.byteLength, data.buffer);
1527
+ return new Array3(
1528
+ slot,
1529
+ ShapeTracker.fromShape(shape2),
1530
+ "float32" /* Float32 */,
1531
+ backend
1532
+ );
1533
+ } else if (data instanceof Int32Array) {
1534
+ if (dtype && dtype !== "int32" /* Int32 */ && dtype !== "bool" /* Bool */) {
1535
+ throw new Error("Int32Array must have int32 or bool type");
1536
+ }
1537
+ const slot = backend.malloc(data.byteLength, data.buffer);
1538
+ return new Array3(
1539
+ slot,
1540
+ ShapeTracker.fromShape(shape2),
1541
+ dtype ?? "int32" /* Int32 */,
1542
+ backend
1543
+ );
1544
+ } else {
1545
+ throw new Error("Unsupported data type");
1546
+ }
1547
+ }
1548
+ function dataToJs(dtype, data, shape2) {
1549
+ if (shape2.length === 0) {
1550
+ return dtype === "bool" /* Bool */ ? Boolean(data[0]) : data[0];
1551
+ }
1552
+ const [first, ...rest] = shape2;
1553
+ const restSize = prod(rest);
1554
+ const ret = [];
1555
+ for (let i = 0; i < first; i++) {
1556
+ const subarray = data.slice(i * restSize, (i + 1) * restSize);
1557
+ ret.push(dataToJs(dtype, subarray, rest));
1558
+ }
1559
+ return ret;
1560
+ }
1561
+ function pureArray(x) {
1562
+ if (x instanceof Tracer) {
1563
+ return x;
1564
+ } else {
1565
+ return scalar(x);
1566
+ }
1567
+ }
1568
+ var EvalTrace = class extends Trace {
1569
+ // No boxing in Tracers needed.
1570
+ pure = (x) => pureArray(x);
1571
+ lift = (x) => x;
1572
+ processPrimitive(primitive, tracers, params) {
1573
+ return implRules[primitive](tracers, params);
1574
+ }
1575
+ };
1576
+ var baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
1577
+ var implRules = Array3._implRules();
1578
+ function zerosLike(val) {
1579
+ const aval = getAval(val);
1580
+ return zeros(aval.shape, { dtype: aval.dtype });
1581
+ }
1582
+ function zeros(shape2, { dtype, device } = {}) {
1583
+ return full(shape2, 0, { dtype, device });
1584
+ }
1585
+ function ones(shape2, { dtype, device } = {}) {
1586
+ return full(shape2, 1, { dtype, device });
1587
+ }
1588
+ function full(shape2, fillValue, { dtype, device } = {}) {
1589
+ let source;
1590
+ if (typeof fillValue === "number") {
1591
+ dtype = dtype ?? "float32" /* Float32 */;
1592
+ source = AluExp.const(dtype, fillValue);
1593
+ } else if (typeof fillValue === "boolean") {
1594
+ dtype = dtype ?? "bool" /* Bool */;
1595
+ source = AluExp.const(dtype, fillValue ? 1 : 0);
1596
+ } else if (fillValue instanceof Array3) {
1597
+ throw new Error("numpy.full() with array argument not implemented yet");
1598
+ } else {
1599
+ throw new TypeError(`Invalid type for full: ${fillValue}`);
1600
+ }
1601
+ return new Array3(
1602
+ source,
1603
+ ShapeTracker.fromShape(shape2),
1604
+ dtype ?? "float32" /* Float32 */,
1605
+ getBackend(device)
1606
+ );
1607
+ }
1608
+ function eye(numRows, numCols, { dtype, device } = {}) {
1609
+ numCols = numCols ?? numRows;
1610
+ dtype = dtype ?? "float32" /* Float32 */;
1611
+ if (numCols < numRows) {
1612
+ const arr = eye(numCols, numRows, { dtype, device });
1613
+ return arr.transpose();
1614
+ }
1615
+ if (numRows === 0) {
1616
+ return zeros([0, numCols], { dtype, device });
1617
+ }
1618
+ const exp3 = AluExp.cmplt(
1619
+ AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)),
1620
+ AluExp.i32(1)
1621
+ );
1622
+ return new Array3(
1623
+ AluExp.cast(dtype, exp3),
1624
+ ShapeTracker.fromShape([numRows, numCols]),
1625
+ dtype,
1626
+ getBackend(device)
1627
+ );
1628
+ }
1629
+ function identity(n, { dtype, device } = {}) {
1630
+ return eye(n, n, { dtype, device });
1631
+ }
1632
+ function arange(start, stop, step = 1, { dtype, device } = {}) {
1633
+ dtype = dtype ?? "int32" /* Int32 */;
1634
+ if (stop === void 0) {
1635
+ stop = start;
1636
+ start = 0;
1637
+ }
1638
+ if (step === 0) {
1639
+ throw new RangeError(
1640
+ `Invalid step for arange: ${step}. Step must be non-zero.`
1641
+ );
1642
+ }
1643
+ const size2 = Math.max(0, Math.ceil((stop - start) / step));
1644
+ if (size2 === 0) {
1645
+ return zeros([0], { dtype, device });
1646
+ }
1647
+ const exp3 = AluExp.add(
1648
+ AluExp.const(dtype, start),
1649
+ AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step))
1650
+ );
1651
+ const st = ShapeTracker.fromShape([size2]);
1652
+ return new Array3(exp3, st, dtype, getBackend(device));
1653
+ }
1654
+ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}) {
1655
+ dtype = dtype ?? "float32" /* Float32 */;
1656
+ if (num < 0 || !Number.isInteger(num)) {
1657
+ throw new RangeError(
1658
+ `Invalid num for linspace: ${num}. Must be non-negative integer.`
1659
+ );
1660
+ } else if (num === 0) {
1661
+ return zeros([0], { dtype, device });
1662
+ } else if (num === 1) {
1663
+ return scalar(start, { dtype, device }).reshape([1]);
1664
+ } else if (start === stop) {
1665
+ return full([num], start, { dtype, device });
1666
+ }
1667
+ const delta = stop - start;
1668
+ const denom = endpoint ? num - 1 : num;
1669
+ const exp3 = AluExp.cast(
1670
+ dtype,
1671
+ AluExp.add(
1672
+ AluExp.f32(start),
1673
+ AluExp.mul(
1674
+ AluExp.f32(delta / denom),
1675
+ AluExp.cast("float32" /* Float32 */, AluVar.idx)
1676
+ )
1677
+ )
1678
+ );
1679
+ const st = ShapeTracker.fromShape([num]);
1680
+ return new Array3(exp3, st, dtype, getBackend(device));
1681
+ }
1682
+ function aluCompare(a, b, op) {
1683
+ switch (op) {
1684
+ case "greater" /* Greater */:
1685
+ return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
1686
+ case "less" /* Less */:
1687
+ return AluExp.cmplt(a, b);
1688
+ case "equal" /* Equal */:
1689
+ return AluExp.cmpne(a, b).not();
1690
+ case "not_equal" /* NotEqual */:
1691
+ return AluExp.cmpne(a, b);
1692
+ case "greater_equal" /* GreaterEqual */:
1693
+ return AluExp.cmplt(a, b).not();
1694
+ case "less_equal" /* LessEqual */:
1695
+ return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
1696
+ }
1697
+ }
1698
+ function generalBroadcast(a, b) {
1699
+ const out = [];
1700
+ let i = a.length - 1;
1701
+ let j = b.length - 1;
1702
+ for (; i >= 0 && j >= 0; i--, j--) {
1703
+ const x = a[i];
1704
+ const y = b[j];
1705
+ if (x === y) {
1706
+ out.push(x);
1707
+ } else if (x === 1) {
1708
+ out.push(y);
1709
+ } else if (y === 1) {
1710
+ out.push(x);
1711
+ } else {
1712
+ throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
1713
+ }
1714
+ }
1715
+ for (; i >= 0; i--) {
1716
+ out.push(a[i]);
1717
+ }
1718
+ for (; j >= 0; j--) {
1719
+ out.push(b[j]);
1720
+ }
1721
+ return out.reverse();
1722
+ }
1723
+
1724
+ // src/frontend/jaxpr.ts
1725
+ var Var = class _Var {
1726
+ static #nextId = 1;
1727
+ // For debugging, since JavaScript has no id() function like Python.
1728
+ id;
1729
+ aval;
1730
+ constructor(aval) {
1731
+ this.id = _Var.#nextId++;
1732
+ this.aval = aval;
1733
+ }
1734
+ toString() {
1735
+ return `Var(${this.id}):${this.aval.strShort()}`;
1736
+ }
1737
+ };
1738
+ var Lit = class {
1739
+ dtype;
1740
+ value;
1741
+ aval;
1742
+ constructor(dtype, value) {
1743
+ this.dtype = dtype;
1744
+ this.value = value;
1745
+ this.aval = new ShapedArray([], dtype);
1746
+ }
1747
+ };
1748
+ function atomIsLit(atom, literal) {
1749
+ return atom instanceof Lit && (literal === void 0 || atom.value === literal);
1750
+ }
1751
+ var VarPrinter = class {
1752
+ names = /* @__PURE__ */ new Map();
1753
+ #next = "a";
1754
+ // a, b, c, ..., z, aa, ab, ..., az, ba, bb, ...
1755
+ #advance() {
1756
+ const ret = this.#next;
1757
+ let lastNonz = this.#next.length - 1;
1758
+ while (lastNonz >= 0 && this.#next[lastNonz] === "z") {
1759
+ lastNonz--;
1760
+ }
1761
+ if (lastNonz < 0) {
1762
+ this.#next = "a".repeat(this.#next.length + 1);
1763
+ } else {
1764
+ let result = this.#next.slice(0, lastNonz);
1765
+ result += String.fromCharCode(this.#next.charCodeAt(lastNonz) + 1);
1766
+ result += "a".repeat(this.#next.length - 1 - lastNonz);
1767
+ this.#next = result;
1768
+ }
1769
+ return ret;
1770
+ }
1771
+ name(v) {
1772
+ if (this.names.has(v)) {
1773
+ return this.names.get(v);
1774
+ }
1775
+ const name = this.#advance();
1776
+ this.names.set(v, name);
1777
+ return name;
1778
+ }
1779
+ nameType(v) {
1780
+ return `${this.name(v)}:${v.aval.strShort()}`;
1781
+ }
1782
+ };
1783
+ var JaxprEqn = class {
1784
+ constructor(primitive, inputs, params, outBinders) {
1785
+ this.primitive = primitive;
1786
+ this.inputs = inputs;
1787
+ this.params = params;
1788
+ this.outBinders = outBinders;
1789
+ }
1790
+ pprint(usedVars, vp = new VarPrinter()) {
1791
+ const lhs = PPrint.pp(
1792
+ this.outBinders.map((v) => !usedVars || usedVars.has(v) ? vp.nameType(v) : "_").join(" ")
1793
+ );
1794
+ let rhs = PPrint.pp(this.primitive);
1795
+ const paramsList = Object.entries(this.params).map(
1796
+ ([k, v]) => PPrint.pp(`${k}=${v}`)
1797
+ );
1798
+ if (paramsList.length > 0) {
1799
+ rhs = rhs.stack(PPrint.pp(" [ ")).stack(PPrint.prototype.concat(...paramsList)).stack(PPrint.pp(" ] "));
1800
+ } else {
1801
+ rhs = rhs.stack(PPrint.pp(" "));
1802
+ }
1803
+ rhs = rhs.stack(
1804
+ PPrint.pp(
1805
+ this.inputs.map((x) => x instanceof Var ? vp.name(x) : JSON.stringify(x.value)).join(" ")
1806
+ )
1807
+ );
1808
+ return lhs.stack(PPrint.pp(" = ")).stack(rhs);
1809
+ }
1810
+ toString() {
1811
+ return this.pprint().toString();
1812
+ }
1813
+ };
1814
+ var Jaxpr2 = class _Jaxpr {
1815
+ constructor(inBinders, eqns, outs) {
1816
+ this.inBinders = inBinders;
1817
+ this.eqns = eqns;
1818
+ this.outs = outs;
1819
+ }
1820
+ #hash;
1821
+ pprint() {
1822
+ const vp = new VarPrinter();
1823
+ const usedVars = new Set(
1824
+ [...this.outs, ...this.eqns.flatMap((eqn) => eqn.inputs)].filter(
1825
+ (x) => x instanceof Var
1826
+ )
1827
+ );
1828
+ const inBinders = this.inBinders.map((v) => vp.nameType(v)).join(", ");
1829
+ const eqns = PPrint.prototype.concat(
1830
+ ...this.eqns.map((e2) => e2.pprint(usedVars, vp))
1831
+ );
1832
+ const outs = this.outs.map((x) => x instanceof Var ? vp.name(x) : x.value).join(", ");
1833
+ return PPrint.pp(`{ lambda ${inBinders} .`).concat(
1834
+ (this.eqns.length ? PPrint.pp("let ").stack(eqns).concat(PPrint.pp(`in ( ${outs} ) }`)) : PPrint.pp(`( ${outs} ) }`)).indent(2)
1835
+ );
1836
+ }
1837
+ toString() {
1838
+ return this.pprint().toString();
1839
+ }
1840
+ /**
1841
+ * Gets a hash of this Jaxpr.
1842
+ *
1843
+ * Var identity is not considered in the hash, so two Jaxprs with the same
1844
+ * order of assignments and operators but different variable IDs will resolve
1845
+ * to the same hash (and toString representation).
1846
+ */
1847
+ getHash() {
1848
+ if (this.#hash !== void 0) return this.#hash;
1849
+ const hasher = new FpHash();
1850
+ const varIds = /* @__PURE__ */ new Map();
1851
+ const vi = (v) => {
1852
+ if (varIds.has(v)) return varIds.get(v);
1853
+ const id = varIds.size + 1;
1854
+ varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
1855
+ return id;
1856
+ };
1857
+ hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
1858
+ hasher.update(
1859
+ this.eqns.length,
1860
+ ...this.eqns.flatMap((eqn) => [
1861
+ eqn.primitive,
1862
+ eqn.inputs.length,
1863
+ ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
1864
+ JSON.stringify(eqn.params),
1865
+ eqn.outBinders.length,
1866
+ ...eqn.outBinders.map(vi)
1867
+ ])
1868
+ );
1869
+ hasher.update(
1870
+ this.outs.length,
1871
+ ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value)
1872
+ );
1873
+ return this.#hash = hasher.value;
1874
+ }
1875
+ hash(state) {
1876
+ state.update(this.getHash());
1877
+ }
1878
+ /**
1879
+ * Produce a simplified Jaxpr with basic optimizations applied.
1880
+ * - Trim away unused variables.
1881
+ * - Fold away *1, *0, or +0 operations against literals.
1882
+ */
1883
+ simplify() {
1884
+ const context = /* @__PURE__ */ new Map();
1885
+ const newEqns = [];
1886
+ for (const e2 of this.eqns) {
1887
+ const inputs = e2.inputs.map(
1888
+ (x) => x instanceof Var ? context.get(x) ?? x : x
1889
+ );
1890
+ const eqn = new JaxprEqn(e2.primitive, inputs, e2.params, e2.outBinders);
1891
+ if (eqn.primitive === "add" /* Add */) {
1892
+ const [a, b] = inputs;
1893
+ const c = eqn.outBinders[0];
1894
+ if (atomIsLit(a, 0)) {
1895
+ context.set(c, b);
1896
+ } else if (atomIsLit(b, 0)) {
1897
+ context.set(c, a);
1898
+ } else if (atomIsLit(a) && atomIsLit(b)) {
1899
+ context.set(
1900
+ c,
1901
+ new Lit(
1902
+ a.dtype,
1903
+ a.dtype === "bool" /* Bool */ ? Math.min(a.value + b.value, 1) : a.value + b.value
1904
+ )
1905
+ );
1906
+ } else {
1907
+ newEqns.push(eqn);
1908
+ }
1909
+ } else if (eqn.primitive === "mul" /* Mul */) {
1910
+ const [a, b] = inputs;
1911
+ const c = eqn.outBinders[0];
1912
+ if (atomIsLit(a, 1)) {
1913
+ context.set(c, b);
1914
+ } else if (atomIsLit(b, 1)) {
1915
+ context.set(c, a);
1916
+ } else if (atomIsLit(a) && atomIsLit(b)) {
1917
+ context.set(c, new Lit(a.dtype, a.value * b.value));
1918
+ } else {
1919
+ newEqns.push(eqn);
1920
+ }
1921
+ } else if (eqn.primitive === "idiv" /* Idiv */) {
1922
+ const [a, b] = inputs;
1923
+ const c = eqn.outBinders[0];
1924
+ if (atomIsLit(b, 1)) {
1925
+ context.set(c, a);
1926
+ }
1927
+ } else {
1928
+ newEqns.push(eqn);
1929
+ }
1930
+ }
1931
+ const outs = this.outs.map(
1932
+ (x) => x instanceof Var ? context.get(x) ?? x : x
1933
+ );
1934
+ const usedVars = new Set(outs.filter((x) => x instanceof Var));
1935
+ const liveEqns = [];
1936
+ for (let i = newEqns.length - 1; i >= 0; i--) {
1937
+ const eqn = newEqns[i];
1938
+ if (eqn.outBinders.some((v) => usedVars.has(v))) {
1939
+ liveEqns.push(eqn);
1940
+ for (const v of eqn.inputs) {
1941
+ if (v instanceof Var) {
1942
+ usedVars.add(v);
1943
+ }
1944
+ }
1945
+ }
1946
+ }
1947
+ return new _Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1948
+ }
1949
+ /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1950
+ flatten() {
1951
+ if (!this.eqns.some((eqn) => eqn.primitive === "jit_call" /* JitCall */)) {
1952
+ return this;
1953
+ }
1954
+ const newEqns = [];
1955
+ const varMap = /* @__PURE__ */ new Map();
1956
+ const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1957
+ for (const eqn of this.eqns) {
1958
+ if (eqn.primitive === "jit_call" /* JitCall */) {
1959
+ const jaxpr = eqn.params.jaxpr.flatten();
1960
+ const translation = /* @__PURE__ */ new Map();
1961
+ const translationF = (x) => x instanceof Var ? translation.get(x) : x;
1962
+ for (const [v, x] of zip(jaxpr.inBinders, eqn.inputs)) {
1963
+ translation.set(v, varMapF(x));
1964
+ }
1965
+ for (const ieqn of jaxpr.eqns) {
1966
+ const inputs = ieqn.inputs.map(translationF);
1967
+ const outBinders = [];
1968
+ for (const v of ieqn.outBinders) {
1969
+ const u = new Var(v.aval);
1970
+ outBinders.push(u);
1971
+ translation.set(v, u);
1972
+ }
1973
+ newEqns.push(
1974
+ new JaxprEqn(ieqn.primitive, inputs, ieqn.params, outBinders)
1975
+ );
1976
+ }
1977
+ for (const [v, x] of zip(eqn.outBinders, jaxpr.outs)) {
1978
+ varMap.set(v, translationF(x));
1979
+ }
1980
+ } else {
1981
+ if (eqn.inputs.some((x) => x instanceof Var && varMap.has(x))) {
1982
+ newEqns.push(
1983
+ new JaxprEqn(
1984
+ eqn.primitive,
1985
+ eqn.inputs.map(varMapF),
1986
+ eqn.params,
1987
+ eqn.outBinders
1988
+ )
1989
+ );
1990
+ } else {
1991
+ newEqns.push(eqn);
1992
+ }
1993
+ }
1994
+ }
1995
+ const newOuts = this.outs.map(varMapF);
1996
+ return new _Jaxpr(this.inBinders, newEqns, newOuts);
1997
+ }
1998
+ };
1999
+ var JaxprType = class {
2000
+ constructor(inTypes, outTypes) {
2001
+ this.inTypes = inTypes;
2002
+ this.outTypes = outTypes;
2003
+ }
2004
+ toString() {
2005
+ const inTypes = this.inTypes.map((aval) => aval.strShort()).join(", ");
2006
+ const outTypes = this.outTypes.map((aval) => aval.strShort()).join(", ");
2007
+ return `(${inTypes}) -> (${outTypes})`;
2008
+ }
2009
+ };
2010
+ function typecheckJaxpr(jaxpr) {
2011
+ const env = /* @__PURE__ */ new Set();
2012
+ for (const v of jaxpr.inBinders) {
2013
+ if (env.has(v)) {
2014
+ throw new TypeError(`Duplicate variable binding: ${v}`);
2015
+ }
2016
+ env.add(v);
2017
+ }
2018
+ for (const eqn of jaxpr.eqns) {
2019
+ const inTypes2 = eqn.inputs.map((x) => typecheckAtom(env, x));
2020
+ const outTypes2 = abstractEvalRules[eqn.primitive](inTypes2, eqn.params);
2021
+ for (const [outBinder, outType] of zip(eqn.outBinders, outTypes2)) {
2022
+ if (!outType.equals(outBinder.aval)) {
2023
+ throw new TypeError(
2024
+ `Output binder type mismatch in ${eqn.primitive}: ${outBinder} vs ${outType}`
2025
+ );
2026
+ }
2027
+ if (env.has(outBinder)) {
2028
+ throw new TypeError(`Duplicate variable binding: ${outBinder}`);
2029
+ }
2030
+ env.add(outBinder);
2031
+ }
2032
+ }
2033
+ const inTypes = jaxpr.inBinders.map((v) => v.aval);
2034
+ const outTypes = jaxpr.outs.map((x) => typecheckAtom(env, x));
2035
+ return new JaxprType(inTypes, outTypes);
2036
+ }
2037
+ function typecheckAtom(env, x) {
2038
+ if (x instanceof Var) {
2039
+ if (!env.has(x)) {
2040
+ throw new Error(`Unknown variable: ${x}`);
2041
+ }
2042
+ return x.aval;
2043
+ } else if (x instanceof Lit) {
2044
+ return x.aval;
2045
+ } else {
2046
+ throw new TypeError(`Invalid atom type: ${x}`);
2047
+ }
2048
+ }
2049
+ function evalJaxpr(jaxpr, args) {
2050
+ const env = /* @__PURE__ */ new Map();
2051
+ const usageCount = /* @__PURE__ */ new Map();
2052
+ for (const x of jaxpr.eqns.flatMap((eqn) => eqn.inputs).concat(jaxpr.outs)) {
2053
+ if (x instanceof Var) usageCount.set(x, (usageCount.get(x) ?? 0) + 1);
2054
+ }
2055
+ const remainingRefs = /* @__PURE__ */ new Map();
2056
+ const read = (x) => {
2057
+ if (x instanceof Var) {
2058
+ remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2059
+ return env.get(x);
2060
+ } else {
2061
+ return scalar(x.value, { dtype: x.dtype });
2062
+ }
2063
+ };
2064
+ const write = (v, val) => {
2065
+ if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
2066
+ let refCount = usageCount.get(v) ?? 0;
2067
+ if (refCount) {
2068
+ env.set(v, val);
2069
+ remainingRefs.set(v, refCount);
2070
+ while (refCount-- > 1) val.ref;
2071
+ } else {
2072
+ val.dispose();
2073
+ }
2074
+ };
2075
+ try {
2076
+ for (const [v, arg] of zip(jaxpr.inBinders, args)) write(v, arg);
2077
+ for (const eqn of jaxpr.eqns) {
2078
+ const inVals = eqn.inputs.map(read);
2079
+ const outVals = bind(eqn.primitive, inVals, eqn.params);
2080
+ for (const [v, val] of zip(eqn.outBinders, outVals)) write(v, val);
2081
+ }
2082
+ return jaxpr.outs.map(read);
2083
+ } catch (error) {
2084
+ for (let [v, refCount] of remainingRefs.entries()) {
2085
+ if (refCount > 0) {
2086
+ const tracer = env.get(v);
2087
+ while (refCount--) tracer.dispose();
2088
+ }
2089
+ }
2090
+ throw error;
2091
+ }
2092
+ }
2093
+ function jaxprAsFun(jaxpr) {
2094
+ return (...args) => evalJaxpr(jaxpr, args);
2095
+ }
2096
+ var JaxprTracer = class extends Tracer {
2097
+ constructor(trace, aval) {
2098
+ super(trace);
2099
+ this.aval = aval;
2100
+ }
2101
+ toString() {
2102
+ return `JaxprTracer(${this.aval.strShort()})`;
2103
+ }
2104
+ // JaxprTracer does not hold any resources, no need to be reference counted.
2105
+ get ref() {
2106
+ return this;
2107
+ }
2108
+ dispose() {
2109
+ }
2110
+ };
2111
+ var JaxprTrace = class extends Trace {
2112
+ /** Register a Jaxpr argument with a given shape and return the tracer. */
2113
+ newArg(aval) {
2114
+ aval = ShapedArray.fromAval(aval);
2115
+ const tracer = this.builder.newTracer(this, aval);
2116
+ this.builder.addVar(tracer);
2117
+ return tracer;
2118
+ }
2119
+ /** Register a constant / literal in this Jaxpr. */
2120
+ getOrMakeConstTracer(val) {
2121
+ let tracer = this.builder.constTracers.get(val);
2122
+ if (tracer === void 0) {
2123
+ tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2124
+ this.builder.addConst(
2125
+ tracer,
2126
+ val instanceof Tracer ? val.ref : scalar(val)
2127
+ );
2128
+ }
2129
+ return tracer;
2130
+ }
2131
+ pure = this.getOrMakeConstTracer;
2132
+ lift = this.getOrMakeConstTracer;
2133
+ processPrimitive(primitive, tracers, params) {
2134
+ const avalsIn = tracers.map((t) => t.aval);
2135
+ const avalsOut = abstractEvalRules[primitive](avalsIn, params);
2136
+ const outTracers = avalsOut.map(
2137
+ (aval) => this.builder.newTracer(this, aval)
2138
+ );
2139
+ this.builder.addEqn(
2140
+ new JaxprEqn(
2141
+ primitive,
2142
+ tracers.map((t) => this.builder.getVar(t)),
2143
+ params,
2144
+ outTracers.map((t) => this.builder.addVar(t))
2145
+ )
2146
+ );
2147
+ return outTracers;
2148
+ }
2149
+ get builder() {
2150
+ return this.main.globalData;
2151
+ }
2152
+ };
2153
+ var JaxprBuilder = class {
2154
+ eqns = [];
2155
+ tracerToVar = /* @__PURE__ */ new Map();
2156
+ constTracers = /* @__PURE__ */ new Map();
2157
+ // already-seen value -> tracer
2158
+ constVals = /* @__PURE__ */ new Map();
2159
+ // var -> const value
2160
+ tracers = [];
2161
+ newTracer(trace, aval) {
2162
+ const tracer = new JaxprTracer(trace, aval);
2163
+ this.tracers.push(tracer);
2164
+ return tracer;
2165
+ }
2166
+ addEqn(eqn) {
2167
+ this.eqns.push(eqn);
2168
+ }
2169
+ addVar(tracer) {
2170
+ if (this.tracerToVar.has(tracer)) {
2171
+ throw new Error(`Tracer was added as variable twice: ${tracer}`);
2172
+ }
2173
+ const v = new Var(tracer.aval);
2174
+ this.tracerToVar.set(tracer, v);
2175
+ return v;
2176
+ }
2177
+ getVar(tracer) {
2178
+ const v = this.tracerToVar.get(tracer);
2179
+ if (v === void 0) {
2180
+ throw new Error(`Could not find variable for tracer: ${tracer}`);
2181
+ }
2182
+ return v;
2183
+ }
2184
+ addConst(tracer, val) {
2185
+ const v = this.addVar(tracer);
2186
+ this.constTracers.set(val, tracer);
2187
+ this.constVals.set(v, val);
2188
+ return v;
2189
+ }
2190
+ build(inTracers, outTracers) {
2191
+ let [constVars, consts] = unzip2(this.constVals.entries());
2192
+ const t2v = this.getVar.bind(this);
2193
+ const inBinders = [...constVars, ...inTracers.map(t2v)];
2194
+ const outVars = outTracers.map(t2v);
2195
+ let jaxpr = new Jaxpr2(inBinders, this.eqns, outVars);
2196
+ typecheckJaxpr(jaxpr);
2197
+ [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
2198
+ return { jaxpr, consts };
2199
+ }
2200
+ };
2201
+ function _inlineLiterals(jaxpr, consts) {
2202
+ const literals = /* @__PURE__ */ new Map();
2203
+ const constBinders = [];
2204
+ const newConsts = [];
2205
+ for (let i = 0; i < consts.length; i++) {
2206
+ if (ndim(consts[i]) === 0 && consts[i] instanceof Array3) {
2207
+ const ar = consts[i];
2208
+ literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2209
+ } else {
2210
+ constBinders.push(jaxpr.inBinders[i]);
2211
+ newConsts.push(consts[i]);
2212
+ }
2213
+ }
2214
+ const newEqns = jaxpr.eqns.map(
2215
+ (eqn) => new JaxprEqn(
2216
+ eqn.primitive,
2217
+ eqn.inputs.map((x) => literals.get(x) ?? x),
2218
+ eqn.params,
2219
+ eqn.outBinders
2220
+ )
2221
+ );
2222
+ const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
2223
+ const newJaxpr = new Jaxpr2(
2224
+ [...constBinders, ...jaxpr.inBinders.slice(consts.length)],
2225
+ newEqns,
2226
+ newOuts
2227
+ );
2228
+ typecheckJaxpr(newJaxpr);
2229
+ return [newJaxpr, newConsts];
2230
+ }
2231
+ function binopAbstractEval([x, y]) {
2232
+ if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) {
2233
+ throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2234
+ }
2235
+ if (x.dtype !== y.dtype) {
2236
+ throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2237
+ }
2238
+ return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2239
+ }
2240
+ function compareAbstractEval([x, y]) {
2241
+ if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) {
2242
+ throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2243
+ }
2244
+ if (x.dtype !== y.dtype) {
2245
+ throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2246
+ }
2247
+ return [new ShapedArray(generalBroadcast(x.shape, y.shape), "bool" /* Bool */)];
2248
+ }
2249
+ function vectorizedUnopAbstractEval([x]) {
2250
+ return [ShapedArray.fromAval(x)];
2251
+ }
2252
+ var abstractEvalRules = {
2253
+ ["add" /* Add */]: binopAbstractEval,
2254
+ ["mul" /* Mul */]: binopAbstractEval,
2255
+ ["idiv" /* Idiv */]: binopAbstractEval,
2256
+ ["neg" /* Neg */]: vectorizedUnopAbstractEval,
2257
+ ["reciprocal" /* Reciprocal */]: vectorizedUnopAbstractEval,
2258
+ ["sin" /* Sin */]: vectorizedUnopAbstractEval,
2259
+ ["cos" /* Cos */]: vectorizedUnopAbstractEval,
2260
+ ["exp" /* Exp */]: vectorizedUnopAbstractEval,
2261
+ ["log" /* Log */]: vectorizedUnopAbstractEval,
2262
+ ["min" /* Min */]: binopAbstractEval,
2263
+ ["max" /* Max */]: binopAbstractEval,
2264
+ ["reduce_sum" /* ReduceSum */]([x], { axis }) {
2265
+ const axisSet = new Set(axis);
2266
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2267
+ return [new ShapedArray(newShape, x.dtype)];
2268
+ },
2269
+ ["compare" /* Compare */]: compareAbstractEval,
2270
+ ["where" /* Where */]([cond, x, y]) {
2271
+ if (cond.dtype !== "bool" /* Bool */)
2272
+ throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2273
+ if (x.dtype !== y.dtype)
2274
+ throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2275
+ const shape2 = generalBroadcast(
2276
+ cond.shape,
2277
+ generalBroadcast(x.shape, y.shape)
2278
+ );
2279
+ return [new ShapedArray(shape2, x.dtype)];
2280
+ },
2281
+ ["transpose" /* Transpose */]([x], { perm }) {
2282
+ return [
2283
+ new ShapedArray(
2284
+ perm.map((i) => x.shape[i]),
2285
+ x.dtype
2286
+ )
2287
+ ];
2288
+ },
2289
+ ["broadcast" /* Broadcast */]([x], { shape: shape2 }) {
2290
+ return [new ShapedArray(shape2, x.dtype)];
2291
+ },
2292
+ ["reshape" /* Reshape */]([x], { shape: shape2 }) {
2293
+ return [new ShapedArray(shape2, x.dtype)];
2294
+ },
2295
+ ["flip" /* Flip */]([x], _) {
2296
+ return [new ShapedArray(x.shape, x.dtype)];
2297
+ },
2298
+ ["jit_call" /* JitCall */](args, { jaxpr }) {
2299
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
2300
+ if (args.length !== inTypes.length) {
2301
+ throw new TypeError(
2302
+ `jit_call expected ${inTypes.length} arguments, got ${args.length}`
2303
+ );
2304
+ }
2305
+ for (let i = 0; i < inTypes.length; i++) {
2306
+ if (!args[i].equals(inTypes[i])) {
2307
+ throw new TypeError(
2308
+ `jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`
2309
+ );
2310
+ }
2311
+ }
2312
+ return outTypes;
2313
+ }
2314
+ };
2315
+ function makeJaxpr(f) {
2316
+ return (...argsIn) => {
2317
+ var _stack = [];
2318
+ try {
2319
+ const [avalsIn, inTree] = flatten(argsIn);
2320
+ const [fFlat, outTree] = flattenFun(f, inTree);
2321
+ const builder = new JaxprBuilder();
2322
+ const main = __using(_stack, newMain(JaxprTrace, builder));
2323
+ const _dynamic = __using(_stack, newDynamic(main));
2324
+ const trace = new JaxprTrace(main);
2325
+ const tracersIn = avalsIn.map(
2326
+ (aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval))
2327
+ );
2328
+ const outs = fFlat(...tracersIn);
2329
+ const tracersOut = outs.map(
2330
+ (out) => fullRaise(trace, out)
2331
+ );
2332
+ const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
2333
+ if (outTree.value === void 0) {
2334
+ throw new Error("outTree was not set in makeJaxpr");
2335
+ }
2336
+ return { jaxpr: jaxpr.simplify(), consts, treedef: outTree.value };
2337
+ } catch (_) {
2338
+ var _error = _, _hasError = true;
2339
+ } finally {
2340
+ __callDispose(_stack, _error, _hasError);
2341
+ }
2342
+ };
2343
+ }
2344
+ function jit(f) {
2345
+ const cache = /* @__PURE__ */ new Map();
2346
+ return (...args) => {
2347
+ const [argsFlat, inTree] = flatten(args);
2348
+ const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
2349
+ const avalsIn = unflatten(inTree, avalsInFlat);
2350
+ const cacheKey = JSON.stringify(avalsIn);
2351
+ const {
2352
+ jaxpr,
2353
+ consts,
2354
+ treedef: outTree
2355
+ } = runWithCache(
2356
+ cache,
2357
+ cacheKey,
2358
+ () => makeJaxpr(f)(...avalsIn)
2359
+ );
2360
+ const outs = bind(
2361
+ "jit_call" /* JitCall */,
2362
+ [...consts.map((c) => c.ref), ...argsFlat],
2363
+ {
2364
+ jaxpr,
2365
+ numConsts: consts.length
2366
+ }
2367
+ );
2368
+ return unflatten(outTree, outs);
2369
+ };
2370
+ }
2371
+
2372
+ // src/frontend/jvp.ts
2373
+ var JVPTracer = class extends Tracer {
2374
+ constructor(trace, primal, tangent) {
2375
+ super(trace);
2376
+ this.primal = primal;
2377
+ this.tangent = tangent;
2378
+ }
2379
+ get aval() {
2380
+ return this.primal.aval;
2381
+ }
2382
+ toString() {
2383
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
2384
+ }
2385
+ get ref() {
2386
+ this.primal.ref, this.tangent.ref;
2387
+ return this;
2388
+ }
2389
+ dispose() {
2390
+ this.primal.dispose();
2391
+ this.tangent.dispose();
2392
+ }
2393
+ };
2394
+ var JVPTrace = class extends Trace {
2395
+ pure(val) {
2396
+ return this.lift(pureArray(val));
2397
+ }
2398
+ lift(val) {
2399
+ return new JVPTracer(this, val, zerosLike(val));
2400
+ }
2401
+ processPrimitive(primitive, tracers, params) {
2402
+ const [primalsIn, tangentsIn] = unzip2(
2403
+ tracers.map((x) => [x.primal, x.tangent])
2404
+ );
2405
+ const jvpRule = jvpRules[primitive];
2406
+ if (jvpRule === void 0) {
2407
+ throw new Error(`No JVP rule for: ${primitive}`);
2408
+ }
2409
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
2410
+ return zip(primalsOut, tangentsOut).map(
2411
+ ([x, t]) => new JVPTracer(this, x, t)
2412
+ );
2413
+ }
2414
+ };
2415
+ var jvpRules = {
2416
+ ["add" /* Add */]([x, y], [dx, dy]) {
2417
+ return [[x.add(y)], [dx.add(dy)]];
2418
+ },
2419
+ ["mul" /* Mul */]([x, y], [dx, dy]) {
2420
+ return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
2421
+ },
2422
+ ["idiv" /* Idiv */]([x, y], [dx, dy]) {
2423
+ dx.dispose(), dy.dispose();
2424
+ const z = idiv(x, y);
2425
+ const dz = zerosLike(z);
2426
+ return [[z], [dz]];
2427
+ },
2428
+ ["neg" /* Neg */]([x], [dx]) {
2429
+ return [[x.neg()], [dx.neg()]];
2430
+ },
2431
+ ["reciprocal" /* Reciprocal */]([x], [dx]) {
2432
+ const xRecip = reciprocal(x.ref);
2433
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
2434
+ },
2435
+ ["sin" /* Sin */]([x], [dx]) {
2436
+ return [[sin(x.ref)], [cos(x).mul(dx)]];
2437
+ },
2438
+ ["cos" /* Cos */]([x], [dx]) {
2439
+ return [[cos(x.ref)], [neg(sin(x)).mul(dx)]];
2440
+ },
2441
+ ["exp" /* Exp */]([x], [dx]) {
2442
+ const z = exp(x);
2443
+ return [[z.ref], [z.mul(dx)]];
2444
+ },
2445
+ ["log" /* Log */]([x], [dx]) {
2446
+ return [[log(x.ref)], [reciprocal(x).mul(dx)]];
2447
+ },
2448
+ ["min" /* Min */]([x, y], [dx, dy]) {
2449
+ return [[min(x.ref, y.ref)], [where(less(y, x), dy, dx)]];
2450
+ },
2451
+ ["max" /* Max */]([x, y], [dx, dy]) {
2452
+ return [[max(x.ref, y.ref)], [where(less(y, x), dx, dy)]];
2453
+ },
2454
+ ["reduce_sum" /* ReduceSum */]([x], [dx], { axis }) {
2455
+ return [[reduceSum(x, axis)], [reduceSum(dx, axis)]];
2456
+ },
2457
+ ["compare" /* Compare */]([x, y], tangents, { op }) {
2458
+ for (const t of tangents) t.dispose();
2459
+ const primal = compare(x, y, op);
2460
+ return [[primal], [zerosLike(primal)]];
2461
+ },
2462
+ ["where" /* Where */]([cond, x, y], [dcond, dx, dy]) {
2463
+ dcond.dispose();
2464
+ return [[where(cond.ref, x, y)], [where(cond, dx, dy)]];
2465
+ },
2466
+ ["transpose" /* Transpose */]([x], [dx], { perm }) {
2467
+ return [[transpose(x, perm)], [transpose(dx, perm)]];
2468
+ },
2469
+ ["broadcast" /* Broadcast */]([x], [dx], { shape: shape2, axis }) {
2470
+ return [[broadcast(x, shape2, axis)], [broadcast(dx, shape2, axis)]];
2471
+ },
2472
+ ["reshape" /* Reshape */]([x], [dx], { shape: shape2 }) {
2473
+ return [[reshape(x, shape2)], [reshape(dx, shape2)]];
2474
+ },
2475
+ ["flip" /* Flip */]([x], [dx], { axis }) {
2476
+ return [[flip(x, axis)], [flip(dx, axis)]];
2477
+ },
2478
+ ["jit_call" /* JitCall */](primals, tangents, { jaxpr }) {
2479
+ const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
2480
+ const outs = bind(
2481
+ "jit_call" /* JitCall */,
2482
+ [...newConsts, ...primals, ...tangents],
2483
+ {
2484
+ jaxpr: newJaxpr,
2485
+ numConsts: newConsts.length
2486
+ }
2487
+ );
2488
+ const n = outs.length / 2;
2489
+ if (!Number.isInteger(n))
2490
+ throw new Error("internal: JVP Jaxpr output length is not even");
2491
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
2492
+ return [primalsOut, tangentsOut];
2493
+ }
2494
+ };
2495
+ var jvpJaxprCache = /* @__PURE__ */ new Map();
2496
+ function jvpJaxpr(jaxpr) {
2497
+ if (jvpJaxprCache.has(jaxpr)) {
2498
+ return jvpJaxprCache.get(jaxpr);
2499
+ }
2500
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
2501
+ const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr(
2502
+ (primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents)
2503
+ )(inAvals, inAvals);
2504
+ const result = { newJaxpr, newConsts };
2505
+ jvpJaxprCache.set(jaxpr, result);
2506
+ return result;
2507
+ }
2508
+ function jvpFlat(f, primals, tangents) {
2509
+ var _stack = [];
2510
+ try {
2511
+ const main = __using(_stack, newMain(JVPTrace));
2512
+ const trace = new JVPTrace(main);
2513
+ const tracersIn = zip(primals, tangents).map(
2514
+ ([x, t]) => new JVPTracer(trace, pureArray(x), pureArray(t))
2515
+ );
2516
+ const outs = f(...tracersIn);
2517
+ const tracersOut = outs.map((out) => fullRaise(trace, out));
2518
+ return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
2519
+ } catch (_) {
2520
+ var _error = _, _hasError = true;
2521
+ } finally {
2522
+ __callDispose(_stack, _error, _hasError);
2523
+ }
2524
+ }
2525
+ function jvp(f, primals, tangents) {
2526
+ const [primalsFlat, inTree] = flatten(primals);
2527
+ const [tangentsFlat, inTree2] = flatten(tangents);
2528
+ if (!inTree.equals(inTree2)) {
2529
+ throw new TreeMismatchError("jvp", inTree, inTree2);
2530
+ }
2531
+ const [flatFun, outTree] = flattenFun(f, inTree);
2532
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(
2533
+ flatFun,
2534
+ primalsFlat,
2535
+ tangentsFlat
2536
+ );
2537
+ if (outTree.value === void 0) {
2538
+ throw new Error("outTree was not set in jvp");
2539
+ }
2540
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
2541
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
2542
+ return [primalsOut, tangentsOut];
2543
+ }
2544
+
2545
+ // src/frontend/vmap.ts
2546
+ function mappedAval(batchDim, aval) {
2547
+ const shape2 = [...aval.shape];
2548
+ shape2.splice(batchDim, 1);
2549
+ return new ShapedArray(shape2, aval.dtype);
2550
+ }
2551
+ function moveaxis(x, src, dst) {
2552
+ const t = pureArray(x);
2553
+ const perm = range(t.shape.length);
2554
+ perm.splice(src, 1);
2555
+ perm.splice(dst, 0, src);
2556
+ return transpose(t, perm);
2557
+ }
2558
+ function moveBatchAxis(axisSize, src, dst, x) {
2559
+ if (src === null) {
2560
+ const targetShape = [...x.shape];
2561
+ targetShape.splice(dst, 0, axisSize);
2562
+ return broadcast(x, targetShape, [dst]);
2563
+ } else if (src === dst) {
2564
+ return x;
2565
+ } else {
2566
+ return moveaxis(x, src, dst);
2567
+ }
2568
+ }
2569
+ var BatchTracer = class extends Tracer {
2570
+ constructor(trace, val, batchDim) {
2571
+ super(trace);
2572
+ this.val = val;
2573
+ this.batchDim = batchDim;
2574
+ }
2575
+ get aval() {
2576
+ if (this.batchDim === null) {
2577
+ return this.val.aval;
2578
+ } else {
2579
+ return mappedAval(this.batchDim, this.val.aval);
2580
+ }
2581
+ }
2582
+ toString() {
2583
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
2584
+ }
2585
+ get ref() {
2586
+ this.val.ref;
2587
+ return this;
2588
+ }
2589
+ dispose() {
2590
+ this.val.dispose();
2591
+ }
2592
+ fullLower() {
2593
+ if (this.batchDim === null) {
2594
+ return this.val.fullLower();
2595
+ } else {
2596
+ return this;
2597
+ }
2598
+ }
2599
+ };
2600
+ var BatchTrace = class extends Trace {
2601
+ pure(val) {
2602
+ return this.lift(pureArray(val));
2603
+ }
2604
+ lift(val) {
2605
+ return new BatchTracer(this, val, null);
2606
+ }
2607
+ processPrimitive(primitive, tracers, params) {
2608
+ const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
2609
+ const vmapRule = vmapRules[primitive];
2610
+ if (vmapRule === void 0) {
2611
+ throw new Error(`No vmap rule for: ${primitive}`);
2612
+ }
2613
+ const [valOuts, bdimOuts] = vmapRule(
2614
+ this.axisSize,
2615
+ valsIn,
2616
+ bdimsIn,
2617
+ params
2618
+ );
2619
+ return zip(valOuts, bdimOuts).map(
2620
+ ([x, bd]) => new BatchTracer(this, x, bd)
2621
+ );
2622
+ }
2623
+ get axisSize() {
2624
+ return this.main.globalData;
2625
+ }
2626
+ };
2627
+ function handleScalarBroadcasting(nd, x, d) {
2628
+ if (d === null || nd === ndim(x)) {
2629
+ return x;
2630
+ } else {
2631
+ const axis = range(ndim(x), nd);
2632
+ const shape2 = [...x.shape, ...axis.map(() => 1)];
2633
+ return broadcast(x, shape2, axis);
2634
+ }
2635
+ }
2636
+ function broadcastBatcher(op) {
2637
+ return (axisSize, args, dims) => {
2638
+ if (args.length === 0) {
2639
+ throw new Error("Empty list in broadcastBatcher");
2640
+ }
2641
+ const idx = dims.findIndex((d) => d !== null);
2642
+ if (idx === -1) {
2643
+ return [[op(...args)], [null]];
2644
+ }
2645
+ if (
2646
+ // If only agreeing batch dims, or scalars, just call the primitive.
2647
+ zip(args, dims).every(
2648
+ ([x, d]) => ndim(x) === 0 || deepEqual(x.shape, args[idx].shape) && d === dims[idx]
2649
+ )
2650
+ ) {
2651
+ return [[op(...args)], [dims[idx]]];
2652
+ }
2653
+ args = args.map(
2654
+ (x, i) => ndim(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x
2655
+ );
2656
+ const nd = Math.max(...args.map(ndim));
2657
+ args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
2658
+ return [[op(...args)], [0]];
2659
+ };
2660
+ }
2661
+ function vectorizedUnopBatchingRule(op) {
2662
+ return (axisSize, [x], [xBdim]) => {
2663
+ return [[op(x)], [xBdim]];
2664
+ };
2665
+ }
2666
+ var vmapRules = {
2667
+ ["add" /* Add */]: broadcastBatcher(add),
2668
+ ["mul" /* Mul */]: broadcastBatcher(mul),
2669
+ ["idiv" /* Idiv */]: broadcastBatcher(idiv),
2670
+ ["neg" /* Neg */]: vectorizedUnopBatchingRule(neg),
2671
+ ["reciprocal" /* Reciprocal */]: vectorizedUnopBatchingRule(reciprocal),
2672
+ ["sin" /* Sin */]: vectorizedUnopBatchingRule(sin),
2673
+ ["cos" /* Cos */]: vectorizedUnopBatchingRule(cos),
2674
+ ["exp" /* Exp */]: vectorizedUnopBatchingRule(exp),
2675
+ ["log" /* Log */]: vectorizedUnopBatchingRule(log),
2676
+ ["min" /* Min */]: broadcastBatcher(min),
2677
+ ["max" /* Max */]: broadcastBatcher(max),
2678
+ ["reduce_sum" /* ReduceSum */](axisSize, [x], [xBdim], { axis }) {
2679
+ if (xBdim === null) {
2680
+ return [[reduceSum(x, axis)], [null]];
2681
+ }
2682
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
2683
+ const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
2684
+ return [[reduceSum(x, newAxis)], [outBdim]];
2685
+ },
2686
+ ["compare" /* Compare */](axisSize, args, dims, { op }) {
2687
+ return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims);
2688
+ },
2689
+ // TODO: where, transpose, broadcast, reshape, flip
2690
+ ["jit_call" /* JitCall */](axisSize, args, dims, { jaxpr }) {
2691
+ const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
2692
+ const outs = bind("jit_call" /* JitCall */, [...newConsts, ...args], {
2693
+ jaxpr: newJaxpr,
2694
+ numConsts: newConsts.length
2695
+ });
2696
+ return [outs, rep(outs.length, 0)];
2697
+ }
2698
+ };
2699
+ var vmapJaxprCache = /* @__PURE__ */ new Map();
2700
+ function vmapJaxpr(jaxpr, axisSize, dims) {
2701
+ const cacheKey = JSON.stringify([axisSize, dims]);
2702
+ const prevResult = vmapJaxprCache.get(jaxpr)?.get(cacheKey);
2703
+ if (prevResult) return prevResult;
2704
+ const inAvals = jaxpr.inBinders.map((v, i) => {
2705
+ if (dims[i] === null) return v.aval;
2706
+ const shape2 = [...v.aval.shape];
2707
+ shape2.splice(dims[i], 0, axisSize);
2708
+ return new ShapedArray(shape2, v.aval.dtype);
2709
+ });
2710
+ const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr(
2711
+ (args) => vmapFlat(jaxprAsFun(jaxpr), dims, args)
2712
+ )(inAvals);
2713
+ const result = { newJaxpr, newConsts };
2714
+ if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
2715
+ vmapJaxprCache.get(jaxpr).set(cacheKey, result);
2716
+ return result;
2717
+ }
2718
+ function vmapFlat(f, inAxes, args) {
2719
+ let axisSize = void 0;
2720
+ for (let i = 0; i < args.length; i++) {
2721
+ if (inAxes[i] !== null) {
2722
+ const arg = args[i];
2723
+ if (!(arg instanceof Tracer)) {
2724
+ throw new TypeError("vmap requires Tracer argument for mapped axes");
2725
+ }
2726
+ const size2 = arg.shape[inAxes[i]];
2727
+ if (axisSize === void 0) {
2728
+ axisSize = size2;
2729
+ } else if (axisSize !== size2) {
2730
+ throw new TypeError(
2731
+ "vmap requires all mapped axes to have the same size"
2732
+ );
2733
+ }
2734
+ }
2735
+ }
2736
+ if (axisSize === void 0) {
2737
+ throw new TypeError("vmap requires at least one mapped axis");
2738
+ }
2739
+ let valsOut, bdimsOut;
2740
+ {
2741
+ var _stack = [];
2742
+ try {
2743
+ const main = __using(_stack, newMain(BatchTrace, axisSize));
2744
+ const trace = new BatchTrace(main);
2745
+ const tracersIn = args.map(
2746
+ (x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace, pureArray(x), inAxes[i])
2747
+ );
2748
+ const outs = f(...tracersIn);
2749
+ const tracersOut = outs.map((out) => fullRaise(trace, out));
2750
+ [valsOut, bdimsOut] = unzip2(tracersOut.map((t) => [t.val, t.batchDim]));
2751
+ } catch (_) {
2752
+ var _error = _, _hasError = true;
2753
+ } finally {
2754
+ __callDispose(_stack, _error, _hasError);
2755
+ }
2756
+ }
2757
+ return zip(valsOut, bdimsOut).map(
2758
+ ([valOut, bdim]) => moveBatchAxis(axisSize, bdim, 0, valOut)
2759
+ );
2760
+ }
2761
+ function vmap(f, inAxes = 0) {
2762
+ return (...args) => {
2763
+ const [argsFlat, inTree] = flatten(args);
2764
+ let inAxesFlat;
2765
+ if (typeof inAxes === "number") {
2766
+ inAxesFlat = rep(argsFlat.length, inAxes);
2767
+ } else {
2768
+ let inTree2;
2769
+ [inAxesFlat, inTree2] = flatten(inAxes);
2770
+ if (!inTree.equals(inTree2)) {
2771
+ throw new TreeMismatchError("vmap", inTree, inTree2);
2772
+ }
2773
+ }
2774
+ const [fFlat, outTree] = flattenFun(f, inTree);
2775
+ const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
2776
+ if (outTree.value === void 0) {
2777
+ throw new Error("outTree was not set in vmap");
2778
+ }
2779
+ return unflatten(outTree.value, outsFlat);
2780
+ };
2781
+ }
2782
+ function jacfwd(f) {
2783
+ return function jacobianForward(x) {
2784
+ if (x.shape.length !== 1) {
2785
+ throw new TypeError("jacfwd only supports 1D inputs");
2786
+ }
2787
+ const [size2] = x.shape;
2788
+ const pushfwd = (v) => jvp(f, [x], [v])[1];
2789
+ return vmap(pushfwd, [0])(eye(size2, void 0, { dtype: x.dtype }));
2790
+ };
2791
+ }
2792
+
2793
+ // src/frontend/linearize.ts
2794
+ var PartialVal = class _PartialVal {
2795
+ constructor(val, aval) {
2796
+ this.val = val;
2797
+ this.aval = aval;
2798
+ }
2799
+ static known(val) {
2800
+ return new _PartialVal(val, ShapedArray.fromAval(val.aval));
2801
+ }
2802
+ static unknown(aval) {
2803
+ return new _PartialVal(null, ShapedArray.fromAval(aval));
2804
+ }
2805
+ get isKnown() {
2806
+ return this.val !== null;
2807
+ }
2808
+ toString() {
2809
+ return this.val ? this.val.toString() : this.aval.strShort();
2810
+ }
2811
+ };
2812
+ function partialEvalFlat(f, pvalsIn) {
2813
+ const main = newMain(PartialEvalTrace);
2814
+ const trace = new PartialEvalTrace(main);
2815
+ const tracersIn = pvalsIn.map((pval) => trace.newArg(pval));
2816
+ const unknownTracersIn = tracersIn.filter((t) => !t.pval.isKnown).map((t) => t.ref);
2817
+ const outs = f(...tracersIn);
2818
+ const tracersOut = outs.map(
2819
+ (out) => fullRaise(trace, out)
2820
+ );
2821
+ const pvalsOut = tracersOut.map((t) => t.pval);
2822
+ const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
2823
+ const { jaxpr, consts } = partialEvalGraphToJaxpr(
2824
+ unknownTracersIn,
2825
+ unknownTracersOut
2826
+ );
2827
+ return { jaxpr, pvalsOut, consts };
2828
+ }
2829
+ function linearizeFlatUtil(f, primalsIn) {
2830
+ const pvalsIn = [
2831
+ ...primalsIn.map(PartialVal.known),
2832
+ ...primalsIn.map((t) => PartialVal.unknown(t.aval))
2833
+ ];
2834
+ const fJvp = (...x) => {
2835
+ const k = x.length / 2;
2836
+ const [primalsOut2, tangentsOut] = jvp(f, x.slice(0, k), x.slice(k, 2 * k));
2837
+ return [...primalsOut2, ...tangentsOut];
2838
+ };
2839
+ const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
2840
+ const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
2841
+ if (!primalPvals.every((pval) => pval.isKnown)) {
2842
+ throw new TypeError(
2843
+ "Not all primal values are known after partial evaluation"
2844
+ );
2845
+ }
2846
+ const primalsOut = primalPvals.map((pval) => pval.val);
2847
+ return { primalsOut, jaxpr, consts };
2848
+ }
2849
+ function linearizeFlat(f, primalsIn) {
2850
+ const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
2851
+ const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
2852
+ return [primalsOut, fLin];
2853
+ }
2854
+ function linearize(f, ...primalsIn) {
2855
+ const [primalsInFlat, inTree] = flatten(primalsIn);
2856
+ const [fFlat, outTree] = flattenFun(f, inTree);
2857
+ const [primalsOutFlat, fLinFlat] = linearizeFlat(
2858
+ fFlat,
2859
+ primalsInFlat.map(pureArray)
2860
+ );
2861
+ if (outTree.value === void 0) {
2862
+ throw new Error("outTree was not set in linearize");
2863
+ }
2864
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
2865
+ const fLin = (...tangentsIn) => {
2866
+ const [tangentsInFlat, inTree2] = flatten(tangentsIn);
2867
+ if (!inTree.equals(inTree2)) {
2868
+ throw new TreeMismatchError("linearize", inTree, inTree2);
2869
+ }
2870
+ const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
2871
+ return unflatten(outTree.value, tangentsOutFlat);
2872
+ };
2873
+ return [primalsOut, fLin];
2874
+ }
2875
+ var PartialEvalTracer = class extends Tracer {
2876
+ // PartialEvalTracer reference count, used to free references.
2877
+ // Note: Either pval is known and recipe is null, or pval is unknown and
2878
+ // recipe describes how to compute the value.
2879
+ constructor(trace, pval, recipe) {
2880
+ super(trace);
2881
+ this.pval = pval;
2882
+ this.recipe = recipe;
2883
+ this.#rc = 1;
2884
+ }
2885
+ #rc;
2886
+ get aval() {
2887
+ return this.pval.aval;
2888
+ }
2889
+ toString() {
2890
+ if (!this.recipe) {
2891
+ return `PartialEvalTracer(${this.pval.toString()})`;
2892
+ } else {
2893
+ return `PartialEvalTracer<${this.recipe.type}>(${this.pval.toString()})`;
2894
+ }
2895
+ }
2896
+ get ref() {
2897
+ if (this.#rc <= 0) {
2898
+ throw new UseAfterFreeError(this);
2899
+ }
2900
+ this.#rc++;
2901
+ return this;
2902
+ }
2903
+ dispose() {
2904
+ if (this.#rc <= 0) {
2905
+ throw new UseAfterFreeError(this);
2906
+ }
2907
+ if (--this.#rc === 0) {
2908
+ if (this.pval.isKnown) {
2909
+ this.pval.val.dispose();
2910
+ } else if (this.recipe) {
2911
+ if (this.recipe.type === "Const") {
2912
+ this.recipe.val.dispose();
2913
+ } else if (this.recipe.type === "JaxprEqn") {
2914
+ this.recipe.tracersIn.forEach((t) => t.dispose());
2915
+ }
2916
+ }
2917
+ }
2918
+ }
2919
+ fullLower() {
2920
+ if (this.pval.isKnown) {
2921
+ const val = this.pval.val.ref;
2922
+ this.dispose();
2923
+ return val;
2924
+ }
2925
+ return this;
2926
+ }
2927
+ };
2928
+ var PartialEvalTrace = class extends Trace {
2929
+ newArg(pval) {
2930
+ if (pval.isKnown) return new PartialEvalTracer(this, pval, null);
2931
+ return new PartialEvalTracer(this, pval, { type: "LambdaBinding" });
2932
+ }
2933
+ pure(val) {
2934
+ return new PartialEvalTracer(this, PartialVal.known(pureArray(val)), null);
2935
+ }
2936
+ lift = this.pure;
2937
+ instantiateConst(tracer) {
2938
+ if (!tracer.pval.isKnown) {
2939
+ return tracer;
2940
+ } else {
2941
+ const pval = PartialVal.unknown(ShapedArray.fromAval(tracer.aval));
2942
+ const val = tracer.pval.val.ref;
2943
+ tracer.dispose();
2944
+ return new PartialEvalTracer(this, pval, { type: "Const", val });
2945
+ }
2946
+ }
2947
+ processPrimitive(primitive, tracers, params) {
2948
+ if (tracers.every((t) => t.pval.isKnown)) {
2949
+ return bind(
2950
+ primitive,
2951
+ tracers.map((t) => t.fullLower()),
2952
+ params
2953
+ );
2954
+ }
2955
+ if (primitive === "jit_call" /* JitCall */) {
2956
+ return this.#partialEvalJaxpr(params.jaxpr, params.numConsts, tracers);
2957
+ }
2958
+ const tracersIn = tracers.map((t) => this.instantiateConst(t));
2959
+ const avalsIn = tracersIn.map((t) => t.pval.aval);
2960
+ const avalsOut = abstractEvalRules[primitive](avalsIn, params);
2961
+ const recipe = {
2962
+ type: "JaxprEqn",
2963
+ prim: primitive,
2964
+ tracersIn,
2965
+ params,
2966
+ avalsOut,
2967
+ tracerRefsOut: []
2968
+ // Populated later on
2969
+ };
2970
+ const tracersOut = avalsOut.map((aval, i) => {
2971
+ if (i > 0) {
2972
+ tracersIn.forEach((t) => t.ref);
2973
+ }
2974
+ return new PartialEvalTracer(this, PartialVal.unknown(aval), recipe);
2975
+ });
2976
+ recipe.tracerRefsOut = tracersOut.map((t) => new WeakRef(t));
2977
+ return tracersOut;
2978
+ }
2979
+ /**
2980
+ * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
2981
+ * values as possible (with JIT) and forwarding the unknown ones.
2982
+ *
2983
+ * Used when encountering a JitCall rule during the trace.
2984
+ */
2985
+ #partialEvalJaxpr(jaxpr, numConsts, tracers) {
2986
+ jaxpr = jaxpr.flatten();
2987
+ const inUnknowns = tracers.map((t) => !t.pval.isKnown);
2988
+ const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(
2989
+ jaxpr,
2990
+ inUnknowns
2991
+ );
2992
+ const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
2993
+ const outs1Res = bind(
2994
+ "jit_call" /* JitCall */,
2995
+ knownTracers.map((t) => t.ref.fullLower()),
2996
+ { jaxpr: jaxpr1, numConsts: 0 }
2997
+ );
2998
+ const outs1 = outs1Res.slice(0, jaxpr1.outs.length - numRes);
2999
+ const res = outs1Res.slice(jaxpr1.outs.length - numRes);
3000
+ const resTracers = res.map(
3001
+ (x) => this.instantiateConst(fullRaise(this, x))
3002
+ );
3003
+ const recipe = {
3004
+ type: "JaxprEqn",
3005
+ prim: "jit_call" /* JitCall */,
3006
+ tracersIn: resTracers.concat(unknownTracers),
3007
+ params: { jaxpr: jaxpr2, numConsts: 0 },
3008
+ avalsOut: jaxpr2.outs.map((x) => x.aval),
3009
+ tracerRefsOut: []
3010
+ // populated later
3011
+ };
3012
+ const outs2 = jaxpr2.outs.map(
3013
+ (x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe)
3014
+ );
3015
+ recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3016
+ let i = 0;
3017
+ let j = 0;
3018
+ return outUnknowns.map((unk) => unk ? outs2[j++] : outs1[i++]);
3019
+ }
3020
+ };
3021
+ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3022
+ jaxpr = jaxpr.flatten();
3023
+ const knownIns = jaxpr.inBinders.filter((_, i) => !inUnknowns[i]);
3024
+ const knownVars = new Set(knownIns);
3025
+ const residuals = /* @__PURE__ */ new Set();
3026
+ const eqns1 = [];
3027
+ const eqns2 = [];
3028
+ for (const eqn of jaxpr.eqns) {
3029
+ if (eqn.primitive === "jit_call" /* JitCall */) {
3030
+ throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3031
+ }
3032
+ const hasUnknowns = eqn.inputs.some(
3033
+ (x) => x instanceof Var && !knownVars.has(x)
3034
+ );
3035
+ if (hasUnknowns) {
3036
+ for (const x of eqn.inputs) {
3037
+ if (x instanceof Var && knownVars.has(x)) {
3038
+ residuals.add(x);
3039
+ }
3040
+ }
3041
+ eqns2.push(eqn);
3042
+ } else {
3043
+ eqns1.push(eqn);
3044
+ for (const v of eqn.outBinders) {
3045
+ knownVars.add(v);
3046
+ }
3047
+ }
3048
+ }
3049
+ const outUnknowns = jaxpr.outs.map(
3050
+ (x) => x instanceof Var && !knownVars.has(x)
3051
+ );
3052
+ if (instantiate !== void 0) {
3053
+ for (let i = 0; i < jaxpr.outs.length; i++) {
3054
+ const x = jaxpr.outs[i];
3055
+ if (instantiate[i] && !outUnknowns[i] && x instanceof Var) {
3056
+ residuals.add(x);
3057
+ outUnknowns[i] = true;
3058
+ }
3059
+ }
3060
+ }
3061
+ const residualsL = Array.from(residuals);
3062
+ const [ins1, ins2] = partitionList(inUnknowns, jaxpr.inBinders);
3063
+ const [outs1, outs2] = partitionList(outUnknowns, jaxpr.outs);
3064
+ const jaxpr1 = new Jaxpr2(ins1, eqns1, outs1.concat(residualsL));
3065
+ const jaxpr2 = new Jaxpr2(residualsL.concat(ins2), eqns2, outs2);
3066
+ return { jaxpr1, jaxpr2, outUnknowns, numRes: residualsL.length };
3067
+ }
3068
+ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3069
+ const tracerToVar = /* @__PURE__ */ new Map();
3070
+ const constToVar = /* @__PURE__ */ new Map();
3071
+ const processedEqns = /* @__PURE__ */ new Set();
3072
+ const eqns = [];
3073
+ for (const t of tracersIn) {
3074
+ tracerToVar.set(t, new Var(ShapedArray.fromAval(t.aval)));
3075
+ }
3076
+ for (const t of toposort(
3077
+ tracersOut,
3078
+ (t2) => t2.recipe?.type === "JaxprEqn" ? t2.recipe.tracersIn : []
3079
+ )) {
3080
+ if (!t.recipe) {
3081
+ throw new TypeError("Tracer is missing a recipe, cannot construct Jaxpr");
3082
+ }
3083
+ if (t.recipe.type === "LambdaBinding") {
3084
+ if (!tracersIn.includes(t)) {
3085
+ throw new TypeError("LambdaBinding tracer not in input list");
3086
+ }
3087
+ } else if (t.recipe.type === "Const") {
3088
+ const val = t.recipe.val;
3089
+ let binder = constToVar.get(val);
3090
+ if (!binder) {
3091
+ binder = new Var(ShapedArray.fromAval(val.aval));
3092
+ constToVar.set(val, binder);
3093
+ }
3094
+ tracerToVar.set(t, binder);
3095
+ } else if (t.recipe.type === "JaxprEqn") {
3096
+ if (!processedEqns.has(t.recipe)) {
3097
+ processedEqns.add(t.recipe);
3098
+ const tracersIn2 = t.recipe.tracersIn.map((t2) => tracerToVar.get(t2));
3099
+ const outBinders = t.recipe.avalsOut.map((aval) => new Var(aval));
3100
+ for (let i = 0; i < outBinders.length; i++) {
3101
+ const tracerOut = t.recipe.tracerRefsOut[i].deref();
3102
+ if (tracerOut) {
3103
+ tracerToVar.set(tracerOut, outBinders[i]);
3104
+ }
3105
+ }
3106
+ eqns.push(
3107
+ new JaxprEqn(t.recipe.prim, tracersIn2, t.recipe.params, outBinders)
3108
+ );
3109
+ }
3110
+ }
3111
+ }
3112
+ const [consts, constvars] = unzip2(constToVar.entries());
3113
+ const inBinders = [
3114
+ ...constvars,
3115
+ ...tracersIn.map((t) => tracerToVar.get(t))
3116
+ ];
3117
+ const outVars = tracersOut.map((t) => tracerToVar.get(t));
3118
+ const jaxpr = new Jaxpr2(inBinders, eqns, outVars);
3119
+ typecheckJaxpr(jaxpr);
3120
+ for (const t of consts) t.ref;
3121
+ for (const t of tracersIn) t.dispose();
3122
+ for (const t of tracersOut) t.dispose();
3123
+ return { jaxpr: jaxpr.simplify(), consts };
3124
+ }
3125
+ var UndefPrimal = class {
3126
+ aval;
3127
+ constructor(aval) {
3128
+ this.aval = ShapedArray.fromAval(aval);
3129
+ }
3130
+ };
3131
+ function evalJaxprTransposed(jaxpr, args, cotangents) {
3132
+ const knownPrimals = /* @__PURE__ */ new Map();
3133
+ for (let i = 0; i < jaxpr.inBinders.length; i++) {
3134
+ if (!(args[i] instanceof UndefPrimal)) {
3135
+ knownPrimals.set(jaxpr.inBinders[i], args[i]);
3136
+ }
3137
+ }
3138
+ const ctStore = /* @__PURE__ */ new Map();
3139
+ const readCotangent = (v) => {
3140
+ const ct = ctStore.get(v);
3141
+ if (ct) {
3142
+ ctStore.delete(v);
3143
+ return ct;
3144
+ } else {
3145
+ return zeros(v.aval.shape, { dtype: v.aval.dtype });
3146
+ }
3147
+ };
3148
+ const writeCotangent = (v, ct) => {
3149
+ if (ct !== null) {
3150
+ const oldCt = ctStore.get(v);
3151
+ if (oldCt) ctStore.set(v, add(oldCt, ct));
3152
+ else ctStore.set(v, ct);
3153
+ }
3154
+ };
3155
+ for (let i = 0; i < jaxpr.outs.length; i++) {
3156
+ const v = jaxpr.outs[i];
3157
+ if (v instanceof Var) writeCotangent(v, cotangents[i]);
3158
+ }
3159
+ for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3160
+ const eqn = jaxpr.eqns[i];
3161
+ const primalsIn = eqn.inputs.map(
3162
+ (v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval)
3163
+ );
3164
+ const cotangentsOut = eqn.outBinders.map(readCotangent);
3165
+ const rule = transposeRules[eqn.primitive];
3166
+ if (!rule) {
3167
+ throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
3168
+ }
3169
+ const cotangentsIn = rule(cotangentsOut, primalsIn, eqn.params);
3170
+ for (let j = 0; j < eqn.inputs.length; j++) {
3171
+ const v = eqn.inputs[j];
3172
+ if (v instanceof Var && !knownPrimals.has(v)) {
3173
+ writeCotangent(v, cotangentsIn[j]);
3174
+ } else if (cotangentsIn[j] !== null) {
3175
+ throw new Error("invariant violation: cotangent should be null");
3176
+ }
3177
+ }
3178
+ }
3179
+ for (const t of knownPrimals.values()) t.dispose();
3180
+ const results = [];
3181
+ for (let i = 0; i < jaxpr.inBinders.length; i++) {
3182
+ if (args[i] instanceof UndefPrimal) {
3183
+ results.push(readCotangent(jaxpr.inBinders[i]));
3184
+ }
3185
+ }
3186
+ return results;
3187
+ }
3188
+ var NonlinearError = class extends TypeError {
3189
+ constructor(primitive) {
3190
+ super(`Nonlinear operation in backward pass for ${primitive}`);
3191
+ }
3192
+ };
3193
+ var transposeRules = {
3194
+ ["mul" /* Mul */]([ct], [x, y]) {
3195
+ if (x instanceof UndefPrimal === y instanceof UndefPrimal)
3196
+ throw new NonlinearError("mul" /* Mul */);
3197
+ return [
3198
+ x instanceof UndefPrimal ? mul(ct, y) : null,
3199
+ y instanceof UndefPrimal ? mul(x, ct) : null
3200
+ ];
3201
+ },
3202
+ ["neg" /* Neg */]([ct], [x]) {
3203
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError("neg" /* Neg */);
3204
+ return [neg(ct)];
3205
+ },
3206
+ ["add" /* Add */]([ct], [x, y]) {
3207
+ if (!(x instanceof UndefPrimal || y instanceof UndefPrimal))
3208
+ throw new NonlinearError("add" /* Add */);
3209
+ if (x instanceof UndefPrimal && y instanceof UndefPrimal)
3210
+ return [ct.ref, ct];
3211
+ return x instanceof UndefPrimal ? (y.dispose(), [ct, null]) : (x.dispose(), [null, ct]);
3212
+ },
3213
+ ["reduce_sum" /* ReduceSum */]([ct], [x], { axis }) {
3214
+ if (!(x instanceof UndefPrimal))
3215
+ throw new NonlinearError("reduce_sum" /* ReduceSum */);
3216
+ return [broadcast(ct, x.aval.shape, axis)];
3217
+ },
3218
+ // BUG: Doesn't handle broadcasting.
3219
+ ["where" /* Where */]([ct], [cond, x, y]) {
3220
+ const cts = [null, null, null];
3221
+ if (cond instanceof UndefPrimal) throw new NonlinearError("where" /* Where */);
3222
+ if (x instanceof UndefPrimal) {
3223
+ const zerosX = zeros(x.aval.shape, { dtype: x.aval.dtype });
3224
+ cts[1] = where(cond.ref, ct.ref, zerosX);
3225
+ } else {
3226
+ x.dispose();
3227
+ }
3228
+ if (y instanceof UndefPrimal) {
3229
+ const zerosY = zeros(x.aval.shape, { dtype: x.aval.dtype });
3230
+ cts[2] = where(cond.ref, zerosY, ct.ref);
3231
+ } else {
3232
+ y.dispose();
3233
+ }
3234
+ ct.dispose();
3235
+ cond.dispose();
3236
+ return cts;
3237
+ },
3238
+ ["transpose" /* Transpose */]([ct], [x], { perm }) {
3239
+ if (!(x instanceof UndefPrimal))
3240
+ throw new NonlinearError("transpose" /* Transpose */);
3241
+ return [transpose(ct, invertPermutation(perm))];
3242
+ },
3243
+ ["broadcast" /* Broadcast */]([ct], [x], { axis }) {
3244
+ if (!(x instanceof UndefPrimal))
3245
+ throw new NonlinearError("broadcast" /* Broadcast */);
3246
+ return [reduceSum(ct, axis)];
3247
+ },
3248
+ ["reshape" /* Reshape */]([ct], [x], _) {
3249
+ if (!(x instanceof UndefPrimal))
3250
+ throw new NonlinearError("reshape" /* Reshape */);
3251
+ return [reshape(ct, x.aval.shape)];
3252
+ },
3253
+ ["flip" /* Flip */]([ct], [x], { axis }) {
3254
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError("flip" /* Flip */);
3255
+ return [flip(ct, axis)];
3256
+ },
3257
+ ["jit_call" /* JitCall */](cts, args, { jaxpr }) {
3258
+ const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3259
+ const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3260
+ const residuals = args.filter((x, i2) => !undefPrimals[i2]);
3261
+ const outs = bind("jit_call" /* JitCall */, [...newConsts, ...residuals, ...cts], {
3262
+ jaxpr: newJaxpr,
3263
+ numConsts: newConsts.length
3264
+ });
3265
+ let i = 0;
3266
+ return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
3267
+ }
3268
+ };
3269
+ var transposeJaxprCache = /* @__PURE__ */ new Map();
3270
+ function transposeJaxpr(jaxpr, undefPrimals) {
3271
+ const cacheKey = JSON.stringify(undefPrimals);
3272
+ const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
3273
+ if (prevResult) return prevResult;
3274
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
3275
+ const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
3276
+ const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr(
3277
+ (forwardIn, cotangents) => {
3278
+ const args = [];
3279
+ let forwardInIdx = 0;
3280
+ for (let i = 0; i < undefPrimals.length; i++) {
3281
+ if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
3282
+ else args.push(forwardIn[forwardInIdx++]);
3283
+ }
3284
+ return evalJaxprTransposed(jaxpr, args, cotangents);
3285
+ }
3286
+ )(forwardInTypes, outTypes);
3287
+ typecheckJaxpr(newJaxpr);
3288
+ const result = { newJaxpr, newConsts };
3289
+ if (!transposeJaxprCache.has(jaxpr))
3290
+ transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3291
+ transposeJaxprCache.get(jaxpr).set(cacheKey, result);
3292
+ return result;
3293
+ }
3294
+ function vjpFlat(f, primalsIn) {
3295
+ const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3296
+ const transposeInputs = [
3297
+ ...consts.map((c) => c.ref),
3298
+ // Explcitly list which arguments should be transposed.
3299
+ ...primalsIn.map((t) => new UndefPrimal(t.aval))
3300
+ ];
3301
+ const fVjp = (...cotangents) => evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3302
+ return [primalsOut, fVjp];
3303
+ }
3304
+ function vjp(f, ...primalsIn) {
3305
+ const [primalsInFlat, inTree] = flatten(primalsIn);
3306
+ const [fFlat, outTree] = flattenFun(f, inTree);
3307
+ const [primalsOutFlat, fVjpFlat] = vjpFlat(
3308
+ fFlat,
3309
+ primalsInFlat.map(pureArray)
3310
+ );
3311
+ if (outTree.value === void 0) {
3312
+ throw new Error("outTree was not set in vjp");
3313
+ }
3314
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
3315
+ const fVjp = (cotangentsOut) => {
3316
+ const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3317
+ if (!outTree.value.equals(outTree2)) {
3318
+ throw new TreeMismatchError("vjp", outTree.value, outTree2);
3319
+ }
3320
+ const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3321
+ return unflatten(inTree, cotangentsInFlat);
3322
+ };
3323
+ return [primalsOut, fVjp];
3324
+ }
3325
+ function grad(f) {
3326
+ return (...x) => {
3327
+ const [y, fVjp] = vjp(f, ...x);
3328
+ if (!(y instanceof Tracer) || ndim(y) !== 0) {
3329
+ throw new TypeError("grad requires a scalar output");
3330
+ }
3331
+ if (y.dtype !== "float32" /* Float32 */) {
3332
+ throw new TypeError("grad currently only supports float32");
3333
+ }
3334
+ return fVjp(pureArray(1))[0];
3335
+ };
3336
+ }
3337
+ function jacrev(f) {
3338
+ return function jacobianReverse(x) {
3339
+ if (x.shape.length !== 1) {
3340
+ throw new TypeError("jacrev only supports 1D inputs");
3341
+ }
3342
+ const [size2] = x.shape;
3343
+ const pullback = (ct) => vjp(f, x)[1](ct)[0];
3344
+ return vmap(pullback, [1])(eye(size2, void 0, { dtype: x.dtype }));
3345
+ };
3346
+ }
3347
+
3348
+ // src/nn.ts
3349
+ var nn_exports = {};
3350
+ __export(nn_exports, {
3351
+ identity: () => identity2,
3352
+ logSigmoid: () => logSigmoid,
3353
+ relu: () => relu,
3354
+ relu6: () => relu6,
3355
+ sigmoid: () => sigmoid,
3356
+ silu: () => silu,
3357
+ softSign: () => softSign,
3358
+ softplus: () => softplus,
3359
+ swish: () => swish
3360
+ });
3361
+
3362
+ // src/numpy.ts
3363
+ var numpy_exports = {};
3364
+ __export(numpy_exports, {
3365
+ Array: () => Array3,
3366
+ DType: () => DType,
3367
+ abs: () => abs,
3368
+ absolute: () => absolute,
3369
+ add: () => add2,
3370
+ allclose: () => allclose,
3371
+ arange: () => arange,
3372
+ array: () => array,
3373
+ bool: () => bool,
3374
+ clip: () => clip,
3375
+ complex64: () => complex64,
3376
+ cos: () => cos2,
3377
+ diag: () => diag,
3378
+ diagonal: () => diagonal,
3379
+ divide: () => divide,
3380
+ dot: () => dot,
3381
+ e: () => e,
3382
+ equal: () => equal2,
3383
+ eulerGamma: () => eulerGamma,
3384
+ exp: () => exp2,
3385
+ exp2: () => exp22,
3386
+ eye: () => eye,
3387
+ flip: () => flip2,
3388
+ fliplr: () => fliplr,
3389
+ flipud: () => flipud,
3390
+ float32: () => float32,
3391
+ full: () => full,
3392
+ greater: () => greater2,
3393
+ greaterEqual: () => greaterEqual2,
3394
+ identity: () => identity,
3395
+ inf: () => inf,
3396
+ int32: () => int32,
3397
+ less: () => less2,
3398
+ lessEqual: () => lessEqual2,
3399
+ linspace: () => linspace,
3400
+ log: () => log2,
3401
+ log10: () => log10,
3402
+ log2: () => log22,
3403
+ matmul: () => matmul,
3404
+ matrixTranspose: () => matrixTranspose,
3405
+ maximum: () => maximum,
3406
+ meshgrid: () => meshgrid,
3407
+ minimum: () => minimum,
3408
+ moveaxis: () => moveaxis2,
3409
+ multiply: () => multiply,
3410
+ nan: () => nan,
3411
+ ndim: () => ndim2,
3412
+ negative: () => negative,
3413
+ notEqual: () => notEqual2,
3414
+ ones: () => ones,
3415
+ permuteDims: () => permuteDims,
3416
+ pi: () => pi,
3417
+ ravel: () => ravel,
3418
+ reciprocal: () => reciprocal2,
3419
+ reshape: () => reshape2,
3420
+ scalar: () => scalar,
3421
+ shape: () => shape,
3422
+ sin: () => sin2,
3423
+ size: () => size,
3424
+ square: () => square,
3425
+ sum: () => sum,
3426
+ tan: () => tan,
3427
+ transpose: () => transpose2,
3428
+ trueDivide: () => trueDivide,
3429
+ trunc: () => trunc,
3430
+ vdot: () => vdot,
3431
+ vecdot: () => vecdot,
3432
+ where: () => where2,
3433
+ zeros: () => zeros
3434
+ });
3435
+ var float32 = "float32" /* Float32 */;
3436
+ var int32 = "int32" /* Int32 */;
3437
+ var bool = "bool" /* Bool */;
3438
+ var complex64 = "complex64" /* Complex64 */;
3439
+ var e = Math.E;
3440
+ var eulerGamma = 0.5772156649015329;
3441
+ var inf = Number.POSITIVE_INFINITY;
3442
+ var nan = Number.NaN;
3443
+ var pi = Math.PI;
3444
+ var add2 = add;
3445
+ var multiply = mul;
3446
+ var negative = neg;
3447
+ var reciprocal2 = reciprocal;
3448
+ var sin2 = sin;
3449
+ var cos2 = cos;
3450
+ var exp2 = exp;
3451
+ var log2 = log;
3452
+ var minimum = min;
3453
+ var maximum = max;
3454
+ var greater2 = greater;
3455
+ var less2 = less;
3456
+ var equal2 = equal;
3457
+ var notEqual2 = notEqual;
3458
+ var greaterEqual2 = greaterEqual;
3459
+ var lessEqual2 = lessEqual;
3460
+ var where2 = where;
3461
+ var transpose2 = transpose;
3462
+ var reshape2 = reshape;
3463
+ var sum = reduceSum;
3464
+ var moveaxis2 = moveaxis;
3465
+ var ndim2 = ndim;
3466
+ var shape = getShape;
3467
+ function size(a, axis) {
3468
+ const s = shape(a);
3469
+ return axis === void 0 ? prod(s) : s[axis];
3470
+ }
3471
+ function flip2(x, axis) {
3472
+ const nd = ndim2(x);
3473
+ if (axis === void 0) {
3474
+ axis = range(nd);
3475
+ } else if (typeof axis === "number") {
3476
+ axis = [axis];
3477
+ }
3478
+ const seen = /* @__PURE__ */ new Set();
3479
+ for (let i = 0; i < axis.length; i++) {
3480
+ if (axis[i] >= nd || axis[i] < -nd) {
3481
+ throw new TypeError(
3482
+ `flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`
3483
+ );
3484
+ }
3485
+ if (axis[i] < 0) axis[i] += nd;
3486
+ if (seen.has(axis[i])) {
3487
+ throw new TypeError(`flip: duplicate axis ${axis[i]} in axis list`);
3488
+ }
3489
+ seen.add(axis[i]);
3490
+ }
3491
+ return flip(x, axis);
3492
+ }
3493
+ function flipud(x) {
3494
+ return flip2(x, 0);
3495
+ }
3496
+ function fliplr(x) {
3497
+ return flip2(x, 1);
3498
+ }
3499
+ var permuteDims = transpose2;
3500
+ function ravel(a) {
3501
+ return fudgeArray(a).ravel();
3502
+ }
3503
+ function diagonal(a, offset, axis1, axis2) {
3504
+ return fudgeArray(a).diagonal(offset, axis1, axis2);
3505
+ }
3506
+ function matrixTranspose(x) {
3507
+ const ar = fudgeArray(x);
3508
+ if (ar.ndim < 2)
3509
+ throw new TypeError("matrixTranspose only supports 2D+ arrays");
3510
+ return ar.transpose([...range(ar.ndim - 2), ar.ndim - 1, ar.ndim - 2]);
3511
+ }
3512
+ function diag(v, k = 0) {
3513
+ const a = fudgeArray(v);
3514
+ if (!Number.isInteger(k))
3515
+ throw new TypeError(`k must be an integer, got ${k}`);
3516
+ if (a.ndim === 1) {
3517
+ const n = a.shape[0];
3518
+ const ret = where2(eye(n).equal(1), a, 0);
3519
+ if (k !== 0) throw new Error("diag() for 1D arrays only for k=0");
3520
+ return ret;
3521
+ } else if (a.ndim === 2) {
3522
+ return diagonal(a, k);
3523
+ } else {
3524
+ throw new TypeError("numpy.diag only supports 1D and 2D arrays");
3525
+ }
3526
+ }
3527
+ function allclose(actual, expected, options) {
3528
+ const { rtol = 1e-5, atol = 1e-8 } = options ?? {};
3529
+ const x = array(actual);
3530
+ const y = array(expected);
3531
+ if (!deepEqual(x.shape, y.shape)) {
3532
+ return false;
3533
+ }
3534
+ const xData = x.dataSync();
3535
+ const yData = y.dataSync();
3536
+ for (let i = 0; i < xData.length; i++) {
3537
+ if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) {
3538
+ return false;
3539
+ }
3540
+ }
3541
+ return true;
3542
+ }
3543
+ var matmul = jit(function matmul2(x, y) {
3544
+ if (x.ndim === 0 || y.ndim === 0) {
3545
+ throw new TypeError("matmul: x and y must be at least 1D");
3546
+ }
3547
+ if (y.ndim === 1) {
3548
+ return x.mul(y).sum(x.ndim - 1);
3549
+ }
3550
+ x = x.reshape(x.shape.toSpliced(-1, 0, 1));
3551
+ y = y.reshape(y.shape.toSpliced(-2, 0, 1)).transpose([
3552
+ ...range(y.shape.length - 1),
3553
+ y.shape.length,
3554
+ y.shape.length - 1
3555
+ ]);
3556
+ return x.mul(y).sum(Math.max(x.ndim, y.ndim) - 1);
3557
+ });
3558
+ var dot = jit(function dot2(x, y) {
3559
+ if (x.ndim === 0 || y.ndim === 0) {
3560
+ return multiply(x, y);
3561
+ }
3562
+ if (y.ndim === 1) {
3563
+ return x.mul(y).sum(x.ndim - 1);
3564
+ }
3565
+ x = x.reshape(x.shape.toSpliced(-1, 0, ...rep(y.ndim - 1, 1)));
3566
+ y = y.transpose([
3567
+ ...range(y.shape.length - 2),
3568
+ y.shape.length - 1,
3569
+ y.shape.length - 2
3570
+ ]);
3571
+ return x.mul(y).sum(x.ndim - 1);
3572
+ });
3573
+ var vecdot = jit(function vecdot2(x, y) {
3574
+ return x.mul(y).sum(Math.max(x.ndim, y.ndim) - 1);
3575
+ });
3576
+ function vdot(x, y) {
3577
+ return vecdot(ravel(x), ravel(y));
3578
+ }
3579
+ function meshgrid(xs, { indexing } = {}) {
3580
+ indexing ??= "xy";
3581
+ for (const x of xs) {
3582
+ if (x.ndim !== 1) {
3583
+ throw new TypeError(
3584
+ `meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`
3585
+ );
3586
+ }
3587
+ }
3588
+ if (xs.length <= 1) return xs;
3589
+ if (indexing === "xy") {
3590
+ const [a, b, ...rest] = xs;
3591
+ const [rb, ra, ...rrest] = meshgrid([b, a, ...rest], { indexing: "ij" });
3592
+ return [ra, rb, ...rrest];
3593
+ }
3594
+ const shape2 = xs.map((x) => x.shape[0]);
3595
+ return xs.map(
3596
+ (x, i) => broadcast(x, shape2, [
3597
+ ...range(i),
3598
+ ...range(i + 1, xs.length)
3599
+ ])
3600
+ );
3601
+ }
3602
+ function clip(a, min2, max2) {
3603
+ a = fudgeArray(a);
3604
+ if (max2 !== void 0) {
3605
+ a = minimum(a, max2);
3606
+ }
3607
+ if (min2 !== void 0) {
3608
+ a = maximum(a, min2);
3609
+ }
3610
+ return a;
3611
+ }
3612
+ function absolute(x) {
3613
+ x = fudgeArray(x);
3614
+ return where2(less2(x.ref, 0), x.ref.mul(-1), x);
3615
+ }
3616
+ var abs = absolute;
3617
+ function square(x) {
3618
+ x = fudgeArray(x);
3619
+ return x.ref.mul(x);
3620
+ }
3621
+ function tan(x) {
3622
+ x = fudgeArray(x);
3623
+ return sin2(x.ref).div(cos2(x));
3624
+ }
3625
+ function trueDivide(x, y) {
3626
+ x = fudgeArray(x);
3627
+ y = fudgeArray(y);
3628
+ if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) {
3629
+ throw new TypeError(
3630
+ `trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`
3631
+ );
3632
+ }
3633
+ return x.div(y);
3634
+ }
3635
+ var divide = trueDivide;
3636
+ function trunc(x) {
3637
+ return idiv(x, 1);
3638
+ }
3639
+ function exp22(p) {
3640
+ return exp2(multiply(p, Math.LN2));
3641
+ }
3642
+ function log22(x) {
3643
+ return log2(x).mul(Math.LOG2E);
3644
+ }
3645
+ function log10(x) {
3646
+ return log2(x).mul(Math.LOG10E);
3647
+ }
3648
+
3649
+ // src/nn.ts
3650
+ function relu(x) {
3651
+ return maximum(x, 0);
3652
+ }
3653
+ function relu6(x) {
3654
+ return clip(x, 0, 6);
3655
+ }
3656
+ function sigmoid(x) {
3657
+ return reciprocal2(exp2(negative(x)).add(1));
3658
+ }
3659
+ function softplus(x) {
3660
+ return log2(exp2(x).add(1));
3661
+ }
3662
+ function softSign(x) {
3663
+ x = fudgeArray(x);
3664
+ return x.ref.div(absolute(x).add(1));
3665
+ }
3666
+ function silu(x) {
3667
+ x = fudgeArray(x);
3668
+ return x.ref.mul(sigmoid(x));
3669
+ }
3670
+ var swish = silu;
3671
+ function logSigmoid(x) {
3672
+ return negative(softplus(negative(x)));
3673
+ }
3674
+ var identity2 = fudgeArray;
3675
+
3676
+ // src/polyfills.ts
3677
+ Symbol.dispose ??= Symbol.for("Symbol.dispose");
3678
+ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
3679
+
3680
+ // src/index.ts
3681
+ var jvp2 = jvp;
3682
+ var vmap2 = vmap;
3683
+ var jacfwd2 = jacfwd;
3684
+ var makeJaxpr2 = makeJaxpr;
3685
+ var jit2 = jit;
3686
+ var linearize2 = linearize;
3687
+ var vjp2 = vjp;
3688
+ var grad2 = grad;
3689
+ var jacrev2 = jacrev;
3690
+ var jacobian = jacrev2;
3691
+ export {
3692
+ devices,
3693
+ grad2 as grad,
3694
+ init,
3695
+ jacfwd2 as jacfwd,
3696
+ jacobian,
3697
+ jacrev2 as jacrev,
3698
+ jit2 as jit,
3699
+ jvp2 as jvp,
3700
+ linearize2 as linearize,
3701
+ makeJaxpr2 as makeJaxpr,
3702
+ nn_exports as nn,
3703
+ numpy_exports as numpy,
3704
+ setDevice,
3705
+ tree_exports as tree,
3706
+ vjp2 as vjp,
3707
+ vmap2 as vmap
3708
+ };