tensorgrad 0.0.14 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -170
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -154
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +64 -11
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -180
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -375
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/dist/runtime.js
DELETED
|
@@ -1,375 +0,0 @@
|
|
|
1
|
-
// WebGPU runtime. Reads a BufferPlan + KernelSpec[] (produced by codegen),
|
|
2
|
-
// allocates real GPU buffers and pipelines, and provides a `step()` method
|
|
3
|
-
// that uploads inputs, dispatches all kernels, and reads back outputs.
|
|
4
|
-
//
|
|
5
|
-
// Browser-only: this module needs `navigator.gpu` at runtime.
|
|
6
|
-
/**
|
|
7
|
-
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
8
|
-
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
9
|
-
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
10
|
-
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
11
|
-
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
12
|
-
* regardless of whether captures were read back.
|
|
13
|
-
*/
|
|
14
|
-
export class Captures {
|
|
15
|
-
shapes;
|
|
16
|
-
data;
|
|
17
|
-
constructor(shapes, data) {
|
|
18
|
-
this.shapes = shapes;
|
|
19
|
-
this.data = data;
|
|
20
|
-
}
|
|
21
|
-
get(name) {
|
|
22
|
-
const d = this.data.get(name);
|
|
23
|
-
if (!d) {
|
|
24
|
-
const known = [...this.data.keys()].sort().join(', ');
|
|
25
|
-
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;
|
|
26
|
-
throw new Error(`Captures.get: '${name}' not present. ${detail}`);
|
|
27
|
-
}
|
|
28
|
-
return d;
|
|
29
|
-
}
|
|
30
|
-
shapeOf(name) {
|
|
31
|
-
const s = this.shapes[name];
|
|
32
|
-
if (!s) {
|
|
33
|
-
const known = Object.keys(this.shapes).sort().join(', ') || '(none registered)';
|
|
34
|
-
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`);
|
|
35
|
-
}
|
|
36
|
-
return s;
|
|
37
|
-
}
|
|
38
|
-
has(name) { return this.data.has(name); }
|
|
39
|
-
names() { return [...this.data.keys()].sort(); }
|
|
40
|
-
}
|
|
41
|
-
// Inlined numeric values (per WebGPU spec) so this module is importable in Node
|
|
42
|
-
// for codegen-only usage. The browser provides GPUBufferUsage as a global, but
|
|
43
|
-
// referencing it at module scope would crash before any browser code runs.
|
|
44
|
-
const STORAGE_RW = 0x80 /*STORAGE*/ | 0x8 /*COPY_DST*/ | 0x4; /*COPY_SRC*/
|
|
45
|
-
const READBACK = 0x1 /*MAP_READ*/ | 0x8; /*COPY_DST*/
|
|
46
|
-
export async function createRuntime(plan, kernels, lossBufferId, opts = {}) {
|
|
47
|
-
const device = opts.device ?? await acquireDevice();
|
|
48
|
-
const queue = device.queue;
|
|
49
|
-
// ---- Allocate one GPUBuffer per BufferSpec --------------------------------
|
|
50
|
-
// State buffers also get filled with their initValue at allocation time.
|
|
51
|
-
// Param buffers may be supplied externally via opts.sharedParams; in that
|
|
52
|
-
// case we reuse the provided GPUBuffer instead of allocating, and the
|
|
53
|
-
// sibling compile that owns it is responsible for upload + lifetime.
|
|
54
|
-
// ownedBufferIds tracks which buffers we allocated ourselves (and so must
|
|
55
|
-
// destroy on .destroy()) vs which were handed in by a sibling compile.
|
|
56
|
-
const buffers = new Map();
|
|
57
|
-
const ownedBufferIds = new Set();
|
|
58
|
-
const sharedParams = opts.sharedParams;
|
|
59
|
-
for (const spec of plan.buffers) {
|
|
60
|
-
const shared = spec.kind === 'param' ? sharedParams?.get(spec.name) : undefined;
|
|
61
|
-
if (shared) {
|
|
62
|
-
if (shared.size !== spec.byteSize) {
|
|
63
|
-
throw new Error(`sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
|
|
64
|
-
`compiled graph expects ${spec.byteSize}.`);
|
|
65
|
-
}
|
|
66
|
-
buffers.set(spec.id, shared);
|
|
67
|
-
continue;
|
|
68
|
-
}
|
|
69
|
-
const buf = device.createBuffer({
|
|
70
|
-
size: spec.byteSize,
|
|
71
|
-
usage: STORAGE_RW,
|
|
72
|
-
label: spec.name ?? `t${spec.id}-${spec.kind}`,
|
|
73
|
-
});
|
|
74
|
-
buffers.set(spec.id, buf);
|
|
75
|
-
ownedBufferIds.add(spec.id);
|
|
76
|
-
if (spec.kind === 'state')
|
|
77
|
-
fillStateBuffer(spec, buf);
|
|
78
|
-
}
|
|
79
|
-
// ---- Compile pipelines per kernel; cache by WGSL source -------------------
|
|
80
|
-
// Push an error scope around each shader+pipeline creation so we can surface
|
|
81
|
-
// the actual compile error rather than the cryptic "previous error" that
|
|
82
|
-
// comes from using an invalid pipeline at dispatch time.
|
|
83
|
-
const moduleCache = new Map();
|
|
84
|
-
const pipelines = [];
|
|
85
|
-
const probes = [];
|
|
86
|
-
for (const k of kernels) {
|
|
87
|
-
if (!k.wgsl) {
|
|
88
|
-
pipelines.push(null);
|
|
89
|
-
continue;
|
|
90
|
-
}
|
|
91
|
-
let module = moduleCache.get(k.wgsl);
|
|
92
|
-
if (!module) {
|
|
93
|
-
module = device.createShaderModule({ code: k.wgsl, label: k.opKind });
|
|
94
|
-
moduleCache.set(k.wgsl, module);
|
|
95
|
-
}
|
|
96
|
-
device.pushErrorScope('validation');
|
|
97
|
-
const pipeline = device.createComputePipeline({
|
|
98
|
-
layout: 'auto',
|
|
99
|
-
compute: { module, entryPoint: 'main' },
|
|
100
|
-
label: k.opKind,
|
|
101
|
-
});
|
|
102
|
-
pipelines.push(pipeline);
|
|
103
|
-
probes.push(device.popErrorScope().then(err => err ? { k, module: module, err } : null));
|
|
104
|
-
}
|
|
105
|
-
const probeResults = await Promise.all(probes);
|
|
106
|
-
const failures = probeResults.filter((p) => p != null);
|
|
107
|
-
if (failures.length > 0) {
|
|
108
|
-
const reports = [];
|
|
109
|
-
for (const { k, module, err } of failures) {
|
|
110
|
-
const info = await module.getCompilationInfo();
|
|
111
|
-
const messages = info.messages
|
|
112
|
-
.map(m => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`)
|
|
113
|
-
.join('\n');
|
|
114
|
-
reports.push(`[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}\n` +
|
|
115
|
-
(messages || ' (no compilation messages)') +
|
|
116
|
-
`\n--- WGSL ---\n${k.wgsl}\n-----------`);
|
|
117
|
-
}
|
|
118
|
-
// eslint-disable-next-line no-console
|
|
119
|
-
console.error(reports.join('\n\n'));
|
|
120
|
-
throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);
|
|
121
|
-
}
|
|
122
|
-
// ---- Pre-build bind groups (static — buffer ids don't change per step) ---
|
|
123
|
-
const bindGroups = kernels.map((k, i) => {
|
|
124
|
-
const pipeline = pipelines[i];
|
|
125
|
-
if (!pipeline)
|
|
126
|
-
return null;
|
|
127
|
-
return device.createBindGroup({
|
|
128
|
-
layout: pipeline.getBindGroupLayout(0),
|
|
129
|
-
entries: k.bindings.map((bufId, idx) => ({
|
|
130
|
-
binding: idx,
|
|
131
|
-
resource: { buffer: buffers.get(bufId) },
|
|
132
|
-
})),
|
|
133
|
-
});
|
|
134
|
-
});
|
|
135
|
-
// ---- Output readback staging buffer ---------------------------------------
|
|
136
|
-
// `outputBufferId` is the graph's main output (loss for training, the user's
|
|
137
|
-
// returned tensor for forward-only). step() reads back its first element;
|
|
138
|
-
// run() reads back the full Float32Array.
|
|
139
|
-
const outputSpec = plan.buffers[lossBufferId];
|
|
140
|
-
const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK });
|
|
141
|
-
let captureStaging = null;
|
|
142
|
-
function ensureCaptureStaging() {
|
|
143
|
-
if (captureStaging)
|
|
144
|
-
return captureStaging;
|
|
145
|
-
let totalBytes = 0;
|
|
146
|
-
const slices = [];
|
|
147
|
-
for (const [name, bufId] of plan.capturesByName) {
|
|
148
|
-
const spec = plan.buffers[bufId];
|
|
149
|
-
// copyBufferToBuffer offsets must be 4-aligned. Capture byteSizes are
|
|
150
|
-
// always shape-product × 4 (f32/i32/bool all 4 bytes), so cumulative
|
|
151
|
-
// offsets stay aligned.
|
|
152
|
-
slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });
|
|
153
|
-
totalBytes += spec.byteSize;
|
|
154
|
-
}
|
|
155
|
-
const buffer = device.createBuffer({ size: totalBytes, usage: READBACK, label: 'captures-staging' });
|
|
156
|
-
captureStaging = { buffer, slices };
|
|
157
|
-
return captureStaging;
|
|
158
|
-
}
|
|
159
|
-
// ---- dispatch() — shared core for step() and run() -----------------------
|
|
160
|
-
// Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
|
|
161
|
-
// the output buffer into its staging, optionally copies captures into theirs,
|
|
162
|
-
// submits, and reads back. Returns the full output Float32Array; step() takes
|
|
163
|
-
// [0] for scalar loss, run() returns it whole.
|
|
164
|
-
//
|
|
165
|
-
// **Concurrent calls auto-serialize.** Two `step()`/`run()` calls on the same
|
|
166
|
-
// runtime would otherwise both try to `mapAsync` the shared output staging
|
|
167
|
-
// buffer at the same time and trip "Buffer already has an outstanding map
|
|
168
|
-
// pending." We chain each new dispatch onto the prior one's promise so they
|
|
169
|
-
// run sequentially even when fired from independent async paths (e.g., a
|
|
170
|
-
// training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
|
|
171
|
-
let pending = Promise.resolve();
|
|
172
|
-
async function dispatch(inputs, wantCaptures) {
|
|
173
|
-
const turn = pending.catch(() => { }).then(() => dispatchUnsynchronized(inputs, wantCaptures));
|
|
174
|
-
pending = turn;
|
|
175
|
-
return turn;
|
|
176
|
-
}
|
|
177
|
-
async function dispatchUnsynchronized(inputs, wantCaptures) {
|
|
178
|
-
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
179
|
-
throw new Error(`withCaptures=true but no capture(...) calls were registered during ` +
|
|
180
|
-
`the trace. Add capture('name', tensor) inside your forward pass for ` +
|
|
181
|
-
`the intermediates you want read back.`);
|
|
182
|
-
}
|
|
183
|
-
for (const [name, bufId] of plan.inputsByName) {
|
|
184
|
-
const data = inputs[name];
|
|
185
|
-
if (!data)
|
|
186
|
-
throw new Error(`tensorgrad: missing input '${name}'`);
|
|
187
|
-
const expectedBytes = plan.buffers[bufId].byteSize;
|
|
188
|
-
if (data.byteLength !== expectedBytes) {
|
|
189
|
-
throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`);
|
|
190
|
-
}
|
|
191
|
-
// Cast to BufferSource: typed arrays are accepted by writeBuffer at runtime
|
|
192
|
-
// but TS may infer ArrayBufferLike (vs ArrayBuffer) under strict configs.
|
|
193
|
-
queue.writeBuffer(buffers.get(bufId), 0, data);
|
|
194
|
-
}
|
|
195
|
-
const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' });
|
|
196
|
-
for (let i = 0; i < kernels.length; i++) {
|
|
197
|
-
const k = kernels[i];
|
|
198
|
-
if (!k.wgsl || k.threads === 0)
|
|
199
|
-
continue;
|
|
200
|
-
const pipeline = pipelines[i];
|
|
201
|
-
const bindGroup = bindGroups[i];
|
|
202
|
-
const pass = encoder.beginComputePass({ label: k.opKind });
|
|
203
|
-
pass.setPipeline(pipeline);
|
|
204
|
-
pass.setBindGroup(0, bindGroup);
|
|
205
|
-
// WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
|
|
206
|
-
// when a kernel needs more than that on the X axis. Kernels compute their
|
|
207
|
-
// global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
|
|
208
|
-
// stride we set here. For dispatches that fit in one row, gid.y is 0.
|
|
209
|
-
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));
|
|
210
|
-
const MAX_X = 65535;
|
|
211
|
-
const wgX = Math.min(wgCount, MAX_X);
|
|
212
|
-
const wgY = Math.ceil(wgCount / MAX_X);
|
|
213
|
-
pass.dispatchWorkgroups(wgX, wgY, 1);
|
|
214
|
-
pass.end();
|
|
215
|
-
}
|
|
216
|
-
// After all dispatches: writebacks (Adam state, updated params). Empty for
|
|
217
|
-
// forward-only compiles.
|
|
218
|
-
for (const wb of plan.writebacks) {
|
|
219
|
-
encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);
|
|
220
|
-
}
|
|
221
|
-
encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);
|
|
222
|
-
// Capture readbacks (only when opted in). All captures concatenate into
|
|
223
|
-
// a single staging buffer so we mapAsync once instead of N times.
|
|
224
|
-
let layout = null;
|
|
225
|
-
if (wantCaptures) {
|
|
226
|
-
layout = ensureCaptureStaging();
|
|
227
|
-
for (const s of layout.slices) {
|
|
228
|
-
encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);
|
|
229
|
-
}
|
|
230
|
-
}
|
|
231
|
-
queue.submit([encoder.finish()]);
|
|
232
|
-
await outputReadback.mapAsync(GPUMapMode.READ);
|
|
233
|
-
const output = new Float32Array(outputReadback.getMappedRange().slice(0));
|
|
234
|
-
outputReadback.unmap();
|
|
235
|
-
const captures = new Map();
|
|
236
|
-
if (layout) {
|
|
237
|
-
await layout.buffer.mapAsync(GPUMapMode.READ);
|
|
238
|
-
const range = layout.buffer.getMappedRange();
|
|
239
|
-
for (const s of layout.slices) {
|
|
240
|
-
// Copy out (slice) before unmap — the underlying ArrayBuffer is
|
|
241
|
-
// detached when the buffer unmaps.
|
|
242
|
-
captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());
|
|
243
|
-
}
|
|
244
|
-
layout.buffer.unmap();
|
|
245
|
-
}
|
|
246
|
-
return { output, captures };
|
|
247
|
-
}
|
|
248
|
-
async function step(inputs, opts) {
|
|
249
|
-
const r = await dispatch(inputs, opts?.withCaptures === true);
|
|
250
|
-
if (opts?.withCaptures)
|
|
251
|
-
return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };
|
|
252
|
-
return r.output[0];
|
|
253
|
-
}
|
|
254
|
-
async function run(inputs, opts) {
|
|
255
|
-
const r = await dispatch(inputs, opts?.withCaptures === true);
|
|
256
|
-
if (opts?.withCaptures)
|
|
257
|
-
return { output: r.output, captures: new Captures(captureShapes, r.captures) };
|
|
258
|
-
return r.output;
|
|
259
|
-
}
|
|
260
|
-
// ---- uploadParams ---------------------------------------------------------
|
|
261
|
-
function uploadParams(params, opts) {
|
|
262
|
-
const partial = opts?.partial ?? false;
|
|
263
|
-
for (const name of Object.keys(params)) {
|
|
264
|
-
if (!plan.paramsByName.has(name)) {
|
|
265
|
-
throw new Error(`uploadParams: unknown param '${name}'. ` +
|
|
266
|
-
`Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`);
|
|
267
|
-
}
|
|
268
|
-
}
|
|
269
|
-
if (!partial) {
|
|
270
|
-
for (const name of plan.paramsByName.keys()) {
|
|
271
|
-
if (!(name in params)) {
|
|
272
|
-
throw new Error(`uploadParams: missing param '${name}'. ` +
|
|
273
|
-
`Pass { partial: true } if you mean to update only some params.`);
|
|
274
|
-
}
|
|
275
|
-
}
|
|
276
|
-
}
|
|
277
|
-
for (const [name, bufId] of plan.paramsByName) {
|
|
278
|
-
const data = params[name];
|
|
279
|
-
if (!data)
|
|
280
|
-
continue;
|
|
281
|
-
const expected = plan.buffers[bufId].byteSize / 4;
|
|
282
|
-
if (data.length !== expected) {
|
|
283
|
-
throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`);
|
|
284
|
-
}
|
|
285
|
-
queue.writeBuffer(buffers.get(bufId), 0, data);
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
// ---- download helpers -----------------------------------------------------
|
|
289
|
-
async function downloadFromMap(map) {
|
|
290
|
-
const stagings = [];
|
|
291
|
-
const encoder = device.createCommandEncoder({ label: 'tensorgrad-download' });
|
|
292
|
-
for (const [name, bufId] of map) {
|
|
293
|
-
const spec = plan.buffers[bufId];
|
|
294
|
-
const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK });
|
|
295
|
-
encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);
|
|
296
|
-
stagings.push({ name, buf: staging, bytes: spec.byteSize });
|
|
297
|
-
}
|
|
298
|
-
queue.submit([encoder.finish()]);
|
|
299
|
-
const out = {};
|
|
300
|
-
for (const s of stagings) {
|
|
301
|
-
await s.buf.mapAsync(GPUMapMode.READ);
|
|
302
|
-
out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));
|
|
303
|
-
s.buf.unmap();
|
|
304
|
-
s.buf.destroy();
|
|
305
|
-
}
|
|
306
|
-
return out;
|
|
307
|
-
}
|
|
308
|
-
// Fill a state buffer with its declared initValue (typically 0). Float and
|
|
309
|
-
// int both serialize to 4 bytes per element. Used at allocation time and on
|
|
310
|
-
// resetOptimizerState() — same logic, two callers.
|
|
311
|
-
function fillStateBuffer(spec, target) {
|
|
312
|
-
const elements = spec.byteSize / 4;
|
|
313
|
-
const init = spec.dtype === 'f32'
|
|
314
|
-
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
315
|
-
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));
|
|
316
|
-
queue.writeBuffer(target, 0, init);
|
|
317
|
-
}
|
|
318
|
-
function resetOptimizerState() {
|
|
319
|
-
for (const spec of plan.buffers) {
|
|
320
|
-
if (spec.kind === 'state')
|
|
321
|
-
fillStateBuffer(spec, buffers.get(spec.id));
|
|
322
|
-
}
|
|
323
|
-
}
|
|
324
|
-
// Build the params map AFTER buffer allocation so it points at the actual
|
|
325
|
-
// GPUBuffers (shared or freshly allocated).
|
|
326
|
-
const params = new Map();
|
|
327
|
-
for (const [name, bufId] of plan.paramsByName) {
|
|
328
|
-
params.set(name, buffers.get(bufId));
|
|
329
|
-
}
|
|
330
|
-
// Static-after-compile shape metadata so users don't have to recompute
|
|
331
|
-
// strides to interpret a flat capture readback.
|
|
332
|
-
const captureShapes = {};
|
|
333
|
-
for (const [name, bufId] of plan.capturesByName) {
|
|
334
|
-
captureShapes[name] = [...plan.buffers[bufId].shape];
|
|
335
|
-
}
|
|
336
|
-
const outputShape = [...plan.buffers[lossBufferId].shape];
|
|
337
|
-
const destroy = () => {
|
|
338
|
-
for (const [id, b] of buffers) {
|
|
339
|
-
if (ownedBufferIds.has(id))
|
|
340
|
-
b.destroy();
|
|
341
|
-
}
|
|
342
|
-
outputReadback.destroy();
|
|
343
|
-
if (captureStaging)
|
|
344
|
-
captureStaging.buffer.destroy();
|
|
345
|
-
};
|
|
346
|
-
return {
|
|
347
|
-
device,
|
|
348
|
-
params,
|
|
349
|
-
outputShape,
|
|
350
|
-
uploadParams,
|
|
351
|
-
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
352
|
-
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
353
|
-
step,
|
|
354
|
-
run,
|
|
355
|
-
resetOptimizerState,
|
|
356
|
-
destroy,
|
|
357
|
-
};
|
|
358
|
-
}
|
|
359
|
-
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
360
|
-
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
361
|
-
* loss readback). The full runtime object is built once and projected by
|
|
362
|
-
* `compileForward` to the public shape. */
|
|
363
|
-
export async function createForwardRuntime(plan, kernels, outputBufferId, opts = {}) {
|
|
364
|
-
return await createRuntime(plan, kernels, outputBufferId, opts);
|
|
365
|
-
}
|
|
366
|
-
async function acquireDevice() {
|
|
367
|
-
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
368
|
-
throw new Error('tensorgrad: WebGPU not available in this environment');
|
|
369
|
-
}
|
|
370
|
-
const adapter = await navigator.gpu.requestAdapter();
|
|
371
|
-
if (!adapter)
|
|
372
|
-
throw new Error('tensorgrad: no WebGPU adapter');
|
|
373
|
-
return await adapter.requestDevice();
|
|
374
|
-
}
|
|
375
|
-
//# sourceMappingURL=runtime.js.map
|
package/dist/runtime.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"runtime.js","sourceRoot":"","sources":["../src/runtime.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,2EAA2E;AAC3E,uEAAuE;AACvE,EAAE;AACF,8DAA8D;AAgB9D;;;;;;;GAOG;AACH,MAAM,OAAO,QAAQ;IAEA;IACA;IAFnB,YACmB,MAAyC,EACzC,IAA+B;QAD/B,WAAM,GAAN,MAAM,CAAmC;QACzC,SAAI,GAAJ,IAAI,CAA2B;IAC/C,CAAC;IACJ,GAAG,CAAC,IAAY;QACd,MAAM,CAAC,GAAG,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,CAAA;QAC7B,IAAI,CAAC,CAAC,EAAE,CAAC;YACP,MAAM,KAAK,GAAG,CAAC,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;YACrD,MAAM,MAAM,GAAG,KAAK,CAAC,CAAC,CAAC,oBAAoB,KAAK,EAAE,CAAC,CAAC,CAAC,yDAAyD,CAAA;YAC9G,MAAM,IAAI,KAAK,CAAC,kBAAkB,IAAI,kBAAkB,MAAM,EAAE,CAAC,CAAA;QACnE,CAAC;QACD,OAAO,CAAC,CAAA;IACV,CAAC;IACD,OAAO,CAAC,IAAY;QAClB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,CAAA;QAC3B,IAAI,CAAC,CAAC,EAAE,CAAC;YACP,MAAM,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,mBAAmB,CAAA;YAC/E,MAAM,IAAI,KAAK,CAAC,sBAAsB,IAAI,4BAA4B,KAAK,EAAE,CAAC,CAAA;QAChF,CAAC;QACD,OAAO,CAAC,CAAA;IACV,CAAC;IACD,GAAG,CAAC,IAAY,IAAa,OAAO,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,CAAA,CAAC,CAAC;IACzD,KAAK,KAAe,OAAO,CAAC,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,IAAI,EAAE,CAAA,CAAC,CAAC;CAC1D;AAyFD,gFAAgF;AAChF,+EAA+E;AAC/E,2EAA2E;AAC3E,MAAM,UAAU,GAAG,IAAI,CAAC,WAAW,GAAG,GAAG,CAAC,YAAY,GAAG,GAAG,CAAA,CAAC,YAAY;AACzE,MAAM,QAAQ,GAAG,GAAG,CAAC,YAAY,GAAG,GAAG,CAAA,CAAC,YAAY;AAEpD,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,IAAgB,EAChB,OAAqB,EACrB,YAAoB,EACpB,OAAoB,EAAE;IAEtB,MAAM,MAAM,GAAG,IAAI,CAAC,MAAM,IAAI,MAAM,aAAa,EAAE,CAAA;IACnD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAA;IAE1B,8EAA8E;IAC9E,yEAAyE;IACzE,0EAA0E;IAC1E,sEAAsE;IACtE,qEAAqE;IACrE,0EAA0E;IAC1E,uEAAuE;IACvE,MAAM,OAAO,GAAG,IAAI,GAAG,EAAqB,CAAA;IAC5C,MAAM,cAAc,GAAG,IAAI,GAAG,EAAU,CAAA;IACxC,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,KAAK,MAAM,IAAI,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;QAChC,MAAM,MAAM,GAAG,IAAI,CAAC,IAAI,KAAK,OAAO,CAAC,CAAC,CAAC,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC,IAAK,CAAC,CAAC,CAAC,CAAC,SAAS,CAAA;QAChF,IAAI,MAAM,EAAE,CAAC;YACX,IAAI,MAAM,CAAC,IAAI,KAAK,IAAI,CAAC,QAAQ,EAAE,CAAC;gBAClC,MAAM,IAAI,KAAK,CACb,oCAAoC,IAAI,CAAC,IAAI,gBAAgB,MAAM,CAAC,IAAI,UAAU;oBAClF,0BAA0B,IAAI,CAAC,QAAQ,GAAG,CAC3C,CAAA;YACH,CAAC;YACD,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,MAAM,CAAC,CAAA;YAC5B,SAAQ;QACV,CAAC;QACD,MAAM,GAAG,GAAG,MAAM,CAAC,YAAY,CAAC;YAC9B,IAAI,EAAE,IAAI,CAAC,QAAQ;YACnB,KAAK,EAAE,UAAU;YACjB,KAAK,EAAE,IAAI,CAAC,IAAI,IAAI,IAAI,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,EAAE;SAC/C,CAAC,CAAA;QACF,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,GAAG,CAAC,CAAA;QACzB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;QAC3B,IAAI,IAAI,CAAC,IAAI,KAAK,OAAO;YAAE,eAAe,CAAC,IAAI,EAAE,GAAG,CAAC,CAAA;IACvD,CAAC;IAED,8EAA8E;IAC9E,6EAA6E;IAC7E,yEAAyE;IACzE,yDAAyD;IACzD,MAAM,WAAW,GAAG,IAAI,GAAG,EAA2B,CAAA;IACtD,MAAM,SAAS,GAAkC,EAAE,CAAA;IAEnD,MAAM,MAAM,GAAiB,EAAE,CAAA;IAC/B,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;QACxB,IAAI,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC;YAAC,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YAAC,SAAQ;QAAC,CAAC;QAC/C,IAAI,MAAM,GAAG,WAAW,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC,CAAA;QACpC,IAAI,CAAC,MAAM,EAAE,CAAC;YACZ,MAAM,GAAG,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,CAAC,CAAC,IAAI,EAAE,KAAK,EAAE,CAAC,CAAC,MAAM,EAAE,CAAC,CAAA;YACrE,WAAW,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,EAAE,MAAM,CAAC,CAAA;QACjC,CAAC;QACD,MAAM,CAAC,cAAc,CAAC,YAAY,CAAC,CAAA;QACnC,MAAM,QAAQ,GAAG,MAAM,CAAC,qBAAqB,CAAC;YAC5C,MAAM,EAAE,MAAM;YACd,OAAO,EAAE,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,EAAE;YACvC,KAAK,EAAE,CAAC,CAAC,MAAM;SAChB,CAAC,CAAA;QACF,SAAS,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACxB,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,EAAE,MAAO,EAAE,GAAG,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,YAAY,GAAG,MAAM,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,CAAA;IAC9C,MAAM,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC,CAAC,EAAkE,EAAE,CAAC,CAAC,IAAI,IAAI,CAAC,CAAA;IACtH,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;QACxB,MAAM,OAAO,GAAa,EAAE,CAAA;QAC5B,KAAK,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,GAAG,EAAE,IAAI,QAAQ,EAAE,CAAC;YAC1C,MAAM,IAAI,GAAG,MAAM,MAAM,CAAC,kBAAkB,EAAE,CAAA;YAC9C,MAAM,QAAQ,GAAG,IAAI,CAAC,QAAQ;iBAC3B,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,OAAO,IAAI,CAAC,CAAC,OAAO,KAAK,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,OAAO,EAAE,CAAC;iBACjE,IAAI,CAAC,IAAI,CAAC,CAAA;YACb,OAAO,CAAC,IAAI,CACV,0BAA0B,CAAC,CAAC,MAAM,SAAS,CAAC,CAAC,OAAO,MAAM,GAAG,CAAC,OAAO,IAAI;gBACzE,CAAC,QAAQ,IAAI,6BAA6B,CAAC;gBAC3C,mBAAmB,CAAC,CAAC,IAAI,eAAe,CACzC,CAAA;QACH,CAAC;QACD,sCAAsC;QACtC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAA;QACnC,MAAM,IAAI,KAAK,CAAC,eAAe,QAAQ,CAAC,MAAM,6CAA6C,CAAC,CAAA;IAC9F,CAAC;IAED,6EAA6E;IAC7E,MAAM,UAAU,GAA4B,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;QAC/D,MAAM,QAAQ,GAAG,SAAS,CAAC,CAAC,CAAC,CAAA;QAC7B,IAAI,CAAC,QAAQ;YAAE,OAAO,IAAI,CAAA;QAC1B,OAAO,MAAM,CAAC,eAAe,CAAC;YAC5B,MAAM,EAAE,QAAQ,CAAC,kBAAkB,CAAC,CAAC,CAAC;YACtC,OAAO,EAAE,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC;gBACvC,OAAO,EAAE,GAAG;gBACZ,QAAQ,EAAE,EAAE,MAAM,EAAE,OAAO,CAAC,GAAG,CAAC,KAAK,CAAE,EAAE;aAC1C,CAAC,CAAC;SACJ,CAAC,CAAA;IACJ,CAAC,CAAC,CAAA;IAEF,8EAA8E;IAC9E,6EAA6E;IAC7E,0EAA0E;IAC1E,0CAA0C;IAC1C,MAAM,UAAU,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAE,CAAA;IAC9C,MAAM,cAAc,GAAG,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI,EAAE,UAAU,CAAC,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAA;IAa1F,IAAI,cAAc,GAAyB,IAAI,CAAA;IAC/C,SAAS,oBAAoB;QAC3B,IAAI,cAAc;YAAE,OAAO,cAAc,CAAA;QACzC,IAAI,UAAU,GAAG,CAAC,CAAA;QAClB,MAAM,MAAM,GAA4B,EAAE,CAAA;QAC1C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,cAAc,EAAE,CAAC;YAChD,MAAM,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAA;YACjC,sEAAsE;YACtE,qEAAqE;YACrE,wBAAwB;YACxB,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,KAAK,EAAE,MAAM,EAAE,UAAU,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAA;YACzE,UAAU,IAAI,IAAI,CAAC,QAAQ,CAAA;QAC7B,CAAC;QACD,MAAM,MAAM,GAAG,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI,EAAE,UAAU,EAAE,KAAK,EAAE,QAAQ,EAAE,KAAK,EAAE,kBAAkB,EAAE,CAAC,CAAA;QACpG,cAAc,GAAG,EAAE,MAAM,EAAE,MAAM,EAAE,CAAA;QACnC,OAAO,cAAc,CAAA;IACvB,CAAC;IAED,6EAA6E;IAC7E,+EAA+E;IAC/E,8EAA8E;IAC9E,8EAA8E;IAC9E,+CAA+C;IAC/C,EAAE;IACF,8EAA8E;IAC9E,2EAA2E;IAC3E,0EAA0E;IAC1E,4EAA4E;IAC5E,yEAAyE;IACzE,0EAA0E;IAC1E,IAAI,OAAO,GAAqB,OAAO,CAAC,OAAO,EAAE,CAAA;IACjD,KAAK,UAAU,QAAQ,CACrB,MAAiD,EACjD,YAAqB;QAErB,MAAM,IAAI,GAAG,OAAO,CAAC,KAAK,CAAC,GAAG,EAAE,GAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,sBAAsB,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAA;QAC7F,OAAO,GAAG,IAAI,CAAA;QACd,OAAO,IAAI,CAAA;IACb,CAAC;IACD,KAAK,UAAU,sBAAsB,CACnC,MAAiD,EACjD,YAAqB;QAErB,IAAI,YAAY,IAAI,IAAI,CAAC,cAAc,CAAC,IAAI,KAAK,CAAC,EAAE,CAAC;YACnD,MAAM,IAAI,KAAK,CACb,qEAAqE;gBACrE,sEAAsE;gBACtE,uCAAuC,CACxC,CAAA;QACH,CAAC;QACD,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,MAAM,IAAI,GAAG,MAAM,CAAC,IAAI,CAAC,CAAA;YACzB,IAAI,CAAC,IAAI;gBAAE,MAAM,IAAI,KAAK,CAAC,8BAA8B,IAAI,GAAG,CAAC,CAAA;YACjE,MAAM,aAAa,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,QAAQ,CAAA;YACnD,IAAI,IAAI,CAAC,UAAU,KAAK,aAAa,EAAE,CAAC;gBACtC,MAAM,IAAI,KAAK,CAAC,sBAAsB,IAAI,SAAS,IAAI,CAAC,UAAU,oBAAoB,aAAa,EAAE,CAAC,CAAA;YACxG,CAAC;YACD,4EAA4E;YAC5E,0EAA0E;YAC1E,KAAK,CAAC,WAAW,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAE,EAAE,CAAC,EAAE,IAA+B,CAAC,CAAA;QAC5E,CAAC;QAED,MAAM,OAAO,GAAG,MAAM,CAAC,oBAAoB,CAAC,EAAE,KAAK,EAAE,iBAAiB,EAAE,CAAC,CAAA;QACzE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACxC,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAE,CAAA;YACrB,IAAI,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC,OAAO,KAAK,CAAC;gBAAE,SAAQ;YACxC,MAAM,QAAQ,GAAG,SAAS,CAAC,CAAC,CAAE,CAAA;YAC9B,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,CAAE,CAAA;YAChC,MAAM,IAAI,GAAG,OAAO,CAAC,gBAAgB,CAAC,EAAE,KAAK,EAAE,CAAC,CAAC,MAAM,EAAE,CAAC,CAAA;YAC1D,IAAI,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAA;YAC1B,IAAI,CAAC,YAAY,CAAC,CAAC,EAAE,SAAS,CAAC,CAAA;YAC/B,yEAAyE;YACzE,0EAA0E;YAC1E,2EAA2E;YAC3E,sEAAsE;YACtE,MAAM,OAAO,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,OAAO,GAAG,CAAC,CAAC,aAAa,CAAC,CAAC,CAAA;YACnE,MAAM,KAAK,GAAG,KAAK,CAAA;YACnB,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,OAAO,EAAE,KAAK,CAAC,CAAA;YACpC,MAAM,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,OAAO,GAAG,KAAK,CAAC,CAAA;YACtC,IAAI,CAAC,kBAAkB,CAAC,GAAG,EAAE,GAAG,EAAE,CAAC,CAAC,CAAA;YACpC,IAAI,CAAC,GAAG,EAAE,CAAA;QACZ,CAAC;QACD,2EAA2E;QAC3E,yBAAyB;QACzB,KAAK,MAAM,EAAE,IAAI,IAAI,CAAC,UAAU,EAAE,CAAC;YACjC,OAAO,CAAC,kBAAkB,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAE,EAAE,CAAC,EAAE,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAE,EAAE,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAA;QAC5F,CAAC;QACD,OAAO,CAAC,kBAAkB,CAAC,OAAO,CAAC,GAAG,CAAC,YAAY,CAAE,EAAE,CAAC,EAAE,cAAc,EAAE,CAAC,EAAE,UAAU,CAAC,QAAQ,CAAC,CAAA;QACjG,wEAAwE;QACxE,kEAAkE;QAClE,IAAI,MAAM,GAAyB,IAAI,CAAA;QACvC,IAAI,YAAY,EAAE,CAAC;YACjB,MAAM,GAAG,oBAAoB,EAAE,CAAA;YAC/B,KAAK,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,EAAE,CAAC;gBAC9B,OAAO,CAAC,kBAAkB,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,CAAE,EAAE,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC,MAAM,EAAE,CAAC,CAAC,QAAQ,CAAC,CAAA;YAC3F,CAAC;QACH,CAAC;QACD,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAA;QAEhC,MAAM,cAAc,CAAC,QAAQ,CAAC,UAAU,CAAC,IAAI,CAAC,CAAA;QAC9C,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,cAAc,CAAC,cAAc,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAA;QACzE,cAAc,CAAC,KAAK,EAAE,CAAA;QAEtB,MAAM,QAAQ,GAAG,IAAI,GAAG,EAAwB,CAAA;QAChD,IAAI,MAAM,EAAE,CAAC;YACX,MAAM,MAAM,CAAC,MAAM,CAAC,QAAQ,CAAC,UAAU,CAAC,IAAI,CAAC,CAAA;YAC7C,MAAM,KAAK,GAAG,MAAM,CAAC,MAAM,CAAC,cAAc,EAAE,CAAA;YAC5C,KAAK,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,EAAE,CAAC;gBAC9B,gEAAgE;gBAChE,mCAAmC;gBACnC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,EAAE,IAAI,YAAY,CAAC,KAAK,EAAE,CAAC,CAAC,MAAM,EAAE,CAAC,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAA;YACjF,CAAC;YACD,MAAM,CAAC,MAAM,CAAC,KAAK,EAAE,CAAA;QACvB,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,QAAQ,EAAE,CAAA;IAC7B,CAAC;IAMD,KAAK,UAAU,IAAI,CACjB,MAAiD,EACjD,IAAiB;QAEjB,MAAM,CAAC,GAAG,MAAM,QAAQ,CAAC,MAAM,EAAE,IAAI,EAAE,YAAY,KAAK,IAAI,CAAC,CAAA;QAC7D,IAAI,IAAI,EAAE,YAAY;YAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC,MAAM,CAAC,CAAC,CAAE,EAAE,QAAQ,EAAE,IAAI,QAAQ,CAAC,aAAa,EAAE,CAAC,CAAC,QAAQ,CAAC,EAAE,CAAA;QACxG,OAAO,CAAC,CAAC,MAAM,CAAC,CAAC,CAAE,CAAA;IACrB,CAAC;IAQD,KAAK,UAAU,GAAG,CAChB,MAAiD,EACjD,IAAiB;QAEjB,MAAM,CAAC,GAAG,MAAM,QAAQ,CAAC,MAAM,EAAE,IAAI,EAAE,YAAY,KAAK,IAAI,CAAC,CAAA;QAC7D,IAAI,IAAI,EAAE,YAAY;YAAE,OAAO,EAAE,MAAM,EAAE,CAAC,CAAC,MAAM,EAAE,QAAQ,EAAE,IAAI,QAAQ,CAAC,aAAa,EAAE,CAAC,CAAC,QAAQ,CAAC,EAAE,CAAA;QACtG,OAAO,CAAC,CAAC,MAAM,CAAA;IACjB,CAAC;IAED,8EAA8E;IAC9E,SAAS,YAAY,CAAC,MAAoC,EAAE,IAA0B;QACpF,MAAM,OAAO,GAAG,IAAI,EAAE,OAAO,IAAI,KAAK,CAAA;QACtC,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC;YACvC,IAAI,CAAC,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;gBACjC,MAAM,IAAI,KAAK,CACb,gCAAgC,IAAI,KAAK;oBACzC,UAAU,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,CAAC,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE,CAC5D,CAAA;YACH,CAAC;QACH,CAAC;QACD,IAAI,CAAC,OAAO,EAAE,CAAC;YACb,KAAK,MAAM,IAAI,IAAI,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,EAAE,CAAC;gBAC5C,IAAI,CAAC,CAAC,IAAI,IAAI,MAAM,CAAC,EAAE,CAAC;oBACtB,MAAM,IAAI,KAAK,CACb,gCAAgC,IAAI,KAAK;wBACzC,gEAAgE,CACjE,CAAA;gBACH,CAAC;YACH,CAAC;QACH,CAAC;QACD,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,MAAM,IAAI,GAAG,MAAM,CAAC,IAAI,CAAC,CAAA;YACzB,IAAI,CAAC,IAAI;gBAAE,SAAQ;YACnB,MAAM,QAAQ,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,QAAQ,GAAG,CAAC,CAAA;YAClD,IAAI,IAAI,CAAC,MAAM,KAAK,QAAQ,EAAE,CAAC;gBAC7B,MAAM,IAAI,KAAK,CAAC,kBAAkB,IAAI,SAAS,IAAI,CAAC,MAAM,uBAAuB,QAAQ,EAAE,CAAC,CAAA;YAC9F,CAAC;YACD,KAAK,CAAC,WAAW,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAE,EAAE,CAAC,EAAE,IAA+B,CAAC,CAAA;QAC5E,CAAC;IACH,CAAC;IAED,8EAA8E;IAC9E,KAAK,UAAU,eAAe,CAAC,GAAwB;QACrD,MAAM,QAAQ,GAAsD,EAAE,CAAA;QACtE,MAAM,OAAO,GAAG,MAAM,CAAC,oBAAoB,CAAC,EAAE,KAAK,EAAE,qBAAqB,EAAE,CAAC,CAAA;QAC7E,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,GAAG,EAAE,CAAC;YAChC,MAAM,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAA;YACjC,MAAM,OAAO,GAAG,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAA;YAC7E,OAAO,CAAC,kBAAkB,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAE,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAA;YAC7E,QAAQ,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAA;QAC7D,CAAC;QACD,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAA;QAChC,MAAM,GAAG,GAAiC,EAAE,CAAA;QAC5C,KAAK,MAAM,CAAC,IAAI,QAAQ,EAAE,CAAC;YACzB,MAAM,CAAC,CAAC,GAAG,CAAC,QAAQ,CAAC,UAAU,CAAC,IAAI,CAAC,CAAA;YACrC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,GAAG,CAAC,cAAc,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAA;YAC/D,CAAC,CAAC,GAAG,CAAC,KAAK,EAAE,CAAA;YACb,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAA;QACjB,CAAC;QACD,OAAO,GAAG,CAAA;IACZ,CAAC;IAED,2EAA2E;IAC3E,4EAA4E;IAC5E,mDAAmD;IACnD,SAAS,eAAe,CAAC,IAA6E,EAAE,MAAiB;QACvH,MAAM,QAAQ,GAAG,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAA;QAClC,MAAM,IAAI,GAAG,IAAI,CAAC,KAAK,KAAK,KAAK;YAC/B,CAAC,CAAC,IAAI,YAAY,CAAC,QAAQ,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,IAAI,CAAC,CAAC;YACtD,CAAC,CAAC,IAAI,UAAU,CAAC,QAAQ,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,IAAI,CAAC,CAAC,CAAC,CAAA;QAClE,KAAK,CAAC,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,IAA+B,CAAC,CAAA;IAC/D,CAAC;IAED,SAAS,mBAAmB;QAC1B,KAAK,MAAM,IAAI,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YAChC,IAAI,IAAI,CAAC,IAAI,KAAK,OAAO;gBAAE,eAAe,CAAC,IAAI,EAAE,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAE,CAAC,CAAA;QACzE,CAAC;IACH,CAAC;IAED,0EAA0E;IAC1E,4CAA4C;IAC5C,MAAM,MAAM,GAAG,IAAI,GAAG,EAAqB,CAAA;IAC3C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;QAC9C,MAAM,CAAC,GAAG,CAAC,IAAI,EAAE,OAAO,CAAC,GAAG,CAAC,KAAK,CAAE,CAAC,CAAA;IACvC,CAAC;IACD,uEAAuE;IACvE,gDAAgD;IAChD,MAAM,aAAa,GAA6B,EAAE,CAAA;IAClD,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,cAAc,EAAE,CAAC;QAChD,aAAa,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAC,CAAA;IACvD,CAAC;IACD,MAAM,WAAW,GAAG,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,YAAY,CAAE,CAAC,KAAK,CAAC,CAAA;IAE1D,MAAM,OAAO,GAAG,GAAG,EAAE;QACnB,KAAK,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,IAAI,OAAO,EAAE,CAAC;YAC9B,IAAI,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC;gBAAE,CAAC,CAAC,OAAO,EAAE,CAAA;QACzC,CAAC;QACD,cAAc,CAAC,OAAO,EAAE,CAAA;QACxB,IAAI,cAAc;YAAE,cAAc,CAAC,MAAM,CAAC,OAAO,EAAE,CAAA;IACrD,CAAC,CAAA;IAED,OAAO;QACL,MAAM;QACN,MAAM;QACN,WAAW;QACX,YAAY;QACZ,cAAc,EAAE,GAAG,EAAE,CAAC,eAAe,CAAC,IAAI,CAAC,YAAY,CAAC;QACxD,kBAAkB,EAAE,GAAG,EAAE,CAAC,eAAe,CAAC,IAAI,CAAC,gBAAgB,CAAC;QAChE,IAAI;QACJ,GAAG;QACH,mBAAmB;QACnB,OAAO;KACR,CAAA;AACH,CAAC;AAED;;;4CAG4C;AAC5C,MAAM,CAAC,KAAK,UAAU,oBAAoB,CACxC,IAAgB,EAChB,OAAqB,EACrB,cAAsB,EACtB,OAAoB,EAAE;IAEtB,OAAO,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,IAAI,CAAC,CAAA;AACjE,CAAC;AAED,KAAK,UAAU,aAAa;IAC1B,IAAI,OAAO,SAAS,KAAK,WAAW,IAAI,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC;QACvD,MAAM,IAAI,KAAK,CAAC,sDAAsD,CAAC,CAAA;IACzE,CAAC;IACD,MAAM,OAAO,GAAG,MAAM,SAAS,CAAC,GAAG,CAAC,cAAc,EAAE,CAAA;IACpD,IAAI,CAAC,OAAO;QAAE,MAAM,IAAI,KAAK,CAAC,+BAA+B,CAAC,CAAA;IAC9D,OAAO,MAAM,OAAO,CAAC,aAAa,EAAE,CAAA;AACtC,CAAC"}
|
package/dist/shape.js
DELETED
|
@@ -1,259 +0,0 @@
|
|
|
1
|
-
// Shape inference and validation for each op kind.
|
|
2
|
-
//
|
|
3
|
-
// Every op in src/ops.ts validates its inputs and computes its output shape
|
|
4
|
-
// through helpers here. Errors throw with the captured call-site so the
|
|
5
|
-
// stack trace points at the user's line, not into the library.
|
|
6
|
-
//
|
|
7
|
-
// Broadcasting rules (deliberately limited):
|
|
8
|
-
// * For element-wise binops (add/sub/mul/div), we support trailing-axis
|
|
9
|
-
// broadcasting: the smaller operand's shape must be a suffix of the
|
|
10
|
-
// larger's, with axes of size 1 broadcasting to any size. Examples
|
|
11
|
-
// ALLOWED: [B, T, D] op [D] → [B, T, D]
|
|
12
|
-
// [B, T, D] op [1, D] → [B, T, D]
|
|
13
|
-
// [B, T, D] op [B, T, D] → [B, T, D]
|
|
14
|
-
// Examples REJECTED: [B, T, D] op [B] (suffix mismatch)
|
|
15
|
-
// [B, T, D] op [T, D] when T != B (legal numpy, banned here)
|
|
16
|
-
// The restriction makes codegen and autograd much simpler and covers every
|
|
17
|
-
// broadcast pattern in our transformer (biases, layernorm gain/bias, masks).
|
|
18
|
-
import { formatSite } from './ir.js';
|
|
19
|
-
// ============================================================================
|
|
20
|
-
// Errors
|
|
21
|
-
// ============================================================================
|
|
22
|
-
export class ShapeError extends Error {
|
|
23
|
-
constructor(message, site) {
|
|
24
|
-
const formatted = site ? `${message}\n at ${formatSite(site)}` : message;
|
|
25
|
-
super(formatted);
|
|
26
|
-
this.name = 'ShapeError';
|
|
27
|
-
}
|
|
28
|
-
}
|
|
29
|
-
function fail(message, site) {
|
|
30
|
-
throw new ShapeError(message, site);
|
|
31
|
-
}
|
|
32
|
-
// ============================================================================
|
|
33
|
-
// Shape utilities
|
|
34
|
-
// ============================================================================
|
|
35
|
-
export function shapesEqual(a, b) {
|
|
36
|
-
if (a.length !== b.length)
|
|
37
|
-
return false;
|
|
38
|
-
for (let i = 0; i < a.length; i++)
|
|
39
|
-
if (a[i] !== b[i])
|
|
40
|
-
return false;
|
|
41
|
-
return true;
|
|
42
|
-
}
|
|
43
|
-
export function shapeSize(shape) {
|
|
44
|
-
let n = 1;
|
|
45
|
-
for (const d of shape)
|
|
46
|
-
n *= d;
|
|
47
|
-
return n;
|
|
48
|
-
}
|
|
49
|
-
export function showShape(shape) {
|
|
50
|
-
return `[${shape.join(', ')}]`;
|
|
51
|
-
}
|
|
52
|
-
// Standard right-aligned NumPy-style broadcasting. Pad the shorter shape with
|
|
53
|
-
// leading 1s, then per-axis: equal dims unify, size-1 dims broadcast on either
|
|
54
|
-
// side, otherwise incompatible. Returns the resulting shape or null.
|
|
55
|
-
export function broadcastTrailing(a, b) {
|
|
56
|
-
const rank = Math.max(a.length, b.length);
|
|
57
|
-
const out = new Array(rank);
|
|
58
|
-
for (let i = 0; i < rank; i++) {
|
|
59
|
-
const ai = i - (rank - a.length);
|
|
60
|
-
const bi = i - (rank - b.length);
|
|
61
|
-
const av = ai < 0 ? 1 : a[ai];
|
|
62
|
-
const bv = bi < 0 ? 1 : b[bi];
|
|
63
|
-
if (av === bv)
|
|
64
|
-
out[i] = av;
|
|
65
|
-
else if (av === 1)
|
|
66
|
-
out[i] = bv;
|
|
67
|
-
else if (bv === 1)
|
|
68
|
-
out[i] = av;
|
|
69
|
-
else
|
|
70
|
-
return null;
|
|
71
|
-
}
|
|
72
|
-
return out;
|
|
73
|
-
}
|
|
74
|
-
// ============================================================================
|
|
75
|
-
// Per-op shape rules
|
|
76
|
-
// ============================================================================
|
|
77
|
-
//
|
|
78
|
-
// Each rule takes the input shapes and returns the output shape, or throws.
|
|
79
|
-
// All rules accept a `site` for error attribution.
|
|
80
|
-
export function inferElementwiseBinop(opName, aShape, bShape, site) {
|
|
81
|
-
const result = broadcastTrailing(aShape, bShape);
|
|
82
|
-
if (!result) {
|
|
83
|
-
fail(`${opName}: incompatible shapes ${showShape(aShape)} and ${showShape(bShape)}. ` +
|
|
84
|
-
`Trailing-suffix broadcasting only — the smaller shape must be a suffix of the larger, ` +
|
|
85
|
-
`with size-1 axes broadcasting to any size.`, site);
|
|
86
|
-
}
|
|
87
|
-
return result;
|
|
88
|
-
}
|
|
89
|
-
export function inferUnary(_opName, aShape, _site) {
|
|
90
|
-
return aShape;
|
|
91
|
-
}
|
|
92
|
-
export function inferMeanLast(opName, aShape, site) {
|
|
93
|
-
if (aShape.length === 0)
|
|
94
|
-
fail(`${opName}: cannot reduce a 0-d tensor`, site);
|
|
95
|
-
// keepdims=true: replace last axis with 1.
|
|
96
|
-
return [...aShape.slice(0, -1), 1];
|
|
97
|
-
}
|
|
98
|
-
export function inferSumLast(opName, aShape, site) {
|
|
99
|
-
if (aShape.length === 0)
|
|
100
|
-
fail(`${opName}: cannot reduce a 0-d tensor`, site);
|
|
101
|
-
// keepdims=false: drop the last axis.
|
|
102
|
-
return aShape.slice(0, -1);
|
|
103
|
-
}
|
|
104
|
-
export function inferReshape(opName, aShape, newShape, site) {
|
|
105
|
-
// Validate -1 placeholder (at most one allowed) and total size match.
|
|
106
|
-
let inferIdx = -1;
|
|
107
|
-
let knownSize = 1;
|
|
108
|
-
for (let i = 0; i < newShape.length; i++) {
|
|
109
|
-
const d = newShape[i];
|
|
110
|
-
if (d === -1) {
|
|
111
|
-
if (inferIdx !== -1)
|
|
112
|
-
fail(`${opName}: at most one -1 dim allowed in newShape ${showShape(newShape)}`, site);
|
|
113
|
-
inferIdx = i;
|
|
114
|
-
}
|
|
115
|
-
else if (d <= 0) {
|
|
116
|
-
fail(`${opName}: invalid dim ${d} in newShape ${showShape(newShape)}`, site);
|
|
117
|
-
}
|
|
118
|
-
else {
|
|
119
|
-
knownSize *= d;
|
|
120
|
-
}
|
|
121
|
-
}
|
|
122
|
-
const totalIn = shapeSize(aShape);
|
|
123
|
-
const out = [...newShape];
|
|
124
|
-
if (inferIdx !== -1) {
|
|
125
|
-
if (totalIn % knownSize !== 0) {
|
|
126
|
-
fail(`${opName}: cannot reshape ${showShape(aShape)} (size ${totalIn}) to ${showShape(newShape)} — known dims multiply to ${knownSize}`, site);
|
|
127
|
-
}
|
|
128
|
-
out[inferIdx] = totalIn / knownSize;
|
|
129
|
-
}
|
|
130
|
-
else if (knownSize !== totalIn) {
|
|
131
|
-
fail(`${opName}: size mismatch — input ${showShape(aShape)} has ${totalIn} elements but newShape ${showShape(newShape)} has ${knownSize}`, site);
|
|
132
|
-
}
|
|
133
|
-
return out;
|
|
134
|
-
}
|
|
135
|
-
export function inferTranspose(opName, aShape, perm, site) {
|
|
136
|
-
if (perm.length !== aShape.length) {
|
|
137
|
-
fail(`${opName}: perm length ${perm.length} must equal input rank ${aShape.length}`, site);
|
|
138
|
-
}
|
|
139
|
-
const seen = new Set();
|
|
140
|
-
for (const p of perm) {
|
|
141
|
-
if (p < 0 || p >= aShape.length)
|
|
142
|
-
fail(`${opName}: perm index ${p} out of range for rank ${aShape.length}`, site);
|
|
143
|
-
if (seen.has(p))
|
|
144
|
-
fail(`${opName}: perm has duplicate index ${p}`, site);
|
|
145
|
-
seen.add(p);
|
|
146
|
-
}
|
|
147
|
-
return perm.map(p => aShape[p]);
|
|
148
|
-
}
|
|
149
|
-
// matmul: a [..., M, K] · b [K, N] → [..., M, N]. b is unbatched.
|
|
150
|
-
export function inferMatmul(opName, aShape, bShape, site) {
|
|
151
|
-
if (aShape.length < 2)
|
|
152
|
-
fail(`${opName}: lhs must have rank >= 2, got ${showShape(aShape)}`, site);
|
|
153
|
-
if (bShape.length !== 2)
|
|
154
|
-
fail(`${opName}: rhs must have rank 2, got ${showShape(bShape)} — use matmulBatched for batched rhs`, site);
|
|
155
|
-
const M = aShape[aShape.length - 2];
|
|
156
|
-
const Ka = aShape[aShape.length - 1];
|
|
157
|
-
const Kb = bShape[0];
|
|
158
|
-
const N = bShape[1];
|
|
159
|
-
if (Ka !== Kb)
|
|
160
|
-
fail(`${opName}: inner dims don't match — ${showShape(aShape)} · ${showShape(bShape)} (last axis of lhs = ${Ka}, first axis of rhs = ${Kb})`, site);
|
|
161
|
-
return [...aShape.slice(0, -2), M, N];
|
|
162
|
-
}
|
|
163
|
-
// matmul_batched: a [..., M, K] · b [..., K, N] → [..., M, N]. Both have leading batch dims.
|
|
164
|
-
export function inferMatmulBatched(opName, aShape, bShape, site) {
|
|
165
|
-
if (aShape.length < 2 || bShape.length < 2) {
|
|
166
|
-
fail(`${opName}: both inputs must have rank >= 2, got ${showShape(aShape)} and ${showShape(bShape)}`, site);
|
|
167
|
-
}
|
|
168
|
-
if (aShape.length !== bShape.length) {
|
|
169
|
-
fail(`${opName}: ranks must match (got ${aShape.length} vs ${bShape.length}). Reshape if you need different batch dims.`, site);
|
|
170
|
-
}
|
|
171
|
-
const aBatch = aShape.slice(0, -2);
|
|
172
|
-
const bBatch = bShape.slice(0, -2);
|
|
173
|
-
for (let i = 0; i < aBatch.length; i++) {
|
|
174
|
-
if (aBatch[i] !== bBatch[i]) {
|
|
175
|
-
fail(`${opName}: batch dims must match — ${showShape(aShape)} vs ${showShape(bShape)}`, site);
|
|
176
|
-
}
|
|
177
|
-
}
|
|
178
|
-
const M = aShape[aShape.length - 2];
|
|
179
|
-
const Ka = aShape[aShape.length - 1];
|
|
180
|
-
const Kb = bShape[bShape.length - 2];
|
|
181
|
-
const N = bShape[bShape.length - 1];
|
|
182
|
-
if (Ka !== Kb)
|
|
183
|
-
fail(`${opName}: inner dims don't match — last axis of lhs = ${Ka}, second-to-last of rhs = ${Kb}`, site);
|
|
184
|
-
return [...aBatch, M, N];
|
|
185
|
-
}
|
|
186
|
-
export function inferOneHot(opName, indicesShape, depth, site) {
|
|
187
|
-
if (depth <= 0)
|
|
188
|
-
fail(`${opName}: depth must be positive, got ${depth}`, site);
|
|
189
|
-
return [...indicesShape, depth];
|
|
190
|
-
}
|
|
191
|
-
// where_causal preserves shape but requires the last two axes to be square.
|
|
192
|
-
export function inferWhereCausal(opName, aShape, site) {
|
|
193
|
-
if (aShape.length < 2)
|
|
194
|
-
fail(`${opName}: requires rank >= 2, got ${showShape(aShape)}`, site);
|
|
195
|
-
const m = aShape[aShape.length - 2];
|
|
196
|
-
const n = aShape[aShape.length - 1];
|
|
197
|
-
if (m !== n)
|
|
198
|
-
fail(`${opName}: last two axes must be equal (square mask), got ${showShape(aShape)}`, site);
|
|
199
|
-
return aShape;
|
|
200
|
-
}
|
|
201
|
-
export function inferSliceLastRange(opName, aShape, start, end, site) {
|
|
202
|
-
if (aShape.length === 0)
|
|
203
|
-
fail(`${opName}: cannot slice 0-d tensor`, site);
|
|
204
|
-
const last = aShape[aShape.length - 1];
|
|
205
|
-
if (start < 0 || end > last || start >= end) {
|
|
206
|
-
fail(`${opName}: invalid range [${start}, ${end}) for last axis of size ${last}`, site);
|
|
207
|
-
}
|
|
208
|
-
return [...aShape.slice(0, -1), end - start];
|
|
209
|
-
}
|
|
210
|
-
// broadcast_to: validate that `aShape` can broadcast to `targetShape` under
|
|
211
|
-
// right-aligned NumPy rules. Returns targetShape on success.
|
|
212
|
-
export function inferBroadcastTo(opName, aShape, targetShape, site) {
|
|
213
|
-
if (aShape.length > targetShape.length) {
|
|
214
|
-
fail(`${opName}: source rank ${aShape.length} > target rank ${targetShape.length}`, site);
|
|
215
|
-
}
|
|
216
|
-
const offset = targetShape.length - aShape.length;
|
|
217
|
-
for (let i = 0; i < aShape.length; i++) {
|
|
218
|
-
const av = aShape[i];
|
|
219
|
-
const tv = targetShape[offset + i];
|
|
220
|
-
if (av !== tv && av !== 1) {
|
|
221
|
-
fail(`${opName}: cannot broadcast ${showShape(aShape)} to ${showShape(targetShape)} — axis ${i} (size ${av}) doesn't match target axis ${offset + i} (size ${tv}) and isn't 1`, site);
|
|
222
|
-
}
|
|
223
|
-
}
|
|
224
|
-
return targetShape;
|
|
225
|
-
}
|
|
226
|
-
// sum_to_shape: validate that `targetShape` is a valid right-aligned reduction
|
|
227
|
-
// of `aShape` (i.e., aShape can have been produced by broadcasting targetShape).
|
|
228
|
-
export function inferSumToShape(opName, aShape, targetShape, site) {
|
|
229
|
-
if (targetShape.length > aShape.length) {
|
|
230
|
-
fail(`${opName}: target rank ${targetShape.length} > source rank ${aShape.length}`, site);
|
|
231
|
-
}
|
|
232
|
-
const offset = aShape.length - targetShape.length;
|
|
233
|
-
for (let i = 0; i < targetShape.length; i++) {
|
|
234
|
-
const av = aShape[offset + i];
|
|
235
|
-
const tv = targetShape[i];
|
|
236
|
-
if (av !== tv && tv !== 1) {
|
|
237
|
-
fail(`${opName}: cannot sum-reduce ${showShape(aShape)} to ${showShape(targetShape)} — target axis ${i} (size ${tv}) must be 1 or match source`, site);
|
|
238
|
-
}
|
|
239
|
-
}
|
|
240
|
-
return targetShape;
|
|
241
|
-
}
|
|
242
|
-
// Three-way broadcast for `where(cond, a, b)`. All three shapes must broadcast
|
|
243
|
-
// to a common shape under standard NumPy rules.
|
|
244
|
-
export function inferWhere(opName, condShape, aShape, bShape, site) {
|
|
245
|
-
const ab = broadcastTrailing(aShape, bShape);
|
|
246
|
-
if (!ab)
|
|
247
|
-
fail(`${opName}: a/b incompatible: ${showShape(aShape)} vs ${showShape(bShape)}`, site);
|
|
248
|
-
const result = broadcastTrailing(condShape, ab);
|
|
249
|
-
if (!result)
|
|
250
|
-
fail(`${opName}: cond ${showShape(condShape)} incompatible with broadcast(a, b) ${showShape(ab)}`, site);
|
|
251
|
-
return result;
|
|
252
|
-
}
|
|
253
|
-
export function inferReluGrad(opName, xShape, dyShape, site) {
|
|
254
|
-
if (!shapesEqual(xShape, dyShape)) {
|
|
255
|
-
fail(`${opName}: x and dy must have matching shapes, got ${showShape(xShape)} and ${showShape(dyShape)}`, site);
|
|
256
|
-
}
|
|
257
|
-
return xShape;
|
|
258
|
-
}
|
|
259
|
-
//# sourceMappingURL=shape.js.map
|