tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/nn/image.py DELETED
@@ -1,100 +0,0 @@
1
- import numpy as np
2
- from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
3
- from tinygrad.lazy import get_single_root
4
-
5
- FLOAT16 = getenv("FLOAT16", 0)
6
- base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32)
7
-
8
- def image_dot(self, w):
9
- # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
10
- n1, n2 = len(self.shape), len(w.shape)
11
- assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
12
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
13
- bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
14
- cin, cout = w.shape[-2], w.shape[-1]
15
- out_shape_t = self.shape[0:-2] + (cout,-1)
16
- if len(self.shape) > 1:
17
- order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
18
- else:
19
- order, out_shape_t = (0,), (cout, )
20
- worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
21
-
22
- # NOTE: with NHWC we can remove the transposes
23
- # bs x groups*cin x H x W
24
- cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
25
- # groups*cout x cin x H, W
26
- cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
27
- return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
28
-
29
- def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
30
- (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
31
- rcout = cout//groups
32
- x, w = self, weight.reshape(groups, rcout, cin, H, W)
33
-
34
- # hack for non multiples of 4 on cin
35
- if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
36
- x = x.reshape(bs, groups, cin, iy, ix) # do this always?
37
- added_input_channels = 4 - (cin % 4)
38
- w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
39
- x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
40
- cin = cin + added_input_channels
41
- x = x.reshape(bs, groups*cin, iy, ix)
42
-
43
- # hack for non multiples of 4 on rcout
44
- added_output_channels = 0
45
- if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
46
- added_output_channels = 4 - (rcout % 4)
47
- rcout += added_output_channels
48
- cout = groups * rcout
49
- w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
50
-
51
- # packed (note: flipping bs and iy would make the auto-padding work)
52
- x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
53
- cin_last = iy == 1 and ix == 1
54
- if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
55
- elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
56
- else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
57
-
58
- # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
59
- if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
60
- x, w = x.contiguous(), w.contiguous()
61
- if get_single_root(w.lazydata).realized: w.realize()
62
-
63
- # expand out
64
- rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
65
- cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
66
- x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
67
- if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
68
- else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
69
-
70
- # padding
71
- padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
72
- x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
73
-
74
- # prepare input
75
- x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
76
- oy, ox = x.shape[4:6]
77
- x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
78
- x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
79
-
80
- # prepare weights
81
- w = w.permute(0,4,2,5,1,3)
82
- w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
83
-
84
- # the conv! (+ the bias)
85
- ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1))
86
-
87
- # reshape to image and cast back to image
88
- ret = ret.reshape(bs*oy, ox*cout//4, 4)
89
- if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
90
- if IMAGE >= 3: ret = ret.contiguous()
91
-
92
- # undo hack for non multiples of 4 on C.rcout
93
- if added_output_channels != 0:
94
- ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
95
- rcout -= added_output_channels
96
- cout = groups * rcout
97
-
98
- # NCHW output
99
- ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
100
- return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
@@ -1,169 +0,0 @@
1
- import struct
2
- from platform import system
3
- from typing import Tuple, Dict, List, Optional
4
- from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
5
- from tinygrad.codegen.linearizer import UOps, UOp
6
- from tinygrad.helpers import dtypes, CI
7
- from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
8
-
9
- def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
10
- def compute_offsets(total):
11
- quotient, remainder = divmod(total, 4096)
12
- return [4096]*quotient + [remainder] if remainder else [4096]*quotient
13
-
14
- #NOTE: Darwin needs names to start with a "_"
15
- def get_name(name): return ('_' if system() == 'Darwin' else '') + name
16
-
17
- class ARM64Language(AssemblyLanguage): pass
18
-
19
- def specialize_to_arm64(fn_nm, asm):
20
- var_size = 16
21
- prev_uop:Optional[UOps] = None
22
- ins = []
23
- x_regs = ['x' + str(i) for i in reversed(range(12))]
24
- s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
25
- type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
26
- alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
27
- BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
28
- UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
29
- TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
30
-
31
- def mov_imm(value, reg):
32
- # Manually move value into reg if value can't fit
33
- if value.__class__ is not float and abs(value) > abs(65535):
34
- ins.append(f"movz w15, #{value & 0xffff}")
35
- ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
36
- ins.append(f"sxtw {reg}, w15")
37
- elif reg[0] == 's':
38
- ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
39
- ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
40
- ins.append("str x15, [sp, 16]")
41
- ins.append(f"ldr {reg}, [sp, 16]")
42
- else:
43
- ins.append(f"mov {reg}, #{value}")
44
-
45
- # Get variables intervals
46
- live_range:Dict[str, List[int]] = {}
47
- for i, (uop, out, vin, arg) in enumerate(asm):
48
- for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
49
- live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
50
-
51
- mem_vars:Dict[str, int] = {}
52
- rtor:Dict[str, str] = {}
53
- def allocate_regs(mvars):
54
- nonlocal var_size
55
- for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
56
- available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
57
- #NOTE: Very simple spill, everything that don't fit in regs goes to mem
58
- if not available_regs:
59
- # ARM needs the stack 16-byte aligned
60
- var_size += 16
61
- available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
62
- mem_vars[v.nm] = var_size
63
- rtor[v.nm] = available_regs.pop()
64
-
65
- temp_floats = ['s0', 's1', 's2']
66
- temp_ints = ['x12', 'x13', 'x16']
67
- for i, (uop, out, vin, arg) in enumerate(asm):
68
- # Clear regs out of interval
69
- for var, reg in list(rtor.items()):
70
- available_regs = s_regs if reg[0] == 's' else x_regs
71
- if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
72
- available_regs.append(rtor.pop(var))
73
- # Assign a registers to the variables using live ranges.
74
- allocate_regs([out] + vin)
75
- # Assign temp regs to vin and load them before direct use
76
- for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
77
- rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
78
- # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
79
- ins.append(f"mov x15, {mem_vars[v.nm]}")
80
- ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
81
-
82
- if uop == UOps.SPECIAL:
83
- if arg.startswith('data'):
84
- # data 8 to n into the stack
85
- if int(arg[4:]) >= 8:
86
- ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
87
- ins.append(f"mov {rtor[out.nm]}, x15")
88
- else:
89
- ins.append(f"mov {rtor[out.nm]}, #0")
90
- ins.append(f"loop_{arg}:")
91
- elif uop == UOps.CAST:
92
- if arg == BinaryOps.CMPLT:
93
- mov_imm(0.0, 's0')
94
- mov_imm(1.0, 's1')
95
- ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
96
- else:
97
- ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
98
- elif uop == UOps.ALU:
99
- if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
100
- if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
101
- ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
102
- elif arg == TernaryOps.WHERE:
103
- ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
104
- ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
105
- elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
106
- #NOTE: Not a real instruction, use to emulate a ext call in unicorn
107
- if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
108
- else:
109
- save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
110
- ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
111
- # Save the registers before they are cleared by func call
112
- for i,k in enumerate(save_regs,1):
113
- ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
114
- ins.append("stp x29, x30, [sp, #0]!")
115
- ins.append("mov x29, sp")
116
- ins.append(f"fmov s0, {rtor[vin[0].nm]}")
117
- ins.append(alu[arg])
118
- ins.append(f"fmov {rtor[out.nm]}, s0")
119
- ins.append("mov sp, x29")
120
- ins.append("ldp x29, x30, [sp], #0")
121
- for i,k in enumerate(save_regs,1):
122
- ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
123
- ins.append(f"add sp, sp, #{len(save_regs)*16}")
124
- elif arg == BinaryOps.CMPLT:
125
- ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
126
- elif arg == BinaryOps.MOD:
127
- ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
128
- ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
129
- else:
130
- ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
131
- elif uop == UOps.LOAD:
132
- if arg.__class__ in (int, float):
133
- mov_imm(arg, rtor[out.nm])
134
- else:
135
- #NOTE: if need casting load var in s/h0 or x/w12 temp regs
136
- reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
137
- mov_imm(arg[0], "x15")
138
- ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
139
- ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
140
- if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
141
- elif uop == UOps.STORE:
142
- #NOTE: if need casting load var in s/h0 or x/w12 temp regs
143
- reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
144
- if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
145
- ins.append(f"mov x15, #{arg[0]}")
146
- ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
147
- elif uop == UOps.COND_BRANCH:
148
- #TODO: this is a hack it shouldn't always be a cmp before a cond branch?
149
- if prev_uop == UOps.LOAD:
150
- ins.append(f"cmp {rtor[vin[0].nm]}, #0")
151
- ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
152
- elif uop == UOps.LABEL:
153
- ins.append(f"{arg[1:]}:")
154
- elif uop == UOps.ENDLOOP:
155
- mov_imm(arg[0], "x15")
156
- ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
157
- ins.append(f"cmp {rtor[vin[0].nm]}, x15")
158
- ins.append(f"b.lt loop_{arg[1]}")
159
- prev_uop = uop
160
- # store regs into memory if needed
161
- if out is not None and out.nm in mem_vars:
162
- ins.append(f"mov x15, {mem_vars[out.nm]}")
163
- ins.append(f"str {rtor[out.nm]}, [sp, x15]")
164
- return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
165
-
166
- def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
167
- lang = ARM64Language()
168
- global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
169
- return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
@@ -1,98 +0,0 @@
1
- from typing import List
2
- import struct
3
- from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
4
- from tinygrad.codegen.linearizer import UOps, UOp
5
- from tinygrad.helpers import dtypes
6
- from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
7
- from tinygrad.runtime.ops_cuda import arch
8
-
9
- dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
10
- def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
11
-
12
- def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
13
-
14
- def render_cast(ins, inp, out):
15
- if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
16
- ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
17
- elif out.dtype == dtypes.bool:
18
- ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
19
- else:
20
- round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
21
- ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
22
-
23
- # https://docs.nvidia.com/cuda/parallel-thread-execution/#
24
-
25
- class PTXLanguage(AssemblyLanguage):
26
- supports_constant_folding: bool = True
27
-
28
- def specialize_to_ptx(lang, function_name):
29
- param_cnt = 0
30
- ins = []
31
- alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
32
- BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
33
- UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
34
- TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
35
- for uop, out, vin, arg in lang.ins:
36
- if uop == UOps.ENDLOOP:
37
- ins.append("bar.sync 0;")
38
- elif uop == UOps.DEFINE_LOCAL:
39
- ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
40
- elif uop == UOps.SPECIAL:
41
- if arg.startswith('data'):
42
- param_cnt += 1
43
- ins.append(f"ld.param.u64 {out}, [{arg}];")
44
- # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
45
- # ins.append(f"cvta.to.global.u64 {out}, {out};")
46
- elif arg.startswith('gid'):
47
- ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
48
- elif arg.startswith('lid'):
49
- ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
50
- elif uop == UOps.ALU:
51
- if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
52
- ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
53
- else:
54
- otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
55
- if arg == TernaryOps.WHERE:
56
- reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
57
- ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
58
- vin = vin[1:] + [reg]
59
- ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
60
- elif uop == UOps.LOAD:
61
- if arg.__class__ in (int, float):
62
- ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
63
- elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
64
- dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
65
- reg = lang.newreg((out, dt[0]), dtype=dt[1])
66
- ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
67
- render_cast(ins, reg, out)
68
- else:
69
- ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
70
- elif uop == UOps.STORE:
71
- if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
72
- if arg[2] == dtypes.bool != vin[1].dtype:
73
- prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
74
- render_cast(ins, vin[1], prereg)
75
- else: prereg = vin[1]
76
- reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
77
- render_cast(ins, prereg, reg)
78
- ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
79
- else:
80
- ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
81
- elif uop == UOps.CAST:
82
- render_cast(ins, vin[0], out)
83
- elif uop == UOps.LABEL:
84
- ins.append(f"{arg}:")
85
- elif uop == UOps.COND_BRANCH:
86
- ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
87
-
88
- ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
89
- f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
90
- for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
91
- ins = ins_prefix + ins
92
- ins += ["ret;", "}"]
93
- return '\n'.join(ins)
94
-
95
- def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
96
- lang = PTXLanguage()
97
- global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
98
- return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
tinygrad/renderer/wgsl.py DELETED
@@ -1,53 +0,0 @@
1
- from tinygrad.renderer.cstyle import render_cl
2
- from tinygrad.helpers import dtypes, DType
3
- from tinygrad.renderer.cstyle import CStyleLanguage
4
- from typing import List, Union
5
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
6
- import math
7
- from typing import Tuple
8
-
9
- type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
10
- class WGSLLanguage(CStyleLanguage):
11
- gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
12
- lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
13
- size_prefix = "let"
14
- barrier="workgroupBarrier();"
15
- generic_var_prefix = "var "
16
- external_local_bufs = True
17
- code_for_op = {
18
- UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
19
- BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
20
- BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
21
- TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
22
- }
23
-
24
- def render_local(self, name: str, size: int):
25
- return f"var<workgroup> {name}: array<f32,{size}>;"
26
-
27
- def render_const(self, x:Union[float,int], var_dtype) -> str:
28
- if math.isnan(x): val = "nan()"
29
- elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
30
- else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
31
- return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
32
-
33
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
34
- local_size = local_size[::-1] if local_size else [1]
35
- bind_it = iter(range(len(bufs)))
36
- prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
37
- prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
38
- prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
39
- return prg, global_size[::-1] if global_size else [1], local_size
40
-
41
- def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
42
- return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
43
-
44
- def render_conditional(self, cond:str, x:str, y:str) -> str:
45
- return f"select(f32({y}), {x}, bool({cond}))"
46
-
47
- def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
48
- return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
49
-
50
- def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
51
- if buf_dtype != var_dtype:
52
- var_name = f"{type_map[buf_dtype]}({var_name})"
53
- return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
tinygrad/runtime/lib.py DELETED
@@ -1,113 +0,0 @@
1
- import ctypes
2
- import numpy as np
3
- from collections import defaultdict, deque
4
- from typing import TypeVar, Type, Any, Dict, Deque, Tuple
5
- from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
6
-
7
- _T = TypeVar("_T")
8
- class RawBuffer: # pylint: disable=abstract-method
9
- def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
10
- self.size: int = size
11
- self.dtype: DType = dtype
12
- self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
13
- self._memsz: int = size*dtype.itemsize
14
- self._allocator = allocator
15
- GlobalCounters.mem_used += self._memsz
16
- def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
17
- if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
18
- if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
19
- def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
20
- @property
21
- def key(self): return (self.size, self.dtype)
22
-
23
- # NOTE: this interface allows for 0 copy
24
- @classmethod
25
- def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
26
- def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
27
-
28
- class RawConst(RawBuffer): # pylint: disable=abstract-method
29
- def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
30
- @property
31
- def key(self): return (str(self._buf), self.dtype)
32
-
33
- def buf_is_kernel_arg(x) -> bool:
34
- return x.realized is not None and x.realized.__class__ is not RawConst
35
-
36
- # --teenygrad--
37
-
38
- class RawBufferCopyIn(RawBuffer):
39
- def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
40
-
41
- @classmethod
42
- def fromCPU(cls, x:np.ndarray, **kwargs):
43
- ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
44
- ret._copyin(x)
45
- return ret
46
-
47
- class RawBufferMapped(RawBufferCopyIn):
48
- def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
49
- # NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
50
- def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
51
- def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
52
-
53
- # this one is simple enough that i moved it out of the runtimes
54
- class RawMallocBuffer(RawBufferMapped):
55
- def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
56
- def _buffer(self): return memoryview(self._buf)
57
-
58
- class RawBufferCopyInOut(RawBufferCopyIn):
59
- def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
60
-
61
- def toCPU(self) -> np.ndarray:
62
- x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
63
- self._copyout(x)
64
- return x
65
-
66
- class RawBufferTransfer(RawBuffer):
67
- def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
68
-
69
- @classmethod
70
- def transfer(cls, x, shape, dtype, **kwargs):
71
- ret = cls(prod(shape), dtype, **kwargs)
72
- ret._transfer(x)
73
- return ret
74
-
75
- class LRUAllocator:
76
- def __init__(self, dev_memsz=(4<<30)):
77
- self.epoch = 0
78
- self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
79
- self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
80
- self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
81
- self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
82
- def __del__(self):
83
- for v in self.cached_buffers.values():
84
- for buf, _ in v: self._free_buffer(buf)
85
- def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
86
- GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
87
- return rawbufs.popleft()[0]
88
- def _alloc_buffer(self, size, dtype, device, **kwargs):
89
- self.free_space[device] -= size*dtype.itemsize
90
- while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
91
- bucket, epoch = self.aging_order[device].popleft()
92
- if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
93
- newbuf = self._do_alloc(size, dtype, device, **kwargs)
94
- self.buffer_info[newbuf] = (size, dtype, device)
95
- return newbuf
96
- def _free_buffer(self, buf_to_free):
97
- self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
98
- GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
99
- self.buffer_info.pop(buf_to_free)
100
- self._do_free(buf_to_free)
101
- def alloc(self, size, dtype, device='0', **kwargs):
102
- rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
103
- return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
104
- def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
105
- self.epoch += 1
106
- size, dtype, device = self.buffer_info[buf]
107
- self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
108
- self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
109
- GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
110
- def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
111
- def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
112
- def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
113
- def _do_free(self, buf): pass
@@ -1,51 +0,0 @@
1
- import numpy as np
2
- import operator
3
- from typing import Callable, Dict, Tuple, Optional
4
- from tinygrad.helpers import dtypes, DType
5
- from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
6
- from tinygrad.runtime.lib import RawBuffer
7
-
8
- def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
9
- assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
10
- return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
11
-
12
- base_fxn_for_op: Dict[Op, Callable] = {
13
- BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
14
- ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
15
- ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
16
- MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
17
- }
18
-
19
- def promote_types(x, y): return ret if (ret := np.promote_types(x.dtype, y.dtype)) != np.float64 else np.float32
20
- def match_types(x, y):
21
- up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
22
- return x.astype(up, copy=False), y.astype(up, copy=False)
23
-
24
- def einsum_mulacc(einsum, get_strides, expand):
25
- def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
26
- def axes_slice(strides): return [i for i in range(len(strides)) if strides[i] != 0], tuple(slice(None) if strides[i] != 0 else 0 for i in range(len(strides)))
27
- def mulacc(a, b, new_shape):
28
- (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
29
- out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
30
- ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
31
- return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
32
- return mulacc
33
-
34
- numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
35
- UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
36
- UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
37
- BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(promote_types(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
38
- BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
39
- BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
40
- MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
41
- MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
42
- TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
43
- TernaryOps.WHERE: np.where,
44
- }}
45
-
46
- class RawNumpyBuffer(RawBuffer):
47
- def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
48
- @classmethod
49
- def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
50
- def toCPU(self): return self._buf
51
- CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)