tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/jit.py DELETED
@@ -1,152 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
3
- import functools, itertools, operator
4
- from tinygrad.dtype import DType
5
- from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH
6
- from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
7
- from tinygrad.tensor import Tensor
8
- from tinygrad.lazy import LazyBuffer
9
- from tinygrad.shape.shapetracker import ShapeTracker
10
- from tinygrad.shape.symbolic import Variable, NumNode, Node
11
- from weakref import ref, WeakKeyDictionary
12
- from dataclasses import dataclass
13
-
14
- @dataclass(frozen=True)
15
- class JitItem:
16
- prg: JITRunner # or a graph executor like MetalGraph
17
- rawbufs: List[Optional[Buffer]]
18
-
19
- def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
20
- return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) # noqa: E501
21
- def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
22
- input_replace: Dict[Tuple[int, int], int] = {}
23
- for j,ji in enumerate(jit_cache):
24
- for i,a in enumerate(ji.rawbufs):
25
- if a in input_rawbuffers:
26
- input_replace[(j,i)] = input_rawbuffers.index(a)
27
- return input_replace
28
- def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
29
- return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501
30
- def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
31
- return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
32
-
33
- class GraphException(Exception): pass
34
-
35
- ReturnType = TypeVar('ReturnType')
36
- class TinyJit(Generic[ReturnType]):
37
- def __init__(self, fxn:Callable[..., ReturnType]):
38
- self.fxn = fxn
39
- self.reset()
40
-
41
- def reset(self):
42
- self.jit_cache: List[JitItem] = []
43
- self.input_replace: Dict[Tuple[int, int], int] = {}
44
- self.cnt: int = 0
45
- self.ret: Optional[ReturnType] = None
46
- self.expected_vals: Optional[Tuple[Variable, ...]] = None
47
- self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
48
-
49
- # add support for instance methods
50
- def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
51
-
52
- def __call__(self, *args, **kwargs) -> ReturnType:
53
- # all inputs (except const) are realized
54
- input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
55
- assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported"
56
- expected_name_sts_dtype = tuple([(k, v.st.unbind(), v.dtype) for k,v in input_tensors.items()])
57
-
58
- # get rawbuffers
59
- # TODO: why can .realized have Any type?
60
- input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None]
61
- assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
62
-
63
- # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
64
- var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
65
- expected_vals = tuple(var_vals.keys())
66
-
67
- if self.cnt >= 2:
68
- # jit exec
69
- assert self.expected_vals == expected_vals, "mismatch of var_vals"
70
- assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" # noqa: E501
71
- for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
72
- for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
73
- elif self.cnt == 1:
74
- # jit capture
75
- self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
76
- CacheCollector.start(var_vals)
77
- with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
78
- self.ret = self.fxn(*args, **kwargs)
79
- self.jit_cache = CacheCollector.finish()
80
- assert len(self.jit_cache) != 0, "didn't JIT anything!"
81
- if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
82
- print("WARNING: some input tensors not found")
83
- if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
84
-
85
- # if your Device supports it, condense the items into a graph executor.
86
- if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2 and len(self.jit_cache) > 1:
87
- # Split JIT cache into batches for faster graph execution.
88
- # This allows the accelerator to run some batches while subsequent graphs are still being updated.
89
- graphed_jit_cache, current_batch = [], []
90
- for i,ji in enumerate(self.jit_cache):
91
- # If the jit item can potentially be graphed, put it in a batch.
92
- if isinstance(ji.prg, CompiledASTRunner): current_batch.append(ji)
93
-
94
- # The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
95
- # (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
96
- if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)): # noqa: E501
97
- try:
98
- graphed_jit_cache.append(JitItem(make_graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers)))
99
- if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels")
100
- except GraphException as e:
101
- graphed_jit_cache.extend(current_batch)
102
- if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels: {e}")
103
- current_batch = []
104
-
105
- # If the jit item cannot be graphed, put it right into the final cache after the flush.
106
- if not isinstance(ji.prg, CompiledASTRunner): graphed_jit_cache.append(ji)
107
-
108
- self.jit_cache = graphed_jit_cache
109
-
110
- self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
111
- elif self.cnt == 0:
112
- # jit ignore
113
- self.ret = self.fxn(*args, **kwargs)
114
-
115
- # clear jit inputs
116
- for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
117
-
118
- self.cnt += 1
119
- return cast(ReturnType, self.ret)
120
-
121
- class PlaceHolder:
122
- def __init__(self, buf:Buffer): self.size, self.dtype, self.device, self.ref, self.bufid = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf)
123
- def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid)
124
- def __hash__(self): return hash(self.to_tuple())
125
- def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
126
- def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
127
- ret = self.ref()
128
- if ret: return ret
129
- if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype)
130
- return buffer_cache[self]
131
-
132
- class _CacheCollector:
133
- def __init__(self):
134
- self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
135
-
136
- def start(self, var_vals:Optional[Dict[Variable, int]]=None):
137
- self.cache = []
138
- self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
139
- self.var_vals = var_vals if var_vals is not None else {}
140
-
141
- def add(self, prg, rawbufs, var_vals):
142
- if self.cache is None: return
143
- for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
144
- self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
145
- self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
146
-
147
- def finish(self) -> List[JitItem]:
148
- if self.cache is None: return []
149
- buffer_cache: Dict[PlaceHolder, Buffer] = {}
150
- saved_cache, self.cache = self.cache, None
151
- return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
152
- CacheCollector = _CacheCollector()
tinygrad/realize.py DELETED
@@ -1,50 +0,0 @@
1
- from typing import List, Dict, Optional, cast
2
- from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
3
- from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
4
- from tinygrad.graph import print_tree, realized_lazybuffer
5
- from tinygrad.helpers import colored, getenv
6
- from tinygrad.shape.symbolic import Variable
7
-
8
- # *** schedule running ***
9
-
10
- class CustomOp(JITRunner):
11
- def __init__(self, fxn):
12
- self.fxn = fxn
13
- super().__init__()
14
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
15
-
16
- def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
17
- assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
18
- if si.ast.op is LoadOps.EMPTY: return None
19
- if si.ast.op is LoadOps.COPY: return BufferCopy
20
- if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
21
- return Device[si.out.device].get_runner(si.ast)
22
-
23
- logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
24
- def run_schedule(schedule:List[ScheduleItem]):
25
- while len(schedule):
26
- si = schedule.pop(0)
27
- if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
28
-
29
- # get the program
30
- prg = lower_schedule_item(si)
31
-
32
- # invalidate the output buffer if there's a non contig usage of it in inputs
33
- if si.out.output_buffer is not None:
34
- for i,a in enumerate(si.inputs):
35
- if a.realized == si.out.output_buffer:
36
- if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op == BufferOps.LOAD and x.arg.idx == i+1):
37
- si.out.output_buffer = None
38
- break
39
-
40
- # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
41
- si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
42
- Buffer(si.out.device, si.out.size, si.out.dtype,
43
- "PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
44
- del si.out.srcs
45
-
46
- # run the function (put it in JIT)
47
- assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
48
- if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
49
- else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
50
- realized_lazybuffer(si.out, GlobalCounters.kernel_count)
@@ -1,24 +0,0 @@
1
- import ctypes
2
- from typing import Tuple
3
- import gpuctypes.hip as hip
4
- from tinygrad.helpers import init_c_var
5
- from tinygrad.runtime.ops_hip import check, hip_time_execution
6
- from tinygrad.runtime.graph.cuda import CUDAGraph
7
-
8
- class HIPGraph(CUDAGraph):
9
- def __del__(self):
10
- check(hip.hipGraphDestroy(self.graph))
11
- check(hip.hipGraphExecDestroy(self.instance))
12
-
13
- def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
14
- def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
15
- def graph_instantiate(self, graph):
16
- return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
17
- def graph_add_kernel_node(self, graph, c_deps, c_params):
18
- return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
19
- def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
20
- def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
21
- def build_kernel_node_params(self, prg, global_size, local_size, c_config):
22
- return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
23
- def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
24
- node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
@@ -1,45 +0,0 @@
1
- import numpy as np
2
- from typing import Callable, Dict, Tuple
3
- from tinygrad.helpers import flat_mv
4
- from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
5
- from tinygrad.device import Interpreted, Allocator
6
-
7
- def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
8
- assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
9
- return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
10
-
11
- def einsum_mulacc(einsum, get_strides, expand):
12
- def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
13
- def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
14
- def mulacc(a, b, new_shape):
15
- (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
16
- out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
17
- ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
18
- return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
19
- return mulacc
20
-
21
- def as_strided(x, arg):
22
- shape, stride, offset = arg
23
- return np.ndarray(shape, x.dtype, buffer=np.require(x, requirements='C'), offset=offset*x.dtype.itemsize,
24
- strides=tuple(y*x.dtype.itemsize for y in stride))
25
-
26
- numpy_fxn_for_op: Dict[Op, Callable] = {
27
- BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
28
- UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
29
- UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
30
- UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
31
- BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract,
32
- BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
33
- ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
34
- ReduceOps.MAX: lambda x, new_shape: x.max(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
35
- TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
36
- TernaryOps.WHERE: np.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: np.broadcast_to, MovementOps.PAD: np.pad
37
- }
38
-
39
- class NumpyAllocator(Allocator):
40
- def _alloc(self, size:int): return np.empty(size, dtype=np.uint8)
41
- def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
42
- def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
43
- def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
44
-
45
- CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)
@@ -1,97 +0,0 @@
1
- from __future__ import annotations
2
- import ctypes, functools, subprocess, io
3
- from typing import Tuple, TypeVar, List
4
- import gpuctypes.hip as hip
5
- from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
6
- from tinygrad.helpers import from_mv, round_up, to_mv
7
- from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
8
- from tinygrad.renderer.cstyle import HIPRenderer
9
- from tinygrad.codegen.kernel import LinearizerOptions
10
-
11
- # The default HIP stream is used for everything.
12
- MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
13
-
14
- def check(status):
15
- if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
16
-
17
- # TODO: remove these helpers, they increase complexity
18
- def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
19
-
20
- def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}', '-I/opt/rocm/include'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
21
-
22
- class HIPProgram:
23
- def __init__(self, device:int, name:str, lib:bytes):
24
- self.device, self.name, self.lib = device, name, lib
25
-
26
- if DEBUG >= 6:
27
- asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
28
- print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
29
-
30
- if MOCKHIP: return
31
- check(hip.hipSetDevice(self.device))
32
- self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
33
- self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
34
-
35
- def __del__(self):
36
- if not MOCKHIP: check(hip.hipModuleUnload(self.module))
37
-
38
- def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
39
- if MOCKHIP: return float("inf")
40
- check(hip.hipSetDevice(self.device))
41
- return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) # noqa: E501
42
-
43
- T = TypeVar("T")
44
- CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
45
- class HIPAllocator(LRUAllocator):
46
- def __init__(self, device:HIPDevice):
47
- self.device = device
48
- super().__init__()
49
- def _alloc(self, size:int):
50
- check(hip.hipSetDevice(self.device.device))
51
- return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
52
- def _free(self, opaque:T): check(hip.hipFree(opaque))
53
- def _hostalloc(self, size:int): return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 0)))
54
- def copy_from_fd(self, dest, fd, offset, size):
55
- check(hip.hipSetDevice(self.device.device))
56
- if not hasattr(self, 'hb'): self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)]
57
- fo = io.FileIO(fd, "a+b", closefd=False)
58
- fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
59
- copied_in = 0
60
- for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
61
- local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
62
- fo.readinto(to_mv(self.hb[0], local_size))
63
- check(hip.hipDeviceSynchronize())
64
- check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[0].value + minor_offset),
65
- copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None))
66
- copied_in += copy_size
67
- self.hb = self.hb[1:] + [self.hb[0]]
68
- minor_offset = 0 # only on the first
69
- def copyin(self, dest:T, src: memoryview):
70
- check(hip.hipSetDevice(self.device.device))
71
- host_mem = self._hostalloc(len(src))
72
- self.device.pending_copyin.append(host_mem)
73
- ctypes.memmove(host_mem, from_mv(src), len(src))
74
- check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
75
- def copyout(self, dest:memoryview, src:T):
76
- check(hip.hipSetDevice(self.device.device))
77
- check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
78
- def transfer(self, dest:T, src:T, sz:int):
79
- check(hip.hipSetDevice(self.device.device))
80
- # TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it
81
- check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
82
-
83
- class HIPDevice(Compiled):
84
- default_arch_name = "gfx1100"
85
- def __init__(self, device:str=""):
86
- self.device = int(device.split(":")[1]) if ":" in device else 0
87
- self.pending_copyin: List[hip.hipDeviceptr_t] = []
88
- if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
89
-
90
- from tinygrad.runtime.graph.hip import HIPGraph
91
- super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer,
92
- compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
93
- def synchronize(self):
94
- check(hip.hipSetDevice(self.device))
95
- check(hip.hipDeviceSynchronize())
96
- for opaque in self.pending_copyin: check(hip.hipFree(opaque))
97
- self.pending_copyin.clear()
@@ -1,49 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from typing import Dict, Callable
4
- from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
5
- from tinygrad.device import Interpreted, Allocator
6
- from tinygrad.dtype import dtypes
7
- from tinygrad.helpers import getenv, flatten
8
- from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
9
-
10
- device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
11
- type_map = {torch.bool: dtypes.bool,
12
- torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
13
- torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32, torch.float64: dtypes.float64}
14
- inverse_type_map = {v: k for k,v in type_map.items()}
15
- # TODO: should unsupported types fail instead of implicit conversion?
16
- inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
17
- def np_type_cvt(t): return {np.uint32: np.int32, np.uint64: np.int64}.get(t, t)
18
-
19
- def as_strided(x, arg):
20
- shape, stride, offset = arg
21
- x = x.contiguous()
22
- offset += x.storage_offset() # NOTE: contiguous can still have a storage_offset, so we adjust for it
23
- if any(i < 0 for i in stride):
24
- return torch.as_strided(x, shape, tuple(abs(i) for i in stride),
25
- offset + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(shape, stride))).flip([i for i,a in enumerate(stride) if a < 0])
26
- return torch.as_strided(x, shape, stride, offset)
27
-
28
- torch_fxn_for_op: Dict[Op, Callable] = {
29
- # TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
30
- #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
31
- BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=np_type_cvt(dtype.np))).to(device),
32
- UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
33
- UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
34
- UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
35
- BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype),
36
- BinaryOps.XOR: torch.bitwise_xor, BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
37
- ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
38
- ReduceOps.MAX: lambda x, new_shape: x.amax(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
39
- TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
40
- TernaryOps.WHERE: torch.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
41
- MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, flatten(padding[::-1])),
42
- }
43
-
44
- class TorchAllocator(Allocator):
45
- def _alloc(self, size:int): return torch.empty([size], device=device, dtype=torch.uint8)
46
- def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
47
- def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
48
-
49
- TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op)
@@ -1,41 +0,0 @@
1
- tinygrad/__init__.py,sha256=FnazNjFEkM_gHbdFHnawSXPH2yWh4HzVlBjz9K1foEc,353
2
- tinygrad/device.py,sha256=BC-KRXOMVQC9ALfZCpC82r9PieQv4wgOEX4dwW9j9XU,18438
3
- tinygrad/dtype.py,sha256=Hw-GBXcJ1DO24aKxO4fSc1ZIbOi03hHFS0gj4ScVl4w,5555
4
- tinygrad/graph.py,sha256=qIl7pLl_31bhxRIEBVqVPgdSSSf6sGoGT9PODCdqsto,5393
5
- tinygrad/helpers.py,sha256=s4fCWZfTRzg_Y66esZhzd_a48qaPgbSecXsrIUG9x4k,11771
6
- tinygrad/jit.py,sha256=9u8bozQce2ZFCFP2m2MhKTNyJ3fesXQuE8xm79DIGVU,8743
7
- tinygrad/lazy.py,sha256=MPKzyOQhMFA6-P9uAHMvn_9vhydrTVRkCOI60kuS7zk,16383
8
- tinygrad/mlops.py,sha256=IYUZ5eZ5xuLWworlPVwplabMMIRpCYLNJc54U2fSxZk,8628
9
- tinygrad/ops.py,sha256=5r4t0AKDwwXXC5sOhltUOtwUIMwJrWdZjmjHa9gjir4,5498
10
- tinygrad/realize.py,sha256=WcMxjNbettjj2O3Y3Q0FuzwOhGKeg3XZDGB9jaZ3pog,2611
11
- tinygrad/tensor.py,sha256=mFljNbp57lGmEe7J8gaRwfVQqdI8ie35L5QsWCgFg1E,58106
12
- tinygrad/codegen/kernel.py,sha256=1EVZ6lRFJxd2B9wmgEjj4l4vFEpksFDAPiLuTmq75lA,35020
13
- tinygrad/codegen/linearizer.py,sha256=16ZjlYLGaFz0DP2jtRMgMLbkvt8JDMeNWLr9OtJ8Cl8,33166
14
- tinygrad/features/image.py,sha256=L2nGvL63WCWXf5hQKWMYs362cEU4jhtZtpnC5g5s0kY,4843
15
- tinygrad/features/multi.py,sha256=B231xBGUCCPuFgLTj65Lkv0H7LEQN3SaSP1A23jHTns,6212
16
- tinygrad/features/search.py,sha256=Sof7SjoQvWvPlkjKaW5zQGEUiQwbwmgQvxhvvJCr3XE,9031
17
- tinygrad/nn/__init__.py,sha256=wWpObBZfHRmdLR6AC7F07xFpskXzyvndsKR3yrjCE_M,7645
18
- tinygrad/nn/optim.py,sha256=4989TCNNFHb5_HJe7fhwaqMmHXRZb83xdeiNm8wztwo,3682
19
- tinygrad/nn/state.py,sha256=UgQRNfMyJOpOdPE87QXfzfpJHCZL2kfZ7Us6WzW4BsM,8294
20
- tinygrad/renderer/cstyle.py,sha256=IojfT4-N3Fi8bJFOBPYUhotR8BU-dz6DQMsGK3azHXA,16951
21
- tinygrad/renderer/llvmir.py,sha256=p9FPUimWPleYj4bE1VsylMzDwREkX7e6Q4tQJMN2QAw,10652
22
- tinygrad/runtime/ops_clang.py,sha256=9RKXciiOr9m6nJjZafR3G3ygHJoroq-C9DxcxPHmAG0,1595
23
- tinygrad/runtime/ops_cpu.py,sha256=0E4umRHMuaUw97kEEFN3f1HKgsBBMisJU9X6BB-8Bjg,3188
24
- tinygrad/runtime/ops_cuda.py,sha256=f46g7sNq0VqqL-XPqyJSAiZgXsXF2K3aThCKEYRsnO4,6466
25
- tinygrad/runtime/ops_disk.py,sha256=rBiUqPF1zDrh9jufkqkHk8bzbhrv4zefqlu49XHWs7w,2835
26
- tinygrad/runtime/ops_gpu.py,sha256=7dH_aR44Z-9UOtZnjwMaLP0n9B6mV7Xf-J6F2xSivK8,7714
27
- tinygrad/runtime/ops_hip.py,sha256=WcZ81n_0yqgs-A8FVtmX6pBrdVAI5UhQ3xQzY5VOu9M,5645
28
- tinygrad/runtime/ops_llvm.py,sha256=X1nn9laLZseHhbKSrjSXXn_FDboknuTe3njaPJ6ueks,2787
29
- tinygrad/runtime/ops_metal.py,sha256=5JGkOrkjGcvGil-eznA0uAt9HR8d9uDx57OlBscdCXQ,5106
30
- tinygrad/runtime/ops_torch.py,sha256=fEDBRwT-jUicaa9tAABCvZ7meLGQrh7nlsVhV3IZV5c,3470
31
- tinygrad/runtime/graph/cuda.py,sha256=UTK1d6Q2Cog83KRNDchMy4Qh6QoocDi3WmsnKNnmsss,5148
32
- tinygrad/runtime/graph/hip.py,sha256=a47xegB9Li6y0IZeAm_B_q357RTnYC5gAYhNbB7r2qY,1680
33
- tinygrad/runtime/graph/metal.py,sha256=nKKPKjMRckKfdUAQRMFOo9DaW-GZ0CdFLwvJdR3bl3o,5123
34
- tinygrad/shape/shapetracker.py,sha256=EzU6zouSs9A6-E8KdjTxG3CWlFbpWruVlyxqxhuj1yE,9843
35
- tinygrad/shape/symbolic.py,sha256=n90f4njHOJlmkZ87sVJn8MauIjsaAJF0DExCIfXVtcg,16565
36
- tinygrad/shape/view.py,sha256=1ubLD4q2tnzq3_WBqv40ckwajfXLIZsVMfDgv7PRQBk,12060
37
- tinygrad-0.8.0.dist-info/LICENSE,sha256=6cp1Hqk0v7NMg1j6OXty_1vAZ4EIwZdCySIoHrCS6RI,1055
38
- tinygrad-0.8.0.dist-info/METADATA,sha256=2MEwTsFJMDgjuxUImazkcp-vwGKNeLtx4RHhHZdovhg,9529
39
- tinygrad-0.8.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
40
- tinygrad-0.8.0.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
41
- tinygrad-0.8.0.dist-info/RECORD,,