tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,47 +1,53 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
3
+ import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
4
4
  from contextlib import ContextDecorator
5
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
6
- from collections import defaultdict
7
-
5
+ from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
8
6
  from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
9
7
  from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
10
- from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
11
- from tinygrad.multi import MultiLazyBuffer
12
- from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
13
- from tinygrad.device import Device, Buffer, BufferOptions
14
- from tinygrad.engine.lazy import LazyBuffer
8
+ from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
9
+ from tinygrad.engine.multi import get_multi_map
10
+ from tinygrad.gradient import compute_gradient
11
+ from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
12
+ from tinygrad.spec import tensor_uop_spec, type_verify
13
+ from tinygrad.device import Device, BufferSpec
15
14
  from tinygrad.engine.realize import run_schedule
16
15
  from tinygrad.engine.memory import memory_planner
17
16
  from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
18
17
 
19
- # **** start with two base classes, Tensor and Function ****
20
-
21
- class Function:
22
- def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
23
- self.device = device
24
- self.needs_input_grad = [t.requires_grad for t in tensors]
25
- self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
26
- if self.requires_grad: self.parents = tensors
27
- self.metadata = metadata
28
-
29
- def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
30
- def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
31
-
32
- @classmethod
33
- def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
34
- ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
35
- ret = Tensor.__new__(Tensor)
36
- ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
- ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
38
- return ret
18
+ # *** all in scope Tensors are here. this gets relevant UOps ***
19
+
20
+ all_tensors: set[weakref.ref[Tensor]] = set()
21
+
22
+ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
23
+ # get all children of keys in applied_map
24
+ all_uops: set[UOp] = set()
25
+ search_uops = list(applied_map)
26
+ while len(search_uops):
27
+ x = search_uops.pop(0)
28
+ if x in all_uops: continue
29
+ all_uops.add(x)
30
+ search_uops.extend([u for c in x.children if (u:=c()) is not None])
31
+
32
+ # link the found UOps back to Tensors. exit early if there's no Tensors to realize
33
+ # NOTE: this uses all_tensors, but it's fast
34
+ fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
39
35
 
40
- import tinygrad.function as F
36
+ if len(fixed_tensors):
37
+ # potentially rewrite all the discovered Tensors
38
+ sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
39
+ new_sink = sink.substitute(applied_map)
41
40
 
42
- def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
- if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
44
- return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
41
+ # set the relevant lazydata to the realized UOps
42
+ for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
43
+ if s is ns: continue
44
+ t.lazydata = ns
45
+
46
+ # **** Tensor helper functions ****
47
+
48
+ def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
49
+ if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
50
+ return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
45
51
 
46
52
  def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
47
53
  import numpy as np
@@ -50,33 +56,31 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
50
56
  import numpy as np
51
57
  return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
52
58
 
53
- def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
54
- ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
59
+ def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
60
+ ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
55
61
  # fake realize
56
62
  ret.buffer.allocate(x)
57
- del ret.srcs
58
63
  return ret
59
64
 
60
- def get_shape(x) -> Tuple[int, ...]:
65
+ def get_shape(x) -> tuple[int, ...]:
61
66
  # NOTE: str is special because __getitem__ on a str is still a str
62
67
  if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
63
68
  if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
64
69
  return (len(subs),) + (subs[0] if subs else ())
65
70
 
66
- def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
67
- if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
71
+ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
72
+ if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
68
73
  else:
69
- ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
74
+ ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
70
75
  assert dtype.fmt is not None, f"{dtype=} has None fmt"
71
76
  truncate_function = truncate[dtype]
72
77
  data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
73
78
  # fake realize
74
79
  ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
75
- del ret.srcs
76
80
  return ret
77
81
 
78
- def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
79
- return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
82
+ def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
83
+ return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
80
84
  for k in range(len(mat[0]))] for dim in range(dims)]
81
85
 
82
86
  # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
@@ -85,21 +89,34 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
85
89
  # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
86
90
  t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
87
91
  # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
88
- matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
92
+ matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
89
93
  # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
90
94
  ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
91
95
  assert isinstance(ret, Tensor), "sum didn't return a Tensor"
92
96
  return ret
93
97
 
94
- def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
98
+ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
99
+ # unsqueeze left to make every shape same length
95
100
  max_dim = max(len(shape) for shape in shapes)
96
101
  return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
97
- def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
98
- return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
102
+ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
103
+ return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
104
+
105
+ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
106
+ # apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
107
+ values = values * mask
108
+ for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
109
+ # remove extra dims from reduce
110
+ for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
111
+ # select from values for each True element in mask else select from self
112
+ return mask.where(values, target)
113
+
114
+ # `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
115
+ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
99
116
 
100
117
  ReductionStr = Literal["mean", "sum", "none"]
101
118
 
102
- class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
119
+ class Tensor(SimpleMathTrait):
103
120
  """
104
121
  A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
105
122
 
@@ -110,15 +127,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
110
127
  np.set_printoptions(precision=4)
111
128
  ```
112
129
  """
113
- __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
114
- __deletable__ = ('_ctx',)
130
+ __slots__ = "lazydata", "requires_grad", "grad"
115
131
  training: ClassVar[bool] = False
116
132
  no_grad: ClassVar[bool] = False
117
133
 
118
- def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, LazyBuffer, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
134
+ def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
119
135
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
120
136
  if dtype is not None: dtype = to_dtype(dtype)
121
- assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
122
137
  if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
123
138
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
124
139
 
@@ -129,21 +144,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
129
144
  # None (the default) will be updated to True if it's put in an optimizer
130
145
  self.requires_grad: Optional[bool] = requires_grad
131
146
 
132
- # internal variable used for autograd graph construction
133
- self._ctx: Optional[Function] = None
134
-
135
147
  # create a LazyBuffer from the different types of inputs
136
- if isinstance(data, (LazyBuffer, MultiLazyBuffer)): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
148
+ if isinstance(data, UOp):
149
+ assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
150
+ # NOTE: this is here because LazyBuffer = UOp
151
+ if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
137
152
  elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
138
153
  elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
139
- elif isinstance(data, UOp):
140
- assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
141
- data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
142
154
  elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
143
155
  elif isinstance(data, (list, tuple)):
144
156
  if dtype is None:
145
157
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
146
- else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
158
+ else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
147
159
  if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
148
160
  else: data = _frompy(data, dtype)
149
161
  elif str(type(data)) == "<class 'numpy.ndarray'>":
@@ -155,17 +167,34 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
155
167
  dtype = dtype or dtypes.uint8
156
168
  data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
157
169
 
158
- # by this point, it has to be a LazyBuffer
159
- if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
170
+ # by this point, it has to be a UOp
171
+ if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
160
172
 
161
173
  # data might be on a different device
