tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -2,7 +2,7 @@
2
2
  from __future__ import annotations
3
3
  import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
4
4
  from contextlib import ContextDecorator
5
- from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
5
+ from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
6
6
  from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
7
7
  from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
8
8
  from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
@@ -68,7 +68,7 @@ def get_shape(x) -> tuple[int, ...]:
68
68
  if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
69
69
  return (len(subs),) + (subs[0] if subs else ())
70
70
 
71
- def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
71
+ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
72
72
  if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
73
73
  else:
74
74
  ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
@@ -131,13 +131,7 @@ class Tensor(SimpleMathTrait):
131
131
  training: ClassVar[bool] = False
132
132
  no_grad: ClassVar[bool] = False
133
133
 
134
- def __new__(cls, *args, **kwargs):
135
- instance = super().__new__(cls)
136
- all_tensors.add(weakref.ref(instance))
137
- return instance
138
- def __del__(self): all_tensors.discard(weakref.ref(self))
139
-
140
- def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
134
+ def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
141
135
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
142
136
  if dtype is not None: dtype = to_dtype(dtype)
143
137
  if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
@@ -173,7 +167,7 @@ class Tensor(SimpleMathTrait):
173
167
  dtype = dtype or dtypes.uint8
174
168
  data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
175
169
 
176
- # by this point, it has to be a LazyBuffer
170
+ # by this point, it has to be a UOp
177
171
  if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
178
172
 
179
173
  # data might be on a different device
@@ -184,6 +178,19 @@ class Tensor(SimpleMathTrait):
184
178
  assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
185
179
  self.lazydata = data
186
180
 
181
+ # add to all_tensors after construction succeeds
182
+ all_tensors.add(weakref.ref(self))
183
+ def __del__(self): all_tensors.discard(weakref.ref(self))
184
+
185
+ def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
186
+ new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
187
+ needs_input_grad = [t.requires_grad for t in (self,)+x]
188
+ return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
189
+
190
+ def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
191
+ lhs,rhs = self._broadcasted(x, reverse)
192
+ return lhs._apply_uop(fxn, rhs)
193
+
187
194
  def requires_grad_(self, requires_grad=True) -> Tensor:
188
195
  self.requires_grad = requires_grad
189
196
  return self
@@ -221,17 +228,6 @@ class Tensor(SimpleMathTrait):
221
228
  @property
222
229
  def dtype(self) -> DType: return self.lazydata.dtype
223
230
 
224
- def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
225
- ret = Tensor.__new__(Tensor)
226
- needs_input_grad = [t.requires_grad for t in (self,)+x]
227
- ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
228
- ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
229
- return ret
230
-
231
- def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
232
- lhs,rhs = self._broadcasted(x, reverse)
233
- return lhs._apply_uop(fxn, rhs)
234
-
235
231
  # ***** data handlers ****
236
232
 
237
233
  def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
@@ -278,7 +274,7 @@ class Tensor(SimpleMathTrait):
278
274
  def assign(self, x) -> Tensor:
279
275
  # TODO: this is a hack for writing to DISK. remove with working assign
280
276
  if isinstance(self.device, str) and self.device.startswith("DISK"):
281
- if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
277
+ if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
282
278
  self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
283
279
  return self
284
280
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
@@ -301,11 +297,11 @@ class Tensor(SimpleMathTrait):
301
297
  def _data(self) -> memoryview:
302
298
  if 0 in self.shape: return memoryview(bytearray(0))
303
299
  # NOTE: this realizes on the object from as_buffer being a Python object
304
- cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
300
+ cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
305
301
  buf = cast(UOp, cpu.lazydata).base.realized
306
302
  assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
307
- if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
308
- return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
303
+ if self.device != "CPU": buf.options = BufferSpec(nolru=True)
304
+ return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
309
305
 
310
306
  def data(self) -> memoryview:
