tensorgrad 0.0.12 → 0.0.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +36 -36
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +15 -18
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +312 -358
- package/src/grad.ts +1 -11
- package/src/runtime.ts +45 -33
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -115
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/dist/index.d.ts
CHANGED
|
@@ -1,14 +1,740 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
1
|
+
type Dtype = 'f32' | 'i32' | 'bool';
|
|
2
|
+
type Shape = readonly number[];
|
|
3
|
+
interface Tensor {
|
|
4
|
+
readonly id: number;
|
|
5
|
+
readonly shape: Shape;
|
|
6
|
+
readonly dtype: Dtype;
|
|
7
|
+
readonly source: number | null;
|
|
8
|
+
readonly site: CallSite | null;
|
|
9
|
+
}
|
|
10
|
+
interface CallSite {
|
|
11
|
+
readonly opName: string;
|
|
12
|
+
readonly stack: string;
|
|
13
|
+
}
|
|
14
|
+
type OpNode = {
|
|
15
|
+
kind: 'param_input';
|
|
16
|
+
out: number;
|
|
17
|
+
name: string;
|
|
18
|
+
} | {
|
|
19
|
+
kind: 'tensor_input';
|
|
20
|
+
out: number;
|
|
21
|
+
name: string;
|
|
22
|
+
} | {
|
|
23
|
+
kind: 'state_input';
|
|
24
|
+
out: number;
|
|
25
|
+
name: string;
|
|
26
|
+
initValue: number;
|
|
27
|
+
} | {
|
|
28
|
+
kind: 'add';
|
|
29
|
+
out: number;
|
|
30
|
+
a: number;
|
|
31
|
+
b: number;
|
|
32
|
+
} | {
|
|
33
|
+
kind: 'sub';
|
|
34
|
+
out: number;
|
|
35
|
+
a: number;
|
|
36
|
+
b: number;
|
|
37
|
+
} | {
|
|
38
|
+
kind: 'mul';
|
|
39
|
+
out: number;
|
|
40
|
+
a: number;
|
|
41
|
+
b: number;
|
|
42
|
+
} | {
|
|
43
|
+
kind: 'div';
|
|
44
|
+
out: number;
|
|
45
|
+
a: number;
|
|
46
|
+
b: number;
|
|
47
|
+
} | {
|
|
48
|
+
kind: 'mul_scalar';
|
|
49
|
+
out: number;
|
|
50
|
+
a: number;
|
|
51
|
+
scalar: number;
|
|
52
|
+
} | {
|
|
53
|
+
kind: 'add_scalar';
|
|
54
|
+
out: number;
|
|
55
|
+
a: number;
|
|
56
|
+
scalar: number;
|
|
57
|
+
} | {
|
|
58
|
+
kind: 'sqrt';
|
|
59
|
+
out: number;
|
|
60
|
+
a: number;
|
|
61
|
+
} | {
|
|
62
|
+
kind: 'rsqrt';
|
|
63
|
+
out: number;
|
|
64
|
+
a: number;
|
|
65
|
+
} | {
|
|
66
|
+
kind: 'log';
|
|
67
|
+
out: number;
|
|
68
|
+
a: number;
|
|
69
|
+
} | {
|
|
70
|
+
kind: 'exp';
|
|
71
|
+
out: number;
|
|
72
|
+
a: number;
|
|
73
|
+
} | {
|
|
74
|
+
kind: 'relu';
|
|
75
|
+
out: number;
|
|
76
|
+
a: number;
|
|
77
|
+
} | {
|
|
78
|
+
kind: 'mean_last';
|
|
79
|
+
out: number;
|
|
80
|
+
a: number;
|
|
81
|
+
} | {
|
|
82
|
+
kind: 'sum_last';
|
|
83
|
+
out: number;
|
|
84
|
+
a: number;
|
|
85
|
+
} | {
|
|
86
|
+
kind: 'reshape';
|
|
87
|
+
out: number;
|
|
88
|
+
a: number;
|
|
89
|
+
newShape: Shape;
|
|
90
|
+
} | {
|
|
91
|
+
kind: 'transpose';
|
|
92
|
+
out: number;
|
|
93
|
+
a: number;
|
|
94
|
+
perm: readonly number[];
|
|
95
|
+
} | {
|
|
96
|
+
kind: 'matmul';
|
|
97
|
+
out: number;
|
|
98
|
+
a: number;
|
|
99
|
+
b: number;
|
|
100
|
+
} | {
|
|
101
|
+
kind: 'matmul_batched';
|
|
102
|
+
out: number;
|
|
103
|
+
a: number;
|
|
104
|
+
b: number;
|
|
105
|
+
} | {
|
|
106
|
+
kind: 'one_hot';
|
|
107
|
+
out: number;
|
|
108
|
+
indices: number;
|
|
109
|
+
depth: number;
|
|
110
|
+
dtype: Dtype;
|
|
111
|
+
} | {
|
|
112
|
+
kind: 'arange';
|
|
113
|
+
out: number;
|
|
114
|
+
n: number;
|
|
115
|
+
dtype: Dtype;
|
|
116
|
+
} | {
|
|
117
|
+
kind: 'softmax_causal_last';
|
|
118
|
+
out: number;
|
|
119
|
+
a: number;
|
|
120
|
+
} | {
|
|
121
|
+
kind: 'log_softmax_last';
|
|
122
|
+
out: number;
|
|
123
|
+
a: number;
|
|
124
|
+
} | {
|
|
125
|
+
kind: 'where_causal';
|
|
126
|
+
out: number;
|
|
127
|
+
a: number;
|
|
128
|
+
fillValue: number;
|
|
129
|
+
} | {
|
|
130
|
+
kind: 'less';
|
|
131
|
+
out: number;
|
|
132
|
+
a: number;
|
|
133
|
+
b: number;
|
|
134
|
+
} | {
|
|
135
|
+
kind: 'greater';
|
|
136
|
+
out: number;
|
|
137
|
+
a: number;
|
|
138
|
+
b: number;
|
|
139
|
+
} | {
|
|
140
|
+
kind: 'where';
|
|
141
|
+
out: number;
|
|
142
|
+
cond: number;
|
|
143
|
+
a: number;
|
|
144
|
+
b: number;
|
|
145
|
+
} | {
|
|
146
|
+
kind: 'adam_update_m';
|
|
147
|
+
out: number;
|
|
148
|
+
m: number;
|
|
149
|
+
g: number;
|
|
150
|
+
b1: number;
|
|
151
|
+
} | {
|
|
152
|
+
kind: 'adam_update_v';
|
|
153
|
+
out: number;
|
|
154
|
+
v: number;
|
|
155
|
+
g: number;
|
|
156
|
+
b2: number;
|
|
157
|
+
} | {
|
|
158
|
+
kind: 'adam_update_p';
|
|
159
|
+
out: number;
|
|
160
|
+
p: number;
|
|
161
|
+
mNew: number;
|
|
162
|
+
vNew: number;
|
|
163
|
+
lrt: number;
|
|
164
|
+
eps: number;
|
|
165
|
+
decayShrink: number;
|
|
166
|
+
decayShrinkTensor: number | null;
|
|
167
|
+
} | {
|
|
168
|
+
kind: 'slice_last_range';
|
|
169
|
+
out: number;
|
|
170
|
+
a: number;
|
|
171
|
+
start: number;
|
|
172
|
+
end: number;
|
|
173
|
+
} | {
|
|
174
|
+
kind: 'broadcast_to';
|
|
175
|
+
out: number;
|
|
176
|
+
a: number;
|
|
177
|
+
targetShape: Shape;
|
|
178
|
+
} | {
|
|
179
|
+
kind: 'sum_to_shape';
|
|
180
|
+
out: number;
|
|
181
|
+
a: number;
|
|
182
|
+
targetShape: Shape;
|
|
183
|
+
} | {
|
|
184
|
+
kind: 'const_scalar';
|
|
185
|
+
out: number;
|
|
186
|
+
value: number;
|
|
187
|
+
dtype: Dtype;
|
|
188
|
+
} | {
|
|
189
|
+
kind: 'relu_grad';
|
|
190
|
+
out: number;
|
|
191
|
+
x: number;
|
|
192
|
+
dy: number;
|
|
193
|
+
};
|
|
194
|
+
interface Graph {
|
|
195
|
+
readonly ops: OpNode[];
|
|
196
|
+
readonly tensors: Tensor[];
|
|
197
|
+
readonly outputs: number[];
|
|
198
|
+
readonly captures: Map<string, number>;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
declare class ShapeError extends Error {
|
|
202
|
+
constructor(message: string, site: CallSite | null);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
declare function trace(fn: () => Tensor | Tensor[]): Graph;
|
|
206
|
+
declare function traceInto<T>(g: Graph, fn: () => T): T;
|
|
207
|
+
declare function paramInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
208
|
+
declare function tensorInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
209
|
+
declare function stateInput(name: string, shape: Shape, dtype?: Dtype, initValue?: number): Tensor;
|
|
210
|
+
|
|
211
|
+
declare function capture<T extends Tensor>(name: string, t: T): T;
|
|
212
|
+
|
|
213
|
+
declare function add(a: Tensor, b: Tensor | number): Tensor;
|
|
214
|
+
declare function sub(a: Tensor, b: Tensor | number): Tensor;
|
|
215
|
+
declare function mul(a: Tensor, b: Tensor | number): Tensor;
|
|
216
|
+
declare function div(a: Tensor, b: Tensor | number): Tensor;
|
|
217
|
+
declare const sqrt: (a: Tensor) => Tensor;
|
|
218
|
+
declare const rsqrt: (a: Tensor) => Tensor;
|
|
219
|
+
declare const log: (a: Tensor) => Tensor;
|
|
220
|
+
declare const exp: (a: Tensor) => Tensor;
|
|
221
|
+
declare const relu: (a: Tensor) => Tensor;
|
|
222
|
+
declare function meanLast(a: Tensor): Tensor;
|
|
223
|
+
declare function sumLast(a: Tensor): Tensor;
|
|
224
|
+
/** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
|
|
225
|
+
declare function sumAll(a: Tensor): Tensor;
|
|
226
|
+
declare function reshape(a: Tensor, newShape: Shape): Tensor;
|
|
227
|
+
declare function transpose(a: Tensor, perm: readonly number[]): Tensor;
|
|
228
|
+
/** Swap two axes of a tensor. Negative indices count from the end (so
|
|
229
|
+
* `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
|
|
230
|
+
* All other axes keep their position. Implemented as `transpose` with the
|
|
231
|
+
* permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
|
|
232
|
+
declare function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor;
|
|
233
|
+
declare function matmul(a: Tensor, b: Tensor): Tensor;
|
|
234
|
+
declare function matmulBatched(a: Tensor, b: Tensor): Tensor;
|
|
235
|
+
declare function oneHot(indices: Tensor, depth: number, dtype?: Dtype): Tensor;
|
|
236
|
+
/** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
|
|
237
|
+
* to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
|
|
238
|
+
* scatter-with-atomic-add backward — the matmul transpose rule handles it.
|
|
239
|
+
* `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
|
|
240
|
+
* is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
|
|
241
|
+
declare function embedding(table: Tensor, indices: Tensor): Tensor;
|
|
242
|
+
declare function arange(n: number, dtype?: Dtype): Tensor;
|
|
243
|
+
declare function softmaxCausalLast(a: Tensor): Tensor;
|
|
244
|
+
declare function logSoftmaxLast(a: Tensor): Tensor;
|
|
245
|
+
declare function whereCausal(a: Tensor, fillValue: number): Tensor;
|
|
246
|
+
declare function sliceLastRange(a: Tensor, start: number, end: number): Tensor;
|
|
247
|
+
declare const less: (a: Tensor, b: Tensor) => Tensor;
|
|
248
|
+
declare const greater: (a: Tensor, b: Tensor) => Tensor;
|
|
249
|
+
declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
|
|
250
|
+
|
|
251
|
+
interface GradResult {
|
|
252
|
+
readonly graph: Graph;
|
|
253
|
+
readonly paramGrads: Record<string, Tensor>;
|
|
254
|
+
readonly loss: Tensor;
|
|
255
|
+
}
|
|
256
|
+
declare function appendGrad(graph: Graph): GradResult;
|
|
257
|
+
|
|
258
|
+
interface BufferSpec {
|
|
259
|
+
/** Matches tensor.id. */
|
|
260
|
+
id: number;
|
|
261
|
+
byteSize: number;
|
|
262
|
+
dtype: Dtype;
|
|
263
|
+
shape: Shape;
|
|
264
|
+
kind: 'param' | 'param_grad' | 'tensor_input' | 'state' | 'intermediate' | 'output';
|
|
265
|
+
/** External name for param/param_grad/tensor_input/state bindings. null otherwise. */
|
|
266
|
+
name: string | null;
|
|
267
|
+
/** For state buffers: the value to fill on initial allocation. 0 by default. */
|
|
268
|
+
initValue?: number;
|
|
269
|
+
}
|
|
270
|
+
/**
|
|
271
|
+
* After step(), copy `source`'s buffer into `dest`'s buffer.
|
|
272
|
+
* Used to write back updated optimizer state and updated parameters into
|
|
273
|
+
* their persistent home buffers.
|
|
274
|
+
*/
|
|
275
|
+
interface Writeback {
|
|
276
|
+
source: number;
|
|
277
|
+
dest: number;
|
|
278
|
+
bytes: number;
|
|
279
|
+
}
|
|
280
|
+
interface BufferPlan {
|
|
281
|
+
buffers: BufferSpec[];
|
|
282
|
+
/** Tensor id -> buffer id (currently 1:1 but kept opaque for future pooling). */
|
|
283
|
+
tensorToBuffer: Map<number, number>;
|
|
284
|
+
/** Easy lookup tables for the runtime. */
|
|
285
|
+
paramsByName: Map<string, number>;
|
|
286
|
+
inputsByName: Map<string, number>;
|
|
287
|
+
paramGradsByName: Map<string, number>;
|
|
288
|
+
statesByName: Map<string, number>;
|
|
289
|
+
capturesByName: Map<string, number>;
|
|
290
|
+
outputBufferIds: number[];
|
|
291
|
+
/** End-of-step writebacks (Adam updates for params, m, v, etc.) */
|
|
292
|
+
writebacks: Writeback[];
|
|
293
|
+
}
|
|
294
|
+
/**
|
|
295
|
+
* Caller-supplied writeback declarations: "after each step, copy this Tensor's
|
|
296
|
+
* buffer into the persistent home of this param/state."
|
|
297
|
+
*/
|
|
298
|
+
interface WritebackDecl {
|
|
299
|
+
/** The Tensor (output of some op) holding the new value to write back. */
|
|
300
|
+
source: Tensor;
|
|
301
|
+
/** Either a param name (writes to that param's home buffer) or a state name. */
|
|
302
|
+
destName: string;
|
|
303
|
+
destKind: 'param' | 'state';
|
|
304
|
+
}
|
|
305
|
+
/**
|
|
306
|
+
* Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
|
|
307
|
+
* @param graph the full graph (forward + backward + any optimizer ops)
|
|
308
|
+
* @param paramGrads map from param name -> the Tensor that holds its gradient
|
|
309
|
+
* @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
|
|
310
|
+
* Empty when there's no optimizer in the graph.
|
|
311
|
+
*/
|
|
312
|
+
declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
|
|
313
|
+
|
|
314
|
+
interface AdamConfig {
|
|
315
|
+
/** Constant scalar (e.g., `0.005`) or a per-step schedule function
|
|
316
|
+
* `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
|
|
317
|
+
* or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
|
|
318
|
+
* per-step automatically when this is a function. */
|
|
319
|
+
lr: number | ((step: number) => number);
|
|
320
|
+
b1?: number;
|
|
321
|
+
b2?: number;
|
|
322
|
+
eps?: number;
|
|
323
|
+
/** AdamW: decoupled weight decay coefficient. Default 0 (plain Adam).
|
|
324
|
+
* When non-zero, every step shrinks each decayed param by a factor of
|
|
325
|
+
* `1 - lr * weightDecay` before the gradient update. */
|
|
326
|
+
weightDecay?: number;
|
|
327
|
+
/** Filter deciding which params get weight decay. Only consulted when
|
|
328
|
+
* weightDecay > 0. Default: decay every param. Override for the standard
|
|
329
|
+
* transformer convention (decay weights/embeddings, skip biases + LN gains).
|
|
330
|
+
* Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
|
|
331
|
+
decayFilter?: (paramName: string) => boolean;
|
|
332
|
+
}
|
|
333
|
+
/** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
|
|
334
|
+
interface AdamResolvedConfig {
|
|
335
|
+
lr: (step: number) => number;
|
|
336
|
+
b1: number;
|
|
337
|
+
b2: number;
|
|
338
|
+
eps: number;
|
|
339
|
+
weightDecay: number;
|
|
340
|
+
decayFilter: (name: string) => boolean;
|
|
341
|
+
/** True iff the user supplied an lr function (vs a constant). When false,
|
|
342
|
+
* decayShrink is baked at compile time and never updated. */
|
|
343
|
+
lrIsScheduled: boolean;
|
|
344
|
+
}
|
|
345
|
+
interface AdamResult {
|
|
346
|
+
/** Writebacks the buffer planner should wire into the runtime. */
|
|
347
|
+
writebacks: WritebackDecl[];
|
|
348
|
+
/** Name of the per-step scalar tensor_input. The runtime fills this each call
|
|
349
|
+
* with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
|
|
350
|
+
lrtInputName: string;
|
|
351
|
+
/** Name of the per-step decayShrink scalar tensor_input, or null when lr is
|
|
352
|
+
* static (decayShrink baked into the kernel) or no params are decayed. */
|
|
353
|
+
decayShrinkInputName: string | null;
|
|
354
|
+
/** Hyperparameters as captured (so the runtime can compute lrt and decayShrink). */
|
|
355
|
+
config: AdamResolvedConfig;
|
|
356
|
+
}
|
|
357
|
+
/**
|
|
358
|
+
* Append Adam update ops to `graph`. Must be called inside an active trace
|
|
359
|
+
* context (or after a trace, since traceInto re-enters the graph).
|
|
360
|
+
*
|
|
361
|
+
* @param graph the graph (already containing forward + backward)
|
|
362
|
+
* @param paramGrads param name -> gradient tensor (output of `appendGrad`)
|
|
363
|
+
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
364
|
+
* Needed because the param_input lives in the graph but we
|
|
365
|
+
* don't have a direct map by name in `Graph` — caller passes it.
|
|
366
|
+
* @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
|
|
367
|
+
* optional `decayFilter` selects which params receive decay.
|
|
368
|
+
*/
|
|
369
|
+
declare function appendAdam(graph: Graph, paramGrads: Record<string, Tensor>, paramTensors: Record<string, Tensor>, config: AdamConfig,
|
|
370
|
+
/** Per-param decay flags from `materializeParams`. When supplied, overrides
|
|
371
|
+
* `config.decayFilter` for any name in the map; falls back to `decayFilter`
|
|
372
|
+
* for names not present (e.g., for low-level callers using `compile()`
|
|
373
|
+
* directly without a Module). */
|
|
374
|
+
decayFlags?: Record<string, boolean>): AdamResult;
|
|
375
|
+
|
|
376
|
+
interface KernelSpec {
|
|
377
|
+
/** Index into graph.ops. */
|
|
378
|
+
opIndex: number;
|
|
379
|
+
/** Op kind (for debugging / pipeline cache key). */
|
|
380
|
+
opKind: OpNode['kind'];
|
|
381
|
+
/** Generated WGSL source. Empty string for "logical" ops with no kernel. */
|
|
382
|
+
wgsl: string;
|
|
383
|
+
/**
|
|
384
|
+
* Buffer ids in binding-index order. The runtime creates a bind group with
|
|
385
|
+
* these in @binding(0..N) on @group(0). Inputs come first (read), output last
|
|
386
|
+
* (read_write).
|
|
387
|
+
*/
|
|
388
|
+
bindings: number[];
|
|
389
|
+
/** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
|
|
390
|
+
threads: number;
|
|
391
|
+
/** Workgroup size; usually WG_SIZE. */
|
|
392
|
+
workgroupSize: number;
|
|
393
|
+
}
|
|
394
|
+
/** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
|
|
395
|
+
declare function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[];
|
|
396
|
+
|
|
397
|
+
interface UploadParamsOptions {
|
|
398
|
+
/** Skip the "missing param" check, allowing the caller to update only some
|
|
399
|
+
* params and leave the rest at their current GPU values. Extra (unknown)
|
|
400
|
+
* keys are still rejected — that's always a typo. Default: false. */
|
|
401
|
+
partial?: boolean;
|
|
402
|
+
}
|
|
403
|
+
/**
|
|
404
|
+
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
405
|
+
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
406
|
+
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
407
|
+
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
408
|
+
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
409
|
+
* regardless of whether captures were read back.
|
|
410
|
+
*/
|
|
411
|
+
declare class Captures {
|
|
412
|
+
private readonly shapes;
|
|
413
|
+
private readonly data;
|
|
414
|
+
constructor(shapes: Record<string, readonly number[]>, data: Map<string, Float32Array>);
|
|
415
|
+
get(name: string): Float32Array;
|
|
416
|
+
shapeOf(name: string): readonly number[];
|
|
417
|
+
has(name: string): boolean;
|
|
418
|
+
names(): string[];
|
|
419
|
+
}
|
|
420
|
+
interface RunResult {
|
|
421
|
+
output: Float32Array;
|
|
422
|
+
captures: Captures;
|
|
423
|
+
}
|
|
424
|
+
interface StepResult {
|
|
425
|
+
loss: number;
|
|
426
|
+
captures: Captures;
|
|
427
|
+
}
|
|
428
|
+
interface RunOptions {
|
|
429
|
+
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
430
|
+
* Default false. When false, the returned `captures` is empty (calling
|
|
431
|
+
* `.get` throws); when true, captures are read back and accessible. */
|
|
432
|
+
withCaptures?: boolean;
|
|
433
|
+
}
|
|
434
|
+
/** Common surface for both training and forward-only compiled runtimes. */
|
|
435
|
+
interface CompiledBase {
|
|
436
|
+
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
437
|
+
* share the device, or use directly for other GPU work. */
|
|
438
|
+
device: GPUDevice;
|
|
439
|
+
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
440
|
+
* `sharedParams` to share without copies. */
|
|
441
|
+
params: Map<string, GPUBuffer>;
|
|
442
|
+
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
443
|
+
* returned tensor for forward-only compiles). */
|
|
444
|
+
outputShape: number[];
|
|
445
|
+
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
446
|
+
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
447
|
+
* `{ partial: true }` to skip the missing-key check. */
|
|
448
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
|
|
449
|
+
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
450
|
+
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
451
|
+
/** Free GPU resources. */
|
|
452
|
+
destroy(): void;
|
|
453
|
+
}
|
|
454
|
+
/** Run a dispatch and read back the full output tensor. Default returns the
|
|
455
|
+
* output as a `Float32Array`; with `{ withCaptures: true }` returns
|
|
456
|
+
* `{ output, captures }`. Same shape as `step()`'s overloads. */
|
|
457
|
+
interface RunFn {
|
|
458
|
+
(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
459
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
460
|
+
withCaptures: true;
|
|
461
|
+
}): Promise<RunResult>;
|
|
462
|
+
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
|
|
463
|
+
}
|
|
464
|
+
interface CompiledRuntime extends CompiledBase {
|
|
465
|
+
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
466
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
467
|
+
/**
|
|
468
|
+
* One full forward+backward step.
|
|
469
|
+
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
470
|
+
* 2. Dispatches every kernel in order.
|
|
471
|
+
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
472
|
+
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
473
|
+
* returns `{ loss, captures }`.
|
|
474
|
+
*/
|
|
475
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
476
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
477
|
+
withCaptures: true;
|
|
478
|
+
}): Promise<StepResult>;
|
|
479
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>;
|
|
480
|
+
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
481
|
+
* training graphs the output is a scalar loss, so step() is usually more
|
|
482
|
+
* convenient. Provided for parity with `compileForward`. */
|
|
483
|
+
run: RunFn;
|
|
484
|
+
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
485
|
+
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
486
|
+
resetOptimizerState(): void;
|
|
487
|
+
}
|
|
488
|
+
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
489
|
+
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
490
|
+
interface CompiledForward extends CompiledBase {
|
|
491
|
+
run: RunFn;
|
|
492
|
+
}
|
|
493
|
+
interface RuntimeOpts {
|
|
494
|
+
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
495
|
+
device?: GPUDevice;
|
|
496
|
+
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
497
|
+
* by param name. Used to share params between a training compile and a
|
|
498
|
+
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
499
|
+
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
500
|
+
* allocates as usual. */
|
|
501
|
+
sharedParams?: Map<string, GPUBuffer>;
|
|
502
|
+
}
|
|
503
|
+
declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
|
|
504
|
+
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
505
|
+
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
506
|
+
* loss readback). The full runtime object is built once and projected by
|
|
507
|
+
* `compileForward` to the public shape. */
|
|
508
|
+
declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
|
|
509
|
+
|
|
510
|
+
/** How a parameter's initial values are produced.
|
|
511
|
+
* - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
|
|
512
|
+
* weight matrices and embeddings.
|
|
513
|
+
* - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
|
|
514
|
+
* - `'ones'` — fill with 1. Common for LayerNorm gain.
|
|
515
|
+
* - Custom function — receives total element count and shape, returns the
|
|
516
|
+
* Float32Array. Use for fan-in scaling or any non-standard scheme.
|
|
517
|
+
*/
|
|
518
|
+
type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
|
|
519
|
+
interface ParamOptions {
|
|
520
|
+
dtype?: Dtype;
|
|
521
|
+
/** Init kind. Default: `'randn'`. */
|
|
522
|
+
init?: InitSpec;
|
|
523
|
+
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
524
|
+
scale?: number;
|
|
525
|
+
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
526
|
+
* decay to this param. Default: `true` for `'randn'` init (weight matrices,
|
|
527
|
+
* embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
|
|
528
|
+
* to force or skip. Replaces `adam.decayFilter` for the common case. */
|
|
529
|
+
decay?: boolean;
|
|
530
|
+
}
|
|
531
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array;
|
|
532
|
+
declare abstract class Module {
|
|
533
|
+
/**
|
|
534
|
+
* Declare a learnable parameter at this module. Must be called from inside
|
|
535
|
+
* the constructor (typically as a field assignment). Returns a placeholder
|
|
536
|
+
* that gets replaced with a real Tensor at compile time.
|
|
537
|
+
*
|
|
538
|
+
* The parameter's name is auto-derived from its property path in the model
|
|
539
|
+
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
540
|
+
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
541
|
+
*/
|
|
542
|
+
protected param(shape: Shape, opts?: ParamOptions): Tensor;
|
|
543
|
+
}
|
|
544
|
+
interface MaterializedParams {
|
|
545
|
+
/** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
|
|
546
|
+
tensors: Record<string, Tensor>;
|
|
547
|
+
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
548
|
+
initFns: Record<string, InitFn>;
|
|
549
|
+
/** Whether this param should receive AdamW weight decay. Resolved at
|
|
550
|
+
* `param()` time from `ParamOptions.decay` (with init-based default). */
|
|
551
|
+
decayFlags: Record<string, boolean>;
|
|
552
|
+
}
|
|
553
|
+
/**
|
|
554
|
+
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
555
|
+
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
556
|
+
* trace context (paramInput appends to the current graph).
|
|
557
|
+
*
|
|
558
|
+
* Returns the param tensors keyed by path, plus init functions for use by
|
|
559
|
+
* `uploadInitialParams`.
|
|
560
|
+
*/
|
|
561
|
+
declare function materializeParams(root: Module): MaterializedParams;
|
|
562
|
+
|
|
563
|
+
/** Declares one input tensor of the model's forward function. The name is the
|
|
564
|
+
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
565
|
+
* `run()` data object at runtime. */
|
|
566
|
+
interface InputDecl {
|
|
567
|
+
shape: Shape;
|
|
568
|
+
dtype?: Dtype;
|
|
569
|
+
}
|
|
570
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. The name
|
|
571
|
+
* doubles as the key the forward fn destructures and the key the runtime
|
|
572
|
+
* expects in `step({...})` / `run({...})`. */
|
|
573
|
+
type InputDecls = Record<string, InputDecl>;
|
|
574
|
+
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
575
|
+
* same keys, each value is a Tensor. Used to type the forward function's
|
|
576
|
+
* `inputs` argument from the declared shape Record. */
|
|
577
|
+
type InputsTensors<I extends InputDecls> = {
|
|
578
|
+
[K in keyof I]: Tensor;
|
|
579
|
+
};
|
|
580
|
+
/** Forward function shape: takes the materialized model and a Record of
|
|
581
|
+
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
582
|
+
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
583
|
+
* The second generic flows from the inputs declaration so destructuring
|
|
584
|
+
* the input record stays typed. */
|
|
585
|
+
type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
|
|
586
|
+
interface CompiledIR {
|
|
587
|
+
graph: GradResult['graph'];
|
|
588
|
+
paramGrads: GradResult['paramGrads'];
|
|
589
|
+
loss: Tensor;
|
|
590
|
+
plan: BufferPlan;
|
|
591
|
+
kernels: KernelSpec[];
|
|
592
|
+
}
|
|
593
|
+
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
594
|
+
declare function compileToIR(traceFn: () => Tensor): CompiledIR;
|
|
595
|
+
/** Full compile pipeline. Browser-only because it creates a GPUDevice. */
|
|
596
|
+
declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
|
|
597
|
+
ir: CompiledIR;
|
|
598
|
+
}>;
|
|
599
|
+
interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
600
|
+
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
601
|
+
* fn destructures these out of its second argument; runtime calls to
|
|
602
|
+
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
603
|
+
inputs?: I;
|
|
604
|
+
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
605
|
+
adam?: AdamConfig;
|
|
606
|
+
}
|
|
607
|
+
interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
608
|
+
/** Per-step data inputs to the forward function, keyed by name. */
|
|
609
|
+
inputs?: I;
|
|
610
|
+
}
|
|
611
|
+
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
612
|
+
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
613
|
+
* (auto-supplied from the train graph's params). */
|
|
614
|
+
interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
615
|
+
inputs?: I;
|
|
616
|
+
}
|
|
617
|
+
/** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
|
|
618
|
+
* sibling-graph compile) on top of the base runtime. */
|
|
619
|
+
interface CompiledModule<M extends Module> extends CompiledRuntime {
|
|
620
|
+
ir: CompiledIR;
|
|
621
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
622
|
+
kernelCount: number;
|
|
623
|
+
/** Re-initialize all params from their declared init specs and zero the
|
|
624
|
+
* optimizer state. Use to start training over without recompiling. */
|
|
625
|
+
reset(): void;
|
|
626
|
+
/** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
|
|
627
|
+
* B=N held-out eval graph) that shares this runtime's device and param
|
|
628
|
+
* buffers. Pass the forward fn (typically distinct from your loss fn —
|
|
629
|
+
* it returns logits, not a scalar) and any shape changes via `inputs`.
|
|
630
|
+
* Auto-initialization is a no-op since params are shared. */
|
|
631
|
+
compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
|
|
632
|
+
}
|
|
633
|
+
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
634
|
+
interface CompiledForwardModule extends CompiledForward {
|
|
635
|
+
ir: CompiledIR;
|
|
636
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
637
|
+
kernelCount: number;
|
|
638
|
+
}
|
|
639
|
+
/**
|
|
640
|
+
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
641
|
+
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
642
|
+
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
643
|
+
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
644
|
+
*
|
|
645
|
+
* The forward function takes the materialized model and a Record of named
|
|
646
|
+
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
647
|
+
* `inputs:` declaration:
|
|
648
|
+
*
|
|
649
|
+
* inputs: {
|
|
650
|
+
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
651
|
+
* targets: { shape: [B, T], dtype: 'i32' },
|
|
652
|
+
* }
|
|
653
|
+
* forward: (m, { tokens, targets }) => …
|
|
654
|
+
*
|
|
655
|
+
* Walks the module tree to materialize params with auto-derived names, then
|
|
656
|
+
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
657
|
+
* parameter values are uploaded automatically before this function returns;
|
|
658
|
+
* call `reset()` later to re-randomize.
|
|
659
|
+
*
|
|
660
|
+
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
661
|
+
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
662
|
+
* users don't need to provide it themselves.
|
|
663
|
+
*/
|
|
664
|
+
declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
|
|
665
|
+
/**
|
|
666
|
+
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
667
|
+
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
668
|
+
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
669
|
+
* `Float32Array`.
|
|
670
|
+
*
|
|
671
|
+
* **Prefer the `compileForward` method on a training runtime** when both
|
|
672
|
+
* graphs use the same Module class — it auto-supplies `device` and
|
|
673
|
+
* `sharedParams`. This standalone form is for forward-only models with no
|
|
674
|
+
* training graph at all, or for sharing params across a different model.
|
|
675
|
+
*
|
|
676
|
+
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
677
|
+
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
678
|
+
* training runtime's GPU buffers — every train step is then immediately
|
|
679
|
+
* visible to `run()` calls here, no copies.
|
|
680
|
+
*
|
|
681
|
+
* Initial param values are uploaded automatically for params *not* covered
|
|
682
|
+
* by `sharedParams` (those are owned by the sibling compile).
|
|
683
|
+
*/
|
|
684
|
+
declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
|
|
685
|
+
|
|
686
|
+
interface LinearOptions {
|
|
687
|
+
/** Include a bias term (default true). */
|
|
688
|
+
bias?: boolean;
|
|
689
|
+
}
|
|
690
|
+
declare class Linear extends Module {
|
|
691
|
+
readonly inDim: number;
|
|
692
|
+
readonly outDim: number;
|
|
693
|
+
W: Tensor;
|
|
694
|
+
b: Tensor | null;
|
|
695
|
+
constructor(inDim: number, outDim: number, opts?: LinearOptions);
|
|
696
|
+
fwd(x: Tensor): Tensor;
|
|
697
|
+
}
|
|
698
|
+
declare class LayerNorm extends Module {
|
|
699
|
+
readonly d: number;
|
|
700
|
+
readonly eps: number;
|
|
701
|
+
g: Tensor;
|
|
702
|
+
b: Tensor;
|
|
703
|
+
constructor(d: number, eps?: number);
|
|
704
|
+
fwd(x: Tensor): Tensor;
|
|
705
|
+
}
|
|
706
|
+
/** [..., T, D] → [..., H, T, D/H]. Folds the standard
|
|
707
|
+
* `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
|
|
708
|
+
* call. Last dim of `x` must divide evenly by `nHeads`. */
|
|
709
|
+
declare function splitHeads(x: Tensor, nHeads: number): Tensor;
|
|
710
|
+
/** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
|
|
711
|
+
declare function mergeHeads(x: Tensor): Tensor;
|
|
712
|
+
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
713
|
+
* the static shape registered at compile time. The leading axis is treated as
|
|
714
|
+
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
715
|
+
* stripped if present so callers can pass capture names directly. Throws if
|
|
716
|
+
* the capture isn't registered or wasn't read back this call. */
|
|
717
|
+
declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
|
|
718
|
+
/** Per-position cross-entropy along the last (vocab) axis: returns
|
|
719
|
+
* `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
|
|
720
|
+
* `[...]` of i32; result is `[...]` (one rank less than logits). The user
|
|
721
|
+
* applies their own masking + reduction downstream — useful when only some
|
|
722
|
+
* positions contribute (e.g. result-digit masking) or for label smoothing. */
|
|
723
|
+
declare function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor;
|
|
724
|
+
|
|
725
|
+
type nn_d_LayerNorm = LayerNorm;
|
|
726
|
+
declare const nn_d_LayerNorm: typeof LayerNorm;
|
|
727
|
+
type nn_d_Linear = Linear;
|
|
728
|
+
declare const nn_d_Linear: typeof Linear;
|
|
729
|
+
type nn_d_LinearOptions = LinearOptions;
|
|
730
|
+
declare const nn_d_crossEntropyLast: typeof crossEntropyLast;
|
|
731
|
+
declare const nn_d_mergeHeads: typeof mergeHeads;
|
|
732
|
+
declare const nn_d_splitHeads: typeof splitHeads;
|
|
733
|
+
declare const nn_d_unsplitHeads: typeof unsplitHeads;
|
|
734
|
+
declare namespace nn_d {
|
|
735
|
+
export { nn_d_LayerNorm as LayerNorm, nn_d_Linear as Linear, nn_d_crossEntropyLast as crossEntropyLast, nn_d_mergeHeads as mergeHeads, nn_d_splitHeads as splitHeads, nn_d_unsplitHeads as unsplitHeads };
|
|
736
|
+
export type { nn_d_LinearOptions as LinearOptions };
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compile, compileForward, compileModule, compileToIR, createForwardRuntime, createRuntime, div, embedding, emitKernels, exp, greater, less, log, logSoftmaxLast, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
|
|
740
|
+
export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForward, CompiledForwardModule, CompiledIR, CompiledModule, CompiledRuntime, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, RuntimeOpts, Shape, StepResult, Tensor, Writeback, WritebackDecl };
|