tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -2,65 +2,55 @@
2
2
  from __future__ import annotations
3
3
  import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
4
4
  from contextlib import ContextDecorator
5
- from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
5
+ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
6
6
  from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
7
+ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
7
8
  from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
8
- from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
9
- from tinygrad.engine.multi import get_multi_map
9
+ from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray
10
10
  from tinygrad.gradient import compute_gradient
11
- from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
12
- from tinygrad.spec import tensor_uop_spec, type_verify
13
- from tinygrad.device import Device, BufferSpec
11
+ from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
12
+ from tinygrad.uop.spec import tensor_uop_spec, type_verify
13
+ from tinygrad.device import Device, Buffer
14
14
  from tinygrad.engine.realize import run_schedule
15
15
  from tinygrad.engine.memory import memory_planner
16
16
  from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
17
+ from tinygrad.schedule.kernelize import get_kernelize_map
17
18
 
18
19
  # *** all in scope Tensors are here. this gets relevant UOps ***
19
20
 
20
- all_tensors: set[weakref.ref[Tensor]] = set()
21
+ all_tensors: dict[weakref.ref[Tensor], None] = {}
22
+ def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
23
+ return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
21
24
 
22
- def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
25
+ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
23
26
  # get all children of keys in applied_map
24
27
  all_uops: set[UOp] = set()
25
28
  search_uops = list(applied_map)
26
29
  while len(search_uops):
27
- x = search_uops.pop(0)
30
+ x = search_uops.pop()
28
31
  if x in all_uops: continue
29
32
  all_uops.add(x)
30
33
  search_uops.extend([u for c in x.children if (u:=c()) is not None])
31
34
 
32
35
  # link the found UOps back to Tensors. exit early if there's no Tensors to realize
33
36
  # NOTE: this uses all_tensors, but it's fast
34
- fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
35
-
36
- if len(fixed_tensors):
37
+ if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
37
38
  # potentially rewrite all the discovered Tensors
38
- sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
39
- new_sink = sink.substitute(applied_map)
39
+ sink = UOp.sink(*[t.uop for t in fixed_tensors])
40
+ new_sink = sink.substitute(applied_map, name=name)
40
41
 
41
- # set the relevant lazydata to the realized UOps
42
+ # set the relevant uop to the realized UOps
42
43
  for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
43
44
  if s is ns: continue
44
- t.lazydata = ns
45
+ t.uop = ns
45
46
 
46
47
  # **** Tensor helper functions ****
47
48
 
48
- def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
49
- if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
50
- return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
51
-
52
- def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
53
- import numpy as np
54
- return dtypes.fields()[np.dtype(npdtype).name]
55
- def _to_np_dtype(dtype:DType) -> Optional[type]:
56
- import numpy as np
57
- return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
58
-
59
49
  def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
60
- ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
50
+ ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype))
61
51
  # fake realize
62
52
  ret.buffer.allocate(x)
63
- return ret
53
+ return ret.reshape(x.shape)
64
54
 
65
55
  def get_shape(x) -> tuple[int, ...]:
66
56
  # NOTE: str is special because __getitem__ on a str is still a str
@@ -68,10 +58,10 @@ def get_shape(x) -> tuple[int, ...]:
68
58
  if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
69
59
  return (len(subs),) + (subs[0] if subs else ())
70
60
 
71
- def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
72
- if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
61
+ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
62
+ if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
73
63
  else:
74
- ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
64
+ ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
75
65
  assert dtype.fmt is not None, f"{dtype=} has None fmt"
76
66
  truncate_function = truncate[dtype]
77
67
  data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
@@ -79,7 +69,7 @@ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
79
69
  ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
80
70
  return ret
81
71
 
82
- def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
72
+ def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]:
83
73
  return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
84
74
  for k in range(len(mat[0]))] for dim in range(dims)]
85
75
 
@@ -102,13 +92,12 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
102
92
  def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
103
93
  return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
104
94
 
105
- def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
106
- # apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
107
- values = values * mask
95
+ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
96
+ # reduce such that if mask contains repeated indices the last one remains
108
97
  for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
109
98
  # remove extra dims from reduce
110
99
  for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
111
- # select from values for each True element in mask else select from self
100
+ # select from values for each True element in mask else select from target
112
101
  return mask.where(values, target)
113
102
 
114
103
  # `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
@@ -116,7 +105,7 @@ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: r
116
105
 
117
106
  ReductionStr = Literal["mean", "sum", "none"]
118
107
 
119
- class Tensor(SimpleMathTrait):
108
+ class Tensor(MathTrait):
120
109
  """
121
110
  A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
122
111
 
@@ -127,70 +116,76 @@ class Tensor(SimpleMathTrait):
127
116
  np.set_printoptions(precision=4)
128
117
  ```
129
118
  """
130
- __slots__ = "lazydata", "requires_grad", "grad"
119
+ __slots__ = "uop", "requires_grad", "grad"
131
120
  training: ClassVar[bool] = False
132
- no_grad: ClassVar[bool] = False
133
121
 
134
- def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
135
- device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
122
+ def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
123
+ device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
136
124
  if dtype is not None: dtype = to_dtype(dtype)
137
125
  if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
138
126
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
139
127
 
140
128
  # tensors can have gradients if you have called .backward
141
- self.grad: Optional[Tensor] = None
129
+ self.grad:Tensor|None = None
142
130
 
143
131
  # NOTE: this can be in three states. False and None: no gradient, True: gradient
144
132
  # None (the default) will be updated to True if it's put in an optimizer
145
- self.requires_grad: Optional[bool] = requires_grad
133
+ self.requires_grad:bool|None = requires_grad
146
134
 
147
- # create a LazyBuffer from the different types of inputs
135
+ # create a UOp from the different types of inputs
148
136
  if isinstance(data, UOp):
149
137
  assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
150
- # NOTE: this is here because LazyBuffer = UOp
151
- if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
152
- elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
153
- elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
138
+ if data.op is Ops.BIND:
139
+ var, val = data.unbind()
140
+ # give the bound constant a device
141
+ const = UOp.const(var.dtype, val, device, ())
142
+ data = data.replace(src=(var.replace(src=const.src), const))
143
+ elif data is None: data = UOp.const(dtype or dtypes.default_float, 0, device, ())
144
+ elif isinstance(data, get_args(ConstType)): data = UOp.const(dtype or dtypes.from_py(data), data, device, ())
154
145
  elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
155
146
  elif isinstance(data, (list, tuple)):
156
147
  if dtype is None:
157
148
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
158
149
  else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
159
- if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
150
+ if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).uop
160
151
  else: data = _frompy(data, dtype)
161
- elif str(type(data)) == "<class 'numpy.ndarray'>":
152
+ elif is_numpy_ndarray(data):
162
153
  import numpy as np
163
154
  assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
164
- if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
155
+ if data.shape == (): data = UOp.const(dtype or _from_np_dtype(data.dtype), data.item(), device, ())
165
156
  else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
166
157
  elif isinstance(data, pathlib.Path):
167
158
  dtype = dtype or dtypes.uint8
168
- data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
159
+ data = UOp.new_buffer(f"DISK:{data.resolve()}", data.stat().st_size // dtype.itemsize, dtype)
169
160
 
170
161
  # by this point, it has to be a UOp
171
162
  if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
172
163
 
173
164
  # data might be on a different device
174
- if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
165
+ if isinstance(device, str): self.uop:UOp = data if data.device == device else data.copy_to_device(device)
175
166
  # if device is a tuple, we should have/construct a MultiLazyBuffer
176
- elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
167
+ elif isinstance(data.device, str): self.uop = Tensor(data).shard(device).uop
177
168
  else:
178
169
  assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
179
- self.lazydata = data
170
+ self.uop = data
180
171
 
181
172
  # add to all_tensors after construction succeeds
182
- all_tensors.add(weakref.ref(self))
183
- def __del__(self): all_tensors.discard(weakref.ref(self))
173
+ all_tensors[weakref.ref(self)] = None
174
+ def __del__(self): all_tensors.pop(weakref.ref(self), None)
184
175
 
185
176
  def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
186
- new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
177
+ new_uop: UOp = fxn(*[t.uop for t in (self,)+x], **kwargs)
178
+ if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
187
179
  needs_input_grad = [t.requires_grad for t in (self,)+x]
188
180
  return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
189
181
 
190
- def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
182
+ def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
191
183
  lhs,rhs = self._broadcasted(x, reverse)
192
184
  return lhs._apply_uop(fxn, rhs)
193
185
 
186
+ # _binop is used by MathTrait
187
+ def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse)
188
+
194
189
  def requires_grad_(self, requires_grad=True) -> Tensor:
195
190
  self.requires_grad = requires_grad
196
191
  return self
@@ -200,15 +195,10 @@ class Tensor(SimpleMathTrait):
200
195
  def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
201
196
  def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
202
197
 
203
- class test(ContextDecorator):
204
- def __init__(self, mode:bool = True): self.mode = mode
205
- def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
206
- def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
207
-
208
198
  def __repr__(self):
209
- ld = self.lazydata
199
+ ld = self.uop
210
200
  ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
211
- return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
201
+ return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.uop if self.grad is not None else None)!r}>"
212
202
 
213
203
  # Python has a non moving GC, so this should be okay
214
204
  def __hash__(self): return id(self)
@@ -220,36 +210,50 @@ class Tensor(SimpleMathTrait):
220
210
  return self.shape[0]
221
211
 
222
212
  @property
223
- def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
213
+ def device(self) -> str|tuple[str, ...]: return self.uop.device
224
214
 
225
215
  @property
226
- def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
216
+ def shape(self) -> tuple[sint, ...]: return self.uop.shape
227
217
 
228
218
  @property
229
- def dtype(self) -> DType: return self.lazydata.dtype
219
+ def dtype(self) -> DType: return self.uop.dtype
230
220
 
231
221
  # ***** data handlers ****
232
222
 
223
+ def kernelize(self, *lst:Tensor) -> Tensor:
224
+ """
225
+ Creates the kernels and buffers needed to realize these Tensor(s).
226
+
227
+ NOTE: Kernelize can be called multiple times on a Tensor
228
+ """
229
+ big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
230
+
231
+ # verify Tensors match the spec
232
+ if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
233
+
234
+ becomes_map = get_kernelize_map(big_sink)
235
+ _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
236
+ return self
237
+
233
238
  def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
234
239
  """
235
240
  Creates the schedule needed to realize these Tensor(s), with Variables.
236
241
 
237
242
  NOTE: A Tensor can only be scheduled once.
238
243
  """
239
- big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
244
+ st = time.perf_counter()
245
+ self.kernelize(*lst)
246
+ sink = UOp.sink(*[x.uop for x in (self,)+lst])
240
247
 
241
- # TODO: move this to scheduler tensor_map pass
242
- if any(x.op is Ops.MULTI for x in big_sink.toposort):
243
- # multi fixup
244
- _apply_map_to_tensors(get_multi_map(big_sink))
245
- big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
246
-
247
- # verify Tensors match the spec
248
- if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
248
+ # remove all ASSIGNs, after scheduling, the tensors are just buffers
249
+ remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
250
+ _apply_map_to_tensors(remove_assign_map, name="Remove Assigns")
249
251
 
250
- schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
251
- _apply_map_to_tensors(becomes_map)
252
- return memory_planner(schedule), var_vals
252
+ # create the schedule
253
+ schedule, var_vals = create_schedule_with_vars(sink)
254
+ schedule = memory_planner(schedule)
255
+ if DEBUG >= 1 and len(schedule) > 1: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
256
+ return schedule, var_vals
253
257
 
254
258
  def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
255
259
  """Creates the schedule needed to realize these Tensor(s)."""
@@ -262,46 +266,43 @@ class Tensor(SimpleMathTrait):
262
266
  run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
263
267
  return self
264
268
 
265
- def replace(self, x:Tensor) -> Tensor:
269
+ def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
266
270
  """
267
271
  Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
268
272
  """
269
273
  # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
270
- assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
271
- self.lazydata = x.lazydata
274
+ assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
275
+ self.uop = x.uop
272
276
  return self
273
277
 
274
278
  def assign(self, x) -> Tensor:
275
279
  # TODO: this is a hack for writing to DISK. remove with working assign
276
280
  if isinstance(self.device, str) and self.device.startswith("DISK"):
277
281
  if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
278
- self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
282
+ self._buffer().copyin(x._data())
279
283
  return self
280
284
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
281
- if self.lazydata is x.lazydata: return self # a self assign is a NOOP
285
+ if self.uop is x.uop: return self # a self assign is a NOOP
282
286
  # NOTE: we allow cross device assign
287
+ # broadcast x
288
+ if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
283
289
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
284
290
  assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
285
291
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
286
- assert not x.requires_grad # self requires_grad is okay?
287
- if not self.lazydata.is_realized: return self.replace(x)
288
- self.lazydata = self.lazydata.assign(x.lazydata)
292
+ self.uop = self.uop.assign(x.uop)
289
293
  return self
290
294
 
291
295
  def detach(self) -> Tensor:
292
296
  """
293
297
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
294
298
  """
295
- return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
299
+ return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
296
300
 
297
- def _data(self) -> memoryview:
298
- if 0 in self.shape: return memoryview(bytearray(0))
299
- # NOTE: this realizes on the object from as_buffer being a Python object
300
- cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
301
- buf = cast(UOp, cpu.lazydata).base.realized
302
- assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
303
- if self.device != "CPU": buf.options = BufferSpec(nolru=True)
304
- return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
301
+ def _buffer(self) -> Buffer:
302
+ x = self.cast(self.dtype.base).contiguous()
303
+ if isinstance(self.device, tuple): x = x.to("CPU")
304
+ return cast(Buffer, x.realize().uop.base.buffer).ensure_allocated()
305
+ def _data(self) -> memoryview: return self._buffer().as_buffer()
305
306
 
306
307
  def data(self) -> memoryview:
307
308
  """
@@ -312,10 +313,9 @@ class Tensor(SimpleMathTrait):
312
313
  print(np.frombuffer(t.data(), dtype=np.int32))
313
314
  ```
314
315
  """
315
- assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
316
+ if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt)
316
317
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
317
- if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
318
- return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
318
+ return self._buffer().as_typed_buffer(self.shape)
319
319
 
320
320
  def item(self) -> ConstType:
321
321
  """
@@ -331,7 +331,7 @@ class Tensor(SimpleMathTrait):
331
331
 
332
332
  # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
333
333
  # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
334
- def tolist(self) -> Union[Sequence[ConstType], ConstType]:
334
+ def tolist(self) -> Sequence[ConstType]|ConstType:
335
335
  """
336
336
  Returns the value of this tensor as a nested list.
337
337
  Returns single value for const tensor.
@@ -345,6 +345,7 @@ class Tensor(SimpleMathTrait):
345
345
  print(t.tolist())
346
346
  ```
347
347
  """
348
+ if self.dtype in (dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist()
348
349
  return self.data().tolist()
349
350
 
350
351
  def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
@@ -356,32 +357,32 @@ class Tensor(SimpleMathTrait):
356
357
  print(repr(t.numpy()))
357
358
  ```
358
359
  """
360
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
359
361
  import numpy as np
360
362
  if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
361
- assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
362
- assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
363
- return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
363
+ if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
364
+ return self._buffer().numpy().reshape(self.shape)
364
365
 
365
366
  def clone(self) -> Tensor:
366
367
  """
