tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,47 +1,53 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
3
+ import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
4
4
  from contextlib import ContextDecorator
5
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
6
- from collections import defaultdict
7
-
5
+ from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
8
6
  from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
9
7
  from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
10
- from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
11
- from tinygrad.multi import MultiLazyBuffer
12
- from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
13
- from tinygrad.device import Device, Buffer, BufferOptions
14
- from tinygrad.engine.lazy import LazyBuffer
8
+ from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
9
+ from tinygrad.engine.multi import get_multi_map
10
+ from tinygrad.gradient import compute_gradient
11
+ from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
12
+ from tinygrad.spec import tensor_uop_spec, type_verify
13
+ from tinygrad.device import Device, BufferSpec
15
14
  from tinygrad.engine.realize import run_schedule
16
15
  from tinygrad.engine.memory import memory_planner
17
16
  from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
18
17
 
19
- # **** start with two base classes, Tensor and Function ****
18
+ # *** all in scope Tensors are here. this gets relevant UOps ***
20
19
 
21
- class Function:
22
- def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
23
- self.device = device
24
- self.needs_input_grad = [t.requires_grad for t in tensors]
25
- self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
26
- if self.requires_grad: self.parents = tensors
27
- self.metadata = metadata
20
+ all_tensors: set[weakref.ref[Tensor]] = set()
28
21
 
29
- def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
30
- def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
22
+ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
23
+ # get all children of keys in applied_map
24
+ all_uops: set[UOp] = set()
25
+ search_uops = list(applied_map)
26
+ while len(search_uops):
27
+ x = search_uops.pop(0)
28
+ if x in all_uops: continue
29
+ all_uops.add(x)
30
+ search_uops.extend([u for c in x.children if (u:=c()) is not None])
31
31
 
32
- @classmethod
33
- def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
34
- ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
35
- ret = Tensor.__new__(Tensor)
36
- ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
- ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
38
- return ret
32
+ # link the found UOps back to Tensors. exit early if there's no Tensors to realize
33
+ # NOTE: this uses all_tensors, but it's fast
34
+ fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
35
+
36
+ if len(fixed_tensors):
37
+ # potentially rewrite all the discovered Tensors
38
+ sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
39
+ new_sink = sink.substitute(applied_map)
39
40
 
40
- import tinygrad.function as F
41
+ # set the relevant lazydata to the realized UOps
42
+ for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
43
+ if s is ns: continue
44
+ t.lazydata = ns
41
45
 
42
- def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
- if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
44
- return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
46
+ # **** Tensor helper functions ****
47
+
48
+ def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
49
+ if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
50
+ return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
45
51
 
46
52
  def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
47
53
  import numpy as np
@@ -50,33 +56,31 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
50
56
  import numpy as np
51
57
  return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
52
58
 
53
- def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
54
- ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
59
+ def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
60
+ ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
55
61
  # fake realize
56
62
  ret.buffer.allocate(x)
57
- del ret.srcs
58
63
  return ret
59
64
 
60
- def get_shape(x) -> Tuple[int, ...]:
65
+ def get_shape(x) -> tuple[int, ...]:
61
66
  # NOTE: str is special because __getitem__ on a str is still a str
62
67
  if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
63
68
  if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
64
69
  return (len(subs),) + (subs[0] if subs else ())
65
70
 
66
- def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
67
- if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
71
+ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
72
+ if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
68
73
  else:
69
- ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
74
+ ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
70
75
  assert dtype.fmt is not None, f"{dtype=} has None fmt"
71
76
  truncate_function = truncate[dtype]
72
77
  data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
73
78
  # fake realize
74
79
  ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
75
- del ret.srcs
76
80
  return ret
77
81
 
78
- def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
79
- return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
82
+ def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
83
+ return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
80
84
  for k in range(len(mat[0]))] for dim in range(dims)]
81
85
 
82
86
  # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
@@ -85,21 +89,34 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
85
89
  # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
86
90
  t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
87
91
  # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
88
- matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
92
+ matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
89
93
  # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
90
94
  ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
91
95
  assert isinstance(ret, Tensor), "sum didn't return a Tensor"
92
96
  return ret
93
97
 
94
- def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
98
+ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
99
+ # unsqueeze left to make every shape same length
95
100
  max_dim = max(len(shape) for shape in shapes)
96
101
  return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
97
- def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
98
- return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
102
+ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
103
+ return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
104
+
105
+ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
106
+ # apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
107
+ values = values * mask
108
+ for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
109
+ # remove extra dims from reduce
110
+ for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
111
+ # select from values for each True element in mask else select from self
112
+ return mask.where(values, target)
113
+
114
+ # `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
115
+ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
99
116
 
100
117
  ReductionStr = Literal["mean", "sum", "none"]
101
118
 
102
- class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
119
+ class Tensor(SimpleMathTrait):
103
120
  """
104
121
  A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
105
122
 
@@ -110,15 +127,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
110
127
  np.set_printoptions(precision=4)
111
128
  ```
112
129
  """
113
- __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
114
- __deletable__ = ('_ctx',)
130
+ __slots__ = "lazydata", "requires_grad", "grad"
115
131
  training: ClassVar[bool] = False
116
132
  no_grad: ClassVar[bool] = False
117
133
 
118
- def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, LazyBuffer, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
134
+ def __new__(cls, *args, **kwargs):
135
+ instance = super().__new__(cls)
136
+ all_tensors.add(weakref.ref(instance))
137
+ return instance
138
+ def __del__(self): all_tensors.discard(weakref.ref(self))
139
+
140
+ def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
119
141
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
120
142
  if dtype is not None: dtype = to_dtype(dtype)
121
- assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
122
143
  if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
123
144
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
124
145
 
@@ -129,21 +150,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
129
150
  # None (the default) will be updated to True if it's put in an optimizer
130
151
  self.requires_grad: Optional[bool] = requires_grad
131
152
 
132
- # internal variable used for autograd graph construction
133
- self._ctx: Optional[Function] = None
134
-
135
153
  # create a LazyBuffer from the different types of inputs
136
- if isinstance(data, (LazyBuffer, MultiLazyBuffer)): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
154
+ if isinstance(data, UOp):
155
+ assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
156
+ # NOTE: this is here because LazyBuffer = UOp
157
+ if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
137
158
  elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
138
159
  elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
139
- elif isinstance(data, UOp):
140
- assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
141
- data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
142
160
  elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
143
161
  elif isinstance(data, (list, tuple)):
144
162
  if dtype is None:
145
163
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
146
- else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
164
+ else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
147
165
  if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
148
166
  else: data = _frompy(data, dtype)
149
167
  elif str(type(data)) == "<class 'numpy.ndarray'>":
@@ -156,16 +174,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
156
174
  data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
157
175
 
158
176
  # by this point, it has to be a LazyBuffer
159
- if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
177
+ if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
160
178
 
161
179
  # data might be on a different device
