tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,539 @@
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Union
3
+ import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
4
+ from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
5
+ from tinygrad.renderer import Renderer
6
+ from tinygrad.device import BufferOptions, Allocator, Compiler, Compiled, LRUAllocator
7
+
8
+ # **************** for HCQ Compatible Devices ****************
9
+
10
+ def hcq_command(func):
11
+ """
12
+ Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
13
+
14
+ For example:
15
+ ```python
16
+ @hcq_command
17
+ def command_method(self, ...): ...
18
+ ```
19
+ """
20
+ def __wrapper(self, *args, **kwargs):
21
+ self.cmds_offset.append(len(self.q))
22
+ func(self, *args, **kwargs)
23
+ self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
24
+ self.cmds_meta.append(func.__name__)
25
+ return self
26
+ return __wrapper
27
+
28
+ class HWCommandQueue:
29
+ """
30
+ A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
31
+ Both compute and copy queues should have the following commands implemented.
32
+ """
33
+
34
+ def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
35
+ def __len__(self): return len(self.cmds_offset)
36
+ def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
37
+ def _cur_cmd_idx(self) -> int:
38
+ """
39
+ Returns the index of the command currently being enqueued.
40
+ Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
41
+ """
42
+ return len(self) - 1
43
+
44
+ @hcq_command
45
+ def signal(self, signal:HCQSignal, value:int):
46
+ """
47
+ Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
48
+
49
+ Args:
50
+ signal: The signal to set
51
+ value: The value to set the signal to
52
+ """
53
+ self._signal(signal, value)
54
+ def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
55
+
56
+ @hcq_command
57
+ def wait(self, signal:HCQSignal, value:int):
58
+ """
59
+ Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
60
+
61
+ Args:
62
+ signal: The signal to wait on
63
+ value: The value to wait for
64
+ """
65
+ self._wait(signal, value)
66
+ def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
67
+
68
+ @hcq_command
69
+ def timestamp(self, signal:HCQSignal):
70
+ """
71
+ Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
72
+
73
+ Args:
74
+ signal: The signal to store the timestamp
75
+ """
76
+ self._timestamp(signal)
77
+ def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
78
+
79
+ def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
80
+ """
81
+ Updates a previously queued signal command.
82
+
83
+ Args:
84
+ cmd_idx: Index of the signal command to update
85
+ signal: New signal to set (if None, keeps the original)
86
+ value: New value to set (if None, keeps the original)
87
+ """
88
+ if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
89
+ self._update_signal(cmd_idx, signal, value)
90
+ return self
91
+ def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
92
+
93
+ def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
94
+ """
95
+ Updates a previously queued wait command.
96
+
97
+ Args:
98
+ cmd_idx: Index of the wait command to update
99
+ signal: New signal to wait on (if None, keeps the original)
100
+ value: New value to wait for (if None, keeps the original)
101
+ """
102
+ if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
103
+ self._update_wait(cmd_idx, signal, value)
104
+ return self
105
+ def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
106
+
107
+ def bind(self, device:HCQCompiled):
108
+ """
109
+ Associates the queue with a specific device for optimized execution.
110
+
111
+ This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
112
+ the need to copy queues into the device, thereby enhancing performance.
113
+
114
+ Args:
115
+ device: The target device for queue optimization.
116
+
117
+ Note:
118
+ Implementing this method is optional but recommended for performance gains.
119
+ """
120
+
121
+ def submit(self, device:HCQCompiled):
122
+ """
123
+ Submits the command queue to a specific device for execution.
124
+
125
+ Args:
126
+ device: The device to submit the queue to
127
+ """
128
+ if self.q: self._submit(device)
129
+ return self
130
+ def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
131
+
132
+ class HWComputeQueue(HWCommandQueue):
133
+ @hcq_command
134
+ def memory_barrier(self):
135
+ """
136
+ Enqueues a memory barrier command to ensure memory coherence between agents.
137
+ """
138
+ self._memory_barrier()
139
+ def _memory_barrier(self): pass
140
+
141
+ @hcq_command
142
+ def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
143
+ """
144
+ Enqueues an execution command for a kernel program.
145
+
146
+ Args:
147
+ prg: The program to execute
148
+ args_state: The args state to execute program with
149
+ global_size: The global work size
150
+ local_size: The local work size
151
+ """
152
+ self._exec(prg, args_state, global_size, local_size)
153
+ def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
154
+
155
+ def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
156
+ """
157
+ Updates a previously queued execution command.
158
+
159
+ Args:
160
+ cmd_idx: Index of the execution command to update
161
+ global_size: New global work size (if None, keeps the original)
162
+ local_size: New local work size (if None, keeps the original)
163
+ """
164
+ if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
165
+ self._update_exec(cmd_idx, global_size, local_size)
166
+ return self
167
+ def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
168
+
169
+ class HWCopyQueue(HWCommandQueue):
170
+ @hcq_command
171
+ def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
172
+ """
173
+ Enqueues a copy command to transfer data.
174
+
175
+ Args:
176
+ dest: The destination of the copy
177
+ src: The source of the copy
178
+ copy_size: The size of data to copy
179
+ """
180
+ self._copy(dest, src, copy_size)
181
+ def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
182
+
183
+ def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
184
+ """
185
+ Updates a previously queued copy command.
186
+
187
+ Args:
188
+ cmd_idx: Index of the copy command to update
189
+ dest: New destination of the copy (if None, keeps the original)
190
+ src: New source of the copy (if None, keeps the original)
191
+ """
192
+ if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
193
+ self._update_copy(cmd_idx, dest, src)
194
+ return self
195
+ def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
196
+
197
+ class HCQSignal:
198
+ def __init__(self, value:int=0, is_timeline:bool=False): self._set_value(value)
199
+
200
+ @property
201
+ def value(self) -> int: return self._get_value()
202
+
203
+ @value.setter
204
+ def value(self, new_value:int): self._set_value(new_value)
205
+
206
+ def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
207
+ def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
208
+
209
+ @property
210
+ def timestamp(self) -> decimal.Decimal:
211
+ """
212
+ Get the timestamp field of the signal.
213
+
214
+ This property provides read-only access to the signal's timestamp.
215
+
216
+ Returns:
217
+ The timestamp in microseconds.
218
+ """
219
+ return self._get_timestamp()
220
+ def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
221
+
222
+ def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
223
+ """
224
+ Waits the signal is greater than or equal to a specific value.
225
+
226
+ Args:
227
+ value: The value to wait for.
228
+ timeout: Maximum time to wait in milliseconds. Defaults to 10s.
229
+ """
230
+ start_time = time.time() * 1000
231
+ while time.time() * 1000 - start_time < timeout:
232
+ if self.value >= value: return
233
+ raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
234
+
235
+ @contextlib.contextmanager
236
+ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
237
+ st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
238
+
239
+ if enabled and queue is not None: queue.timestamp(st)
240
+ elif enabled:
241
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
242
+ dev.timeline_value += 1
243
+
244
+ try: yield (st, en)
245
+ finally:
246
+ if enabled and queue is not None: queue.timestamp(en)
247
+ elif enabled:
248
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
249
+ dev.timeline_value += 1
250
+
251
+ if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
252
+
253
+ class HCQArgsState:
254
+ def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
255
+ def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
256
+ def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
257
+
258
+ class HCQProgram:
259
+ def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
260
+ self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
261
+
262
+ def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
263
+ """
264
+ Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
265
+ Args:
266
+ bufs: Buffers to be written to kernel arguments.
267
+ vals: Values to be written to kernel arguments.
268
+ kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
269
+ Returns:
270
+ Arguments state with the given buffers and values set for the program.
271
+ """
272
+ return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
273
+
274
+ def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
275
+ vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
276
+ """
277
+ Enqueues the program for execution with the given arguments and dimensions.
278
+
279
+ Args:
280
+ bufs: Buffer arguments to execute the kernel with.
281
+ global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
282
+ local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
283
+ vals: Value arguments to execute the kernel with.
284
+ wait: If True, waits for the kernel to complete execution.
285
+
286
+ Returns:
287
+ Execution time of the kernel if 'wait' is True, otherwise None.
288
+ """
289
+
290
+ kernargs = self.fill_kernargs(bufs, vals)
291
+ q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
292
+
293
+ with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
294
+ q.exec(self, kernargs, global_size, local_size)
295
+
296
+ q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
297
+ self.device.timeline_value += 1
298
+
299
+ if wait: self.device.synchronize()
300
+ return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
301
+
302
+ class ProfileLogger:
303
+ writers: int = 0
304
+ mjson: List[Dict] = []
305
+ actors: Dict[Union[str, Tuple[str, str]], int] = {}
306
+
307
+ def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
308
+
309
+ def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
310
+
311
+ def _ensure_actor(self, actor_name, subactor_name):
312
+ if actor_name not in self.actors:
313
+ self.actors[actor_name] = (pid:=len(self.actors))
314
+ self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
315
+
316
+ if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
317
+ self.actors[subactor_key] = (tid:=len(self.actors))
318
+ self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
319
+
320
+ return self.actors[actor_name], self.actors.get(subactor_key, -1)
321
+
322
+ def __del__(self):
323
+ # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
324
+ for name, st, et, actor_name, subactor_name, args in self.events:
325
+ pid, tid = self._ensure_actor(actor_name,subactor_name)
326
+ args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
327
+ self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
328
+
329
+ for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
330
+ dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
331
+ pid, tid = self._ensure_actor(actor_name,subactor_name)
332
+ self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
333
+ self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
334
+
335
+ ProfileLogger.writers -= 1
336
+ if ProfileLogger.writers == 0 and len(self.mjson) > 0:
337
+ with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
338
+ print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
339
+
340
+ class HCQCompiled(Compiled):
341
+ """
342
+ A base class for devices compatible with the HCQ (Hardware Command Queue) API.
343
+ """
344
+ devices: List[HCQCompiled] = []
345
+ gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
346
+ gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
347
+
348
+ def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
349
+ comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]]):
350
+ self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
351
+ self.timeline_value:int = 1
352
+ self.timeline_signal, self._shadow_timeline_signal = self.signal_t(0, is_timeline=True), self.signal_t(0, is_timeline=True)
353
+ self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
354
+ self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
355
+ self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
356
+ if PROFILE: self._prof_setup()
357
+
358
+ from tinygrad.runtime.graph.hcq import HCQGraph
359
+ super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
360
+
361
+ self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
362
+ self.kernargs_ptr:int = self.kernargs_page.va_addr
363
+ self.devices.append(self)
364
+
365
+ def synchronize(self):
366
+ try: self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
367
+ except RuntimeError as e:
368
+ if hasattr(self, 'on_device_hang'): self.on_device_hang()
369
+ else: raise e
370
+
371
+ if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
372
+ if PROFILE:
373
+ self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
374
+ self.sig_prof_records = []
375
+
376
+ def _alloc_kernargs(self, alloc_size:int) -> int:
377
+ """
378
+ Allocates space for arguments passed to the kernel.
379
+ """
380
+ if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
381
+ self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
382
+ return res
383
+
384
+ def _ensure_shared_time_base(self):
385
+ if not self.gpu2cpu_compute_time_diff.is_nan(): return
386
+
387
+ def _sync_cpu_queue(d, q_t):
388
+ q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
389
+ d.timeline_value += 1
390
+ st = time.perf_counter_ns()
391
+ d.timeline_signal.wait(d.timeline_value - 1) # average of the two
392
+ et = time.perf_counter_ns()
393
+ return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
394
+
395
+ # randomly sample the timing from GPU to CPU
396
+ choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
397
+ choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
398
+ for _ in range(100*len(self.devices)):
399
+ d,q,l = random.choice(choices)
400
+ l.append(_sync_cpu_queue(d,q))
401
+ for d,q,l in choices:
402
+ if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
403
+ if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
404
+
405
+ def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
406
+ q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
407
+ .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
408
+ q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
409
+ .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
410
+ d1.timeline_value += 2
411
+ d2.timeline_value += 2
412
+ d1.timeline_signal.wait(d1.timeline_value - 1)
413
+ d2.timeline_signal.wait(d2.timeline_value - 1)
414
+ return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
415
+
416
+ # then test it by timing the GPU to GPU times
417
+ jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
418
+ for i1, d1 in enumerate(self.devices):
419
+ for i2, d2 in enumerate(self.devices):
420
+ if d1 == d2: continue
421
+ d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
422
+ _sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
423
+ jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
424
+ print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
425
+
426
+ def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
427
+ """
428
+ Translates local gpu time (timestamp) into global cpu time.
429
+ """
430
+ self._ensure_shared_time_base()
431
+ return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
432
+
433
+ def _prof_setup(self):
434
+ if hasattr(self, 'profile_logger'): return
435
+ atexit.register(self._prof_finalize)
436
+ self.profile_logger = ProfileLogger()
437
+
438
+ def _prof_finalize(self):
439
+ qname = ["COMPUTE", "DMA"]
440
+
441
+ # Sync to be sure all events on the device are recorded.
442
+ self.synchronize()
443
+
444
+ for st, en, name, is_cp, args in self.raw_prof_records:
445
+ self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
446
+ for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
447
+ # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
448
+ a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
449
+ self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
450
+ self.raw_prof_records, self.dep_prof_records = [], []
451
+
452
+ # Remove the logger, this flushes all data written by the device.
453
+ del self.profile_logger
454
+
455
+ def _wrap_timeline_signal(self):
456
+ self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
457
+ self.timeline_signal.value = 0
458
+ cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
459
+
460
+ # Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
461
+ class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
462
+
463
+ class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
464
+ """
465
+ A base allocator class compatible with the HCQ (Hardware Command Queue) API.
466
+
467
+ This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
468
+ """
469
+
470
+ def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
471
+ self.device:Any = device
472
+ self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
473
+ self.b_timeline, self.b_next = [0] * len(self.b), 0
474
+ super().__init__()
475
+
476
+ def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
477
+
478
+ def copyin(self, dest:HCQBuffer, src:memoryview):
479
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
480
+ for i in range(0, src.nbytes, self.b[0].size):
481
+ self.b_next = (self.b_next + 1) % len(self.b)
482
+ self.device.timeline_signal.wait(self.b_timeline[self.b_next])
483
+ ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
484
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
485
+ .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
486
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
487
+ self.b_timeline[self.b_next] = self.device.timeline_value
488
+ self.device.timeline_value += 1
489
+
490
+ def copy_from_disk(self, dest:HCQBuffer, src, size):
491
+ def _get_temp_buf():
492
+ # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
493
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
494
+ self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
495
+ return (self.b[self.b_next].va_addr, self.b_next)
496
+ return None
497
+
498
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
499
+ for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
500
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
501
+ .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
502
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
503
+ self.b_timeline[batch_info[1]] = self.device.timeline_value
504
+ self.device.timeline_value += 1
505
+
506
+ def copyout(self, dest:memoryview, src:HCQBuffer):
507
+ self.device.synchronize()
508
+
509
+ with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
510
+ for i in range(0, dest.nbytes, self.b[0].size):
511
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
512
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
513
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
514
+ self.device.timeline_signal.wait(self.device.timeline_value)
515
+ self.device.timeline_value += 1
516
+
517
+ ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
518
+
519
+ def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
520
+ src_dev.allocator.map(dest)
521
+
522
+ with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
523
+ src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
524
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
525
+ .copy(dest.va_addr, src.va_addr, sz) \
526
+ .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
527
+ src_dev.timeline_value += 1
528
+
529
+ if src_dev != dest_dev:
530
+ dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
531
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
532
+ .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
533
+ dest_dev.timeline_value += 1
534
+
535
+ def map(self, buf:HCQBuffer): pass
536
+
537
+ def offset(self, buf, size:int, offset:int) -> HCQBuffer:
538
+ return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
539
+ **{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
@@ -1,21 +1,13 @@
1
1
  # ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
2
2
  from __future__ import annotations
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
4
+ from typing import Tuple, List, Optional, Dict, Set
5
5
  from tinygrad.helpers import merge_dicts, getenv
6
- from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
7
6
  from tinygrad.shape.view import View, strides_for_shape
7
+ from tinygrad.dtype import dtypes
8
+ from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
8
9
 
9
- def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
10
- assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
11
- iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
12
- vexpr: List[Node] = [valid] if valid is not None else []
13
- for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
14
- if sh != 1 and st != 0: iexpr.append(idx*st)
15
- if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
16
- return Node.sum(iexpr), Node.ands(vexpr)
17
-
18
- @dataclass(frozen=True)
10
+ @dataclass(frozen=True, order=True)
19
11
  class ShapeTracker:
20
12
  views: Tuple[View, ...]
21
13
 
@@ -32,7 +24,7 @@ class ShapeTracker:
32
24
  return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
33
25
 
34
26
  @staticmethod
35
- def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
27
+ def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
36
28
 
37
29
  @property
38
30
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@@ -46,17 +38,29 @@ class ShapeTracker:
46
38
  @property
47
39
  def size(self) -> int: return self.views[-1].size()
48
40
 
41
+ def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
42
+
43
+ def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
44
+
45
+ def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
46
+ idx, valid = self.views[-1].to_indexed_uops(_idxs)
47
+ for view in reversed(self.views[0:-1]):
48
+ view = view.minify()
49
+ acc, idxs = 1, []
50
+ for d in reversed(view.shape):
51
+ idxs.append((idx//acc)%d)
52
+ acc *= d
53
+ idx, valid = view.to_indexed_uops(idxs[::-1], valid)
54
+ return idx, valid
55
+
49
56
  def real_size(self) -> int:
50
57
  if 0 in self.shape: return 0
51
- idx, valid = self.expr_idxs()
52
- if not valid: return 0
53
- # TODO: it's possible that the real_size is smaller condition on valid being true
54
- ret = idx.max
55
- if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
56
- assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
57
- return ret+1
58
+ idx, valid = self.to_indexed_uops()
59
+ if not valid.vmax: return 0
60
+ assert idx.vmax < 1e12, f"real_size broken for {self}"
61
+ return int(idx.vmax+1)
58
62
 
59
- def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
63
+ def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
60
64
 
61
65
  @property
62
66
  def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
@@ -68,40 +72,26 @@ class ShapeTracker:
68
72
  # NOTE: if a stride is not always valid, it will be None
69
73
  def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
70
74
  if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
71
- idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
72
- idx, valid = self.expr_idxs(idxs)
73
- ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
74
- bad_idx_vars: Set[Variable] = set()
75
- for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
76
- idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
77
- try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
78
- except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
79
- idx_vars, valid_vars = idx.vars(), valid.vars()
80
- for i,tidx in enumerate(idxs):
81
- if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
82
- elif tidx not in idx_vars: ret[i] = 0
75
+ ret: List[Optional[sint]] = [None] * len(self.shape)
76
+ idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
77
+ # TODO: always apply these in to_indexed_uops?
78
+ if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
79
+ if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
80
+ for c in split_uop(idx, Ops.ADD):
81
+ if c.op is Ops.RANGE: ret[c.arg[0]] = 1
82
+ if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
83
+ if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
84
+ used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
85
+ ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
86
+ if not ignore_valid:
87
+ for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
83
88
  return tuple(ret)
84
89
 
85
90
  def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
86
91
 
87
- def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
88
- idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
89
- idx, valid = _expr_view(self.views[-1], idxs)
90
- for view in reversed(self.views[0:-1]):
91
- if valid.max == 0: return NumNode(-1), valid
92
- view = view.minify()
93
- acc, idxs = 1, []
94
- for d in reversed(view.shape):
95
- idxs.append((idx//acc)%d)
96
- acc *= d
97
- idx, valid = _expr_view(view, idxs[::-1], valid)
98
- assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
99
- assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
100
- return idx, valid
101
-
102
92
  def axis_is_masked(self, axis:int) -> bool:
103
- _, valid = self.expr_idxs()
104
- return f'idx{axis}' in [v.expr for v in valid.vars()]
93
+ _, valid = self.to_indexed_uops()
94
+ return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
105
95
 
106
96
  def simplify(self) -> ShapeTracker:
107
97
  if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: