tensorgrad 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +121 -0
  3. package/SPEC.md +293 -0
  4. package/dist/adam.d.ts +31 -0
  5. package/dist/adam.d.ts.map +1 -0
  6. package/dist/adam.js +66 -0
  7. package/dist/adam.js.map +1 -0
  8. package/dist/buffers.d.ts +56 -0
  9. package/dist/buffers.d.ts.map +1 -0
  10. package/dist/buffers.js +114 -0
  11. package/dist/buffers.js.map +1 -0
  12. package/dist/codegen.d.ts +23 -0
  13. package/dist/codegen.d.ts.map +1 -0
  14. package/dist/codegen.js +709 -0
  15. package/dist/codegen.js.map +1 -0
  16. package/dist/compile.d.ts +53 -0
  17. package/dist/compile.d.ts.map +1 -0
  18. package/dist/compile.js +76 -0
  19. package/dist/compile.js.map +1 -0
  20. package/dist/grad.d.ts +8 -0
  21. package/dist/grad.d.ts.map +1 -0
  22. package/dist/grad.js +404 -0
  23. package/dist/grad.js.map +1 -0
  24. package/dist/index.d.ts +12 -0
  25. package/dist/index.d.ts.map +1 -0
  26. package/dist/index.js +37 -0
  27. package/dist/index.js.map +1 -0
  28. package/dist/ir.d.ts +204 -0
  29. package/dist/ir.d.ts.map +1 -0
  30. package/dist/ir.js +60 -0
  31. package/dist/ir.js.map +1 -0
  32. package/dist/module.d.ts +21 -0
  33. package/dist/module.d.ts.map +1 -0
  34. package/dist/module.js +113 -0
  35. package/dist/module.js.map +1 -0
  36. package/dist/ops.d.ts +35 -0
  37. package/dist/ops.d.ts.map +1 -0
  38. package/dist/ops.js +270 -0
  39. package/dist/ops.js.map +1 -0
  40. package/dist/runtime.d.ts +26 -0
  41. package/dist/runtime.d.ts.map +1 -0
  42. package/dist/runtime.js +190 -0
  43. package/dist/runtime.js.map +1 -0
  44. package/dist/shape.d.ts +24 -0
  45. package/dist/shape.d.ts.map +1 -0
  46. package/dist/shape.js +259 -0
  47. package/dist/shape.js.map +1 -0
  48. package/dist/trace.d.ts +8 -0
  49. package/dist/trace.d.ts.map +1 -0
  50. package/dist/trace.js +93 -0
  51. package/dist/trace.js.map +1 -0
  52. package/package.json +62 -0
  53. package/src/adam.ts +95 -0
  54. package/src/buffers.ts +173 -0
  55. package/src/codegen.ts +758 -0
  56. package/src/compile.ts +120 -0
  57. package/src/grad.ts +459 -0
  58. package/src/index.ts +40 -0
  59. package/src/ir.ts +197 -0
  60. package/src/module.ts +126 -0
  61. package/src/ops.ts +311 -0
  62. package/src/runtime.ts +232 -0
  63. package/src/shape.ts +263 -0
  64. package/src/trace.ts +101 -0