162
- if isinstance(device, str): self.lazydata:Union[LazyBuffer, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
174
+ if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
163
175
  # if device is a tuple, we should have/construct a MultiLazyBuffer
164
- elif isinstance(data, LazyBuffer): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
176
+ elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
165
177
  else:
166
178
  assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
167
179
  self.lazydata = data
168
180
 
181
+ # add to all_tensors after construction succeeds
182
+ all_tensors.add(weakref.ref(self))
183
+ def __del__(self): all_tensors.discard(weakref.ref(self))
184
+
185
+ def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
186
+ new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
187
+ needs_input_grad = [t.requires_grad for t in (self,)+x]
188
+ return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
189
+
190
+ def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
191
+ lhs,rhs = self._broadcasted(x, reverse)
192
+ return lhs._apply_uop(fxn, rhs)
193
+
194
+ def requires_grad_(self, requires_grad=True) -> Tensor:
195
+ self.requires_grad = requires_grad
196
+ return self
197
+
169
198
  class train(ContextDecorator):
170
199
  def __init__(self, mode:bool = True): self.mode = mode
171
200
  def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
@@ -177,7 +206,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
177
206
  def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
178
207
 
179
208
  def __repr__(self):
180
- return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
209
+ ld = self.lazydata
210
+ ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
211
+ return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
181
212
 
182
213
  # Python has a non moving GC, so this should be okay
183
214
  def __hash__(self): return id(self)
@@ -189,26 +220,38 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
189
220
  return self.shape[0]
190
221
 
191
222
  @property
192
- def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
223
+ def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
193
224
 
194
225
  @property
195
- def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
226
+ def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
196
227
 
197
228
  @property
198
229
  def dtype(self) -> DType: return self.lazydata.dtype
199
230
 
200
231
  # ***** data handlers ****
201
232
 
202
- def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
233
+ def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
203
234
  """
204
235
  Creates the schedule needed to realize these Tensor(s), with Variables.
205
236
 
206
237
  NOTE: A Tensor can only be scheduled once.
207
238
  """
208
- schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
239
+ big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
240
+
241
+ # TODO: move this to scheduler tensor_map pass
242
+ if any(x.op is Ops.MULTI for x in big_sink.toposort):
243
+ # multi fixup
244
+ _apply_map_to_tensors(get_multi_map(big_sink))
245
+ big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
246
+
247
+ # verify Tensors match the spec
248
+ if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
249
+
250
+ schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
251
+ _apply_map_to_tensors(becomes_map)
209
252
  return memory_planner(schedule), var_vals
210
253
 
211
- def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
254
+ def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
212
255
  """Creates the schedule needed to realize these Tensor(s)."""
213
256
  schedule, var_vals = self.schedule_with_vars(*lst)
214
257
  assert len(var_vals) == 0
@@ -224,7 +267,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
224
267
  Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
225
268
  """
226
269
  # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
227
- assert not x.requires_grad and getattr(self, '_ctx', None) is None
228
270
  assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
229
271
  self.lazydata = x.lazydata
230
272
  return self
@@ -232,17 +274,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
232
274
  def assign(self, x) -> Tensor:
233
275
  # TODO: this is a hack for writing to DISK. remove with working assign
234
276
  if isinstance(self.device, str) and self.device.startswith("DISK"):
235
- if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
236
- self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
277
+ if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
278
+ self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
237
279
  return self
238
280
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
239
- if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
240
281
  if self.lazydata is x.lazydata: return self # a self assign is a NOOP
241
282
  # NOTE: we allow cross device assign
242
283
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
243
284
  assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
244
285
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
245
- assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
246
286
  assert not x.requires_grad # self requires_grad is okay?
247
287
  if not self.lazydata.is_realized: return self.replace(x)
248
288
  self.lazydata = self.lazydata.assign(x.lazydata)
@@ -252,15 +292,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
252
292
  """
253
293
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
254
294
  """
255
- return Tensor(self.lazydata, device=self.device, requires_grad=False)
295
+ return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
256
296
 
257
297
  def _data(self) -> memoryview:
258
298
  if 0 in self.shape: return memoryview(bytearray(0))
259
299
  # NOTE: this realizes on the object from as_buffer being a Python object
260
- cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
261
- buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
262
- if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
263
- return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
300
+ cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
301
+ buf = cast(UOp, cpu.lazydata).base.realized
302
+ assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
303
+ if self.device != "CPU": buf.options = BufferSpec(nolru=True)
304
+ return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
264
305
 
265
306
  def data(self) -> memoryview:
266
307
  """
@@ -271,9 +312,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
271
312
  print(np.frombuffer(t.data(), dtype=np.int32))
272
313
  ```
273
314
  """
274
- assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
315
+ assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
275
316
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
276
- return self._data().cast(self.dtype.fmt, self.shape)
317
+ if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
318
+ return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
277
319
 
278
320
  def item(self) -> ConstType:
279
321
  """
@@ -284,20 +326,24 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
284
326
  print(t.item())
285
327
  ```
286
328
  """
287
- assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
288
329
  assert self.numel() == 1, "must have one element for item"
289
- return self._data().cast(self.dtype.fmt)[0]
330
+ return self.data()[(0,) * len(self.shape)]
290
331
 
291
- # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
332
+ # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
292
333
  # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
293
334
  def tolist(self) -> Union[Sequence[ConstType], ConstType]:
294
335
  """
295
336
  Returns the value of this tensor as a nested list.
337
+ Returns single value for const tensor.
296
338
 
297
339
  ```python exec="true" source="above" session="tensor" result="python"
298
340
  t = Tensor([1, 2, 3, 4])
299
341
  print(t.tolist())
300
342
  ```
343
+ ```python exec="true" source="above" session="tensor" result="python"
344
+ t = Tensor(5)
345
+ print(t.tolist())
346
+ ```
301
347
  """
302
348
  return self.data().tolist()
303
349
 
@@ -311,21 +357,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
311
357
  ```
312
358
  """
313
359
  import numpy as np
314
- if self.dtype == dtypes.bfloat16: return self.float().numpy()
315
- assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
360
+ if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
361
+ assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
316
362
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
317
- return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
363
+ return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
318
364
 
319
365
  def clone(self) -> Tensor:
320
366
  """
321
- Creates a clone of this tensor allocating a seperate buffer for the data.
367
+ Creates a clone of this tensor allocating a separate buffer for the data.
322
368
  """
323
369
  ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
324
370
  if self.grad is not None: ret.grad = self.grad.clone()
325
- if hasattr(self, '_ctx'): ret._ctx = self._ctx
326
371
  return ret
327
372
 
328
- def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
373
+ def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
329
374
  """
330
375
  Moves the tensor to the given device.
331
376
  """
@@ -334,47 +379,35 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
334
379
  if not isinstance(device, str): return self.shard(device)
335
380
  ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
336
381
  if self.grad is not None: ret.grad = self.grad.to(device)
337
- if hasattr(self, '_ctx'): ret._ctx = self._ctx
338
382
  return ret
339
383
 
340
- def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
384
+ def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
341
385
  """
342
386
  Moves the tensor to the given device in place.
343
387
  """
344
388
  real = self.to(device)
345
- # TODO: is this assign?
346
- if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
347
- self.lazydata = real.lazydata
389
+ if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
390
+ return self.replace(real)
348
391
 
349
- def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
392
+ def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
350
393
  """
351
- Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
394
+ Shards the tensor across the given devices. Optionally specify which axis to shard on.
352
395
 
353
396
  ```python exec="true" source="above" session="tensor" result="python"
354
- t = Tensor.empty(2, 3)
355
- print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
397
+ t = Tensor.empty(2, 4)
398
+ print(t.shard((t.device, t.device), axis=1).lazydata)
356
399
  ```
357
-
358
400
  """
359
- assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
360
- devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
361
- if axis is not None:
362
- if axis < 0: axis += len(self.shape)
363
- if splits is None:
364
- if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
365
- sz = ceildiv(total, len(devices))
366
- splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
367
- assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
368
- boundaries = tuple(itertools.accumulate(splits))
369
- bounds = tuple(zip((0,) + boundaries, boundaries))
370
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
401
+ assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
402
+ devices = tuple(Device.canonicalize(x) for x in devices)
403
+ mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
404
+ return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
371
405
 
372
- def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
406
+ def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
373
407
  """
374
408
  Shards the tensor across the given devices in place.
375
409
  """
376
- self.lazydata = self.shard(devices, axis, splits).lazydata
377
- return self
410
+ return self.replace(self.shard(devices, axis))
378
411
 
379
412
  @staticmethod
380
413
  def from_uop(y:UOp, **kwargs) -> Tensor:
@@ -382,18 +415,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
382
415
  if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
383
416
  if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
384
417
  if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
385
- if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
386
418
  raise RuntimeError(f"unhandled UOp {y}")
387
419
 
388
420
  # ***** creation entrypoint *****
389
421
 
390
422
  @staticmethod
391
- def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
423
+ def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
392
424
  dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
393
425
  if isinstance(device, tuple):
394
- return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
426
+ return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
395
427
  device, dtype, **kwargs)
396
- return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
428
+ return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
397
429
 
398
430
  @staticmethod
399
431
  def empty(*shape, **kwargs):
@@ -411,7 +443,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
411
443
  return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
412
444
 
413
445
  @staticmethod
414
- def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
446
+ def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
415
447
  """
416
448
  Exposes the pointer as a Tensor without taking ownership of the original data.
417
449
  The pointer must remain valid for the entire lifetime of the created Tensor.
@@ -422,7 +454,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
422
454
 
423
455
  r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
424
456
  r.lazydata.buffer.allocate(external_ptr=ptr)
425
- del r.lazydata.srcs # fake realize
426
457
  return r
427
458
 
428
459
  @staticmethod
@@ -439,8 +470,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
439
470
  return Tensor(fetch(url, gunzip=gunzip), **kwargs)
440
471
 
441
472
  _seed: int = int(time.time())
442
- _device_seeds: Dict[str, Tensor] = {}
443
- _device_rng_counters: Dict[str, Tensor] = {}
473
+ _device_seeds: dict[str, Tensor] = {}
474
+ _device_rng_counters: dict[str, Tensor] = {}
444
475
  @staticmethod
445
476
  def manual_seed(seed=0):
446
477
  """
@@ -462,7 +493,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
462
493
  @staticmethod
463
494
  def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
464
495
  x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
465
- x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
496
+ x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
466
497
  counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
467
498
  return counts0.cat(counts1)
468
499
 
@@ -485,8 +516,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
485
516
  if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
486
517
  _device = device = Device.canonicalize(device)
487
518
 
488
- # when using MOCKGPU and NV generate rand on CLANG
489
- if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
519
+ # if shape has 0, return zero tensor
520
+ if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
521
+ num = ceildiv(numel * dtype.itemsize, 4)
522
+
523
+ # when using MOCKGPU and NV generate rand on CPU
524
+ if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
490
525
 
491
526
  # generate per device seeds and rng counter if we haven't seen this device yet
492
527
  if device not in Tensor._device_seeds:
@@ -494,15 +529,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
494
529
  [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
495
530
  device=device, dtype=dtypes.uint32, requires_grad=False)
496
531
  Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
497
- had_counter = False
498
- else: had_counter = True
499
-
500
- # if shape has 0, return zero tensor
501
- if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
502
- num = ceildiv(numel * dtype.itemsize, 4)
503
-
504
532
  # increment rng counter for devices
505
- if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
533
+ else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
506
534
 
507
535
  # threefry random bits
508
536
  counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
@@ -528,7 +556,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
528
556
  # ***** creation helper functions *****
529
557
 
530
558
  @staticmethod
531
- def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
559
+ def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
532
560
  """
533
561
  Creates a tensor with the given shape, filled with the given value.
534
562
 
@@ -607,7 +635,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
607
635
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
608
636
  # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
609
637
  if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
610
- return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
638
+ return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
611
639
 
612
640
  @staticmethod
613
641
  def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
@@ -705,18 +733,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
705
733
  ```
706
734
  """
707
735
  dtype = kwargs.pop("dtype", self.dtype)
708
- if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
736
+ if isinstance(self.device, tuple):
709
737
  if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
710
738
  if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
711
739
  contiguous = kwargs.pop("contiguous", True)