367
368
  Creates a clone of this tensor allocating a separate buffer for the data.
368
369
  """
369
- ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
370
+ ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
370
371
  if self.grad is not None: ret.grad = self.grad.clone()
371
- return ret
372
+ return ret.assign(self)
372
373
 
373
- def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
374
+ def to(self, device:str|tuple[str, ...]|None) -> Tensor:
374
375
  """
375
376
  Moves the tensor to the given device.
376
377
  """
377
378
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
378
379
  if device == self.device: return self
379
380
  if not isinstance(device, str): return self.shard(device)
380
- ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
381
+ ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
381
382
  if self.grad is not None: ret.grad = self.grad.to(device)
382
383
  return ret
383
384
 
384
- def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
385
+ def to_(self, device:str|tuple[str, ...]|None) -> Tensor:
385
386
  """
386
387
  Moves the tensor to the given device in place.
387
388
  """
@@ -389,21 +390,21 @@ class Tensor(SimpleMathTrait):
389
390
  if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
390
391
  return self.replace(real)
391
392
 
392
- def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
393
+ def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
393
394
  """
394
395
  Shards the tensor across the given devices. Optionally specify which axis to shard on.
395
396
 
396
397
  ```python exec="true" source="above" session="tensor" result="python"
397
398
  t = Tensor.empty(2, 4)
398
- print(t.shard((t.device, t.device), axis=1).lazydata)
399
+ print(t.shard((t.device, t.device), axis=1).uop)
399
400
  ```
400
401
  """
401
402
  assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
402
403
  devices = tuple(Device.canonicalize(x) for x in devices)
403
- mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
404
+ mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
404
405
  return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
405
406
 
406
- def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
407
+ def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
407
408
  """
408
409
  Shards the tensor across the given devices in place.
409
410
  """
@@ -411,7 +412,7 @@ class Tensor(SimpleMathTrait):
411
412
 
412
413
  @staticmethod
413
414
  def from_uop(y:UOp, **kwargs) -> Tensor:
414
- if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
415
+ if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
415
416
  if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
416
417
  if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
417
418
  if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
@@ -420,15 +421,7 @@ class Tensor(SimpleMathTrait):
420
421
  # ***** creation entrypoint *****
421
422
 
422
423
  @staticmethod
423
- def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
424
- dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
425
- if isinstance(device, tuple):
426
- return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
427
- device, dtype, **kwargs)
428
- return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
429
-
430
- @staticmethod
431
- def empty(*shape, **kwargs):
424
+ def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
432
425
  """
433
426
  Creates an empty tensor with the given shape.
434
427
 
@@ -440,7 +433,11 @@ class Tensor(SimpleMathTrait):
440
433
  print(t.shape)
441
434
  ```
442
435
  """
443
- return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
436
+ dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
437
+ if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
438
+ # TODO: add test for multidevice tensor
439
+ device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device)
440
+ return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape)
444
441
 
445
442
  @staticmethod
446
443
  def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
@@ -451,21 +448,21 @@ class Tensor(SimpleMathTrait):
451
448
  You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
452
449
  Additionally, all other keyword arguments are passed to the constructor of the tensor.
453
450
  """
454
-
455
- r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
456
- r.lazydata.buffer.allocate(external_ptr=ptr)
451
+ r = Tensor.empty(*shape, **kwargs)
452
+ assert isinstance(r.device, str)
453
+ cast(Buffer, r.uop.buffer).allocate(external_ptr=ptr)
457
454
  return r
458
455
 
459
456
  @staticmethod
460
457
  def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
461
458
  """
462
- Create a Tensor from a URL.
459
+ Creates a Tensor from a URL.
463
460
 
464
461
  This is the preferred way to access Internet resources.
465
462
  It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
466
463
  This also will soon become lazy (when possible) and not print progress without DEBUG.
467
464
 
468
- THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
465
+ The `gunzip` flag will gzip extract the resource and return an extracted Tensor.
469
466
  """
470
467
  return Tensor(fetch(url, gunzip=gunzip), **kwargs)
471
468
 
@@ -473,7 +470,7 @@ class Tensor(SimpleMathTrait):
473
470
  _device_seeds: dict[str, Tensor] = {}
474
471
  _device_rng_counters: dict[str, Tensor] = {}
475
472
  @staticmethod
476
- def manual_seed(seed=0):
473
+ def manual_seed(seed=0) -> None:
477
474
  """
478
475
  Sets the seed for random operations.
479
476
 
@@ -491,14 +488,14 @@ class Tensor(SimpleMathTrait):
491
488
  Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
492
489
 
493
490
  @staticmethod
494
- def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
491
+ def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor:
495
492
  x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
496
493
  x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
497
494
  counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
498
495
  return counts0.cat(counts1)
499
496
 
500
497
  @staticmethod
501
- def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
498
+ def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor:
502
499
  """
503
500
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
504
501
 
@@ -514,26 +511,24 @@ class Tensor(SimpleMathTrait):
514
511
  if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
515
512
  if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
516
513
  if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
517
- _device = device = Device.canonicalize(device)
514
+ device = Device.canonicalize(device)
518
515
 
519
516
  # if shape has 0, return zero tensor
520
- if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
517
+ if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
521
518
  num = ceildiv(numel * dtype.itemsize, 4)
522
519
 
523
- # when using MOCKGPU and NV generate rand on CPU
524
- if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
525
-
526
520
  # generate per device seeds and rng counter if we haven't seen this device yet
527
521
  if device not in Tensor._device_seeds:
528
522
  Tensor._device_seeds[device] = Tensor(
529
523
  [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
530
524
  device=device, dtype=dtypes.uint32, requires_grad=False)
531
- Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
525
+ Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
532
526
  # increment rng counter for devices
533
527
  else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
534
528
 
535
529
  # threefry random bits
536
- counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
530
+ bits_count = Tensor._device_rng_counters[device] - num
531
+ counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
537
532
  counts1 = counts0 + ceildiv(num, 2)
538
533
  bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
539
534
 
@@ -545,12 +540,7 @@ class Tensor(SimpleMathTrait):
545
540
  one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
546
541
  bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
547
542
  # bitcast back to the original dtype and reshape
548
- out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
549
-
550
- # move back to the original device if we were using MOCKGPU
551
- if getenv("MOCKGPU") and _device: out = out.to(_device)
552
-
553
- out.requires_grad = kwargs.get("requires_grad")
543
+ out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
554
544
  return out.contiguous() if contiguous else out
555
545
 
556
546
  # ***** creation helper functions *****
@@ -638,7 +628,7 @@ class Tensor(SimpleMathTrait):
638
628
  return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
639
629
 
640
630
  @staticmethod
641
- def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
631
+ def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor:
642
632
  """
643
633
  Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
644
634
 
@@ -658,7 +648,7 @@ class Tensor(SimpleMathTrait):
658
648
  return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
659
649
 
660
650
  @staticmethod
661
- def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
651
+ def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
662
652
  """
663
653
  Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
664
654
 
@@ -674,7 +664,7 @@ class Tensor(SimpleMathTrait):
674
664
  ```
675
665
  """
676
666
  if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
677
- x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
667
+ x = Tensor.ones(n, **kwargs).diag()
678
668
  return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
679
669
 
680
670
  def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
@@ -735,17 +725,34 @@ class Tensor(SimpleMathTrait):
735
725
  dtype = kwargs.pop("dtype", self.dtype)
736
726
  if isinstance(self.device, tuple):
737
727
  if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
738
- if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
728
+ if self.uop.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
739
729
  contiguous = kwargs.pop("contiguous", True)
740
- sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
741
- rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
742
- return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
730
+ sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
731
+ rands = UOp(Ops.MSTACK, dtype=dtype,
732
+ src=tuple([Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).uop for d in self.device]))
733
+ return Tensor(UOp.multi(rands, axis=self.uop.axis), device=self.device, dtype=dtype, **kwargs)
743
734
  return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
744
735
 
745
736
  # ***** rng hlops *****
746
737
 
738
+ def randn_like(self, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
739
+ """
740
+ Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
741
+
742
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
743
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
744
+
745
+ ```python exec="true" source="above" session="tensor" result="python"
746
+ t = Tensor.ones(2, 3)
747
+ print(Tensor.randn_like(t).numpy())
748
+ ```
749
+ """
750
+ src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
751
+ # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
752
+ return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)).requires_grad_(requires_grad)
753
+
747
754
  @staticmethod
748
- def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
755
+ def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
749
756
  """
750
757
  Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
751
758
  If `dtype` is not specified, the default type is used.
@@ -758,9 +765,7 @@ class Tensor(SimpleMathTrait):
758
765
  print(Tensor.randn(2, 3).numpy())
759
766
  ```
760
767
  """
761
- # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
762
- src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
763
- return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
768
+ return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype, requires_grad=requires_grad)
764
769
 
765
770
  @staticmethod
766
771
  def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
@@ -782,7 +787,7 @@ class Tensor(SimpleMathTrait):
782
787
  return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
783
788
 
784
789
  @staticmethod
785
- def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
790
+ def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
786
791
  """
787
792
  Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
788
793
 
@@ -797,7 +802,7 @@ class Tensor(SimpleMathTrait):
797
802
  return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
798
803
 
799
804
  @staticmethod
800
- def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
805
+ def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
801
806
  """
802
807
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
803
808
 
@@ -877,7 +882,29 @@ class Tensor(SimpleMathTrait):
877
882
  std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
878
883
  return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
879
884
 
885
+ @staticmethod
886
+ def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
887
+ """
888
+ Returns a tensor with a random permutation of integers from `0` to `n-1`.
889
+
890
+ ```python exec="true" source="above" session="tensor" result="python"
891
+ Tensor.manual_seed(42)
892
+ print(Tensor.randperm(6).numpy())
893
+ ```
894
+ """
895
+ return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
896
+
880
897
  def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
898
+ """
899
+ Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
900
+
901
+ NOTE: `replacement=False` for `num_samples > 1` is not supported yet.
902
+ ```python exec="true" source="above" session="tensor" result="python"
903
+ Tensor.manual_seed(42)
904
+ t = Tensor([1, 2, 3, 4])
905
+ print(t.multinomial(20, replacement=True).numpy())
906
+ ```
907
+ """
881
908
  assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
882
909
  assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
883
910
  weight = self.unsqueeze(0) if self.ndim == 1 else self
@@ -888,9 +915,9 @@ class Tensor(SimpleMathTrait):
888
915
 
889
916
  # ***** toposort and backward pass *****
890
917
 
891
- def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
918
+ def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]:
892
919
  """
893
- Compute the gradient of the targets with respect to self.
920
+ Computes the gradient of the targets with respect to self.
894
921
 
895
922
  ```python exec="true" source="above" session="tensor" result="python"
896
923
  x = Tensor.eye(3)
@@ -903,21 +930,20 @@ class Tensor(SimpleMathTrait):
903
930
  ```
904
931
  """
905
932
  assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
933
+ if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
906
934
  if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
907
- rets = []
908
- target_uops = [x.lazydata for x in targets]
909
- grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
935
+ target_uops = [x.uop for x in targets]
936
+ grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
910
937
  ret = []
911
938
  for x in target_uops:
912
939
  if (y:=grads.get(x)) is None:
913
940
  if materialize_grads: y = x.const_like(0)
914
- else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
941
+ else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
915
942
  ret.append(y)
916
- rets.append(ret)
917
943
  # create returned Tensors
918
- return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
944
+ return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
919
945
 
920
- def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
946
+ def backward(self, gradient:Tensor|None=None) -> Tensor:
921
947
  """
922
948
  Propagates the gradient of a tensor backwards through the computation graph.
923
949
  If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
@@ -927,9 +953,9 @@ class Tensor(SimpleMathTrait):
927
953
  print(t.grad.numpy())
928
954
  ```
929
955
  """
930
- all_uops = self.lazydata.toposort
956
+ all_uops = self.uop.toposort()
931
957
  tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
932
- t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
958
+ t.uop in all_uops and t.requires_grad]
933
959
  # clear contexts
934
960
  for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
935
961
  assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
@@ -938,9 +964,9 @@ class Tensor(SimpleMathTrait):
938
964
 
939
965
  # ***** movement low level ops *****
940
966
 
941
- def view(self, *shape) -> Tensor:
967
+ def view(self, shape:tuple[sint, ...], *args) -> Tensor:
942
968
  """`.view` is an alias for `.reshape`."""
943
- return self.reshape(shape)
969
+ return self.reshape(shape, *args)
944
970
 
945
971
  def reshape(self, shape, *args) -> Tensor:
946
972
  """
@@ -981,11 +1007,11 @@ class Tensor(SimpleMathTrait):
981
1007
  `order` can be passed as a tuple or as separate arguments.
982
1008
 
983
1009
  ```python exec="true" source="above" session="tensor" result="python"
984
- t = Tensor.arange(6).reshape(2, 3)
985
- print(t.numpy())
1010
+ t = Tensor.empty(2, 3, 5)
1011
+ print(t.shape)
986
1012
  ```
987
1013
  ```python exec="true" source="above" session="tensor" result="python"
988
- print(t.permute(1, 0).numpy())
1014
+ print(t.permute(2, 0, 1).shape)
989
1015
  ```
990
1016
  """
991
1017
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
@@ -1012,7 +1038,7 @@ class Tensor(SimpleMathTrait):
1012
1038
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
1013
1039
  return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
1014
1040
 
1015
- def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
1041
+ def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
1016
1042
  """
1017
1043
  Returns a tensor that shrinks the each axis based on input arg.
1018
1044
  `arg` must have the same length as `self.ndim`.
@@ -1032,7 +1058,7 @@ class Tensor(SimpleMathTrait):
1032
1058
  if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
1033
1059
  return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
1034
1060
 