162
- if isinstance(device, str): self.lazydata:Union[LazyBuffer, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
180
+ if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
163
181
  # if device is a tuple, we should have/construct a MultiLazyBuffer
164
- elif isinstance(data, LazyBuffer): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
182
+ elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
165
183
  else:
166
184
  assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
167
185
  self.lazydata = data
168
186
 
187
+ def requires_grad_(self, requires_grad=True) -> Tensor:
188
+ self.requires_grad = requires_grad
189
+ return self
190
+
169
191
  class train(ContextDecorator):
170
192
  def __init__(self, mode:bool = True): self.mode = mode
171
193
  def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
@@ -177,7 +199,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
177
199
  def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
178
200
 
179
201
  def __repr__(self):
180
- return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
202
+ ld = self.lazydata
203
+ ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
204
+ return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
181
205
 
182
206
  # Python has a non moving GC, so this should be okay
183
207
  def __hash__(self): return id(self)
@@ -189,26 +213,49 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
189
213
  return self.shape[0]
190
214
 
191
215
  @property
192
- def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
216
+ def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
193
217
 
194
218
  @property
195
- def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
219
+ def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
196
220
 
197
221
  @property
198
222
  def dtype(self) -> DType: return self.lazydata.dtype
199
223
 
224
+ def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
225
+ ret = Tensor.__new__(Tensor)
226
+ needs_input_grad = [t.requires_grad for t in (self,)+x]
227
+ ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
228
+ ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
229
+ return ret
230
+
231
+ def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
232
+ lhs,rhs = self._broadcasted(x, reverse)
233
+ return lhs._apply_uop(fxn, rhs)
234
+
200
235
  # ***** data handlers ****
201
236
 
202
- def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
237
+ def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
203
238
  """
204
239
  Creates the schedule needed to realize these Tensor(s), with Variables.
205
240
 
206
241
  NOTE: A Tensor can only be scheduled once.
207
242
  """
208
- schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
243
+ big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
244
+
245
+ # TODO: move this to scheduler tensor_map pass
246
+ if any(x.op is Ops.MULTI for x in big_sink.toposort):
247
+ # multi fixup
248
+ _apply_map_to_tensors(get_multi_map(big_sink))
249
+ big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
250
+
251
+ # verify Tensors match the spec
252
+ if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
253
+
254
+ schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
255
+ _apply_map_to_tensors(becomes_map)
209
256
  return memory_planner(schedule), var_vals
210
257
 
211
- def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
258
+ def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
212
259
  """Creates the schedule needed to realize these Tensor(s)."""
213
260
  schedule, var_vals = self.schedule_with_vars(*lst)
214
261
  assert len(var_vals) == 0
@@ -224,7 +271,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
224
271
  Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
225
272
  """
226
273
  # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
227
- assert not x.requires_grad and getattr(self, '_ctx', None) is None
228
274
  assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
229
275
  self.lazydata = x.lazydata
230
276
  return self
@@ -232,17 +278,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
232
278
  def assign(self, x) -> Tensor:
233
279
  # TODO: this is a hack for writing to DISK. remove with working assign
234
280
  if isinstance(self.device, str) and self.device.startswith("DISK"):
235
- if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
236
- self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
281
+ if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
282
+ self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
237
283
  return self
238
284
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
239
- if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
240
285
  if self.lazydata is x.lazydata: return self # a self assign is a NOOP
241
286
  # NOTE: we allow cross device assign
242
287
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
243
288
  assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
244
289
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
245
- assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
246
290
  assert not x.requires_grad # self requires_grad is okay?
247
291
  if not self.lazydata.is_realized: return self.replace(x)
248
292
  self.lazydata = self.lazydata.assign(x.lazydata)
@@ -252,14 +296,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
252
296
  """
253
297
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
254
298
  """
255
- return Tensor(self.lazydata, device=self.device, requires_grad=False)
299
+ return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
256
300
 
257
301
  def _data(self) -> memoryview:
258
302
  if 0 in self.shape: return memoryview(bytearray(0))
259
303
  # NOTE: this realizes on the object from as_buffer being a Python object
260
304
  cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
261
- buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
262
- if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
305
+ buf = cast(UOp, cpu.lazydata).base.realized
306
+ assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
307
+ if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
263
308
  return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
264
309
 
265
310
  def data(self) -> memoryview:
@@ -271,9 +316,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
271
316
  print(np.frombuffer(t.data(), dtype=np.int32))
272
317
  ```
273
318
  """
274
- assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
319
+ assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
275
320
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
276
- return self._data().cast(self.dtype.fmt, self.shape)
321
+ if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
322
+ return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
277
323
 
278
324
  def item(self) -> ConstType:
279
325
  """
@@ -284,11 +330,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
284
330
  print(t.item())
285
331
  ```
286
332
  """
287
- assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
288
333
  assert self.numel() == 1, "must have one element for item"
289
- return self._data().cast(self.dtype.fmt)[0]
334
+ return self.data()[(0,) * len(self.shape)]
290
335
 
291
- # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
336
+ # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
292
337
  # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
293
338
  def tolist(self) -> Union[Sequence[ConstType], ConstType]:
294
339
  """
@@ -311,21 +356,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
311
356
  ```
312
357
  """
313
358
  import numpy as np
314
- if self.dtype == dtypes.bfloat16: return self.float().numpy()
315
- assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
359
+ if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
360
+ assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
316
361
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
317
- return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
362
+ return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
318
363
 
319
364
  def clone(self) -> Tensor:
320
365
  """
321
- Creates a clone of this tensor allocating a seperate buffer for the data.
366
+ Creates a clone of this tensor allocating a separate buffer for the data.
322
367
  """
323
368
  ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
324
369
  if self.grad is not None: ret.grad = self.grad.clone()
325
- if hasattr(self, '_ctx'): ret._ctx = self._ctx
326
370
  return ret
327
371
 
328
- def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
372
+ def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
329
373
  """
330
374
  Moves the tensor to the given device.
331
375
  """
@@ -334,47 +378,35 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
334
378
  if not isinstance(device, str): return self.shard(device)
335
379
  ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
336
380
  if self.grad is not None: ret.grad = self.grad.to(device)
337
- if hasattr(self, '_ctx'): ret._ctx = self._ctx
338
381
  return ret
339
382
 
340
- def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
383
+ def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
341
384
  """
342
385
  Moves the tensor to the given device in place.
343
386
  """
344
387
  real = self.to(device)
345
- # TODO: is this assign?
346
- if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
347
- self.lazydata = real.lazydata
388
+ if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
389
+ return self.replace(real)
348
390
 
349
- def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
391
+ def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
350
392
  """
351
- Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
393
+ Shards the tensor across the given devices. Optionally specify which axis to shard on.
352
394
 
353
395
  ```python exec="true" source="above" session="tensor" result="python"
354
- t = Tensor.empty(2, 3)
355
- print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
396
+ t = Tensor.empty(2, 4)
397
+ print(t.shard((t.device, t.device), axis=1).lazydata)
356
398
  ```
357
-
358
399
  """
359
- assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
360
- devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
361
- if axis is not None:
362
- if axis < 0: axis += len(self.shape)
363
- if splits is None:
364
- if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
365
- sz = ceildiv(total, len(devices))
366
- splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
367
- assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
368
- boundaries = tuple(itertools.accumulate(splits))
369
- bounds = tuple(zip((0,) + boundaries, boundaries))
370
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
400
+ assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
401
+ devices = tuple(Device.canonicalize(x) for x in devices)
402
+ mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
403
+ return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
371
404
 
372
- def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
405
+ def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
373
406
  """
374
407
  Shards the tensor across the given devices in place.
375
408
  """
376
- self.lazydata = self.shard(devices, axis, splits).lazydata
377
- return self
409
+ return self.replace(self.shard(devices, axis))
378
410
 
379
411
  @staticmethod
380
412
  def from_uop(y:UOp, **kwargs) -> Tensor:
@@ -382,18 +414,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
382
414
  if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
383
415
  if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
384
416
  if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
385
- if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
386
417
  raise RuntimeError(f"unhandled UOp {y}")
387
418
 
388
419
  # ***** creation entrypoint *****
389
420
 
390
421
  @staticmethod
391
- def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
422
+ def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
392
423
  dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
393
424
  if isinstance(device, tuple):
394
- return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
425
+ return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
395
426
  device, dtype, **kwargs)
396
- return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
427
+ return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
397
428
 
398
429
  @staticmethod
399
430
  def empty(*shape, **kwargs):
@@ -411,7 +442,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
411
442
  return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
412
443
 
413
444
  @staticmethod
414
- def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
445
+ def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
415
446
  """
416
447
  Exposes the pointer as a Tensor without taking ownership of the original data.
417
448
  The pointer must remain valid for the entire lifetime of the created Tensor.
@@ -422,7 +453,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
422
453
 
423
454
  r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
424
455
  r.lazydata.buffer.allocate(external_ptr=ptr)
425
- del r.lazydata.srcs # fake realize
426
456
  return r
427
457
 
428
458
  @staticmethod
@@ -439,8 +469,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
439
469
  return Tensor(fetch(url, gunzip=gunzip), **kwargs)
440
470
 
441
471
  _seed: int = int(time.time())
442
- _device_seeds: Dict[str, Tensor] = {}
443
- _device_rng_counters: Dict[str, Tensor] = {}
472
+ _device_seeds: dict[str, Tensor] = {}
473
+ _device_rng_counters: dict[str, Tensor] = {}
444
474
  @staticmethod
445
475
  def manual_seed(seed=0):
446
476
  """
@@ -462,7 +492,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
462
492
  @staticmethod
463
493
  def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
464
494
  x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
465
- x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
495
+ x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
466
496
  counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
467
497
  return counts0.cat(counts1)
468
498
 
@@ -485,6 +515,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
485
515
  if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
486
516
  _device = device = Device.canonicalize(device)
487
517
 
518
+ # if shape has 0, return zero tensor
519
+ if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
520
+ num = ceildiv(numel * dtype.itemsize, 4)
521
+
488
522
  # when using MOCKGPU and NV generate rand on CLANG
489
523
  if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
490
524
 
@@ -494,15 +528,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
494
528
  [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
495
529
  device=device, dtype=dtypes.uint32, requires_grad=False)
496
530
  Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
497
- had_counter = False
498
- else: had_counter = True
499
-
500
- # if shape has 0, return zero tensor
501
- if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
502
- num = ceildiv(numel * dtype.itemsize, 4)
503
-
504
531
  # increment rng counter for devices
505
- if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
532
+ else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
506
533
 
507
534
  # threefry random bits
508
535
  counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
@@ -528,7 +555,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
528
555
  # ***** creation helper functions *****
529
556
 
530
557
  @staticmethod
531
- def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
558
+ def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
532
559
  """
533
560
  Creates a tensor with the given shape, filled with the given value.
534
561
 
@@ -607,7 +634,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
607
634
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
608
635
  # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
609
636
  if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
610
- return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
637
+ return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
611
638
 
612
639
  @staticmethod
613
640
  def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
@@ -705,18 +732,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
705
732
  ```
706
733
  """
707
734
  dtype = kwargs.pop("dtype", self.dtype)