712
- rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
713
- return Tensor(MultiLazyBuffer(cast(List[LazyBuffer], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
740
+ sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
741
+ rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
742
+ return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
714
743
  return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
715
744
 
716
745
  # ***** rng hlops *****
717
746
 
718
747
  @staticmethod
719
- def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
748
+ def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
720
749
  """
721
750
  Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
722
751
  If `dtype` is not specified, the default type is used.
@@ -731,10 +760,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
731
760
  """
732
761
  # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
733
762
  src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
734
- return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
763
+ return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
735
764
 
736
765
  @staticmethod
737
- def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
766
+ def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
738
767
  """
739
768
  Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
740
769
  If `dtype` is not specified, the default type is used.
@@ -748,12 +777,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
748
777
  ```
749
778
  """
750
779
  if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
751
- dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
780
+ dtype = to_dtype(dtype)
752
781
  if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
753
782
  return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
754
783
 
755
784
  @staticmethod
756
- def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
785
+ def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
757
786
  """
758
787
  Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
759
788
 
@@ -765,10 +794,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
765
794
  print(Tensor.normal(2, 3, mean=10, std=2).numpy())
766
795
  ```
767
796
  """
768
- return (std * Tensor.randn(*shape, **kwargs)) + mean
797
+ return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
769
798
 
770
799
  @staticmethod
771
- def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
800
+ def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
772
801
  """
773
802
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
774
803
 
@@ -780,8 +809,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
780
809
  print(Tensor.uniform(2, 3, low=2, high=10).numpy())
781
810
  ```
782
811
  """
783
- dtype = kwargs.pop("dtype", dtypes.default_float)
784
- return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
812
+ return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
785
813
 
786
814
  @staticmethod
787
815
  def scaled_uniform(*shape, **kwargs) -> Tensor:
@@ -860,49 +888,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
860
888
 
861
889
  # ***** toposort and backward pass *****
862
890
 
863
- def _deepwalk(self):
864
- def _walk(node, visited):
865
- visited.add(node)
866
- # if tensor is not leaf, reset grad
867
- if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
868
- if ctx:
869
- for i in node._ctx.parents:
870
- if i not in visited: yield from _walk(i, visited)
871
- yield node
872
- return list(_walk(self, set()))
891
+ def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
892
+ """
893
+ Compute the gradient of the targets with respect to self.
873
894
 
874
- def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
895
+ ```python exec="true" source="above" session="tensor" result="python"
896
+ x = Tensor.eye(3)
897
+ y = Tensor([[2.0,0,-2.0]])
898
+ z = y.matmul(x).sum()
899
+ dx, dy = z.gradient(x, y)
900
+
901
+ print(dx.tolist()) # dz/dx
902
+ print(dy.tolist()) # dz/dy
903
+ ```
904
+ """
905
+ assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
906
+ if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
907
+ rets = []
908
+ target_uops = [x.lazydata for x in targets]
909
+ grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
910
+ ret = []
911
+ for x in target_uops:
912
+ if (y:=grads.get(x)) is None:
913
+ if materialize_grads: y = x.const_like(0)
914
+ else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
915
+ ret.append(y)
916
+ rets.append(ret)
917
+ # create returned Tensors
918
+ return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
919
+
920
+ def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
875
921
  """
876
922
  Propagates the gradient of a tensor backwards through the computation graph.
877
923
  If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
878
- If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
879
924
  ```python exec="true" source="above" session="tensor" result="python"
880
925
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
881
926
  t.sum().backward()
882
927
  print(t.grad.numpy())
883
928
  ```
884
929
  """
885
- toposorted = self._deepwalk()
886
- if gradient is None:
887
- assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
888
- # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
889
- # this is "implicit gradient creation"
890
- gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
891
-
892
- assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
893
- self.grad = gradient
894
- for t0 in reversed(toposorted):
895
- if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
896
- token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
897
- grads = t0._ctx.backward(t0.grad.lazydata)
898
- _METADATA.reset(token)
899
- grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
900
- for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
901
- for t, g in zip(t0._ctx.parents, grads):
902
- if g is not None and t.requires_grad:
903
- assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
904
- t.grad = g if t.grad is None else (t.grad + g)
905
- if not retain_graph: del t0._ctx
930
+ all_uops = self.lazydata.toposort
931
+ tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
932
+ t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
933
+ # clear contexts
934
+ for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
935
+ assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
936
+ t.grad = g if t.grad is None else (t.grad + g)
906
937
  return self
907
938
 
908
939
  # ***** movement low level ops *****
@@ -926,7 +957,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
926
957
  # resolve -1
927
958
  if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
928
959
  if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
929
- return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
960
+ return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
930
961
 
931
962
  def expand(self, shape, *args) -> Tensor:
932
963
  """
@@ -940,7 +971,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
940
971
  print(t.expand(4, -1).numpy())
941
972
  ```
942
973
  """
943
- return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
974
+ new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
975
+ return self._broadcast_to(new_shape)
944
976
 
945
977
  def permute(self, order, *args) -> Tensor:
946
978
  """
@@ -958,7 +990,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
958
990
  """
959
991
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
960
992
  if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
961
- return F.Permute.apply(self, order=order_arg)
993
+ return self._apply_uop(UOp.permute, arg=order_arg)
962
994
 
963
995
  def flip(self, axis, *args) -> Tensor:
964
996
  """
@@ -978,9 +1010,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
978
1010
  """
979
1011
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
980
1012
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
981
- return F.Flip.apply(self, axis=axis_arg)
1013
+ return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
982
1014
 
983
- def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
1015
+ def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
984
1016
  """
985
1017
  Returns a tensor that shrinks the each axis based on input arg.
986
1018
  `arg` must have the same length as `self.ndim`.
@@ -998,24 +1030,25 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
998
1030
  ```
999
1031
  """
1000
1032
  if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
1001
- return F.Shrink.apply(self, arg=tuple(shrink_arg))
1033
+ return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
1002
1034
 
1003
- def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
1035
+ def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
1004
1036
  """
1005
1037
  Returns a tensor with padding applied based on the input `padding`.
1038
+
1006
1039
  `padding` supports two padding structures:
1007
1040
 
1008
- 1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
1009
- - This structure matches PyTorch's pad.
1010
- - `padding` length must be even.
1041
+ 1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
1042
+ - This structure matches PyTorch's pad.
1043
+ - `padding` length must be even.
1011
1044
 
1012
- 2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
1013
- - This structure matches pad for jax, numpy, tensorflow and others.
1014
- - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
1015
- - `padding` must have the same length as `self.ndim`.
1045
+ 2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
1046
+ - This structure matches pad for JAX, NumPy, TensorFlow, and others.
1047
+ - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
1048
+ - `padding` must have the same length as `self.ndim`.
1016
1049
 
1017
1050
  Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
1018
- Padding modes is selected with `mode` which supports `constant` and `reflect`.
1051
+ Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
1019
1052
 
1020
1053
  ```python exec="true" source="above" session="tensor" result="python"
1021
1054
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
@@ -1031,176 +1064,166 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1031
1064
  print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
1032
1065
  ```
1033
1066
  """
1034
- if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
1035
- if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1036
- # turn flat padding into group padding
1037
- pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
1067
+ if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
1068
+ # flat padding
1069
+ if all(isinstance(p, (int,UOp)) for p in padding):
1070
+ if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1071
+ pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
1072
+ # group padding
1073
+ else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
1038
1074
  if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
1039
- X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
1040
- def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
1041
- # early return for symbolic with positive pads (no need to max)
1042
- if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
1043
- pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
1044
- if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
1075
+ X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
1076
+ if mode == "constant":
1077
+ def _constant(x:Tensor,px,v):
1078
+ return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
1079
+ return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
1080
+ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
1045
1081
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1082
+ if mode == "circular":
1083
+ if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
1084
+ if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
1085
+ orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
1086
+ return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
1046
1087
  for d,(pB,pA) in enumerate(pads):
1047
- if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
1048
- slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
1049
- xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
1088
+ if mode == "reflect":
1089
+ if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
1090
+ slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
1091
+ xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
1092
+ if mode == "replicate":
1093
+ shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
1094
+ xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
1050
1095
  X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
1051
- return X.shrink(shrinks(X.shape))
1096
+ return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
1052
1097
 
1053
1098
  # ***** movement high level ops *****
1054
1099
 
1055
- # Supported Indexing Implementations:
1056
- # 1. Int indexing (no copy)
1057
- # - for all dims where there's int, shrink -> reshape
1058
- # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
1059
- # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
1060
- # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
1061
- # 2. Slice indexing (no copy)
1062
- # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
1063
- # - first shrink the Tensor to X.shrink(((start, end),))
1064
- # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
1065
- # - flip where dim value is negative
1066
- # - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
1067
- # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
1068
- # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
1069
- # 3. None indexing (no copy)
1070
- # - reshape (inject) a dim at the dim where there's None
1071
- # 4. Tensor indexing (copy)
1072
- # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
1073
- # - combine masks together with mul
1074
- # - apply mask to self by mask * self
1075
- # - sum reduce away the extra dims added from creating masks
1076
- # Tiny Things:
1077
- # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
1078
- # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
1079
- # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
1080
- # 2. Bool indexing is not supported
1081
- # 3. Out of bounds Tensor indexing results in 0
1082
- # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
1083
1100
  def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
1084
- # 1. indices normalization and validation
1085
- # treat internal tuples and lists as Tensors and standardize indices to list type
1086
- if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
1087
- elif isinstance(indices, (tuple, list)):
1088
- indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
1089
- else: indices = [indices]
1090
-
1091
- # turn scalar Tensors into const val for int indexing if possible
1092
- indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
1093
- # move Tensor indices to the same device as self
1094
- indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
1101
+ # wrap single index into a list
1102
+ if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
1103
+ x, indices = self, list(indices)
1095
1104
 
1096
1105
  # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
1097
- ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
1106
+ if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
1098
1107
  fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
1099
1108
  num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
1100
- indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
1101
-
1102
- # use Dict[type, List[dimension]] to track elements in indices
1103
- type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
1104
-
1105
- # record None for dimension injection later and filter None and record rest of indices
1106
- type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
1107
- indices_filtered = [i for i in indices if i is not None]
1108
- for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
1109
-
1110
- if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
1111
- for index_type in type_dim:
1112
- if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
1113
1109
  if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
1110
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
1114
1111
 
1115
- # 2. basic indexing, uses only movement ops (no copy)
1116
- # currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
1117
- # turn indices in indices_filtered to Tuple[new_slice, strides]
1118
- for dim in type_dim[int]:
1119
- if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
1120
- raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
1121
- indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
1122
- for dim in type_dim[slice]:
1123
- if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
1124
- if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)):
1125
- raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice("
1126
- f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).")
1127
- s, e, st = index.indices(self.shape[dim])
1128
- indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
1129
- # skip all Tensor dims for basic indexing
1130
- for dim in type_dim[Tensor]:
1131
- dtype = indices_filtered[dim].dtype
1132
- if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported")
1133
- indices_filtered[dim] = ((0, self.shape[dim]), 1)
1134
-
1135
- new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
1136
- # flip negative strides
1137
- ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
1138
- # handle stride != 1 or -1
1139
- if any(abs(st) != 1 for st in strides):
1140
- strides = tuple(abs(s) for s in strides)
1141
- # pad shape to multiple of stride
1142
- if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted")
1143
- ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
1144
- ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
1145
- ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
1146
-
1147
- # inject 1 for dim where it's None and collapse dim for int
1148
- new_shape = list(ret.shape)
1149
- for dim in type_dim[None]: new_shape.insert(dim, 1)
1150
- for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
1151
-
1152
- ret = ret.reshape(new_shape)
1153
-
1154
- # 3. advanced indexing (copy)
1155
- if type_dim[Tensor]:
1156
- dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)]
1157
- # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
1158
- def calc_dim(tensor_dim:int) -> int:
1159
- return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
1160
-
1161
- assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
1162
- # track tensor_dim and tensor_index using a dict
1163
- # calc_dim to get dim and use that to normalize the negative tensor indices
1164
- idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors}
1165
-
1166
- masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
1167
- pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
1168
-
1169
- # create masks
1170
- for dim, i in idx.items():
1171
- try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
1112
+ indices_parsed, dim = [], 0
1113
+ for index in indices:
1114
+ size = 1 if index is None else self.shape[dim]
1115
+ boundary, stride = [0, size], 1 # defaults
1116
+ match index:
1117
+ case list() | tuple() | Tensor():
1118
+ if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
1119
+ if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
1120
+ index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
1121
+ case int() | UOp(): # sint
1122
+ if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
1123
+ boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
1124
+ case slice():
1125
+ if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
1126
+ if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
1127
+ # handle int slicing
1128
+ *boundary, stride = index.indices(cast(SupportsIndex, size))
1129
+ if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
1130
+ elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
1131
+ # update size for slice
1132
+ size = ceildiv((boundary[1] - boundary[0]), abs(stride))
1133
+ case None: pass # do nothing
1134
+ case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
1135
+ indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
1136
+ if index is not None: dim += 1
1137
+
1138
+ # movement op indexing
1139
+ if mops := [i for i in indices_parsed if i['index'] is not None]:
1140
+ # flip negative strides
1141
+ shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
1142
+ x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
1143
+ # handle stride != 1 or -1
1144
+ if any(abs(st) != 1 for st in strides):
1145
+ strides = tuple(abs(s) for s in strides)
1146
+ # pad shape to multiple of stride
1147
+ if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
1148
+ x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
1149
+ x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
1150
+ x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
1151
+
1152
+ # dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
1153
+ x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
1154
+
1155
+ # tensor indexing
1156
+ if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
1157
+ # unload the tensor object into actual tensors
1158
+ dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
1159
+ pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:]
1160
+
1161
+ # create index masks
1162
+ for dim, tensor in zip(dims, tensors):
1163
+ try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
1172
1164
  except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