1035
- def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
1061
+ def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
1036
1062
  """
1037
1063
  Returns a tensor with padding applied based on the input `padding`.
1038
1064
 
@@ -1070,11 +1096,11 @@ class Tensor(SimpleMathTrait):
1070
1096
  if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1071
1097
  pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
1072
1098
  # group padding
1073
- else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
1099
+ else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
1074
1100
  if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
1075
1101
  X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
1076
1102
  if mode == "constant":
1077
- def _constant(x:Tensor,px,v):
1103
+ def _constant(x:Tensor,px,v) -> Tensor:
1078
1104
  return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
1079
1105
  return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
1080
1106
  _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
@@ -1097,16 +1123,17 @@ class Tensor(SimpleMathTrait):
1097
1123
 
1098
1124
  # ***** movement high level ops *****
1099
1125
 
1100
- def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
1126
+ def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
1101
1127
  # wrap single index into a list
1102
1128
  if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
1103
1129
  x, indices = self, list(indices)
1104
1130
 
1105
- # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
1131
+ # fill ellipsis or rest of indices with slice(None)
1106
1132
  if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
1107
- fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
1133
+ # NOTE: None adds a dim later
1108
1134
  num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
1109
1135
  if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
1136
+ fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
1110
1137
  indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
1111
1138
 
1112
1139
  indices_parsed, dim = [], 0
@@ -1114,22 +1141,32 @@ class Tensor(SimpleMathTrait):
1114
1141
  size = 1 if index is None else self.shape[dim]
1115
1142
  boundary, stride = [0, size], 1 # defaults
1116
1143
  match index:
1117
- case list() | tuple() | Tensor():
1118
- if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
1144
+ case Tensor():
1119
1145
  if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
1120
- index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
1146
+ index = (index < 0).where(index+size, index).to(self.device) # treat negative index values
1147
+ case list() | tuple():
1148
+ if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element")
1149
+ index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device, requires_grad=False).reshape(ti.shape)
1121
1150
  case int() | UOp(): # sint
1122
1151
  if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
1152
+ # TODO: is this right for (negative) symbolic?
1123
1153
  boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
1124
1154
  case slice():
1125
1155
  if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
1126
- if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
1127
- # handle int slicing
1128
- *boundary, stride = index.indices(cast(SupportsIndex, size))
1129
- if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
1130
- elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
1131
- # update size for slice
1132
- size = ceildiv((boundary[1] - boundary[0]), abs(stride))
1156
+ start, stop = 0 if index.start is None else index.start, size if index.stop is None else index.stop
1157
+ step = 1 if index.step is None else index.step
1158
+ boundary, stride = [start, stop], step
1159
+ if all(isinstance(s, int) for s in (start,stop,step)):
1160
+ # handle int slicing
1161
+ *boundary, stride = index.indices(cast(SupportsIndex, size))
1162
+ if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
1163
+ elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
1164
+ # update size for slice
1165
+ size = ceildiv((boundary[1] - boundary[0]), abs(stride))
1166
+ elif resolve(step == 1, False) and all(isinstance(s,sint) for s in (start, stop)) and resolve((stop-start) > 0, False):
1167
+ # simple symbolic slice
1168
+ size = cast(sint, cast(UOp, (stop - start)).ssimplify())
1169
+ else: raise TypeError(f"slice {index=} is not supported")
1133
1170
  case None: pass # do nothing
1134
1171
  case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
1135
1172
  indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
@@ -1140,9 +1177,9 @@ class Tensor(SimpleMathTrait):
1140
1177
  # flip negative strides
1141
1178
  shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
1142
1179
  x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
1143
- # handle stride != 1 or -1
1144
- if any(abs(st) != 1 for st in strides):
1145
- strides = tuple(abs(s) for s in strides)
1180
+ strides = tuple(map(abs, strides))
1181
+ # apply stride
1182
+ if any(st != 1 for st in strides):
1146
1183
  # pad shape to multiple of stride
1147
1184
  if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
1148
1185
  x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
@@ -1150,7 +1187,7 @@ class Tensor(SimpleMathTrait):
1150
1187
  x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
1151
1188
 
1152
1189
  # dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
1153
- x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
1190
+ x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], (int, UOp))))
1154
1191
 
1155
1192
  # tensor indexing
1156
1193
  if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
@@ -1170,7 +1207,7 @@ class Tensor(SimpleMathTrait):
1170
1207
  # inject 1's for the extra dims added in create masks
1171
1208
  reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
1172
1209
  # sum reduce the extra dims introduced in create masks
1173
- x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
1210
+ x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
1174
1211
 
1175
1212
  # special permute case
1176
1213
  if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
@@ -1188,7 +1225,7 @@ class Tensor(SimpleMathTrait):
1188
1225
 
1189
1226
  def __getitem__(self, indices) -> Tensor:
1190
1227
  """
1191
- Retrieve a sub-tensor using indexing.
1228
+ Retrieves a sub-tensor using indexing.
1192
1229
 
1193
1230
  Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
1194
1231
 
@@ -1226,19 +1263,19 @@ class Tensor(SimpleMathTrait):
1226
1263
  """
1227
1264
  return self._getitem(indices)
1228
1265
 
1229
- def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
1266
+ def __setitem__(self, indices, v:Tensor|ConstType) -> None:
1230
1267
  if isinstance(self.device, str) and self.device.startswith("DISK"):
1231
- self._getitem(indices).assign(v)
1268
+ self.realize()._getitem(indices).assign(v)
1232
1269
  return
1233
1270
  # NOTE: check that setitem target is valid first
1234
- if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
1271
+ if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
1235
1272
  if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
1236
1273
  if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1237
1274
  if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
1238
1275
 
1239
1276
  res = self.realize()._getitem(indices, v)
1240
1277
  # if shapes match and data is not shared it's a copy and we assign to self
1241
- if res.shape == self.shape and res.lazydata is not self.lazydata:
1278
+ if res.shape == self.shape and res.uop is not self.uop:
1242
1279
  self.assign(res).realize()
1243
1280
  else: # no copy, basic setitem
1244
1281
  v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
@@ -1261,7 +1298,7 @@ class Tensor(SimpleMathTrait):
1261
1298
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1262
1299
  index = index.to(self.device)
1263
1300
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1264
- return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
1301
+ return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype)
1265
1302
 
1266
1303
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1267
1304
  """
@@ -1296,11 +1333,11 @@ class Tensor(SimpleMathTrait):
1296
1333
  ```
1297
1334
  """
1298
1335
  # checks for shapes and number of dimensions delegated to cat
1299
- return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
1336
+ return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim)
1300
1337
 
1301
- def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
1338
+ def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor:
1302
1339
  """
1303
- Repeat elements of a tensor.
1340
+ Repeats elements of a tensor.
1304
1341
 
1305
1342
  ```python exec="true" source="above" session="tensor" result="python"
1306
1343
  t = Tensor([1, 2, 3])
@@ -1336,7 +1373,7 @@ class Tensor(SimpleMathTrait):
1336
1373
  if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
1337
1374
  return dim + total if dim < 0 else dim
1338
1375
 
1339
- def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
1376
+ def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
1340
1377
  """
1341
1378
  Splits the tensor into chunks along the dimension specified by `dim`.
1342
1379
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
@@ -1385,7 +1422,31 @@ class Tensor(SimpleMathTrait):
1385
1422
  dim = self._resolve_dim(dim)
1386
1423
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1387
1424
 
1388
- def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
1425
+ def unfold(self, dim:int, size:sint, step:int) -> Tensor:
1426
+ """
1427
+ Unfolds the tensor along dimension `dim` into overlapping windows.
1428
+
1429
+ Each window has length `size` and begins every `step` elements of `self`.
1430
+ Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)`
1431
+ where `n_windows = (self.shape[dim] - size) // step + 1`.
1432
+
1433
+ ```python exec="true" source="above" session="tensor" result="python"
1434
+ unfolded = Tensor.arange(8).unfold(0,2,2)
1435
+ print("\\n".join([repr(x.numpy()) for x in unfolded]))
1436
+ ```
1437
+ ```python exec="true" source="above" session="tensor" result="python"
1438
+ unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
1439
+ print("\\n".join([repr(x.numpy()) for x in unfolded]))
1440
+ ```
1441
+ """
1442
+ if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}')
1443
+ if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}')
1444
+ if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}')
1445
+ dim = self._resolve_dim(dim)
1446
+ perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,)
1447
+ return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,))
1448
+
1449
+ def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
1389
1450
  """
1390
1451
  Generates coordinate matrices from coordinate vectors.
1391
1452
  Input tensors can be scalars or 1D tensors.
@@ -1412,7 +1473,7 @@ class Tensor(SimpleMathTrait):
1412
1473
  output_shape = _broadcast_shape(*(t.shape for t in tensors))
1413
1474
  return tuple(t._broadcast_to(output_shape) for t in tensors)
1414
1475
 
1415
- def squeeze(self, dim:Optional[int]=None) -> Tensor:
1476
+ def squeeze(self, dim:int|None=None) -> Tensor:
1416
1477
  """
1417
1478
  Returns a tensor with specified dimensions of input of size 1 removed.
1418
1479
  If `dim` is not specified, all dimensions with size 1 are removed.
@@ -1469,7 +1530,7 @@ class Tensor(SimpleMathTrait):
1469
1530
  order[dim0], order[dim1] = order[dim1], order[dim0]
1470
1531
  return self.permute(order)
1471
1532
 
1472
- def flatten(self, start_dim=0, end_dim=-1):
1533
+ def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
1473
1534
  """
1474
1535
  Flattens the tensor by reshaping it into a one-dimensional tensor.
1475
1536
  If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
@@ -1485,7 +1546,7 @@ class Tensor(SimpleMathTrait):
1485
1546
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
1486
1547
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1487
1548
 
1488
- def unflatten(self, dim:int, sizes:tuple[int,...]):
1549
+ def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
1489
1550
  """
1490
1551
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
1491
1552
 
@@ -1502,7 +1563,33 @@ class Tensor(SimpleMathTrait):
1502
1563
  dim = self._resolve_dim(dim)
1503
1564
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
1504
1565
 
1505
- def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
1566
+ def diag(self) -> Tensor:
1567
+ """
1568
+ Returns a 2-D square tensor with the elements of input as the main diagonal.
1569
+
1570
+ ```python exec="true" source="above" session="tensor" result="python"
1571
+ print(Tensor([1, 2, 3]).diag().numpy())
1572
+ ```
1573
+ """
1574
+ if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
1575
+ return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
1576
+
1577
+ def diagonal(self) -> Tensor:
1578
+ """
1579
+ Returns a view of input tensor with its main diagonal elements.
1580
+
1581
+ ```python exec="true" source="above" session="tensor" result="python"
1582
+ t = Tensor.arange(9).reshape(3, 3)
1583
+ print(t.numpy())
1584
+ ```
1585
+ ```python exec="true" source="above" session="tensor" result="python"
1586
+ print(t.diagonal().numpy())
1587
+ ```
1588
+ """
1589
+ if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
1590
+ return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
1591
+
1592
+ def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor:
1506
1593
  """
1507
1594
  Rolls the tensor along specified dimension(s).
1508
1595
  The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
@@ -1515,12 +1602,11 @@ class Tensor(SimpleMathTrait):
1515
1602
  print(t.roll(shifts=-1, dims=0).numpy())
1516
1603
  ```
1517
1604
  """
1518
- dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
1519
- for dim, shift in zip(dims, make_tuple(shifts, 1)):
1520
- shift = shift % self.shape[dim]
1521
- rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
1522
- rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
1523
- return rolled
1605
+ if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape)
1606
+ dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim
1607
+ if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}")
1608
+ for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
1609
+ return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
1524
1610
 
1525
1611
  def rearrange(self, formula:str, **sizes) -> Tensor:
1526
1612
  """
@@ -1562,22 +1648,61 @@ class Tensor(SimpleMathTrait):
1562
1648
  t = t.permute([lhs.index(name) for name in rhs])
1563
1649
  return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1564
1650
 
1651
+ def masked_select(self, mask):
1652
+ """
1653
+ Selects elements from `self` based on the boolean `mask`.
1654
+
1655
+ ```python exec="true" source="above" session="tensor" result="python"
1656
+ t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1657
+ mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
1658
+ print(t.numpy())
1659
+ print(mask.numpy())
1660
+ ```
1661
+ ```python exec="true" source="above" session="tensor" result="python"
1662
+ print(t.masked_select(mask).numpy())
1663
+ ```
1664
+ """
1665
+ if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
1666
+ x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
1667
+ mask_cumsum = mask.cumsum()
1668
+ counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32)
1669
+ idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()
1670
+ return x[idxs]
1671
+
1672
+ def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor:
1673
+ """
1674
+ Replaces `self` with `value` wherever the elements of `mask` are True.
1675
+
1676
+ ```python exec="true" source="above" session="tensor" result="python"
1677
+ t = Tensor([1, 2, 3, 4, 5])
1678
+ mask = Tensor([True, False, True, False, False])
1679
+ print(t.masked_fill(mask, -12).numpy())
1680
+ ```
1681
+ ```python exec="true" source="above" session="tensor" result="python"
1682
+ t = Tensor([1, 2, 3, 4, 5])
1683
+ mask = Tensor([True, False, True, False, False])
1684
+ value = Tensor([-1, -2, -3, -4, -5])
1685
+ print(t.masked_fill(mask, value).numpy())
1686
+ ```
1687
+ """
1688
+ return mask.where(value, self)
1689
+
1565
1690
  # ***** reduce ops *****
1566
1691
 
1567
- def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1692
+ def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1568
1693
  axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
1569
1694
  if self.ndim == 0: axis = ()
1570
1695
  ret = self._apply_uop(UOp.r, op=op, axis=axis)
1571
1696
  return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
1572
1697
 
1573
- def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
1698
+ def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
1574
1699
  """
1575
1700
  Returns the sum of the elements of the tensor along the specified axis or axes.
1576
1701
 
1577
1702
  You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1578
1703
  which the maximum is computed and whether the reduced dimensions are retained.
1579
1704
 
1580
- You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1705
+ You can pass in `dtype` keyword argument to control the data type of the accumulation.
1581
1706
  If not specified, the accumulation data type is chosen based on the input tensor's data type.
1582
1707
 
1583
1708
  ```python exec="true" source="above" session="tensor" result="python"
@@ -1594,17 +1719,17 @@ class Tensor(SimpleMathTrait):
1594
1719
  print(t.sum(axis=1).numpy())
1595
1720
  ```
1596
1721
  """
1597
- ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
1598
- return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1722
+ ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim)
1723
+ return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1599
1724
 
1600
- def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
1725
+ def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
1601
1726
  """
1602
1727
  Returns the product of the elements of the tensor along the specified axis or axes.
1603
1728
 
1604
1729
  You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1605
1730
  which the maximum is computed and whether the reduced dimensions are retained.
1606
1731
 
1607
- You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1732
+ You can pass in `dtype` keyword argument to control the data type of the accumulation.
1608
1733
  If not specified, the accumulation data type is chosen based on the input tensor's data type.
1609
1734
 
1610
1735
  ```python exec="true" source="above" session="tensor" result="python"
@@ -1621,9 +1746,9 @@ class Tensor(SimpleMathTrait):
1621
1746
  print(t.prod(axis=1).numpy())
1622
1747
  ```
1623
1748
  """
1624
- return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
1749
+ return self.cast(dtype if dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
1625
1750
 
1626
- def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1751
+ def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1627
1752
  """
1628
1753
  Returns the maximum value of the tensor along the specified axis or axes.
1629
1754
 
@@ -1646,9 +1771,9 @@ class Tensor(SimpleMathTrait):
1646
1771
  """
1647
1772
  return self._reduce(Ops.MAX, axis, keepdim)
1648
1773
 
1649
- def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
1774
+ def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
1650
1775
 
1651
- def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1776
+ def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1652
1777
  """
1653
1778
  Returns the minimum value of the tensor along the specified axis or axes.
1654
1779
 
@@ -1671,7 +1796,7 @@ class Tensor(SimpleMathTrait):
1671
1796
  """
1672
1797
  return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
1673
1798
 
1674
- def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1799
+ def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1675
1800
  """
1676
1801
  Tests if any element evaluates to `True` along the specified axis or axes.
1677
1802
 
@@ -1693,7 +1818,7 @@ class Tensor(SimpleMathTrait):
1693
1818
  """
1694
1819
  return self.bool().max(axis, keepdim)
1695
1820
 
1696
- def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1821
+ def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1697
1822
  """
1698
1823
  Tests if all element evaluates to `True` along the specified axis or axes.
1699
1824
 
@@ -1730,14 +1855,12 @@ class Tensor(SimpleMathTrait):
1730
1855
  print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
1731
1856
  ```
1732
1857
  """
1733
- # TODO: Tensor.isfinite
1734
- def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
1735
- is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
1858
+ is_finite_close = self.isfinite() & other.isfinite() & ((self - other).abs() <= atol + rtol * other.abs())
1736
1859
  is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
1737
1860
  is_nan_close = (self.isnan() & other.isnan()) & equal_nan
1738
1861
  return is_finite_close | is_infinite_close | is_nan_close
1739
1862
 
1740
- def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1863
+ def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
1741
1864
  """