708
- if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
735
+ if isinstance(self.device, tuple):
709
736
  if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
710
737
  if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
711
738
  contiguous = kwargs.pop("contiguous", True)
712
- rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
713
- return Tensor(MultiLazyBuffer(cast(List[LazyBuffer], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
739
+ sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
740
+ rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
741
+ return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
714
742
  return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
715
743
 
716
744
  # ***** rng hlops *****
717
745
 
718
746
  @staticmethod
719
- def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
747
+ def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
720
748
  """
721
749
  Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
722
750
  If `dtype` is not specified, the default type is used.
@@ -731,10 +759,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
731
759
  """
732
760
  # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
733
761
  src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
734
- return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
762
+ return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
735
763
 
736
764
  @staticmethod
737
- def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
765
+ def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
738
766
  """
739
767
  Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
740
768
  If `dtype` is not specified, the default type is used.
@@ -748,12 +776,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
748
776
  ```
749
777
  """
750
778
  if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
751
- dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
779
+ dtype = to_dtype(dtype)
752
780
  if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
753
781
  return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
754
782
 
755
783
  @staticmethod
756
- def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
784
+ def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
757
785
  """
758
786
  Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
759
787
 
@@ -765,10 +793,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
765
793
  print(Tensor.normal(2, 3, mean=10, std=2).numpy())
766
794
  ```
767
795
  """
768
- return (std * Tensor.randn(*shape, **kwargs)) + mean
796
+ return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
769
797
 
770
798
  @staticmethod
771
- def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
799
+ def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
772
800
  """
773
801
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
774
802
 
@@ -780,8 +808,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
780
808
  print(Tensor.uniform(2, 3, low=2, high=10).numpy())
781
809
  ```
782
810
  """
783
- dtype = kwargs.pop("dtype", dtypes.default_float)
784
- return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
811
+ return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
785
812
 
786
813
  @staticmethod
787
814
  def scaled_uniform(*shape, **kwargs) -> Tensor:
@@ -860,49 +887,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
860
887
 
861
888
  # ***** toposort and backward pass *****
862
889
 
863
- def _deepwalk(self):
864
- def _walk(node, visited):
865
- visited.add(node)
866
- # if tensor is not leaf, reset grad
867
- if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
868
- if ctx:
869
- for i in node._ctx.parents:
870
- if i not in visited: yield from _walk(i, visited)
871
- yield node
872
- return list(_walk(self, set()))
890
+ def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
891
+ """
892
+ Compute the gradient of the targets with respect to self.
893
+
894
+ ```python exec="true" source="above" session="tensor" result="python"
895
+ x = Tensor.eye(3)
896
+ y = Tensor([[2.0,0,-2.0]])
897
+ z = y.matmul(x).sum()
898
+ dx, dy = z.gradient(x, y)
899
+
900
+ print(dx.tolist()) # dz/dx
901
+ print(dy.tolist()) # dz/dy
902
+ ```
903
+ """
904
+ assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
905
+ if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
906
+ rets = []
907
+ target_uops = [x.lazydata for x in targets]
908
+ grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
909
+ ret = []
910
+ for x in target_uops:
911
+ if (y:=grads.get(x)) is None:
912
+ if materialize_grads: y = x.const_like(0)
913
+ else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
914
+ ret.append(y)
915
+ rets.append(ret)
916
+ # create returned Tensors
917
+ return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
873
918
 
874
- def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
919
+ def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
875
920
  """
876
921
  Propagates the gradient of a tensor backwards through the computation graph.
877
922
  If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
878
- If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
879
923
  ```python exec="true" source="above" session="tensor" result="python"
880
924
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
881
925
  t.sum().backward()
882
926
  print(t.grad.numpy())
883
927
  ```
884
928
  """
885
- toposorted = self._deepwalk()
886
- if gradient is None:
887
- assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
888
- # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
889
- # this is "implicit gradient creation"
890
- gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
891
-
892
- assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
893
- self.grad = gradient
894
- for t0 in reversed(toposorted):
895
- if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
896
- token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
897
- grads = t0._ctx.backward(t0.grad.lazydata)
898
- _METADATA.reset(token)
899
- grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
900
- for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
901
- for t, g in zip(t0._ctx.parents, grads):
902
- if g is not None and t.requires_grad:
903
- assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
904
- t.grad = g if t.grad is None else (t.grad + g)
905
- if not retain_graph: del t0._ctx
929
+ all_uops = self.lazydata.toposort
930
+ tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
931
+ t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
932
+ # clear contexts
933
+ for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
934
+ assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
935
+ t.grad = g if t.grad is None else (t.grad + g)
906
936
  return self
907
937
 
908
938
  # ***** movement low level ops *****
@@ -926,7 +956,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
926
956
  # resolve -1
927
957
  if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
928
958
  if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
929
- return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
959
+ return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
930
960
 
931
961
  def expand(self, shape, *args) -> Tensor:
932
962
  """
@@ -940,7 +970,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
940
970
  print(t.expand(4, -1).numpy())
941
971
  ```
942
972
  """
943
- return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
973
+ new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
974
+ return self._broadcast_to(new_shape)
944
975
 
945
976
  def permute(self, order, *args) -> Tensor:
946
977
  """
@@ -958,7 +989,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
958
989
  """
959
990
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
960
991
  if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
961
- return F.Permute.apply(self, order=order_arg)
992
+ return self._apply_uop(UOp.permute, arg=order_arg)
962
993
 
963
994
  def flip(self, axis, *args) -> Tensor:
964
995
  """
@@ -978,9 +1009,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
978
1009
  """
979
1010
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
980
1011
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
981
- return F.Flip.apply(self, axis=axis_arg)
1012
+ return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
982
1013
 
983
- def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
1014
+ def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
984
1015
  """
985
1016
  Returns a tensor that shrinks the each axis based on input arg.
986
1017
  `arg` must have the same length as `self.ndim`.
@@ -998,24 +1029,25 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
998
1029
  ```
999
1030
  """
1000
1031
  if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
1001
- return F.Shrink.apply(self, arg=tuple(shrink_arg))
1032
+ return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
1002
1033
 
1003
- def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
1034
+ def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
1004
1035
  """
1005
1036
  Returns a tensor with padding applied based on the input `padding`.
1037
+
1006
1038
  `padding` supports two padding structures:
1007
1039
 
1008
- 1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
1009
- - This structure matches PyTorch's pad.
1010
- - `padding` length must be even.
1040
+ 1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
1041
+ - This structure matches PyTorch's pad.
1042
+ - `padding` length must be even.
1011
1043
 
1012
- 2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
1013
- - This structure matches pad for jax, numpy, tensorflow and others.
1014
- - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
1015
- - `padding` must have the same length as `self.ndim`.
1044
+ 2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
1045
+ - This structure matches pad for JAX, NumPy, TensorFlow, and others.
1046
+ - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
1047
+ - `padding` must have the same length as `self.ndim`.
1016
1048
 
1017
1049
  Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
1018
- Padding modes is selected with `mode` which supports `constant` and `reflect`.
1050
+ Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
1019
1051
 
1020
1052
  ```python exec="true" source="above" session="tensor" result="python"
1021
1053
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
@@ -1031,176 +1063,167 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1031
1063
  print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
1032
1064
  ```
1033
1065
  """
1034
- if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
1035
- if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1036
- # turn flat padding into group padding
1037
- pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
1066
+ if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
1067
+ # flat padding
1068
+ if all(isinstance(p, (int,UOp)) for p in padding):
1069
+ if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1070
+ pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
1071
+ # group padding
1072
+ else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
1038
1073
  if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
1039
- X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
1040
- def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
1041
- # early return for symbolic with positive pads (no need to max)
1042
- if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
1043
- pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
1044
- if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
1074
+ X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
1075
+ if mode == "constant":
1076
+ def _constant(x:Tensor,px,v):
1077
+ return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
1078
+ return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
1079
+ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
1045
1080
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1081
+ if mode == "circular":
1082
+ if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
1083
+ if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
1084
+ orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
1085
+ return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
1046
1086
  for d,(pB,pA) in enumerate(pads):
1047
- if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
1048
- slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
1049
- xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
1087
+ if mode == "reflect":
1088
+ if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
1089
+ slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
1090
+ xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
1091
+ if mode == "replicate":
1092
+ shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
1093
+ xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
1050
1094
  X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
1051
- return X.shrink(shrinks(X.shape))
1095
+ return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
1052
1096
 
1053
1097
  # ***** movement high level ops *****
1054
1098
 
1055
- # Supported Indexing Implementations:
1056
- # 1. Int indexing (no copy)
1057
- # - for all dims where there's int, shrink -> reshape
1058
- # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
1059
- # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
1060
- # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
1061
- # 2. Slice indexing (no copy)
1062
- # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
1063
- # - first shrink the Tensor to X.shrink(((start, end),))
1064
- # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
1065
- # - flip where dim value is negative
1066
- # - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
1067
- # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
1068
- # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
1069
- # 3. None indexing (no copy)
1070
- # - reshape (inject) a dim at the dim where there's None
1071
- # 4. Tensor indexing (copy)
1072
- # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
1073
- # - combine masks together with mul
1074
- # - apply mask to self by mask * self
1075
- # - sum reduce away the extra dims added from creating masks
1076
- # Tiny Things:
1077
- # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
1078
- # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
1079
- # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
1080
- # 2. Bool indexing is not supported
1081
- # 3. Out of bounds Tensor indexing results in 0
1082
- # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
1083
1099
  def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
1084
- # 1. indices normalization and validation
1085
- # treat internal tuples and lists as Tensors and standardize indices to list type
1086
- if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
1087
- elif isinstance(indices, (tuple, list)):
1088
- indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
1089
- else: indices = [indices]
1090
-
1100
+ # wrap single index into a list
1101
+ if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
1091
1102
  # turn scalar Tensors into const val for int indexing if possible
1092
- indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
1093
- # move Tensor indices to the same device as self
1094
- indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
1103
+ x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
1095
1104
 
1096
1105
  # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
1097
- ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
1106
+ if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
1098
1107
  fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
1099
1108
  num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
1100
- indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
1101
-
1102
- # use Dict[type, List[dimension]] to track elements in indices
1103
- type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
1104
-
1105
- # record None for dimension injection later and filter None and record rest of indices
1106
- type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
1107
- indices_filtered = [i for i in indices if i is not None]
1108
- for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
1109
-
1110
- if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
1111
- for index_type in type_dim:
1112
- if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
1113
1109
  if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
1110
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
1114
1111
 
1115
- # 2. basic indexing, uses only movement ops (no copy)
1116
- # currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
1117
- # turn indices in indices_filtered to Tuple[new_slice, strides]
1118
- for dim in type_dim[int]:
1119
- if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
1120
- raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
1121
- indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
1122
- for dim in type_dim[slice]:
1123
- if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
1124
- if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)):
1125
- raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice("
1126
- f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).")
1127
- s, e, st = index.indices(self.shape[dim])
1128
- indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
1129
- # skip all Tensor dims for basic indexing
1130
- for dim in type_dim[Tensor]:
1131
- dtype = indices_filtered[dim].dtype
1132
- if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported")
1133
- indices_filtered[dim] = ((0, self.shape[dim]), 1)
1134
-
1135
- new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
1136
- # flip negative strides
1137
- ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
1138
- # handle stride != 1 or -1
1139
- if any(abs(st) != 1 for st in strides):
1140
- strides = tuple(abs(s) for s in strides)
1141
- # pad shape to multiple of stride
1142
- if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted")
1143
- ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
1144
- ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
1145
- ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
1146
-
1147
- # inject 1 for dim where it's None and collapse dim for int
1148
- new_shape = list(ret.shape)
1149
- for dim in type_dim[None]: new_shape.insert(dim, 1)
1150
- for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
1151
-
1152
- ret = ret.reshape(new_shape)
1153
-
1154
- # 3. advanced indexing (copy)
1155
- if type_dim[Tensor]:
1156
- dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)]
1157
- # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
1158
- def calc_dim(tensor_dim:int) -> int:
1159
- return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
1160
-
1161
- assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
1162
- # track tensor_dim and tensor_index using a dict
1163
- # calc_dim to get dim and use that to normalize the negative tensor indices
1164
- idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors}
1165
-
1166
- masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
1167
- pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
1168
-
1169
- # create masks
1170
- for dim, i in idx.items():
1171
- try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
1112
+ indices_parsed, dim = [], 0
1113
+ for index in indices:
1114
+ size = 1 if index is None else self.shape[dim]
1115
+ boundary, stride = [0, size], 1 # defaults
1116
+ match index:
1117
+ case list() | tuple() | Tensor():
1118
+ if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
1119
+ if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
1120
+ index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
1121
+ case int() | UOp(): # sint
1122
+ if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
1123
+ boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
1124
+ case slice():
1125
+ if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
1126
+ if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
1127
+ # handle int slicing
1128
+ *boundary, stride = index.indices(cast(SupportsIndex, size))
1129
+ if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
1130
+ elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
1131
+ # update size for slice
1132
+ size = ceildiv((boundary[1] - boundary[0]), abs(stride))
1133
+ case None: pass # do nothing
1134
+ case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
1135
+ indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
1136
+ if index is not None: dim += 1
1137
+
1138
+ # movement op indexing
1139
+ if mops := [i for i in indices_parsed if i['index'] is not None]:
1140
+ # flip negative strides
1141
+ shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
1142
+ x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
1143
+ # handle stride != 1 or -1
1144
+ if any(abs(st) != 1 for st in strides):
1145
+ strides = tuple(abs(s) for s in strides)
1146
+ # pad shape to multiple of stride
1147
+ if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
1148
+ x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
1149
+ x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
1150
+ x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
1151
+
1152
+ # dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
1153
+ x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
1154
+
1155
+ # tensor indexing
1156
+ if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
1157
+ # unload the tensor object into actual tensors
1158
+ dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
1159
+ pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:]
1160
+
1161
+ # create index masks
1162
+ for dim, tensor in zip(dims, tensors):
1163
+ try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
1172
1164
  except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
