tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
4
4
|
from contextlib import ContextDecorator
|
5
|
-
from typing import
|
5
|
+
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
6
6
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
7
7
|
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
8
8
|
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
@@ -68,7 +68,7 @@ def get_shape(x) -> tuple[int, ...]:
|
|
68
68
|
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
69
69
|
return (len(subs),) + (subs[0] if subs else ())
|
70
70
|
|
71
|
-
def _frompy(x:Union[
|
71
|
+
def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
|
72
72
|
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
73
73
|
else:
|
74
74
|
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
@@ -131,13 +131,7 @@ class Tensor(SimpleMathTrait):
|
|
131
131
|
training: ClassVar[bool] = False
|
132
132
|
no_grad: ClassVar[bool] = False
|
133
133
|
|
134
|
-
def
|
135
|
-
instance = super().__new__(cls)
|
136
|
-
all_tensors.add(weakref.ref(instance))
|
137
|
-
return instance
|
138
|
-
def __del__(self): all_tensors.discard(weakref.ref(self))
|
139
|
-
|
140
|
-
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
134
|
+
def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
141
135
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
142
136
|
if dtype is not None: dtype = to_dtype(dtype)
|
143
137
|
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
@@ -173,7 +167,7 @@ class Tensor(SimpleMathTrait):
|
|
173
167
|
dtype = dtype or dtypes.uint8
|
174
168
|
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
175
169
|
|
176
|
-
# by this point, it has to be a
|
170
|
+
# by this point, it has to be a UOp
|
177
171
|
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
178
172
|
|
179
173
|
# data might be on a different device
|
@@ -184,6 +178,19 @@ class Tensor(SimpleMathTrait):
|
|
184
178
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
185
179
|
self.lazydata = data
|
186
180
|
|
181
|
+
# add to all_tensors after construction succeeds
|
182
|
+
all_tensors.add(weakref.ref(self))
|
183
|
+
def __del__(self): all_tensors.discard(weakref.ref(self))
|
184
|
+
|
185
|
+
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
186
|
+
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
187
|
+
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
188
|
+
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
189
|
+
|
190
|
+
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
191
|
+
lhs,rhs = self._broadcasted(x, reverse)
|
192
|
+
return lhs._apply_uop(fxn, rhs)
|
193
|
+
|
187
194
|
def requires_grad_(self, requires_grad=True) -> Tensor:
|
188
195
|
self.requires_grad = requires_grad
|
189
196
|
return self
|
@@ -221,17 +228,6 @@ class Tensor(SimpleMathTrait):
|
|
221
228
|
@property
|
222
229
|
def dtype(self) -> DType: return self.lazydata.dtype
|
223
230
|
|
224
|
-
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
225
|
-
ret = Tensor.__new__(Tensor)
|
226
|
-
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
227
|
-
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
|
228
|
-
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
229
|
-
return ret
|
230
|
-
|
231
|
-
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
232
|
-
lhs,rhs = self._broadcasted(x, reverse)
|
233
|
-
return lhs._apply_uop(fxn, rhs)
|
234
|
-
|
235
231
|
# ***** data handlers ****
|
236
232
|
|
237
233
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
@@ -278,7 +274,7 @@ class Tensor(SimpleMathTrait):
|
|
278
274
|
def assign(self, x) -> Tensor:
|
279
275
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
280
276
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
281
|
-
if x.__class__ is not Tensor: x = Tensor(x, device="
|
277
|
+
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
282
278
|
self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
|
283
279
|
return self
|
284
280
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
@@ -301,11 +297,11 @@ class Tensor(SimpleMathTrait):
|
|
301
297
|
def _data(self) -> memoryview:
|
302
298
|
if 0 in self.shape: return memoryview(bytearray(0))
|
303
299
|
# NOTE: this realizes on the object from as_buffer being a Python object
|
304
|
-
cpu = self.cast(self.dtype.base).contiguous().to("
|
300
|
+
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
|
305
301
|
buf = cast(UOp, cpu.lazydata).base.realized
|
306
302
|
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
307
|
-
if self.device != "
|
308
|
-
return buf.as_buffer(allow_zero_copy=True if self.device != "
|
303
|
+
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
304
|
+
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
309
305
|
|
310
306
|
def data(self) -> memoryview:
|
311
307
|
"""
|
@@ -333,16 +329,21 @@ class Tensor(SimpleMathTrait):
|
|
333
329
|
assert self.numel() == 1, "must have one element for item"
|
334
330
|
return self.data()[(0,) * len(self.shape)]
|
335
331
|
|
336
|
-
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The
|
332
|
+
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
|
337
333
|
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
338
334
|
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
339
335
|
"""
|
340
336
|
Returns the value of this tensor as a nested list.
|
337
|
+
Returns single value for const tensor.
|
341
338
|
|
342
339
|
```python exec="true" source="above" session="tensor" result="python"
|
343
340
|
t = Tensor([1, 2, 3, 4])
|
344
341
|
print(t.tolist())
|
345
342
|
```
|
343
|
+
```python exec="true" source="above" session="tensor" result="python"
|
344
|
+
t = Tensor(5)
|
345
|
+
print(t.tolist())
|
346
|
+
```
|
346
347
|
"""
|
347
348
|
return self.data().tolist()
|
348
349
|
|
@@ -519,8 +520,8 @@ class Tensor(SimpleMathTrait):
|
|
519
520
|
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
520
521
|
num = ceildiv(numel * dtype.itemsize, 4)
|
521
522
|
|
522
|
-
# when using MOCKGPU and NV generate rand on
|
523
|
-
if getenv("MOCKGPU") and device.startswith("NV"): device = "
|
523
|
+
# when using MOCKGPU and NV generate rand on CPU
|
524
|
+
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
524
525
|
|
525
526
|
# generate per device seeds and rng counter if we haven't seen this device yet
|
526
527
|
if device not in Tensor._device_seeds:
|
@@ -1099,8 +1100,7 @@ class Tensor(SimpleMathTrait):
|
|
1099
1100
|
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
1100
1101
|
# wrap single index into a list
|
1101
1102
|
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
1102
|
-
|
1103
|
-
x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
1103
|
+
x, indices = self, list(indices)
|
1104
1104
|
|
1105
1105
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
1106
1106
|
if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
|
@@ -1117,7 +1117,7 @@ class Tensor(SimpleMathTrait):
|
|
1117
1117
|
case list() | tuple() | Tensor():
|
1118
1118
|
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
1119
1119
|
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
1120
|
-
index = (index.to(self.device) < 0).where(size,
|
1120
|
+
index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
|
1121
1121
|
case int() | UOp(): # sint
|
1122
1122
|
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
1123
1123
|
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
@@ -1190,7 +1190,7 @@ class Tensor(SimpleMathTrait):
|
|
1190
1190
|
"""
|
1191
1191
|
Retrieve a sub-tensor using indexing.
|
1192
1192
|
|
1193
|
-
Supported Index Types: `int | slice | Tensor | None |
|
1193
|
+
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
1194
1194
|
|
1195
1195
|
Examples:
|
1196
1196
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -1232,8 +1232,8 @@ class Tensor(SimpleMathTrait):
|
|
1232
1232
|
return
|
1233
1233
|
# NOTE: check that setitem target is valid first
|
1234
1234
|
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
1235
|
-
if
|
1236
|
-
if not isinstance(v, Tensor):
|
1235
|
+
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1236
|
+
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1237
1237
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
1238
1238
|
|
1239
1239
|
res = self.realize()._getitem(indices, v)
|
@@ -1715,6 +1715,28 @@ class Tensor(SimpleMathTrait):
|
|
1715
1715
|
"""
|
1716
1716
|
return self.logical_not().any(axis, keepdim).logical_not()
|
1717
1717
|
|
1718
|
+
def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
|
1719
|
+
"""
|
1720
|
+
Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance.
|
1721
|
+
|
1722
|
+
The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison.
|
1723
|
+
|
1724
|
+
By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close.
|
1725
|
+
|
1726
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1727
|
+
print(Tensor([1e-7, 1e-8, 1e-9, float('nan')]).isclose(Tensor([0.0, 0.0, 0.0, float('nan')])).numpy())
|
1728
|
+
```
|
1729
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1730
|
+
print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
|
1731
|
+
```
|
1732
|
+
"""
|
1733
|
+
# TODO: Tensor.isfinite
|
1734
|
+
def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
|
1735
|
+
is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
|
1736
|
+
is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
|
1737
|
+
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
1738
|
+
return is_finite_close | is_infinite_close | is_nan_close
|
1739
|
+
|
1718
1740
|
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1719
1741
|
"""
|
1720
1742
|
Returns the mean value of the tensor along the specified axis or axes.
|
@@ -1911,8 +1933,16 @@ class Tensor(SimpleMathTrait):
|
|
1911
1933
|
print(t.logcumsumexp(axis=1).numpy())
|
1912
1934
|
```
|
1913
1935
|
"""
|
1914
|
-
|
1915
|
-
|
1936
|
+
if self.ndim == 0: return self
|
1937
|
+
axis = self._resolve_dim(axis)
|
1938
|
+
x = self.transpose(axis, -1)
|
1939
|
+
last_dim_size = x.shape[-1]
|
1940
|
+
x_reshaped = x.reshape(-1, last_dim_size)
|
1941
|
+
x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
|
1942
|
+
x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
|
1943
|
+
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
|
1944
|
+
ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
|
1945
|
+
return ret.reshape(*x.shape).transpose(-1, axis)
|
1916
1946
|
|
1917
1947
|
def argmax(self, axis=None, keepdim=False):
|
1918
1948
|
"""
|
@@ -2020,7 +2050,7 @@ class Tensor(SimpleMathTrait):
|
|
2020
2050
|
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
2021
2051
|
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
2022
2052
|
# input size scaling factor to make sure shrink for stride is possible
|
2023
|
-
f_ = [1 + int(resolve(o*s > i
|
2053
|
+
f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
|
2024
2054
|
# # repeats such that we don't need padding
|
2025
2055
|
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
2026
2056
|
# handle dilation
|
@@ -2041,7 +2071,7 @@ class Tensor(SimpleMathTrait):
|
|
2041
2071
|
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
2042
2072
|
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
2043
2073
|
|
2044
|
-
def _apply_ceil_mode(self, pads:Sequence[int], k_:
|
2074
|
+
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
|
2045
2075
|
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
|
2046
2076
|
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
|
2047
2077
|
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
@@ -2064,10 +2094,10 @@ class Tensor(SimpleMathTrait):
|
|
2064
2094
|
1. `int` (single value):
|
2065
2095
|
Applies the same padding value uniformly to all spatial dimensions.
|
2066
2096
|
|
2067
|
-
2. `
|
2097
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2068
2098
|
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2069
2099
|
|
2070
|
-
3. `
|
2100
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2071
2101
|
Specifies explicit padding for each side of each spatial dimension in the form
|
2072
2102
|
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2073
2103
|
|
@@ -2111,10 +2141,10 @@ class Tensor(SimpleMathTrait):
|
|
2111
2141
|
1. `int` (single value):
|
2112
2142
|
Applies the same padding value uniformly to all spatial dimensions.
|
2113
2143
|
|
2114
|
-
2. `
|
2144
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2115
2145
|
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2116
2146
|
|
2117
|
-
3. `
|
2147
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2118
2148
|
Specifies explicit padding for each side of each spatial dimension in the form
|
2119
2149
|
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2120
2150
|
|
@@ -2149,10 +2179,10 @@ class Tensor(SimpleMathTrait):
|
|
2149
2179
|
1. `int` (single value):
|
2150
2180
|
Applies the same padding value uniformly to all spatial dimensions.
|
2151
2181
|
|
2152
|
-
2. `
|
2182
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2153
2183
|
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2154
2184
|
|
2155
|
-
3. `
|
2185
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2156
2186
|
Specifies explicit padding for each side of each spatial dimension in the form
|
2157
2187
|
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2158
2188
|
|
@@ -2222,10 +2252,10 @@ class Tensor(SimpleMathTrait):
|
|
2222
2252
|
1. `int` (single value):
|
2223
2253
|
Applies the same padding value uniformly to all spatial dimensions.
|
2224
2254
|
|
2225
|
-
2. `
|
2255
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2226
2256
|
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2227
2257
|
|
2228
|
-
3. `
|
2258
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2229
2259
|
Specifies explicit padding for each side of each spatial dimension in the form
|
2230
2260
|
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2231
2261
|
|
@@ -2423,18 +2453,35 @@ class Tensor(SimpleMathTrait):
|
|
2423
2453
|
reshape[i] = expand[i] = size[i]
|
2424
2454
|
if mode == "linear":
|
2425
2455
|
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
2426
|
-
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
2456
|
+
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
|
2427
2457
|
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
2428
2458
|
else:
|
2429
2459
|
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
|
2430
2460
|
x = x.gather(i, index)
|
2431
2461
|
return x.cast(self.dtype)
|
2432
2462
|
|
2463
|
+
def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
|
2464
|
+
index, dim = index.to(self.device), self._resolve_dim(dim)
|
2465
|
+
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
2466
|
+
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
2467
|
+
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
2468
|
+
if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
|
2469
|
+
# shrink src to index shape to shrink away the unused values
|
2470
|
+
src = src.shrink(tuple((0,s) for s in index.shape))
|
2471
|
+
# prepare src and mask for reduce with respect to dim
|
2472
|
+
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
2473
|
+
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
2474
|
+
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
2475
|
+
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
2476
|
+
return src, mask
|
2477
|
+
|
2433
2478
|
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
2434
2479
|
"""
|
2435
2480
|
Scatters `src` values along an axis specified by `dim`.
|
2436
2481
|
Apply `add` or `multiply` reduction operation with `reduce`.
|
2437
2482
|
|
2483
|
+
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
|
2484
|
+
|
2438
2485
|
```python exec="true" source="above" session="tensor" result="python"
|
2439
2486
|
src = Tensor.arange(1, 11).reshape(2, 5)
|
2440
2487
|
print(src.numpy())
|
@@ -2455,22 +2502,55 @@ class Tensor(SimpleMathTrait):
|
|
2455
2502
|
```
|
2456
2503
|
"""
|
2457
2504
|
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
2458
|
-
|
2459
|
-
|
2460
|
-
|
2461
|
-
|
2462
|
-
|
2463
|
-
# shrink src to index shape to shrink away the unused values
|
2464
|
-
src = src.shrink(tuple((0,s) for s in index.shape))
|
2465
|
-
# prepare src and mask for reduce with respect to dim
|
2466
|
-
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
2467
|
-
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
2468
|
-
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
2469
|
-
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
2470
|
-
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
|
2471
|
-
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
|
2505
|
+
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
|
2506
|
+
if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype)
|
2507
|
+
if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
|
2508
|
+
if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
|
2509
|
+
src, mask = self._pre_scatter(dim, index, src)
|
2472
2510
|
return _masked_setitem(self, src, mask, (-1,))
|
2473
2511
|
|
2512
|
+
def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
|
2513
|
+
include_self:bool=True) -> Tensor:
|
2514
|
+
"""
|
2515
|
+
Scatters `src` values along an axis specified by `dim`.
|
2516
|
+
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
|
2517
|
+
|
2518
|
+
Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
|
2519
|
+
|
2520
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2521
|
+
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
|
2522
|
+
print(src.numpy())
|
2523
|
+
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
2524
|
+
print(index.numpy())
|
2525
|
+
```
|
2526
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2527
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
|
2528
|
+
```
|
2529
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2530
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
|
2531
|
+
```
|
2532
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2533
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
|
2534
|
+
```
|
2535
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2536
|
+
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
|
2537
|
+
```
|
2538
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2539
|
+
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
|
2540
|
+
```
|
2541
|
+
"""
|
2542
|
+
src, mask = self._pre_scatter(dim, index, src)
|
2543
|
+
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
2544
|
+
# TODO: should not overwrite acc_dtype here?
|
2545
|
+
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
2546
|
+
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
2547
|
+
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
2548
|
+
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
2549
|
+
if reduce == "mean":
|
2550
|
+
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
2551
|
+
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
2552
|
+
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
2553
|
+
|
2474
2554
|
# ***** unary ops *****
|
2475
2555
|
|
2476
2556
|
def logical_not(self):
|
@@ -2600,7 +2680,7 @@ class Tensor(SimpleMathTrait):
|
|
2600
2680
|
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
2601
2681
|
```
|
2602
2682
|
"""
|
2603
|
-
return self.
|
2683
|
+
return self.sqrt().reciprocal()
|
2604
2684
|
def sin(self):
|
2605
2685
|
"""
|
2606
2686
|
Computes the sine of the tensor element-wise.
|
@@ -3085,11 +3165,6 @@ class Tensor(SimpleMathTrait):
|
|
3085
3165
|
# broadcast
|
3086
3166
|
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
3087
3167
|
|
3088
|
-
# TODO: tensor should stop checking if things are const
|
3089
|
-
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
3090
|
-
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, UOp) and x.lazydata.base.op is Ops.CONST \
|
3091
|
-
and unwrap(x.lazydata.st).views[0].mask is None and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
3092
|
-
|
3093
3168
|
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3094
3169
|
"""
|
3095
3170
|
Adds `self` and `x`.
|
@@ -3289,36 +3364,21 @@ class Tensor(SimpleMathTrait):
|
|
3289
3364
|
Equivalent to `self ** x`.
|
3290
3365
|
|
3291
3366
|
```python exec="true" source="above" session="tensor" result="python"
|
3292
|
-
print(Tensor([-1, 2, 3]).pow(2).numpy())
|
3367
|
+
print(Tensor([-1, 2, 3]).pow(2.0).numpy())
|
3293
3368
|
```
|
3294
3369
|
```python exec="true" source="above" session="tensor" result="python"
|
3295
3370
|
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
|
3296
3371
|
```
|
3297
3372
|
```python exec="true" source="above" session="tensor" result="python"
|
3298
|
-
print((2 ** Tensor([-1, 2, 3])).numpy())
|
3373
|
+
print((2.0 ** Tensor([-1, 2, 3])).numpy())
|
3299
3374
|
```
|
3300
3375
|
"""
|
3301
|
-
x = self._to_const_val(x)
|
3302
|
-
if not isinstance(x, Tensor) and not reverse:
|
3303
|
-
# simple pow identities
|
3304
|
-
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
|
3305
|
-
if x == 0: return 1 + self * 0
|
3306
|
-
# rewrite pow 0.5 to sqrt
|
3307
|
-
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
3308
|
-
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
3309
|
-
|
3310
|
-
# positive const ** self
|
3311
|
-
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
3312
|
-
|
3313
3376
|
base, exponent = self._broadcasted(x, reverse=reverse)
|
3314
3377
|
# TODO: int pow
|
3315
3378
|
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
3316
|
-
|
3317
|
-
|
3318
|
-
|
3319
|
-
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
|
3320
|
-
# fix 0 ** 0 = 1
|
3321
|
-
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
|
3379
|
+
|
3380
|
+
# NOTE: pow(int, float) -> int
|
3381
|
+
ret = base._apply_uop(UOp.pow, exponent)
|
3322
3382
|
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
3323
3383
|
|
3324
3384
|
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
@@ -3332,9 +3392,7 @@ class Tensor(SimpleMathTrait):
|
|
3332
3392
|
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
3333
3393
|
```
|
3334
3394
|
"""
|
3335
|
-
|
3336
|
-
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
3337
|
-
return (self<x).detach().where(x, self)
|
3395
|
+
return self._apply_broadcasted_uop(UOp.maximum, x)
|
3338
3396
|
|
3339
3397
|
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
3340
3398
|
"""
|
@@ -3500,6 +3558,7 @@ class Tensor(SimpleMathTrait):
|
|
3500
3558
|
|
3501
3559
|
# helper function commonly used for indexing
|
3502
3560
|
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
3561
|
+
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
|
3503
3562
|
offset = self.ndim - self._resolve_dim(dim) - 1
|
3504
3563
|
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
3505
3564
|
|
@@ -3514,6 +3573,7 @@ class Tensor(SimpleMathTrait):
|
|
3514
3573
|
print(t.one_hot(5).numpy())
|
3515
3574
|
```
|
3516
3575
|
"""
|
3576
|
+
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
|
3517
3577
|
if num_classes == -1: num_classes = (self.max()+1).item()
|
3518
3578
|
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
3519
3579
|
|