1742
1865
  Returns the mean value of the tensor along the specified axis or axes.
1743
1866
 
@@ -1761,9 +1884,10 @@ class Tensor(SimpleMathTrait):
1761
1884
  """
1762
1885
  output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
1763
1886
  numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1764
- return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
1887
+ return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \
1888
+ .cast(output_dtype)
1765
1889
 
1766
- def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1890
+ def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
1767
1891
  """
1768
1892
  Returns the variance of the tensor along the specified axis or axes.
1769
1893
 
@@ -1789,7 +1913,24 @@ class Tensor(SimpleMathTrait):
1789
1913
  n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
1790
1914
  return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
1791
1915
 
1792
- def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1916
+ def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
1917
+ """
1918
+ Calculates the variance and mean over the dimensions specified by dim.
1919
+ Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
1920
+
1921
+ ```python exec="true" source="above" session="tensor" result="python"
1922
+ Tensor.manual_seed(42)
1923
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1924
+ print(t.numpy())
1925
+ ```
1926
+ ```python exec="true" source="above" session="tensor" result="python"
1927
+ var, mean = t.var_mean()
1928
+ print(var.numpy(), mean.numpy())
1929
+ ```
1930
+ """
1931
+ return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
1932
+
1933
+ def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
1793
1934
  """
1794
1935
  Returns the standard deviation of the tensor along the specified axis or axes.
1795
1936
 
@@ -1813,7 +1954,7 @@ class Tensor(SimpleMathTrait):
1813
1954
  """
1814
1955
  return self.var(axis, keepdim, correction).sqrt()
1815
1956
 
1816
- def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1957
+ def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
1817
1958
  """
1818
1959
  Calculates the standard deviation and mean over the dimensions specified by dim.
1819
1960
  Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
@@ -1830,13 +1971,100 @@ class Tensor(SimpleMathTrait):
1830
1971
  """
1831
1972
  return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
1832
1973
 
1833
- def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
1974
+ def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
1975
+ """
1976
+ Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default.
1977
+
1978
+ ```python exec="false" source="above" session="tensor" result="python"
1979
+ t = Tensor(b"Hello World!").keccak()
1980
+ print(t.data().hex())
1981
+ ```
1982
+ """
1983
+
1984
+ # https://keccak.team/keccak_specs_summary.html
1985
+
1986
+ def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64):
1987
+ # TODO: contiguous is here for compile speed
1988
+ return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous()
1989
+ rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
1990
+ rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets])
1991
+
1992
+ # calculated from π step
1993
+ reorder_indexes = ctensor([0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21], dtype=dtypes.int32)
1994
+ rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081,
1995
+ 0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
1996
+ 0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)]
1997
+
1998
+ rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg
1999
+ data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]), rate - (self.shape[-1] * self.dtype.itemsize % rate)
2000
+ # pad batches then pad blocks
2001
+ data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad((None, None, (0, 200 - rate)))
2002
+
2003
+ # create pad mask
2004
+ lbe = prod(data.shape[1:]) + rate - data_pad - 200
2005
+ if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)]
2006
+ else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)]
2007
+ pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0)
2008
+
2009
+ data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
2010
+
2011
+ state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64)
2012
+ for k in range(int(data.shape[1])):
2013
+ state = state.bitwise_xor(data[:,k].reshape(bs, 25))
2014
+ for i in range(24): # f1600
2015
+ # θ step
2016
+ p = state.reshape(bs, 5, 5).transpose(2, 1)
2017
+ t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
2018
+ state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
2019
+ # ρ and π steps
2020
+ state = state[:, reorder_indexes]
2021
+ state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
2022
+ # χ and ι step
2023
+ state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2))
2024
+ state = state.flatten(1) ^ rnd_const_masks[i]
2025
+ # NOTE: kernelize here to prevent internal stack from growing propotional to data size
2026
+ state = state.kernelize()
2027
+ return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes)
2028
+
2029
+ def _hash_1mb(self) -> Tensor:
2030
+ assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing"
2031
+ assert self.ndim == 2, "only support batched 1d tensors"
2032
+ assert self.shape[1] == 1024 * 1024, "only support messages of 1mb"
2033
+
2034
+ blocks = self.shape[0] * self.shape[1] // 4096
2035
+ data = self.reshape(blocks, 4096)
2036
+ block_hashes = data.keccak("shake_128").reshape(self.shape[0], 4096)
2037
+ return block_hashes.keccak("shake_128").reshape(self.shape[0], 16)
2038
+
2039
+ def hash(self) -> Tensor:
2040
+ """
2041
+ Calculates a 16-byte hash of the tensor.
2042
+ ```python exec="false source="above" session="tensor" result="python"
2043
+ t = Tensor(b"Hello World!").hash()
2044
+ print(t.data().hex())
2045
+ ```
2046
+ """
2047
+
2048
+ data = self.flatten().bitcast(dtypes.uint8)
2049
+ if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
2050
+ base_chunks = ceildiv(data.shape[0], 2**20)
2051
+ tree_depth = math.ceil(math.log(base_chunks, 65536)) if base_chunks > 1 else 0
2052
+
2053
+ level_chunks = base_chunks
2054
+ for _ in range(tree_depth + 1):
2055
+ data = data.reshape(level_chunks, 2**20)._hash_1mb().flatten()
2056
+ if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
2057
+ level_chunks = ceildiv(data.shape[0], 2**20)
2058
+
2059
+ return data[:16]
2060
+
2061
+ def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
1834
2062
  m = self - self.max(axis=axis, keepdim=True).detach()
1835
2063
  if dtype is not None: m = m.cast(dtype)
1836
2064
  e = m.exp()
1837
2065
  return m, e, e.sum(axis=axis, keepdim=True)
1838
2066
 
1839
- def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
2067
+ def softmax(self, axis=-1, dtype:DTypeLike|None=None, _single_kernel=getenv("SINGLE_KERNEL_SOFTMAX")) -> Tensor:
1840
2068
  """
1841
2069
  Applies the softmax function to the tensor along the specified axis.
1842
2070
 
@@ -1856,10 +2084,13 @@ class Tensor(SimpleMathTrait):
1856
2084
  print(t.softmax(axis=0).numpy())
1857
2085
  ```
1858
2086
  """
2087
+ if _single_kernel:
2088
+ _, e, ss = self.contiguous()._softmax(axis, dtype)
2089
+ return e.div(ss).fuse()
1859
2090
  _, e, ss = self._softmax(axis, dtype)
1860
2091
  return e.div(ss)
1861
2092
 
1862
- def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
2093
+ def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
1863
2094
  """
1864
2095
  Applies the log-softmax function to the tensor along the specified axis.
1865
2096
 
@@ -1882,7 +2113,7 @@ class Tensor(SimpleMathTrait):
1882
2113
  m, _, ss = self._softmax(axis, dtype)
1883
2114
  return m - ss.log()
1884
2115
 
1885
- def logsumexp(self, axis=None, keepdim=False):
2116
+ def logsumexp(self, axis=None, keepdim=False) -> Tensor:
1886
2117
  """
1887
2118
  Computes the log-sum-exp of the tensor along the specified axis or axes.
1888
2119
 
@@ -1909,14 +2140,14 @@ class Tensor(SimpleMathTrait):
1909
2140
  m = self.max(axis=axis, keepdim=True)
1910
2141
  return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
1911
2142
 
1912
- def logcumsumexp(self, axis=0):
2143
+ def logcumsumexp(self, axis=0) -> Tensor:
1913
2144
  """
1914
2145
  Computes the log-cumsum-exp of the tensor along the specified axis or axes.
1915
2146
 
1916
2147
  The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
1917
2148
 
1918
2149
  You can pass in the `axis` keyword argument to control the axis along which
1919
- the log-cum-sum-exp is computed.
2150
+ the log-cumsum-exp is computed.
1920
2151
 
1921
2152
  ```python exec="true" source="above" session="tensor" result="python"
1922
2153
  Tensor.manual_seed(42)
@@ -1934,17 +2165,15 @@ class Tensor(SimpleMathTrait):
1934
2165
  ```
1935
2166
  """
1936
2167
  if self.ndim == 0: return self
1937
- axis = self._resolve_dim(axis)
1938
2168
  x = self.transpose(axis, -1)
1939
2169
  last_dim_size = x.shape[-1]
1940
- x_reshaped = x.reshape(-1, last_dim_size)
1941
- x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
1942
- x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
1943
- mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
1944
- ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
1945
- return ret.reshape(*x.shape).transpose(-1, axis)
2170
+ x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
2171
+ x_cummax = x.cummax(-1)
2172
+ mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
2173
+ ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
2174
+ return ret.transpose(-1, axis)
1946
2175
 
1947
- def argmax(self, axis=None, keepdim=False):
2176
+ def argmax(self, axis=None, keepdim=False) -> Tensor:
1948
2177
  """
1949
2178
  Returns the indices of the maximum value of the tensor along the specified axis.
1950
2179
 
@@ -1971,7 +2200,7 @@ class Tensor(SimpleMathTrait):
1971
2200
  idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
1972
2201
  return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
1973
2202
 
1974
- def argmin(self, axis=None, keepdim=False):
2203
+ def argmin(self, axis=None, keepdim=False) -> Tensor:
1975
2204
  """
1976
2205
  Returns the indices of the minimum value of the tensor along the specified axis.
1977
2206
 
@@ -1995,7 +2224,7 @@ class Tensor(SimpleMathTrait):
1995
2224
  return self._inverse().argmax(axis=axis, keepdim=keepdim)
1996
2225
 
1997
2226
  @staticmethod
1998
- def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2227
+ def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor:
1999
2228
  """
2000
2229
  Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
2001
2230
 
@@ -2009,7 +2238,7 @@ class Tensor(SimpleMathTrait):
2009
2238
  """
2010
2239
  def parse_formula(formula:str, *operands:Tensor):
2011
2240
  if "..." in (formula := formula.replace(" ", "")):
2012
- ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
2241
+ ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
2013
2242
  for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
2014
2243
  if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
2015
2244
  inputs[i] = inp.replace("...", ell_chars[-ell_count:])
@@ -2037,11 +2266,11 @@ class Tensor(SimpleMathTrait):
2037
2266
 
2038
2267
  # sum over all axes that's not in the output, then permute to the output order
2039
2268
  return functools.reduce(lambda a,b:a*b, xs_) \
2040
- .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
2269
+ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
2041
2270
 
2042
2271
  # ***** processing ops *****
2043
2272
 
2044
- def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
2273
+ def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
2045
2274
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
2046
2275
  s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
2047
2276
  assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@@ -2066,12 +2295,12 @@ class Tensor(SimpleMathTrait):
2066
2295
  x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
2067
2296
  return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
2068
2297
 
2069
- def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2298
+ def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
2070
2299
  if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
2071
2300
  raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
2072
2301
  return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
2073
2302
 
2074
- def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
2303
+ def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]:
2075
2304
  (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
2076
2305
  pads, grouped_pads = list(pads), _flat_to_grouped(pads)
2077
2306
  # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
@@ -2085,7 +2314,8 @@ class Tensor(SimpleMathTrait):
2085
2314
  return pads
2086
2315
 
2087
2316
  # NOTE: these work for more than 2D
2088
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
2317
+ def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
2318
+ ceil_mode=False, count_include_pad=True) -> Tensor:
2089
2319
  """
2090
2320
  Applies average pooling over a tensor.
2091
2321
 
@@ -2106,8 +2336,6 @@ class Tensor(SimpleMathTrait):
2106
2336
 
2107
2337
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2108
2338
 
2109
- See: https://paperswithcode.com/method/average-pooling
2110
-
2111
2339
  ```python exec="true" source="above" session="tensor" result="python"
2112
2340
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
2113
2341
  print(t.avg_pool2d().numpy())
@@ -2132,7 +2360,8 @@ class Tensor(SimpleMathTrait):
2132
2360
  if not ceil_mode: return pool(self, reg_pads).mean(axis)
2133
2361
  return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
2134
2362
 
2135
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
2363
+ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
2364
+ ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
2136
2365
  """
2137
2366
  Applies max pooling over a tensor.
2138
2367
 
@@ -2149,11 +2378,10 @@ class Tensor(SimpleMathTrait):
2149
2378
  `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2150
2379
 
2151
2380
  When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2381
+ When `return_indices` is set to `True`, the argmax will be returned along with the max values.
2152
2382
 
2153
2383
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2154
2384
 
2155
- See: https://paperswithcode.com/method/max-pooling
2156
-
2157
2385
  ```python exec="true" source="above" session="tensor" result="python"
2158
2386
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
2159
2387
  print(t.max_pool2d().numpy())
@@ -2165,12 +2393,50 @@ class Tensor(SimpleMathTrait):
2165
2393
  print(t.max_pool2d(padding=1).numpy())
2166
2394
  ```
2167
2395
  """
2168
- pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
2396
+ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
2397
+ pads = self._resolve_pool_pads(padding, len(k_))
2169
2398
  if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
2170
- return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2399
+ pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
2400
+ if not return_indices: return pooled.max(axis)
2401
+ spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
2402
+ idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
2403
+ m = pooled == pooled.max(axis, keepdim=True)
2404
+ idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
2405
+ return pooled.max(axis), spatial_sz - idx.max(axis)
2406
+
2407
+ def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
2408
+ """
2409
+ Performs a partial inverse of `max_pool2d` using the indices from the argmax.
2410
+
2411
+ When `output_size` is provided, the output shape disambiguates to the provided shape.
2412
+
2413
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2414
+
2415
+ ```python exec="true" source="above" session="tensor" result="python"
2416
+ t = Tensor.arange(1, 17).reshape(1, 1, 4, 4)
2417
+ print(t.numpy())
2418
+ ```
2419
+ ```python exec="true" source="above" session="tensor" result="python"
2420
+ output, indices = Tensor.max_pool2d(t, return_indices=True)
2421
+ print(output.numpy())
2422
+ print(indices.numpy())
2423
+ ```
2424
+ ```python exec="true" source="above" session="tensor" result="python"
2425
+ print(Tensor.max_unpool2d(output, indices).numpy())
2426
+ ```
2427
+ """
2428
+ bs,c,*spatial_shape = self.shape
2429
+ if output_size is None:
2430
+ k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size))
2431
+ p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape)))
2432
+ # https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1.
2433
+ output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_))
2434
+ else: output_size = output_size[-len(spatial_shape):]
2435
+ ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3)
2436
+ return ret.reshape(bs,c,*output_size)
2171
2437
 
2172
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
2173
- acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2438
+ def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
2439
+ dtype:DTypeLike|None=None) -> Tensor:
2174
2440
  """
2175
2441
  Applies a convolution over a tensor with a given `weight` and optional `bias`.
2176
2442
 
@@ -2196,7 +2462,7 @@ class Tensor(SimpleMathTrait):
2196
2462
  print(t.conv2d(w).numpy())
2197
2463
  ```
2198
2464
  """
2199
- if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
2465
+ if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
2200
2466
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
2201
2467
  padding_ = self._resolve_pool_pads(padding, len(HW))
2202
2468
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
@@ -2209,7 +2475,7 @@ class Tensor(SimpleMathTrait):
2209
2475
  x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
2210
2476
 
2211
2477
  # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
2212
- ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
2478
+ ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
2213
2479
  return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
2214
2480
 
2215
2481
  HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
@@ -2217,7 +2483,7 @@ class Tensor(SimpleMathTrait):
2217
2483
  winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
2218
2484
  winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
2219
2485
 
2220
- # todo: stride == dilation
2486
+ # TODO: stride == dilation
2221
2487
  # use padding to round up to 4x4 output tiles
2222
2488
  # (bs, cin_, tyx, HWI)
2223
2489
  d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
@@ -2234,7 +2500,7 @@ class Tensor(SimpleMathTrait):
2234
2500
  dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
2235
2501
 
2236
2502
  # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
2237
- ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
2503
+ ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
2238
2504
 
2239
2505
  # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
2240
2506
  ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
@@ -2243,7 +2509,7 @@ class Tensor(SimpleMathTrait):
2243
2509
 
2244
2510
  return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
2245
2511
 
2246
- def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
2512
+ def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
2247
2513
  """
2248
2514
  Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
2249
2515
 
@@ -2282,14 +2548,14 @@ class Tensor(SimpleMathTrait):
2282
2548
  padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
2283
2549
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
2284
2550
 
2285
- def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2551
+ def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
2286
2552
 
2287
2553
  """
2288
2554
  Performs dot product between two tensors.
2289
2555
  If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
2290
2556
  If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
2291
2557
 
2292
- You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
2558
+ You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
2293
2559
 
2294
2560
  ```python exec="true" source="above" session="tensor" result="python"
2295
2561
  a = Tensor([1, 2, 3])
@@ -2302,20 +2568,20 @@ class Tensor(SimpleMathTrait):
2302
2568
  print(a.dot(b).numpy())
2303
2569
  ```
2304
2570
  """
2305
- if IMAGE: return self.image_dot(w, acc_dtype)
2571
+ if IMAGE: return self.image_dot(w, dtype)
2306
2572
  x, dx, dw = self, self.ndim, w.ndim
2307
2573
  if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
2308
2574
  if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
2309
2575
  x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
2310
2576
  w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
2311
- return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
2577
+ return (x*w).sum(-1, dtype=dtype).cast(least_upper_dtype(x.dtype, w.dtype) if dtype is None else dtype)
2312
2578
 
2313
- def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2579
+ def matmul(self, x:Tensor, reverse=False, dtype:DTypeLike|None=None) -> Tensor:
2314
2580
  """
2315
2581
  Performs matrix multiplication between two tensors.
2316
2582
 
2317
2583
  You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
2318
- You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
2584
+ You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
2319
2585
 
2320
2586
  ```python exec="true" source="above" session="tensor" result="python"
2321
2587
  a = Tensor([[1, 2], [3, 4]])
@@ -2323,26 +2589,26 @@ class Tensor(SimpleMathTrait):
2323
2589
  print(a.matmul(b).numpy())
2324
2590
  ```
2325
2591
  """
2326
- return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
2592
+ return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype)
2327
2593
 
2328
2594
  def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
2329
- assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
2595
+ assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
2330
2596
  pl_sz = self.shape[axis] - int(not _include_initial)
2331
2597
  pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
2332
- return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
2598
+ return {Ops.ADD: pooled.sum(-1), Ops.MAX: pooled.max(-1), Ops.MUL: pooled.prod(-1)}[op].transpose(axis, -1)
2333
2599
 
2334
2600
  def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
2335
2601
  axis = self._resolve_dim(axis)
2336
2602
  if self.ndim == 0 or 0 in self.shape: return self
2337
- # TODO: someday the optimizer will find this on it's own
2603
+ # TODO: someday the optimizer will find this on its own
2338
2604
  # for now this is a two stage cumsum
2339
2605
  SPLIT = 256
2340
2606
  if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
2341
2607
  ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
2342
2608
  base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
2343
2609
  base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
2344
- def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2345
- return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
2610
+ def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2611
+ return {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}[op](fix(ret), fix(base))
2346
2612
 
2347
2613
  def cumsum(self, axis:int=0) -> Tensor:
2348
2614
  """
@@ -2358,6 +2624,20 @@ class Tensor(SimpleMathTrait):
2358
2624
  """
2359
2625
  return self._split_cumalu(axis, Ops.ADD)
2360
2626
 
2627
+ def cumprod(self, axis:int) -> Tensor:
2628
+ """
2629
+ Computes the cumulative product of the elements of the tensor along the specified `axis`.
2630
+
2631
+ ```python exec="true" source="above" session="tensor" result="python"
2632
+ t = Tensor.arange(1, 7).reshape(2, 3)
2633
+ print(t.numpy())
2634
+ ```
2635
+ ```python exec="true" source="above" session="tensor" result="python"
2636
+ print(t.cumprod(axis=0).numpy())
2637
+ ```
2638
+ """
2639
+ return self._split_cumalu(axis, Ops.MUL)
2640
+
2361
2641
  def cummax(self, axis:int=0) -> Tensor:
2362
2642
  """
2363
2643
  Computes the cumulative max of the tensor along the specified `axis`.
@@ -2403,7 +2683,7 @@ class Tensor(SimpleMathTrait):
2403
2683
  print(t.triu(diagonal=-1).numpy())
2404
2684
  ```
2405
2685
  """
2406
- return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
2686
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, self.zeros_like())
2407
2687
 
2408
2688
  def tril(self, diagonal:int=0) -> Tensor:
2409
2689
  """
@@ -2426,7 +2706,7 @@ class Tensor(SimpleMathTrait):
2426
2706
  print(t.tril(diagonal=-1).numpy())
2427
2707
  ```
2428
2708
  """
2429
- return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
2709
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(self.zeros_like(), self)
2430
2710
 
2431
2711
  def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2432
2712
  """
@@ -2462,7 +2742,7 @@ class Tensor(SimpleMathTrait):
2462
2742
 
2463
2743
  def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
2464
2744
  index, dim = index.to(self.device), self._resolve_dim(dim)
2465
- assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2745
+ assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.ndim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2466
2746
  assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
2467
2747
  f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
2468
2748
  if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
@@ -2475,7 +2755,7 @@ class Tensor(SimpleMathTrait):
2475
2755
  src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
2476
2756
  return src, mask
2477
2757
 
2478
- def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
2758
+ def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
2479
2759
  """
2480
2760
  Scatters `src` values along an axis specified by `dim`.
2481
2761
  Apply `add` or `multiply` reduction operation with `reduce`.
@@ -2540,20 +2820,103 @@ class Tensor(SimpleMathTrait):
2540
2820
  ```
2541
2821
  """
2542
2822
  src, mask = self._pre_scatter(dim, index, src)
2543
- def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
2544
- # TODO: should not overwrite acc_dtype here?
2545
- if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
2546
- if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
2823
+ def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
2824
+ if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
2825
+ if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
2547
2826
  if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
2548
2827
  if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
2549
2828
  if reduce == "mean":
2550
- count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
2551
- return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
2829
+ count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
2830
+ return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
2552
2831
  raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
2553
2832
 
2833
+ def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
2834
+ """
2835
+ Performs a bitonic sort on the tensor along the specified dimension.
2836
+
2837
+ Order of indices for equivalent elements is always preserved.
2838
+
2839
+ See: https://en.wikipedia.org/wiki/Bitonic_sorter
2840
+
2841
+ ```python exec="true" source="above" session="tensor" result="python"
2842
+ t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]])
2843
+ print(t.numpy())
2844
+ ```
2845
+ ```python exec="true" source="above" session="tensor" result="python"
2846
+ sorted_values, indices = t.sort(dim=1, descending=True)
2847
+ print(sorted_values.numpy())
2848
+ print(indices.numpy())
2849
+ ```
2850
+ """
2851
+ x, dim = self, self._resolve_dim(dim)
2852
+ if (orig_len:= x.shape[dim]) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
2853
+ # pad to power of 2
2854
+ n_stages = (orig_len-1).bit_length()
2855
+ pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
2856
+ x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages)
2857
+ # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg
2858
+ for stage in range(1, n_stages+1):
2859
+ if stage != n_stages:
2860
+ # flip so arrows of green boxes point the same way as blue boxes
2861
+ crossover_dim = dim + n_stages - stage - 1
2862
+ blue_box, green_box = x.split(1, crossover_dim)
2863
+ flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim)))
2864
+ x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous()
2865
+ for substage in range(stage-1, -1, -1):
2866
+ partner_dim = dim + n_stages - substage - 1
2867
+ x_top, x_bottom = x.split(1, partner_dim)
2868
+ x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom)
2869
+ x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous()
2870
+ if stage != n_stages:
2871
+ # flip wires back to undo the crossover
2872
+ blue_box, flipped_green_box = x.split(1, crossover_dim)
2873
+ x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
2874
+ x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, s) for s in self.shape))
2875
+ # compute indices for sorted values
2876
+ mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1))
2877
+ def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
2878
+ count_orig, count_sorted = compute_counts(self), compute_counts(x)
2879
+ cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
2880
+ idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
2881
+ idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
2882
+ return x, idx
2883
+
2884
+ def argsort(self, dim:int=-1, descending:bool=False) -> Tensor:
2885
+ """
2886
+ Returns the indices that sort input tensor along given `dimension` in given `descending` order by value.
2887
+
2888
+ ```python exec="true" source="above" session="tensor" result="python"
2889
+ t = Tensor([[2, 3, 4, 1], [1, 4, 3, 2]])
2890
+ print(t.argsort().numpy())
2891
+ ```
2892
+ """
2893
+ return self.sort(dim, descending)[1]
2894
+
2895
+ def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]:
2896
+ """
2897
+ Computes the top-k elements of the tensor along the specified `dim`.
2898
+
2899
+ Order of indices for equivalent elements is always preserved.
2900
+
2901
+ ```python exec="true" source="above" session="tensor" result="python"
2902
+ t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]])
2903
+ print(t.numpy())
2904
+ ```
2905
+ ```python exec="true" source="above" session="tensor" result="python"
2906
+ topk_values, topk_indices = t.topk(2, dim=1)
2907
+ print(topk_values.numpy())
2908
+ print(topk_indices.numpy())
2909
+ ```
2910
+ """
2911
+ if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported")
2912
+ if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range")
2913
+ x, idx = self.sort(dim, descending=largest)
2914
+ shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim))
2915
+ return x.shrink(shrink_to_k), idx.shrink(shrink_to_k)
2916
+
2554
2917
  # ***** unary ops *****
2555
2918
 
2556
- def logical_not(self):
2919
+ def logical_not(self) -> Tensor:
2557
2920
  """
2558
2921
  Computes the logical NOT of the tensor element-wise.
2559
2922
 
@@ -2562,7 +2925,8 @@ class Tensor(SimpleMathTrait):
2562
2925
  ```
2563
2926
  """
2564
2927
  return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
2565
- def neg(self):
2928
+
2929
+ def neg(self) -> Tensor:
2566
2930
  """
2567
2931
  Negates the tensor element-wise.
2568
2932
 
@@ -2571,17 +2935,29 @@ class Tensor(SimpleMathTrait):
2571
2935
  ```
2572
2936
  """
2573
2937
  return self*-1 if self.dtype != dtypes.bool else self.logical_not()
2574
- def contiguous(self):
2938
+
2939
+ def contiguous(self, **kwargs) -> Tensor:
2575
2940
  """
2576
2941
  Returns a contiguous tensor.
2577
2942
  """
2578
- return self._apply_uop(UOp.contiguous)
2579
- def contiguous_backward(self):
2943
+ return self._apply_uop(UOp.contiguous, **kwargs)
2944
+
2945
+ def fuse(self) -> Tensor:
2946
+ """
2947
+ Makes this a single kernel back to Ops.CONTIGUOUS on the inputs.
2948
+
2949
+ Useful for single kernel softmax and flash attention.
2950
+ Careful, this can break codegen or make kernels really slow.
2951
+ """
2952
+ return self._apply_uop(UOp.fuse)
2953
+
2954
+ def contiguous_backward(self) -> Tensor:
2580
2955
  """
2581
2956
  Inserts a contiguous operation in the backward pass.
2582
2957
  """
2583
2958
  return self._apply_uop(UOp.contiguous_backward)
2584
- def log(self):
2959
+
2960
+ def log(self) -> Tensor:
2585
2961
  """
2586
2962
  Computes the natural logarithm element-wise.
2587
2963
 
@@ -2592,7 +2968,8 @@ class Tensor(SimpleMathTrait):
2592
2968
  ```
2593
2969
  """
2594
2970
  return self.log2()*math.log(2)
2595
- def log2(self):
2971
+
2972
+ def log2(self) -> Tensor:
2596
2973
  """
2597
2974
  Computes the base-2 logarithm element-wise.
2598
2975
 
@@ -2603,7 +2980,8 @@ class Tensor(SimpleMathTrait):
2603
2980
  ```
2604
2981
  """
2605
2982
  return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
2606
- def exp(self):
2983
+
2984
+ def exp(self) -> Tensor:
2607
2985
  """
2608
2986
  Computes the exponential function element-wise.
2609
2987
 
@@ -2614,7 +2992,8 @@ class Tensor(SimpleMathTrait):
2614
2992
  ```
2615
2993
  """
2616
2994
  return self.mul(1/math.log(2)).exp2()
2617
- def exp2(self):
2995
+
2996
+ def exp2(self) -> Tensor:
2618
2997
  """
2619
2998
  Computes the base-2 exponential function element-wise.
2620
2999
 
@@ -2625,19 +3004,19 @@ class Tensor(SimpleMathTrait):
2625
3004
  ```
2626
3005
  """
2627
3006
  return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
2628
- def relu(self):
3007
+
3008
+ def relu(self) -> Tensor:
2629
3009
  """
2630
3010
  Applies the Rectified Linear Unit (ReLU) function element-wise.
2631
3011
 
2632
- - Described: https://paperswithcode.com/method/relu
2633
-
2634
3012
  ```python exec="true" source="above" session="tensor" result="python"
2635
3013
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
2636
3014
  ```
2637
3015
  """
3016
+ # NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0
2638
3017
  return (self>0).where(self, 0)
2639
3018
 
2640
- def sigmoid(self):
3019
+ def sigmoid(self) -> Tensor:
2641
3020
  """
2642
3021
  Applies the Sigmoid function element-wise.
2643
3022
 
@@ -2649,12 +3028,23 @@ class Tensor(SimpleMathTrait):
2649
3028
  """
2650
3029
  return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
2651
3030
 
2652
- def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
3031
+ def logsigmoid(self) -> Tensor:
3032
+ """
3033
+ Applies the LogSigmoid function element-wise.
3034
+
3035
+ - See: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.logsigmoid.html
3036
+
3037
+ ```python exec="true" source="above" session="tensor" result="python"
3038
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).logsigmoid().numpy())
3039
+ ```
3040
+ """
3041
+ return -(-self).softplus()
3042
+
3043
+ def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
2653
3044
  """
2654
3045
  Applies the Hardsigmoid function element-wise.
2655
- NOTE: default `alpha` and `beta` values is taken from torch
3046
+ NOTE: default `alpha` and `beta` values are taken from torch
2656
3047
 
2657
- - Described: https://paperswithcode.com/method/hard-sigmoid
2658
3048
  - See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
2659
3049
 
2660
3050
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2663,7 +3053,7 @@ class Tensor(SimpleMathTrait):
2663
3053
  """
2664
3054
  return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
2665
3055
 
2666
- def sqrt(self):
3056
+ def sqrt(self) -> Tensor:
2667
3057
  """
2668
3058
  Computes the square root of the tensor element-wise.
2669
3059
 
@@ -2672,7 +3062,8 @@ class Tensor(SimpleMathTrait):
2672
3062
  ```
2673
3063
  """
2674
3064
  return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
2675
- def rsqrt(self):
3065
+
3066
+ def rsqrt(self) -> Tensor:
2676
3067
  """
2677
3068
  Computes the reciprocal of the square root of the tensor element-wise.
2678
3069
 
@@ -2681,7 +3072,8 @@ class Tensor(SimpleMathTrait):
2681
3072
  ```
2682
3073
  """
2683
3074
  return self.sqrt().reciprocal()
2684
- def sin(self):
3075
+
3076
+ def sin(self) -> Tensor:
2685
3077
  """
2686
3078
  Computes the sine of the tensor element-wise.
2687
3079
 
@@ -2690,7 +3082,8 @@ class Tensor(SimpleMathTrait):
2690
3082
  ```
2691
3083
  """
2692
3084
  return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
2693
- def cos(self):
3085
+
3086
+ def cos(self) -> Tensor:
2694
3087
  """
2695
3088
  Computes the cosine of the tensor element-wise.
2696
3089
 
@@ -2699,7 +3092,8 @@ class Tensor(SimpleMathTrait):
2699
3092
  ```
2700
3093
  """
2701
3094
  return ((math.pi/2)-self).sin()
2702
- def tan(self):
3095
+
3096
+ def tan(self) -> Tensor:
2703
3097
  """
2704
3098
  Computes the tangent of the tensor element-wise.
2705
3099
 
@@ -2709,7 +3103,7 @@ class Tensor(SimpleMathTrait):
2709
3103
  """
2710
3104
  return self.sin() / self.cos()
2711
3105
 
2712
- def asin(self):
3106
+ def asin(self) -> Tensor:
2713
3107
  """
2714
3108
  Computes the inverse sine (arcsine) of the tensor element-wise.
2715
3109
 
@@ -2722,7 +3116,7 @@ class Tensor(SimpleMathTrait):
2722
3116
  x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
2723
3117
  return self.sign() * x
2724
3118
 
2725
- def acos(self):
3119
+ def acos(self) -> Tensor:
2726
3120
  """
2727
3121
  Computes the inverse cosine (arccosine) of the tensor element-wise.
2728
3122
 
@@ -2732,7 +3126,7 @@ class Tensor(SimpleMathTrait):
2732
3126
  """
2733
3127
  return math.pi / 2 - self.asin()
2734
3128
 
2735
- def atan(self):
3129
+ def atan(self) -> Tensor:
2736
3130
  """
2737
3131
  Computes the inverse tangent (arctan) of the tensor element-wise.
2738
3132
 
@@ -2752,7 +3146,8 @@ class Tensor(SimpleMathTrait):
2752
3146
  print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
2753
3147
  ```
2754
3148
  """
2755
- return self.cast(dtypes.int32).cast(self.dtype)
3149
+ return self._apply_uop(UOp.trunc)
3150
+
2756
3151
  def ceil(self: Tensor) -> Tensor:
2757
3152
  """
2758
3153
  Rounds the tensor element-wise towards positive infinity.
@@ -2762,6 +3157,7 @@ class Tensor(SimpleMathTrait):
2762
3157
  ```
2763
3158
  """
2764
3159
  return (self > (b := self.trunc())).where(b+1, b)
3160
+
2765
3161
  def floor(self: Tensor) -> Tensor:
2766
3162
  """
2767
3163
  Rounds the tensor element-wise towards negative infinity.
@@ -2771,6 +3167,7 @@ class Tensor(SimpleMathTrait):
2771
3167
  ```
2772
3168
  """
2773
3169
  return (self < (b := self.trunc())).where(b-1, b)
3170
+
2774
3171
  def round(self: Tensor) -> Tensor:
2775
3172
  """
2776
3173
  Rounds the tensor element-wise with rounding half to even.
@@ -2779,9 +3176,9 @@ class Tensor(SimpleMathTrait):
2779
3176
  print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
2780
3177
  ```
2781
3178
  """
2782
- return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
3179
+ return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
2783
3180
 
2784
- def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
3181
+ def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor:
2785
3182
  """
2786
3183
  Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
2787
3184
 
@@ -2790,7 +3187,8 @@ class Tensor(SimpleMathTrait):
2790
3187
  ```
2791
3188
  """
2792
3189
  return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
2793
- def isnan(self:Tensor):
3190
+
3191
+ def isnan(self:Tensor) -> Tensor:
2794
3192
  """
2795
3193
  Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
2796
3194
 
@@ -2800,7 +3198,17 @@ class Tensor(SimpleMathTrait):
2800
3198
  """
2801
3199
  return self != self
2802
3200
 
2803
- def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
3201
+ def isfinite(self:Tensor) -> Tensor:
3202
+ """
3203
+ Checks the tensor element-wise to return True where the element is finite, otherwise returns False
3204
+
3205
+ ```python exec="true" source="above" session="tensor" result="python"
3206
+ print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy())
3207
+ ```
3208
+ """
3209
+ return (self.isinf()|self.isnan()).logical_not()
3210
+
3211
+ def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
2804
3212
  """
2805
3213
  Linearly interpolates between `self` and `end` by `weight`.
2806
3214
 
@@ -2813,7 +3221,7 @@ class Tensor(SimpleMathTrait):
2813
3221
  return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
2814
3222
  return self + (end - self) * weight
2815
3223
 
2816
- def square(self):
3224
+ def square(self) -> Tensor:
2817
3225
  """
2818
3226
  Squares the tensor element-wise.
2819
3227
  Equivalent to `self*self`.
@@ -2823,7 +3231,8 @@ class Tensor(SimpleMathTrait):
2823
3231
  ```
2824
3232
  """
2825
3233
  return self*self
2826
- def clamp(self, min_=None, max_=None):
3234
+
3235
+ def clamp(self, min_=None, max_=None) -> Tensor:
2827
3236
  """
2828
3237
  Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
2829
3238
  If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
@@ -2835,12 +3244,14 @@ class Tensor(SimpleMathTrait):
2835
3244
  if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
2836
3245
  ret = self.maximum(min_) if min_ is not None else self
2837
3246
  return ret.minimum(max_) if max_ is not None else ret
2838
- def clip(self, min_=None, max_=None):
3247
+
3248
+ def clip(self, min_=None, max_=None) -> Tensor:
2839
3249
  """
2840
3250
  Alias for `Tensor.clamp`.
2841
3251
  """
2842
3252
  return self.clamp(min_, max_)
2843
- def sign(self):
3253
+
3254
+ def sign(self) -> Tensor:
2844
3255
  """
2845
3256
  Returns the sign of the tensor element-wise.
2846
3257
 
@@ -2849,7 +3260,8 @@ class Tensor(SimpleMathTrait):
2849
3260
  ```
2850
3261
  """
2851
3262
  return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
2852
- def abs(self):
3263
+
3264
+ def abs(self) -> Tensor:
2853
3265
  """
2854
3266
  Computes the absolute value of the tensor element-wise.
2855
3267
 
@@ -2858,9 +3270,10 @@ class Tensor(SimpleMathTrait):
2858
3270
  ```
2859
3271
  """
2860
3272
  return self * self.sign()
2861
- def reciprocal(self):
3273
+
3274
+ def reciprocal(self) -> Tensor:
2862
3275
  """
2863
- Compute `1/x` element-wise.
3276
+ Computes `1/x` element-wise.
2864
3277
 
2865
3278
  ```python exec="true" source="above" session="tensor" result="python"
2866
3279
  print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
@@ -2870,11 +3283,10 @@ class Tensor(SimpleMathTrait):
2870
3283
 
2871
3284
  # ***** activation functions *****
2872
3285
 
2873
- def elu(self, alpha=1.0):
3286
+ def elu(self, alpha=1.0) -> Tensor:
2874
3287
  """
2875
3288
  Applies the Exponential Linear Unit (ELU) function element-wise.
2876
3289
 
2877
- - Described: https://paperswithcode.com/method/elu
2878
3290
  - Paper: https://arxiv.org/abs/1511.07289v5
2879
3291
 
2880
3292
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2883,11 +3295,10 @@ class Tensor(SimpleMathTrait):
2883
3295
  """
2884
3296
  return self.relu() - alpha*(1-self.exp()).relu()
2885
3297
 
2886
- def celu(self, alpha=1.0):
3298
+ def celu(self, alpha=1.0) -> Tensor:
2887
3299
  """
2888
3300
  Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
2889
3301
 
2890
- - Described: https://paperswithcode.com/method/celu
2891
3302
  - Paper: https://arxiv.org/abs/1704.07483
2892
3303
 
2893
3304
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2896,11 +3307,10 @@ class Tensor(SimpleMathTrait):
2896
3307
  """
2897
3308
  return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2898
3309
 
2899
- def selu(self, alpha=1.67326, gamma=1.0507):
3310
+ def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
2900
3311
  """
2901
3312
  Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
2902
3313
 
2903
- - Described: https://paperswithcode.com/method/selu
2904
3314
  - Paper: https://arxiv.org/abs/1706.02515v5
2905
3315
 
2906
3316
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2909,7 +3319,7 @@ class Tensor(SimpleMathTrait):
2909
3319
  """
2910
3320
  return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
2911
3321
 
2912
- def swish(self):
3322
+ def swish(self) -> Tensor:
2913
3323
  """
2914
3324
  See `.silu()`
2915
3325
 
@@ -2921,11 +3331,10 @@ class Tensor(SimpleMathTrait):
2921
3331
  """
2922
3332
  return self * self.sigmoid()
2923
3333
 
2924
- def silu(self):
3334
+ def silu(self) -> Tensor:
2925
3335
  """
2926
3336
  Applies the Sigmoid Linear Unit (SiLU) function element-wise.
2927
3337
 
2928
- - Described: https://paperswithcode.com/method/silu
2929
3338
  - Paper: https://arxiv.org/abs/1606.08415
2930
3339
 
2931
3340
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2934,11 +3343,10 @@ class Tensor(SimpleMathTrait):
2934
3343
  """
2935
3344
  return self.swish() # The SiLU function is also known as the swish function.
2936
3345
 
2937
- def relu6(self):
3346
+ def relu6(self) -> Tensor:
2938
3347
  """
2939
3348
  Applies the ReLU6 function element-wise.
2940
3349
 
2941
- - Described: https://paperswithcode.com/method/relu6
2942
3350
  - Paper: https://arxiv.org/abs/1704.04861v1
2943
3351
 
2944
3352
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2947,11 +3355,10 @@ class Tensor(SimpleMathTrait):
2947
3355
  """
2948
3356
  return self.relu() - (self-6).relu()
2949
3357
 
2950
- def hardswish(self):
3358
+ def hardswish(self) -> Tensor:
2951
3359
  """
2952
3360
  Applies the Hardswish function element-wise.
2953
3361
 
2954
- - Described: https://paperswithcode.com/method/hard-swish
2955
3362
  - Paper: https://arxiv.org/abs/1905.02244v5
2956
3363
 
2957
3364
  ```python exec="true" source="above" session="tensor" result="python"
@@ -2960,7 +3367,7 @@ class Tensor(SimpleMathTrait):
2960
3367
  """
2961
3368
  return self * (self+3).relu6() * (1/6)
2962
3369
 
2963
- def tanh(self):
3370
+ def tanh(self) -> Tensor:
2964
3371
  """
2965
3372
  Applies the Hyperbolic Tangent (tanh) function element-wise.
2966
3373
 
@@ -2972,7 +3379,7 @@ class Tensor(SimpleMathTrait):
2972
3379
  """
2973
3380
  return 2.0 * ((2.0 * self).sigmoid()) - 1.0
2974
3381
 
2975
- def sinh(self):
3382
+ def sinh(self) -> Tensor:
2976
3383
  """
2977
3384
  Applies the Hyperbolic Sine (sinh) function element-wise.
2978
3385
 
@@ -2984,7 +3391,7 @@ class Tensor(SimpleMathTrait):
2984
3391
  """
2985
3392
  return (self.exp() - self.neg().exp()) / 2
2986
3393
 
2987
- def cosh(self):
3394
+ def cosh(self) -> Tensor:
2988
3395
  """
2989
3396
  Applies the Hyperbolic Cosine (cosh) function element-wise.
2990
3397
 
@@ -2996,7 +3403,7 @@ class Tensor(SimpleMathTrait):
2996
3403
  """
2997
3404
  return (self.exp() + self.neg().exp()) / 2
2998
3405
 
2999
- def atanh(self):
3406
+ def atanh(self) -> Tensor:
3000
3407
  """
3001
3408
  Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
3002
3409
 
@@ -3008,7 +3415,7 @@ class Tensor(SimpleMathTrait):
3008
3415
  """
3009
3416
  return ((1 + self)/(1 - self)).log() / 2
3010
3417
 
3011
- def asinh(self):
3418
+ def asinh(self) -> Tensor:
3012
3419
  """
3013
3420
  Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
3014
3421
 
@@ -3020,7 +3427,7 @@ class Tensor(SimpleMathTrait):
3020
3427
  """
3021
3428
  return (self + (self.square() + 1).sqrt()).log()
3022
3429
 
3023
- def acosh(self):
3430
+ def acosh(self) -> Tensor:
3024
3431
  """
3025
3432
  Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
3026
3433
 
@@ -3032,19 +3439,17 @@ class Tensor(SimpleMathTrait):
3032
3439
  """
3033
3440
  return (self + (self.square() - 1).sqrt()).log()
3034
3441
 
3035
- def hardtanh(self, min_val=-1, max_val=1):
3442
+ def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
3036
3443
  """
3037
3444
  Applies the Hardtanh function element-wise.
3038
3445
 
3039
- - Described: https://paperswithcode.com/method/hardtanh-activation
3040
-
3041
3446
  ```python exec="true" source="above" session="tensor" result="python"
3042
3447
  print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
3043
3448
  ```
3044
3449
  """
3045
3450
  return self.clip(min_val, max_val)
3046
3451
 
3047
- def erf(self):
3452
+ def erf(self) -> Tensor:
3048
3453
  """
3049
3454
  Applies error function element-wise.
3050
3455
 
@@ -3058,11 +3463,10 @@ class Tensor(SimpleMathTrait):
3058
3463
  t = 1.0 / (1.0 + 0.3275911 * self.abs())
3059
3464
  return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
3060
3465
 
3061
- def gelu(self):
3466
+ def gelu(self) -> Tensor:
3062
3467
  """
3063
3468
  Applies the Gaussian Error Linear Unit (GELU) function element-wise.
3064
3469
 
3065
- - Described: https://paperswithcode.com/method/gelu
3066
3470
  - Paper: https://arxiv.org/abs/1606.08415v5
3067
3471
 
3068
3472
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3071,38 +3475,33 @@ class Tensor(SimpleMathTrait):
3071
3475
  """
3072
3476
  return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
3073
3477
 
3074
- def quick_gelu(self):
3478
+ def quick_gelu(self) -> Tensor:
3075
3479
  """
3076
3480
  Applies the Sigmoid GELU approximation element-wise.
3077
3481
 
3078
- - Described: https://paperswithcode.com/method/gelu
3079
-
3080
3482
  ```python exec="true" source="above" session="tensor" result="python"
3081
3483
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
3082
3484
  ```
3083
3485
  """