1173
- a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
1174
- masks.append(i == a)
1165
+ masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
1175
1166
 
1176
1167
  # reduce masks to 1 mask
1177
1168
  mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
1178
1169
 
1179
1170
  # inject 1's for the extra dims added in create masks
1180
- reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
1171
+ reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
1181
1172
  # sum reduce the extra dims introduced in create masks
1182
- ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
1173
+ x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
1183
1174
 
1184
1175
  # special permute case
1185
- if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
1186
- ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
1176
+ if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
1177
+ x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
1187
1178
 
1188
1179
  # for advanced setitem, returns whole tensor with indices replaced
1189
1180
  if v is not None:
1190
- vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
1181
+ vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
1191
1182
  # add back reduced dims from sum
1192
1183
  for dim in sum_axis: vb = vb.unsqueeze(dim)
1193
- # axis to be reduced to match self.shape
1194
- axis = tuple(range(first_dim, first_dim + len(big_shape)))
1195
- # apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
1196
- vb = vb * mask
1197
- for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
1198
- # reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
1199
- ret = mask.any(axis).where(vb.squeeze(), self)
1184
+ # run _masked_setitem on tuple of axis that is to be reduced to match self.shape
1185
+ x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
1200
1186
 
1201
- return ret
1187
+ return x
1202
1188
 
1203
1189
  def __getitem__(self, indices) -> Tensor:
1190
+ """
1191
+ Retrieve a sub-tensor using indexing.
1192
+
1193
+ Supported Index Types: `int | slice | Tensor | None | List | Tuple | Ellipsis`
1194
+
1195
+ Examples:
1196
+ ```python exec="true" source="above" session="tensor" result="python"
1197
+ t = Tensor.arange(12).reshape(3, 4)
1198
+ print(t.numpy())
1199
+ ```
1200
+
1201
+ - Int Indexing: Select an element or sub-tensor using integers for each dimension.
1202
+ ```python exec="true" source="above" session="tensor" result="python"
1203
+ print(t[1, 2].numpy())
1204
+ ```
1205
+
1206
+ - Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
1207
+ ```python exec="true" source="above" session="tensor" result="python"
1208
+ print(t[0:2, ::2].numpy())
1209
+ ```
1210
+
1211
+ - Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
1212
+ ```python exec="true" source="above" session="tensor" result="python"
1213
+ print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
1214
+ ```
1215
+
1216
+ - `None` Indexing: Add a new dimension to the tensor.
1217
+ ```python exec="true" source="above" session="tensor" result="python"
1218
+ print(t[:, None].shape)
1219
+ ```
1220
+
1221
+ NOTE: Out-of-bounds indexing results in a value of `0`.
1222
+ ```python exec="true" source="above" session="tensor" result="python"
1223
+ t = Tensor([1, 2, 3])
1224
+ print(t[Tensor([4, 3, 2])].numpy())
1225
+ ```
1226
+ """
1204
1227
  return self._getitem(indices)
1205
1228
 
1206
1229
  def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
@@ -1208,7 +1231,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1208
1231
  self._getitem(indices).assign(v)
1209
1232
  return
1210
1233
  # NOTE: check that setitem target is valid first