1173
- a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
1174
- masks.append(i == a)
1165
+ masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
1175
1166
 
1176
1167
  # reduce masks to 1 mask
1177
1168
  mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
1178
1169
 
1179
1170
  # inject 1's for the extra dims added in create masks
1180
- reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
1171
+ reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
1181
1172
  # sum reduce the extra dims introduced in create masks
1182
- ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
1173
+ x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
1183
1174
 
1184
1175
  # special permute case
1185
- if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
1186
- ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
1176
+ if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
1177
+ x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
1187
1178
 
1188
1179
  # for advanced setitem, returns whole tensor with indices replaced
1189
1180
  if v is not None:
1190
- vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
1181
+ vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
1191
1182
  # add back reduced dims from sum
1192
1183
  for dim in sum_axis: vb = vb.unsqueeze(dim)
1193
- # axis to be reduced to match self.shape
1194
- axis = tuple(range(first_dim, first_dim + len(big_shape)))
1195
- # apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
1196
- vb = vb * mask
1197
- for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
1198
- # reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
1199
- ret = mask.any(axis).where(vb.squeeze(), self)
1184
+ # run _masked_setitem on tuple of axis that is to be reduced to match self.shape
1185
+ x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
1200
1186
 
1201
- return ret
1187
+ return x
1202
1188
 
1203
1189
  def __getitem__(self, indices) -> Tensor:
1190
+ """
1191
+ Retrieve a sub-tensor using indexing.
1192
+
1193
+ Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
1194
+
1195
+ Examples:
1196
+ ```python exec="true" source="above" session="tensor" result="python"
1197
+ t = Tensor.arange(12).reshape(3, 4)
1198
+ print(t.numpy())
1199
+ ```
1200
+
1201
+ - Int Indexing: Select an element or sub-tensor using integers for each dimension.
1202
+ ```python exec="true" source="above" session="tensor" result="python"
1203
+ print(t[1, 2].numpy())
1204
+ ```
1205
+
1206
+ - Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
1207
+ ```python exec="true" source="above" session="tensor" result="python"
1208
+ print(t[0:2, ::2].numpy())
1209
+ ```
1210
+
1211
+ - Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
1212
+ ```python exec="true" source="above" session="tensor" result="python"
1213
+ print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
1214
+ ```
1215
+
1216
+ - `None` Indexing: Add a new dimension to the tensor.
1217
+ ```python exec="true" source="above" session="tensor" result="python"
1218
+ print(t[:, None].shape)
1219
+ ```
1220
+
1221
+ NOTE: Out-of-bounds indexing results in a value of `0`.
1222
+ ```python exec="true" source="above" session="tensor" result="python"
1223
+ t = Tensor([1, 2, 3])
1224
+ print(t[Tensor([4, 3, 2])].numpy())
1225
+ ```
1226
+ """
1204
1227
  return self._getitem(indices)
1205
1228
 
1206
1229
  def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
@@ -1208,9 +1231,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1208
1231
  self._getitem(indices).assign(v)
1209
1232
  return
1210
1233
  # NOTE: check that setitem target is valid first
1211
- if not all(lb.st.contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
1212
- if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1213
- if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
1234
+ if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
1235
+ if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
1236
+ if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1214
1237
  if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
1215
1238
 
1216
1239
  res = self.realize()._getitem(indices, v)
@@ -1238,7 +1261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1238
1261
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1239
1262
  index = index.to(self.device)
1240
1263
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1241
- return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
1264
+ return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
1242
1265
 
1243
1266
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1244
1267
  """
@@ -1302,7 +1325,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1302
1325
  ```
1303
1326
  """
1304
1327
  repeats = argfix(repeats, *args)
1305
- base_shape = _pad_left(self.shape, repeats)[0]
1328
+ base_shape = _align_left(self.shape, repeats)[0]
1306
1329
  unsqueezed_shape = flatten([[1, s] for s in base_shape])
1307
1330
  expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
1308
1331
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
@@ -1313,7 +1336,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1313
1336
  if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
1314
1337
  return dim + total if dim < 0 else dim
1315
1338
 
1316
- def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1339
+ def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
1317
1340
  """
1318
1341
  Splits the tensor into chunks along the dimension specified by `dim`.
1319
1342
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
@@ -1338,7 +1361,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1338
1361
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
1339
1362
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
1340
1363
 
1341
- def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
1364
+ def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
1342
1365
  """
1343
1366
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
1344
1367
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
@@ -1362,7 +1385,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1362
1385
  dim = self._resolve_dim(dim)
1363
1386
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1364
1387
 
1365
- def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]:
1388
+ def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
1366
1389
  """
1367
1390
  Generates coordinate matrices from coordinate vectors.
1368
1391
  Input tensors can be scalars or 1D tensors.
@@ -1462,7 +1485,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1462
1485
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
1463
1486
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1464
1487
 
1465
- def unflatten(self, dim:int, sizes:Tuple[int,...]):
1488
+ def unflatten(self, dim:int, sizes:tuple[int,...]):
1466
1489
  """
1467
1490
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
1468
1491
 
@@ -1479,7 +1502,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1479
1502
  dim = self._resolve_dim(dim)
1480
1503
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
1481
1504
 
