tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,47 +1,53 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
|
3
|
+
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
4
4
|
from contextlib import ContextDecorator
|
5
|
-
from typing import List, Tuple, Callable, Optional, ClassVar,
|
6
|
-
from collections import defaultdict
|
7
|
-
|
5
|
+
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
8
6
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
9
7
|
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
10
|
-
from tinygrad.helpers import IMAGE,
|
11
|
-
from tinygrad.multi import
|
12
|
-
from tinygrad.
|
13
|
-
from tinygrad.
|
14
|
-
from tinygrad.
|
8
|
+
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
9
|
+
from tinygrad.engine.multi import get_multi_map
|
10
|
+
from tinygrad.gradient import compute_gradient
|
11
|
+
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
12
|
+
from tinygrad.spec import tensor_uop_spec, type_verify
|
13
|
+
from tinygrad.device import Device, BufferSpec
|
15
14
|
from tinygrad.engine.realize import run_schedule
|
16
15
|
from tinygrad.engine.memory import memory_planner
|
17
16
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
18
17
|
|
19
|
-
#
|
18
|
+
# *** all in scope Tensors are here. this gets relevant UOps ***
|
20
19
|
|
21
|
-
|
22
|
-
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
23
|
-
self.device = device
|
24
|
-
self.needs_input_grad = [t.requires_grad for t in tensors]
|
25
|
-
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
26
|
-
if self.requires_grad: self.parents = tensors
|
27
|
-
self.metadata = metadata
|
20
|
+
all_tensors: set[weakref.ref[Tensor]] = set()
|
28
21
|
|
29
|
-
|
30
|
-
|
22
|
+
def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
23
|
+
# get all children of keys in applied_map
|
24
|
+
all_uops: set[UOp] = set()
|
25
|
+
search_uops = list(applied_map)
|
26
|
+
while len(search_uops):
|
27
|
+
x = search_uops.pop(0)
|
28
|
+
if x in all_uops: continue
|
29
|
+
all_uops.add(x)
|
30
|
+
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
32
|
+
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
33
|
+
# NOTE: this uses all_tensors, but it's fast
|
34
|
+
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
|
35
|
+
|
36
|
+
if len(fixed_tensors):
|
37
|
+
# potentially rewrite all the discovered Tensors
|
38
|
+
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
|
39
|
+
new_sink = sink.substitute(applied_map)
|
39
40
|
|
40
|
-
|
41
|
+
# set the relevant lazydata to the realized UOps
|
42
|
+
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
43
|
+
if s is ns: continue
|
44
|
+
t.lazydata = ns
|
41
45
|
|
42
|
-
|
43
|
-
|
44
|
-
|
46
|
+
# **** Tensor helper functions ****
|
47
|
+
|
48
|
+
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
49
|
+
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
50
|
+
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
45
51
|
|
46
52
|
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
47
53
|
import numpy as np
|
@@ -50,33 +56,31 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
|
|
50
56
|
import numpy as np
|
51
57
|
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
52
58
|
|
53
|
-
def _fromnp(x: 'np.ndarray') ->
|
54
|
-
ret =
|
59
|
+
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
60
|
+
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
55
61
|
# fake realize
|
56
62
|
ret.buffer.allocate(x)
|
57
|
-
del ret.srcs
|
58
63
|
return ret
|
59
64
|
|
60
|
-
def get_shape(x) ->
|
65
|
+
def get_shape(x) -> tuple[int, ...]:
|
61
66
|
# NOTE: str is special because __getitem__ on a str is still a str
|
62
67
|
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
|
63
68
|
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
64
69
|
return (len(subs),) + (subs[0] if subs else ())
|
65
70
|
|
66
|
-
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) ->
|
67
|
-
if isinstance(x, bytes): ret, data =
|
71
|
+
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
|
72
|
+
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
68
73
|
else:
|
69
|
-
ret =
|
74
|
+
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
70
75
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
71
76
|
truncate_function = truncate[dtype]
|
72
77
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
73
78
|
# fake realize
|
74
79
|
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
75
|
-
del ret.srcs
|
76
80
|
return ret
|
77
81
|
|
78
|
-
def _get_winograd_matcols(mat, dims:int, shp:
|
79
|
-
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
82
|
+
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
|
83
|
+
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
|
80
84
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
81
85
|
|
82
86
|
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
@@ -85,21 +89,34 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|
85
89
|
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
86
90
|
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
87
91
|
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
88
|
-
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
92
|
+
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
|
89
93
|
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
90
94
|
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
91
95
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
92
96
|
return ret
|
93
97
|
|
94
|
-
def
|
98
|
+
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
99
|
+
# unsqueeze left to make every shape same length
|
95
100
|
max_dim = max(len(shape) for shape in shapes)
|
96
101
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
97
|
-
def _broadcast_shape(*shapes:
|
98
|
-
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*
|
102
|
+
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
103
|
+
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
104
|
+
|
105
|
+
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
|
106
|
+
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
|
107
|
+
values = values * mask
|
108
|
+
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
109
|
+
# remove extra dims from reduce
|
110
|
+
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
111
|
+
# select from values for each True element in mask else select from self
|
112
|
+
return mask.where(values, target)
|
113
|
+
|
114
|
+
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
115
|
+
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
|
99
116
|
|
100
117
|
ReductionStr = Literal["mean", "sum", "none"]
|
101
118
|
|
102
|
-
class Tensor(SimpleMathTrait):
|
119
|
+
class Tensor(SimpleMathTrait):
|
103
120
|
"""
|
104
121
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
105
122
|
|
@@ -110,15 +127,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
110
127
|
np.set_printoptions(precision=4)
|
111
128
|
```
|
112
129
|
"""
|
113
|
-
__slots__ = "lazydata", "requires_grad", "grad"
|
114
|
-
__deletable__ = ('_ctx',)
|
130
|
+
__slots__ = "lazydata", "requires_grad", "grad"
|
115
131
|
training: ClassVar[bool] = False
|
116
132
|
no_grad: ClassVar[bool] = False
|
117
133
|
|
118
|
-
def
|
134
|
+
def __new__(cls, *args, **kwargs):
|
135
|
+
instance = super().__new__(cls)
|
136
|
+
all_tensors.add(weakref.ref(instance))
|
137
|
+
return instance
|
138
|
+
def __del__(self): all_tensors.discard(weakref.ref(self))
|
139
|
+
|
140
|
+
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
119
141
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
120
142
|
if dtype is not None: dtype = to_dtype(dtype)
|
121
|
-
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
122
143
|
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
123
144
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
124
145
|
|
@@ -129,21 +150,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
129
150
|
# None (the default) will be updated to True if it's put in an optimizer
|
130
151
|
self.requires_grad: Optional[bool] = requires_grad
|
131
152
|
|
132
|
-
# internal variable used for autograd graph construction
|
133
|
-
self._ctx: Optional[Function] = None
|
134
|
-
|
135
153
|
# create a LazyBuffer from the different types of inputs
|
136
|
-
if isinstance(data,
|
154
|
+
if isinstance(data, UOp):
|
155
|
+
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
156
|
+
# NOTE: this is here because LazyBuffer = UOp
|
157
|
+
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
|
137
158
|
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
|
138
159
|
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
139
|
-
elif isinstance(data, UOp):
|
140
|
-
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
|
141
|
-
data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
|
142
160
|
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
143
161
|
elif isinstance(data, (list, tuple)):
|
144
162
|
if dtype is None:
|
145
163
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
146
|
-
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
164
|
+
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
147
165
|
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
148
166
|
else: data = _frompy(data, dtype)
|
149
167
|
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
@@ -156,16 +174,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
156
174
|
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
157
175
|
|
158
176
|
# by this point, it has to be a LazyBuffer
|
159
|
-
if not isinstance(data,
|
177
|
+
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
160
178
|
|
161
179
|
# data might be on a different device
|
162
|
-
if isinstance(device, str): self.lazydata:
|
180
|
+
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
|
163
181
|
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
164
|
-
elif isinstance(data,
|
182
|
+
elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
|
165
183
|
else:
|
166
184
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
167
185
|
self.lazydata = data
|
168
186
|
|
187
|
+
def requires_grad_(self, requires_grad=True) -> Tensor:
|
188
|
+
self.requires_grad = requires_grad
|
189
|
+
return self
|
190
|
+
|
169
191
|
class train(ContextDecorator):
|
170
192
|
def __init__(self, mode:bool = True): self.mode = mode
|
171
193
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
@@ -177,7 +199,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
177
199
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
178
200
|
|
179
201
|
def __repr__(self):
|
180
|
-
|
202
|
+
ld = self.lazydata
|
203
|
+
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
204
|
+
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
181
205
|
|
182
206
|
# Python has a non moving GC, so this should be okay
|
183
207
|
def __hash__(self): return id(self)
|
@@ -189,26 +213,49 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
189
213
|
return self.shape[0]
|
190
214
|
|
191
215
|
@property
|
192
|
-
def device(self) -> Union[str,
|
216
|
+
def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
|
193
217
|
|
194
218
|
@property
|
195
|
-
def shape(self) ->
|
219
|
+
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
|
196
220
|
|
197
221
|
@property
|
198
222
|
def dtype(self) -> DType: return self.lazydata.dtype
|
199
223
|
|
224
|
+
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
225
|
+
ret = Tensor.__new__(Tensor)
|
226
|
+
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
227
|
+
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
|
228
|
+
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
229
|
+
return ret
|
230
|
+
|
231
|
+
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
232
|
+
lhs,rhs = self._broadcasted(x, reverse)
|
233
|
+
return lhs._apply_uop(fxn, rhs)
|
234
|
+
|
200
235
|
# ***** data handlers ****
|
201
236
|
|
202
|
-
def schedule_with_vars(self, *lst:Tensor) ->
|
237
|
+
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
203
238
|
"""
|
204
239
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
205
240
|
|
206
241
|
NOTE: A Tensor can only be scheduled once.
|
207
242
|
"""
|
208
|
-
|
243
|
+
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
244
|
+
|
245
|
+
# TODO: move this to scheduler tensor_map pass
|
246
|
+
if any(x.op is Ops.MULTI for x in big_sink.toposort):
|
247
|
+
# multi fixup
|
248
|
+
_apply_map_to_tensors(get_multi_map(big_sink))
|
249
|
+
big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
|
250
|
+
|
251
|
+
# verify Tensors match the spec
|
252
|
+
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
253
|
+
|
254
|
+
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
|
255
|
+
_apply_map_to_tensors(becomes_map)
|
209
256
|
return memory_planner(schedule), var_vals
|
210
257
|
|
211
|
-
def schedule(self, *lst:Tensor) ->
|
258
|
+
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
212
259
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
213
260
|
schedule, var_vals = self.schedule_with_vars(*lst)
|
214
261
|
assert len(var_vals) == 0
|
@@ -224,7 +271,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
224
271
|
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
225
272
|
"""
|
226
273
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
227
|
-
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
228
274
|
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
229
275
|
self.lazydata = x.lazydata
|
230
276
|
return self
|
@@ -232,17 +278,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
232
278
|
def assign(self, x) -> Tensor:
|
233
279
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
234
280
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
235
|
-
if x.__class__ is not Tensor: x = Tensor(x, device="
|
236
|
-
self.contiguous().realize().lazydata.base.realized.copyin(x.
|
281
|
+
if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
|
282
|
+
self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
|
237
283
|
return self
|
238
284
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
239
|
-
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
240
285
|
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
241
286
|
# NOTE: we allow cross device assign
|
242
287
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
243
288
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
244
289
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
245
|
-
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
246
290
|
assert not x.requires_grad # self requires_grad is okay?
|
247
291
|
if not self.lazydata.is_realized: return self.replace(x)
|
248
292
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
@@ -252,14 +296,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
252
296
|
"""
|
253
297
|
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
254
298
|
"""
|
255
|
-
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
299
|
+
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
|
256
300
|
|
257
301
|
def _data(self) -> memoryview:
|
258
302
|
if 0 in self.shape: return memoryview(bytearray(0))
|
259
303
|
# NOTE: this realizes on the object from as_buffer being a Python object
|
260
304
|
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
261
|
-
buf = cast(
|
262
|
-
|
305
|
+
buf = cast(UOp, cpu.lazydata).base.realized
|
306
|
+
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
307
|
+
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
|
263
308
|
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
264
309
|
|
265
310
|
def data(self) -> memoryview:
|
@@ -271,9 +316,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
271
316
|
print(np.frombuffer(t.data(), dtype=np.int32))
|
272
317
|
```
|
273
318
|
"""
|
274
|
-
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
319
|
+
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
275
320
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
276
|
-
|
321
|
+
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
322
|
+
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
277
323
|
|
278
324
|
def item(self) -> ConstType:
|
279
325
|
"""
|
@@ -284,11 +330,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
284
330
|
print(t.item())
|
285
331
|
```
|
286
332
|
"""
|
287
|
-
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
288
333
|
assert self.numel() == 1, "must have one element for item"
|
289
|
-
return self.
|
334
|
+
return self.data()[(0,) * len(self.shape)]
|
290
335
|
|
291
|
-
# TODO: should be Tensor.tolist() -> Union[
|
336
|
+
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
292
337
|
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
293
338
|
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
294
339
|
"""
|
@@ -311,21 +356,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
311
356
|
```
|
312
357
|
"""
|
313
358
|
import numpy as np
|
314
|
-
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
315
|
-
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
359
|
+
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
|
360
|
+
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
|
316
361
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
317
|
-
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
362
|
+
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
|
318
363
|
|
319
364
|
def clone(self) -> Tensor:
|
320
365
|
"""
|
321
|
-
Creates a clone of this tensor allocating a
|
366
|
+
Creates a clone of this tensor allocating a separate buffer for the data.
|
322
367
|
"""
|
323
368
|
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
|
324
369
|
if self.grad is not None: ret.grad = self.grad.clone()
|
325
|
-
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
326
370
|
return ret
|
327
371
|
|
328
|
-
def to(self, device:Optional[Union[str,
|
372
|
+
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
|
329
373
|
"""
|
330
374
|
Moves the tensor to the given device.
|
331
375
|
"""
|
@@ -334,47 +378,35 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
334
378
|
if not isinstance(device, str): return self.shard(device)
|
335
379
|
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
336
380
|
if self.grad is not None: ret.grad = self.grad.to(device)
|
337
|
-
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
338
381
|
return ret
|
339
382
|
|
340
|
-
def to_(self, device:Optional[Union[str,
|
383
|
+
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
|
341
384
|
"""
|
342
385
|
Moves the tensor to the given device in place.
|
343
386
|
"""
|
344
387
|
real = self.to(device)
|
345
|
-
|
346
|
-
|
347
|
-
self.lazydata = real.lazydata
|
388
|
+
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
389
|
+
return self.replace(real)
|
348
390
|
|
349
|
-
def shard(self, devices:
|
391
|
+
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
350
392
|
"""
|
351
|
-
Shards the tensor across the given devices. Optionally specify which axis to shard on
|
393
|
+
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
352
394
|
|
353
395
|
```python exec="true" source="above" session="tensor" result="python"
|
354
|
-
t = Tensor.empty(2,
|
355
|
-
print(t.shard((t.device, t.device), axis=1
|
396
|
+
t = Tensor.empty(2, 4)
|
397
|
+
print(t.shard((t.device, t.device), axis=1).lazydata)
|
356
398
|
```
|
357
|
-
|
358
399
|
"""
|
359
|
-
assert isinstance(self.
|
360
|
-
devices
|
361
|
-
if axis is not None
|
362
|
-
|
363
|
-
if splits is None:
|
364
|
-
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
365
|
-
sz = ceildiv(total, len(devices))
|
366
|
-
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
367
|
-
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
368
|
-
boundaries = tuple(itertools.accumulate(splits))
|
369
|
-
bounds = tuple(zip((0,) + boundaries, boundaries))
|
370
|
-
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
400
|
+
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
401
|
+
devices = tuple(Device.canonicalize(x) for x in devices)
|
402
|
+
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
|
403
|
+
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
371
404
|
|
372
|
-
def shard_(self, devices:
|
405
|
+
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
373
406
|
"""
|
374
407
|
Shards the tensor across the given devices in place.
|
375
408
|
"""
|
376
|
-
self.
|
377
|
-
return self
|
409
|
+
return self.replace(self.shard(devices, axis))
|
378
410
|
|
379
411
|
@staticmethod
|
380
412
|
def from_uop(y:UOp, **kwargs) -> Tensor:
|
@@ -382,18 +414,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
382
414
|
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
383
415
|
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
384
416
|
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
385
|
-
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
386
417
|
raise RuntimeError(f"unhandled UOp {y}")
|
387
418
|
|
388
419
|
# ***** creation entrypoint *****
|
389
420
|
|
390
421
|
@staticmethod
|
391
|
-
def _metaop(op, shape, device:Optional[Union[
|
422
|
+
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
392
423
|
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
393
424
|
if isinstance(device, tuple):
|
394
|
-
return Tensor(
|
425
|
+
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
395
426
|
device, dtype, **kwargs)
|
396
|
-
return Tensor(
|
427
|
+
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
397
428
|
|
398
429
|
@staticmethod
|
399
430
|
def empty(*shape, **kwargs):
|
@@ -411,7 +442,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
411
442
|
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
|
412
443
|
|
413
444
|
@staticmethod
|
414
|
-
def from_blob(ptr:int, shape:
|
445
|
+
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
415
446
|
"""
|
416
447
|
Exposes the pointer as a Tensor without taking ownership of the original data.
|
417
448
|
The pointer must remain valid for the entire lifetime of the created Tensor.
|
@@ -422,7 +453,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
422
453
|
|
423
454
|
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
|
424
455
|
r.lazydata.buffer.allocate(external_ptr=ptr)
|
425
|
-
del r.lazydata.srcs # fake realize
|
426
456
|
return r
|
427
457
|
|
428
458
|
@staticmethod
|
@@ -439,8 +469,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
439
469
|
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
|
440
470
|
|
441
471
|
_seed: int = int(time.time())
|
442
|
-
_device_seeds:
|
443
|
-
_device_rng_counters:
|
472
|
+
_device_seeds: dict[str, Tensor] = {}
|
473
|
+
_device_rng_counters: dict[str, Tensor] = {}
|
444
474
|
@staticmethod
|
445
475
|
def manual_seed(seed=0):
|
446
476
|
"""
|
@@ -462,7 +492,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
462
492
|
@staticmethod
|
463
493
|
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
464
494
|
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
465
|
-
x =
|
495
|
+
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
466
496
|
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
467
497
|
return counts0.cat(counts1)
|
468
498
|
|
@@ -485,6 +515,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
485
515
|
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
486
516
|
_device = device = Device.canonicalize(device)
|
487
517
|
|
518
|
+
# if shape has 0, return zero tensor
|
519
|
+
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
520
|
+
num = ceildiv(numel * dtype.itemsize, 4)
|
521
|
+
|
488
522
|
# when using MOCKGPU and NV generate rand on CLANG
|
489
523
|
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
|
490
524
|
|
@@ -494,15 +528,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
494
528
|
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
495
529
|
device=device, dtype=dtypes.uint32, requires_grad=False)
|
496
530
|
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
497
|
-
had_counter = False
|
498
|
-
else: had_counter = True
|
499
|
-
|
500
|
-
# if shape has 0, return zero tensor
|
501
|
-
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
502
|
-
num = ceildiv(numel * dtype.itemsize, 4)
|
503
|
-
|
504
531
|
# increment rng counter for devices
|
505
|
-
|
532
|
+
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
506
533
|
|
507
534
|
# threefry random bits
|
508
535
|
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
@@ -528,7 +555,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
528
555
|
# ***** creation helper functions *****
|
529
556
|
|
530
557
|
@staticmethod
|
531
|
-
def full(shape:
|
558
|
+
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
532
559
|
"""
|
533
560
|
Creates a tensor with the given shape, filled with the given value.
|
534
561
|
|
@@ -607,7 +634,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
607
634
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
608
635
|
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
609
636
|
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
610
|
-
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs).
|
637
|
+
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
611
638
|
|
612
639
|
@staticmethod
|
613
640
|
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
@@ -705,18 +732,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
705
732
|
```
|
706
733
|
"""
|
707
734
|
dtype = kwargs.pop("dtype", self.dtype)
|
708
|
-
if isinstance(self.device, tuple)
|
735
|
+
if isinstance(self.device, tuple):
|
709
736
|
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
710
737
|
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
711
738
|
contiguous = kwargs.pop("contiguous", True)
|
712
|
-
|
713
|
-
|
739
|
+
sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
|
740
|
+
rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
|
741
|
+
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
714
742
|
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
715
743
|
|
716
744
|
# ***** rng hlops *****
|
717
745
|
|
718
746
|
@staticmethod
|
719
|
-
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
747
|
+
def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
720
748
|
"""
|
721
749
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
722
750
|
If `dtype` is not specified, the default type is used.
|
@@ -731,10 +759,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
731
759
|
"""
|
732
760
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
733
761
|
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
734
|
-
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
762
|
+
return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
|
735
763
|
|
736
764
|
@staticmethod
|
737
|
-
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
765
|
+
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
738
766
|
"""
|
739
767
|
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
740
768
|
If `dtype` is not specified, the default type is used.
|
@@ -748,12 +776,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
748
776
|
```
|
749
777
|
"""
|
750
778
|
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
751
|
-
dtype = to_dtype(
|
779
|
+
dtype = to_dtype(dtype)
|
752
780
|
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
753
781
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
754
782
|
|
755
783
|
@staticmethod
|
756
|
-
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
784
|
+
def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
757
785
|
"""
|
758
786
|
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
759
787
|
|
@@ -765,10 +793,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
765
793
|
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
766
794
|
```
|
767
795
|
"""
|
768
|
-
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
796
|
+
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
|
769
797
|
|
770
798
|
@staticmethod
|
771
|
-
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
799
|
+
def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
772
800
|
"""
|
773
801
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
774
802
|
|
@@ -780,8 +808,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
780
808
|
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
781
809
|
```
|
782
810
|
"""
|
783
|
-
|
784
|
-
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
811
|
+
return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
|
785
812
|
|
786
813
|
@staticmethod
|
787
814
|
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
@@ -860,49 +887,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
860
887
|
|
861
888
|
# ***** toposort and backward pass *****
|
862
889
|
|
863
|
-
def
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
890
|
+
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
|
891
|
+
"""
|
892
|
+
Compute the gradient of the targets with respect to self.
|
893
|
+
|
894
|
+
```python exec="true" source="above" session="tensor" result="python"
|
895
|
+
x = Tensor.eye(3)
|
896
|
+
y = Tensor([[2.0,0,-2.0]])
|
897
|
+
z = y.matmul(x).sum()
|
898
|
+
dx, dy = z.gradient(x, y)
|
899
|
+
|
900
|
+
print(dx.tolist()) # dz/dx
|
901
|
+
print(dy.tolist()) # dz/dy
|
902
|
+
```
|
903
|
+
"""
|
904
|
+
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
905
|
+
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
906
|
+
rets = []
|
907
|
+
target_uops = [x.lazydata for x in targets]
|
908
|
+
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
909
|
+
ret = []
|
910
|
+
for x in target_uops:
|
911
|
+
if (y:=grads.get(x)) is None:
|
912
|
+
if materialize_grads: y = x.const_like(0)
|
913
|
+
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
914
|
+
ret.append(y)
|
915
|
+
rets.append(ret)
|
916
|
+
# create returned Tensors
|
917
|
+
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
873
918
|
|
874
|
-
def backward(self, gradient:Optional[Tensor]=None
|
919
|
+
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
875
920
|
"""
|
876
921
|
Propagates the gradient of a tensor backwards through the computation graph.
|
877
922
|
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
878
|
-
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
879
923
|
```python exec="true" source="above" session="tensor" result="python"
|
880
924
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
881
925
|
t.sum().backward()
|
882
926
|
print(t.grad.numpy())
|
883
927
|
```
|
884
928
|
"""
|
885
|
-
|
886
|
-
if
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
893
|
-
self.grad = gradient
|
894
|
-
for t0 in reversed(toposorted):
|
895
|
-
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
896
|
-
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
897
|
-
grads = t0._ctx.backward(t0.grad.lazydata)
|
898
|
-
_METADATA.reset(token)
|
899
|
-
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
900
|
-
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
901
|
-
for t, g in zip(t0._ctx.parents, grads):
|
902
|
-
if g is not None and t.requires_grad:
|
903
|
-
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
904
|
-
t.grad = g if t.grad is None else (t.grad + g)
|
905
|
-
if not retain_graph: del t0._ctx
|
929
|
+
all_uops = self.lazydata.toposort
|
930
|
+
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
931
|
+
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
|
932
|
+
# clear contexts
|
933
|
+
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
934
|
+
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
935
|
+
t.grad = g if t.grad is None else (t.grad + g)
|
906
936
|
return self
|
907
937
|
|
908
938
|
# ***** movement low level ops *****
|
@@ -926,7 +956,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
926
956
|
# resolve -1
|
927
957
|
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
928
958
|
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
929
|
-
return
|
959
|
+
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
930
960
|
|
931
961
|
def expand(self, shape, *args) -> Tensor:
|
932
962
|
"""
|
@@ -940,7 +970,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
940
970
|
print(t.expand(4, -1).numpy())
|
941
971
|
```
|
942
972
|
"""
|
943
|
-
|
973
|
+
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
974
|
+
return self._broadcast_to(new_shape)
|
944
975
|
|
945
976
|
def permute(self, order, *args) -> Tensor:
|
946
977
|
"""
|
@@ -958,7 +989,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
958
989
|
"""
|
959
990
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
960
991
|
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
961
|
-
return
|
992
|
+
return self._apply_uop(UOp.permute, arg=order_arg)
|
962
993
|
|
963
994
|
def flip(self, axis, *args) -> Tensor:
|
964
995
|
"""
|
@@ -978,9 +1009,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
978
1009
|
"""
|
979
1010
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
980
1011
|
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
981
|
-
return
|
1012
|
+
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
982
1013
|
|
983
|
-
def shrink(self, arg:
|
1014
|
+
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
984
1015
|
"""
|
985
1016
|
Returns a tensor that shrinks the each axis based on input arg.
|
986
1017
|
`arg` must have the same length as `self.ndim`.
|
@@ -998,24 +1029,25 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
998
1029
|
```
|
999
1030
|
"""
|
1000
1031
|
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
1001
|
-
return
|
1032
|
+
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
1002
1033
|
|
1003
|
-
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[
|
1034
|
+
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
1004
1035
|
"""
|
1005
1036
|
Returns a tensor with padding applied based on the input `padding`.
|
1037
|
+
|
1006
1038
|
`padding` supports two padding structures:
|
1007
1039
|
|
1008
|
-
1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
|
1009
|
-
|
1010
|
-
|
1040
|
+
1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
|
1041
|
+
- This structure matches PyTorch's pad.
|
1042
|
+
- `padding` length must be even.
|
1011
1043
|
|
1012
|
-
2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1044
|
+
2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
1045
|
+
- This structure matches pad for JAX, NumPy, TensorFlow, and others.
|
1046
|
+
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
|
1047
|
+
- `padding` must have the same length as `self.ndim`.
|
1016
1048
|
|
1017
1049
|
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
|
1018
|
-
Padding modes is selected with `mode` which supports `constant` and `
|
1050
|
+
Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
|
1019
1051
|
|
1020
1052
|
```python exec="true" source="above" session="tensor" result="python"
|
1021
1053
|
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
@@ -1031,176 +1063,167 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1031
1063
|
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
|
1032
1064
|
```
|
1033
1065
|
"""
|
1034
|
-
if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1066
|
+
if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
|
1067
|
+
# flat padding
|
1068
|
+
if all(isinstance(p, (int,UOp)) for p in padding):
|
1069
|
+
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
1070
|
+
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
|
1071
|
+
# group padding
|
1072
|
+
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
|
1038
1073
|
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
1039
|
-
X,
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1074
|
+
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
1075
|
+
if mode == "constant":
|
1076
|
+
def _constant(x:Tensor,px,v):
|
1077
|
+
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
1078
|
+
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
1079
|
+
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
1045
1080
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
1081
|
+
if mode == "circular":
|
1082
|
+
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
1083
|
+
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
1084
|
+
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
|
1085
|
+
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
|
1046
1086
|
for d,(pB,pA) in enumerate(pads):
|
1047
|
-
if
|
1048
|
-
|
1049
|
-
|
1087
|
+
if mode == "reflect":
|
1088
|
+
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
1089
|
+
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
1090
|
+
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
1091
|
+
if mode == "replicate":
|
1092
|
+
shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
|
1093
|
+
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
|
1050
1094
|
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
1051
|
-
return X.shrink(
|
1095
|
+
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
1052
1096
|
|
1053
1097
|
# ***** movement high level ops *****
|
1054
1098
|
|
1055
|
-
# Supported Indexing Implementations:
|
1056
|
-
# 1. Int indexing (no copy)
|
1057
|
-
# - for all dims where there's int, shrink -> reshape
|
1058
|
-
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
1059
|
-
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
1060
|
-
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
1061
|
-
# 2. Slice indexing (no copy)
|
1062
|
-
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
1063
|
-
# - first shrink the Tensor to X.shrink(((start, end),))
|
1064
|
-
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
1065
|
-
# - flip where dim value is negative
|
1066
|
-
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
1067
|
-
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
1068
|
-
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
1069
|
-
# 3. None indexing (no copy)
|
1070
|
-
# - reshape (inject) a dim at the dim where there's None
|
1071
|
-
# 4. Tensor indexing (copy)
|
1072
|
-
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
1073
|
-
# - combine masks together with mul
|
1074
|
-
# - apply mask to self by mask * self
|
1075
|
-
# - sum reduce away the extra dims added from creating masks
|
1076
|
-
# Tiny Things:
|
1077
|
-
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
1078
|
-
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
1079
|
-
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
1080
|
-
# 2. Bool indexing is not supported
|
1081
|
-
# 3. Out of bounds Tensor indexing results in 0
|
1082
|
-
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
1083
1099
|
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
1084
|
-
#
|
1085
|
-
|
1086
|
-
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
1087
|
-
elif isinstance(indices, (tuple, list)):
|
1088
|
-
indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
1089
|
-
else: indices = [indices]
|
1090
|
-
|
1100
|
+
# wrap single index into a list
|
1101
|
+
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
1091
1102
|
# turn scalar Tensors into const val for int indexing if possible
|
1092
|
-
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
1093
|
-
# move Tensor indices to the same device as self
|
1094
|
-
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
1103
|
+
x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
1095
1104
|
|
1096
1105
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
1097
|
-
ellipsis_idx
|
1106
|
+
if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
|
1098
1107
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
1099
1108
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
1100
|
-
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
1101
|
-
|
1102
|
-
# use Dict[type, List[dimension]] to track elements in indices
|
1103
|
-
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
1104
|
-
|
1105
|
-
# record None for dimension injection later and filter None and record rest of indices
|
1106
|
-
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
1107
|
-
indices_filtered = [i for i in indices if i is not None]
|
1108
|
-
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
1109
|
-
|
1110
|
-
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
1111
|
-
for index_type in type_dim:
|
1112
|
-
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
1113
1109
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
1110
|
+
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
1114
1111
|
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
1168
|
-
|
1169
|
-
# create masks
|
1170
|
-
for dim, i in idx.items():
|
1171
|
-
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
1112
|
+
indices_parsed, dim = [], 0
|
1113
|
+
for index in indices:
|
1114
|
+
size = 1 if index is None else self.shape[dim]
|
1115
|
+
boundary, stride = [0, size], 1 # defaults
|
1116
|
+
match index:
|
1117
|
+
case list() | tuple() | Tensor():
|
1118
|
+
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
1119
|
+
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
1120
|
+
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
|
1121
|
+
case int() | UOp(): # sint
|
1122
|
+
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
1123
|
+
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
1124
|
+
case slice():
|
1125
|
+
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
1126
|
+
if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
|
1127
|
+
# handle int slicing
|
1128
|
+
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
1129
|
+
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
1130
|
+
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
1131
|
+
# update size for slice
|
1132
|
+
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
1133
|
+
case None: pass # do nothing
|
1134
|
+
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
|
1135
|
+
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
|
1136
|
+
if index is not None: dim += 1
|
1137
|
+
|
1138
|
+
# movement op indexing
|
1139
|
+
if mops := [i for i in indices_parsed if i['index'] is not None]:
|
1140
|
+
# flip negative strides
|
1141
|
+
shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
|
1142
|
+
x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
|
1143
|
+
# handle stride != 1 or -1
|
1144
|
+
if any(abs(st) != 1 for st in strides):
|
1145
|
+
strides = tuple(abs(s) for s in strides)
|
1146
|
+
# pad shape to multiple of stride
|
1147
|
+
if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
|
1148
|
+
x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
|
1149
|
+
x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
|
1150
|
+
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
|
1151
|
+
|
1152
|
+
# dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
|
1153
|
+
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
|
1154
|
+
|
1155
|
+
# tensor indexing
|
1156
|
+
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
|
1157
|
+
# unload the tensor object into actual tensors
|
1158
|
+
dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
|
1159
|
+
pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:]
|
1160
|
+
|
1161
|
+
# create index masks
|
1162
|
+
for dim, tensor in zip(dims, tensors):
|
1163
|
+
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
|
1172
1164
|
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
1173
|
-
|
1174
|
-
masks.append(i == a)
|
1165
|
+
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
|
1175
1166
|
|
1176
1167
|
# reduce masks to 1 mask
|
1177
1168
|
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
1178
1169
|
|
1179
1170
|
# inject 1's for the extra dims added in create masks
|
1180
|
-
reshape_arg =
|
1171
|
+
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
1181
1172
|
# sum reduce the extra dims introduced in create masks
|
1182
|
-
|
1173
|
+
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
|
1183
1174
|
|
1184
1175
|
# special permute case
|
1185
|
-
if
|
1186
|
-
|
1176
|
+
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
|
1177
|
+
x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
|
1187
1178
|
|
1188
1179
|
# for advanced setitem, returns whole tensor with indices replaced
|
1189
1180
|
if v is not None:
|
1190
|
-
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(
|
1181
|
+
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
1191
1182
|
# add back reduced dims from sum
|
1192
1183
|
for dim in sum_axis: vb = vb.unsqueeze(dim)
|
1193
|
-
# axis to be reduced to match self.shape
|
1194
|
-
|
1195
|
-
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
|
1196
|
-
vb = vb * mask
|
1197
|
-
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
|
1198
|
-
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
|
1199
|
-
ret = mask.any(axis).where(vb.squeeze(), self)
|
1184
|
+
# run _masked_setitem on tuple of axis that is to be reduced to match self.shape
|
1185
|
+
x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
|
1200
1186
|
|
1201
|
-
return
|
1187
|
+
return x
|
1202
1188
|
|
1203
1189
|
def __getitem__(self, indices) -> Tensor:
|
1190
|
+
"""
|
1191
|
+
Retrieve a sub-tensor using indexing.
|
1192
|
+
|
1193
|
+
Supported Index Types: `int | slice | Tensor | None | List | Tuple | Ellipsis`
|
1194
|
+
|
1195
|
+
Examples:
|
1196
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1197
|
+
t = Tensor.arange(12).reshape(3, 4)
|
1198
|
+
print(t.numpy())
|
1199
|
+
```
|
1200
|
+
|
1201
|
+
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
1202
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1203
|
+
print(t[1, 2].numpy())
|
1204
|
+
```
|
1205
|
+
|
1206
|
+
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
1207
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1208
|
+
print(t[0:2, ::2].numpy())
|
1209
|
+
```
|
1210
|
+
|
1211
|
+
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
1212
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1213
|
+
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
1214
|
+
```
|
1215
|
+
|
1216
|
+
- `None` Indexing: Add a new dimension to the tensor.
|
1217
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1218
|
+
print(t[:, None].shape)
|
1219
|
+
```
|
1220
|
+
|
1221
|
+
NOTE: Out-of-bounds indexing results in a value of `0`.
|
1222
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1223
|
+
t = Tensor([1, 2, 3])
|
1224
|
+
print(t[Tensor([4, 3, 2])].numpy())
|
1225
|
+
```
|
1226
|
+
"""
|
1204
1227
|
return self._getitem(indices)
|
1205
1228
|
|
1206
1229
|
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
@@ -1208,7 +1231,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1208
1231
|
self._getitem(indices).assign(v)
|
1209
1232
|
return
|
1210
1233
|
# NOTE: check that setitem target is valid first
|
1211
|
-
if not
|
1234
|
+
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
1212
1235
|
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1213
1236
|
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1214
1237
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
@@ -1238,7 +1261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1238
1261
|
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1239
1262
|
index = index.to(self.device)
|
1240
1263
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1241
|
-
return (
|
1264
|
+
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
|
1242
1265
|
|
1243
1266
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1244
1267
|
"""
|
@@ -1302,7 +1325,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1302
1325
|
```
|
1303
1326
|
"""
|
1304
1327
|
repeats = argfix(repeats, *args)
|
1305
|
-
base_shape =
|
1328
|
+
base_shape = _align_left(self.shape, repeats)[0]
|
1306
1329
|
unsqueezed_shape = flatten([[1, s] for s in base_shape])
|
1307
1330
|
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
|
1308
1331
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
@@ -1313,7 +1336,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1313
1336
|
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
1314
1337
|
return dim + total if dim < 0 else dim
|
1315
1338
|
|
1316
|
-
def split(self, sizes:Union[int,
|
1339
|
+
def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
|
1317
1340
|
"""
|
1318
1341
|
Splits the tensor into chunks along the dimension specified by `dim`.
|
1319
1342
|
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
@@ -1338,7 +1361,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1338
1361
|
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
1339
1362
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
1340
1363
|
|
1341
|
-
def chunk(self, chunks:int, dim:int=0) ->
|
1364
|
+
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
|
1342
1365
|
"""
|
1343
1366
|
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
1344
1367
|
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
@@ -1362,7 +1385,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1362
1385
|
dim = self._resolve_dim(dim)
|
1363
1386
|
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
1364
1387
|
|
1365
|
-
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") ->
|
1388
|
+
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
|
1366
1389
|
"""
|
1367
1390
|
Generates coordinate matrices from coordinate vectors.
|
1368
1391
|
Input tensors can be scalars or 1D tensors.
|
@@ -1462,7 +1485,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1462
1485
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
1463
1486
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
1464
1487
|
|
1465
|
-
def unflatten(self, dim:int, sizes:
|
1488
|
+
def unflatten(self, dim:int, sizes:tuple[int,...]):
|
1466
1489
|
"""
|
1467
1490
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
1468
1491
|
|
@@ -1479,7 +1502,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1479
1502
|
dim = self._resolve_dim(dim)
|
1480
1503
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
1481
1504
|
|
1482
|
-
def roll(self, shifts:Union[int,
|
1505
|
+
def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
|
1483
1506
|
"""
|
1484
1507
|
Rolls the tensor along specified dimension(s).
|
1485
1508
|
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
@@ -1499,12 +1522,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1499
1522
|
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
1500
1523
|
return rolled
|
1501
1524
|
|
1525
|
+
def rearrange(self, formula:str, **sizes) -> Tensor:
|
1526
|
+
"""
|
1527
|
+
Rearranges input according to formula
|
1528
|
+
|
1529
|
+
See: https://einops.rocks/api/rearrange/
|
1530
|
+
|
1531
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1532
|
+
x = Tensor([[1, 2], [3, 4]])
|
1533
|
+
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
1534
|
+
```
|
1535
|
+
"""
|
1536
|
+
def parse_formula(formula: str):
|
1537
|
+
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
1538
|
+
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
1539
|
+
pairs = list(zip(lparens, rparens))
|
1540
|
+
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
1541
|
+
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
1542
|
+
|
1543
|
+
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
1544
|
+
|
1545
|
+
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
1546
|
+
|
1547
|
+
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
1548
|
+
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
1549
|
+
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
1550
|
+
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
1551
|
+
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
1552
|
+
|
1553
|
+
# resolve ellipsis
|
1554
|
+
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
1555
|
+
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
1556
|
+
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
1557
|
+
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
1558
|
+
|
1559
|
+
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
1560
|
+
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
1561
|
+
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
1562
|
+
t = t.permute([lhs.index(name) for name in rhs])
|
1563
|
+
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1564
|
+
|
1502
1565
|
# ***** reduce ops *****
|
1503
1566
|
|
1504
|
-
def _reduce(self,
|
1567
|
+
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
1505
1568
|
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
1506
1569
|
if self.ndim == 0: axis = ()
|
1507
|
-
ret =
|
1570
|
+
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
1508
1571
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
1509
1572
|
|
1510
1573
|
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
@@ -1531,7 +1594,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1531
1594
|
print(t.sum(axis=1).numpy())
|
1532
1595
|
```
|
1533
1596
|
"""
|
1534
|
-
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(
|
1597
|
+
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
1535
1598
|
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1536
1599
|
|
1537
1600
|
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
@@ -1558,7 +1621,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1558
1621
|
print(t.prod(axis=1).numpy())
|
1559
1622
|
```
|
1560
1623
|
"""
|
1561
|
-
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(
|
1624
|
+
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
1562
1625
|
|
1563
1626
|
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1564
1627
|
"""
|
@@ -1581,7 +1644,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1581
1644
|
print(t.max(axis=1, keepdim=True).numpy())
|
1582
1645
|
```
|
1583
1646
|
"""
|
1584
|
-
return self._reduce(
|
1647
|
+
return self._reduce(Ops.MAX, axis, keepdim)
|
1648
|
+
|
1649
|
+
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
1585
1650
|
|
1586
1651
|
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1587
1652
|
"""
|
@@ -1604,8 +1669,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1604
1669
|
print(t.min(axis=1, keepdim=True).numpy())
|
1605
1670
|
```
|
1606
1671
|
"""
|
1607
|
-
|
1608
|
-
return -((-self).max(axis=axis, keepdim=keepdim))
|
1672
|
+
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
|
1609
1673
|
|
1610
1674
|
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1611
1675
|
"""
|
@@ -1745,8 +1809,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1745
1809
|
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
1746
1810
|
|
1747
1811
|
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
1748
|
-
|
1749
|
-
|
1812
|
+
m = self - self.max(axis=axis, keepdim=True).detach()
|
1813
|
+
if dtype is not None: m = m.cast(dtype)
|
1750
1814
|
e = m.exp()
|
1751
1815
|
return m, e, e.sum(axis=axis, keepdim=True)
|
1752
1816
|
|
@@ -1898,47 +1962,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1898
1962
|
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
1899
1963
|
```
|
1900
1964
|
"""
|
1901
|
-
return (
|
1902
|
-
|
1903
|
-
def rearrange(self, formula: str, **sizes) -> Tensor:
|
1904
|
-
"""
|
1905
|
-
Rearranges input according to formula
|
1906
|
-
|
1907
|
-
See: https://einops.rocks/api/rearrange/
|
1908
|
-
|
1909
|
-
```python exec="true" source="above" session="tensor" result="python"
|
1910
|
-
x = Tensor([[1, 2], [3, 4]])
|
1911
|
-
print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
|
1912
|
-
```
|
1913
|
-
"""
|
1914
|
-
def parse_formula(formula: str):
|
1915
|
-
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
1916
|
-
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
1917
|
-
pairs = list(zip(lparens, rparens))
|
1918
|
-
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
1919
|
-
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
1920
|
-
|
1921
|
-
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
1922
|
-
|
1923
|
-
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
1924
|
-
|
1925
|
-
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
1926
|
-
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
1927
|
-
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
1928
|
-
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
1929
|
-
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
1930
|
-
|
1931
|
-
# resolve ellipsis
|
1932
|
-
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
1933
|
-
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
1934
|
-
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
1935
|
-
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
1936
|
-
|
1937
|
-
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
1938
|
-
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
1939
|
-
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
1940
|
-
t = t.permute([lhs.index(name) for name in rhs])
|
1941
|
-
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1965
|
+
return self._inverse().argmax(axis=axis, keepdim=keepdim)
|
1942
1966
|
|
1943
1967
|
@staticmethod
|
1944
1968
|
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
@@ -1964,7 +1988,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1964
1988
|
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
1965
1989
|
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
1966
1990
|
|
1967
|
-
xs:
|
1991
|
+
xs:tuple[Tensor, ...] = argfix(*operands)
|
1968
1992
|
inputs_str, output = parse_formula(formula, *xs)
|
1969
1993
|
inputs = inputs_str.split(",")
|
1970
1994
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
@@ -1972,7 +1996,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1972
1996
|
# map the value of each letter in the formula
|
1973
1997
|
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
1974
1998
|
|
1975
|
-
xs_:
|
1999
|
+
xs_:list[Tensor] = []
|
1976
2000
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
1977
2001
|
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
|
1978
2002
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
@@ -1987,7 +2011,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1987
2011
|
|
1988
2012
|
# ***** processing ops *****
|
1989
2013
|
|
1990
|
-
def _pool(self, k_:
|
2014
|
+
def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
|
1991
2015
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
1992
2016
|
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
1993
2017
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
@@ -1995,10 +2019,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1995
2019
|
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
1996
2020
|
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
1997
2021
|
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
1998
|
-
#
|
1999
|
-
|
2022
|
+
# input size scaling factor to make sure shrink for stride is possible
|
2023
|
+
f_ = [1 + int(resolve(o*s > i+d)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
2024
|
+
# # repeats such that we don't need padding
|
2025
|
+
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
2000
2026
|
# handle dilation
|
2001
|
-
x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
|
2027
|
+
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
2002
2028
|
# handle stride
|
2003
2029
|
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
2004
2030
|
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
@@ -2010,14 +2036,44 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2010
2036
|
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
2011
2037
|
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
2012
2038
|
|
2013
|
-
def
|
2039
|
+
def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
2040
|
+
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
2041
|
+
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
2014
2042
|
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
2015
2043
|
|
2044
|
+
def _apply_ceil_mode(self, pads:Sequence[int], k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int]) -> List[int]:
|
2045
|
+
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
|
2046
|
+
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
|
2047
|
+
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
2048
|
+
o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)]
|
2049
|
+
for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)):
|
2050
|
+
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
|
2051
|
+
# `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
|
2052
|
+
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
|
2053
|
+
# `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset)
|
2054
|
+
pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0)
|
2055
|
+
return pads
|
2056
|
+
|
2016
2057
|
# NOTE: these work for more than 2D
|
2017
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
2058
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
|
2018
2059
|
"""
|
2019
2060
|
Applies average pooling over a tensor.
|
2020
2061
|
|
2062
|
+
This function supports three different types of `padding`
|
2063
|
+
|
2064
|
+
1. `int` (single value):
|
2065
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2066
|
+
|
2067
|
+
2. `Tuple[int, ...]` (length = number of spatial dimensions):
|
2068
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2069
|
+
|
2070
|
+
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2071
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2072
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2073
|
+
|
2074
|
+
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
2075
|
+
When `count_include_pad` is set to `False`, zero padding will not be included in the averaging calculation.
|
2076
|
+
|
2021
2077
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2022
2078
|
|
2023
2079
|
See: https://paperswithcode.com/method/average-pooling
|
@@ -2027,17 +2083,43 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2027
2083
|
print(t.avg_pool2d().numpy())
|
2028
2084
|
```
|
2029
2085
|
```python exec="true" source="above" session="tensor" result="python"
|
2086
|
+
print(t.avg_pool2d(ceil_mode=True).numpy())
|
2087
|
+
```
|
2088
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2030
2089
|
print(t.avg_pool2d(padding=1).numpy())
|
2031
2090
|
```
|
2091
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2092
|
+
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
|
2093
|
+
```
|
2032
2094
|
"""
|
2033
|
-
|
2034
|
-
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
2035
|
-
|
2095
|
+
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
2096
|
+
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
2097
|
+
reg_pads = self._resolve_pool_pads(padding, len(k_))
|
2098
|
+
ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation)
|
2099
|
+
if not count_include_pad:
|
2100
|
+
pads = ceil_pads if ceil_mode else reg_pads
|
2101
|
+
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
|
2102
|
+
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
2103
|
+
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
2036
2104
|
|
2037
|
-
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
2105
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
|
2038
2106
|
"""
|
2039
2107
|
Applies max pooling over a tensor.
|
2040
2108
|
|
2109
|
+
This function supports three different types of `padding`
|
2110
|
+
|
2111
|
+
1. `int` (single value):
|
2112
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2113
|
+
|
2114
|
+
2. `Tuple[int, ...]` (length = number of spatial dimensions):
|
2115
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2116
|
+
|
2117
|
+
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2118
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2119
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2120
|
+
|
2121
|
+
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
2122
|
+
|
2041
2123
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2042
2124
|
|
2043
2125
|
See: https://paperswithcode.com/method/max-pooling
|
@@ -2047,17 +2129,33 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2047
2129
|
print(t.max_pool2d().numpy())
|
2048
2130
|
```
|
2049
2131
|
```python exec="true" source="above" session="tensor" result="python"
|
2132
|
+
print(t.max_pool2d(ceil_mode=True).numpy())
|
2133
|
+
```
|
2134
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2050
2135
|
print(t.max_pool2d(padding=1).numpy())
|
2051
2136
|
```
|
2052
2137
|
"""
|
2053
|
-
|
2054
|
-
|
2138
|
+
pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
|
2139
|
+
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
2140
|
+
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
2055
2141
|
|
2056
|
-
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|
|
2142
|
+
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
2057
2143
|
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
2058
2144
|
"""
|
2059
2145
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
2060
2146
|
|
2147
|
+
This function supports three different types of `padding`
|
2148
|
+
|
2149
|
+
1. `int` (single value):
|
2150
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2151
|
+
|
2152
|
+
2. `Tuple[int, ...]` (length = number of spatial dimensions):
|
2153
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2154
|
+
|
2155
|
+
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2156
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2157
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2158
|
+
|
2061
2159
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
2062
2160
|
|
2063
2161
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
@@ -2070,9 +2168,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2070
2168
|
"""
|
2071
2169
|
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
|
2072
2170
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
2171
|
+
padding_ = self._resolve_pool_pads(padding, len(HW))
|
2073
2172
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
2074
|
-
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
2075
|
-
padding_ = self._padding2d(padding, len(HW))
|
2076
2173
|
|
2077
2174
|
# conv2d is a pooling op (with padding)
|
2078
2175
|
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
@@ -2120,6 +2217,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2120
2217
|
"""
|
2121
2218
|
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
2122
2219
|
|
2220
|
+
This function supports three different types of `padding`
|
2221
|
+
|
2222
|
+
1. `int` (single value):
|
2223
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2224
|
+
|
2225
|
+
2. `Tuple[int, ...]` (length = number of spatial dimensions):
|
2226
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2227
|
+
|
2228
|
+
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2229
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2230
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2231
|
+
|
2123
2232
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
2124
2233
|
|
2125
2234
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
@@ -2132,14 +2241,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2132
2241
|
"""
|
2133
2242
|
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
2134
2243
|
HW = weight.shape[2:]
|
2135
|
-
|
2244
|
+
padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW)))
|
2245
|
+
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
|
2136
2246
|
if any(s>1 for s in stride):
|
2137
2247
|
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
2138
2248
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
2139
2249
|
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
2140
2250
|
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
2141
2251
|
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
2142
|
-
padding = flatten((((k-1)*d-
|
2252
|
+
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
2143
2253
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
2144
2254
|
|
2145
2255
|
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
@@ -2185,15 +2295,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2185
2295
|
"""
|
2186
2296
|
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
2187
2297
|
|
2188
|
-
def
|
2189
|
-
assert self.shape[axis] != 0
|
2190
|
-
pl_sz = self.shape[axis] - int(not
|
2191
|
-
|
2298
|
+
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
2299
|
+
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
2300
|
+
pl_sz = self.shape[axis] - int(not _include_initial)
|
2301
|
+
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
2302
|
+
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
|
2303
|
+
|
2304
|
+
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
2305
|
+
axis = self._resolve_dim(axis)
|
2306
|
+
if self.ndim == 0 or 0 in self.shape: return self
|
2307
|
+
# TODO: someday the optimizer will find this on it's own
|
2308
|
+
# for now this is a two stage cumsum
|
2309
|
+
SPLIT = 256
|
2310
|
+
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
2311
|
+
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
2312
|
+
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
2313
|
+
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
2314
|
+
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
2315
|
+
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
|
2316
|
+
|
2192
2317
|
def cumsum(self, axis:int=0) -> Tensor:
|
2193
2318
|
"""
|
2194
|
-
Computes the cumulative sum of the tensor along the specified axis
|
2195
|
-
|
2196
|
-
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
2319
|
+
Computes the cumulative sum of the tensor along the specified `axis`.
|
2197
2320
|
|
2198
2321
|
```python exec="true" source="above" session="tensor" result="python"
|
2199
2322
|
t = Tensor.ones(2, 3)
|
@@ -2203,17 +2326,21 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2203
2326
|
print(t.cumsum(1).numpy())
|
2204
2327
|
```
|
2205
2328
|
"""
|
2206
|
-
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2210
|
-
|
2211
|
-
|
2212
|
-
|
2213
|
-
|
2214
|
-
|
2215
|
-
|
2216
|
-
|
2329
|
+
return self._split_cumalu(axis, Ops.ADD)
|
2330
|
+
|
2331
|
+
def cummax(self, axis:int=0) -> Tensor:
|
2332
|
+
"""
|
2333
|
+
Computes the cumulative max of the tensor along the specified `axis`.
|
2334
|
+
|
2335
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2336
|
+
t = Tensor([0, 1, -1, 2, -2, 3, -3])
|
2337
|
+
print(t.numpy())
|
2338
|
+
```
|
2339
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2340
|
+
print(t.cummax(0).numpy())
|
2341
|
+
```
|
2342
|
+
"""
|
2343
|
+
return self._split_cumalu(axis, Ops.MAX)
|
2217
2344
|
|
2218
2345
|
@staticmethod
|
2219
2346
|
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
@@ -2271,7 +2398,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2271
2398
|
"""
|
2272
2399
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
2273
2400
|
|
2274
|
-
def interpolate(self, size:
|
2401
|
+
def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
2275
2402
|
"""
|
2276
2403
|
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
|
2277
2404
|
|
@@ -2303,6 +2430,47 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2303
2430
|
x = x.gather(i, index)
|
2304
2431
|
return x.cast(self.dtype)
|
2305
2432
|
|
2433
|
+
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
2434
|
+
"""
|
2435
|
+
Scatters `src` values along an axis specified by `dim`.
|
2436
|
+
Apply `add` or `multiply` reduction operation with `reduce`.
|
2437
|
+
|
2438
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2439
|
+
src = Tensor.arange(1, 11).reshape(2, 5)
|
2440
|
+
print(src.numpy())
|
2441
|
+
```
|
2442
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2443
|
+
index = Tensor([[0, 1, 2, 0]])
|
2444
|
+
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
|
2445
|
+
```
|
2446
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2447
|
+
index = Tensor([[0, 1, 2], [0, 1, 4]])
|
2448
|
+
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
|
2449
|
+
```
|
2450
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2451
|
+
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
|
2452
|
+
```
|
2453
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2454
|
+
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
|
2455
|
+
```
|
2456
|
+
"""
|
2457
|
+
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
2458
|
+
index, dim = index.to(self.device), self._resolve_dim(dim)
|
2459
|
+
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
|
2460
|
+
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
2461
|
+
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
2462
|
+
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
2463
|
+
# shrink src to index shape to shrink away the unused values
|
2464
|
+
src = src.shrink(tuple((0,s) for s in index.shape))
|
2465
|
+
# prepare src and mask for reduce with respect to dim
|
2466
|
+
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
2467
|
+
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
2468
|
+
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
2469
|
+
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
2470
|
+
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
|
2471
|
+
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
|
2472
|
+
return _masked_setitem(self, src, mask, (-1,))
|
2473
|
+
|
2306
2474
|
# ***** unary ops *****
|
2307
2475
|
|
2308
2476
|
def logical_not(self):
|
@@ -2313,7 +2481,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2313
2481
|
print(Tensor([False, True]).logical_not().numpy())
|
2314
2482
|
```
|
2315
2483
|
"""
|
2316
|
-
return
|
2484
|
+
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
2317
2485
|
def neg(self):
|
2318
2486
|
"""
|
2319
2487
|
Negates the tensor element-wise.
|
@@ -2327,12 +2495,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2327
2495
|
"""
|
2328
2496
|
Returns a contiguous tensor.
|
2329
2497
|
"""
|
2330
|
-
return
|
2498
|
+
return self._apply_uop(UOp.contiguous)
|
2331
2499
|
def contiguous_backward(self):
|
2332
2500
|
"""
|
2333
2501
|
Inserts a contiguous operation in the backward pass.
|
2334
2502
|
"""
|
2335
|
-
return
|
2503
|
+
return self._apply_uop(UOp.contiguous_backward)
|
2336
2504
|
def log(self):
|
2337
2505
|
"""
|
2338
2506
|
Computes the natural logarithm element-wise.
|
@@ -2343,7 +2511,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2343
2511
|
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
2344
2512
|
```
|
2345
2513
|
"""
|
2346
|
-
return
|
2514
|
+
return self.log2()*math.log(2)
|
2347
2515
|
def log2(self):
|
2348
2516
|
"""
|
2349
2517
|
Computes the base-2 logarithm element-wise.
|
@@ -2354,7 +2522,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2354
2522
|
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
2355
2523
|
```
|
2356
2524
|
"""
|
2357
|
-
return self.
|
2525
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
2358
2526
|
def exp(self):
|
2359
2527
|
"""
|
2360
2528
|
Computes the exponential function element-wise.
|
@@ -2365,7 +2533,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2365
2533
|
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
2366
2534
|
```
|
2367
2535
|
"""
|
2368
|
-
return
|
2536
|
+
return self.mul(1/math.log(2)).exp2()
|
2369
2537
|
def exp2(self):
|
2370
2538
|
"""
|
2371
2539
|
Computes the base-2 exponential function element-wise.
|
@@ -2376,7 +2544,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2376
2544
|
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
2377
2545
|
```
|
2378
2546
|
"""
|
2379
|
-
return
|
2547
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
2380
2548
|
def relu(self):
|
2381
2549
|
"""
|
2382
2550
|
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
@@ -2387,7 +2555,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2387
2555
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
2388
2556
|
```
|
2389
2557
|
"""
|
2390
|
-
return
|
2558
|
+
return (self>0).where(self, 0)
|
2559
|
+
|
2391
2560
|
def sigmoid(self):
|
2392
2561
|
"""
|
2393
2562
|
Applies the Sigmoid function element-wise.
|
@@ -2398,7 +2567,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2398
2567
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
2399
2568
|
```
|
2400
2569
|
"""
|
2401
|
-
return
|
2570
|
+
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
|
2571
|
+
|
2402
2572
|
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
|
2403
2573
|
"""
|
2404
2574
|
Applies the Hardsigmoid function element-wise.
|
@@ -2421,7 +2591,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2421
2591
|
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
2422
2592
|
```
|
2423
2593
|
"""
|
2424
|
-
return
|
2594
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
2425
2595
|
def rsqrt(self):
|
2426
2596
|
"""
|
2427
2597
|
Computes the reciprocal of the square root of the tensor element-wise.
|
@@ -2439,7 +2609,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2439
2609
|
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
2440
2610
|
```
|
2441
2611
|
"""
|
2442
|
-
return
|
2612
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
2443
2613
|
def cos(self):
|
2444
2614
|
"""
|
2445
2615
|
Computes the cosine of the tensor element-wise.
|
@@ -2459,6 +2629,39 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2459
2629
|
"""
|
2460
2630
|
return self.sin() / self.cos()
|
2461
2631
|
|
2632
|
+
def asin(self):
|
2633
|
+
"""
|
2634
|
+
Computes the inverse sine (arcsine) of the tensor element-wise.
|
2635
|
+
|
2636
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2637
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
|
2638
|
+
```
|
2639
|
+
"""
|
2640
|
+
# https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
|
2641
|
+
coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
|
2642
|
+
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
|
2643
|
+
return self.sign() * x
|
2644
|
+
|
2645
|
+
def acos(self):
|
2646
|
+
"""
|
2647
|
+
Computes the inverse cosine (arccosine) of the tensor element-wise.
|
2648
|
+
|
2649
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2650
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
|
2651
|
+
```
|
2652
|
+
"""
|
2653
|
+
return math.pi / 2 - self.asin()
|
2654
|
+
|
2655
|
+
def atan(self):
|
2656
|
+
"""
|
2657
|
+
Computes the inverse tangent (arctan) of the tensor element-wise.
|
2658
|
+
|
2659
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2660
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
|
2661
|
+
```
|
2662
|
+
"""
|
2663
|
+
return (self / (1 + self * self).sqrt()).asin()
|
2664
|
+
|
2462
2665
|
# ***** math functions *****
|
2463
2666
|
|
2464
2667
|
def trunc(self: Tensor) -> Tensor:
|
@@ -2565,7 +2768,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2565
2768
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
2566
2769
|
```
|
2567
2770
|
"""
|
2568
|
-
return
|
2771
|
+
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
2569
2772
|
def abs(self):
|
2570
2773
|
"""
|
2571
2774
|
Computes the absolute value of the tensor element-wise.
|
@@ -2583,7 +2786,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2583
2786
|
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
2584
2787
|
```
|
2585
2788
|
"""
|
2586
|
-
return
|
2789
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
|
2587
2790
|
|
2588
2791
|
# ***** activation functions *****
|
2589
2792
|
|
@@ -2613,6 +2816,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2613
2816
|
"""
|
2614
2817
|
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
2615
2818
|
|
2819
|
+
def selu(self, alpha=1.67326, gamma=1.0507):
|
2820
|
+
"""
|
2821
|
+
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
2822
|
+
|
2823
|
+
- Described: https://paperswithcode.com/method/selu
|
2824
|
+
- Paper: https://arxiv.org/abs/1706.02515v5
|
2825
|
+
|
2826
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2827
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
|
2828
|
+
```
|
2829
|
+
"""
|
2830
|
+
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
2831
|
+
|
2616
2832
|
def swish(self):
|
2617
2833
|
"""
|
2618
2834
|
See `.silu()`
|
@@ -2840,17 +3056,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2840
3056
|
return self / (1 + self.abs())
|
2841
3057
|
|
2842
3058
|
# ***** broadcasted elementwise ops *****
|
2843
|
-
def _broadcast_to(self,
|
2844
|
-
if self.shape ==
|
2845
|
-
if self.ndim > len(
|
2846
|
-
# first
|
2847
|
-
|
2848
|
-
# for each dimension, check either
|
2849
|
-
if
|
2850
|
-
raise ValueError(f"cannot broadcast
|
2851
|
-
return
|
2852
|
-
|
2853
|
-
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) ->
|
3059
|
+
def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor:
|
3060
|
+
if self.shape == new_shape: return self
|
3061
|
+
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
3062
|
+
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
3063
|
+
shape, _ = _align_left(self.shape, new_shape)
|
3064
|
+
# for each dimension, check either dim is 1, or it does not change
|
3065
|
+
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
3066
|
+
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
3067
|
+
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
3068
|
+
|
3069
|
+
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
2854
3070
|
x: Tensor = self
|
2855
3071
|
if not isinstance(y, Tensor):
|
2856
3072
|
# make y a Tensor
|
@@ -2867,12 +3083,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2867
3083
|
if reverse: x, y = y, x
|
2868
3084
|
|
2869
3085
|
# broadcast
|
2870
|
-
out_shape
|
2871
|
-
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
3086
|
+
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
2872
3087
|
|
3088
|
+
# TODO: tensor should stop checking if things are const
|
2873
3089
|
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
2874
|
-
return x.lazydata.
|
2875
|
-
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
3090
|
+
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, UOp) and x.lazydata.base.op is Ops.CONST \
|
3091
|
+
and unwrap(x.lazydata.st).views[0].mask is None and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
2876
3092
|
|
2877
3093
|
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2878
3094
|
"""
|
@@ -2892,7 +3108,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2892
3108
|
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
2893
3109
|
```
|
2894
3110
|
"""
|
2895
|
-
return
|
3111
|
+
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
2896
3112
|
|
2897
3113
|
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2898
3114
|
"""
|
@@ -2933,20 +3149,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2933
3149
|
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
2934
3150
|
```
|
2935
3151
|
"""
|
2936
|
-
return
|
3152
|
+
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
2937
3153
|
|
2938
3154
|
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2939
3155
|
"""
|
2940
3156
|
Divides `self` by `x`.
|
2941
3157
|
Equivalent to `self // x`.
|
2942
3158
|
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
2943
|
-
`idiv` performs integer division.
|
3159
|
+
`idiv` performs integer division (truncate towards zero).
|
2944
3160
|
|
2945
3161
|
```python exec="true" source="above" session="tensor" result="python"
|
2946
|
-
print(Tensor([
|
3162
|
+
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
2947
3163
|
```
|
2948
3164
|
"""
|
2949
|
-
return
|
3165
|
+
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
2950
3166
|
|
2951
3167
|
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2952
3168
|
"""
|
@@ -2970,6 +3186,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2970
3186
|
numerator, denominator = self._broadcasted(x, reverse)
|
2971
3187
|
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
2972
3188
|
|
3189
|
+
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3190
|
+
"""
|
3191
|
+
Mod `self` by `x`.
|
3192
|
+
Equivalent to `self % x`.
|
3193
|
+
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
3194
|
+
|
3195
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3196
|
+
print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
3197
|
+
```
|
3198
|
+
"""
|
3199
|
+
a, b = self._broadcasted(x, reverse)
|
3200
|
+
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
3201
|
+
|
2973
3202
|
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2974
3203
|
"""
|
2975
3204
|
Computes bitwise xor of `self` and `x`.
|
@@ -2984,7 +3213,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2984
3213
|
```
|
2985
3214
|
"""
|
2986
3215
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
2987
|
-
return
|
3216
|
+
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
|
2988
3217
|
|
2989
3218
|
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2990
3219
|
"""
|
@@ -2999,7 +3228,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2999
3228
|
```
|
3000
3229
|
"""
|
3001
3230
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3002
|
-
return
|
3231
|
+
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
3003
3232
|
|
3004
3233
|
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3005
3234
|
"""
|
@@ -3014,7 +3243,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3014
3243
|
```
|
3015
3244
|
"""
|
3016
3245
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3017
|
-
return
|
3246
|
+
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
|
3018
3247
|
|
3019
3248
|
def bitwise_not(self) -> Tensor:
|
3020
3249
|
"""
|
@@ -3028,7 +3257,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3028
3257
|
```
|
3029
3258
|
"""
|
3030
3259
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3031
|
-
return self.logical_not() if self.dtype == dtypes.bool else self ^
|
3260
|
+
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
3032
3261
|
|
3033
3262
|
def lshift(self, x:int):
|
3034
3263
|
"""
|
@@ -3072,8 +3301,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3072
3301
|
x = self._to_const_val(x)
|
3073
3302
|
if not isinstance(x, Tensor) and not reverse:
|
3074
3303
|
# simple pow identities
|
3075
|
-
if x < 0: return self.reciprocal().pow(-x)
|
3304
|
+
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
|
3076
3305
|
if x == 0: return 1 + self * 0
|
3306
|
+
# rewrite pow 0.5 to sqrt
|
3077
3307
|
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
3078
3308
|
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
3079
3309
|
|
@@ -3081,16 +3311,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3081
3311
|
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
3082
3312
|
|
3083
3313
|
base, exponent = self._broadcasted(x, reverse=reverse)
|
3314
|
+
# TODO: int pow
|
3315
|
+
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
3084
3316
|
# start with b ** e = exp(e * log(b))
|
3085
3317
|
ret = base.abs().log().mul(exponent).exp()
|
3086
|
-
#
|
3087
|
-
|
3088
|
-
#
|
3089
|
-
|
3090
|
-
|
3091
|
-
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
3092
|
-
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
3093
|
-
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
3318
|
+
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
|
3319
|
+
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
|
3320
|
+
# fix 0 ** 0 = 1
|
3321
|
+
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
|
3322
|
+
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
3094
3323
|
|
3095
3324
|
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
3096
3325
|
"""
|
@@ -3103,7 +3332,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3103
3332
|
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
3104
3333
|
```
|
3105
3334
|
"""
|
3106
|
-
|
3335
|
+
# NOTE: the mid-point is for backward, revisit after new gradient API
|
3336
|
+
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
3337
|
+
return (self<x).detach().where(x, self)
|
3107
3338
|
|
3108
3339
|
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
3109
3340
|
"""
|
@@ -3116,9 +3347,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3116
3347
|
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
3117
3348
|
```
|
3118
3349
|
"""
|
3119
|
-
|
3350
|
+
t, x = self._broadcasted(x)
|
3351
|
+
return t._inverse().maximum(x._inverse())._inverse()
|
3120
3352
|
|
3121
|
-
def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
|
3353
|
+
def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
|
3122
3354
|
"""
|
3123
3355
|
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
3124
3356
|
`output_i = x_i if self_i else y_i`.
|
@@ -3140,7 +3372,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3140
3372
|
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
3141
3373
|
cond, x = self._broadcasted(x, match_dtype=False)
|
3142
3374
|
cond, y = cond._broadcasted(y, match_dtype=False)
|
3143
|
-
return
|
3375
|
+
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
3144
3376
|
|
3145
3377
|
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
3146
3378
|
|
@@ -3170,9 +3402,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3170
3402
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
3171
3403
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
3172
3404
|
|
3173
|
-
def
|
3174
|
-
def
|
3175
|
-
def ne(self, x) -> Tensor: return
|
3405
|
+
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
3406
|
+
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
3407
|
+
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
|
3176
3408
|
|
3177
3409
|
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
3178
3410
|
|
@@ -3194,7 +3426,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3194
3426
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
3195
3427
|
return x.add(bias) if bias is not None else x
|
3196
3428
|
|
3197
|
-
def sequential(self, ll:
|
3429
|
+
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
|
3198
3430
|
"""
|
3199
3431
|
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
3200
3432
|
|
@@ -3205,7 +3437,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3205
3437
|
"""
|
3206
3438
|
return functools.reduce(lambda x,f: f(x), ll, self)
|
3207
3439
|
|
3208
|
-
def layernorm(self, axis:Union[int,
|
3440
|
+
def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
|
3209
3441
|
"""
|
3210
3442
|
Applies Layer Normalization over a mini-batch of inputs.
|
3211
3443
|
|
@@ -3224,7 +3456,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3224
3456
|
y = (self - self.mean(axis, keepdim=True))
|
3225
3457
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
3226
3458
|
|
3227
|
-
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,
|
3459
|
+
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
|
3228
3460
|
"""
|
3229
3461
|
Applies Batch Normalization over a mini-batch of inputs.
|
3230
3462
|
|
@@ -3266,6 +3498,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3266
3498
|
if not Tensor.training or p == 0: return self
|
3267
3499
|
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
3268
3500
|
|
3501
|
+
# helper function commonly used for indexing
|
3502
|
+
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
3503
|
+
offset = self.ndim - self._resolve_dim(dim) - 1
|
3504
|
+
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
3505
|
+
|
3269
3506
|
def one_hot(self, num_classes:int=-1) -> Tensor:
|
3270
3507
|
"""
|
3271
3508
|
Converts `self` to a one-hot tensor.
|
@@ -3278,10 +3515,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3278
3515
|
```
|
3279
3516
|
"""
|
3280
3517
|
if num_classes == -1: num_classes = (self.max()+1).item()
|
3281
|
-
return
|
3518
|
+
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
3282
3519
|
|
3283
|
-
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:
|
3284
|
-
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
3520
|
+
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
3285
3521
|
"""
|
3286
3522
|
Computes scaled dot-product attention.
|
3287
3523
|
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
@@ -3298,14 +3534,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3298
3534
|
"""
|
3299
3535
|
# NOTE: it also works when `key` and `value` have symbolic shape.
|
3300
3536
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
3301
|
-
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
3302
|
-
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
3303
3537
|
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
3304
|
-
|
3538
|
+
# handle attention mask
|
3539
|
+
if is_causal:
|
3540
|
+
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
3541
|
+
attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
|
3542
|
+
if attn_mask is not None:
|
3543
|
+
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
3544
|
+
qk = qk + attn_mask
|
3545
|
+
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
3305
3546
|
|
3306
3547
|
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
3307
3548
|
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
3308
|
-
reductions:
|
3549
|
+
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
|
3309
3550
|
return reductions[reduction](self)
|
3310
3551
|
|
3311
3552
|
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
@@ -3354,8 +3595,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3354
3595
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
3355
3596
|
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
3356
3597
|
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
3357
|
-
|
3358
|
-
y = (
|
3598
|
+
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
|
3599
|
+
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
3359
3600
|
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
3360
3601
|
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
3361
3602
|
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
@@ -3469,7 +3710,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3469
3710
|
"""
|
3470
3711
|
return dtypes.is_float(self.dtype)
|
3471
3712
|
|
3472
|
-
def size(self, dim:Optional[int]=None) -> Union[sint,
|
3713
|
+
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
|
3473
3714
|
"""
|
3474
3715
|
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
3475
3716
|
|
@@ -3488,7 +3729,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3488
3729
|
def llvm_bf16_cast(self, dtype:DTypeLike):
|
3489
3730
|
# hack for devices that don't support bfloat16
|
3490
3731
|
assert self.dtype == dtypes.bfloat16
|
3491
|
-
return self.to("LLVM").
|
3732
|
+
return self.to("LLVM").cast(dtype)
|
3492
3733
|
|
3493
3734
|
def cast(self, dtype:DTypeLike) -> Tensor:
|
3494
3735
|
"""
|
@@ -3502,8 +3743,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3502
3743
|
t = t.cast(dtypes.int32)
|
3503
3744
|
print(t.dtype, t.numpy())
|
3504
3745
|
```
|
3746
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3747
|
+
t = t.cast(dtypes.uint8)
|
3748
|
+
print(t.dtype, t.numpy())
|
3749
|
+
```
|
3505
3750
|
"""
|
3506
|
-
|
3751
|
+
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
|
3752
|
+
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
|
3753
|
+
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
|
3754
|
+
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
|
3507
3755
|
|
3508
3756
|
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
3509
3757
|
"""
|
@@ -3522,13 +3770,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3522
3770
|
"""
|
3523
3771
|
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
3524
3772
|
dt = to_dtype(dtype)
|
3525
|
-
if (
|
3526
|
-
|
3773
|
+
if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
3774
|
+
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
|
3527
3775
|
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
3528
3776
|
tmp = self.bitcast(old_uint)
|
3529
3777
|
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
3530
3778
|
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
3531
|
-
return
|
3779
|
+
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
3532
3780
|
|
3533
3781
|
def float(self) -> Tensor:
|
3534
3782
|
"""
|
@@ -3650,7 +3898,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3650
3898
|
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
3651
3899
|
|
3652
3900
|
# prepare input
|
3653
|
-
x = x.permute(0,3,4,5,1,2).pad(self.
|
3901
|
+
x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
3654
3902
|
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
3655
3903
|
|
3656
3904
|
# prepare weights
|
@@ -3702,5 +3950,5 @@ def _metadata_wrapper(fn):
|
|
3702
3950
|
|
3703
3951
|
if TRACEMETA >= 1:
|
3704
3952
|
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
3705
|
-
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
3953
|
+
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
|
3706
3954
|
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|