1211
- if not all(lb.st.contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
1234
+ if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
1212
1235
  if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1213
1236
  if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
1214
1237
  if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
@@ -1238,7 +1261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1238
1261
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1239
1262
  index = index.to(self.device)
1240
1263
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1241
- return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
1264
+ return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
1242
1265
 
1243
1266
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1244
1267
  """
@@ -1302,7 +1325,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1302
1325
  ```
1303
1326
  """
1304
1327
  repeats = argfix(repeats, *args)
1305
- base_shape = _pad_left(self.shape, repeats)[0]
1328
+ base_shape = _align_left(self.shape, repeats)[0]
1306
1329
  unsqueezed_shape = flatten([[1, s] for s in base_shape])
1307
1330
  expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
1308
1331
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
@@ -1313,7 +1336,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1313
1336
  if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
1314
1337
  return dim + total if dim < 0 else dim
1315
1338
 
1316
- def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1339
+ def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
1317
1340
  """
1318
1341
  Splits the tensor into chunks along the dimension specified by `dim`.
1319
1342
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
@@ -1338,7 +1361,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1338
1361
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
1339
1362
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
1340
1363
 
1341
- def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
1364
+ def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
1342
1365
  """
1343
1366
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
1344
1367
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
@@ -1362,7 +1385,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1362
1385
  dim = self._resolve_dim(dim)
1363
1386
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1364
1387
 
1365
- def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]:
1388
+ def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
1366
1389
  """
1367
1390
  Generates coordinate matrices from coordinate vectors.
1368
1391
  Input tensors can be scalars or 1D tensors.
@@ -1462,7 +1485,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1462
1485
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
1463
1486
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1464
1487
 
1465
- def unflatten(self, dim:int, sizes:Tuple[int,...]):
1488
+ def unflatten(self, dim:int, sizes:tuple[int,...]):
1466
1489
  """
1467
1490
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
1468
1491
 
@@ -1479,7 +1502,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1479
1502
  dim = self._resolve_dim(dim)
1480
1503
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
1481
1504
 
1482
- def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
1505
+ def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
1483
1506
  """
1484
1507
  Rolls the tensor along specified dimension(s).
1485
1508
  The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
@@ -1499,12 +1522,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1499
1522
  rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
1500
1523
  return rolled
1501
1524
 
1525
+ def rearrange(self, formula:str, **sizes) -> Tensor:
1526
+ """
1527
+ Rearranges input according to formula
1528
+
1529
+ See: https://einops.rocks/api/rearrange/
1530
+
1531
+ ```python exec="true" source="above" session="tensor" result="python"
1532
+ x = Tensor([[1, 2], [3, 4]])
1533
+ print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
1534
+ ```
1535
+ """
1536
+ def parse_formula(formula: str):
1537
+ tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
1538
+ lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
1539
+ pairs = list(zip(lparens, rparens))
1540
+ assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
1541
+ return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
1542
+
1543
+ assert formula.count("->") == 1, 'need exactly one "->" in formula'
1544
+
1545
+ (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
1546
+
1547
+ for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
1548
+ assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
1549
+ for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
1550
+ assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
1551
+ assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
1552
+
1553
+ # resolve ellipsis
1554
+ if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
1555
+ lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
1556
+ unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
1557
+ flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
1558
+
1559
+ # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
1560
+ t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
1561
+ for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
1562
+ t = t.permute([lhs.index(name) for name in rhs])
1563
+ return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1564
+
1502
1565
  # ***** reduce ops *****
1503
1566
 
1504
- def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1567
+ def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1505
1568
  axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
1506
1569
  if self.ndim == 0: axis = ()
1507
- ret = fxn.apply(self, axis=axis)
1570
+ ret = self._apply_uop(UOp.r, op=op, axis=axis)
1508
1571
  return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
1509
1572
 
1510
1573
  def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1531,7 +1594,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1531
1594
  print(t.sum(axis=1).numpy())
1532
1595
  ```
1533
1596
  """
1534
- ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
1597
+ ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
1535
1598
  return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1536
1599
 
1537
1600
  def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1558,7 +1621,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1558
1621
  print(t.prod(axis=1).numpy())
1559
1622
  ```
1560
1623
  """
1561
- return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
1624
+ return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
1562
1625
 
1563
1626
  def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1564
1627
  """
@@ -1581,7 +1644,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1581
1644
  print(t.max(axis=1, keepdim=True).numpy())
1582
1645
  ```
1583
1646
  """
1584
- return self._reduce(F.Max, axis, keepdim)
1647
+ return self._reduce(Ops.MAX, axis, keepdim)
1648
+
1649
+ def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
1585
1650
 
1586
1651
  def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1587
1652
  """
@@ -1604,8 +1669,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1604
1669
  print(t.min(axis=1, keepdim=True).numpy())
1605
1670
  ```
1606
1671
  """
1607
- if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
1608
- return -((-self).max(axis=axis, keepdim=keepdim))
1672
+ return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
1609
1673
 
1610
1674
  def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1611
1675
  """
@@ -1745,8 +1809,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1745
1809
  return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
1746
1810
 
1747
1811
  def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
1748
- x = self.cast(dtype) if dtype is not None else self
1749
- m = x - x.max(axis=axis, keepdim=True).detach()
1812
+ m = self - self.max(axis=axis, keepdim=True).detach()
1813
+ if dtype is not None: m = m.cast(dtype)
1750
1814
  e = m.exp()
1751
1815
  return m, e, e.sum(axis=axis, keepdim=True)
1752
1816
 
@@ -1898,47 +1962,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1898
1962
  print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
1899
1963
  ```
1900
1964
  """
1901
- return (-self).argmax(axis=axis, keepdim=keepdim)
1902
-
1903
- def rearrange(self, formula: str, **sizes) -> Tensor:
1904
- """
1905
- Rearranges input according to formula
1906
-
1907
- See: https://einops.rocks/api/rearrange/
1908
-
1909
- ```python exec="true" source="above" session="tensor" result="python"
1910
- x = Tensor([[1, 2], [3, 4]])
1911
- print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
1912
- ```
1913
- """
1914
- def parse_formula(formula: str):
1915
- tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
1916
- lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
1917
- pairs = list(zip(lparens, rparens))
1918
- assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
1919
- return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
1920
-
1921
- assert formula.count("->") == 1, 'need exactly one "->" in formula'
1922
-
1923
- (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
1924
-
1925
- for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
1926
- assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
1927
- for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
1928
- assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
1929
- assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
1930
-
1931
- # resolve ellipsis
1932
- if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
1933
- lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
1934
- unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
1935
- flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
1936
-
1937
- # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
1938
- t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
1939
- for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
1940
- t = t.permute([lhs.index(name) for name in rhs])
1941
- return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1965
+ return self._inverse().argmax(axis=axis, keepdim=keepdim)
1942
1966
 
1943
1967
  @staticmethod
1944
1968
  def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
@@ -1964,7 +1988,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1964
1988
  (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
1965
1989
  return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1966
1990
 
1967
- xs:Tuple[Tensor, ...] = argfix(*operands)
1991
+ xs:tuple[Tensor, ...] = argfix(*operands)
1968
1992
  inputs_str, output = parse_formula(formula, *xs)
1969
1993
  inputs = inputs_str.split(",")
1970
1994
  assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
@@ -1972,7 +1996,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1972
1996
  # map the value of each letter in the formula
1973
1997
  letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
1974
1998
 
1975
- xs_:List[Tensor] = []
1999
+ xs_:list[Tensor] = []
1976
2000
  lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
1977
2001
  for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
1978
2002
  # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
@@ -1987,7 +2011,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1987
2011
 
1988
2012
  # ***** processing ops *****
1989
2013
 
1990
- def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
2014
+ def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
1991
2015
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
1992
2016
  s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
1993
2017
  assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@@ -1995,10 +2019,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1995
2019
  assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
1996
2020
  o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
1997
2021
  if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
1998
- # repeats such that we don't need padding
1999
- x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)])
2022
+ # input size scaling factor to make sure shrink for stride is possible
2023
+ f_ = [1 + int(resolve(o*s > i+d)) for o,s,i,d in zip(o_,s_,i_,d_)]
2024
+ # # repeats such that we don't need padding
2025
+ x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
2000
2026
  # handle dilation
2001
- x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
2027
+ x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
2002
2028
  # handle stride
2003
2029
  x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
2004
2030
  x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
@@ -2010,14 +2036,44 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2010
2036
  x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
2011
2037
  return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
2012
2038
 
2013
- def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2039
+ def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2040
+ if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
2041
+ raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
2014
2042
  return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
2015
2043
 
2044
+ def _apply_ceil_mode(self, pads:Sequence[int], k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int]) -> List[int]:
2045
+ (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
2046
+ pads, grouped_pads = list(pads), _flat_to_grouped(pads)
2047
+ # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
2048
+ o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)]
2049
+ for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)):
2050
+ # we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
2051
+ # `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
2052
+ # we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
2053
+ # `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset)
2054
+ pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0)
2055
+ return pads
2056
+
2016
2057
  # NOTE: these work for more than 2D
2017
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
2058
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
2018
2059
  """
2019
2060
  Applies average pooling over a tensor.
2020
2061
 
2062
+ This function supports three different types of `padding`
2063
+
2064
+ 1. `int` (single value):
2065
+ Applies the same padding value uniformly to all spatial dimensions.
2066
+
2067
+ 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2068
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2069
+
2070
+ 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2071
+ Specifies explicit padding for each side of each spatial dimension in the form
2072
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2073
+
2074
+ When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2075
+ When `count_include_pad` is set to `False`, zero padding will not be included in the averaging calculation.
2076
+
2021
2077
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2022
2078
 