1482
- def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
1505
+ def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
1483
1506
  """
1484
1507
  Rolls the tensor along specified dimension(s).
1485
1508
  The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
@@ -1499,12 +1522,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1499
1522
  rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
1500
1523
  return rolled
1501
1524
 
1525
+ def rearrange(self, formula:str, **sizes) -> Tensor:
1526
+ """
1527
+ Rearranges input according to formula
1528
+
1529
+ See: https://einops.rocks/api/rearrange/
1530
+
1531
+ ```python exec="true" source="above" session="tensor" result="python"
1532
+ x = Tensor([[1, 2], [3, 4]])
1533
+ print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
1534
+ ```
1535
+ """
1536
+ def parse_formula(formula: str):
1537
+ tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
1538
+ lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
1539
+ pairs = list(zip(lparens, rparens))
1540
+ assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
1541
+ return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
1542
+
1543
+ assert formula.count("->") == 1, 'need exactly one "->" in formula'
1544
+
1545
+ (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
1546
+
1547
+ for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
1548
+ assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
1549
+ for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
1550
+ assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
1551
+ assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
1552
+
1553
+ # resolve ellipsis
1554
+ if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
1555
+ lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
1556
+ unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
1557
+ flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
1558
+
1559
+ # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
1560
+ t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
1561
+ for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
1562
+ t = t.permute([lhs.index(name) for name in rhs])
1563
+ return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1564
+
1502
1565
  # ***** reduce ops *****
1503
1566
 
1504
- def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1567
+ def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1505
1568
  axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
1506
1569
  if self.ndim == 0: axis = ()
1507
- ret = fxn.apply(self, axis=axis)
1570
+ ret = self._apply_uop(UOp.r, op=op, axis=axis)
1508
1571
  return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
1509
1572
 
1510
1573
  def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1531,7 +1594,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1531
1594
  print(t.sum(axis=1).numpy())
1532
1595
  ```
1533
1596
  """
1534
- ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
1597
+ ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
1535
1598
  return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1536
1599
 
1537
1600
  def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1558,7 +1621,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1558
1621
  print(t.prod(axis=1).numpy())
1559
1622
  ```
1560
1623
  """
1561
- return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
1624
+ return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
1562
1625
 
1563
1626
  def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1564
1627
  """
@@ -1581,7 +1644,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1581
1644
  print(t.max(axis=1, keepdim=True).numpy())
1582
1645
  ```
1583
1646
  """
1584
- return self._reduce(F.Max, axis, keepdim)
1647
+ return self._reduce(Ops.MAX, axis, keepdim)
1648
+
1649
+ def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
1585
1650
 
1586
1651
  def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1587
1652
  """
@@ -1604,8 +1669,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1604
1669
  print(t.min(axis=1, keepdim=True).numpy())
1605
1670
  ```
1606
1671
  """
1607
- if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
1608
- return -((-self).max(axis=axis, keepdim=keepdim))
1672
+ return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
1609
1673
 
1610
1674
  def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1611
1675
  """
@@ -1651,6 +1715,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1651
1715
  """
1652
1716
  return self.logical_not().any(axis, keepdim).logical_not()
1653
1717
 
1718
+ def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
1719
+ """
1720
+ Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance.
1721
+
1722
+ The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison.
1723
+
1724
+ By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close.
1725
+
1726
+ ```python exec="true" source="above" session="tensor" result="python"
1727
+ print(Tensor([1e-7, 1e-8, 1e-9, float('nan')]).isclose(Tensor([0.0, 0.0, 0.0, float('nan')])).numpy())
1728
+ ```
1729
+ ```python exec="true" source="above" session="tensor" result="python"
1730
+ print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
1731
+ ```
1732
+ """
1733
+ # TODO: Tensor.isfinite
1734
+ def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
1735
+ is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
1736
+ is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
1737
+ is_nan_close = (self.isnan() & other.isnan()) & equal_nan
1738
+ return is_finite_close | is_infinite_close | is_nan_close
1739
+
1654
1740
  def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1655
1741
  """
1656
1742
  Returns the mean value of the tensor along the specified axis or axes.
@@ -1745,8 +1831,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1745
1831
  return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
1746
1832
 
1747
1833
  def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
1748
- x = self.cast(dtype) if dtype is not None else self
1749
- m = x - x.max(axis=axis, keepdim=True).detach()
1834
+ m = self - self.max(axis=axis, keepdim=True).detach()
1835
+ if dtype is not None: m = m.cast(dtype)
1750
1836
  e = m.exp()
1751
1837
  return m, e, e.sum(axis=axis, keepdim=True)
1752
1838
 
@@ -1847,8 +1933,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1847
1933
  print(t.logcumsumexp(axis=1).numpy())
1848
1934
  ```
1849
1935
  """
1850
- m = self.max(axis=axis, keepdim=True)
1851
- return (self - m).exp().cumsum(axis=axis).log() + m
1936
+ if self.ndim == 0: return self
1937
+ axis = self._resolve_dim(axis)
1938
+ x = self.transpose(axis, -1)
1939
+ last_dim_size = x.shape[-1]
1940
+ x_reshaped = x.reshape(-1, last_dim_size)
1941
+ x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
1942
+ x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
1943
+ mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
1944
+ ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
1945
+ return ret.reshape(*x.shape).transpose(-1, axis)
1852
1946
 
1853
1947
  def argmax(self, axis=None, keepdim=False):
1854
1948
  """
@@ -1898,47 +1992,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1898
1992
  print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
1899
1993
  ```
1900
1994
  """
1901
- return (-self).argmax(axis=axis, keepdim=keepdim)
1902
-
1903
- def rearrange(self, formula: str, **sizes) -> Tensor:
1904
- """
1905
- Rearranges input according to formula
1906
-
1907
- See: https://einops.rocks/api/rearrange/
1908
-
1909
- ```python exec="true" source="above" session="tensor" result="python"
1910
- x = Tensor([[1, 2], [3, 4]])
1911
- print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
1912
- ```
1913
- """
1914
- def parse_formula(formula: str):
1915
- tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
1916
- lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
1917
- pairs = list(zip(lparens, rparens))
1918
- assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
1919
- return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
1920
-
1921
- assert formula.count("->") == 1, 'need exactly one "->" in formula'
1922
-
1923
- (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
1924
-
1925
- for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
1926
- assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
1927
- for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
1928
- assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
1929
- assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
1930
-
1931
- # resolve ellipsis
1932
- if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
1933
- lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
1934
- unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
1935
- flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
1936
-
1937
- # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
1938
- t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
1939
- for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
1940
- t = t.permute([lhs.index(name) for name in rhs])
1941
- return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1995
+ return self._inverse().argmax(axis=axis, keepdim=keepdim)
1942
1996
 
1943
1997
  @staticmethod
1944
1998
  def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
@@ -1964,7 +2018,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1964
2018
  (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
1965
2019
  return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1966
2020
 
1967
- xs:Tuple[Tensor, ...] = argfix(*operands)
2021
+ xs:tuple[Tensor, ...] = argfix(*operands)
1968
2022
  inputs_str, output = parse_formula(formula, *xs)
1969
2023
  inputs = inputs_str.split(",")
1970
2024
  assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
@@ -1972,7 +2026,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1972
2026
  # map the value of each letter in the formula
1973
2027
  letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
1974
2028
 
1975
- xs_:List[Tensor] = []
2029
+ xs_:list[Tensor] = []
1976
2030
  lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
1977
2031
  for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
1978
2032
  # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
@@ -1987,7 +2041,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1987
2041
 
1988
2042
  # ***** processing ops *****
1989
2043
 
1990
- def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
2044
+ def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
1991
2045
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
1992
2046
  s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
1993
2047
  assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@@ -1995,10 +2049,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
1995
2049
  assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
1996
2050
  o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
1997
2051
  if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
1998
- # repeats such that we don't need padding
1999
- x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)])
2052
+ # input size scaling factor to make sure shrink for stride is possible
2053
+ f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
2054
+ # # repeats such that we don't need padding
2055
+ x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
2000
2056
  # handle dilation
2001
- x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
2057
+ x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
2002
2058
  # handle stride
2003
2059
  x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
2004
2060
  x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
@@ -2010,14 +2066,44 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2010
2066
  x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
2011
2067
  return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
2012
2068
 
2013
- def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2069
+ def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2070
+ if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
2071
+ raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
2014
2072
  return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
2015
2073
 
2074
+ def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
2075
+ (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
2076
+ pads, grouped_pads = list(pads), _flat_to_grouped(pads)
2077
+ # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
2078
+ o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)]
2079
+ for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)):
2080
+ # we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
2081
+ # `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
2082
+ # we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
2083
+ # `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset)
2084
+ pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0)
2085
+ return pads
2086
+
2016
2087
  # NOTE: these work for more than 2D
2017
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
2088
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
2018
2089
  """
2019
2090
  Applies average pooling over a tensor.
2020
2091
 
2092
+ This function supports three different types of `padding`
2093
+
2094
+ 1. `int` (single value):
2095
+ Applies the same padding value uniformly to all spatial dimensions.
2096
+
2097
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2098
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2099
+
2100
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2101
+ Specifies explicit padding for each side of each spatial dimension in the form
2102
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2103
+
2104
+ When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2105
+ When `count_include_pad` is set to `False`, zero padding will not be included in the averaging calculation.
2106
+
2021
2107
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2022
2108
 
2023
2109
  See: https://paperswithcode.com/method/average-pooling
@@ -2027,17 +2113,43 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2027
2113
  print(t.avg_pool2d().numpy())
2028
2114
  ```
2029
2115
  ```python exec="true" source="above" session="tensor" result="python"
2116
+ print(t.avg_pool2d(ceil_mode=True).numpy())
2117
+ ```
2118
+ ```python exec="true" source="above" session="tensor" result="python"
2030
2119
  print(t.avg_pool2d(padding=1).numpy())
2031
2120
  ```
2121
+ ```python exec="true" source="above" session="tensor" result="python"
2122
+ print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
2123
+ ```
2032
2124
  """
