tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/image.py
DELETED
@@ -1,100 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
|
3
|
-
from tinygrad.lazy import get_single_root
|
4
|
-
|
5
|
-
FLOAT16 = getenv("FLOAT16", 0)
|
6
|
-
base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32)
|
7
|
-
|
8
|
-
def image_dot(self, w):
|
9
|
-
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
10
|
-
n1, n2 = len(self.shape), len(w.shape)
|
11
|
-
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
12
|
-
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
|
13
|
-
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
14
|
-
cin, cout = w.shape[-2], w.shape[-1]
|
15
|
-
out_shape_t = self.shape[0:-2] + (cout,-1)
|
16
|
-
if len(self.shape) > 1:
|
17
|
-
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
|
18
|
-
else:
|
19
|
-
order, out_shape_t = (0,), (cout, )
|
20
|
-
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
|
21
|
-
|
22
|
-
# NOTE: with NHWC we can remove the transposes
|
23
|
-
# bs x groups*cin x H x W
|
24
|
-
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
25
|
-
# groups*cout x cin x H, W
|
26
|
-
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
27
|
-
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
|
28
|
-
|
29
|
-
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
30
|
-
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
31
|
-
rcout = cout//groups
|
32
|
-
x, w = self, weight.reshape(groups, rcout, cin, H, W)
|
33
|
-
|
34
|
-
# hack for non multiples of 4 on cin
|
35
|
-
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
36
|
-
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
37
|
-
added_input_channels = 4 - (cin % 4)
|
38
|
-
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
|
39
|
-
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
|
40
|
-
cin = cin + added_input_channels
|
41
|
-
x = x.reshape(bs, groups*cin, iy, ix)
|
42
|
-
|
43
|
-
# hack for non multiples of 4 on rcout
|
44
|
-
added_output_channels = 0
|
45
|
-
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
46
|
-
added_output_channels = 4 - (rcout % 4)
|
47
|
-
rcout += added_output_channels
|
48
|
-
cout = groups * rcout
|
49
|
-
w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
|
50
|
-
|
51
|
-
# packed (note: flipping bs and iy would make the auto-padding work)
|
52
|
-
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
|
53
|
-
cin_last = iy == 1 and ix == 1
|
54
|
-
if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
|
55
|
-
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
|
56
|
-
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
|
57
|
-
|
58
|
-
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
59
|
-
if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
|
60
|
-
x, w = x.contiguous(), w.contiguous()
|
61
|
-
if get_single_root(w.lazydata).realized: w.realize()
|
62
|
-
|
63
|
-
# expand out
|
64
|
-
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
65
|
-
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
66
|
-
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
67
|
-
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
68
|
-
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
69
|
-
|
70
|
-
# padding
|
71
|
-
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
72
|
-
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
73
|
-
|
74
|
-
# prepare input
|
75
|
-
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
76
|
-
oy, ox = x.shape[4:6]
|
77
|
-
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
78
|
-
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
|
79
|
-
|
80
|
-
# prepare weights
|
81
|
-
w = w.permute(0,4,2,5,1,3)
|
82
|
-
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
|
83
|
-
|
84
|
-
# the conv! (+ the bias)
|
85
|
-
ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1))
|
86
|
-
|
87
|
-
# reshape to image and cast back to image
|
88
|
-
ret = ret.reshape(bs*oy, ox*cout//4, 4)
|
89
|
-
if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
|
90
|
-
if IMAGE >= 3: ret = ret.contiguous()
|
91
|
-
|
92
|
-
# undo hack for non multiples of 4 on C.rcout
|
93
|
-
if added_output_channels != 0:
|
94
|
-
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
95
|
-
rcout -= added_output_channels
|
96
|
-
cout = groups * rcout
|
97
|
-
|
98
|
-
# NCHW output
|
99
|
-
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
100
|
-
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
@@ -1,169 +0,0 @@
|
|
1
|
-
import struct
|
2
|
-
from platform import system
|
3
|
-
from typing import Tuple, Dict, List, Optional
|
4
|
-
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
5
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
6
|
-
from tinygrad.helpers import dtypes, CI
|
7
|
-
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
8
|
-
|
9
|
-
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
10
|
-
def compute_offsets(total):
|
11
|
-
quotient, remainder = divmod(total, 4096)
|
12
|
-
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
13
|
-
|
14
|
-
#NOTE: Darwin needs names to start with a "_"
|
15
|
-
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
16
|
-
|
17
|
-
class ARM64Language(AssemblyLanguage): pass
|
18
|
-
|
19
|
-
def specialize_to_arm64(fn_nm, asm):
|
20
|
-
var_size = 16
|
21
|
-
prev_uop:Optional[UOps] = None
|
22
|
-
ins = []
|
23
|
-
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
24
|
-
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
25
|
-
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
26
|
-
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
27
|
-
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
28
|
-
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
29
|
-
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
30
|
-
|
31
|
-
def mov_imm(value, reg):
|
32
|
-
# Manually move value into reg if value can't fit
|
33
|
-
if value.__class__ is not float and abs(value) > abs(65535):
|
34
|
-
ins.append(f"movz w15, #{value & 0xffff}")
|
35
|
-
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
36
|
-
ins.append(f"sxtw {reg}, w15")
|
37
|
-
elif reg[0] == 's':
|
38
|
-
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
39
|
-
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
40
|
-
ins.append("str x15, [sp, 16]")
|
41
|
-
ins.append(f"ldr {reg}, [sp, 16]")
|
42
|
-
else:
|
43
|
-
ins.append(f"mov {reg}, #{value}")
|
44
|
-
|
45
|
-
# Get variables intervals
|
46
|
-
live_range:Dict[str, List[int]] = {}
|
47
|
-
for i, (uop, out, vin, arg) in enumerate(asm):
|
48
|
-
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
49
|
-
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
50
|
-
|
51
|
-
mem_vars:Dict[str, int] = {}
|
52
|
-
rtor:Dict[str, str] = {}
|
53
|
-
def allocate_regs(mvars):
|
54
|
-
nonlocal var_size
|
55
|
-
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
56
|
-
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
57
|
-
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
58
|
-
if not available_regs:
|
59
|
-
# ARM needs the stack 16-byte aligned
|
60
|
-
var_size += 16
|
61
|
-
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
62
|
-
mem_vars[v.nm] = var_size
|
63
|
-
rtor[v.nm] = available_regs.pop()
|
64
|
-
|
65
|
-
temp_floats = ['s0', 's1', 's2']
|
66
|
-
temp_ints = ['x12', 'x13', 'x16']
|
67
|
-
for i, (uop, out, vin, arg) in enumerate(asm):
|
68
|
-
# Clear regs out of interval
|
69
|
-
for var, reg in list(rtor.items()):
|
70
|
-
available_regs = s_regs if reg[0] == 's' else x_regs
|
71
|
-
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
72
|
-
available_regs.append(rtor.pop(var))
|
73
|
-
# Assign a registers to the variables using live ranges.
|
74
|
-
allocate_regs([out] + vin)
|
75
|
-
# Assign temp regs to vin and load them before direct use
|
76
|
-
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
77
|
-
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
78
|
-
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
79
|
-
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
80
|
-
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
81
|
-
|
82
|
-
if uop == UOps.SPECIAL:
|
83
|
-
if arg.startswith('data'):
|
84
|
-
# data 8 to n into the stack
|
85
|
-
if int(arg[4:]) >= 8:
|
86
|
-
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
87
|
-
ins.append(f"mov {rtor[out.nm]}, x15")
|
88
|
-
else:
|
89
|
-
ins.append(f"mov {rtor[out.nm]}, #0")
|
90
|
-
ins.append(f"loop_{arg}:")
|
91
|
-
elif uop == UOps.CAST:
|
92
|
-
if arg == BinaryOps.CMPLT:
|
93
|
-
mov_imm(0.0, 's0')
|
94
|
-
mov_imm(1.0, 's1')
|
95
|
-
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
96
|
-
else:
|
97
|
-
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
98
|
-
elif uop == UOps.ALU:
|
99
|
-
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
100
|
-
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
101
|
-
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
102
|
-
elif arg == TernaryOps.WHERE:
|
103
|
-
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
|
104
|
-
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
105
|
-
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
106
|
-
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
107
|
-
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
108
|
-
else:
|
109
|
-
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
110
|
-
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
111
|
-
# Save the registers before they are cleared by func call
|
112
|
-
for i,k in enumerate(save_regs,1):
|
113
|
-
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
114
|
-
ins.append("stp x29, x30, [sp, #0]!")
|
115
|
-
ins.append("mov x29, sp")
|
116
|
-
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
117
|
-
ins.append(alu[arg])
|
118
|
-
ins.append(f"fmov {rtor[out.nm]}, s0")
|
119
|
-
ins.append("mov sp, x29")
|
120
|
-
ins.append("ldp x29, x30, [sp], #0")
|
121
|
-
for i,k in enumerate(save_regs,1):
|
122
|
-
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
123
|
-
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
124
|
-
elif arg == BinaryOps.CMPLT:
|
125
|
-
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
126
|
-
elif arg == BinaryOps.MOD:
|
127
|
-
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
|
128
|
-
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
|
129
|
-
else:
|
130
|
-
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
131
|
-
elif uop == UOps.LOAD:
|
132
|
-
if arg.__class__ in (int, float):
|
133
|
-
mov_imm(arg, rtor[out.nm])
|
134
|
-
else:
|
135
|
-
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
136
|
-
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
137
|
-
mov_imm(arg[0], "x15")
|
138
|
-
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
139
|
-
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
140
|
-
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
141
|
-
elif uop == UOps.STORE:
|
142
|
-
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
143
|
-
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
144
|
-
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
145
|
-
ins.append(f"mov x15, #{arg[0]}")
|
146
|
-
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
147
|
-
elif uop == UOps.COND_BRANCH:
|
148
|
-
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
149
|
-
if prev_uop == UOps.LOAD:
|
150
|
-
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
151
|
-
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
152
|
-
elif uop == UOps.LABEL:
|
153
|
-
ins.append(f"{arg[1:]}:")
|
154
|
-
elif uop == UOps.ENDLOOP:
|
155
|
-
mov_imm(arg[0], "x15")
|
156
|
-
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
157
|
-
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
158
|
-
ins.append(f"b.lt loop_{arg[1]}")
|
159
|
-
prev_uop = uop
|
160
|
-
# store regs into memory if needed
|
161
|
-
if out is not None and out.nm in mem_vars:
|
162
|
-
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
163
|
-
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
164
|
-
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
165
|
-
|
166
|
-
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
167
|
-
lang = ARM64Language()
|
168
|
-
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
169
|
-
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
@@ -1,98 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
import struct
|
3
|
-
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
4
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
5
|
-
from tinygrad.helpers import dtypes
|
6
|
-
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
7
|
-
from tinygrad.runtime.ops_cuda import arch
|
8
|
-
|
9
|
-
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
10
|
-
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
11
|
-
|
12
|
-
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
13
|
-
|
14
|
-
def render_cast(ins, inp, out):
|
15
|
-
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
16
|
-
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
17
|
-
elif out.dtype == dtypes.bool:
|
18
|
-
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
19
|
-
else:
|
20
|
-
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
21
|
-
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
22
|
-
|
23
|
-
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
24
|
-
|
25
|
-
class PTXLanguage(AssemblyLanguage):
|
26
|
-
supports_constant_folding: bool = True
|
27
|
-
|
28
|
-
def specialize_to_ptx(lang, function_name):
|
29
|
-
param_cnt = 0
|
30
|
-
ins = []
|
31
|
-
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
32
|
-
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
33
|
-
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
34
|
-
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
35
|
-
for uop, out, vin, arg in lang.ins:
|
36
|
-
if uop == UOps.ENDLOOP:
|
37
|
-
ins.append("bar.sync 0;")
|
38
|
-
elif uop == UOps.DEFINE_LOCAL:
|
39
|
-
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
40
|
-
elif uop == UOps.SPECIAL:
|
41
|
-
if arg.startswith('data'):
|
42
|
-
param_cnt += 1
|
43
|
-
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
44
|
-
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
45
|
-
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
46
|
-
elif arg.startswith('gid'):
|
47
|
-
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
48
|
-
elif arg.startswith('lid'):
|
49
|
-
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
50
|
-
elif uop == UOps.ALU:
|
51
|
-
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
52
|
-
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
53
|
-
else:
|
54
|
-
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
55
|
-
if arg == TernaryOps.WHERE:
|
56
|
-
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
57
|
-
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
58
|
-
vin = vin[1:] + [reg]
|
59
|
-
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
60
|
-
elif uop == UOps.LOAD:
|
61
|
-
if arg.__class__ in (int, float):
|
62
|
-
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
63
|
-
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
64
|
-
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
65
|
-
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
66
|
-
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
67
|
-
render_cast(ins, reg, out)
|
68
|
-
else:
|
69
|
-
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
70
|
-
elif uop == UOps.STORE:
|
71
|
-
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
72
|
-
if arg[2] == dtypes.bool != vin[1].dtype:
|
73
|
-
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
74
|
-
render_cast(ins, vin[1], prereg)
|
75
|
-
else: prereg = vin[1]
|
76
|
-
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
77
|
-
render_cast(ins, prereg, reg)
|
78
|
-
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
79
|
-
else:
|
80
|
-
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
81
|
-
elif uop == UOps.CAST:
|
82
|
-
render_cast(ins, vin[0], out)
|
83
|
-
elif uop == UOps.LABEL:
|
84
|
-
ins.append(f"{arg}:")
|
85
|
-
elif uop == UOps.COND_BRANCH:
|
86
|
-
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
87
|
-
|
88
|
-
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
89
|
-
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
90
|
-
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
91
|
-
ins = ins_prefix + ins
|
92
|
-
ins += ["ret;", "}"]
|
93
|
-
return '\n'.join(ins)
|
94
|
-
|
95
|
-
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
|
96
|
-
lang = PTXLanguage()
|
97
|
-
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
98
|
-
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
tinygrad/renderer/wgsl.py
DELETED
@@ -1,53 +0,0 @@
|
|
1
|
-
from tinygrad.renderer.cstyle import render_cl
|
2
|
-
from tinygrad.helpers import dtypes, DType
|
3
|
-
from tinygrad.renderer.cstyle import CStyleLanguage
|
4
|
-
from typing import List, Union
|
5
|
-
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
6
|
-
import math
|
7
|
-
from typing import Tuple
|
8
|
-
|
9
|
-
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
|
10
|
-
class WGSLLanguage(CStyleLanguage):
|
11
|
-
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
|
12
|
-
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
|
13
|
-
size_prefix = "let"
|
14
|
-
barrier="workgroupBarrier();"
|
15
|
-
generic_var_prefix = "var "
|
16
|
-
external_local_bufs = True
|
17
|
-
code_for_op = {
|
18
|
-
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
19
|
-
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
|
20
|
-
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
|
21
|
-
TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
|
22
|
-
}
|
23
|
-
|
24
|
-
def render_local(self, name: str, size: int):
|
25
|
-
return f"var<workgroup> {name}: array<f32,{size}>;"
|
26
|
-
|
27
|
-
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
28
|
-
if math.isnan(x): val = "nan()"
|
29
|
-
elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
|
30
|
-
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
|
31
|
-
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
32
|
-
|
33
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
|
34
|
-
local_size = local_size[::-1] if local_size else [1]
|
35
|
-
bind_it = iter(range(len(bufs)))
|
36
|
-
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
37
|
-
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
|
38
|
-
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
39
|
-
return prg, global_size[::-1] if global_size else [1], local_size
|
40
|
-
|
41
|
-
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
|
42
|
-
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
43
|
-
|
44
|
-
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
45
|
-
return f"select(f32({y}), {x}, bool({cond}))"
|
46
|
-
|
47
|
-
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
48
|
-
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
|
49
|
-
|
50
|
-
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
51
|
-
if buf_dtype != var_dtype:
|
52
|
-
var_name = f"{type_map[buf_dtype]}({var_name})"
|
53
|
-
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
tinygrad/runtime/lib.py
DELETED
@@ -1,113 +0,0 @@
|
|
1
|
-
import ctypes
|
2
|
-
import numpy as np
|
3
|
-
from collections import defaultdict, deque
|
4
|
-
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
|
5
|
-
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
|
6
|
-
|
7
|
-
_T = TypeVar("_T")
|
8
|
-
class RawBuffer: # pylint: disable=abstract-method
|
9
|
-
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
|
10
|
-
self.size: int = size
|
11
|
-
self.dtype: DType = dtype
|
12
|
-
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
13
|
-
self._memsz: int = size*dtype.itemsize
|
14
|
-
self._allocator = allocator
|
15
|
-
GlobalCounters.mem_used += self._memsz
|
16
|
-
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
17
|
-
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
18
|
-
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
|
19
|
-
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
|
20
|
-
@property
|
21
|
-
def key(self): return (self.size, self.dtype)
|
22
|
-
|
23
|
-
# NOTE: this interface allows for 0 copy
|
24
|
-
@classmethod
|
25
|
-
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
|
26
|
-
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
27
|
-
|
28
|
-
class RawConst(RawBuffer): # pylint: disable=abstract-method
|
29
|
-
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
|
30
|
-
@property
|
31
|
-
def key(self): return (str(self._buf), self.dtype)
|
32
|
-
|
33
|
-
def buf_is_kernel_arg(x) -> bool:
|
34
|
-
return x.realized is not None and x.realized.__class__ is not RawConst
|
35
|
-
|
36
|
-
# --teenygrad--
|
37
|
-
|
38
|
-
class RawBufferCopyIn(RawBuffer):
|
39
|
-
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
40
|
-
|
41
|
-
@classmethod
|
42
|
-
def fromCPU(cls, x:np.ndarray, **kwargs):
|
43
|
-
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
44
|
-
ret._copyin(x)
|
45
|
-
return ret
|
46
|
-
|
47
|
-
class RawBufferMapped(RawBufferCopyIn):
|
48
|
-
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
|
49
|
-
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
|
50
|
-
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
51
|
-
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
|
52
|
-
|
53
|
-
# this one is simple enough that i moved it out of the runtimes
|
54
|
-
class RawMallocBuffer(RawBufferMapped):
|
55
|
-
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
|
56
|
-
def _buffer(self): return memoryview(self._buf)
|
57
|
-
|
58
|
-
class RawBufferCopyInOut(RawBufferCopyIn):
|
59
|
-
def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
60
|
-
|
61
|
-
def toCPU(self) -> np.ndarray:
|
62
|
-
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
|
63
|
-
self._copyout(x)
|
64
|
-
return x
|
65
|
-
|
66
|
-
class RawBufferTransfer(RawBuffer):
|
67
|
-
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
|
68
|
-
|
69
|
-
@classmethod
|
70
|
-
def transfer(cls, x, shape, dtype, **kwargs):
|
71
|
-
ret = cls(prod(shape), dtype, **kwargs)
|
72
|
-
ret._transfer(x)
|
73
|
-
return ret
|
74
|
-
|
75
|
-
class LRUAllocator:
|
76
|
-
def __init__(self, dev_memsz=(4<<30)):
|
77
|
-
self.epoch = 0
|
78
|
-
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
|
79
|
-
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
|
80
|
-
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
|
81
|
-
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
|
82
|
-
def __del__(self):
|
83
|
-
for v in self.cached_buffers.values():
|
84
|
-
for buf, _ in v: self._free_buffer(buf)
|
85
|
-
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
|
86
|
-
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
|
87
|
-
return rawbufs.popleft()[0]
|
88
|
-
def _alloc_buffer(self, size, dtype, device, **kwargs):
|
89
|
-
self.free_space[device] -= size*dtype.itemsize
|
90
|
-
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
|
91
|
-
bucket, epoch = self.aging_order[device].popleft()
|
92
|
-
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
|
93
|
-
newbuf = self._do_alloc(size, dtype, device, **kwargs)
|
94
|
-
self.buffer_info[newbuf] = (size, dtype, device)
|
95
|
-
return newbuf
|
96
|
-
def _free_buffer(self, buf_to_free):
|
97
|
-
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
|
98
|
-
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
|
99
|
-
self.buffer_info.pop(buf_to_free)
|
100
|
-
self._do_free(buf_to_free)
|
101
|
-
def alloc(self, size, dtype, device='0', **kwargs):
|
102
|
-
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
|
103
|
-
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
|
104
|
-
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
|
105
|
-
self.epoch += 1
|
106
|
-
size, dtype, device = self.buffer_info[buf]
|
107
|
-
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
|
108
|
-
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
|
109
|
-
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
|
110
|
-
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
|
111
|
-
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
|
112
|
-
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
|
113
|
-
def _do_free(self, buf): pass
|
tinygrad/runtime/ops_cpu.py
DELETED
@@ -1,51 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
import operator
|
3
|
-
from typing import Callable, Dict, Tuple, Optional
|
4
|
-
from tinygrad.helpers import dtypes, DType
|
5
|
-
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
6
|
-
from tinygrad.runtime.lib import RawBuffer
|
7
|
-
|
8
|
-
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
9
|
-
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
10
|
-
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
11
|
-
|
12
|
-
base_fxn_for_op: Dict[Op, Callable] = {
|
13
|
-
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
|
14
|
-
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
15
|
-
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
16
|
-
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
17
|
-
}
|
18
|
-
|
19
|
-
def promote_types(x, y): return ret if (ret := np.promote_types(x.dtype, y.dtype)) != np.float64 else np.float32
|
20
|
-
def match_types(x, y):
|
21
|
-
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
22
|
-
return x.astype(up, copy=False), y.astype(up, copy=False)
|
23
|
-
|
24
|
-
def einsum_mulacc(einsum, get_strides, expand):
|
25
|
-
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
26
|
-
def axes_slice(strides): return [i for i in range(len(strides)) if strides[i] != 0], tuple(slice(None) if strides[i] != 0 else 0 for i in range(len(strides)))
|
27
|
-
def mulacc(a, b, new_shape):
|
28
|
-
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
29
|
-
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
30
|
-
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
31
|
-
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
32
|
-
return mulacc
|
33
|
-
|
34
|
-
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
35
|
-
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
36
|
-
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
37
|
-
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
38
|
-
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
39
|
-
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
40
|
-
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
41
|
-
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
42
|
-
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
43
|
-
TernaryOps.WHERE: np.where,
|
44
|
-
}}
|
45
|
-
|
46
|
-
class RawNumpyBuffer(RawBuffer):
|
47
|
-
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
|
48
|
-
@classmethod
|
49
|
-
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
50
|
-
def toCPU(self): return self._buf
|
51
|
-
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
|