tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,269 +0,0 @@
1
- from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
2
- import struct, math
3
- from collections import defaultdict
4
- from tinygrad.helpers import DEBUG
5
- from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
6
- from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
7
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
8
- from tinygrad.renderer import Renderer, TensorCore
9
-
10
- def render_val(x, dtype):
11
- if dtypes.is_float(dtype):
12
- if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
13
- if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
14
- return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
15
- return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
16
-
17
- class PTXRenderer(Renderer):
18
- device = "CUDA"
19
- suffix = "PTX"
20
- global_max = (2147483647, 65535, 65535)
21
- local_max = (1024, 1024, 64)
22
- shared_max = 49152
23
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
24
- def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
25
-
26
- # language options
27
- kernel_prefix = """.version VERSION
28
- .target TARGET
29
- .address_size 64
30
- .visible .entry"""
31
- barrier = "bar.sync\t0;"
32
- gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
33
- gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
34
- lid = [f'%tid.{chr(120+i)}' for i in range(3)]
35
- asm_for_op: Dict[Op, Callable] = {
36
- UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
37
- else f"neg.{name} {d}, {a};",
38
- UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
39
- UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
40
- UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
41
- BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
42
- BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
43
- BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
44
- BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
45
- BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
46
- BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
47
- BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
48
- BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
49
- TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
50
- TernaryOps.WHERE: lambda d,a,b,c,dt,name:
51
- f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
52
- }
53
- supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
54
- TernaryOps.WHERE]
55
- # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
56
- types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
57
- dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
58
- dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
59
-
60
- mem_types: Dict[DType, str] = types.copy()
61
- mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
62
-
63
- const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
64
-
65
- def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
66
- val = render_val(x, dtype)
67
- if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
68
- return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
69
-
70
- def render_local(self, dest, name, size, dtype) -> List[str]:
71
- return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
72
-
73
- def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
74
-
75
- def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
76
-
77
- def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
78
- assert dtype != dtypes.bool
79
- if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
80
- return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
81
-
82
- def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
83
- return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
84
-
85
- def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
86
- if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
87
- if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
88
- if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
89
- rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
90
- '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
91
- return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
92
-
93
- def render_kernel(self, kernel, function_name, bufs, regs) -> str:
94
- kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
95
- def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
96
- return (f"{self.kernel_prefix} {function_name}(\n\t" +
97
- ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
98
- '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
99
- "\n}")
100
-
101
- def render(self, name:str, uops:UOpGraph) -> str:
102
- kernel:List[str] = []
103
- bufs = []
104
-
105
- uops.linearize(ptx_matcher)
106
- if DEBUG >= 4: uops.print()
107
-
108
- def kk(*s: str): kernel.append("\n".join(s))
109
-
110
- c: DefaultDict[str, int] = defaultdict(int)
111
- r: Dict[UOp, Union[List[str], str]] = {}
112
- def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
113
- nonlocal c, r
114
- prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
115
- c[prefix] += 1
116
- if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
117
- return f"%{prefix}{c[prefix]-1}"
118
-
119
- def const(x:ConstType, dtype:DType, mov=False):
120
- if mov or dtype in self.const_requires_mov:
121
- kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
122
- return out
123
- return self.render_const(x, dtype)
124
-
125
- def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
126
- if atype == dtype or isinstance(atype, PtrDType):
127
- if u: r[u] = a
128
- return a
129
- kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
130
- return ret
131
-
132
- for u in uops:
133
- uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
134
- if uop is UOps.IF:
135
- assert src[0].dtype is not None
136
- kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
137
- elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
138
- elif uop is UOps.ENDRANGE:
139
- kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
140
- self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
141
- kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
142
- elif uop is UOps.ENDIF:
143
- kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
144
- elif uop is UOps.STORE:
145
- assert src[0].dtype is not None and src[2].dtype is not None
146
- assert src[0].dtype == dtypes.int64, "store isn't int64"
147
- assert src[1].op is UOps.CONST, f"store isn't const {u}"
148
- mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
149
- if src[2].dtype.count > 1:
150
- kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
151
- f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
152
- else:
153
- kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
154
- else:
155
- assert dtype is not None, f"None dtype for uop {uop}"
156
- if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
157
- elif uop is UOps.ALU:
158
- assert src[0].dtype is not None
159
- if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
160
- # pass in the other dtype here
161
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
162
- else:
163
- kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
164
- elif uop is UOps.DEFINE_ACC:
165
- if dtype.count > 1:
166
- r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
167
- for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
168
- else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
169
- elif uop is UOps.SPECIAL:
170
- assert args[1][0] != "i", "idx not supported"
171
- kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
172
- r[u] = "%" + args[1]
173
- kernel = [f".reg .u32 %{args[1]};"] + kernel
174
- elif uop is UOps.CONST:
175
- if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
176
- else: r[u] = const(args, dtype, mov=True)
177
- elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
178
- elif uop is UOps.LOAD:
179
- assert src[0].dtype == dtypes.int64, "load isn't int64"
180
- assert src[1].op is UOps.CONST, f"load isn't const {u}"
181
- mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
182
- if dtype.count > 1:
183
- r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
184
- if(len(src)>3):
185
- for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
186
- kk((f"@{r[src[2]]}"if len(src) > 3 else "")
187
- + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
188
- else:
189
- kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
190
- alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
191
- elif uop is UOps.PHI:
192
- if dtype.count > 1:
193
- for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
194
- else:
195
- kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
196
- r[u] = r[src[0]]
197
- elif uop in {UOps.CAST, UOps.BITCAST}:
198
- assert src[0].dtype is not None
199
- if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
200
- else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
201
- elif uop is UOps.DEFINE_LOCAL:
202
- # TODO: we should sum these, and fetch 0xC000 from somewhere
203
- assert args[1]*dtype.itemsize <= 0xC000, "too large local"
204
- kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
205
- elif uop is UOps.DEFINE_VAR:
206
- bufs.append((args.expr, dtype))
207
- r[u] = f"%{args.expr}"
208
- kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
209
- elif uop is UOps.DEFINE_GLOBAL:
210
- bufs.append((nm:=f"data{args[0]}", dtype))
211
- r[u] = f"%{nm}"
212
- dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
213
- kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
214
- elif uop is UOps.WMMA:
215
- wmma = []
216
- for vv in src[:2]:
217
- for i in range(0, len(r[vv]), 2):
218
- wmma.append(ssa("wmma", dtype="b32"))
219
- kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
220
- r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
221
- kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
222
- {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
223
- else: raise NotImplementedError(f"no code for {uop}")
224
-
225
- return self.render_kernel(kernel, name, bufs, c.items())
226
-
227
- ptx_matcher = PatternMatcher([
228
- (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
229
- src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
230
- lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
231
- (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
232
- src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
233
- lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
234
- (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
235
- (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
236
- lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
237
- (UPat(UOps.ALU, BinaryOps.ADD,
238
- [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
239
- lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
240
- *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
241
- lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
242
- for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
243
- (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
244
- lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
245
- (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
246
- lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
247
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
248
- lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
249
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
250
- lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
251
- (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
252
- lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
253
- # ptr_ar (load/store)
254
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
255
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
256
- lambda root, alu, const: UOp(root.op, root.dtype,
257
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
258
- UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
259
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
260
- UPat(UOps.CONST, name="const"))),
261
- lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
262
- UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
263
- )+root.src[2:])),
264
- (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
265
- UPat(name="alu"))), # no const here
266
- lambda root, alu: UOp(root.op, root.dtype,
267
- (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
268
- UOp.const(dtypes.int64, 0))+root.src[2:])),
269
- ])
@@ -1,327 +0,0 @@
1
- from __future__ import annotations
2
- import functools
3
- from math import gcd
4
- from tinygrad.helpers import partition
5
- from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
6
-
7
- # NOTE: Python has different behavior for negative mod and floor div than c
8
- # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
9
-
10
- class Node:
11
- b: Union[Node, int]
12
- min: int
13
- max: sint
14
- def render(self, ops=None, ctx=None) -> Any:
15
- if ops is None: ops = render_python
16
- assert self.__class__ in (Variable, NumNode) or self.min != self.max
17
- return ops[type(self)](self, ops, ctx)
18
- def vars(self) -> Set[Variable]: return set()
19
- # substitute Variables with the values in var_vals
20
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
21
- def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
22
-
23
- @functools.cached_property
24
- def key(self) -> str: return self.render(ctx="DEBUG")
25
- def __repr__(self): return self.render(ctx="REPR")
26
- def __str__(self): return "<"+self.key+">"
27
- def __hash__(self): return hash(self.key)
28
- def __bool__(self): return not (self.max == self.min == 0)
29
- def __eq__(self, other:object) -> bool:
30
- if not isinstance(other, Node): return NotImplemented
31
- return self.key == other.key
32
- def __neg__(self): return self*-1
33
- def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
34
- def __radd__(self, b:int): return self+b
35
- def __sub__(self, b:Union[Node,int]): return self+-b
36
- def __rsub__(self, b:int): return -self+b
37
- def __le__(self, b:Union[Node,int]): return self < (b+1)
38
- def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
39
- def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
40
- def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
41
- def __mul__(self, b:Union[Node, int]):
42
- if b == 0: return NumNode(0)
43
- if b == 1: return self
44
- return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
45
- def __rmul__(self, b:int): return self*b
46
-
47
- # *** complex ops ***
48
-
49
- def __rfloordiv__(self, b:int): return NumNode(b) // self
50
- def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
51
- if isinstance(b, Node):
52
- if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
53
- if self == b: return NumNode(1)
54
- if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
55
- raise RuntimeError(f"not supported: {self} // {b}")
56
- assert b != 0
57
- if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
58
- if b == 1: return self
59
-
60
- # the numerator of div is not allowed to be negative
61
- if self.min < 0:
62
- offset = self.min//b
63
- # factor out an "offset" to make the numerator positive. don't allowing factoring again
64
- return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
65
- return create_node(DivNode(self, b))
66
-
67
- def __rmod__(self, b:int): return NumNode(b) % self
68
- def __mod__(self, b:Union[Node,int]):
69
- if isinstance(b, Node):
70
- if b.__class__ is NumNode: return self % b.b
71
- if self == b: return NumNode(0)
72
- if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
73
- raise RuntimeError(f"not supported: {self} % {b}")
74
- assert b > 0
75
- if b == 1: return NumNode(0)
76
- if isinstance(self.max, int) and isinstance(self.min, int):
77
- if self.min >= 0 and self.max < b: return self
78
- if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
79
- if self.min < 0: return (self - ((self.min//b)*b)) % b
80
- return create_node(ModNode(self, b))
81
-
82
- @staticmethod
83
- def sum(nodes:List[Node]) -> Node:
84
- nodes = [x for x in nodes if x.max or x.min]
85
- if not nodes: return NumNode(0)
86
- if len(nodes) == 1: return nodes[0]
87
-
88
- mul_groups: Dict[Node, int] = {}
89
- num_node_sum = 0
90
- for node in SumNode(nodes).flat_components:
91
- if node.__class__ is NumNode: num_node_sum += node.b
92
- elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
93
- else: mul_groups[node] = mul_groups.get(node, 0) + 1
94
- new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
95
- if num_node_sum: new_nodes.append(NumNode(num_node_sum))
96
- return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
97
-
98
- @staticmethod
99
- def ands(nodes:List[Node]) -> Node:
100
- if not nodes: return NumNode(1)
101
- if len(nodes) == 1: return nodes[0]
102
- if any(not x for x in nodes): return NumNode(0)
103
-
104
- # filter 1s
105
- nodes = [x for x in nodes if x.min != x.max]
106
- return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
107
-
108
- # 4 basic node types
109
-
110
- class Variable(Node):
111
- def __new__(cls, *args):
112
- expr, nmin, nmax = args
113
- assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
114
- if nmin == nmax: return NumNode(nmin)
115
- return super().__new__(cls)
116
-
117
- def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
118
-
119
- def __init__(self, expr:str, nmin:int, nmax:sint):
120
- self.expr, self.min, self.max = expr, nmin, nmax
121
- self._val: Optional[int] = None
122
- @property
123
- def val(self):
124
- assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
125
- return self._val
126
- def bind(self, val):
127
- assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
128
- self._val = val
129
- return self
130
- def unbind(self) -> Tuple[Variable, int]:
131
- assert self.val is not None, f"cannot unbind {self}"
132
- return Variable(self.expr, self.min, self.max), self.val
133
- def vars(self): return {self}
134
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
135
-
136
- class NumNode(Node):
137
- def __init__(self, num:int):
138
- assert isinstance(num, int), f"{num} is not an int"
139
- self.b:int = num
140
- self.min, self.max = num, num
141
- def bind(self, val):
142
- assert self.b == val, f"cannot bind {val} to {self}"
143
- return self
144
- def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
145
- def __eq__(self, other): return self.b == other
146
- def __hash__(self): return hash(self.b) # needed with __eq__ override
147
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
148
-
149
- def create_node(ret:Node):
150
- assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
151
- if ret.min == ret.max: return NumNode(ret.min)
152
- return ret
153
-
154
- def create_lt_node(lhs:Node, b:Union[Node, int]):
155
- if isinstance(lhs, SumNode):
156
- if isinstance(b, int):
157
- new_sum = []
158
- for x in lhs.nodes:
159
- # TODO: should we just force the last one to always be the number
160
- if isinstance(x, NumNode): b -= x.b
161
- else: new_sum.append(x)
162
- lhs = Node.sum(new_sum)
163
- nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
164
- assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
165
- muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
166
- if muls:
167
- # NOTE: gcd in python 3.8 takes exactly 2 args
168
- mul_gcd = b
169
- for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
170
- all_others = Node.sum(others)
171
- if all_others.min >= 0 and all_others.max < mul_gcd:
172
- lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
173
- return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
174
- if isinstance(lhs, MulNode):
175
- if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
176
- sgn = 1 if lhs.b > 0 else -1
177
- return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
178
- return create_node(LtNode(lhs, b))
179
-
180
- def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
181
-
182
- class OpNode(Node):
183
- def __init__(self, a:Node, b:Union[Node, int]):
184
- self.a, self.b = a, b
185
- self.min, self.max = self.get_bounds()
186
- def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
187
- def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
188
-
189
- class LtNode(OpNode):
190
- def get_bounds(self) -> Tuple[int, int]:
191
- if self.a == self.b: return (0, 0)
192
- if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
193
- return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
194
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
195
- return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
196
-
197
- class MulNode(OpNode):
198
- def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
199
- def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
200
- if self.b % b == 0: return self.a*(self.b//b)
201
- if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
202
- return Node.__floordiv__(self, b, factoring_allowed)
203
- def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
204
- def get_bounds(self) -> Tuple[int, sint]:
205
- assert self.a.min >= 0
206
- if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
207
- return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
208
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
209
- return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
210
-
211
- class DivNode(OpNode):
212
- def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
213
- def get_bounds(self) -> Tuple[int, sint]:
214
- assert self.a.min >= 0 and isinstance(self.b, int)
215
- return self.a.min//self.b, self.a.max//self.b
216
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
217
-
218
- class ModNode(OpNode):
219
- def __mod__(self, b: Union[Node, int]):
220
- if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
221
- return Node.__mod__(self, b)
222
- def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
223
- return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
224
- def get_bounds(self) -> Tuple[int, sint]:
225
- assert self.a.min >= 0 and isinstance(self.b, int)
226
- if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
227
- return (self.a.min%self.b, self.a.max%self.b)
228
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
229
-
230
- class RedNode(Node):
231
- def __init__(self, nodes:List[Node]):
232
- self.nodes = nodes
233
- self.min, self.max = self.get_bounds()
234
- def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
235
- def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
236
-
237
- class SumNode(RedNode):
238
- def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
239
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
240
- def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
241
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
242
- def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
243
- if self == b: return NumNode(1)
244
- fully_divided: List[Node] = []
245
- rest: List[Node] = []
246
- if isinstance(b, Node):
247
- for x in self.flat_components:
248
- if x % b == 0: fully_divided.append(x // b)
249
- else: rest.append(x)
250
- if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
251
- return Node.__floordiv__(self, b, False)
252
- if b == 1: return self
253
- if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
254
- _gcd = b
255
- divisor = 1
256
- for x in self.flat_components:
257
- if x.__class__ in (NumNode, MulNode):
258
- if x.b % b == 0: fully_divided.append(x // b)
259
- else:
260
- if x.__class__ is NumNode and (div := x.b // b):
261
- fully_divided.append(NumNode(div))
262
- x = NumNode(x.b - b * div)
263
- rest.append(x)
264
- if isinstance(x.b, int):
265
- _gcd = gcd(_gcd, x.b)
266
- if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
267
- else:
268
- _gcd = 1
269
- else:
270
- rest.append(x)
271
- _gcd = 1
272
- if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
273
- if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
274
- return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
275
-
276
- @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
277
- def __mod__(self, b: Union[Node, int]):
278
- if self == b: return NumNode(0)
279
- if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
280
- new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
281
- return Node.__mod__(new_sum, b)
282
-
283
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
284
- return Node.sum([node.substitute(var_vals) for node in self.nodes])
285
-
286
- # recursively expand sumnode components
287
- # TODO: can remove this if there's no SumNode inside SumNode
288
- @property
289
- def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
290
-
291
- class AndNode(RedNode):
292
- def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
293
- def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
294
- subed = []
295
- for node in self.nodes:
296
- if not (sub:=node.substitute(var_vals)): return NumNode(0)
297
- subed.append(sub)
298
- return Node.ands(subed)
299
-
300
- def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
301
- def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
302
- if isinstance(a, (int, float)): return a
303
- ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
304
- assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
305
- return ret.b
306
-
307
- # symbolic int, these are allowed in a Tensor shape
308
- sint = Union[int, Variable, MulNode, SumNode]
309
-
310
- def render_mulnode(node:MulNode, ops, ctx):
311
- # TODO: add ProdNode and remove this case
312
- if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
313
- return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
314
- return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
315
-
316
- render_python: Dict[Type, Callable[..., str]] = {
317
- Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
318
- else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
319
- else f"{self.expr}"),
320
- NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
321
- MulNode: render_mulnode,
322
- DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
323
- ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
324
- LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
325
- SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
326
- AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
327
- }