311
307
  """
@@ -333,16 +329,21 @@ class Tensor(SimpleMathTrait):
333
329
  assert self.numel() == 1, "must have one element for item"
334
330
  return self.data()[(0,) * len(self.shape)]
335
331
 
336
- # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
332
+ # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
337
333
  # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
338
334
  def tolist(self) -> Union[Sequence[ConstType], ConstType]:
339
335
  """
340
336
  Returns the value of this tensor as a nested list.
337
+ Returns single value for const tensor.
341
338
 
342
339
  ```python exec="true" source="above" session="tensor" result="python"
343
340
  t = Tensor([1, 2, 3, 4])
344
341
  print(t.tolist())
345
342
  ```
343
+ ```python exec="true" source="above" session="tensor" result="python"
344
+ t = Tensor(5)
345
+ print(t.tolist())
346
+ ```
346
347
  """
347
348
  return self.data().tolist()
348
349
 
@@ -519,8 +520,8 @@ class Tensor(SimpleMathTrait):
519
520
  if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
520
521
  num = ceildiv(numel * dtype.itemsize, 4)
521
522
 
522
- # when using MOCKGPU and NV generate rand on CLANG
523
- if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
523
+ # when using MOCKGPU and NV generate rand on CPU
524
+ if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
524
525
 
525
526
  # generate per device seeds and rng counter if we haven't seen this device yet
526
527
  if device not in Tensor._device_seeds:
@@ -1099,8 +1100,7 @@ class Tensor(SimpleMathTrait):
1099
1100
  def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
1100
1101
  # wrap single index into a list
1101
1102
  if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
1102
- # turn scalar Tensors into const val for int indexing if possible
1103
- x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
1103
+ x, indices = self, list(indices)
1104
1104
 
1105
1105
  # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
1106
1106
  if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
@@ -1117,7 +1117,7 @@ class Tensor(SimpleMathTrait):
1117
1117
  case list() | tuple() | Tensor():
1118
1118
  if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
1119
1119
  if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
1120
- index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
1120
+ index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
1121
1121
  case int() | UOp(): # sint
1122
1122
  if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
1123
1123
  boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
@@ -1190,7 +1190,7 @@ class Tensor(SimpleMathTrait):
1190
1190
  """
1191
1191
  Retrieve a sub-tensor using indexing.
1192
1192
 
1193
- Supported Index Types: `int | slice | Tensor | None | List | Tuple | Ellipsis`
1193
+ Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
1194
1194
 
1195
1195
  Examples:
1196
1196
  ```python exec="true" source="above" session="tensor" result="python"
@@ -1232,8 +1232,8 @@ class Tensor(SimpleMathTrait):
1232
1232
  return
1233
1233
  # NOTE: check that setitem target is valid first
1234
1234
  if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
1235
- if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1236
- if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
1235
+ if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
1236
+ if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1237
1237
  if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
1238
1238
 
1239
1239
  res = self.realize()._getitem(indices, v)
@@ -1715,6 +1715,28 @@ class Tensor(SimpleMathTrait):
1715
1715
  """
1716
1716
  return self.logical_not().any(axis, keepdim).logical_not()
1717
1717
 
1718
+ def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
1719
+ """
1720
+ Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance.
1721
+
1722
+ The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison.
1723
+
1724
+ By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close.
1725
+
1726
+ ```python exec="true" source="above" session="tensor" result="python"
1727
+ print(Tensor([1e-7, 1e-8, 1e-9, float('nan')]).isclose(Tensor([0.0, 0.0, 0.0, float('nan')])).numpy())
1728
+ ```
1729
+ ```python exec="true" source="above" session="tensor" result="python"
1730
+ print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
1731
+ ```
1732
+ """
1733
+ # TODO: Tensor.isfinite
1734
+ def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
1735
+ is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
1736
+ is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
1737
+ is_nan_close = (self.isnan() & other.isnan()) & equal_nan
1738
+ return is_finite_close | is_infinite_close | is_nan_close
1739
+
1718
1740
  def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1719
1741
  """
1720
1742
  Returns the mean value of the tensor along the specified axis or axes.
@@ -1911,8 +1933,16 @@ class Tensor(SimpleMathTrait):
1911
1933
  print(t.logcumsumexp(axis=1).numpy())
1912
1934
  ```