2033
- padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
2034
- def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
2035
- return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
2125
+ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
2126
+ def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
2127
+ reg_pads = self._resolve_pool_pads(padding, len(k_))
2128
+ ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation)
2129
+ if not count_include_pad:
2130
+ pads = ceil_pads if ceil_mode else reg_pads
2131
+ return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
2132
+ if not ceil_mode: return pool(self, reg_pads).mean(axis)
2133
+ return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
2036
2134
 
2037
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
2135
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
2038
2136
  """
2039
2137
  Applies max pooling over a tensor.
2040
2138
 
2139
+ This function supports three different types of `padding`
2140
+
2141
+ 1. `int` (single value):
2142
+ Applies the same padding value uniformly to all spatial dimensions.
2143
+
2144
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2145
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2146
+
2147
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2148
+ Specifies explicit padding for each side of each spatial dimension in the form
2149
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2150
+
2151
+ When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2152
+
2041
2153
  NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2042
2154
 
2043
2155
  See: https://paperswithcode.com/method/max-pooling
@@ -2047,17 +2159,33 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2047
2159
  print(t.max_pool2d().numpy())
2048
2160
  ```
2049
2161
  ```python exec="true" source="above" session="tensor" result="python"
2162
+ print(t.max_pool2d(ceil_mode=True).numpy())
2163
+ ```
2164
+ ```python exec="true" source="above" session="tensor" result="python"
2050
2165
  print(t.max_pool2d(padding=1).numpy())
2051
2166
  ```
2052
2167
  """
2053
- padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
2054
- return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2168
+ pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
2169
+ if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
2170
+ return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2055
2171
 
2056
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
2172
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
2057
2173
  acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2058
2174
  """
2059
2175
  Applies a convolution over a tensor with a given `weight` and optional `bias`.
2060
2176
 
2177
+ This function supports three different types of `padding`
2178
+
2179
+ 1. `int` (single value):
2180
+ Applies the same padding value uniformly to all spatial dimensions.
2181
+
2182
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2183
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2184
+
2185
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2186
+ Specifies explicit padding for each side of each spatial dimension in the form
2187
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2188
+
2061
2189
  NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
2062
2190
 
2063
2191
  See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
@@ -2070,9 +2198,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2070
2198
  """
2071
2199
  if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
2072
2200
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
2201
+ padding_ = self._resolve_pool_pads(padding, len(HW))
2073
2202
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
2074
- if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
2075
- padding_ = self._padding2d(padding, len(HW))
2076
2203
 
2077
2204
  # conv2d is a pooling op (with padding)
2078
2205
  x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
@@ -2120,6 +2247,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2120
2247
  """
2121
2248
  Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
2122
2249
 
2250
+ This function supports three different types of `padding`
2251
+
2252
+ 1. `int` (single value):
2253
+ Applies the same padding value uniformly to all spatial dimensions.
2254
+
2255
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2256
+ Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2257
+
2258
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2259
+ Specifies explicit padding for each side of each spatial dimension in the form
2260
+ `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2261
+
2123
2262
  NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
2124
2263
 
2125
2264
  See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
@@ -2132,14 +2271,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2132
2271
  """
2133
2272
  x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
2134
2273
  HW = weight.shape[2:]
2135
- stride, dilation, padding, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
2274
+ padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW)))
2275
+ stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
2136
2276
  if any(s>1 for s in stride):
2137
2277
  # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
2138
2278
  x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
2139
2279
  x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
2140
2280
  x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
2141
2281
  x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
2142
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
2282
+ padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
2143
2283
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
2144
2284
 
2145
2285
  def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
@@ -2185,15 +2325,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2185
2325
  """
2186
2326
  return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
2187
2327
 
2188
- def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
2189
- assert self.shape[axis] != 0
2190
- pl_sz = self.shape[axis] - int(not _first_zero)
2191
- return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
2328
+ def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
2329
+ assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
2330
+ pl_sz = self.shape[axis] - int(not _include_initial)
2331
+ pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
2332
+ return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
2333
+
2334
+ def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
2335
+ axis = self._resolve_dim(axis)
2336
+ if self.ndim == 0 or 0 in self.shape: return self
2337
+ # TODO: someday the optimizer will find this on it's own
2338
+ # for now this is a two stage cumsum
2339
+ SPLIT = 256
2340
+ if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
2341
+ ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
2342
+ base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
2343
+ base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
2344
+ def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2345
+ return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
2346
+
2192
2347
  def cumsum(self, axis:int=0) -> Tensor:
2193
2348
  """
2194
- Computes the cumulative sum of the tensor along the specified axis.
2195
-
2196
- You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
2349
+ Computes the cumulative sum of the tensor along the specified `axis`.
2197
2350
 
2198
2351
  ```python exec="true" source="above" session="tensor" result="python"
2199
2352
  t = Tensor.ones(2, 3)
@@ -2203,17 +2356,21 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2203
2356
  print(t.cumsum(1).numpy())
2204
2357
  ```
2205
2358
  """
2206
- axis = self._resolve_dim(axis)
2207
- if self.ndim == 0 or 0 in self.shape: return self
2208
- # TODO: someday the optimizer will find this on it's own
2209
- # for now this is a two stage cumsum
2210
- SPLIT = 256
2211
- if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
2212
- ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
2213
- base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
2214
- base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
2215
- def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
2216
- return fix(ret) + fix(base_add)
2359
+ return self._split_cumalu(axis, Ops.ADD)
2360
+
2361
+ def cummax(self, axis:int=0) -> Tensor:
2362
+ """
2363
+ Computes the cumulative max of the tensor along the specified `axis`.
2364
+
2365
+ ```python exec="true" source="above" session="tensor" result="python"
2366
+ t = Tensor([0, 1, -1, 2, -2, 3, -3])
2367
+ print(t.numpy())
2368
+ ```
2369
+ ```python exec="true" source="above" session="tensor" result="python"
2370
+ print(t.cummax(0).numpy())
2371
+ ```
2372
+ """
2373
+ return self._split_cumalu(axis, Ops.MAX)
2217
2374
 
2218
2375
  @staticmethod
2219
2376
  def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
@@ -2271,7 +2428,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2271
2428
  """
2272
2429
  return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
2273
2430
 
2274
- def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2431
+ def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2275
2432
  """
2276
2433
  Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
2277
2434
 
@@ -2296,13 +2453,104 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2296
2453
  reshape[i] = expand[i] = size[i]
2297
2454
  if mode == "linear":
2298
2455
  index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
2299
- low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
2456
+ low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
2300
2457
  x = x.gather(i, low).lerp(x.gather(i, high), perc)
2301
2458
  else:
2302
2459
  index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
2303
2460
  x = x.gather(i, index)
2304
2461
  return x.cast(self.dtype)
2305
2462
 
2463
+ def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
2464
+ index, dim = index.to(self.device), self._resolve_dim(dim)
2465
+ assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2466
+ assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
2467
+ f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
2468
+ if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
2469
+ # shrink src to index shape to shrink away the unused values
2470
+ src = src.shrink(tuple((0,s) for s in index.shape))
2471
+ # prepare src and mask for reduce with respect to dim
2472
+ src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
2473
+ mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
2474
+ # pad src and mask to self.shape so that reduce can be done with padded values as no-ops
2475
+ src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
2476
+ return src, mask
2477
+
2478
+ def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
2479
+ """
2480
+ Scatters `src` values along an axis specified by `dim`.
2481
+ Apply `add` or `multiply` reduction operation with `reduce`.
2482
+
2483
+ NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
2484
+
2485
+ ```python exec="true" source="above" session="tensor" result="python"
2486
+ src = Tensor.arange(1, 11).reshape(2, 5)
2487
+ print(src.numpy())
2488
+ ```
2489
+ ```python exec="true" source="above" session="tensor" result="python"
2490
+ index = Tensor([[0, 1, 2, 0]])
2491
+ print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
2492
+ ```
2493
+ ```python exec="true" source="above" session="tensor" result="python"
2494
+ index = Tensor([[0, 1, 2], [0, 1, 4]])
2495
+ print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
2496
+ ```
2497
+ ```python exec="true" source="above" session="tensor" result="python"
2498
+ print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
2499
+ ```
2500
+ ```python exec="true" source="above" session="tensor" result="python"
2501
+ print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
2502
+ ```
2503
+ """
2504
+ if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
2505
+ if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
2506
+ if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype)
2507
+ if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
2508
+ if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
2509
+ src, mask = self._pre_scatter(dim, index, src)
2510
+ return _masked_setitem(self, src, mask, (-1,))
2511
+
2512
+ def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
2513
+ include_self:bool=True) -> Tensor:
2514
+ """
2515
+ Scatters `src` values along an axis specified by `dim`.
2516
+ Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
2517
+
2518
+ Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
2519
+
2520
+ ```python exec="true" source="above" session="tensor" result="python"
2521
+ src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
2522
+ print(src.numpy())
2523
+ index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
2524
+ print(index.numpy())
2525
+ ```
2526
+ ```python exec="true" source="above" session="tensor" result="python"
2527
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
2528
+ ```
2529
+ ```python exec="true" source="above" session="tensor" result="python"
2530
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
2531
+ ```
2532
+ ```python exec="true" source="above" session="tensor" result="python"
2533
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
2534
+ ```
2535
+ ```python exec="true" source="above" session="tensor" result="python"
2536
+ print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
2537
+ ```
2538
+ ```python exec="true" source="above" session="tensor" result="python"
2539
+ print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
2540
+ ```
2541
+ """
2542
+ src, mask = self._pre_scatter(dim, index, src)
2543
+ def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
2544
+ # TODO: should not overwrite acc_dtype here?
2545
+ if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
2546
+ if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
2547
+ if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
2548
+ if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
2549
+ if reduce == "mean":
2550
+ count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
2551
+ return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
2552
+ raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
2553
+
2306
2554
  # ***** unary ops *****
2307
2555
 
2308
2556
  def logical_not(self):
@@ -2313,7 +2561,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2313
2561
  print(Tensor([False, True]).logical_not().numpy())
2314
2562
  ```
2315
2563
  """