3084
3486
  return self * (self * 1.702).sigmoid()
3085
3487
 
3086
- def leakyrelu(self, neg_slope=0.01):
3488
+ def leaky_relu(self, neg_slope=0.01) -> Tensor:
3087
3489
  """
3088
3490
  Applies the Leaky ReLU function element-wise.
3089
3491
 
3090
- - Described: https://paperswithcode.com/method/leaky-relu
3091
-
3092
3492
  ```python exec="true" source="above" session="tensor" result="python"
3093
- print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
3493
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
3094
3494
  ```
3095
3495
  ```python exec="true" source="above" session="tensor" result="python"
3096
- print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
3496
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
3097
3497
  ```
3098
3498
  """
3099
- return self.relu() - (-neg_slope*self).relu()
3499
+ return (self<0).where(neg_slope*self, self)
3100
3500
 
3101
- def mish(self):
3501
+ def mish(self) -> Tensor:
3102
3502
  """
3103
3503
  Applies the Mish function element-wise.
3104
3504
 
3105
- - Described: https://paperswithcode.com/method/mish
3106
3505
  - Paper: https://arxiv.org/abs/1908.08681v3
3107
3506
 
3108
3507
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3111,24 +3510,21 @@ class Tensor(SimpleMathTrait):
3111
3510
  """
3112
3511
  return self * self.softplus().tanh()
3113
3512
 
3114
- def softplus(self, beta=1):
3513
+ def softplus(self, beta=1.0, threshold=20.0) -> Tensor:
3115
3514
  """
3116
3515
  Applies the Softplus function element-wise.
3117
-
3118
- - Described: https://paperswithcode.com/method/softplus
3516
+ For numerical stability, the implementation folds into identity function when `self * beta > threshold`.
3119
3517
 
3120
3518
  ```python exec="true" source="above" session="tensor" result="python"
3121
3519
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
3122
3520
  ```
3123
3521
  """
3124
- return (1/beta) * (1 + (self*beta).exp()).log()
3522
+ return (self * beta > threshold).where(self, (1/beta) * (1 + (self*beta).exp()).log())
3125
3523
 
3126
- def softsign(self):
3524
+ def softsign(self) -> Tensor:
3127
3525
  """
3128
3526
  Applies the Softsign function element-wise.
3129
3527
 
3130
- - Described: https://paperswithcode.com/method/softsign
3131
-
3132
3528
  ```python exec="true" source="above" session="tensor" result="python"
3133
3529
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
3134
3530
  ```
@@ -3144,9 +3540,10 @@ class Tensor(SimpleMathTrait):
3144
3540
  # for each dimension, check either dim is 1, or it does not change
3145
3541
  if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
3146
3542
  raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
3147
- return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
3543
+ # NOTE: this cast is no-op in forward and uses sum_acc_dtype in the backward sum
3544
+ return self.reshape(shape).cast(sum_acc_dtype(self.dtype))._apply_uop(UOp.expand, arg=new_shape).cast(self.dtype)
3148
3545
 
3149
- def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
3546
+ def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
3150
3547
  x: Tensor = self
3151
3548
  if not isinstance(y, Tensor):
3152
3549
  # make y a Tensor
@@ -3165,27 +3562,7 @@ class Tensor(SimpleMathTrait):
3165
3562
  # broadcast
3166
3563
  return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
3167
3564
 
3168
- def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3169
- """
3170
- Adds `self` and `x`.
3171
- Equivalent to `self + x`.
3172
- Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
3173
-
3174
- ```python exec="true" source="above" session="tensor" result="python"
3175
- Tensor.manual_seed(42)
3176
- t = Tensor.randn(4)
3177
- print(t.numpy())
3178
- ```
3179
- ```python exec="true" source="above" session="tensor" result="python"
3180
- print(t.add(20).numpy())
3181
- ```
3182
- ```python exec="true" source="above" session="tensor" result="python"
3183
- print(t.add(Tensor([[2.0], [3.5]])).numpy())
3184
- ```
3185
- """
3186
- return self._apply_broadcasted_uop(UOp.add, x, reverse)
3187
-
3188
- def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3565
+ def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
3189
3566
  """
3190
3567
  Subtracts `x` from `self`.
3191
3568
  Equivalent to `self - x`.
@@ -3206,40 +3583,7 @@ class Tensor(SimpleMathTrait):
3206
3583
  a, b = self._broadcasted(x, reverse)
3207
3584
  return a + (-b)
3208
3585
 
3209
- def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3210
- """
3211
- Multiplies `self` and `x`.
3212
- Equivalent to `self * x`.
3213
- Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
3214
-
3215
- ```python exec="true" source="above" session="tensor" result="python"
3216
- Tensor.manual_seed(42)
3217
- t = Tensor.randn(4)
3218
- print(t.numpy())
3219
- ```
3220
- ```python exec="true" source="above" session="tensor" result="python"
3221
- print(t.mul(3).numpy())
3222
- ```
3223
- ```python exec="true" source="above" session="tensor" result="python"
3224
- print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
3225
- ```
3226
- """
3227
- return self._apply_broadcasted_uop(UOp.mul, x, reverse)
3228
-
3229
- def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3230
- """
3231
- Divides `self` by `x`.
3232
- Equivalent to `self // x`.
3233
- Supports broadcasting to a common shape, type promotion, and integer inputs.
3234
- `idiv` performs integer division (truncate towards zero).
3235
-
3236
- ```python exec="true" source="above" session="tensor" result="python"
3237
- print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
3238
- ```
3239
- """
3240
- return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
3241
-
3242
- def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3586
+ def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
3243
3587
  """
3244
3588
  Divides `self` by `x`.
3245
3589
  Equivalent to `self / x`.
@@ -3259,9 +3603,21 @@ class Tensor(SimpleMathTrait):
3259
3603
  ```
3260
3604
  """
3261
3605
  numerator, denominator = self._broadcasted(x, reverse)
3262
- return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
3263
-
3264
- def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3606
+ d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
3607
+ output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
3608
+ if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)) and rounding_mode is not None:
3609
+ numerator, denominator = numerator.cast(dt), denominator.cast(dt)
3610
+ if rounding_mode == "trunc": return numerator.idiv(denominator)
3611
+ if rounding_mode == "floor":
3612
+ truncate_div, truncate_mod = numerator.idiv(denominator), numerator._apply_broadcasted_uop(UOp.mod, denominator)
3613
+ opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
3614
+ return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
3615
+ if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
3616
+ if rounding_mode == "floor": return d.floor().cast(output_dtype)
3617
+ if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
3618
+ return d
3619
+
3620
+ def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
3265
3621
  """
3266
3622
  Mod `self` by `x`.
3267
3623
  Equivalent to `self % x`.
@@ -3272,57 +3628,11 @@ class Tensor(SimpleMathTrait):
3272
3628
  ```
3273
3629
  """
3274
3630
  a, b = self._broadcasted(x, reverse)
3275
- return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
3276
-
3277
- def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3278
- """
3279
- Computes bitwise xor of `self` and `x`.
3280
- Equivalent to `self ^ x`.
3281
- Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
3282
-
3283
- ```python exec="true" source="above" session="tensor" result="python"
3284
- print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
3285
- ```
3286
- ```python exec="true" source="above" session="tensor" result="python"
3287
- print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
3288
- ```
3289
- """
3290
- if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3291
- return self._apply_broadcasted_uop(UOp.xor, x, reverse)
3292
-
3293
- def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3294
- """
3295
- Compute the bit-wise AND of `self` and `x`.
3296
- Equivalent to `self & x`.
3297
- Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
3298
- ```python exec="true" source="above" session="tensor" result="python"
3299
- print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
3300
- ```
3301
- ```python exec="true" source="above" session="tensor" result="python"
3302
- print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
3303
- ```
3304
- """
3305
- if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3306
- return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
3307
-
3308
- def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3309
- """
3310
- Compute the bit-wise OR of `self` and `x`.
3311
- Equivalent to `self | x`.
3312
- Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
3313
- ```python exec="true" source="above" session="tensor" result="python"
3314
- print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
3315
- ```
3316
- ```python exec="true" source="above" session="tensor" result="python"
3317
- print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
3318
- ```
3319
- """
3320
- if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3321
- return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
3631
+ return a - a.div(b, rounding_mode="floor") * b
3322
3632
 
3323
3633
  def bitwise_not(self) -> Tensor:
3324
3634
  """
3325
- Compute the bit-wise NOT of `self`.
3635
+ Computes the bitwise NOT of `self`.
3326
3636
  Equivalent to `~self`.
3327
3637
  ```python exec="true" source="above" session="tensor" result="python"
3328
3638
  print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
@@ -3334,7 +3644,7 @@ class Tensor(SimpleMathTrait):
3334
3644
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3335
3645
  return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
3336
3646
 
3337
- def lshift(self, x:int):
3647
+ def lshift(self, x:int, reverse=False) -> Tensor:
3338
3648
  """
3339
3649
  Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
3340
3650
  Equivalent to `self << x`.
@@ -3343,10 +3653,10 @@ class Tensor(SimpleMathTrait):
3343
3653
  print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
3344
3654
  ```
3345
3655
  """
3346
- assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
3347
- return self.mul(2 ** x)
3656
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
3657
+ return self.mul(2 ** x, reverse)
3348
3658
 
3349
- def rshift(self, x:int):
3659
+ def rshift(self, x:int, reverse=False) -> Tensor:
3350
3660
  """
3351
3661
  Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
3352
3662
  Equivalent to `self >> x`.
@@ -3355,10 +3665,10 @@ class Tensor(SimpleMathTrait):
3355
3665
  print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
3356
3666
  ```
3357
3667
  """
3358
- assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
3359
- return self.idiv(2 ** x)
3668
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
3669
+ return self.idiv(2 ** x, reverse)
3360
3670
 
3361
- def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3671
+ def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor:
3362
3672
  """
3363
3673
  Computes power of `self` with `x`.
3364
3674
  Equivalent to `self ** x`.
@@ -3375,13 +3685,13 @@ class Tensor(SimpleMathTrait):
3375
3685
  """
3376
3686
  base, exponent = self._broadcasted(x, reverse=reverse)
3377
3687
  # TODO: int pow
3378
- if not base.is_floating_point(): raise RuntimeError("base needs to be float")
3688
+ if not base.is_floating_point() and not (isinstance(x, int) and x >= 0): raise RuntimeError("base needs to be float")
3379
3689
 
3380
- # NOTE: pow(int, float) -> int
3381
3690
  ret = base._apply_uop(UOp.pow, exponent)
3382
- return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
3691
+ # NOTE: pow(int, float) -> int
3692
+ return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
3383
3693
 
3384
- def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
3694
+ def maximum(self, x:Tensor|ConstType) -> Tensor:
3385
3695
  """
3386
3696
  Computes element-wise maximum of `self` and `x`.
3387
3697
 
@@ -3394,7 +3704,7 @@ class Tensor(SimpleMathTrait):
3394
3704
  """
3395
3705
  return self._apply_broadcasted_uop(UOp.maximum, x)
3396
3706
 
3397
- def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
3707
+ def minimum(self, x:Tensor|ConstType) -> Tensor:
3398
3708
  """
3399
3709
  Computes element-wise minimum of `self` and `x`.
3400
3710
 
@@ -3408,9 +3718,9 @@ class Tensor(SimpleMathTrait):
3408
3718
  t, x = self._broadcasted(x)
3409
3719
  return t._inverse().maximum(x._inverse())._inverse()
3410
3720
 
3411
- def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
3721
+ def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
3412
3722
  """
3413
- Return a tensor of elements selected from either `x` or `y`, depending on `self`.
3723
+ Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
3414
3724
  `output_i = x_i if self_i else y_i`.
3415
3725
 
3416
3726
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3432,14 +3742,22 @@ class Tensor(SimpleMathTrait):
3432
3742
  cond, y = cond._broadcasted(y, match_dtype=False)
3433
3743
  return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
3434
3744
 
3435
- def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
3745
+ def copysign(self, other) -> Tensor:
3746
+ """
3747
+ Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
3748
+ """
3749
+ # NOTE: torch always return in float, we return based on the broadcasting rule.
3750
+ other = self._broadcasted(other)[1]
3751
+ # TODO: remove other*0?
3752
+ return (other < 0).where(-self.abs(), self.abs()) + other*0
3436
3753
 
3437
3754
  # ***** op wrappers *****
3438
3755
 
3439
3756
  def __invert__(self) -> Tensor: return self.bitwise_not()
3440
3757
 
3441
- def __lshift__(self, x) -> Tensor: return self.lshift(x)
3442
- def __rshift__(self, x) -> Tensor: return self.rshift(x)
3758
+ # TODO: combine with UOps __floordiv__
3759
+ def __floordiv__(self, x): return self.div(x, rounding_mode="floor")
3760
+ def __rfloordiv__(self, x): return self.div(x, rounding_mode="floor", reverse=True)
3443
3761
 
3444
3762
  def __pow__(self, x) -> Tensor: return self.pow(x)
3445
3763
  def __matmul__(self, x) -> Tensor: return self.matmul(x)
@@ -3452,11 +3770,11 @@ class Tensor(SimpleMathTrait):
3452
3770
  def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
3453
3771
  def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
3454
3772
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
3455
- def __ifloordiv__(self, x) -> Tensor: return self.assign(self.idiv(x))
3773
+ def __ifloordiv__(self, x) -> Tensor: return self.assign(self.__floordiv__(x))
3456
3774
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
3457
3775
  def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
3458
3776
  def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
3459
- def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
3777
+ def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x))
3460
3778
  def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
3461
3779
  def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
3462
3780
 
@@ -3468,7 +3786,7 @@ class Tensor(SimpleMathTrait):
3468
3786
 
3469
3787
  # ***** functional nn ops *****
3470
3788
 
3471
- def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
3789
+ def linear(self, weight:Tensor, bias:Tensor|None=None, dtype:DTypeLike|None=None) -> Tensor:
3472
3790
  """
3473
3791
  Applies a linear transformation to `self` using `weight` and `bias`.
3474
3792
 
@@ -3481,10 +3799,11 @@ class Tensor(SimpleMathTrait):
3481
3799
  print(t.linear(weight, bias).numpy())
3482
3800
  ```
3483
3801
  """
3802
+ if dtype is not None: return self.cast(dtype).linear(weight.cast(dtype), bias.cast(dtype) if bias is not None else bias)
3484
3803
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
3485
3804
  return x.add(bias) if bias is not None else x
3486
3805
 
3487
- def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
3806
+ def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor:
3488
3807
  """
3489
3808
  Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
3490
3809
 
@@ -3495,11 +3814,10 @@ class Tensor(SimpleMathTrait):
3495
3814
  """
3496
3815
  return functools.reduce(lambda x,f: f(x), ll, self)
3497
3816
 
3498
- def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
3817
+ def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor:
3499
3818
  """
3500
3819
  Applies Layer Normalization over a mini-batch of inputs.
3501
3820
 
3502
- - Described: https://paperswithcode.com/method/layer-normalization
3503
3821
  - Paper: https://arxiv.org/abs/1607.06450v1
3504
3822
 
3505
3823
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3514,11 +3832,10 @@ class Tensor(SimpleMathTrait):
3514
3832
  y = (self - self.mean(axis, keepdim=True))
3515
3833
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
3516
3834
 
3517
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
3835
+ def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor:
3518
3836
  """
3519
3837
  Applies Batch Normalization over a mini-batch of inputs.
3520
3838
 
3521
- - Described: https://paperswithcode.com/method/batch-normalization
3522
3839
  - Paper: https://arxiv.org/abs/1502.03167
3523
3840
 
3524
3841
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3543,7 +3860,6 @@ class Tensor(SimpleMathTrait):
3543
3860
 
3544
3861
  NOTE: dropout is only applied when `Tensor.training` is `True`.
3545
3862
 
3546
- - Described: https://paperswithcode.com/method/dropout
3547
3863
  - Paper: https://jmlr.org/papers/v15/srivastava14a.html
3548
3864
 
3549
3865
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3553,11 +3869,13 @@ class Tensor(SimpleMathTrait):
3553
3869
  print(t.dropout().numpy())
3554
3870
  ```