@@ -0,0 +1,12 @@
1
+ export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js';
2
+ export { ShapeError } from './shape.js';
3
+ export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
4
+ export { add, sub, mul, div, sqrt, rsqrt, log, exp, relu, less, greater, where, meanLast, sumLast, reshape, transpose, matmul, matmulBatched, oneHot, arange, softmaxCausalLast, logSoftmaxLast, whereCausal, sliceLastRange, } from './ops.js';
5
+ export { appendGrad, type GradResult } from './grad.js';
6
+ export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
7
+ export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
8
+ export { emitKernels, type KernelSpec } from './codegen.js';
9
+ export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js';
10
+ export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js';
11
+ export { Module, materializeParams } from './module.js';
12
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAEjB,OAAO,EAAE,SAAS,EAElB,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAEd,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;AACpF,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC9H,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA"}
package/dist/index.js ADDED
@@ -0,0 +1,37 @@
1
+ // Public surface. Bulb code imports from here.
2
+ //
3
+ // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
+ // codegen / compile() (Phase 3+) come later.
5
+ export { ShapeError } from './shape.js';
6
+ export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
7
+ export {
8
+ // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
9
+ add, sub, mul, div,
10
+ // Element-wise unary
11
+ sqrt, rsqrt, log, exp, relu,
12
+ // Comparisons + select
13
+ less, greater, where,
14
+ // Reductions over the last axis (other axes via reshape/transpose first)
15
+ meanLast, sumLast,
16
+ // Shape ops
17
+ reshape, transpose,
18
+ // Linear algebra
19
+ matmul, matmulBatched,
20
+ // Indexing / casting
21
+ oneHot, arange,
22
+ // ML primitives — fused for the transformer
23
+ softmaxCausalLast, logSoftmaxLast, whereCausal,
24
+ // Slicing
25
+ sliceLastRange, } from './ops.js';
26
+ // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
27
+ // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
28
+ // adam.ts can import them) but aren't part of the public API — `add`/`mul`
29
+ // overload on JS numbers, `where` subsumes the rest.
30
+ export { appendGrad } from './grad.js';
31
+ export { appendAdam } from './adam.js';
32
+ export { planBuffers } from './buffers.js';
33
+ export { emitKernels } from './codegen.js';
34
+ export { createRuntime } from './runtime.js';
35
+ export { compile, compileToIR, compileModule } from './compile.js';
36
+ export { Module, materializeParams } from './module.js';
37
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO;AACjB,YAAY;AACZ,OAAO,EAAE,SAAS;AAClB,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM;AACd,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAA0C,MAAM,cAAc,CAAA;AACpF,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAA8D,MAAM,cAAc,CAAA;AAC9H,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA"}
package/dist/ir.d.ts ADDED
@@ -0,0 +1,204 @@
1
+ export type Dtype = 'f32' | 'i32' | 'bool';
2
+ export type Shape = readonly number[];
3
+ export interface Tensor {
4
+ readonly id: number;
5
+ readonly shape: Shape;
6
+ readonly dtype: Dtype;
7
+ readonly source: number | null;
8
+ readonly site: CallSite | null;
9
+ }
10
+ export interface CallSite {
11
+ readonly opName: string;
12
+ readonly stack: string;
13
+ }
14
+ export type OpNode = {
15
+ kind: 'param_input';
16
+ out: number;
17
+ name: string;
18
+ } | {
19
+ kind: 'tensor_input';
20
+ out: number;
21
+ name: string;
22
+ } | {
23
+ kind: 'state_input';
24
+ out: number;
25
+ name: string;
26
+ initValue: number;
27
+ } | {
28
+ kind: 'add';
29
+ out: number;
30
+ a: number;
31
+ b: number;
32
+ } | {
33
+ kind: 'sub';
34
+ out: number;
35
+ a: number;
36
+ b: number;
37
+ } | {
38
+ kind: 'mul';
39
+ out: number;
40
+ a: number;
41
+ b: number;
42
+ } | {
43
+ kind: 'div';
44
+ out: number;
45
+ a: number;
46
+ b: number;
47
+ } | {
48
+ kind: 'mul_scalar';
49
+ out: number;
50
+ a: number;
51
+ scalar: number;
52
+ } | {
53
+ kind: 'add_scalar';
54
+ out: number;
55
+ a: number;
56
+ scalar: number;
57
+ } | {
58
+ kind: 'sqrt';
59
+ out: number;
60
+ a: number;
61
+ } | {
62
+ kind: 'rsqrt';
63
+ out: number;
64
+ a: number;
65
+ } | {
66
+ kind: 'log';
67
+ out: number;
68
+ a: number;
69
+ } | {
70
+ kind: 'exp';
71
+ out: number;
72
+ a: number;
73
+ } | {
74
+ kind: 'relu';
75
+ out: number;
76
+ a: number;
77
+ } | {
78
+ kind: 'mean_last';
79
+ out: number;
80
+ a: number;
81
+ } | {
82
+ kind: 'sum_last';
83
+ out: number;
84
+ a: number;
85
+ } | {
86
+ kind: 'reshape';
87
+ out: number;
88
+ a: number;
89
+ newShape: Shape;
90
+ } | {
91
+ kind: 'transpose';
92
+ out: number;
93
+ a: number;
94
+ perm: readonly number[];
95
+ } | {
96
+ kind: 'matmul';
97
+ out: number;
98
+ a: number;
99
+ b: number;
100
+ } | {
101
+ kind: 'matmul_batched';
102
+ out: number;
103
+ a: number;
104
+ b: number;
105
+ } | {
106
+ kind: 'one_hot';
107
+ out: number;
108
+ indices: number;
109
+ depth: number;
110
+ dtype: Dtype;
111
+ } | {
112
+ kind: 'arange';
113
+ out: number;
114
+ n: number;
115
+ dtype: Dtype;
116
+ } | {
117
+ kind: 'softmax_causal_last';
118
+ out: number;
119
+ a: number;
120
+ } | {
121
+ kind: 'log_softmax_last';
122
+ out: number;
123
+ a: number;
124
+ } | {
125
+ kind: 'where_causal';
126
+ out: number;
127
+ a: number;
128
+ fillValue: number;
129
+ } | {
130
+ kind: 'less';
131
+ out: number;
132
+ a: number;
133
+ b: number;
134
+ } | {
135
+ kind: 'greater';
136
+ out: number;
137
+ a: number;
138
+ b: number;
139
+ } | {
140
+ kind: 'where';
141
+ out: number;
142
+ cond: number;
143
+ a: number;
144
+ b: number;
145
+ } | {
146
+ kind: 'adam_update_m';
147
+ out: number;
148
+ m: number;
149
+ g: number;
150
+ b1: number;
151
+ } | {
152
+ kind: 'adam_update_v';
153
+ out: number;
154
+ v: number;
155
+ g: number;
156
+ b2: number;
157
+ } | {
158
+ kind: 'adam_update_p';
159
+ out: number;
160
+ p: number;
161
+ mNew: number;
162
+ vNew: number;
163
+ lrt: number;
164
+ eps: number;
165
+ } | {
166
+ kind: 'slice_last_range';
167
+ out: number;
168
+ a: number;
169
+ start: number;
170
+ end: number;
171
+ } | {
172
+ kind: 'broadcast_to';
173
+ out: number;
174
+ a: number;
175
+ targetShape: Shape;
176
+ } | {
177
+ kind: 'sum_to_shape';
178
+ out: number;
179
+ a: number;
180
+ targetShape: Shape;
181
+ } | {
182
+ kind: 'const_scalar';
183
+ out: number;
184
+ value: number;
185
+ dtype: Dtype;
186
+ } | {
187
+ kind: 'relu_grad';
188
+ out: number;
189
+ x: number;
190
+ dy: number;
191
+ };
192
+ export interface Graph {
193
+ readonly ops: OpNode[];
194
+ readonly tensors: Tensor[];
195
+ readonly outputs: number[];
196
+ }
197
+ export declare function makeGraph(): Graph;
198
+ export declare function addTensor(g: Graph, shape: Shape, dtype: Dtype, source: number | null, site: CallSite | null): Tensor;
199
+ export declare function addOp<K extends OpNode['kind']>(g: Graph, kind: K, shape: Shape, dtype: Dtype, site: CallSite | null, fields: Omit<Extract<OpNode, {
200
+ kind: K;
201
+ }>, 'kind' | 'out'>): Tensor;
202
+ export declare function captureSite(opName: string): CallSite;
203
+ export declare function formatSite(site: CallSite): string;
204
+ //# sourceMappingURL=ir.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GAKxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAMvG;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAGhF;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAGpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAEpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAElE;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,CAAA;AAI7D,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,GAAG,EAAE,MAAM,EAAE,CAAA;IACtB,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAG1B,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;CAC3B;AAED,wBAAgB,SAAS,IAAI,KAAK,CAEjC;AAGD,wBAAgB,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,MAAM,CAKpH;AAMD,wBAAgB,KAAK,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,CAAC,EAC5C,CAAC,EAAE,KAAK,EACR,IAAI,EAAE,CAAC,EACP,KAAK,EAAE,KAAK,EACZ,KAAK,EAAE,KAAK,EACZ,IAAI,EAAE,QAAQ,GAAG,IAAI,EACrB,MAAM,EAAE,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;IAAE,IAAI,EAAE,CAAC,CAAA;CAAE,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,GACzD,MAAM,CAMR;AAID,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,GAAG,QAAQ,CAIpD;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,QAAQ,GAAG,MAAM,CAYjD"}
package/dist/ir.js ADDED
@@ -0,0 +1,60 @@
1
+ // Intermediate representation for tensor computations.
2
+ //
3
+ // A `Graph` is a flat array of `OpNode`s in topological (= construction) order.
4
+ // A `Tensor` is an opaque handle: shape + dtype + a pointer back to the OpNode
5
+ // that produced it (or `null` for graph leaves — params and external inputs).
6
+ //
7
+ // This is the data structure everything else operates on:
8
+ // - tracing builds it (src/trace.ts)
9
+ // - autograd walks it in reverse to add backward nodes (src/grad.ts, later)
10
+ // - codegen reads it to emit WGSL kernels and a dispatch plan (src/codegen.ts, later)
11
+ //
12
+ // Design intent: keep this file boring. No tracing logic, no shape inference,
13
+ // no codegen — those live in their own modules and consume `Graph` / `OpNode`.
14
+ export function makeGraph() {
15
+ return { ops: [], tensors: [], outputs: [] };
16
+ }
17
+ // Internal: register a fresh tensor in the graph and return its id.
18
+ export function addTensor(g, shape, dtype, source, site) {
19
+ const id = g.tensors.length;
20
+ const t = { id, shape, dtype, source, site };
21
+ g.tensors.push(t);
22
+ return t;
23
+ }
24
+ // Internal: append an op and the tensor it produces. Returns the produced tensor.
25
+ // Generic over the specific op kind so callers don't need `as any` casts.
26
+ // `Extract<OpNode, { kind: K }>` narrows the union to the chosen variant, then
27
+ // `Omit` strips the parts addOp itself supplies (the kind tag and out tensor id).
28
+ export function addOp(g, kind, shape, dtype, site, fields) {
29
+ const opIndex = g.ops.length;
30
+ const out = addTensor(g, shape, dtype, opIndex, site);
31
+ const node = { kind, out: out.id, ...fields };
32
+ g.ops.push(node);
33
+ return out;
34
+ }
35
+ // Capture a call site without paying full Error formatting cost up-front.
36
+ // The stack is materialised but parsing/trimming is deferred to error reporting.
37
+ export function captureSite(opName) {
38
+ // Skip our own frame plus the op wrapper's frame; user's frame is what's left.
39
+ const stack = (new Error()).stack ?? '';
40
+ return { opName, stack };
41
+ }
42
+ // Format a CallSite for inclusion in a thrown error. Strips Tensorgrad frames
43
+ // and library internals so the user sees their code first.
44
+ export function formatSite(site) {
45
+ const lines = site.stack.split('\n');
46
+ // Stack starts with "Error" line; drop it. Then drop frames from this file
47
+ // and from src/ops.ts so the first surviving frame is user code.
48
+ const userFrames = [];
49
+ for (const line of lines.slice(1)) {
50
+ if (line.includes('/tensorgrad/src/') || line.includes('\\tensorgrad\\src\\'))
51
+ continue;
52
+ userFrames.push(line.trim());
53
+ if (userFrames.length >= 3)
54
+ break;
55
+ }
56
+ if (userFrames.length === 0)
57
+ return `[${site.opName}] (no user frame found)`;
58
+ return `[${site.opName}]\n ${userFrames.join('\n ')}`;
59
+ }
60
+ //# sourceMappingURL=ir.js.map
package/dist/ir.js.map ADDED
@@ -0,0 +1 @@
1
+ {"version":3,"file":"ir.js","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAAA,uDAAuD;AACvD,EAAE;AACF,gFAAgF;AAChF,+EAA+E;AAC/E,8EAA8E;AAC9E,EAAE;AACF,0DAA0D;AAC1D,uCAAuC;AACvC,8EAA8E;AAC9E,wFAAwF;AACxF,EAAE;AACF,8EAA8E;AAC9E,+EAA+E;AAmI/E,MAAM,UAAU,SAAS;IACvB,OAAO,EAAE,GAAG,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;AAC9C,CAAC;AAED,oEAAoE;AACpE,MAAM,UAAU,SAAS,CAAC,CAAQ,EAAE,KAAY,EAAE,KAAY,EAAE,MAAqB,EAAE,IAAqB;IAC1G,MAAM,EAAE,GAAG,CAAC,CAAC,OAAO,CAAC,MAAM,CAAA;IAC3B,MAAM,CAAC,GAAW,EAAE,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,CAAA;IACpD,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA;IACjB,OAAO,CAAC,CAAA;AACV,CAAC;AAED,kFAAkF;AAClF,0EAA0E;AAC1E,+EAA+E;AAC/E,kFAAkF;AAClF,MAAM,UAAU,KAAK,CACnB,CAAQ,EACR,IAAO,EACP,KAAY,EACZ,KAAY,EACZ,IAAqB,EACrB,MAA0D;IAE1D,MAAM,OAAO,GAAG,CAAC,CAAC,GAAG,CAAC,MAAM,CAAA;IAC5B,MAAM,GAAG,GAAG,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,CAAC,CAAA;IACrD,MAAM,IAAI,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,MAAM,EAAkC,CAAA;IAC7E,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;IAChB,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,0EAA0E;AAC1E,iFAAiF;AACjF,MAAM,UAAU,WAAW,CAAC,MAAc;IACxC,+EAA+E;IAC/E,MAAM,KAAK,GAAG,CAAC,IAAI,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,EAAE,CAAA;IACvC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAA;AAC1B,CAAC;AAED,8EAA8E;AAC9E,2DAA2D;AAC3D,MAAM,UAAU,UAAU,CAAC,IAAc;IACvC,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,CAAC,CAAA;IACpC,2EAA2E;IAC3E,iEAAiE;IACjE,MAAM,UAAU,GAAa,EAAE,CAAA;IAC/B,KAAK,MAAM,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;QAClC,IAAI,IAAI,CAAC,QAAQ,CAAC,kBAAkB,CAAC,IAAI,IAAI,CAAC,QAAQ,CAAC,qBAAqB,CAAC;YAAE,SAAQ;QACvF,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAA;QAC5B,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC;YAAE,MAAK;IACnC,CAAC;IACD,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,IAAI,IAAI,CAAC,MAAM,yBAAyB,CAAA;IAC5E,OAAO,IAAI,IAAI,CAAC,MAAM,QAAQ,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAA;AACzD,CAAC"}
@@ -0,0 +1,21 @@
1
+ import type { Tensor, Shape, Dtype } from './ir.js';
2
+ export declare abstract class Module {
3
+ /**
4
+ * Declare a learnable parameter at this module. Must be called from inside
5
+ * the constructor (typically as a field assignment). Returns a placeholder
6
+ * that gets replaced with a real Tensor at compile time.
7
+ *
8
+ * The parameter's name is auto-derived from its property path in the model
9
+ * tree (e.g. `layers.0.attn.W_q`).
10
+ */
11
+ protected param(shape: Shape, dtype?: Dtype): Tensor;
12
+ }
13
+ /**
14
+ * Walk the module tree and replace every ParamSentinel with a real Tensor
15
+ * created via `paramInput(autoName, ...)`. Must be called inside an active
16
+ * trace context (paramInput appends to the current graph).
17
+ *
18
+ * Returns a flat record of `{ path: tensor }` for every materialized param.
19
+ */
20
+ export declare function materializeParams(root: Module): Record<string, Tensor>;
21
+ //# sourceMappingURL=module.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAoBnD,8BAAsB,MAAM;IAC1B;;;;;;;OAOG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM;CAI5D;AAMD;;;;;;GAMG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAUtE"}
package/dist/module.js ADDED
@@ -0,0 +1,113 @@
1
+ // Module abstraction — a Domeleon-style component layer for parameter trees.
2
+ //
3
+ // User code defines a model as nested classes:
4
+ //
5
+ // class Linear extends Module {
6
+ // W: Tensor; b: Tensor
7
+ // constructor(inDim: number, outDim: number) {
8
+ // super()
9
+ // this.W = this.param([inDim, outDim])
10
+ // this.b = this.param([outDim])
11
+ // }
12
+ // }
13
+ // class Block extends Module {
14
+ // attn = new Attention(D)
15
+ // mlp = new MLP(D, 4 * D)
16
+ // }
17
+ // class Model extends Module {
18
+ // embed = new Linear(VOCAB, D)
19
+ // layers = range(N).map(() => new Block())
20
+ // }
21
+ //
22
+ // The param tree is discovered automatically at compile time by walking
23
+ // enumerable instance properties. Each parameter gets a name auto-derived
24
+ // from its path (`layers.0.attn.W_q`); names are used for upload/download
25
+ // and writeback wiring. Forward functions are pure and stateless — they
26
+ // take the materialized model and inputs, return a Tensor.
27
+ import { paramInput } from './trace.js';
28
+ // ============================================================================
29
+ // Internals: param sentinel
30
+ // ============================================================================
31
+ //
32
+ // `this.param(shape)` returns a placeholder that's replaced by a real Tensor
33
+ // during `materializeParams`. We type-cheat by declaring the return type as
34
+ // `Tensor` so user code can write `this.W` and have TS happy; the cheat is
35
+ // only valid post-materialization (which is always before forward runs).
36
+ class ParamSentinel {
37
+ shape;
38
+ dtype;
39
+ constructor(shape, dtype) {
40
+ this.shape = shape;
41
+ this.dtype = dtype;
42
+ }
43
+ }
44
+ // ============================================================================
45
+ // Module base class
46
+ // ============================================================================
47
+ export class Module {
48
+ /**
49
+ * Declare a learnable parameter at this module. Must be called from inside
50
+ * the constructor (typically as a field assignment). Returns a placeholder
51
+ * that gets replaced with a real Tensor at compile time.
52
+ *
53
+ * The parameter's name is auto-derived from its property path in the model
54
+ * tree (e.g. `layers.0.attn.W_q`).
55
+ */
56
+ param(shape, dtype = 'f32') {
57
+ // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
58
+ return new ParamSentinel(shape, dtype);
59
+ }
60
+ }
61
+ // ============================================================================
62
+ // Tree walking
63
+ // ============================================================================
64
+ /**
65
+ * Walk the module tree and replace every ParamSentinel with a real Tensor
66
+ * created via `paramInput(autoName, ...)`. Must be called inside an active
67
+ * trace context (paramInput appends to the current graph).
68
+ *
69
+ * Returns a flat record of `{ path: tensor }` for every materialized param.
70
+ */
71
+ export function materializeParams(root) {
72
+ const out = {};
73
+ visit(root, '', (path, val, owner, key) => {
74
+ if (val instanceof ParamSentinel) {
75
+ const t = paramInput(path, val.shape, val.dtype);
76
+ owner[key] = t;
77
+ out[path] = t;
78
+ }
79
+ });
80
+ return out;
81
+ }
82
+ function visit(node, path, visitor) {
83
+ if (node === null || node === undefined)
84
+ return;
85
+ if (typeof node !== 'object')
86
+ return;
87
+ if (node instanceof Module) {
88
+ for (const key of Object.keys(node)) {
89
+ const child = node[key];
90
+ const childPath = path ? `${path}.${key}` : key;
91
+ visitChild(child, childPath, node, key, visitor);
92
+ }
93
+ return;
94
+ }
95
+ if (Array.isArray(node)) {
96
+ node.forEach((item, i) => {
97
+ const childPath = path ? `${path}.${i}` : String(i);
98
+ visitChild(item, childPath, node, i, visitor);
99
+ });
100
+ return;
101
+ }
102
+ // Plain leaf object (sentinel / tensor / something else): visitor decides.
103
+ // No deeper recursion.
104
+ }
105
+ function visitChild(child, path, owner, key, visitor) {
106
+ if (child instanceof Module || Array.isArray(child)) {
107
+ visit(child, path, visitor);
108
+ }
109
+ else {
110
+ visitor(path, child, owner, key);
111
+ }
112
+ }
113
+ //# sourceMappingURL=module.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"module.js","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,EAAE;AACF,+CAA+C;AAC/C,EAAE;AACF,kCAAkC;AAClC,2BAA2B;AAC3B,mDAAmD;AACnD,gBAAgB;AAChB,6CAA6C;AAC7C,sCAAsC;AACtC,QAAQ;AACR,MAAM;AACN,iCAAiC;AACjC,8BAA8B;AAC9B,+BAA+B;AAC/B,MAAM;AACN,iCAAiC;AACjC,mCAAmC;AACnC,+CAA+C;AAC/C,MAAM;AACN,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,wEAAwE;AACxE,2DAA2D;AAG3D,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAEvC,+EAA+E;AAC/E,4BAA4B;AAC5B,+EAA+E;AAC/E,EAAE;AACF,6EAA6E;AAC7E,4EAA4E;AAC5E,2EAA2E;AAC3E,yEAAyE;AAEzE,MAAM,aAAa;IACW;IAA8B;IAA1D,YAA4B,KAAY,EAAkB,KAAY;QAA1C,UAAK,GAAL,KAAK,CAAO;QAAkB,UAAK,GAAL,KAAK,CAAO;IAAG,CAAC;CAC3E;AAED,+EAA+E;AAC/E,oBAAoB;AACpB,+EAA+E;AAE/E,MAAM,OAAgB,MAAM;IAC1B;;;;;;;OAOG;IACO,KAAK,CAAC,KAAY,EAAE,QAAe,KAAK;QAChD,wEAAwE;QACxE,OAAO,IAAI,aAAa,CAAC,KAAK,EAAE,KAAK,CAAsB,CAAA;IAC7D,CAAC;CACF;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;;;GAMG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,MAAM,GAAG,GAA2B,EAAE,CAAA;IACtC,KAAK,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,IAAI,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,EAAE,EAAE;QACxC,IAAI,GAAG,YAAY,aAAa,EAAE,CAAC;YACjC,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,EAAE,GAAG,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,CAAC,CAC/C;YAAC,KAAa,CAAC,GAAG,CAAC,GAAG,CAAC,CAAA;YACxB,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;QACf,CAAC;IACH,CAAC,CAAC,CAAA;IACF,OAAO,GAAG,CAAA;AACZ,CAAC;AAaD,SAAS,KAAK,CAAC,IAAa,EAAE,IAAY,EAAE,OAAgB;IAC1D,IAAI,IAAI,KAAK,IAAI,IAAI,IAAI,KAAK,SAAS;QAAE,OAAM;IAC/C,IAAI,OAAO,IAAI,KAAK,QAAQ;QAAE,OAAM;IAEpC,IAAI,IAAI,YAAY,MAAM,EAAE,CAAC;QAC3B,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,IAAc,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAI,IAAY,CAAC,GAAG,CAAC,CAAA;YAChC,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,CAAA;YAC/C,UAAU,CAAC,KAAK,EAAE,SAAS,EAAE,IAAI,EAAE,GAAG,EAAE,OAAO,CAAC,CAAA;QAClD,CAAC;QACD,OAAM;IACR,CAAC;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;QACxB,IAAI,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YACvB,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YACnD,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,IAAyB,EAAE,CAAC,EAAE,OAAO,CAAC,CAAA;QACpE,CAAC,CAAC,CAAA;QACF,OAAM;IACR,CAAC;IACD,2EAA2E;IAC3E,uBAAuB;AACzB,CAAC;AAED,SAAS,UAAU,CAAC,KAAc,EAAE,IAAY,EAAE,KAAa,EAAE,GAAoB,EAAE,OAAgB;IACrG,IAAI,KAAK,YAAY,MAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;QACpD,KAAK,CAAC,KAAK,EAAE,IAAI,EAAE,OAAO,CAAC,CAAA;IAC7B,CAAC;SAAM,CAAC;QACN,OAAO,CAAC,IAAI,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,CAAC,CAAA;IAClC,CAAC;AACH,CAAC"}
package/dist/ops.d.ts ADDED
@@ -0,0 +1,35 @@
1
+ import type { Tensor, Shape, Dtype } from './ir.js';
2
+ export declare function add(a: Tensor, b: Tensor | number): Tensor;
3
+ export declare function sub(a: Tensor, b: Tensor | number): Tensor;
4
+ export declare function mul(a: Tensor, b: Tensor | number): Tensor;
5
+ export declare function div(a: Tensor, b: Tensor | number): Tensor;
6
+ export declare function mulScalar(a: Tensor, scalar: number): Tensor;
7
+ export declare function addScalar(a: Tensor, scalar: number): Tensor;
8
+ export declare const sqrt: (a: Tensor) => Tensor;
9
+ export declare const rsqrt: (a: Tensor) => Tensor;
10
+ export declare const log: (a: Tensor) => Tensor;
11
+ export declare const exp: (a: Tensor) => Tensor;
12
+ export declare const relu: (a: Tensor) => Tensor;
13
+ export declare function meanLast(a: Tensor): Tensor;
14
+ export declare function sumLast(a: Tensor): Tensor;
15
+ export declare function reshape(a: Tensor, newShape: Shape): Tensor;
16
+ export declare function transpose(a: Tensor, perm: readonly number[]): Tensor;
17
+ export declare function matmul(a: Tensor, b: Tensor): Tensor;
18
+ export declare function matmulBatched(a: Tensor, b: Tensor): Tensor;
19
+ export declare function oneHot(indices: Tensor, depth: number, dtype?: Dtype): Tensor;
20
+ export declare function arange(n: number, dtype?: Dtype): Tensor;
21
+ export declare function softmaxCausalLast(a: Tensor): Tensor;
22
+ export declare function logSoftmaxLast(a: Tensor): Tensor;
23
+ export declare function whereCausal(a: Tensor, fillValue: number): Tensor;
24
+ export declare function sliceLastRange(a: Tensor, start: number, end: number): Tensor;
25
+ export declare function broadcastTo(a: Tensor, targetShape: Shape): Tensor;
26
+ export declare function sumToShape(a: Tensor, targetShape: Shape): Tensor;
27
+ export declare function constScalar(value: number, dtype?: Dtype): Tensor;
28
+ export declare const less: (a: Tensor, b: Tensor) => Tensor;
29
+ export declare const greater: (a: Tensor, b: Tensor) => Tensor;
30
+ export declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
31
+ export declare function reluGrad(x: Tensor, dy: Tensor): Tensor;
32
+ export declare function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor;
33
+ export declare function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor;
34
+ export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor;
35
+ //# sourceMappingURL=ops.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAWnG"}