2316
- return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
2564
+ return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
2317
2565
  def neg(self):
2318
2566
  """
2319
2567
  Negates the tensor element-wise.
@@ -2327,12 +2575,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2327
2575
  """
2328
2576
  Returns a contiguous tensor.
2329
2577
  """
2330
- return F.Contiguous.apply(self)
2578
+ return self._apply_uop(UOp.contiguous)
2331
2579
  def contiguous_backward(self):
2332
2580
  """
2333
2581
  Inserts a contiguous operation in the backward pass.
2334
2582
  """
2335
- return F.ContiguousBackward.apply(self)
2583
+ return self._apply_uop(UOp.contiguous_backward)
2336
2584
  def log(self):
2337
2585
  """
2338
2586
  Computes the natural logarithm element-wise.
@@ -2343,7 +2591,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2343
2591
  print(Tensor([1., 2., 4., 8.]).log().numpy())
2344
2592
  ```
2345
2593
  """
2346
- return F.Log.apply(self.cast(least_upper_float(self.dtype)))
2594
+ return self.log2()*math.log(2)
2347
2595
  def log2(self):
2348
2596
  """
2349
2597
  Computes the base-2 logarithm element-wise.
@@ -2354,7 +2602,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2354
2602
  print(Tensor([1., 2., 4., 8.]).log2().numpy())
2355
2603
  ```
2356
2604
  """
2357
- return self.log()/math.log(2)
2605
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
2358
2606
  def exp(self):
2359
2607
  """
2360
2608
  Computes the exponential function element-wise.
@@ -2365,7 +2613,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2365
2613
  print(Tensor([0., 1., 2., 3.]).exp().numpy())
2366
2614
  ```
2367
2615
  """
2368
- return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
2616
+ return self.mul(1/math.log(2)).exp2()
2369
2617
  def exp2(self):
2370
2618
  """
2371
2619
  Computes the base-2 exponential function element-wise.
@@ -2376,7 +2624,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2376
2624
  print(Tensor([0., 1., 2., 3.]).exp2().numpy())
2377
2625
  ```
2378
2626
  """
2379
- return F.Exp.apply(self*math.log(2))
2627
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
2380
2628
  def relu(self):
2381
2629
  """
2382
2630
  Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -2387,7 +2635,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2387
2635
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
2388
2636
  ```
2389
2637
  """
2390
- return F.Relu.apply(self)
2638
+ return (self>0).where(self, 0)
2639
+
2391
2640
  def sigmoid(self):
2392
2641
  """
2393
2642
  Applies the Sigmoid function element-wise.
@@ -2398,7 +2647,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2398
2647
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
2399
2648
  ```
2400
2649
  """
2401
- return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
2650
+ return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
2651
+
2402
2652
  def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
2403
2653
  """
2404
2654
  Applies the Hardsigmoid function element-wise.
@@ -2421,7 +2671,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2421
2671
  print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
2422
2672
  ```
2423
2673
  """
2424
- return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
2674
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
2425
2675
  def rsqrt(self):
2426
2676
  """
2427
2677
  Computes the reciprocal of the square root of the tensor element-wise.
@@ -2430,7 +2680,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2430
2680
  print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
2431
2681
  ```
2432
2682
  """
2433
- return self.reciprocal().sqrt()
2683
+ return self.sqrt().reciprocal()
2434
2684
  def sin(self):
2435
2685
  """
2436
2686
  Computes the sine of the tensor element-wise.
@@ -2439,7 +2689,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2439
2689
  print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
2440
2690
  ```
2441
2691
  """
2442
- return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
2692
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
2443
2693
  def cos(self):
2444
2694
  """
2445
2695
  Computes the cosine of the tensor element-wise.
@@ -2459,6 +2709,39 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2459
2709
  """
2460
2710
  return self.sin() / self.cos()
2461
2711
 
2712
+ def asin(self):
2713
+ """
2714
+ Computes the inverse sine (arcsine) of the tensor element-wise.
2715
+
2716
+ ```python exec="true" source="above" session="tensor" result="python"
2717
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
2718
+ ```
2719
+ """
2720
+ # https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
2721
+ coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
2722
+ x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
2723
+ return self.sign() * x
2724
+
2725
+ def acos(self):
2726
+ """
2727
+ Computes the inverse cosine (arccosine) of the tensor element-wise.
2728
+
2729
+ ```python exec="true" source="above" session="tensor" result="python"
2730
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
2731
+ ```
2732
+ """
2733
+ return math.pi / 2 - self.asin()
2734
+
2735
+ def atan(self):
2736
+ """
2737
+ Computes the inverse tangent (arctan) of the tensor element-wise.
2738
+
2739
+ ```python exec="true" source="above" session="tensor" result="python"
2740
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
2741
+ ```
2742
+ """
2743
+ return (self / (1 + self * self).sqrt()).asin()
2744
+
2462
2745
  # ***** math functions *****
2463
2746
 
2464
2747
  def trunc(self: Tensor) -> Tensor:
@@ -2565,7 +2848,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2565
2848
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
2566
2849
  ```
2567
2850
  """
2568
- return F.Sign.apply(self)
2851
+ return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
2569
2852
  def abs(self):
2570
2853
  """
2571
2854
  Computes the absolute value of the tensor element-wise.
@@ -2583,7 +2866,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2583
2866
  print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
2584
2867
  ```
2585
2868
  """
2586
- return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
2869
+ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
2587
2870
 
2588
2871
  # ***** activation functions *****
2589
2872
 
@@ -2613,6 +2896,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2613
2896
  """
2614
2897
  return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2615
2898
 
2899
+ def selu(self, alpha=1.67326, gamma=1.0507):
2900
+ """
2901
+ Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
2902
+
2903
+ - Described: https://paperswithcode.com/method/selu
2904
+ - Paper: https://arxiv.org/abs/1706.02515v5
2905
+
2906
+ ```python exec="true" source="above" session="tensor" result="python"
2907
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
2908
+ ```
2909
+ """
2910
+ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
2911
+
2616
2912
  def swish(self):
2617
2913
  """
2618
2914
  See `.silu()`
@@ -2840,17 +3136,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2840
3136
  return self / (1 + self.abs())
2841
3137
 
2842
3138
  # ***** broadcasted elementwise ops *****
2843
- def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
2844
- if self.shape == shape: return self
2845
- if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
2846
- # first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
2847
- padded, _ = _pad_left(self.shape, shape)
2848
- # for each dimension, check either from_ is 1, or it does not change
2849
- if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
2850
- raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2851
- return F.Expand.apply(self.reshape(padded), shape=shape)
2852
-
2853
- def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
3139
+ def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor:
3140
+ if self.shape == new_shape: return self
3141
+ if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
3142
+ # first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
3143
+ shape, _ = _align_left(self.shape, new_shape)
3144
+ # for each dimension, check either dim is 1, or it does not change
3145
+ if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
3146
+ raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
3147
+ return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
3148
+
3149
+ def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
2854
3150
  x: Tensor = self
2855
3151
  if not isinstance(y, Tensor):
2856
3152
  # make y a Tensor
@@ -2867,12 +3163,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2867
3163
  if reverse: x, y = y, x
2868
3164
 
2869
3165
  # broadcast
2870
- out_shape = _broadcast_shape(x.shape, y.shape)
2871
- return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
2872
-
2873
- def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
2874
- return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
2875
- and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
3166
+ return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
2876
3167
 
2877
3168
  def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2878
3169
  """
@@ -2892,7 +3183,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2892
3183
  print(t.add(Tensor([[2.0], [3.5]])).numpy())
2893
3184
  ```
2894
3185
  """
2895
- return F.Add.apply(*self._broadcasted(x, reverse))
3186
+ return self._apply_broadcasted_uop(UOp.add, x, reverse)
2896
3187
 
2897
3188
  def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2898
3189
  """
@@ -2933,20 +3224,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2933
3224
  print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
2934
3225
  ```
2935
3226
  """
2936
- return F.Mul.apply(*self._broadcasted(x, reverse))
3227
+ return self._apply_broadcasted_uop(UOp.mul, x, reverse)
2937
3228
 
2938
3229
  def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2939
3230
  """
2940
3231
  Divides `self` by `x`.
2941
3232
  Equivalent to `self // x`.
2942
3233
  Supports broadcasting to a common shape, type promotion, and integer inputs.
2943
- `idiv` performs integer division.
3234
+ `idiv` performs integer division (truncate towards zero).
2944
3235
 
2945
3236
  ```python exec="true" source="above" session="tensor" result="python"
2946
- print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy())
3237
+ print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
2947
3238
  ```
2948
3239
  """
2949
- return F.IDiv.apply(*self._broadcasted(x, reverse))
3240
+ return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
2950
3241
 
2951
3242
  def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2952
3243
  """
@@ -2970,6 +3261,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2970
3261
  numerator, denominator = self._broadcasted(x, reverse)
2971
3262
  return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
2972
3263
 