2023
2079
  See: https://paperswithcode.com/method/average-pooling
@@ -2027,17 +2083,43 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2027
2083
  print(t.avg_pool2d().numpy())
2028
2084
  ```
2029
2085
  ```python exec="true" source="above" session="tensor" result="python"
2086
+ print(t.avg_pool2d(ceil_mode=True).numpy())
2087
+ ```
2088
+ ```python exec="true" source="above" session="tensor" result="python"
2030
2089
  print(t.avg_pool2d(padding=1).numpy())
2031
2090
  ```
2091
+ ```python exec="true" source="above" session="tensor" result="python"
2092
+ print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
2093
+ ```
2032
2094
  """
2033
- padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
2034
- def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
2035
- return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
2095
+ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
2096
+ def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
2097
+ reg_pads = self._resolve_pool_pads(padding, len(k_))
2098
+ ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation)
2099
+ if not count_include_pad:
2100
+ pads = ceil_pads if ceil_mode else reg_pads
2101
+ return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
2102
+ if not ceil_mode: return pool(self, reg_pads).mean(axis)
2103
+ return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
2036
2104
 
2037
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
2105
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
2038
2106
  """
2039
2107
  Applies max pooling over a tensor.
2040
2108
 
2109
+ This function supports three different types of `padding`
2110
+
2111
+ 1. `int` (single value):
2112
+ Applies the same padding value uniformly to all spatial dimensions.
2113
+
2114
+ 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2115
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2116
+
2117
+ 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2118
+ Specifies explicit padding for each side of each spatial dimension in the form
2119
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2120
+
2121
+ When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2122
+
2041
2123
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2042
2124
 
2043
2125
  See: https://paperswithcode.com/method/max-pooling
@@ -2047,17 +2129,33 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2047
2129
  print(t.max_pool2d().numpy())
2048
2130
  ```
2049
2131
  ```python exec="true" source="above" session="tensor" result="python"
2132
+ print(t.max_pool2d(ceil_mode=True).numpy())
2133
+ ```
2134
+ ```python exec="true" source="above" session="tensor" result="python"
2050
2135
  print(t.max_pool2d(padding=1).numpy())
2051
2136
  ```
2052
2137
  """
2053
- padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
2054
- return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2138
+ pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
2139
+ if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
2140
+ return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2055
2141
 
2056
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
2142
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
2057
2143
  acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2058
2144
  """
2059
2145
  Applies a convolution over a tensor with a given `weight` and optional `bias`.
2060
2146
 
2147
+ This function supports three different types of `padding`
2148
+
2149
+ 1. `int` (single value):
2150
+ Applies the same padding value uniformly to all spatial dimensions.
2151
+
2152
+ 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2153
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2154
+
2155
+ 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2156
+ Specifies explicit padding for each side of each spatial dimension in the form
2157
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2158
+
2061
2159
  NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
2062
2160
 
2063
2161
  See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
@@ -2070,9 +2168,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2070
2168
  """
2071
2169
  if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
2072
2170
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
2171
+ padding_ = self._resolve_pool_pads(padding, len(HW))
2073
2172
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
2074
- if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
2075
- padding_ = self._padding2d(padding, len(HW))
2076
2173
 
2077
2174
  # conv2d is a pooling op (with padding)
2078
2175
  x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
@@ -2120,6 +2217,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2120
2217
  """
2121
2218
  Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
2122
2219
 
2220
+ This function supports three different types of `padding`
2221
+
2222
+ 1. `int` (single value):
2223
+ Applies the same padding value uniformly to all spatial dimensions.
2224
+
2225
+ 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2226
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2227
+
2228
+ 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2229
+ Specifies explicit padding for each side of each spatial dimension in the form
2230
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2231
+
2123
2232
  NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
2124
2233
 
2125
2234
  See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
@@ -2132,14 +2241,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2132
2241
  """
2133
2242
  x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
2134
2243
  HW = weight.shape[2:]
2135
- stride, dilation, padding, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
2244
+ padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW)))
2245
+ stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
2136
2246
  if any(s>1 for s in stride):
2137
2247
  # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
2138
2248
  x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
2139
2249
  x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
2140
2250
  x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
2141
2251
  x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
2142
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
2252
+ padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
2143
2253
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
2144
2254
 
2145
2255
  def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
@@ -2185,15 +2295,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2185
2295
  """
2186
2296
  return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
2187
2297
 
2188
- def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
2189
- assert self.shape[axis] != 0
2190
- pl_sz = self.shape[axis] - int(not _first_zero)
2191
- return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
2298
+ def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
2299
+ assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
2300
+ pl_sz = self.shape[axis] - int(not _include_initial)
2301
+ pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
2302
+ return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
2303
+
2304
+ def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
2305
+ axis = self._resolve_dim(axis)
2306
+ if self.ndim == 0 or 0 in self.shape: return self
2307
+ # TODO: someday the optimizer will find this on it's own
2308
+ # for now this is a two stage cumsum
2309
+ SPLIT = 256
2310
+ if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
2311
+ ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
2312
+ base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
2313
+ base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
2314
+ def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2315
+ return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
2316
+
2192
2317
  def cumsum(self, axis:int=0) -> Tensor:
2193
2318
  """
2194
- Computes the cumulative sum of the tensor along the specified axis.
2195
-
2196
- You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
2319
+ Computes the cumulative sum of the tensor along the specified `axis`.
2197
2320
 
2198
2321
  ```python exec="true" source="above" session="tensor" result="python"
2199
2322
  t = Tensor.ones(2, 3)
@@ -2203,17 +2326,21 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2203
2326
  print(t.cumsum(1).numpy())
2204
2327
  ```
2205
2328
  """
2206
- axis = self._resolve_dim(axis)
2207
- if self.ndim == 0 or 0 in self.shape: return self
2208
- # TODO: someday the optimizer will find this on it's own
2209
- # for now this is a two stage cumsum
2210
- SPLIT = 256
2211
- if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
2212
- ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
2213
- base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
2214
- base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
2215
- def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2216
- return fix(ret) + fix(base_add)
2329
+ return self._split_cumalu(axis, Ops.ADD)
2330
+
2331
+ def cummax(self, axis:int=0) -> Tensor:
2332
+ """
2333
+ Computes the cumulative max of the tensor along the specified `axis`.
2334
+
2335
+ ```python exec="true" source="above" session="tensor" result="python"
2336
+ t = Tensor([0, 1, -1, 2, -2, 3, -3])
2337
+ print(t.numpy())
2338
+ ```
2339
+ ```python exec="true" source="above" session="tensor" result="python"
2340
+ print(t.cummax(0).numpy())
2341
+ ```
2342
+ """
2343
+ return self._split_cumalu(axis, Ops.MAX)
2217
2344
 
2218
2345
  @staticmethod
2219
2346
  def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
@@ -2271,7 +2398,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2271
2398
  """
2272
2399
  return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
2273
2400
 
2274
- def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2401
+ def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2275
2402
  """
2276
2403
  Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
2277
2404
 
@@ -2303,6 +2430,47 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2303
2430
  x = x.gather(i, index)
2304
2431
  return x.cast(self.dtype)
2305
2432
 
2433
+ def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
2434
+ """
2435
+ Scatters `src` values along an axis specified by `dim`.
2436
+ Apply `add` or `multiply` reduction operation with `reduce`.
2437
+
2438
+ ```python exec="true" source="above" session="tensor" result="python"
2439
+ src = Tensor.arange(1, 11).reshape(2, 5)
2440
+ print(src.numpy())
2441
+ ```
2442
+ ```python exec="true" source="above" session="tensor" result="python"
2443
+ index = Tensor([[0, 1, 2, 0]])
2444
+ print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
2445
+ ```
2446
+ ```python exec="true" source="above" session="tensor" result="python"
2447
+ index = Tensor([[0, 1, 2], [0, 1, 4]])
2448
+ print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
2449
+ ```
2450
+ ```python exec="true" source="above" session="tensor" result="python"
2451
+ print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
2452
+ ```
2453
+ ```python exec="true" source="above" session="tensor" result="python"
2454
+ print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
2455
+ ```
2456
+ """
2457
+ if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
2458
+ index, dim = index.to(self.device), self._resolve_dim(dim)
2459
+ src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
2460
+ assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2461
+ assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
2462
+ f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
2463
+ # shrink src to index shape to shrink away the unused values
2464
+ src = src.shrink(tuple((0,s) for s in index.shape))
2465
+ # prepare src and mask for reduce with respect to dim
2466
+ src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
2467
+ mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
2468
+ # pad src and mask to self.shape so that reduce can be done with padded values as no-ops
2469
+ src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
2470
+ if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
2471
+ if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
2472
+ return _masked_setitem(self, src, mask, (-1,))
2473
+
2306
2474
  # ***** unary ops *****
2307
2475
 
2308
2476
  def logical_not(self):
@@ -2313,7 +2481,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2313
2481
  print(Tensor([False, True]).logical_not().numpy())