1913
1935
  """
1914
- m = self.max(axis=axis, keepdim=True)
1915
- return (self - m).exp().cumsum(axis=axis).log() + m
1936
+ if self.ndim == 0: return self
1937
+ axis = self._resolve_dim(axis)
1938
+ x = self.transpose(axis, -1)
1939
+ last_dim_size = x.shape[-1]
1940
+ x_reshaped = x.reshape(-1, last_dim_size)
1941
+ x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
1942
+ x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
1943
+ mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
1944
+ ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
1945
+ return ret.reshape(*x.shape).transpose(-1, axis)
1916
1946
 
1917
1947
  def argmax(self, axis=None, keepdim=False):
1918
1948
  """
@@ -2020,7 +2050,7 @@ class Tensor(SimpleMathTrait):
2020
2050
  o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
2021
2051
  if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
2022
2052
  # input size scaling factor to make sure shrink for stride is possible
2023
- f_ = [1 + int(resolve(o*s > i+d)) for o,s,i,d in zip(o_,s_,i_,d_)]
2053
+ f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
2024
2054
  # # repeats such that we don't need padding
2025
2055
  x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
2026
2056
  # handle dilation
@@ -2041,7 +2071,7 @@ class Tensor(SimpleMathTrait):
2041
2071
  raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
2042
2072
  return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
2043
2073
 
