tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/jit.py
DELETED
@@ -1,152 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
3
|
-
import functools, itertools, operator
|
4
|
-
from tinygrad.dtype import DType
|
5
|
-
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH
|
6
|
-
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
|
7
|
-
from tinygrad.tensor import Tensor
|
8
|
-
from tinygrad.lazy import LazyBuffer
|
9
|
-
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
-
from tinygrad.shape.symbolic import Variable, NumNode, Node
|
11
|
-
from weakref import ref, WeakKeyDictionary
|
12
|
-
from dataclasses import dataclass
|
13
|
-
|
14
|
-
@dataclass(frozen=True)
|
15
|
-
class JitItem:
|
16
|
-
prg: JITRunner # or a graph executor like MetalGraph
|
17
|
-
rawbufs: List[Optional[Buffer]]
|
18
|
-
|
19
|
-
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
|
20
|
-
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) # noqa: E501
|
21
|
-
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
22
|
-
input_replace: Dict[Tuple[int, int], int] = {}
|
23
|
-
for j,ji in enumerate(jit_cache):
|
24
|
-
for i,a in enumerate(ji.rawbufs):
|
25
|
-
if a in input_rawbuffers:
|
26
|
-
input_replace[(j,i)] = input_rawbuffers.index(a)
|
27
|
-
return input_replace
|
28
|
-
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
|
29
|
-
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501
|
30
|
-
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
31
|
-
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
32
|
-
|
33
|
-
class GraphException(Exception): pass
|
34
|
-
|
35
|
-
ReturnType = TypeVar('ReturnType')
|
36
|
-
class TinyJit(Generic[ReturnType]):
|
37
|
-
def __init__(self, fxn:Callable[..., ReturnType]):
|
38
|
-
self.fxn = fxn
|
39
|
-
self.reset()
|
40
|
-
|
41
|
-
def reset(self):
|
42
|
-
self.jit_cache: List[JitItem] = []
|
43
|
-
self.input_replace: Dict[Tuple[int, int], int] = {}
|
44
|
-
self.cnt: int = 0
|
45
|
-
self.ret: Optional[ReturnType] = None
|
46
|
-
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
47
|
-
self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
|
48
|
-
|
49
|
-
# add support for instance methods
|
50
|
-
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
51
|
-
|
52
|
-
def __call__(self, *args, **kwargs) -> ReturnType:
|
53
|
-
# all inputs (except const) are realized
|
54
|
-
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
|
55
|
-
assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported"
|
56
|
-
expected_name_sts_dtype = tuple([(k, v.st.unbind(), v.dtype) for k,v in input_tensors.items()])
|
57
|
-
|
58
|
-
# get rawbuffers
|
59
|
-
# TODO: why can .realized have Any type?
|
60
|
-
input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None]
|
61
|
-
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
62
|
-
|
63
|
-
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
64
|
-
var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501
|
65
|
-
expected_vals = tuple(var_vals.keys())
|
66
|
-
|
67
|
-
if self.cnt >= 2:
|
68
|
-
# jit exec
|
69
|
-
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
70
|
-
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" # noqa: E501
|
71
|
-
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
72
|
-
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
|
73
|
-
elif self.cnt == 1:
|
74
|
-
# jit capture
|
75
|
-
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
|
76
|
-
CacheCollector.start(var_vals)
|
77
|
-
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
|
78
|
-
self.ret = self.fxn(*args, **kwargs)
|
79
|
-
self.jit_cache = CacheCollector.finish()
|
80
|
-
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
81
|
-
if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
|
82
|
-
print("WARNING: some input tensors not found")
|
83
|
-
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
84
|
-
|
85
|
-
# if your Device supports it, condense the items into a graph executor.
|
86
|
-
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2 and len(self.jit_cache) > 1:
|
87
|
-
# Split JIT cache into batches for faster graph execution.
|
88
|
-
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
89
|
-
graphed_jit_cache, current_batch = [], []
|
90
|
-
for i,ji in enumerate(self.jit_cache):
|
91
|
-
# If the jit item can potentially be graphed, put it in a batch.
|
92
|
-
if isinstance(ji.prg, CompiledASTRunner): current_batch.append(ji)
|
93
|
-
|
94
|
-
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
|
95
|
-
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
|
96
|
-
if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not isinstance(ji.prg, CompiledASTRunner)): # noqa: E501
|
97
|
-
try:
|
98
|
-
graphed_jit_cache.append(JitItem(make_graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers)))
|
99
|
-
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels")
|
100
|
-
except GraphException as e:
|
101
|
-
graphed_jit_cache.extend(current_batch)
|
102
|
-
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels: {e}")
|
103
|
-
current_batch = []
|
104
|
-
|
105
|
-
# If the jit item cannot be graphed, put it right into the final cache after the flush.
|
106
|
-
if not isinstance(ji.prg, CompiledASTRunner): graphed_jit_cache.append(ji)
|
107
|
-
|
108
|
-
self.jit_cache = graphed_jit_cache
|
109
|
-
|
110
|
-
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
111
|
-
elif self.cnt == 0:
|
112
|
-
# jit ignore
|
113
|
-
self.ret = self.fxn(*args, **kwargs)
|
114
|
-
|
115
|
-
# clear jit inputs
|
116
|
-
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
117
|
-
|
118
|
-
self.cnt += 1
|
119
|
-
return cast(ReturnType, self.ret)
|
120
|
-
|
121
|
-
class PlaceHolder:
|
122
|
-
def __init__(self, buf:Buffer): self.size, self.dtype, self.device, self.ref, self.bufid = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf)
|
123
|
-
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid)
|
124
|
-
def __hash__(self): return hash(self.to_tuple())
|
125
|
-
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
|
126
|
-
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
|
127
|
-
ret = self.ref()
|
128
|
-
if ret: return ret
|
129
|
-
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype)
|
130
|
-
return buffer_cache[self]
|
131
|
-
|
132
|
-
class _CacheCollector:
|
133
|
-
def __init__(self):
|
134
|
-
self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
|
135
|
-
|
136
|
-
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
137
|
-
self.cache = []
|
138
|
-
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
|
139
|
-
self.var_vals = var_vals if var_vals is not None else {}
|
140
|
-
|
141
|
-
def add(self, prg, rawbufs, var_vals):
|
142
|
-
if self.cache is None: return
|
143
|
-
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
144
|
-
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
|
145
|
-
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
|
146
|
-
|
147
|
-
def finish(self) -> List[JitItem]:
|
148
|
-
if self.cache is None: return []
|
149
|
-
buffer_cache: Dict[PlaceHolder, Buffer] = {}
|
150
|
-
saved_cache, self.cache = self.cache, None
|
151
|
-
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
|
152
|
-
CacheCollector = _CacheCollector()
|
tinygrad/realize.py
DELETED
@@ -1,50 +0,0 @@
|
|
1
|
-
from typing import List, Dict, Optional, cast
|
2
|
-
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
3
|
-
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
|
4
|
-
from tinygrad.graph import print_tree, realized_lazybuffer
|
5
|
-
from tinygrad.helpers import colored, getenv
|
6
|
-
from tinygrad.shape.symbolic import Variable
|
7
|
-
|
8
|
-
# *** schedule running ***
|
9
|
-
|
10
|
-
class CustomOp(JITRunner):
|
11
|
-
def __init__(self, fxn):
|
12
|
-
self.fxn = fxn
|
13
|
-
super().__init__()
|
14
|
-
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
15
|
-
|
16
|
-
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
17
|
-
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
|
18
|
-
if si.ast.op is LoadOps.EMPTY: return None
|
19
|
-
if si.ast.op is LoadOps.COPY: return BufferCopy
|
20
|
-
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
|
21
|
-
return Device[si.out.device].get_runner(si.ast)
|
22
|
-
|
23
|
-
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
24
|
-
def run_schedule(schedule:List[ScheduleItem]):
|
25
|
-
while len(schedule):
|
26
|
-
si = schedule.pop(0)
|
27
|
-
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
28
|
-
|
29
|
-
# get the program
|
30
|
-
prg = lower_schedule_item(si)
|
31
|
-
|
32
|
-
# invalidate the output buffer if there's a non contig usage of it in inputs
|
33
|
-
if si.out.output_buffer is not None:
|
34
|
-
for i,a in enumerate(si.inputs):
|
35
|
-
if a.realized == si.out.output_buffer:
|
36
|
-
if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op == BufferOps.LOAD and x.arg.idx == i+1):
|
37
|
-
si.out.output_buffer = None
|
38
|
-
break
|
39
|
-
|
40
|
-
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
41
|
-
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
42
|
-
Buffer(si.out.device, si.out.size, si.out.dtype,
|
43
|
-
"PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
|
44
|
-
del si.out.srcs
|
45
|
-
|
46
|
-
# run the function (put it in JIT)
|
47
|
-
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
|
48
|
-
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
|
49
|
-
else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
50
|
-
realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
tinygrad/runtime/graph/hip.py
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
import ctypes
|
2
|
-
from typing import Tuple
|
3
|
-
import gpuctypes.hip as hip
|
4
|
-
from tinygrad.helpers import init_c_var
|
5
|
-
from tinygrad.runtime.ops_hip import check, hip_time_execution
|
6
|
-
from tinygrad.runtime.graph.cuda import CUDAGraph
|
7
|
-
|
8
|
-
class HIPGraph(CUDAGraph):
|
9
|
-
def __del__(self):
|
10
|
-
check(hip.hipGraphDestroy(self.graph))
|
11
|
-
check(hip.hipGraphExecDestroy(self.instance))
|
12
|
-
|
13
|
-
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
|
14
|
-
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
|
15
|
-
def graph_instantiate(self, graph):
|
16
|
-
return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0)))
|
17
|
-
def graph_add_kernel_node(self, graph, c_deps, c_params):
|
18
|
-
return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) # noqa: E501
|
19
|
-
def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait)
|
20
|
-
def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args))
|
21
|
-
def build_kernel_node_params(self, prg, global_size, local_size, c_config):
|
22
|
-
return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0)
|
23
|
-
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
24
|
-
node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size
|
tinygrad/runtime/ops_cpu.py
DELETED
@@ -1,45 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
from typing import Callable, Dict, Tuple
|
3
|
-
from tinygrad.helpers import flat_mv
|
4
|
-
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
5
|
-
from tinygrad.device import Interpreted, Allocator
|
6
|
-
|
7
|
-
def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
8
|
-
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
|
9
|
-
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
|
10
|
-
|
11
|
-
def einsum_mulacc(einsum, get_strides, expand):
|
12
|
-
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
13
|
-
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
|
14
|
-
def mulacc(a, b, new_shape):
|
15
|
-
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
16
|
-
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
17
|
-
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
18
|
-
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
|
19
|
-
return mulacc
|
20
|
-
|
21
|
-
def as_strided(x, arg):
|
22
|
-
shape, stride, offset = arg
|
23
|
-
return np.ndarray(shape, x.dtype, buffer=np.require(x, requirements='C'), offset=offset*x.dtype.itemsize,
|
24
|
-
strides=tuple(y*x.dtype.itemsize for y in stride))
|
25
|
-
|
26
|
-
numpy_fxn_for_op: Dict[Op, Callable] = {
|
27
|
-
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
28
|
-
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
|
29
|
-
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
30
|
-
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
31
|
-
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: np.less, BinaryOps.CMPEQ: np.equal, BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract,
|
32
|
-
BinaryOps.MUL: np.multiply, BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(x.dtype, copy=False), BinaryOps.XOR: np.bitwise_xor,
|
33
|
-
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
34
|
-
ReduceOps.MAX: lambda x, new_shape: x.max(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
35
|
-
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
36
|
-
TernaryOps.WHERE: np.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: np.broadcast_to, MovementOps.PAD: np.pad
|
37
|
-
}
|
38
|
-
|
39
|
-
class NumpyAllocator(Allocator):
|
40
|
-
def _alloc(self, size:int): return np.empty(size, dtype=np.uint8)
|
41
|
-
def as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
|
42
|
-
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
43
|
-
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
44
|
-
|
45
|
-
CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)
|
tinygrad/runtime/ops_hip.py
DELETED
@@ -1,97 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
import ctypes, functools, subprocess, io
|
3
|
-
from typing import Tuple, TypeVar, List
|
4
|
-
import gpuctypes.hip as hip
|
5
|
-
from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
6
|
-
from tinygrad.helpers import from_mv, round_up, to_mv
|
7
|
-
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
|
8
|
-
from tinygrad.renderer.cstyle import HIPRenderer
|
9
|
-
from tinygrad.codegen.kernel import LinearizerOptions
|
10
|
-
|
11
|
-
# The default HIP stream is used for everything.
|
12
|
-
MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
|
13
|
-
|
14
|
-
def check(status):
|
15
|
-
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
16
|
-
|
17
|
-
# TODO: remove these helpers, they increase complexity
|
18
|
-
def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) # noqa: E501
|
19
|
-
|
20
|
-
def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}', '-I/opt/rocm/include'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) # noqa: E501
|
21
|
-
|
22
|
-
class HIPProgram:
|
23
|
-
def __init__(self, device:int, name:str, lib:bytes):
|
24
|
-
self.device, self.name, self.lib = device, name, lib
|
25
|
-
|
26
|
-
if DEBUG >= 6:
|
27
|
-
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
|
28
|
-
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
29
|
-
|
30
|
-
if MOCKHIP: return
|
31
|
-
check(hip.hipSetDevice(self.device))
|
32
|
-
self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
|
33
|
-
self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
|
34
|
-
|
35
|
-
def __del__(self):
|
36
|
-
if not MOCKHIP: check(hip.hipModuleUnload(self.module))
|
37
|
-
|
38
|
-
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
|
39
|
-
if MOCKHIP: return float("inf")
|
40
|
-
check(hip.hipSetDevice(self.device))
|
41
|
-
return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) # noqa: E501
|
42
|
-
|
43
|
-
T = TypeVar("T")
|
44
|
-
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
|
45
|
-
class HIPAllocator(LRUAllocator):
|
46
|
-
def __init__(self, device:HIPDevice):
|
47
|
-
self.device = device
|
48
|
-
super().__init__()
|
49
|
-
def _alloc(self, size:int):
|
50
|
-
check(hip.hipSetDevice(self.device.device))
|
51
|
-
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
|
52
|
-
def _free(self, opaque:T): check(hip.hipFree(opaque))
|
53
|
-
def _hostalloc(self, size:int): return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 0)))
|
54
|
-
def copy_from_fd(self, dest, fd, offset, size):
|
55
|
-
check(hip.hipSetDevice(self.device.device))
|
56
|
-
if not hasattr(self, 'hb'): self.hb = [self._hostalloc(CHUNK_SIZE) for _ in range(2)]
|
57
|
-
fo = io.FileIO(fd, "a+b", closefd=False)
|
58
|
-
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
|
59
|
-
copied_in = 0
|
60
|
-
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
|
61
|
-
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
|
62
|
-
fo.readinto(to_mv(self.hb[0], local_size))
|
63
|
-
check(hip.hipDeviceSynchronize())
|
64
|
-
check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[0].value + minor_offset),
|
65
|
-
copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None))
|
66
|
-
copied_in += copy_size
|
67
|
-
self.hb = self.hb[1:] + [self.hb[0]]
|
68
|
-
minor_offset = 0 # only on the first
|
69
|
-
def copyin(self, dest:T, src: memoryview):
|
70
|
-
check(hip.hipSetDevice(self.device.device))
|
71
|
-
host_mem = self._hostalloc(len(src))
|
72
|
-
self.device.pending_copyin.append(host_mem)
|
73
|
-
ctypes.memmove(host_mem, from_mv(src), len(src))
|
74
|
-
check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None))
|
75
|
-
def copyout(self, dest:memoryview, src:T):
|
76
|
-
check(hip.hipSetDevice(self.device.device))
|
77
|
-
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
78
|
-
def transfer(self, dest:T, src:T, sz:int):
|
79
|
-
check(hip.hipSetDevice(self.device.device))
|
80
|
-
# TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it
|
81
|
-
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
|
82
|
-
|
83
|
-
class HIPDevice(Compiled):
|
84
|
-
default_arch_name = "gfx1100"
|
85
|
-
def __init__(self, device:str=""):
|
86
|
-
self.device = int(device.split(":")[1]) if ":" in device else 0
|
87
|
-
self.pending_copyin: List[hip.hipDeviceptr_t] = []
|
88
|
-
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
89
|
-
|
90
|
-
from tinygrad.runtime.graph.hip import HIPGraph
|
91
|
-
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer,
|
92
|
-
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
93
|
-
def synchronize(self):
|
94
|
-
check(hip.hipSetDevice(self.device))
|
95
|
-
check(hip.hipDeviceSynchronize())
|
96
|
-
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
|
97
|
-
self.pending_copyin.clear()
|
tinygrad/runtime/ops_torch.py
DELETED
@@ -1,49 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
from typing import Dict, Callable
|
4
|
-
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
5
|
-
from tinygrad.device import Interpreted, Allocator
|
6
|
-
from tinygrad.dtype import dtypes
|
7
|
-
from tinygrad.helpers import getenv, flatten
|
8
|
-
from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
|
9
|
-
|
10
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
11
|
-
type_map = {torch.bool: dtypes.bool,
|
12
|
-
torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
|
13
|
-
torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32, torch.float64: dtypes.float64}
|
14
|
-
inverse_type_map = {v: k for k,v in type_map.items()}
|
15
|
-
# TODO: should unsupported types fail instead of implicit conversion?
|
16
|
-
inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
|
17
|
-
def np_type_cvt(t): return {np.uint32: np.int32, np.uint64: np.int64}.get(t, t)
|
18
|
-
|
19
|
-
def as_strided(x, arg):
|
20
|
-
shape, stride, offset = arg
|
21
|
-
x = x.contiguous()
|
22
|
-
offset += x.storage_offset() # NOTE: contiguous can still have a storage_offset, so we adjust for it
|
23
|
-
if any(i < 0 for i in stride):
|
24
|
-
return torch.as_strided(x, shape, tuple(abs(i) for i in stride),
|
25
|
-
offset + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(shape, stride))).flip([i for i,a in enumerate(stride) if a < 0])
|
26
|
-
return torch.as_strided(x, shape, stride, offset)
|
27
|
-
|
28
|
-
torch_fxn_for_op: Dict[Op, Callable] = {
|
29
|
-
# TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
|
30
|
-
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
31
|
-
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=np_type_cvt(dtype.np))).to(device),
|
32
|
-
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
|
33
|
-
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
34
|
-
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
35
|
-
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul, BinaryOps.DIV: lambda x,y: torch.div(x, y).type(x.dtype),
|
36
|
-
BinaryOps.XOR: torch.bitwise_xor, BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: torch.lt, BinaryOps.CMPEQ: torch.eq,
|
37
|
-
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
38
|
-
ReduceOps.MAX: lambda x, new_shape: x.amax(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
39
|
-
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
40
|
-
TernaryOps.WHERE: torch.where, MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
41
|
-
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, flatten(padding[::-1])),
|
42
|
-
}
|
43
|
-
|
44
|
-
class TorchAllocator(Allocator):
|
45
|
-
def _alloc(self, size:int): return torch.empty([size], device=device, dtype=torch.uint8)
|
46
|
-
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
47
|
-
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
48
|
-
|
49
|
-
TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op)
|
tinygrad-0.8.0.dist-info/RECORD
DELETED
@@ -1,41 +0,0 @@
|
|
1
|
-
tinygrad/__init__.py,sha256=FnazNjFEkM_gHbdFHnawSXPH2yWh4HzVlBjz9K1foEc,353
|
2
|
-
tinygrad/device.py,sha256=BC-KRXOMVQC9ALfZCpC82r9PieQv4wgOEX4dwW9j9XU,18438
|
3
|
-
tinygrad/dtype.py,sha256=Hw-GBXcJ1DO24aKxO4fSc1ZIbOi03hHFS0gj4ScVl4w,5555
|
4
|
-
tinygrad/graph.py,sha256=qIl7pLl_31bhxRIEBVqVPgdSSSf6sGoGT9PODCdqsto,5393
|
5
|
-
tinygrad/helpers.py,sha256=s4fCWZfTRzg_Y66esZhzd_a48qaPgbSecXsrIUG9x4k,11771
|
6
|
-
tinygrad/jit.py,sha256=9u8bozQce2ZFCFP2m2MhKTNyJ3fesXQuE8xm79DIGVU,8743
|
7
|
-
tinygrad/lazy.py,sha256=MPKzyOQhMFA6-P9uAHMvn_9vhydrTVRkCOI60kuS7zk,16383
|
8
|
-
tinygrad/mlops.py,sha256=IYUZ5eZ5xuLWworlPVwplabMMIRpCYLNJc54U2fSxZk,8628
|
9
|
-
tinygrad/ops.py,sha256=5r4t0AKDwwXXC5sOhltUOtwUIMwJrWdZjmjHa9gjir4,5498
|
10
|
-
tinygrad/realize.py,sha256=WcMxjNbettjj2O3Y3Q0FuzwOhGKeg3XZDGB9jaZ3pog,2611
|
11
|
-
tinygrad/tensor.py,sha256=mFljNbp57lGmEe7J8gaRwfVQqdI8ie35L5QsWCgFg1E,58106
|
12
|
-
tinygrad/codegen/kernel.py,sha256=1EVZ6lRFJxd2B9wmgEjj4l4vFEpksFDAPiLuTmq75lA,35020
|
13
|
-
tinygrad/codegen/linearizer.py,sha256=16ZjlYLGaFz0DP2jtRMgMLbkvt8JDMeNWLr9OtJ8Cl8,33166
|
14
|
-
tinygrad/features/image.py,sha256=L2nGvL63WCWXf5hQKWMYs362cEU4jhtZtpnC5g5s0kY,4843
|
15
|
-
tinygrad/features/multi.py,sha256=B231xBGUCCPuFgLTj65Lkv0H7LEQN3SaSP1A23jHTns,6212
|
16
|
-
tinygrad/features/search.py,sha256=Sof7SjoQvWvPlkjKaW5zQGEUiQwbwmgQvxhvvJCr3XE,9031
|
17
|
-
tinygrad/nn/__init__.py,sha256=wWpObBZfHRmdLR6AC7F07xFpskXzyvndsKR3yrjCE_M,7645
|
18
|
-
tinygrad/nn/optim.py,sha256=4989TCNNFHb5_HJe7fhwaqMmHXRZb83xdeiNm8wztwo,3682
|
19
|
-
tinygrad/nn/state.py,sha256=UgQRNfMyJOpOdPE87QXfzfpJHCZL2kfZ7Us6WzW4BsM,8294
|
20
|
-
tinygrad/renderer/cstyle.py,sha256=IojfT4-N3Fi8bJFOBPYUhotR8BU-dz6DQMsGK3azHXA,16951
|
21
|
-
tinygrad/renderer/llvmir.py,sha256=p9FPUimWPleYj4bE1VsylMzDwREkX7e6Q4tQJMN2QAw,10652
|
22
|
-
tinygrad/runtime/ops_clang.py,sha256=9RKXciiOr9m6nJjZafR3G3ygHJoroq-C9DxcxPHmAG0,1595
|
23
|
-
tinygrad/runtime/ops_cpu.py,sha256=0E4umRHMuaUw97kEEFN3f1HKgsBBMisJU9X6BB-8Bjg,3188
|
24
|
-
tinygrad/runtime/ops_cuda.py,sha256=f46g7sNq0VqqL-XPqyJSAiZgXsXF2K3aThCKEYRsnO4,6466
|
25
|
-
tinygrad/runtime/ops_disk.py,sha256=rBiUqPF1zDrh9jufkqkHk8bzbhrv4zefqlu49XHWs7w,2835
|
26
|
-
tinygrad/runtime/ops_gpu.py,sha256=7dH_aR44Z-9UOtZnjwMaLP0n9B6mV7Xf-J6F2xSivK8,7714
|
27
|
-
tinygrad/runtime/ops_hip.py,sha256=WcZ81n_0yqgs-A8FVtmX6pBrdVAI5UhQ3xQzY5VOu9M,5645
|
28
|
-
tinygrad/runtime/ops_llvm.py,sha256=X1nn9laLZseHhbKSrjSXXn_FDboknuTe3njaPJ6ueks,2787
|
29
|
-
tinygrad/runtime/ops_metal.py,sha256=5JGkOrkjGcvGil-eznA0uAt9HR8d9uDx57OlBscdCXQ,5106
|
30
|
-
tinygrad/runtime/ops_torch.py,sha256=fEDBRwT-jUicaa9tAABCvZ7meLGQrh7nlsVhV3IZV5c,3470
|
31
|
-
tinygrad/runtime/graph/cuda.py,sha256=UTK1d6Q2Cog83KRNDchMy4Qh6QoocDi3WmsnKNnmsss,5148
|
32
|
-
tinygrad/runtime/graph/hip.py,sha256=a47xegB9Li6y0IZeAm_B_q357RTnYC5gAYhNbB7r2qY,1680
|
33
|
-
tinygrad/runtime/graph/metal.py,sha256=nKKPKjMRckKfdUAQRMFOo9DaW-GZ0CdFLwvJdR3bl3o,5123
|
34
|
-
tinygrad/shape/shapetracker.py,sha256=EzU6zouSs9A6-E8KdjTxG3CWlFbpWruVlyxqxhuj1yE,9843
|
35
|
-
tinygrad/shape/symbolic.py,sha256=n90f4njHOJlmkZ87sVJn8MauIjsaAJF0DExCIfXVtcg,16565
|
36
|
-
tinygrad/shape/view.py,sha256=1ubLD4q2tnzq3_WBqv40ckwajfXLIZsVMfDgv7PRQBk,12060
|
37
|
-
tinygrad-0.8.0.dist-info/LICENSE,sha256=6cp1Hqk0v7NMg1j6OXty_1vAZ4EIwZdCySIoHrCS6RI,1055
|
38
|
-
tinygrad-0.8.0.dist-info/METADATA,sha256=2MEwTsFJMDgjuxUImazkcp-vwGKNeLtx4RHhHZdovhg,9529
|
39
|
-
tinygrad-0.8.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
40
|
-
tinygrad-0.8.0.dist-info/top_level.txt,sha256=vDABMCWBFQnx2kn9Azueu88FP-1klQdePoHikQhHymc,9
|
41
|
-
tinygrad-0.8.0.dist-info/RECORD,,
|
File without changes
|