2314
2482
  ```
2315
2483
  """
2316
- return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
2484
+ return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
2317
2485
  def neg(self):
2318
2486
  """
2319
2487
  Negates the tensor element-wise.
@@ -2327,12 +2495,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2327
2495
  """
2328
2496
  Returns a contiguous tensor.
2329
2497
  """
2330
- return F.Contiguous.apply(self)
2498
+ return self._apply_uop(UOp.contiguous)
2331
2499
  def contiguous_backward(self):
2332
2500
  """
2333
2501
  Inserts a contiguous operation in the backward pass.
2334
2502
  """
2335
- return F.ContiguousBackward.apply(self)
2503
+ return self._apply_uop(UOp.contiguous_backward)
2336
2504
  def log(self):
2337
2505
  """
2338
2506
  Computes the natural logarithm element-wise.
@@ -2343,7 +2511,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2343
2511
  print(Tensor([1., 2., 4., 8.]).log().numpy())
2344
2512
  ```
2345
2513
  """
2346
- return F.Log.apply(self.cast(least_upper_float(self.dtype)))
2514
+ return self.log2()*math.log(2)
2347
2515
  def log2(self):
2348
2516
  """
2349
2517
  Computes the base-2 logarithm element-wise.
@@ -2354,7 +2522,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2354
2522
  print(Tensor([1., 2., 4., 8.]).log2().numpy())
2355
2523
  ```
2356
2524
  """
2357
- return self.log()/math.log(2)
2525
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
2358
2526
  def exp(self):
2359
2527
  """
2360
2528
  Computes the exponential function element-wise.
@@ -2365,7 +2533,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2365
2533
  print(Tensor([0., 1., 2., 3.]).exp().numpy())
2366
2534
  ```
2367
2535
  """
2368
- return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
2536
+ return self.mul(1/math.log(2)).exp2()
2369
2537
  def exp2(self):
2370
2538
  """
2371
2539
  Computes the base-2 exponential function element-wise.
@@ -2376,7 +2544,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2376
2544
  print(Tensor([0., 1., 2., 3.]).exp2().numpy())
2377
2545
  ```
2378
2546
  """
2379
- return F.Exp.apply(self*math.log(2))
2547
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
2380
2548
  def relu(self):
2381
2549
  """
2382
2550
  Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -2387,7 +2555,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2387
2555
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
2388
2556
  ```
2389
2557
  """
2390
- return F.Relu.apply(self)
2558
+ return (self>0).where(self, 0)
2559
+
2391
2560
  def sigmoid(self):
2392
2561
  """
2393
2562
  Applies the Sigmoid function element-wise.
@@ -2398,7 +2567,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2398
2567
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
2399
2568
  ```
2400
2569
  """
2401
- return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
2570
+ return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
2571
+
2402
2572
  def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
2403
2573
  """
2404
2574
  Applies the Hardsigmoid function element-wise.
@@ -2421,7 +2591,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2421
2591
  print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
2422
2592
  ```
2423
2593
  """
2424
- return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
2594
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
2425
2595
  def rsqrt(self):
2426
2596
  """
2427
2597
  Computes the reciprocal of the square root of the tensor element-wise.
@@ -2439,7 +2609,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2439
2609
  print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
2440
2610
  ```
2441
2611
  """
2442
- return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
2612
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
2443
2613
  def cos(self):
2444
2614
  """
2445
2615
  Computes the cosine of the tensor element-wise.
@@ -2459,6 +2629,39 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2459
2629
  """
2460
2630
  return self.sin() / self.cos()
2461
2631
 
2632
+ def asin(self):
2633
+ """
2634
+ Computes the inverse sine (arcsine) of the tensor element-wise.
2635
+
2636
+ ```python exec="true" source="above" session="tensor" result="python"
2637
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
2638
+ ```
2639
+ """
2640
+ # https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
2641
+ coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
2642
+ x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
2643
+ return self.sign() * x
2644
+
2645
+ def acos(self):
2646
+ """
2647
+ Computes the inverse cosine (arccosine) of the tensor element-wise.
2648
+
2649
+ ```python exec="true" source="above" session="tensor" result="python"
2650
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
2651
+ ```
2652
+ """
2653
+ return math.pi / 2 - self.asin()
2654
+
2655
+ def atan(self):
2656
+ """
2657
+ Computes the inverse tangent (arctan) of the tensor element-wise.
2658
+
2659
+ ```python exec="true" source="above" session="tensor" result="python"
2660
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
2661
+ ```
2662
+ """
2663
+ return (self / (1 + self * self).sqrt()).asin()
2664
+
2462
2665
  # ***** math functions *****
2463
2666
 
2464
2667
  def trunc(self: Tensor) -> Tensor:
@@ -2565,7 +2768,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2565
2768
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
2566
2769
  ```
2567
2770
  """
2568
- return F.Sign.apply(self)
2771
+ return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
2569
2772
  def abs(self):
2570
2773
  """
2571
2774
  Computes the absolute value of the tensor element-wise.
@@ -2583,7 +2786,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2583
2786
  print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
2584
2787
  ```
2585
2788
  """
2586
- return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
2789
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
2587
2790
 
2588
2791
  # ***** activation functions *****
2589
2792
 
@@ -2613,6 +2816,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2613
2816
  """
2614
2817
  return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2615
2818
 
2819
+ def selu(self, alpha=1.67326, gamma=1.0507):
2820
+ """
2821
+ Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
2822
+
2823
+ - Described: https://paperswithcode.com/method/selu
2824
+ - Paper: https://arxiv.org/abs/1706.02515v5
2825
+
2826
+ ```python exec="true" source="above" session="tensor" result="python"
2827
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
2828
+ ```
2829
+ """
2830
+ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
2831
+
2616
2832
  def swish(self):
2617
2833
  """
2618
2834
  See `.silu()`
@@ -2840,17 +3056,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2840
3056
  return self / (1 + self.abs())
2841
3057
 
2842
3058
  # ***** broadcasted elementwise ops *****
2843
- def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
2844
- if self.shape == shape: return self
2845
- if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
2846
- # first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
2847
- padded, _ = _pad_left(self.shape, shape)
2848
- # for each dimension, check either from_ is 1, or it does not change
2849
- if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
2850
- raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2851
- return F.Expand.apply(self.reshape(padded), shape=shape)
2852
-
2853
- def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
3059
+ def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor:
3060
+ if self.shape == new_shape: return self
3061
+ if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
3062
+ # first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
3063
+ shape, _ = _align_left(self.shape, new_shape)
3064
+ # for each dimension, check either dim is 1, or it does not change
3065
+ if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
3066
+ raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
3067
+ return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
3068
+
3069
+ def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
2854
3070
  x: Tensor = self
2855
3071
  if not isinstance(y, Tensor):
2856
3072
  # make y a Tensor
@@ -2867,12 +3083,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2867
3083
  if reverse: x, y = y, x
2868
3084
 
2869
3085
  # broadcast
2870
- out_shape = _broadcast_shape(x.shape, y.shape)
2871
- return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
3086
+ return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
2872
3087
 
3088
+ # TODO: tensor should stop checking if things are const
2873
3089
  def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
2874
- return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
2875
- and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
3090
+ return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, UOp) and x.lazydata.base.op is Ops.CONST \
3091
+ and unwrap(x.lazydata.st).views[0].mask is None and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
2876
3092
 
2877
3093
  def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2878
3094
  """
@@ -2892,7 +3108,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2892
3108
  print(t.add(Tensor([[2.0], [3.5]])).numpy())
2893
3109
  ```
2894
3110
  """
2895
- return F.Add.apply(*self._broadcasted(x, reverse))
3111
+ return self._apply_broadcasted_uop(UOp.add, x, reverse)
2896
3112
 
2897
3113
  def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2898
3114
  """
@@ -2933,20 +3149,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2933
3149
  print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
2934
3150
  ```
2935
3151
  """
2936
- return F.Mul.apply(*self._broadcasted(x, reverse))
3152
+ return self._apply_broadcasted_uop(UOp.mul, x, reverse)
2937
3153
 
2938
3154
  def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2939
3155
  """
2940
3156
  Divides `self` by `x`.
2941
3157
  Equivalent to `self // x`.
2942
3158
  Supports broadcasting to a common shape, type promotion, and integer inputs.
2943
- `idiv` performs integer division.
3159
+ `idiv` performs integer division (truncate towards zero).
2944
3160
 
2945
3161
  ```python exec="true" source="above" session="tensor" result="python"
2946
- print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy())
3162
+ print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
2947
3163
  ```
2948
3164
  """
2949
- return F.IDiv.apply(*self._broadcasted(x, reverse))
3165
+ return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
2950
3166
 
2951
3167
  def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2952
3168
  """
@@ -2970,6 +3186,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2970
3186
  numerator, denominator = self._broadcasted(x, reverse)
2971
3187
  return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
2972
3188
 
3189
+ def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3190
+ """
3191
+ Mod `self` by `x`.
3192
+ Equivalent to `self % x`.
3193
+ Supports broadcasting to a common shape, type promotion, and integer inputs.
3194
+
3195
+ ```python exec="true" source="above" session="tensor" result="python"
3196
+ print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
3197
+ ```
3198
+ """
3199
+ a, b = self._broadcasted(x, reverse)
3200
+ return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
3201
+
2973
3202
  def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2974