2044
- def _apply_ceil_mode(self, pads:Sequence[int], k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int]) -> List[int]:
2074
+ def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
2045
2075
  (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
2046
2076
  pads, grouped_pads = list(pads), _flat_to_grouped(pads)
2047
2077
  # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
@@ -2064,10 +2094,10 @@ class Tensor(SimpleMathTrait):
2064
2094
  1. `int` (single value):
2065
2095
  Applies the same padding value uniformly to all spatial dimensions.
2066
2096
 
2067
- 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2097
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2068
2098
  Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2069
2099
 
2070
- 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2100
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2071
2101
  Specifies explicit padding for each side of each spatial dimension in the form
2072
2102
  `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2073
2103
 
@@ -2111,10 +2141,10 @@ class Tensor(SimpleMathTrait):
2111
2141
  1. `int` (single value):
2112
2142
  Applies the same padding value uniformly to all spatial dimensions.
2113
2143
 
2114
- 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2144
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2115
2145
  Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2116
2146
 
2117
- 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2147
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2118
2148
  Specifies explicit padding for each side of each spatial dimension in the form
2119
2149
  `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2120
2150
 
@@ -2149,10 +2179,10 @@ class Tensor(SimpleMathTrait):
2149
2179
  1. `int` (single value):
2150
2180
  Applies the same padding value uniformly to all spatial dimensions.
2151
2181
 
2152
- 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2182
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2153
2183
  Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2154
2184
 
2155
- 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2185
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2156
2186
  Specifies explicit padding for each side of each spatial dimension in the form
2157
2187
  `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2158
2188
 
@@ -2222,10 +2252,10 @@ class Tensor(SimpleMathTrait):
2222
2252
  1. `int` (single value):
2223
2253
  Applies the same padding value uniformly to all spatial dimensions.
2224
2254
 
2225
- 2. `Tuple[int, ...]` (length = number of spatial dimensions):
2255
+ 2. `tuple[int, ...]` (length = number of spatial dimensions):
2226
2256
  Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
2227
2257
 
2228
- 3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2258
+ 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
2229
2259
  Specifies explicit padding for each side of each spatial dimension in the form
2230
2260
  `(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2231
2261
 
@@ -2423,18 +2453,35 @@ class Tensor(SimpleMathTrait):
2423
2453
  reshape[i] = expand[i] = size[i]
2424
2454
  if mode == "linear":
2425
2455
  index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
2426
- low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
2456
+ low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
2427
2457
  x = x.gather(i, low).lerp(x.gather(i, high), perc)
2428
2458
  else:
2429
2459
  index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
2430
2460
  x = x.gather(i, index)
2431
2461
  return x.cast(self.dtype)
2432
2462
 
2463
+ def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
2464
+ index, dim = index.to(self.device), self._resolve_dim(dim)
2465
+ assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2466
+ assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
2467
+ f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
2468
+ if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
2469
+ # shrink src to index shape to shrink away the unused values
2470
+ src = src.shrink(tuple((0,s) for s in index.shape))
2471
+ # prepare src and mask for reduce with respect to dim
2472
+ src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
2473
+ mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
2474
+ # pad src and mask to self.shape so that reduce can be done with padded values as no-ops
2475
+ src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
2476
+ return src, mask
2477
+
2433
2478
  def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
2434
2479
  """
2435
2480
  Scatters `src` values along an axis specified by `dim`.
2436
2481
  Apply `add` or `multiply` reduction operation with `reduce`.
2437
2482
 
2483
+ NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
2484
+
2438
2485
  ```python exec="true" source="above" session="tensor" result="python"
2439
2486
  src = Tensor.arange(1, 11).reshape(2, 5)
2440
2487
  print(src.numpy())
@@ -2455,22 +2502,55 @@ class Tensor(SimpleMathTrait):
2455
2502
  ```
2456
2503
  """
2457
2504
  if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
2458
- index, dim = index.to(self.device), self._resolve_dim(dim)
2459
- src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
2460
- assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
2461
- assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
2462
- f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
2463
- # shrink src to index shape to shrink away the unused values
2464
- src = src.shrink(tuple((0,s) for s in index.shape))
2465
- # prepare src and mask for reduce with respect to dim
2466
- src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
2467
- mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
2468
- # pad src and mask to self.shape so that reduce can be done with padded values as no-ops
2469
- src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
2470
- if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
2471
- if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
2505
+ if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
2506
+ if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype)
2507
+ if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
2508
+ if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
2509
+ src, mask = self._pre_scatter(dim, index, src)
2472
2510
  return _masked_setitem(self, src, mask, (-1,))
2473
2511
 
2512
+ def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
2513
+ include_self:bool=True) -> Tensor:
2514
+ """
2515
+ Scatters `src` values along an axis specified by `dim`.
2516
+ Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
2517
+
2518
+ Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
2519
+
2520
+ ```python exec="true" source="above" session="tensor" result="python"
2521
+ src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
2522
+ print(src.numpy())
2523
+ index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
2524
+ print(index.numpy())
2525
+ ```
2526
+ ```python exec="true" source="above" session="tensor" result="python"
2527
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
2528
+ ```
2529
+ ```python exec="true" source="above" session="tensor" result="python"
2530
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
2531
+ ```
2532
+ ```python exec="true" source="above" session="tensor" result="python"
2533
+ print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
2534
+ ```
2535
+ ```python exec="true" source="above" session="tensor" result="python"
2536
+ print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
2537
+ ```
2538
+ ```python exec="true" source="above" session="tensor" result="python"
2539
+ print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
2540
+ ```
2541
+ """
2542
+ src, mask = self._pre_scatter(dim, index, src)
2543
+ def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
2544
+ # TODO: should not overwrite acc_dtype here?
2545
+ if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
2546
+ if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
2547
+ if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
2548
+ if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
2549
+ if reduce == "mean":
2550
+ count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
2551
+ return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
2552
+ raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
2553
+
2474
2554
  # ***** unary ops *****
2475
2555
 
2476
2556
  def logical_not(self):