3264
+ def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3265
+ """
3266
+ Mod `self` by `x`.
3267
+ Equivalent to `self % x`.
3268
+ Supports broadcasting to a common shape, type promotion, and integer inputs.
3269
+
3270
+ ```python exec="true" source="above" session="tensor" result="python"
3271
+ print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
3272
+ ```
3273
+ """
3274
+ a, b = self._broadcasted(x, reverse)
3275
+ return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
3276
+
2973
3277
  def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2974
3278
  """
2975
3279
  Computes bitwise xor of `self` and `x`.
@@ -2984,7 +3288,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2984
3288
  ```
2985
3289
  """
2986
3290
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
2987
- return F.Xor.apply(*self._broadcasted(x, reverse))
3291
+ return self._apply_broadcasted_uop(UOp.xor, x, reverse)
2988
3292
 
2989
3293
  def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2990
3294
  """
@@ -2999,7 +3303,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
2999
3303
  ```
3000
3304
  """
3001
3305
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3002
- return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
3306
+ return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
3003
3307
 
3004
3308
  def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3005
3309
  """
@@ -3014,7 +3318,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3014
3318
  ```
3015
3319
  """
3016
3320
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3017
- return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
3321
+ return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
3018
3322
 
3019
3323
  def bitwise_not(self) -> Tensor:
3020
3324
  """
@@ -3028,7 +3332,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3028
3332
  ```
3029
3333
  """
3030
3334
  if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3031
- return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
3335
+ return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
3032
3336
 
3033
3337
  def lshift(self, x:int):
3034
3338
  """
@@ -3060,37 +3364,22 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3060
3364
  Equivalent to `self ** x`.
3061
3365
 
3062
3366
  ```python exec="true" source="above" session="tensor" result="python"
3063
- print(Tensor([-1, 2, 3]).pow(2).numpy())
3367
+ print(Tensor([-1, 2, 3]).pow(2.0).numpy())
3064
3368
  ```
3065
3369
  ```python exec="true" source="above" session="tensor" result="python"
3066
3370
  print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
3067
3371
  ```
3068
3372
  ```python exec="true" source="above" session="tensor" result="python"
3069
- print((2 ** Tensor([-1, 2, 3])).numpy())
3373
+ print((2.0 ** Tensor([-1, 2, 3])).numpy())
3070
3374
  ```
3071
3375
  """
3072
- x = self._to_const_val(x)
3073
- if not isinstance(x, Tensor) and not reverse:
3074
- # simple pow identities
3075
- if x < 0: return self.reciprocal().pow(-x)
3076
- if x == 0: return 1 + self * 0
3077
- if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
3078
- if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
3079
-
3080
- # positive const ** self
3081
- if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
3082
-
3083
3376
  base, exponent = self._broadcasted(x, reverse=reverse)
3084
- # start with b ** e = exp(e * log(b))
3085
- ret = base.abs().log().mul(exponent).exp()
3086
- # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
3087
- negative_base = (base < 0).detach().where(1, 0)
3088
- # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
3089
- correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
3090
- # inject nan for negative base and non-integer exponent
3091
- inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
3092
- # apply correct_sign inject_nan, and fix 0 ** 0 = 1
3093
- return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
3377
+ # TODO: int pow
3378
+ if not base.is_floating_point(): raise RuntimeError("base needs to be float")
3379
+
3380
+ # NOTE: pow(int, float) -> int
3381
+ ret = base._apply_uop(UOp.pow, exponent)
3382
+ return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
3094
3383
 
3095
3384
  def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
3096
3385
  """
@@ -3103,7 +3392,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3103
3392
  print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
3104
3393
  ```
3105
3394
  """
3106
- return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
3395
+ return self._apply_broadcasted_uop(UOp.maximum, x)
3107
3396
 
3108
3397
  def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
3109
3398
  """
@@ -3116,9 +3405,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3116
3405
  print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
3117
3406
  ```
3118
3407
  """
3119
- return -((-self).maximum(-x))
3408
+ t, x = self._broadcasted(x)
3409
+ return t._inverse().maximum(x._inverse())._inverse()
3120
3410
 
3121
- def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
3411
+ def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
3122
3412
  """
3123
3413
  Return a tensor of elements selected from either `x` or `y`, depending on `self`.
3124
3414
  `output_i = x_i if self_i else y_i`.
@@ -3140,7 +3430,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3140
3430
  elif isinstance(y, Tensor): y, x = y._broadcasted(x)
3141
3431
  cond, x = self._broadcasted(x, match_dtype=False)
3142
3432
  cond, y = cond._broadcasted(y, match_dtype=False)
3143
- return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
3433
+ return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
3144
3434
 
3145
3435
  def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
3146
3436
 
@@ -3170,9 +3460,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3170
3460
  def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
3171
3461
  def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
3172
3462
 
3173
- def lt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
3174
- def gt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
3175
- def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
3463
+ def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
3464
+ def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
3465
+ def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
3176
3466
 
3177
3467
  def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
3178
3468
 
@@ -3194,7 +3484,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3194
3484
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
3195
3485
  return x.add(bias) if bias is not None else x
3196
3486
 
3197
- def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
3487
+ def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
3198
3488
  """
3199
3489
  Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
3200
3490
 
@@ -3205,7 +3495,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3205
3495
  """
3206
3496
  return functools.reduce(lambda x,f: f(x), ll, self)
3207
3497
 
3208
- def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
3498
+ def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
3209
3499
  """
3210
3500
  Applies Layer Normalization over a mini-batch of inputs.
3211
3501
 
@@ -3224,7 +3514,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3224
3514
  y = (self - self.mean(axis, keepdim=True))
3225
3515
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
3226
3516
 
3227
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
3517
+ def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
3228
3518
  """
3229
3519
  Applies Batch Normalization over a mini-batch of inputs.
3230
3520
 
@@ -3266,6 +3556,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3266
3556
  if not Tensor.training or p == 0: return self
3267
3557
  return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
3268
3558
 
3559
+ # helper function commonly used for indexing
3560
+ def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
3561
+ if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
3562
+ offset = self.ndim - self._resolve_dim(dim) - 1
3563
+ return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
3564
+
3269
3565
  def one_hot(self, num_classes:int=-1) -> Tensor:
3270
3566
  """
3271
3567
  Converts `self` to a one-hot tensor.
@@ -3277,11 +3573,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3277
3573
  print(t.one_hot(5).numpy())
3278
3574
  ```
3279
3575
  """
3576
+ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
3280
3577
  if num_classes == -1: num_classes = (self.max()+1).item()
3281
- return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
3578
+ return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
3282
3579
 
3283
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
3284
- dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
3580
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
3285
3581
  """
3286
3582
  Computes scaled dot-product attention.
3287
3583
  `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
@@ -3298,14 +3594,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3298
3594
  """
3299
3595
  # NOTE: it also works when `key` and `value` have symbolic shape.
3300
3596
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
3301
- if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
3302
- if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
3303
3597
  qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
3304
- return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3598
+ # handle attention mask
3599
+ if is_causal:
3600
+ if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
3601
+ attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
3602
+ if attn_mask is not None:
3603
+ if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
3604
+ qk = qk + attn_mask
3605
+ return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
3305
3606
 
3306
3607
  def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
3307
3608
  if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
3308
- reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
3609
+ reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
3309
3610
  return reductions[reduction](self)
3310
3611
 
3311
3612
  def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
@@ -3354,8 +3655,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3354
3655
  assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
3355
3656
  assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
3356
3657
  log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
3357
- y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
3358
- y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
3658
+ y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
3659
+ y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
3359
3660
  smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
3360
3661
  unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
3361
3662
  # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
@@ -3469,7 +3770,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3469
3770
  """
3470
3771
  return dtypes.is_float(self.dtype)
3471
3772
 
3472
- def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
3773
+ def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
3473
3774
  """
3474
3775
  Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
3475
3776
 
@@ -3488,7 +3789,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3488
3789
  def llvm_bf16_cast(self, dtype:DTypeLike):
3489
3790
  # hack for devices that don't support bfloat16
3490
3791
  assert self.dtype == dtypes.bfloat16
3491
- return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
3792
+ return self.to("LLVM").cast(dtype)
3492
3793
 
3493
3794
  def cast(self, dtype:DTypeLike) -> Tensor:
3494
3795
  """
@@ -3502,8 +3803,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3502
3803
  t = t.cast(dtypes.int32)
3503
3804
  print(t.dtype, t.numpy())
3504
3805
  ```
3806
+ ```python exec="true" source="above" session="tensor" result="python"
3807
+ t = t.cast(dtypes.uint8)
3808
+ print(t.dtype, t.numpy())
3809
+ ```
3505
3810
  """
3506
- return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
3811
+ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
3812
+ # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
3813
+ return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
3814
+ return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
3507
3815
 
3508
3816
  def bitcast(self, dtype:DTypeLike) -> Tensor:
3509
3817
  """
@@ -3522,13 +3830,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3522
3830
  """
3523
3831
  if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
3524
3832
  dt = to_dtype(dtype)
3525
- if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
3526
- if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
3833
+ if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
3834
+ if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
3527
3835
  new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
3528
3836
  tmp = self.bitcast(old_uint)
3529
3837
  if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
3530
3838
  return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
3531
- return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
3839
+ return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
3532
3840
 
3533
3841
  def float(self) -> Tensor:
3534
3842
  """
@@ -3650,7 +3958,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
3650
3958
  else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
3651
3959
 
3652
3960
  # prepare input
3653
- x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
3961
+ x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
3654
3962
  x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
3655
3963
 
3656
3964
  # prepare weights
@@ -3702,5 +4010,5 @@ def _metadata_wrapper(fn):
3702
4010
 
3703
4011
  if TRACEMETA >= 1:
3704
4012
  for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
3705
- if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
4013
+ if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
3706
4014
  setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))