tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,539 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Union
|
3
|
+
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
|
4
|
+
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
|
5
|
+
from tinygrad.renderer import Renderer
|
6
|
+
from tinygrad.device import BufferOptions, Allocator, Compiler, Compiled, LRUAllocator
|
7
|
+
|
8
|
+
# **************** for HCQ Compatible Devices ****************
|
9
|
+
|
10
|
+
def hcq_command(func):
|
11
|
+
"""
|
12
|
+
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
|
13
|
+
|
14
|
+
For example:
|
15
|
+
```python
|
16
|
+
@hcq_command
|
17
|
+
def command_method(self, ...): ...
|
18
|
+
```
|
19
|
+
"""
|
20
|
+
def __wrapper(self, *args, **kwargs):
|
21
|
+
self.cmds_offset.append(len(self.q))
|
22
|
+
func(self, *args, **kwargs)
|
23
|
+
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
|
24
|
+
self.cmds_meta.append(func.__name__)
|
25
|
+
return self
|
26
|
+
return __wrapper
|
27
|
+
|
28
|
+
class HWCommandQueue:
|
29
|
+
"""
|
30
|
+
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
31
|
+
Both compute and copy queues should have the following commands implemented.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
|
35
|
+
def __len__(self): return len(self.cmds_offset)
|
36
|
+
def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
|
37
|
+
def _cur_cmd_idx(self) -> int:
|
38
|
+
"""
|
39
|
+
Returns the index of the command currently being enqueued.
|
40
|
+
Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
|
41
|
+
"""
|
42
|
+
return len(self) - 1
|
43
|
+
|
44
|
+
@hcq_command
|
45
|
+
def signal(self, signal:HCQSignal, value:int):
|
46
|
+
"""
|
47
|
+
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
signal: The signal to set
|
51
|
+
value: The value to set the signal to
|
52
|
+
"""
|
53
|
+
self._signal(signal, value)
|
54
|
+
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
|
55
|
+
|
56
|
+
@hcq_command
|
57
|
+
def wait(self, signal:HCQSignal, value:int):
|
58
|
+
"""
|
59
|
+
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
signal: The signal to wait on
|
63
|
+
value: The value to wait for
|
64
|
+
"""
|
65
|
+
self._wait(signal, value)
|
66
|
+
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
67
|
+
|
68
|
+
@hcq_command
|
69
|
+
def timestamp(self, signal:HCQSignal):
|
70
|
+
"""
|
71
|
+
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
signal: The signal to store the timestamp
|
75
|
+
"""
|
76
|
+
self._timestamp(signal)
|
77
|
+
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
78
|
+
|
79
|
+
def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
80
|
+
"""
|
81
|
+
Updates a previously queued signal command.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
cmd_idx: Index of the signal command to update
|
85
|
+
signal: New signal to set (if None, keeps the original)
|
86
|
+
value: New value to set (if None, keeps the original)
|
87
|
+
"""
|
88
|
+
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
89
|
+
self._update_signal(cmd_idx, signal, value)
|
90
|
+
return self
|
91
|
+
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
92
|
+
|
93
|
+
def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
94
|
+
"""
|
95
|
+
Updates a previously queued wait command.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
cmd_idx: Index of the wait command to update
|
99
|
+
signal: New signal to wait on (if None, keeps the original)
|
100
|
+
value: New value to wait for (if None, keeps the original)
|
101
|
+
"""
|
102
|
+
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
103
|
+
self._update_wait(cmd_idx, signal, value)
|
104
|
+
return self
|
105
|
+
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
106
|
+
|
107
|
+
def bind(self, device:HCQCompiled):
|
108
|
+
"""
|
109
|
+
Associates the queue with a specific device for optimized execution.
|
110
|
+
|
111
|
+
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
|
112
|
+
the need to copy queues into the device, thereby enhancing performance.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
device: The target device for queue optimization.
|
116
|
+
|
117
|
+
Note:
|
118
|
+
Implementing this method is optional but recommended for performance gains.
|
119
|
+
"""
|
120
|
+
|
121
|
+
def submit(self, device:HCQCompiled):
|
122
|
+
"""
|
123
|
+
Submits the command queue to a specific device for execution.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
device: The device to submit the queue to
|
127
|
+
"""
|
128
|
+
if self.q: self._submit(device)
|
129
|
+
return self
|
130
|
+
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
|
131
|
+
|
132
|
+
class HWComputeQueue(HWCommandQueue):
|
133
|
+
@hcq_command
|
134
|
+
def memory_barrier(self):
|
135
|
+
"""
|
136
|
+
Enqueues a memory barrier command to ensure memory coherence between agents.
|
137
|
+
"""
|
138
|
+
self._memory_barrier()
|
139
|
+
def _memory_barrier(self): pass
|
140
|
+
|
141
|
+
@hcq_command
|
142
|
+
def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
143
|
+
"""
|
144
|
+
Enqueues an execution command for a kernel program.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
prg: The program to execute
|
148
|
+
args_state: The args state to execute program with
|
149
|
+
global_size: The global work size
|
150
|
+
local_size: The local work size
|
151
|
+
"""
|
152
|
+
self._exec(prg, args_state, global_size, local_size)
|
153
|
+
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
154
|
+
|
155
|
+
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
|
156
|
+
"""
|
157
|
+
Updates a previously queued execution command.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
cmd_idx: Index of the execution command to update
|
161
|
+
global_size: New global work size (if None, keeps the original)
|
162
|
+
local_size: New local work size (if None, keeps the original)
|
163
|
+
"""
|
164
|
+
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
|
165
|
+
self._update_exec(cmd_idx, global_size, local_size)
|
166
|
+
return self
|
167
|
+
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
168
|
+
|
169
|
+
class HWCopyQueue(HWCommandQueue):
|
170
|
+
@hcq_command
|
171
|
+
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
|
172
|
+
"""
|
173
|
+
Enqueues a copy command to transfer data.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
dest: The destination of the copy
|
177
|
+
src: The source of the copy
|
178
|
+
copy_size: The size of data to copy
|
179
|
+
"""
|
180
|
+
self._copy(dest, src, copy_size)
|
181
|
+
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
|
182
|
+
|
183
|
+
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
|
184
|
+
"""
|
185
|
+
Updates a previously queued copy command.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
cmd_idx: Index of the copy command to update
|
189
|
+
dest: New destination of the copy (if None, keeps the original)
|
190
|
+
src: New source of the copy (if None, keeps the original)
|
191
|
+
"""
|
192
|
+
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
193
|
+
self._update_copy(cmd_idx, dest, src)
|
194
|
+
return self
|
195
|
+
def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
|
196
|
+
|
197
|
+
class HCQSignal:
|
198
|
+
def __init__(self, value:int=0, is_timeline:bool=False): self._set_value(value)
|
199
|
+
|
200
|
+
@property
|
201
|
+
def value(self) -> int: return self._get_value()
|
202
|
+
|
203
|
+
@value.setter
|
204
|
+
def value(self, new_value:int): self._set_value(new_value)
|
205
|
+
|
206
|
+
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
|
207
|
+
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
|
208
|
+
|
209
|
+
@property
|
210
|
+
def timestamp(self) -> decimal.Decimal:
|
211
|
+
"""
|
212
|
+
Get the timestamp field of the signal.
|
213
|
+
|
214
|
+
This property provides read-only access to the signal's timestamp.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
The timestamp in microseconds.
|
218
|
+
"""
|
219
|
+
return self._get_timestamp()
|
220
|
+
def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
|
221
|
+
|
222
|
+
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
223
|
+
"""
|
224
|
+
Waits the signal is greater than or equal to a specific value.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
value: The value to wait for.
|
228
|
+
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
229
|
+
"""
|
230
|
+
start_time = time.time() * 1000
|
231
|
+
while time.time() * 1000 - start_time < timeout:
|
232
|
+
if self.value >= value: return
|
233
|
+
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
234
|
+
|
235
|
+
@contextlib.contextmanager
|
236
|
+
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
237
|
+
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
238
|
+
|
239
|
+
if enabled and queue is not None: queue.timestamp(st)
|
240
|
+
elif enabled:
|
241
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
242
|
+
dev.timeline_value += 1
|
243
|
+
|
244
|
+
try: yield (st, en)
|
245
|
+
finally:
|
246
|
+
if enabled and queue is not None: queue.timestamp(en)
|
247
|
+
elif enabled:
|
248
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
249
|
+
dev.timeline_value += 1
|
250
|
+
|
251
|
+
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
252
|
+
|
253
|
+
class HCQArgsState:
|
254
|
+
def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
|
255
|
+
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
256
|
+
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
257
|
+
|
258
|
+
class HCQProgram:
|
259
|
+
def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
|
260
|
+
self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
|
261
|
+
|
262
|
+
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
|
263
|
+
"""
|
264
|
+
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
265
|
+
Args:
|
266
|
+
bufs: Buffers to be written to kernel arguments.
|
267
|
+
vals: Values to be written to kernel arguments.
|
268
|
+
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
|
269
|
+
Returns:
|
270
|
+
Arguments state with the given buffers and values set for the program.
|
271
|
+
"""
|
272
|
+
return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
|
273
|
+
|
274
|
+
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
|
275
|
+
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
|
276
|
+
"""
|
277
|
+
Enqueues the program for execution with the given arguments and dimensions.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
bufs: Buffer arguments to execute the kernel with.
|
281
|
+
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
|
282
|
+
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
|
283
|
+
vals: Value arguments to execute the kernel with.
|
284
|
+
wait: If True, waits for the kernel to complete execution.
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
Execution time of the kernel if 'wait' is True, otherwise None.
|
288
|
+
"""
|
289
|
+
|
290
|
+
kernargs = self.fill_kernargs(bufs, vals)
|
291
|
+
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
|
292
|
+
|
293
|
+
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
294
|
+
q.exec(self, kernargs, global_size, local_size)
|
295
|
+
|
296
|
+
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
297
|
+
self.device.timeline_value += 1
|
298
|
+
|
299
|
+
if wait: self.device.synchronize()
|
300
|
+
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
301
|
+
|
302
|
+
class ProfileLogger:
|
303
|
+
writers: int = 0
|
304
|
+
mjson: List[Dict] = []
|
305
|
+
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
306
|
+
|
307
|
+
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
308
|
+
|
309
|
+
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
310
|
+
|
311
|
+
def _ensure_actor(self, actor_name, subactor_name):
|
312
|
+
if actor_name not in self.actors:
|
313
|
+
self.actors[actor_name] = (pid:=len(self.actors))
|
314
|
+
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
315
|
+
|
316
|
+
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
317
|
+
self.actors[subactor_key] = (tid:=len(self.actors))
|
318
|
+
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
319
|
+
|
320
|
+
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
321
|
+
|
322
|
+
def __del__(self):
|
323
|
+
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
324
|
+
for name, st, et, actor_name, subactor_name, args in self.events:
|
325
|
+
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
326
|
+
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
327
|
+
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
328
|
+
|
329
|
+
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
330
|
+
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
331
|
+
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
332
|
+
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
333
|
+
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
334
|
+
|
335
|
+
ProfileLogger.writers -= 1
|
336
|
+
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
337
|
+
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
338
|
+
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
339
|
+
|
340
|
+
class HCQCompiled(Compiled):
|
341
|
+
"""
|
342
|
+
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
343
|
+
"""
|
344
|
+
devices: List[HCQCompiled] = []
|
345
|
+
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
346
|
+
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
347
|
+
|
348
|
+
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
|
349
|
+
comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]]):
|
350
|
+
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
351
|
+
self.timeline_value:int = 1
|
352
|
+
self.timeline_signal, self._shadow_timeline_signal = self.signal_t(0, is_timeline=True), self.signal_t(0, is_timeline=True)
|
353
|
+
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
354
|
+
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
|
355
|
+
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
356
|
+
if PROFILE: self._prof_setup()
|
357
|
+
|
358
|
+
from tinygrad.runtime.graph.hcq import HCQGraph
|
359
|
+
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
360
|
+
|
361
|
+
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
|
362
|
+
self.kernargs_ptr:int = self.kernargs_page.va_addr
|
363
|
+
self.devices.append(self)
|
364
|
+
|
365
|
+
def synchronize(self):
|
366
|
+
try: self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
|
367
|
+
except RuntimeError as e:
|
368
|
+
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
369
|
+
else: raise e
|
370
|
+
|
371
|
+
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
372
|
+
if PROFILE:
|
373
|
+
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
374
|
+
self.sig_prof_records = []
|
375
|
+
|
376
|
+
def _alloc_kernargs(self, alloc_size:int) -> int:
|
377
|
+
"""
|
378
|
+
Allocates space for arguments passed to the kernel.
|
379
|
+
"""
|
380
|
+
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
381
|
+
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
382
|
+
return res
|
383
|
+
|
384
|
+
def _ensure_shared_time_base(self):
|
385
|
+
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
386
|
+
|
387
|
+
def _sync_cpu_queue(d, q_t):
|
388
|
+
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
389
|
+
d.timeline_value += 1
|
390
|
+
st = time.perf_counter_ns()
|
391
|
+
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
|
392
|
+
et = time.perf_counter_ns()
|
393
|
+
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
394
|
+
|
395
|
+
# randomly sample the timing from GPU to CPU
|
396
|
+
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
|
397
|
+
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
|
398
|
+
for _ in range(100*len(self.devices)):
|
399
|
+
d,q,l = random.choice(choices)
|
400
|
+
l.append(_sync_cpu_queue(d,q))
|
401
|
+
for d,q,l in choices:
|
402
|
+
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
403
|
+
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
404
|
+
|
405
|
+
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
|
406
|
+
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
407
|
+
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
408
|
+
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
409
|
+
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
410
|
+
d1.timeline_value += 2
|
411
|
+
d2.timeline_value += 2
|
412
|
+
d1.timeline_signal.wait(d1.timeline_value - 1)
|
413
|
+
d2.timeline_signal.wait(d2.timeline_value - 1)
|
414
|
+
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
415
|
+
|
416
|
+
# then test it by timing the GPU to GPU times
|
417
|
+
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
418
|
+
for i1, d1 in enumerate(self.devices):
|
419
|
+
for i2, d2 in enumerate(self.devices):
|
420
|
+
if d1 == d2: continue
|
421
|
+
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
422
|
+
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
423
|
+
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
424
|
+
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
425
|
+
|
426
|
+
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
427
|
+
"""
|
428
|
+
Translates local gpu time (timestamp) into global cpu time.
|
429
|
+
"""
|
430
|
+
self._ensure_shared_time_base()
|
431
|
+
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
432
|
+
|
433
|
+
def _prof_setup(self):
|
434
|
+
if hasattr(self, 'profile_logger'): return
|
435
|
+
atexit.register(self._prof_finalize)
|
436
|
+
self.profile_logger = ProfileLogger()
|
437
|
+
|
438
|
+
def _prof_finalize(self):
|
439
|
+
qname = ["COMPUTE", "DMA"]
|
440
|
+
|
441
|
+
# Sync to be sure all events on the device are recorded.
|
442
|
+
self.synchronize()
|
443
|
+
|
444
|
+
for st, en, name, is_cp, args in self.raw_prof_records:
|
445
|
+
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
|
446
|
+
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
447
|
+
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
448
|
+
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
449
|
+
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
|
450
|
+
self.raw_prof_records, self.dep_prof_records = [], []
|
451
|
+
|
452
|
+
# Remove the logger, this flushes all data written by the device.
|
453
|
+
del self.profile_logger
|
454
|
+
|
455
|
+
def _wrap_timeline_signal(self):
|
456
|
+
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
457
|
+
self.timeline_signal.value = 0
|
458
|
+
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
|
459
|
+
|
460
|
+
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
|
461
|
+
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
|
462
|
+
|
463
|
+
class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
|
464
|
+
"""
|
465
|
+
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
466
|
+
|
467
|
+
This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
|
468
|
+
"""
|
469
|
+
|
470
|
+
def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
|
471
|
+
self.device:Any = device
|
472
|
+
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
473
|
+
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
474
|
+
super().__init__()
|
475
|
+
|
476
|
+
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
477
|
+
|
478
|
+
def copyin(self, dest:HCQBuffer, src:memoryview):
|
479
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
480
|
+
for i in range(0, src.nbytes, self.b[0].size):
|
481
|
+
self.b_next = (self.b_next + 1) % len(self.b)
|
482
|
+
self.device.timeline_signal.wait(self.b_timeline[self.b_next])
|
483
|
+
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
484
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
485
|
+
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
486
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
487
|
+
self.b_timeline[self.b_next] = self.device.timeline_value
|
488
|
+
self.device.timeline_value += 1
|
489
|
+
|
490
|
+
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
491
|
+
def _get_temp_buf():
|
492
|
+
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
493
|
+
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
|
494
|
+
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
495
|
+
return (self.b[self.b_next].va_addr, self.b_next)
|
496
|
+
return None
|
497
|
+
|
498
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
499
|
+
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
500
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
501
|
+
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
502
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
503
|
+
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
504
|
+
self.device.timeline_value += 1
|
505
|
+
|
506
|
+
def copyout(self, dest:memoryview, src:HCQBuffer):
|
507
|
+
self.device.synchronize()
|
508
|
+
|
509
|
+
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
510
|
+
for i in range(0, dest.nbytes, self.b[0].size):
|
511
|
+
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
512
|
+
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
513
|
+
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
514
|
+
self.device.timeline_signal.wait(self.device.timeline_value)
|
515
|
+
self.device.timeline_value += 1
|
516
|
+
|
517
|
+
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
518
|
+
|
519
|
+
def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
|
520
|
+
src_dev.allocator.map(dest)
|
521
|
+
|
522
|
+
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
523
|
+
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
524
|
+
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
525
|
+
.copy(dest.va_addr, src.va_addr, sz) \
|
526
|
+
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
527
|
+
src_dev.timeline_value += 1
|
528
|
+
|
529
|
+
if src_dev != dest_dev:
|
530
|
+
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
531
|
+
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
532
|
+
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
533
|
+
dest_dev.timeline_value += 1
|
534
|
+
|
535
|
+
def map(self, buf:HCQBuffer): pass
|
536
|
+
|
537
|
+
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
538
|
+
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
539
|
+
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -1,21 +1,13 @@
|
|
1
1
|
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
2
2
|
from __future__ import annotations
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Optional, Dict, Set
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set
|
5
5
|
from tinygrad.helpers import merge_dicts, getenv
|
6
|
-
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
|
7
6
|
from tinygrad.shape.view import View, strides_for_shape
|
7
|
+
from tinygrad.dtype import dtypes
|
8
|
+
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
8
9
|
|
9
|
-
|
10
|
-
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
11
|
-
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
|
12
|
-
vexpr: List[Node] = [valid] if valid is not None else []
|
13
|
-
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
14
|
-
if sh != 1 and st != 0: iexpr.append(idx*st)
|
15
|
-
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
16
|
-
return Node.sum(iexpr), Node.ands(vexpr)
|
17
|
-
|
18
|
-
@dataclass(frozen=True)
|
10
|
+
@dataclass(frozen=True, order=True)
|
19
11
|
class ShapeTracker:
|
20
12
|
views: Tuple[View, ...]
|
21
13
|
|
@@ -32,7 +24,7 @@ class ShapeTracker:
|
|
32
24
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
33
25
|
|
34
26
|
@staticmethod
|
35
|
-
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
27
|
+
def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
36
28
|
|
37
29
|
@property
|
38
30
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
@@ -46,17 +38,29 @@ class ShapeTracker:
|
|
46
38
|
@property
|
47
39
|
def size(self) -> int: return self.views[-1].size()
|
48
40
|
|
41
|
+
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
42
|
+
|
43
|
+
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
44
|
+
|
45
|
+
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
46
|
+
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
47
|
+
for view in reversed(self.views[0:-1]):
|
48
|
+
view = view.minify()
|
49
|
+
acc, idxs = 1, []
|
50
|
+
for d in reversed(view.shape):
|
51
|
+
idxs.append((idx//acc)%d)
|
52
|
+
acc *= d
|
53
|
+
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
54
|
+
return idx, valid
|
55
|
+
|
49
56
|
def real_size(self) -> int:
|
50
57
|
if 0 in self.shape: return 0
|
51
|
-
idx, valid = self.
|
52
|
-
if not valid: return 0
|
53
|
-
|
54
|
-
|
55
|
-
if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
|
56
|
-
assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
|
57
|
-
return ret+1
|
58
|
+
idx, valid = self.to_indexed_uops()
|
59
|
+
if not valid.vmax: return 0
|
60
|
+
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
61
|
+
return int(idx.vmax+1)
|
58
62
|
|
59
|
-
def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views]
|
63
|
+
def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
|
60
64
|
|
61
65
|
@property
|
62
66
|
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
@@ -68,40 +72,26 @@ class ShapeTracker:
|
|
68
72
|
# NOTE: if a stride is not always valid, it will be None
|
69
73
|
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
70
74
|
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
71
|
-
|
72
|
-
idx, valid = self.
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
for
|
81
|
-
|
82
|
-
|
75
|
+
ret: List[Optional[sint]] = [None] * len(self.shape)
|
76
|
+
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
77
|
+
# TODO: always apply these in to_indexed_uops?
|
78
|
+
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
79
|
+
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
80
|
+
for c in split_uop(idx, Ops.ADD):
|
81
|
+
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
82
|
+
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
83
|
+
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
84
|
+
used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
|
85
|
+
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
86
|
+
if not ignore_valid:
|
87
|
+
for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
|
83
88
|
return tuple(ret)
|
84
89
|
|
85
90
|
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
86
91
|
|
87
|
-
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
88
|
-
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
89
|
-
idx, valid = _expr_view(self.views[-1], idxs)
|
90
|
-
for view in reversed(self.views[0:-1]):
|
91
|
-
if valid.max == 0: return NumNode(-1), valid
|
92
|
-
view = view.minify()
|
93
|
-
acc, idxs = 1, []
|
94
|
-
for d in reversed(view.shape):
|
95
|
-
idxs.append((idx//acc)%d)
|
96
|
-
acc *= d
|
97
|
-
idx, valid = _expr_view(view, idxs[::-1], valid)
|
98
|
-
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
99
|
-
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
100
|
-
return idx, valid
|
101
|
-
|
102
92
|
def axis_is_masked(self, axis:int) -> bool:
|
103
|
-
_, valid = self.
|
104
|
-
return
|
93
|
+
_, valid = self.to_indexed_uops()
|
94
|
+
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
|
105
95
|
|
106
96
|
def simplify(self) -> ShapeTracker:
|
107
97
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|