tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/lazy.py DELETED
@@ -1,228 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional, Any, Tuple, List, get_args
3
- from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
4
- from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
5
- from tinygrad.ops import exec_alu, python_alu
6
- from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
7
- from tinygrad.shape.shapetracker import ShapeTracker
8
- from tinygrad.device import Buffer
9
- from weakref import ref, ReferenceType, WeakValueDictionary
10
-
11
- lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
12
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
13
- base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
14
- if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
15
- dtype = to_dtype(dtype)
16
- if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
17
-
18
- cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
- if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
20
-
21
- ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
22
- if enable_cache: lazycache[cache_key] = ret
23
- return ret
24
-
25
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
26
- class LazyBuffer(MathTrait):
27
- def __init__(self, device:str, st:ShapeTracker, dtype:DType,
28
- op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
29
- base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
30
- self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
31
- self._base: Optional[LazyBuffer] = None
32
- if base is None:
33
- # properties on base
34
- self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
35
- assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
36
-
37
- if self.op is Ops.BUFFER_VIEW:
38
- # some LazyBuffers can be processed with only a view, no AST required
39
- self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
40
- else:
41
- self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
42
- self.buffer.ref(1)
43
- self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
44
- self.forced_realize = False
45
- else:
46
- # properties on view
47
- assert base.base == base, "base must be a base itself"
48
- self._base = base
49
-
50
- def __del__(self):
51
- if hasattr(self, 'buffer'): self.buffer.ref(-1)
52
-
53
- def __repr__(self) -> str:
54
- return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
55
-
56
- @property
57
- def realized(self) -> Optional[Buffer]:
58
- # NOTE: we check for a lack of srcs instead of an allocated buffer to make unrealized assigns return None here
59
- return self.buffer if self._base is None and not hasattr(self, 'srcs') else None
60
-
61
- # NOTE: this has to be a function to prevent self reference
62
- @property
63
- def base(self) -> LazyBuffer: return self._base if self._base is not None else self
64
-
65
- # same API as multi
66
- @property
67
- def lbs(self) -> List[LazyBuffer]: return [self]
68
-
69
- @staticmethod
70
- def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
71
- assert isinstance(src, tuple)
72
- return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
73
-
74
- def const_like(self, b): return self.const_with_shape(b, self.shape)
75
- def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
76
- assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
77
- return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
78
-
79
- @property
80
- def is_realized(self) -> bool: return self.base.realized is not None
81
-
82
- def assign(self, x:LazyBuffer) -> LazyBuffer:
83
- assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
84
- assert self.is_realized, f"assign target must be realized {self}"
85
- return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
86
- src=(self.base, x), enable_cache=True)
87
-
88
- def can_view(self):
89
- return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
90
- self.device.split(":")[0] in view_supported_devices)
91
-
92
- def contiguous(self, allow_buffer_view=True):
93
- if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
94
- ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
95
- if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
96
- return ret
97
- self.base.forced_realize = True
98
- return self
99
-
100
- def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
101
- def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
102
- if self.dtype == dtype: return self
103
- if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
104
- if self.is_unrealized_unmasked_const() and not bitcast:
105
- return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
106
- new_shape = self.shape
107
- if bitcast and self.dtype.itemsize != dtype.itemsize:
108
- if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
109
- if not all_int(new_shape): raise RuntimeError("shape changing bitcast with symbolic shape isn't supported yet")
110
- # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
111
- if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
112
- new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
113
- elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
114
- # TODO: applying this makes gpt2 slower
115
- return self.base.cast(dtype, bitcast)._view(self.st)
116
- cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
117
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
118
-
119
- def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
120
- def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
121
-
122
- def _copy(self, device:str) -> LazyBuffer:
123
- assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
124
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
125
-
126
- def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
127
- # no COPY
128
- if self.device == device and not clone: return self
129
-
130
- # double COPY = one COPY
131
- if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
132
- return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
133
-
134
- # const doesn't have to be copied (issues with disk tensor)
135
- if self.is_unrealized_const():
136
- return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
137
-
138
- # if it's a shrink, do the shrink before the copy with CONTIGUOUS
139
- if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
140
-
141
- # copy the base and apply the shapetracker on the new device
142
- return self.base._copy(device)._view(self.st)
143
-
144
- def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
145
-
146
- def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
147
- srcs: List[LazyBuffer] = []
148
- for s in (self,)+in_srcs:
149
- if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
150
- srcs.append(root._view(s.base.contiguous_child[1]))
151
- else:
152
- srcs.append(s)
153
- if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
154
- raise AssertionError(f"all dtypes must match {dts} on {op}")
155
- assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
156
- if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
157
-
158
- out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
159
-
160
- # const folding
161
- if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
162
- return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
163
- if op in GroupOp.Binary:
164
- x, y = self, in_srcs[0]
165
- if op is Ops.ADD:
166
- if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
167
- if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
168
- if op is Ops.MUL:
169
- if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
170
- if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
171
- if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
172
-
173
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
174
-
175
- # *** reduce ops ***
176
-
177
- def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
178
- assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
179
- axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
180
- if len(axis) == 0: return self
181
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
182
-
183
- def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
184
- new_shape = self.st.reduce(axis)
185
- # TODO: this logic should move to the scheduler
186
- if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
187
-
188
- # const folding
189
- # TODO: fold this for symbolic?
190
- if self.is_unrealized_unmasked_const() and all_int(self.shape):
191
- if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
192
- if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
193
- if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
194
-
195
- # TODO: can we split symbolic shape if the reduce axis is not symbolic?
196
- if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
197
- prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
198
- return self._reduce_op(op, axis)
199
-
200
- # if there are few globals, make some reduces into globals by splitting into two kernels
201
- # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
202
- # ~2**10 should be enough if GROUP is used
203
- # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
204
- # split is moved to the end to provide maximum locality for the second phase reduce.
205
- self_real_strides = self.st.real_strides(ignore_valid=True)
206
- split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
207
- if self.shape[i] % x == 0 and self_real_strides[i] != 0]
208
- if not split_candidates: return self._reduce_op(op, axis)
209
- dim_to_split, divisor = split_candidates[0]
210
- splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
211
- splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
212
- if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
213
- return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
214
-
215
- # *** movement ops ***
216
-
217
- def _view(self, new_st:ShapeTracker) -> LazyBuffer:
218
- if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
219
- return self.const_with_shape(0, new_st.shape)
220
- if new_st.contiguous and self.base.shape == new_st.shape: return self.base
221
- return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
222
-
223
- def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
224
- def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
225
- def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
226
- def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
227
- def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
228
- def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
tinygrad/function.py DELETED
@@ -1,212 +0,0 @@
1
- """This is where the forwards and backwards passes live."""
2
- import math
3
- from typing import Tuple, Optional
4
- from tinygrad.helpers import argsort
5
- from tinygrad.dtype import dtypes, DType, sum_acc_dtype
6
- from tinygrad.ops import Ops, resolve, sint
7
- from tinygrad.tensor import Function
8
- from tinygrad.engine.lazy import LazyBuffer
9
-
10
- class Contiguous(Function):
11
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
12
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
13
-
14
- class ContiguousBackward(Function):
15
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x
16
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
17
-
18
- class Cast(Function):
19
- def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
20
- self.input_dtype, self.bitcast = x.dtype, bitcast
21
- return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
22
-
23
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
24
- if self.bitcast: raise RuntimeError("bitcast cannot backward")
25
- return grad_output.cast(self.input_dtype)
26
-
27
- # ************* unary ops *************
28
-
29
- class Reciprocal(Function):
30
- def forward(self, x:LazyBuffer) -> LazyBuffer:
31
- self.ret = x.reciprocal()
32
- return self.ret
33
-
34
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return -grad_output * self.ret * self.ret
35
-
36
- class Sin(Function):
37
- def forward(self, x:LazyBuffer) -> LazyBuffer:
38
- self.x = x
39
- return x.sin()
40
-
41
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
42
-
43
- class Relu(Function):
44
- def forward(self, x:LazyBuffer) -> LazyBuffer:
45
- self.ret = x.maximum(0)
46
- return self.ret
47
-
48
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.gt(0).cast(grad_output.dtype) * grad_output
49
-
50
- class Log(Function):
51
- def forward(self, x:LazyBuffer) -> LazyBuffer:
52
- self.x = x
53
- return x.log2() * math.log(2)
54
-
55
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / self.x
56
-
57
- class Exp(Function):
58
- def forward(self, x:LazyBuffer) -> LazyBuffer:
59
- self.ret = (x * (1/math.log(2))).exp2()
60
- return self.ret
61
-
62
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret * grad_output
63
-
64
- class Sqrt(Function):
65
- def forward(self, x:LazyBuffer) -> LazyBuffer:
66
- self.ret = x.sqrt()
67
- return self.ret
68
-
69
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output / (self.ret*2)
70
-
71
- # NOTE: the implicit derivative of sigmoid is not stable
72
- # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
73
- # TODO: have the backend automatically find this
74
- class Sigmoid(Function):
75
- def forward(self, x:LazyBuffer) -> LazyBuffer:
76
- self.ret = (1 + (x * (-1/math.log(2))).exp2()).reciprocal()
77
- return self.ret
78
-
79
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
80
- return (self.ret * (1 - self.ret)) * grad_output
81
-
82
- class Sign(Function):
83
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
84
- # backward always return 0 to match torch
85
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const_like(0)
86
-
87
- # ************* binary ops *************
88
-
89
- class Less(Function):
90
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.lt(y)
91
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
92
-
93
- class Neq(Function):
94
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.ne(y)
95
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
96
-
97
- class Xor(Function):
98
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x^y
99
-
100
- class BitwiseAnd(Function):
101
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x&y
102
-
103
- class BitwiseOr(Function):
104
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x|y
105
-
106
- class Threefry(Function):
107
- def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.threefry(seed)
108
-
109
- class Add(Function):
110
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x+y
111
-
112
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
113
- return grad_output if self.needs_input_grad[0] else None, \
114
- grad_output if self.needs_input_grad[1] else None
115
-
116
- class Mul(Function):
117
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
118
- self.x, self.y = x, y
119
- return x * y
120
-
121
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
122
- return (self.y * grad_output) if self.needs_input_grad[0] else None, \
123
- (self.x * grad_output) if self.needs_input_grad[1] else None
124
-
125
- class IDiv(Function):
126
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x // y
127
-
128
- # ************* ternary ops *************
129
-
130
- class Where(Function):
131
- def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
132
- self.x = x
133
- return self.x.where(y, z)
134
-
135
- def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
136
- return None, \
137
- self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
138
- self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
139
-
140
- # ************* reduce ops *************
141
-
142
- class Sum(Function):
143
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
144
- self.input_shape = x.shape
145
- return x.r(Ops.ADD, axis)
146
-
147
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
148
-
149
- class Prod(Function):
150
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
151
- self.x, self.ret = x, x.r(Ops.MUL, axis)
152
- return self.ret
153
-
154
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
155
- return (grad_output * self.ret).expand(self.x.shape) / self.x
156
-
157
- class Max(Function):
158
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
159
- self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
160
- return self.ret
161
-
162
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
163
- # 1s in locations where the max was chosen (can be two locations)
164
- max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
165
- div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
166
- return (max_is_1s/div) * grad_output.expand(self.x.shape)
167
-
168
- # ************* movement ops *************
169
-
170
- # NOTE: this is sum in reverse
171
- class Expand(Function):
172
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
173
- self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
174
- return x.expand(shape)
175
-
176
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
177
- return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
178
-
179
- class Reshape(Function):
180
- def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
181
- self.input_shape = x.shape
182
- return x.reshape(shape)
183
-
184
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
185
-
186
- class Permute(Function):
187
- def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
188
- self.input_order = order
189
- return x.permute(order)
190
-
191
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
192
-
193
- class Pad(Function):
194
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
195
- self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
196
- return x.pad(arg)
197
-
198
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
199
-
200
- class Shrink(Function):
201
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
202
- self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
203
- return x.shrink(arg)
204
-
205
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
206
-
207
- class Flip(Function):
208
- def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
209
- self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
210
- return x.stride(self.arg)
211
-
212
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
tinygrad/multi.py DELETED
@@ -1,177 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional, Tuple, List, Dict
3
- import functools, itertools, operator
4
- from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
5
- from tinygrad.dtype import DType
6
- from tinygrad.ops import Ops, MathTrait
7
- from tinygrad.engine.lazy import LazyBuffer
8
- from tinygrad.shape.shapetracker import sint
9
-
10
- def all_reduce(bop: Ops, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
11
- assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
12
- assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
13
- n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
14
- # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
15
- # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
16
- use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
17
- if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
18
- if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
19
-
20
- factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
21
- base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
22
- chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
23
- acc = 0
24
- chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
25
- chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
26
-
27
- # scatter-reduce
28
- for step in range(n_lbs-1):
29
- for i in range(len(chunks)):
30
- src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
31
- chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device, force=True))
32
-
33
- # allgather
34
- for step in range(n_lbs-1):
35
- for i in range(len(chunks)):
36
- src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
37
- chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device, force=True)
38
-
39
- # assemble chunks back
40
- pads = [((s,numel-e),) for s,e in chunks]
41
- return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
42
-
43
- def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
44
- if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
45
- return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
46
-
47
- class MultiLazyBuffer(MathTrait):
48
- def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
49
- assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
50
- assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
51
- self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
52
- if axis is not None:
53
- splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
54
- self.bounds = tuple(zip(splits, splits[1:]))
55
-
56
- @property
57
- def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
58
-
59
- @property
60
- def size(self): return sum(x.size for x in self.real_lbs)
61
-
62
- @property
63
- def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
64
-
65
- def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
66
-
67
- @staticmethod
68
- def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int], bounds:Optional[Tuple[Tuple[int, int], ...]]):
69
- assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
70
- lbs = [lb] * len(devices)
71
- sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
72
- return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
73
-
74
- def copy_to_device(self, device:str) -> LazyBuffer:
75
- if self.axis is None:
76
- # if we already have a copy on the device, return that
77
- return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
78
- # copy lbs to device, pad to final shape, and sum
79
- llbs:List[LazyBuffer] = []
80
- for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
81
- if not real: continue
82
- pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
83
- llbs.append(lb.copy_to_device(device).pad(pad_arg))
84
- return functools.reduce(operator.add, llbs)
85
-
86
- # passthroughs
87
- @property
88
- def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
89
- def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
90
- return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
91
- def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
92
- def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
93
- def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
94
- def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
95
-
96
- # elementwise is simple
97
- def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
98
- msrcs = (self,)+in_srcs
99
- assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
100
- assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
101
-
102
- # NOTE: they all have to share an axis, we always choose [-1]
103
- axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
104
- srcs:List[List[LazyBuffer]] = []
105
- not_all_real = not all(all(mlb.real) for mlb in msrcs)
106
- new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
107
- assert any(new_real), "output contains no real lb"
108
- for mlb in msrcs:
109
- if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
110
- elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
111
- else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
112
- new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
113
- # NOTE: const dtype should match real
114
- new_dtype = next(iter(new_real_lbs.values())).dtype
115
- return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
116
-
117
- def r(self, op:Ops, axis:Tuple[int, ...]) -> MultiLazyBuffer:
118
- if self.axis is not None and self.axis in axis:
119
- # all-reduce on sharded axes
120
- reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
121
- # if all partitions are real, do all_reduce
122
- if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
123
- # only one partition is real, keep it
124
- return MultiLazyBuffer(reduced_parts, None, self.real)
125
- # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
126
- return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
127
-
128
- # *** movement ops ***
129
-
130
- def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
131
- return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
132
-
133
- def reshape(self, arg:Tuple[sint, ...]):
134
- if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
135
- assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
136
- arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
137
- # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
138
- # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
139
- new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
140
- assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
141
- lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
142
- return MultiLazyBuffer(lbs, new_axis, self.real)
143
-
144
- def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
145
- assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
146
- # pad on shard axis -> fill others with zeros and set real to all True
147
- if self.axis is not None and arg[self.axis] != (0,0):
148
- # pad back to whole axis, remove real mask
149
- assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
150
- dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
151
- assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
152
- return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
153
- return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
154
-
155
- def expand(self, arg:Tuple[sint, ...]):
156
- # NOTE: this assert isn't needed, sharded axis can have dim 1
157
- assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
158
- return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
159
-
160
- def permute(self, arg:Tuple[int, ...]):
161
- # all permutes supported!
162
- return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
163
-
164
- def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
165
- assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
166
- if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
167
- assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
168
- # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
169
- idx = self.bounds.index(arg[self.axis])
170
- # zero out other lbs to not create lb reference
171
- return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
172
- return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
173
- self.axis, self.real)
174
-
175
- def stride(self, arg:Tuple[int, ...]):
176
- assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
177
- return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
@@ -1,39 +0,0 @@
1
- from typing import List, Dict, cast
2
- import ctypes
3
- from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
4
- from tinygrad.engine.jit import GraphRunner, GraphException
5
- from tinygrad.device import Buffer, Device
6
- from tinygrad.engine.realize import ExecItem, CompiledRunner
7
- from tinygrad.ops import Variable
8
- from tinygrad.runtime.ops_clang import ClangProgram
9
- from tinygrad.renderer.cstyle import ClangRenderer
10
- render_dtype = ClangRenderer().render_dtype
11
-
12
- class ClangGraph(GraphRunner):
13
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
14
- super().__init__(jit_cache, input_rawbuffers, var_vals)
15
- if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
16
-
17
- prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
18
- args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
19
- args += sorted([f"int {v.expr}" for v in var_vals])
20
- code = ["void batched("+','.join(args)+") {"]
21
- for ji in jit_cache:
22
- args = []
23
- for buf in ji.bufs:
24
- assert buf is not None
25
- if buf in input_rawbuffers:
26
- args.append(f"arg{input_rawbuffers.index(buf)}")
27
- else:
28
- args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
29
- args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
30
- code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
31
- code.append("}")
32
- if DEBUG >= 4: print("\n".join(code))
33
- compiler = Device["CLANG"].compiler
34
- assert compiler is not None
35
- self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
36
-
37
- def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
38
- return cpu_time_execution(
39
- lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)