3203
  """
2975
3204
  Computes bitwise xor of `self` and `x`.
@@ -2984,7 +3213,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2984
3213
  ```
2985
3214
  """
2986
3215
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
2987
- return F.Xor.apply(*self._broadcasted(x, reverse))
3216
+ return self._apply_broadcasted_uop(UOp.xor, x, reverse)
2988
3217
 
2989
3218
  def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2990
3219
  """
@@ -2999,7 +3228,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2999
3228
  ```
3000
3229
  """
3001
3230
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3002
- return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
3231
+ return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
3003
3232
 
3004
3233
  def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3005
3234
  """
@@ -3014,7 +3243,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3014
3243
  ```
3015
3244
  """
3016
3245
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3017
- return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
3246
+ return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
3018
3247
 
3019
3248
  def bitwise_not(self) -> Tensor:
3020
3249
  """
@@ -3028,7 +3257,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3028
3257
  ```
3029
3258
  """
3030
3259
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3031
- return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
3260
+ return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
3032
3261
 
3033
3262
  def lshift(self, x:int):
3034
3263
  """
@@ -3072,8 +3301,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3072
3301
  x = self._to_const_val(x)
3073
3302
  if not isinstance(x, Tensor) and not reverse:
3074
3303
  # simple pow identities
3075
- if x < 0: return self.reciprocal().pow(-x)
3304
+ if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
3076
3305
  if x == 0: return 1 + self * 0
3306
+ # rewrite pow 0.5 to sqrt
3077
3307
  if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
3078
3308
  if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
3079
3309
 
@@ -3081,16 +3311,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3081
3311
  if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
3082
3312
 
3083
3313
  base, exponent = self._broadcasted(x, reverse=reverse)
3314
+ # TODO: int pow
3315
+ if not base.is_floating_point(): raise RuntimeError("base needs to be float")
3084
3316
  # start with b ** e = exp(e * log(b))
3085
3317
  ret = base.abs().log().mul(exponent).exp()
3086
- # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
3087
- negative_base = (base < 0).detach().where(1, 0)
3088
- # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
3089
- correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
3090
- # inject nan for negative base and non-integer exponent
3091
- inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
3092
- # apply correct_sign inject_nan, and fix 0 ** 0 = 1
3093
- return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
3318
+ # negative base adjustment: nan for non-integer exponent and -1 for odd exponent
3319
+ adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
3320
+ # fix 0 ** 0 = 1
3321
+ ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
3322
+ return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
3094
3323
 
3095
3324
  def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
3096
3325
  """
@@ -3103,7 +3332,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3103
3332
  print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
3104
3333
  ```
3105
3334
  """
3106
- return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
3335
+ # NOTE: the mid-point is for backward, revisit after new gradient API
3336
+ if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
3337
+ return (self<x).detach().where(x, self)
3107
3338
 
3108
3339
  def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
3109
3340
  """
@@ -3116,9 +3347,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3116
3347
  print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
3117
3348
  ```
3118
3349
  """
3119
- return -((-self).maximum(-x))
3350
+ t, x = self._broadcasted(x)
3351
+ return t._inverse().maximum(x._inverse())._inverse()
3120
3352
 
3121
- def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
3353
+ def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
3122
3354
  """
3123
3355
  Return a tensor of elements selected from either `x` or `y`, depending on `self`.
3124
3356
  `output_i = x_i if self_i else y_i`.
@@ -3140,7 +3372,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3140
3372
  elif isinstance(y, Tensor): y, x = y._broadcasted(x)
3141
3373
  cond, x = self._broadcasted(x, match_dtype=False)
3142
3374
  cond, y = cond._broadcasted(y, match_dtype=False)
3143
- return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
3375
+ return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
3144
3376
 
3145
3377
  def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
3146
3378
 
@@ -3170,9 +3402,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3170
3402
  def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
3171
3403
  def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
3172
3404
 
3173
- def lt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
3174
- def gt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
3175
- def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
3405
+ def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
3406
+ def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
3407
+ def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
3176
3408
 
3177
3409
  def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
3178
3410
 
@@ -3194,7 +3426,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3194
3426
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
3195
3427
  return x.add(bias) if bias is not None else x
3196
3428
 
3197
- def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
3429
+ def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
3198
3430
  """
3199
3431
  Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
3200
3432
 
@@ -3205,7 +3437,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3205
3437
  """
3206
3438
  return functools.reduce(lambda x,f: f(x), ll, self)
3207
3439
 
3208
- def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
3440
+ def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
3209
3441
  """
3210
3442
  Applies Layer Normalization over a mini-batch of inputs.
3211
3443
 
@@ -3224,7 +3456,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3224
3456
  y = (self - self.mean(axis, keepdim=True))
3225
3457
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
3226
3458
 
3227
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
3459
+ def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
3228
3460
  """
3229
3461
  Applies Batch Normalization over a mini-batch of inputs.
3230
3462
 
@@ -3266,6 +3498,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3266
3498
  if not Tensor.training or p == 0: return self
3267
3499
  return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
3268
3500
 
3501
+ # helper function commonly used for indexing
3502
+ def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
3503
+ offset = self.ndim - self._resolve_dim(dim) - 1
3504
+ return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
3505
+
3269
3506
  def one_hot(self, num_classes:int=-1) -> Tensor:
3270
3507
  """
3271
3508
  Converts `self` to a one-hot tensor.
@@ -3278,10 +3515,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3278
3515
  ```
3279
3516
  """
3280
3517
  if num_classes == -1: num_classes = (self.max()+1).item()
3281
- return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
3518
+ return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
3282
3519
 
3283
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
3284
- dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
3520
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
3285
3521
  """
3286
3522
  Computes scaled dot-product attention.
3287
3523
  `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
@@ -3298,14 +3534,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3298
3534
  """
3299
3535
  # NOTE: it also works when `key` and `value` have symbolic shape.
3300
3536
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
3301
- if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
3302
- if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
3303
3537
  qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
3304
- return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3538
+ # handle attention mask
3539
+ if is_causal:
3540
+ if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
3541
+ attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
3542
+ if attn_mask is not None:
3543
+ if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
3544
+ qk = qk + attn_mask
3545
+ return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3305
3546
 
3306
3547
  def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
3307
3548
  if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
3308
- reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
3549
+ reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
3309
3550
  return reductions[reduction](self)
3310
3551
 
3311
3552
  def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
@@ -3354,8 +3595,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3354
3595
  assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
3355
3596
  assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
3356
3597
  log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
3357
- y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
3358
- y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
3598
+ y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
3599
+ y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
3359
3600
  smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
3360
3601
  unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
3361
3602
  # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
@@ -3469,7 +3710,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3469
3710
  """
3470
3711
  return dtypes.is_float(self.dtype)
3471
3712
 
3472
- def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
3713
+ def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
3473
3714
  """
3474
3715
  Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
3475
3716
 
@@ -3488,7 +3729,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3488
3729
  def llvm_bf16_cast(self, dtype:DTypeLike):
3489
3730
  # hack for devices that don't support bfloat16
3490
3731
  assert self.dtype == dtypes.bfloat16
3491
- return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
3732
+ return self.to("LLVM").cast(dtype)
3492
3733
 
3493
3734
  def cast(self, dtype:DTypeLike) -> Tensor:
3494
3735
  """
@@ -3502,8 +3743,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3502
3743
  t = t.cast(dtypes.int32)
3503
3744
  print(t.dtype, t.numpy())
3504
3745
  ```
3746
+ ```python exec="true" source="above" session="tensor" result="python"
3747
+ t = t.cast(dtypes.uint8)
3748
+ print(t.dtype, t.numpy())
3749
+ ```
3505
3750
  """
3506
- return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
3751
+ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
3752
+ # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
3753
+ return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
3754
+ return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
3507
3755
 
3508
3756
  def bitcast(self, dtype:DTypeLike) -> Tensor:
3509
3757
  """
@@ -3522,13 +3770,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3522
3770
  """
3523
3771
  if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
3524
3772
  dt = to_dtype(dtype)
3525
- if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
3526
- if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
3773
+ if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
3774
+ if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
3527
3775
  new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
3528
3776
  tmp = self.bitcast(old_uint)
3529
3777
  if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
3530
3778
  return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
3531
- return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
3779
+ return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
3532
3780
 
3533
3781
  def float(self) -> Tensor:
3534
3782
  """
@@ -3650,7 +3898,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3650
3898
  else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
3651
3899
 
3652
3900
  # prepare input
3653
- x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
3901
+ x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
3654
3902
  x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
3655
3903
 
3656
3904
  # prepare weights
@@ -3702,5 +3950,5 @@ def _metadata_wrapper(fn):
3702
3950
 
3703
3951
  if TRACEMETA >= 1:
3704
3952
  for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
3705
- if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
3953
+ if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
3706
3954
  setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))