3555
3871
  """
3872
+ if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
3556
3873
  if not Tensor.training or p == 0: return self
3874
+ if p == 1: return self.zeros_like()
3557
3875
  return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
3558
3876
 
3559
3877
  # helper function commonly used for indexing
3560
- def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
3878
+ def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor:
3561
3879
  if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
3562
3880
  offset = self.ndim - self._resolve_dim(dim) - 1
3563
3881
  return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
@@ -3577,12 +3895,12 @@ class Tensor(SimpleMathTrait):
3577
3895
  if num_classes == -1: num_classes = (self.max()+1).item()
3578
3896
  return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
3579
3897
 
3580
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
3898
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
3899
+ is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
3581
3900
  """
3582
3901
  Computes scaled dot-product attention.
3583
3902
  `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
3584
3903
 
3585
- - Described: https://paperswithcode.com/method/scaled
3586
3904
  - Paper: https://arxiv.org/abs/1706.03762v7
3587
3905
 
3588
3906
  ```python exec="true" source="above" session="tensor" result="python"
@@ -3594,7 +3912,11 @@ class Tensor(SimpleMathTrait):
3594
3912
  """
3595
3913
  # NOTE: it also works when `key` and `value` have symbolic shape.
3596
3914
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
3597
- qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
3915
+ # GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
3916
+ if enable_gqa:
3917
+ key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)
3918
+ value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3)
3919
+ qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
3598
3920
  # handle attention mask
3599
3921
  if is_causal:
3600
3922
  if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
@@ -3602,7 +3924,7 @@ class Tensor(SimpleMathTrait):
3602
3924
  if attn_mask is not None:
3603
3925
  if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
3604
3926
  qk = qk + attn_mask
3605
- return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3927
+ return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
3606
3928
 
3607
3929
  def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
3608
3930
  if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
@@ -3623,7 +3945,7 @@ class Tensor(SimpleMathTrait):
3623
3945
  """
3624
3946
  return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
3625
3947
 
3626
- def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
3948
+ def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean", pos_weight:Tensor|None=None) -> Tensor:
3627
3949
  """
3628
3950
  Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
3629
3951
 
@@ -3635,7 +3957,8 @@ class Tensor(SimpleMathTrait):
3635
3957
  print(t.binary_crossentropy_logits(Y).item())
3636
3958
  ```
3637
3959
  """
3638
- return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
3960
+ log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
3961
+ return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
3639
3962
 
3640
3963
  def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
3641
3964
  """
@@ -3653,10 +3976,10 @@ class Tensor(SimpleMathTrait):
3653
3976
  ```
3654
3977
  """
3655
3978
  assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
3656
- assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
3657
- log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
3658
- y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
3659
- y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
3979
+ assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
3980
+ log_probs = self.log_softmax()
3981
+ loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
3982
+ y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
3660
3983
  smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
3661
3984
  unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
3662
3985
  # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
@@ -3664,7 +3987,7 @@ class Tensor(SimpleMathTrait):
3664
3987
 
3665
3988
  def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
3666
3989
  """
3667
- Compute the cross entropy loss between input logits and target.
3990
+ Computes the cross entropy loss between input logits and target.
3668
3991
 
3669
3992
  NOTE: `self` are logits and `Y` are the target labels or class probabilities.
3670
3993
 
@@ -3682,14 +4005,16 @@ class Tensor(SimpleMathTrait):
3682
4005
  ```
3683
4006
  """
3684
4007
  assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
3685
- Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
3686
- Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
3687
- ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
3688
- return ret._do_reduction(reduction)
4008
+ classes_dim = 0 if self.ndim == 1 else 1
4009
+ if self.shape != Y.shape:
4010
+ if self.max(classes_dim).shape != Y.shape: raise RuntimeError(f"shape mismatch: {self.shape=}, {Y.shape=}")
4011
+ Y = Y.unsqueeze(classes_dim)._one_hot_along_dim(num_classes=self.shape[classes_dim], dim=classes_dim)
4012
+ Y = (1 - label_smoothing)*Y + label_smoothing / int(Y.shape[classes_dim])
4013
+ return -self.log_softmax(classes_dim).mul(Y).sum(classes_dim)._do_reduction(reduction)
3689
4014
 
3690
- def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
4015
+ def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor:
3691
4016
  """
3692
- Compute the negative log likelihood loss between log-probabilities and target labels.
4017
+ Computes the negative log likelihood loss between log-probabilities and target labels.
3693
4018
 
3694
4019
  NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
3695
4020
 
@@ -3711,6 +4036,87 @@ class Tensor(SimpleMathTrait):
3711
4036
  nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
3712
4037
  return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
3713
4038
 
4039
+ def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
4040
+ """
4041
+ Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
4042
+
4043
+ ```python exec="true" source="above" session="tensor" result="python"
4044
+ t = Tensor.randn(4, 4)
4045
+ print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy())
4046
+ ```
4047
+ """
4048
+ assert self.ndim > 1, "NS only works for two or more dims"
4049
+ G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
4050
+ G = G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
4051
+ for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
4052
+ return G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
4053
+
4054
+ def qr(self) -> tuple[Tensor, Tensor]:
4055
+ assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
4056
+ R = self.clone()
4057
+ b_shape, m, n = self.shape[0:self.ndim - 2], int(R.shape[-2]), int(R.shape[-1])
4058
+ Q = Tensor.eye(m, dtype = self.dtype).reshape((1,) * (len(self.shape) - 2) + 2 * (m,)).expand(b_shape + 2 * (m,)).contiguous()
4059
+ for i in range(int(min(m, n))):
4060
+ x = R[..., i:m, i]
4061
+ s = -x[..., 0].sign()
4062
+ u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
4063
+ w = x.unsqueeze(-1) / u1.reshape(b_shape + 2 * (1,))
4064
+ w[..., 0, 0] = 1
4065
+ tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + 2 * (1,)).expand(w.shape)
4066
+ R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
4067
+ Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau.transpose(-2, -1) * w.transpose(-2, -1))
4068
+ return Q,R
4069
+
4070
+ def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]:
4071
+ #partial implementation of https://www.netlib.org/lapack/lawnspdf/lawn169.pdf , pg 26
4072
+ assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
4073
+ b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
4074
+ #preprocess the matrix
4075
+ Q, R = (Tensor.qr(self) if m >= n else Tensor.qr(self.transpose(-2, -1)))
4076
+ num, q_num = int(min(m, n)), int(max(m, n))
4077
+ U = R.shrink(tuple([(0, self.shape[i]) for i in range(self.ndim - 2)] + [(0, num), (0, num)])).contiguous()
4078
+ V = Tensor.eye(num, dtype = self.dtype).reshape((1,) * (self.ndim - 2) + (num, num)).expand(b_shape + 2 * (num,)).contiguous()
4079
+ #prepare round robin pairing
4080
+ permute, inverse_permute = Tensor.arange(0, num, dtype = dtypes.int), Tensor.zeros(num, dtype = dtypes.int).contiguous()
4081
+ permute[num//2:num] = permute[num//2:num].flip(0)
4082
+ inverse_permute[permute] = Tensor.arange(num, dtype = dtypes.int)
4083
+ def one_round_jacobi(U, V,permute,inverse_permute):
4084
+ #pair all the columns
4085
+ V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None)
4086
+ V_left, V_right = V_permuted.split(num//2, -1)
4087
+ U_permuted, runoff_U = (U[..., permute].split(num - 1, -1)) if num % 2 == 1 else (U[..., permute], None)
4088
+ U_left, U_right = U_permuted.split(num//2, -1)
4089
+ #compute the jacobi rotations for each pairing
4090
+ gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
4091
+ alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
4092
+ tau = (beta - alpha) / (2 * gamma)
4093
+ t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
4094
+ c = 1 / (1 + t.square()).sqrt()
4095
+ s = c * t
4096
+ #apply the rotations
4097
+ U_left, U_right = c * U_left - s * U_right, s * U_left + c * U_right
4098
+ U = U_left.cat(U_right.cat(runoff_U, dim = -1) if num % 2 == 1 else U_right, dim = -1)[..., inverse_permute]
4099
+ V_left, V_right = c * V_left - s * V_right, s * V_left + c * V_right
4100
+ V = V_left.cat(V_right.cat(runoff_V, dim = -1) if num % 2 == 1 else V_right, dim = -1)[..., inverse_permute]
4101
+ #prepare the next round robin pairings
4102
+ if num % 2 == 1: permute = ((permute - 1) % num)
4103
+ else: permute = permute[0].reshape(1).cat(((permute[1:num] - 2) % (num - 1)) + 1)
4104
+ inverse_permute = inverse_permute.scatter(0,permute,Tensor.arange(num,dtype=dtypes.int32))
4105
+ return U, V, permute, inverse_permute
4106
+ max_iterations, iterations_per_round = 1, int((num) * math.log2(num) * 2 + 2)#sorta heuristic, most use num*log2(num)
4107
+ for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
4108
+ #extract singular values and sort. construct U from Q
4109
+ S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
4110
+ new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + 2 * (num,)).contiguous()
4111
+ new_indices[..., :num] = indices.reshape(b_shape + (1,) + (num,)).expand(b_shape + 2 * (num,))
4112
+ U,V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
4113
+
4114
+ padded_u = Tensor.eye(q_num, dtype = U.dtype).reshape((1,) * (self.ndim - 2) + 2 * (q_num,)).expand(b_shape + 2 * (q_num,)).contiguous()
4115
+ padded_u[..., 0:num, 0:num] = U
4116
+ U = Q @ padded_u
4117
+ if not full_matrices: U, V = U[..., 0:num], V[..., 0:num]
4118
+ return (U, S, V.transpose(-2,-1)) if m >= n else (V, S, U.transpose(-2, -1))
4119
+
3714
4120
  # ***** Tensor Properties *****
3715
4121
 
3716
4122
  @property
@@ -3760,8 +4166,8 @@ class Tensor(SimpleMathTrait):
3760
4166
 
3761
4167
  def is_floating_point(self) -> bool:
3762
4168
  """
3763
- Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
3764
- `dtype.float16`, `dtype.bfloat16`.
4169
+ Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
4170
+ `dtypes.float16`, `dtypes.bfloat16`.
3765
4171
 
3766
4172
  ```python exec="true" source="above" session="tensor" result="python"
3767
4173
  t = Tensor([8, 9], dtype=dtypes.float32)
@@ -3770,9 +4176,9 @@ class Tensor(SimpleMathTrait):
3770
4176
  """
3771
4177
  return dtypes.is_float(self.dtype)
3772
4178
 
3773
- def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
4179
+ def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
3774
4180
  """
3775
- Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
4181
+ Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
3776
4182
 
3777
4183
  ```python exec="true" source="above" session="tensor" result="python"
3778
4184
  t = Tensor([[4, 5, 6], [7, 8, 9]])
@@ -3786,7 +4192,7 @@ class Tensor(SimpleMathTrait):
3786
4192
 
3787
4193
  # ***** cast ops *****
3788
4194
 
3789
- def llvm_bf16_cast(self, dtype:DTypeLike):
4195
+ def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
3790
4196
  # hack for devices that don't support bfloat16
3791
4197
  assert self.dtype == dtypes.bfloat16
3792
4198
  return self.to("LLVM").cast(dtype)
@@ -3834,7 +4240,10 @@ class Tensor(SimpleMathTrait):
3834
4240
  if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
3835
4241
  new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
3836
4242
  tmp = self.bitcast(old_uint)
3837
- if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
4243
+ if ns > os:
4244
+ tmp = tmp.reshape(self.shape[:-1] + (self.shape[-1]//(rate := ns//os), rate))
4245
+ nones = (None,) * (tmp.ndim - 1)
4246
+ return functools.reduce(Tensor.add, (tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate))).squeeze(-1).bitcast(dtype)
3838
4247
  return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
3839
4248
  return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
3840
4249
 
@@ -3898,9 +4307,14 @@ class Tensor(SimpleMathTrait):
3898
4307
  """
3899
4308
  return self.cast(dtypes.bool)
3900
4309
 
4310
+ def bfloat16(self) -> Tensor: return self.cast(dtypes.bfloat16)
4311
+ def double(self) -> Tensor: return self.cast(dtypes.double)
4312
+ def long(self) -> Tensor: return self.cast(dtypes.long)
4313
+ def short(self) -> Tensor: return self.cast(dtypes.short)
4314
+
3901
4315
  # *** image Tensor function replacements ***
3902
4316
 
3903
- def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
4317
+ def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
3904
4318
  # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
3905
4319
  x, dx, dw = self, self.ndim, w.ndim
3906
4320
  if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
@@ -3914,9 +4328,9 @@ class Tensor(SimpleMathTrait):
3914
4328
  cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
3915
4329
  # groups*cout x cin x H, W
3916
4330
  cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
3917
- return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
4331
+ return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
3918
4332
 
3919
- def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
4333
+ def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
3920
4334
  base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
3921
4335
 
3922
4336
  (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
@@ -3965,7 +4379,7 @@ class Tensor(SimpleMathTrait):
3965
4379
  w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
3966
4380
 
3967
4381
  # the conv!
3968
- ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
4382
+ ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
3969
4383
 
3970
4384
  # undo hack for non multiples of 4 on C.rcout
3971
4385
  if added_output_channels != 0:
@@ -3976,8 +4390,20 @@ class Tensor(SimpleMathTrait):
3976
4390
  ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
3977
4391
  return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
3978
4392
 
3979
- def _metadata_wrapper(fn):
3980
- def _wrapper(*args, **kwargs):
4393
+ P = ParamSpec("P")
4394
+ T = TypeVar("T")
4395
+
4396
+ # this tracks the tensor.py METADATA, contextvars.ContextVar was switched to this due to thread safety issues
4397
+ class _ContextVar(Generic[T]):
4398
+ def __init__(self, default:T): self.state:T = default
4399
+ def get(self) -> T: return self.state
4400
+ def set(self, x:T) -> T:
4401
+ ret, self.state = self.state, x
4402
+ return ret
4403
+ _METADATA: _ContextVar[Metadata|None] = _ContextVar(default=None)
4404
+
4405
+ def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
4406
+ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
3981
4407
  if _METADATA.get() is not None: return fn(*args, **kwargs)
3982
4408
 
3983
4409
  if TRACEMETA >= 2:
@@ -4004,7 +4430,7 @@ def _metadata_wrapper(fn):
4004
4430
 
4005
4431
  token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
4006
4432
  ret = fn(*args, **kwargs)
4007
- _METADATA.reset(token)
4433
+ _METADATA.set(token)
4008
4434
  return ret
4009
4435
  return _wrapper
4010
4436