@@ -2600,7 +2680,7 @@ class Tensor(SimpleMathTrait):
2600
2680
  print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
2601
2681
  ```
2602
2682
  """
2603
- return self.reciprocal().sqrt()
2683
+ return self.sqrt().reciprocal()
2604
2684
  def sin(self):
2605
2685
  """
2606
2686
  Computes the sine of the tensor element-wise.
@@ -3085,11 +3165,6 @@ class Tensor(SimpleMathTrait):
3085
3165
  # broadcast
3086
3166
  return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
3087
3167
 
3088
- # TODO: tensor should stop checking if things are const
3089
- def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
3090
- return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, UOp) and x.lazydata.base.op is Ops.CONST \
3091
- and unwrap(x.lazydata.st).views[0].mask is None and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
3092
-
3093
3168
  def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3094
3169
  """
3095
3170
  Adds `self` and `x`.
@@ -3289,36 +3364,21 @@ class Tensor(SimpleMathTrait):
3289
3364
  Equivalent to `self ** x`.
3290
3365
 
3291
3366
  ```python exec="true" source="above" session="tensor" result="python"
3292
- print(Tensor([-1, 2, 3]).pow(2).numpy())
3367
+ print(Tensor([-1, 2, 3]).pow(2.0).numpy())
3293
3368
  ```
3294
3369
  ```python exec="true" source="above" session="tensor" result="python"
3295
3370
  print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
3296
3371
  ```
3297
3372
  ```python exec="true" source="above" session="tensor" result="python"
3298
- print((2 ** Tensor([-1, 2, 3])).numpy())
3373
+ print((2.0 ** Tensor([-1, 2, 3])).numpy())
3299
3374
  ```
3300
3375
  """
3301
- x = self._to_const_val(x)
3302
- if not isinstance(x, Tensor) and not reverse:
3303
- # simple pow identities
3304
- if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
3305
- if x == 0: return 1 + self * 0
3306
- # rewrite pow 0.5 to sqrt
3307
- if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
3308
- if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
3309
-
3310
- # positive const ** self
3311
- if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
3312
-
3313
3376
  base, exponent = self._broadcasted(x, reverse=reverse)
3314
3377
  # TODO: int pow
3315
3378
  if not base.is_floating_point(): raise RuntimeError("base needs to be float")
3316
- # start with b ** e = exp(e * log(b))
3317
- ret = base.abs().log().mul(exponent).exp()
3318
- # negative base adjustment: nan for non-integer exponent and -1 for odd exponent
3319
- adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
3320
- # fix 0 ** 0 = 1
3321
- ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
3379
+
3380
+ # NOTE: pow(int, float) -> int
3381
+ ret = base._apply_uop(UOp.pow, exponent)
3322
3382
  return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
3323
3383
 
3324
3384
  def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
@@ -3332,9 +3392,7 @@ class Tensor(SimpleMathTrait):
3332
3392
  print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
3333
3393
  ```
3334
3394
  """
3335
- # NOTE: the mid-point is for backward, revisit after new gradient API
3336
- if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
3337
- return (self<x).detach().where(x, self)
3395
+ return self._apply_broadcasted_uop(UOp.maximum, x)
3338
3396
 
3339
3397
  def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
3340
3398
  """
@@ -3500,6 +3558,7 @@ class Tensor(SimpleMathTrait):
3500
3558
 
3501
3559
  # helper function commonly used for indexing
3502
3560
  def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
3561
+ if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
3503
3562
  offset = self.ndim - self._resolve_dim(dim) - 1
3504
3563
  return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
3505
3564
 
@@ -3514,6 +3573,7 @@ class Tensor(SimpleMathTrait):
3514
3573
  print(t.one_hot(5).numpy())
3515
3574
  ```
3516
3575
  """
3576
+ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
3517
3577
  if num_classes == -1: num_classes = (self.max()+1).item()
3518
3578
  return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
3519
3579