tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,47 +1,53 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
|
3
|
+
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
4
4
|
from contextlib import ContextDecorator
|
5
|
-
from typing import
|
6
|
-
from collections import defaultdict
|
7
|
-
|
5
|
+
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
8
6
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
9
7
|
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
10
|
-
from tinygrad.helpers import IMAGE,
|
11
|
-
from tinygrad.multi import
|
12
|
-
from tinygrad.
|
13
|
-
from tinygrad.
|
14
|
-
from tinygrad.
|
8
|
+
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
9
|
+
from tinygrad.engine.multi import get_multi_map
|
10
|
+
from tinygrad.gradient import compute_gradient
|
11
|
+
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
12
|
+
from tinygrad.spec import tensor_uop_spec, type_verify
|
13
|
+
from tinygrad.device import Device, BufferSpec
|
15
14
|
from tinygrad.engine.realize import run_schedule
|
16
15
|
from tinygrad.engine.memory import memory_planner
|
17
16
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
18
17
|
|
19
|
-
#
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
37
|
-
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
38
|
-
return ret
|
18
|
+
# *** all in scope Tensors are here. this gets relevant UOps ***
|
19
|
+
|
20
|
+
all_tensors: set[weakref.ref[Tensor]] = set()
|
21
|
+
|
22
|
+
def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
23
|
+
# get all children of keys in applied_map
|
24
|
+
all_uops: set[UOp] = set()
|
25
|
+
search_uops = list(applied_map)
|
26
|
+
while len(search_uops):
|
27
|
+
x = search_uops.pop(0)
|
28
|
+
if x in all_uops: continue
|
29
|
+
all_uops.add(x)
|
30
|
+
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
31
|
+
|
32
|
+
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
33
|
+
# NOTE: this uses all_tensors, but it's fast
|
34
|
+
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
|
39
35
|
|
40
|
-
|
36
|
+
if len(fixed_tensors):
|
37
|
+
# potentially rewrite all the discovered Tensors
|
38
|
+
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
|
39
|
+
new_sink = sink.substitute(applied_map)
|
41
40
|
|
42
|
-
|
43
|
-
|
44
|
-
|
41
|
+
# set the relevant lazydata to the realized UOps
|
42
|
+
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
43
|
+
if s is ns: continue
|
44
|
+
t.lazydata = ns
|
45
|
+
|
46
|
+
# **** Tensor helper functions ****
|
47
|
+
|
48
|
+
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
49
|
+
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
50
|
+
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
45
51
|
|
46
52
|
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
47
53
|
import numpy as np
|
@@ -50,33 +56,31 @@ def _to_np_dtype(dtype:DType) -> Optional[type]:
|
|
50
56
|
import numpy as np
|
51
57
|
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
52
58
|
|
53
|
-
def _fromnp(x: 'np.ndarray') ->
|
54
|
-
ret =
|
59
|
+
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
60
|
+
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
55
61
|
# fake realize
|
56
62
|
ret.buffer.allocate(x)
|
57
|
-
del ret.srcs
|
58
63
|
return ret
|
59
64
|
|
60
|
-
def get_shape(x) ->
|
65
|
+
def get_shape(x) -> tuple[int, ...]:
|
61
66
|
# NOTE: str is special because __getitem__ on a str is still a str
|
62
67
|
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
|
63
68
|
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
64
69
|
return (len(subs),) + (subs[0] if subs else ())
|
65
70
|
|
66
|
-
def _frompy(x:Union[
|
67
|
-
if isinstance(x, bytes): ret, data =
|
71
|
+
def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
|
72
|
+
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
68
73
|
else:
|
69
|
-
ret =
|
74
|
+
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
70
75
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
71
76
|
truncate_function = truncate[dtype]
|
72
77
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
73
78
|
# fake realize
|
74
79
|
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
75
|
-
del ret.srcs
|
76
80
|
return ret
|
77
81
|
|
78
|
-
def _get_winograd_matcols(mat, dims:int, shp:
|
79
|
-
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
82
|
+
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
|
83
|
+
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
|
80
84
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
81
85
|
|
82
86
|
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
@@ -85,21 +89,34 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|
85
89
|
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
86
90
|
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
87
91
|
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
88
|
-
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
92
|
+
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
|
89
93
|
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
90
94
|
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
91
95
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
92
96
|
return ret
|
93
97
|
|
94
|
-
def
|
98
|
+
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
99
|
+
# unsqueeze left to make every shape same length
|
95
100
|
max_dim = max(len(shape) for shape in shapes)
|
96
101
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
97
|
-
def _broadcast_shape(*shapes:
|
98
|
-
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*
|
102
|
+
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
103
|
+
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
104
|
+
|
105
|
+
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
|
106
|
+
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
|
107
|
+
values = values * mask
|
108
|
+
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
109
|
+
# remove extra dims from reduce
|
110
|
+
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
111
|
+
# select from values for each True element in mask else select from self
|
112
|
+
return mask.where(values, target)
|
113
|
+
|
114
|
+
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
115
|
+
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
|
99
116
|
|
100
117
|
ReductionStr = Literal["mean", "sum", "none"]
|
101
118
|
|
102
|
-
class Tensor(SimpleMathTrait):
|
119
|
+
class Tensor(SimpleMathTrait):
|
103
120
|
"""
|
104
121
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
105
122
|
|
@@ -110,15 +127,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
110
127
|
np.set_printoptions(precision=4)
|
111
128
|
```
|
112
129
|
"""
|
113
|
-
__slots__ = "lazydata", "requires_grad", "grad"
|
114
|
-
__deletable__ = ('_ctx',)
|
130
|
+
__slots__ = "lazydata", "requires_grad", "grad"
|
115
131
|
training: ClassVar[bool] = False
|
116
132
|
no_grad: ClassVar[bool] = False
|
117
133
|
|
118
|
-
def __init__(self, data:Union[None, ConstType,
|
134
|
+
def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
119
135
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
120
136
|
if dtype is not None: dtype = to_dtype(dtype)
|
121
|
-
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
122
137
|
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
123
138
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
124
139
|
|
@@ -129,21 +144,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
129
144
|
# None (the default) will be updated to True if it's put in an optimizer
|
130
145
|
self.requires_grad: Optional[bool] = requires_grad
|
131
146
|
|
132
|
-
# internal variable used for autograd graph construction
|
133
|
-
self._ctx: Optional[Function] = None
|
134
|
-
|
135
147
|
# create a LazyBuffer from the different types of inputs
|
136
|
-
if isinstance(data,
|
148
|
+
if isinstance(data, UOp):
|
149
|
+
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
150
|
+
# NOTE: this is here because LazyBuffer = UOp
|
151
|
+
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
|
137
152
|
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
|
138
153
|
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
139
|
-
elif isinstance(data, UOp):
|
140
|
-
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
|
141
|
-
data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
|
142
154
|
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
143
155
|
elif isinstance(data, (list, tuple)):
|
144
156
|
if dtype is None:
|
145
157
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
146
|
-
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
158
|
+
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
147
159
|
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
148
160
|
else: data = _frompy(data, dtype)
|
149
161
|
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
@@ -155,17 +167,34 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
155
167
|
dtype = dtype or dtypes.uint8
|
156
168
|
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
157
169
|
|
158
|
-
# by this point, it has to be a
|
159
|
-
if not isinstance(data,
|
170
|
+
# by this point, it has to be a UOp
|
171
|
+
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
160
172
|
|
161
173
|
# data might be on a different device
|
162
|
-
if isinstance(device, str): self.lazydata:
|
174
|
+
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
|
163
175
|
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
164
|
-
elif isinstance(data,
|
176
|
+
elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
|
165
177
|
else:
|
166
178
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
167
179
|
self.lazydata = data
|
168
180
|
|
181
|
+
# add to all_tensors after construction succeeds
|
182
|
+
all_tensors.add(weakref.ref(self))
|
183
|
+
def __del__(self): all_tensors.discard(weakref.ref(self))
|
184
|
+
|
185
|
+
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
186
|
+
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
187
|
+
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
188
|
+
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
189
|
+
|
190
|
+
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
191
|
+
lhs,rhs = self._broadcasted(x, reverse)
|
192
|
+
return lhs._apply_uop(fxn, rhs)
|
193
|
+
|
194
|
+
def requires_grad_(self, requires_grad=True) -> Tensor:
|
195
|
+
self.requires_grad = requires_grad
|
196
|
+
return self
|
197
|
+
|
169
198
|
class train(ContextDecorator):
|
170
199
|
def __init__(self, mode:bool = True): self.mode = mode
|
171
200
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
@@ -177,7 +206,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
177
206
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
178
207
|
|
179
208
|
def __repr__(self):
|
180
|
-
|
209
|
+
ld = self.lazydata
|
210
|
+
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
211
|
+
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
181
212
|
|
182
213
|
# Python has a non moving GC, so this should be okay
|
183
214
|
def __hash__(self): return id(self)
|
@@ -189,26 +220,38 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
189
220
|
return self.shape[0]
|
190
221
|
|
191
222
|
@property
|
192
|
-
def device(self) -> Union[str,
|
223
|
+
def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
|
193
224
|
|
194
225
|
@property
|
195
|
-
def shape(self) ->
|
226
|
+
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
|
196
227
|
|
197
228
|
@property
|
198
229
|
def dtype(self) -> DType: return self.lazydata.dtype
|
199
230
|
|
200
231
|
# ***** data handlers ****
|
201
232
|
|
202
|
-
def schedule_with_vars(self, *lst:Tensor) ->
|
233
|
+
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
203
234
|
"""
|
204
235
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
205
236
|
|
206
237
|
NOTE: A Tensor can only be scheduled once.
|
207
238
|
"""
|
208
|
-
|
239
|
+
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
240
|
+
|
241
|
+
# TODO: move this to scheduler tensor_map pass
|
242
|
+
if any(x.op is Ops.MULTI for x in big_sink.toposort):
|
243
|
+
# multi fixup
|
244
|
+
_apply_map_to_tensors(get_multi_map(big_sink))
|
245
|
+
big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
|
246
|
+
|
247
|
+
# verify Tensors match the spec
|
248
|
+
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
249
|
+
|
250
|
+
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
|
251
|
+
_apply_map_to_tensors(becomes_map)
|
209
252
|
return memory_planner(schedule), var_vals
|
210
253
|
|
211
|
-
def schedule(self, *lst:Tensor) ->
|
254
|
+
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
212
255
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
213
256
|
schedule, var_vals = self.schedule_with_vars(*lst)
|
214
257
|
assert len(var_vals) == 0
|
@@ -224,7 +267,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
224
267
|
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
225
268
|
"""
|
226
269
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
227
|
-
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
228
270
|
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
229
271
|
self.lazydata = x.lazydata
|
230
272
|
return self
|
@@ -232,17 +274,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
232
274
|
def assign(self, x) -> Tensor:
|
233
275
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
234
276
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
235
|
-
if x.__class__ is not Tensor: x = Tensor(x, device="
|
236
|
-
self.contiguous().realize().lazydata.base.realized.copyin(x.
|
277
|
+
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
278
|
+
self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
|
237
279
|
return self
|
238
280
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
239
|
-
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
240
281
|
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
241
282
|
# NOTE: we allow cross device assign
|
242
283
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
243
284
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
244
285
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
245
|
-
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
246
286
|
assert not x.requires_grad # self requires_grad is okay?
|
247
287
|
if not self.lazydata.is_realized: return self.replace(x)
|
248
288
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
@@ -252,15 +292,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
252
292
|
"""
|
253
293
|
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
254
294
|
"""
|
255
|
-
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
295
|
+
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
|
256
296
|
|
257
297
|
def _data(self) -> memoryview:
|
258
298
|
if 0 in self.shape: return memoryview(bytearray(0))
|
259
299
|
# NOTE: this realizes on the object from as_buffer being a Python object
|
260
|
-
cpu = self.cast(self.dtype.base).contiguous().to("
|
261
|
-
buf = cast(
|
262
|
-
|
263
|
-
|
300
|
+
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
|
301
|
+
buf = cast(UOp, cpu.lazydata).base.realized
|
302
|
+
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
303
|
+
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
304
|
+
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
264
305
|
|
265
306
|
def data(self) -> memoryview:
|
266
307
|
"""
|
@@ -271,9 +312,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
271
312
|
print(np.frombuffer(t.data(), dtype=np.int32))
|
272
313
|
```
|
273
314
|
"""
|
274
|
-
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
315
|
+
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
275
316
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
276
|
-
|
317
|
+
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
318
|
+
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
277
319
|
|
278
320
|
def item(self) -> ConstType:
|
279
321
|
"""
|
@@ -284,20 +326,24 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
284
326
|
print(t.item())
|
285
327
|
```
|
286
328
|
"""
|
287
|
-
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
288
329
|
assert self.numel() == 1, "must have one element for item"
|
289
|
-
return self.
|
330
|
+
return self.data()[(0,) * len(self.shape)]
|
290
331
|
|
291
|
-
# TODO: should be Tensor.tolist() -> Union[
|
332
|
+
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
|
292
333
|
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
293
334
|
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
294
335
|
"""
|
295
336
|
Returns the value of this tensor as a nested list.
|
337
|
+
Returns single value for const tensor.
|
296
338
|
|
297
339
|
```python exec="true" source="above" session="tensor" result="python"
|
298
340
|
t = Tensor([1, 2, 3, 4])
|
299
341
|
print(t.tolist())
|
300
342
|
```
|
343
|
+
```python exec="true" source="above" session="tensor" result="python"
|
344
|
+
t = Tensor(5)
|
345
|
+
print(t.tolist())
|
346
|
+
```
|
301
347
|
"""
|
302
348
|
return self.data().tolist()
|
303
349
|
|
@@ -311,21 +357,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
311
357
|
```
|
312
358
|
"""
|
313
359
|
import numpy as np
|
314
|
-
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
315
|
-
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
360
|
+
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
|
361
|
+
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
|
316
362
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
317
|
-
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
363
|
+
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
|
318
364
|
|
319
365
|
def clone(self) -> Tensor:
|
320
366
|
"""
|
321
|
-
Creates a clone of this tensor allocating a
|
367
|
+
Creates a clone of this tensor allocating a separate buffer for the data.
|
322
368
|
"""
|
323
369
|
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
|
324
370
|
if self.grad is not None: ret.grad = self.grad.clone()
|
325
|
-
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
326
371
|
return ret
|
327
372
|
|
328
|
-
def to(self, device:Optional[Union[str,
|
373
|
+
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
|
329
374
|
"""
|
330
375
|
Moves the tensor to the given device.
|
331
376
|
"""
|
@@ -334,47 +379,35 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
334
379
|
if not isinstance(device, str): return self.shard(device)
|
335
380
|
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
336
381
|
if self.grad is not None: ret.grad = self.grad.to(device)
|
337
|
-
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
338
382
|
return ret
|
339
383
|
|
340
|
-
def to_(self, device:Optional[Union[str,
|
384
|
+
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
|
341
385
|
"""
|
342
386
|
Moves the tensor to the given device in place.
|
343
387
|
"""
|
344
388
|
real = self.to(device)
|
345
|
-
|
346
|
-
|
347
|
-
self.lazydata = real.lazydata
|
389
|
+
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
390
|
+
return self.replace(real)
|
348
391
|
|
349
|
-
def shard(self, devices:
|
392
|
+
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
350
393
|
"""
|
351
|
-
Shards the tensor across the given devices. Optionally specify which axis to shard on
|
394
|
+
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
352
395
|
|
353
396
|
```python exec="true" source="above" session="tensor" result="python"
|
354
|
-
t = Tensor.empty(2,
|
355
|
-
print(t.shard((t.device, t.device), axis=1
|
397
|
+
t = Tensor.empty(2, 4)
|
398
|
+
print(t.shard((t.device, t.device), axis=1).lazydata)
|
356
399
|
```
|
357
|
-
|
358
400
|
"""
|
359
|
-
assert isinstance(self.
|
360
|
-
devices
|
361
|
-
if axis is not None
|
362
|
-
|
363
|
-
if splits is None:
|
364
|
-
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
365
|
-
sz = ceildiv(total, len(devices))
|
366
|
-
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
367
|
-
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
368
|
-
boundaries = tuple(itertools.accumulate(splits))
|
369
|
-
bounds = tuple(zip((0,) + boundaries, boundaries))
|
370
|
-
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
401
|
+
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
402
|
+
devices = tuple(Device.canonicalize(x) for x in devices)
|
403
|
+
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
|
404
|
+
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
371
405
|
|
372
|
-
def shard_(self, devices:
|
406
|
+
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
373
407
|
"""
|
374
408
|
Shards the tensor across the given devices in place.
|
375
409
|
"""
|
376
|
-
self.
|
377
|
-
return self
|
410
|
+
return self.replace(self.shard(devices, axis))
|
378
411
|
|
379
412
|
@staticmethod
|
380
413
|
def from_uop(y:UOp, **kwargs) -> Tensor:
|
@@ -382,18 +415,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
382
415
|
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
383
416
|
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
384
417
|
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
385
|
-
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
386
418
|
raise RuntimeError(f"unhandled UOp {y}")
|
387
419
|
|
388
420
|
# ***** creation entrypoint *****
|
389
421
|
|
390
422
|
@staticmethod
|
391
|
-
def _metaop(op, shape, device:Optional[Union[
|
423
|
+
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
392
424
|
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
393
425
|
if isinstance(device, tuple):
|
394
|
-
return Tensor(
|
426
|
+
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
395
427
|
device, dtype, **kwargs)
|
396
|
-
return Tensor(
|
428
|
+
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
397
429
|
|
398
430
|
@staticmethod
|
399
431
|
def empty(*shape, **kwargs):
|
@@ -411,7 +443,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
411
443
|
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
|
412
444
|
|
413
445
|
@staticmethod
|
414
|
-
def from_blob(ptr:int, shape:
|
446
|
+
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
415
447
|
"""
|
416
448
|
Exposes the pointer as a Tensor without taking ownership of the original data.
|
417
449
|
The pointer must remain valid for the entire lifetime of the created Tensor.
|
@@ -422,7 +454,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
422
454
|
|
423
455
|
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
|
424
456
|
r.lazydata.buffer.allocate(external_ptr=ptr)
|
425
|
-
del r.lazydata.srcs # fake realize
|
426
457
|
return r
|
427
458
|
|
428
459
|
@staticmethod
|
@@ -439,8 +470,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
439
470
|
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
|
440
471
|
|
441
472
|
_seed: int = int(time.time())
|
442
|
-
_device_seeds:
|
443
|
-
_device_rng_counters:
|
473
|
+
_device_seeds: dict[str, Tensor] = {}
|
474
|
+
_device_rng_counters: dict[str, Tensor] = {}
|
444
475
|
@staticmethod
|
445
476
|
def manual_seed(seed=0):
|
446
477
|
"""
|
@@ -462,7 +493,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
462
493
|
@staticmethod
|
463
494
|
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
464
495
|
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
465
|
-
x =
|
496
|
+
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
466
497
|
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
467
498
|
return counts0.cat(counts1)
|
468
499
|
|
@@ -485,8 +516,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
485
516
|
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
486
517
|
_device = device = Device.canonicalize(device)
|
487
518
|
|
488
|
-
#
|
489
|
-
if
|
519
|
+
# if shape has 0, return zero tensor
|
520
|
+
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
521
|
+
num = ceildiv(numel * dtype.itemsize, 4)
|
522
|
+
|
523
|
+
# when using MOCKGPU and NV generate rand on CPU
|
524
|
+
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
490
525
|
|
491
526
|
# generate per device seeds and rng counter if we haven't seen this device yet
|
492
527
|
if device not in Tensor._device_seeds:
|
@@ -494,15 +529,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
494
529
|
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
495
530
|
device=device, dtype=dtypes.uint32, requires_grad=False)
|
496
531
|
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
497
|
-
had_counter = False
|
498
|
-
else: had_counter = True
|
499
|
-
|
500
|
-
# if shape has 0, return zero tensor
|
501
|
-
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
502
|
-
num = ceildiv(numel * dtype.itemsize, 4)
|
503
|
-
|
504
532
|
# increment rng counter for devices
|
505
|
-
|
533
|
+
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
506
534
|
|
507
535
|
# threefry random bits
|
508
536
|
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
@@ -528,7 +556,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
528
556
|
# ***** creation helper functions *****
|
529
557
|
|
530
558
|
@staticmethod
|
531
|
-
def full(shape:
|
559
|
+
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
532
560
|
"""
|
533
561
|
Creates a tensor with the given shape, filled with the given value.
|
534
562
|
|
@@ -607,7 +635,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
607
635
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
608
636
|
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
609
637
|
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
610
|
-
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs).
|
638
|
+
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
611
639
|
|
612
640
|
@staticmethod
|
613
641
|
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
@@ -705,18 +733,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
705
733
|
```
|
706
734
|
"""
|
707
735
|
dtype = kwargs.pop("dtype", self.dtype)
|
708
|
-
if isinstance(self.device, tuple)
|
736
|
+
if isinstance(self.device, tuple):
|
709
737
|
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
710
738
|
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
711
739
|
contiguous = kwargs.pop("contiguous", True)
|
712
|
-
|
713
|
-
|
740
|
+
sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape))
|
741
|
+
rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device]
|
742
|
+
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
714
743
|
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
715
744
|
|
716
745
|
# ***** rng hlops *****
|
717
746
|
|
718
747
|
@staticmethod
|
719
|
-
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
748
|
+
def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
720
749
|
"""
|
721
750
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
722
751
|
If `dtype` is not specified, the default type is used.
|
@@ -731,10 +760,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
731
760
|
"""
|
732
761
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
733
762
|
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
734
|
-
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
763
|
+
return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
|
735
764
|
|
736
765
|
@staticmethod
|
737
|
-
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
766
|
+
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
738
767
|
"""
|
739
768
|
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
740
769
|
If `dtype` is not specified, the default type is used.
|
@@ -748,12 +777,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
748
777
|
```
|
749
778
|
"""
|
750
779
|
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
751
|
-
dtype = to_dtype(
|
780
|
+
dtype = to_dtype(dtype)
|
752
781
|
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
753
782
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
754
783
|
|
755
784
|
@staticmethod
|
756
|
-
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
785
|
+
def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
757
786
|
"""
|
758
787
|
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
759
788
|
|
@@ -765,10 +794,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
765
794
|
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
766
795
|
```
|
767
796
|
"""
|
768
|
-
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
797
|
+
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
|
769
798
|
|
770
799
|
@staticmethod
|
771
|
-
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
800
|
+
def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
|
772
801
|
"""
|
773
802
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
774
803
|
|
@@ -780,8 +809,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
780
809
|
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
781
810
|
```
|
782
811
|
"""
|
783
|
-
|
784
|
-
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
812
|
+
return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
|
785
813
|
|
786
814
|
@staticmethod
|
787
815
|
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
@@ -860,49 +888,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
860
888
|
|
861
889
|
# ***** toposort and backward pass *****
|
862
890
|
|
863
|
-
def
|
864
|
-
|
865
|
-
|
866
|
-
# if tensor is not leaf, reset grad
|
867
|
-
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
868
|
-
if ctx:
|
869
|
-
for i in node._ctx.parents:
|
870
|
-
if i not in visited: yield from _walk(i, visited)
|
871
|
-
yield node
|
872
|
-
return list(_walk(self, set()))
|
891
|
+
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
|
892
|
+
"""
|
893
|
+
Compute the gradient of the targets with respect to self.
|
873
894
|
|
874
|
-
|
895
|
+
```python exec="true" source="above" session="tensor" result="python"
|
896
|
+
x = Tensor.eye(3)
|
897
|
+
y = Tensor([[2.0,0,-2.0]])
|
898
|
+
z = y.matmul(x).sum()
|
899
|
+
dx, dy = z.gradient(x, y)
|
900
|
+
|
901
|
+
print(dx.tolist()) # dz/dx
|
902
|
+
print(dy.tolist()) # dz/dy
|
903
|
+
```
|
904
|
+
"""
|
905
|
+
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
906
|
+
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
907
|
+
rets = []
|
908
|
+
target_uops = [x.lazydata for x in targets]
|
909
|
+
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
910
|
+
ret = []
|
911
|
+
for x in target_uops:
|
912
|
+
if (y:=grads.get(x)) is None:
|
913
|
+
if materialize_grads: y = x.const_like(0)
|
914
|
+
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
915
|
+
ret.append(y)
|
916
|
+
rets.append(ret)
|
917
|
+
# create returned Tensors
|
918
|
+
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
919
|
+
|
920
|
+
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
875
921
|
"""
|
876
922
|
Propagates the gradient of a tensor backwards through the computation graph.
|
877
923
|
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
878
|
-
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
879
924
|
```python exec="true" source="above" session="tensor" result="python"
|
880
925
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
881
926
|
t.sum().backward()
|
882
927
|
print(t.grad.numpy())
|
883
928
|
```
|
884
929
|
"""
|
885
|
-
|
886
|
-
if
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
893
|
-
self.grad = gradient
|
894
|
-
for t0 in reversed(toposorted):
|
895
|
-
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
896
|
-
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
897
|
-
grads = t0._ctx.backward(t0.grad.lazydata)
|
898
|
-
_METADATA.reset(token)
|
899
|
-
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
900
|
-
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
901
|
-
for t, g in zip(t0._ctx.parents, grads):
|
902
|
-
if g is not None and t.requires_grad:
|
903
|
-
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
904
|
-
t.grad = g if t.grad is None else (t.grad + g)
|
905
|
-
if not retain_graph: del t0._ctx
|
930
|
+
all_uops = self.lazydata.toposort
|
931
|
+
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
932
|
+
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
|
933
|
+
# clear contexts
|
934
|
+
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
935
|
+
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
936
|
+
t.grad = g if t.grad is None else (t.grad + g)
|
906
937
|
return self
|
907
938
|
|
908
939
|
# ***** movement low level ops *****
|
@@ -926,7 +957,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
926
957
|
# resolve -1
|
927
958
|
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
928
959
|
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
929
|
-
return
|
960
|
+
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
930
961
|
|
931
962
|
def expand(self, shape, *args) -> Tensor:
|
932
963
|
"""
|
@@ -940,7 +971,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
940
971
|
print(t.expand(4, -1).numpy())
|
941
972
|
```
|
942
973
|
"""
|
943
|
-
|
974
|
+
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
975
|
+
return self._broadcast_to(new_shape)
|
944
976
|
|
945
977
|
def permute(self, order, *args) -> Tensor:
|
946
978
|
"""
|
@@ -958,7 +990,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
958
990
|
"""
|
959
991
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
960
992
|
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
961
|
-
return
|
993
|
+
return self._apply_uop(UOp.permute, arg=order_arg)
|
962
994
|
|
963
995
|
def flip(self, axis, *args) -> Tensor:
|
964
996
|
"""
|
@@ -978,9 +1010,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
978
1010
|
"""
|
979
1011
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
980
1012
|
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
981
|
-
return
|
1013
|
+
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
982
1014
|
|
983
|
-
def shrink(self, arg:
|
1015
|
+
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
984
1016
|
"""
|
985
1017
|
Returns a tensor that shrinks the each axis based on input arg.
|
986
1018
|
`arg` must have the same length as `self.ndim`.
|
@@ -998,24 +1030,25 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
998
1030
|
```
|
999
1031
|
"""
|
1000
1032
|
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
1001
|
-
return
|
1033
|
+
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
1002
1034
|
|
1003
|
-
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[
|
1035
|
+
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
1004
1036
|
"""
|
1005
1037
|
Returns a tensor with padding applied based on the input `padding`.
|
1038
|
+
|
1006
1039
|
`padding` supports two padding structures:
|
1007
1040
|
|
1008
|
-
1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
|
1009
|
-
|
1010
|
-
|
1041
|
+
1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
|
1042
|
+
- This structure matches PyTorch's pad.
|
1043
|
+
- `padding` length must be even.
|
1011
1044
|
|
1012
|
-
2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1045
|
+
2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
1046
|
+
- This structure matches pad for JAX, NumPy, TensorFlow, and others.
|
1047
|
+
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
|
1048
|
+
- `padding` must have the same length as `self.ndim`.
|
1016
1049
|
|
1017
1050
|
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
|
1018
|
-
Padding modes is selected with `mode` which supports `constant` and `
|
1051
|
+
Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
|
1019
1052
|
|
1020
1053
|
```python exec="true" source="above" session="tensor" result="python"
|
1021
1054
|
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
@@ -1031,176 +1064,166 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1031
1064
|
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
|
1032
1065
|
```
|
1033
1066
|
"""
|
1034
|
-
if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1067
|
+
if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
|
1068
|
+
# flat padding
|
1069
|
+
if all(isinstance(p, (int,UOp)) for p in padding):
|
1070
|
+
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
1071
|
+
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
|
1072
|
+
# group padding
|
1073
|
+
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[Optional[tuple[sint, sint]]], padding))
|
1038
1074
|
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
1039
|
-
X,
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1075
|
+
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
1076
|
+
if mode == "constant":
|
1077
|
+
def _constant(x:Tensor,px,v):
|
1078
|
+
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
1079
|
+
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
1080
|
+
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
1045
1081
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
1082
|
+
if mode == "circular":
|
1083
|
+
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
1084
|
+
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
1085
|
+
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
|
1086
|
+
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
|
1046
1087
|
for d,(pB,pA) in enumerate(pads):
|
1047
|
-
if
|
1048
|
-
|
1049
|
-
|
1088
|
+
if mode == "reflect":
|
1089
|
+
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
1090
|
+
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
1091
|
+
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
1092
|
+
if mode == "replicate":
|
1093
|
+
shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
|
1094
|
+
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
|
1050
1095
|
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
1051
|
-
return X.shrink(
|
1096
|
+
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
1052
1097
|
|
1053
1098
|
# ***** movement high level ops *****
|
1054
1099
|
|
1055
|
-
# Supported Indexing Implementations:
|
1056
|
-
# 1. Int indexing (no copy)
|
1057
|
-
# - for all dims where there's int, shrink -> reshape
|
1058
|
-
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
1059
|
-
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
1060
|
-
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
1061
|
-
# 2. Slice indexing (no copy)
|
1062
|
-
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
1063
|
-
# - first shrink the Tensor to X.shrink(((start, end),))
|
1064
|
-
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
1065
|
-
# - flip where dim value is negative
|
1066
|
-
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
1067
|
-
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
1068
|
-
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
1069
|
-
# 3. None indexing (no copy)
|
1070
|
-
# - reshape (inject) a dim at the dim where there's None
|
1071
|
-
# 4. Tensor indexing (copy)
|
1072
|
-
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
1073
|
-
# - combine masks together with mul
|
1074
|
-
# - apply mask to self by mask * self
|
1075
|
-
# - sum reduce away the extra dims added from creating masks
|
1076
|
-
# Tiny Things:
|
1077
|
-
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
1078
|
-
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
1079
|
-
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
1080
|
-
# 2. Bool indexing is not supported
|
1081
|
-
# 3. Out of bounds Tensor indexing results in 0
|
1082
|
-
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
1083
1100
|
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
1084
|
-
#
|
1085
|
-
|
1086
|
-
|
1087
|
-
elif isinstance(indices, (tuple, list)):
|
1088
|
-
indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
1089
|
-
else: indices = [indices]
|
1090
|
-
|
1091
|
-
# turn scalar Tensors into const val for int indexing if possible
|
1092
|
-
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
1093
|
-
# move Tensor indices to the same device as self
|
1094
|
-
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
1101
|
+
# wrap single index into a list
|
1102
|
+
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
1103
|
+
x, indices = self, list(indices)
|
1095
1104
|
|
1096
1105
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
1097
|
-
ellipsis_idx
|
1106
|
+
if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
|
1098
1107
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
1099
1108
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
1100
|
-
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
1101
|
-
|
1102
|
-
# use Dict[type, List[dimension]] to track elements in indices
|
1103
|
-
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
1104
|
-
|
1105
|
-
# record None for dimension injection later and filter None and record rest of indices
|
1106
|
-
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
1107
|
-
indices_filtered = [i for i in indices if i is not None]
|
1108
|
-
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
1109
|
-
|
1110
|
-
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
1111
|
-
for index_type in type_dim:
|
1112
|
-
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
1113
1109
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
1110
|
+
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
1114
1111
|
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
1168
|
-
|
1169
|
-
# create masks
|
1170
|
-
for dim, i in idx.items():
|
1171
|
-
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
1112
|
+
indices_parsed, dim = [], 0
|
1113
|
+
for index in indices:
|
1114
|
+
size = 1 if index is None else self.shape[dim]
|
1115
|
+
boundary, stride = [0, size], 1 # defaults
|
1116
|
+
match index:
|
1117
|
+
case list() | tuple() | Tensor():
|
1118
|
+
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
1119
|
+
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
1120
|
+
index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
|
1121
|
+
case int() | UOp(): # sint
|
1122
|
+
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
1123
|
+
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
1124
|
+
case slice():
|
1125
|
+
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
1126
|
+
if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
|
1127
|
+
# handle int slicing
|
1128
|
+
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
1129
|
+
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
1130
|
+
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
1131
|
+
# update size for slice
|
1132
|
+
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
1133
|
+
case None: pass # do nothing
|
1134
|
+
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
|
1135
|
+
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
|
1136
|
+
if index is not None: dim += 1
|
1137
|
+
|
1138
|
+
# movement op indexing
|
1139
|
+
if mops := [i for i in indices_parsed if i['index'] is not None]:
|
1140
|
+
# flip negative strides
|
1141
|
+
shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
|
1142
|
+
x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
|
1143
|
+
# handle stride != 1 or -1
|
1144
|
+
if any(abs(st) != 1 for st in strides):
|
1145
|
+
strides = tuple(abs(s) for s in strides)
|
1146
|
+
# pad shape to multiple of stride
|
1147
|
+
if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
|
1148
|
+
x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
|
1149
|
+
x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
|
1150
|
+
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
|
1151
|
+
|
1152
|
+
# dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
|
1153
|
+
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
|
1154
|
+
|
1155
|
+
# tensor indexing
|
1156
|
+
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
|
1157
|
+
# unload the tensor object into actual tensors
|
1158
|
+
dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
|
1159
|
+
pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:]
|
1160
|
+
|
1161
|
+
# create index masks
|
1162
|
+
for dim, tensor in zip(dims, tensors):
|
1163
|
+
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
|
1172
1164
|
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
1173
|
-
|
1174
|
-
masks.append(i == a)
|
1165
|
+
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
|
1175
1166
|
|
1176
1167
|
# reduce masks to 1 mask
|
1177
1168
|
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
1178
1169
|
|
1179
1170
|
# inject 1's for the extra dims added in create masks
|
1180
|
-
reshape_arg =
|
1171
|
+
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
1181
1172
|
# sum reduce the extra dims introduced in create masks
|
1182
|
-
|
1173
|
+
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
|
1183
1174
|
|
1184
1175
|
# special permute case
|
1185
|
-
if
|
1186
|
-
|
1176
|
+
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
|
1177
|
+
x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
|
1187
1178
|
|
1188
1179
|
# for advanced setitem, returns whole tensor with indices replaced
|
1189
1180
|
if v is not None:
|
1190
|
-
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(
|
1181
|
+
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
1191
1182
|
# add back reduced dims from sum
|
1192
1183
|
for dim in sum_axis: vb = vb.unsqueeze(dim)
|
1193
|
-
# axis to be reduced to match self.shape
|
1194
|
-
|
1195
|
-
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
|
1196
|
-
vb = vb * mask
|
1197
|
-
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
|
1198
|
-
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
|
1199
|
-
ret = mask.any(axis).where(vb.squeeze(), self)
|
1184
|
+
# run _masked_setitem on tuple of axis that is to be reduced to match self.shape
|
1185
|
+
x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
|
1200
1186
|
|
1201
|
-
return
|
1187
|
+
return x
|
1202
1188
|
|
1203
1189
|
def __getitem__(self, indices) -> Tensor:
|
1190
|
+
"""
|
1191
|
+
Retrieve a sub-tensor using indexing.
|
1192
|
+
|
1193
|
+
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
1194
|
+
|
1195
|
+
Examples:
|
1196
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1197
|
+
t = Tensor.arange(12).reshape(3, 4)
|
1198
|
+
print(t.numpy())
|
1199
|
+
```
|
1200
|
+
|
1201
|
+
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
1202
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1203
|
+
print(t[1, 2].numpy())
|
1204
|
+
```
|
1205
|
+
|
1206
|
+
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
1207
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1208
|
+
print(t[0:2, ::2].numpy())
|
1209
|
+
```
|
1210
|
+
|
1211
|
+
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
1212
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1213
|
+
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
1214
|
+
```
|
1215
|
+
|
1216
|
+
- `None` Indexing: Add a new dimension to the tensor.
|
1217
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1218
|
+
print(t[:, None].shape)
|
1219
|
+
```
|
1220
|
+
|
1221
|
+
NOTE: Out-of-bounds indexing results in a value of `0`.
|
1222
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1223
|
+
t = Tensor([1, 2, 3])
|
1224
|
+
print(t[Tensor([4, 3, 2])].numpy())
|
1225
|
+
```
|
1226
|
+
"""
|
1204
1227
|
return self._getitem(indices)
|
1205
1228
|
|
1206
1229
|
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
@@ -1208,9 +1231,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1208
1231
|
self._getitem(indices).assign(v)
|
1209
1232
|
return
|
1210
1233
|
# NOTE: check that setitem target is valid first
|
1211
|
-
if not
|
1212
|
-
if
|
1213
|
-
if not isinstance(v, Tensor):
|
1234
|
+
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
1235
|
+
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1236
|
+
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1214
1237
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
1215
1238
|
|
1216
1239
|
res = self.realize()._getitem(indices, v)
|
@@ -1238,7 +1261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1238
1261
|
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1239
1262
|
index = index.to(self.device)
|
1240
1263
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1241
|
-
return (
|
1264
|
+
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
|
1242
1265
|
|
1243
1266
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1244
1267
|
"""
|
@@ -1302,7 +1325,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1302
1325
|
```
|
1303
1326
|
"""
|
1304
1327
|
repeats = argfix(repeats, *args)
|
1305
|
-
base_shape =
|
1328
|
+
base_shape = _align_left(self.shape, repeats)[0]
|
1306
1329
|
unsqueezed_shape = flatten([[1, s] for s in base_shape])
|
1307
1330
|
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
|
1308
1331
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
@@ -1313,7 +1336,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1313
1336
|
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
1314
1337
|
return dim + total if dim < 0 else dim
|
1315
1338
|
|
1316
|
-
def split(self, sizes:Union[int,
|
1339
|
+
def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
|
1317
1340
|
"""
|
1318
1341
|
Splits the tensor into chunks along the dimension specified by `dim`.
|
1319
1342
|
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
@@ -1338,7 +1361,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1338
1361
|
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
1339
1362
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
1340
1363
|
|
1341
|
-
def chunk(self, chunks:int, dim:int=0) ->
|
1364
|
+
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
|
1342
1365
|
"""
|
1343
1366
|
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
1344
1367
|
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
@@ -1362,7 +1385,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1362
1385
|
dim = self._resolve_dim(dim)
|
1363
1386
|
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
1364
1387
|
|
1365
|
-
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") ->
|
1388
|
+
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
|
1366
1389
|
"""
|
1367
1390
|
Generates coordinate matrices from coordinate vectors.
|
1368
1391
|
Input tensors can be scalars or 1D tensors.
|
@@ -1462,7 +1485,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1462
1485
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
1463
1486
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
1464
1487
|
|
1465
|
-
def unflatten(self, dim:int, sizes:
|
1488
|
+
def unflatten(self, dim:int, sizes:tuple[int,...]):
|
1466
1489
|
"""
|
1467
1490
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
1468
1491
|
|
@@ -1479,7 +1502,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1479
1502
|
dim = self._resolve_dim(dim)
|
1480
1503
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
1481
1504
|
|
1482
|
-
def roll(self, shifts:Union[int,
|
1505
|
+
def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
|
1483
1506
|
"""
|
1484
1507
|
Rolls the tensor along specified dimension(s).
|
1485
1508
|
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
@@ -1499,12 +1522,52 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1499
1522
|
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
1500
1523
|
return rolled
|
1501
1524
|
|
1525
|
+
def rearrange(self, formula:str, **sizes) -> Tensor:
|
1526
|
+
"""
|
1527
|
+
Rearranges input according to formula
|
1528
|
+
|
1529
|
+
See: https://einops.rocks/api/rearrange/
|
1530
|
+
|
1531
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1532
|
+
x = Tensor([[1, 2], [3, 4]])
|
1533
|
+
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
1534
|
+
```
|
1535
|
+
"""
|
1536
|
+
def parse_formula(formula: str):
|
1537
|
+
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
1538
|
+
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
1539
|
+
pairs = list(zip(lparens, rparens))
|
1540
|
+
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
1541
|
+
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
1542
|
+
|
1543
|
+
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
1544
|
+
|
1545
|
+
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
1546
|
+
|
1547
|
+
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
1548
|
+
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
1549
|
+
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
1550
|
+
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
1551
|
+
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
1552
|
+
|
1553
|
+
# resolve ellipsis
|
1554
|
+
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
1555
|
+
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
1556
|
+
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
1557
|
+
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
1558
|
+
|
1559
|
+
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
1560
|
+
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
1561
|
+
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
1562
|
+
t = t.permute([lhs.index(name) for name in rhs])
|
1563
|
+
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1564
|
+
|
1502
1565
|
# ***** reduce ops *****
|
1503
1566
|
|
1504
|
-
def _reduce(self,
|
1567
|
+
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
1505
1568
|
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
1506
1569
|
if self.ndim == 0: axis = ()
|
1507
|
-
ret =
|
1570
|
+
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
1508
1571
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
1509
1572
|
|
1510
1573
|
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
@@ -1531,7 +1594,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1531
1594
|
print(t.sum(axis=1).numpy())
|
1532
1595
|
```
|
1533
1596
|
"""
|
1534
|
-
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(
|
1597
|
+
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
1535
1598
|
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1536
1599
|
|
1537
1600
|
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
@@ -1558,7 +1621,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1558
1621
|
print(t.prod(axis=1).numpy())
|
1559
1622
|
```
|
1560
1623
|
"""
|
1561
|
-
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(
|
1624
|
+
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
1562
1625
|
|
1563
1626
|
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1564
1627
|
"""
|
@@ -1581,7 +1644,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1581
1644
|
print(t.max(axis=1, keepdim=True).numpy())
|
1582
1645
|
```
|
1583
1646
|
"""
|
1584
|
-
return self._reduce(
|
1647
|
+
return self._reduce(Ops.MAX, axis, keepdim)
|
1648
|
+
|
1649
|
+
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
1585
1650
|
|
1586
1651
|
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1587
1652
|
"""
|
@@ -1604,8 +1669,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1604
1669
|
print(t.min(axis=1, keepdim=True).numpy())
|
1605
1670
|
```
|
1606
1671
|
"""
|
1607
|
-
|
1608
|
-
return -((-self).max(axis=axis, keepdim=keepdim))
|
1672
|
+
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
|
1609
1673
|
|
1610
1674
|
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1611
1675
|
"""
|
@@ -1651,6 +1715,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1651
1715
|
"""
|
1652
1716
|
return self.logical_not().any(axis, keepdim).logical_not()
|
1653
1717
|
|
1718
|
+
def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
|
1719
|
+
"""
|
1720
|
+
Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance.
|
1721
|
+
|
1722
|
+
The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison.
|
1723
|
+
|
1724
|
+
By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close.
|
1725
|
+
|
1726
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1727
|
+
print(Tensor([1e-7, 1e-8, 1e-9, float('nan')]).isclose(Tensor([0.0, 0.0, 0.0, float('nan')])).numpy())
|
1728
|
+
```
|
1729
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1730
|
+
print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
|
1731
|
+
```
|
1732
|
+
"""
|
1733
|
+
# TODO: Tensor.isfinite
|
1734
|
+
def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
|
1735
|
+
is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
|
1736
|
+
is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
|
1737
|
+
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
1738
|
+
return is_finite_close | is_infinite_close | is_nan_close
|
1739
|
+
|
1654
1740
|
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1655
1741
|
"""
|
1656
1742
|
Returns the mean value of the tensor along the specified axis or axes.
|
@@ -1745,8 +1831,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1745
1831
|
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
1746
1832
|
|
1747
1833
|
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
1748
|
-
|
1749
|
-
|
1834
|
+
m = self - self.max(axis=axis, keepdim=True).detach()
|
1835
|
+
if dtype is not None: m = m.cast(dtype)
|
1750
1836
|
e = m.exp()
|
1751
1837
|
return m, e, e.sum(axis=axis, keepdim=True)
|
1752
1838
|
|
@@ -1847,8 +1933,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1847
1933
|
print(t.logcumsumexp(axis=1).numpy())
|
1848
1934
|
```
|
1849
1935
|
"""
|
1850
|
-
|
1851
|
-
|
1936
|
+
if self.ndim == 0: return self
|
1937
|
+
axis = self._resolve_dim(axis)
|
1938
|
+
x = self.transpose(axis, -1)
|
1939
|
+
last_dim_size = x.shape[-1]
|
1940
|
+
x_reshaped = x.reshape(-1, last_dim_size)
|
1941
|
+
x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
|
1942
|
+
x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
|
1943
|
+
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
|
1944
|
+
ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
|
1945
|
+
return ret.reshape(*x.shape).transpose(-1, axis)
|
1852
1946
|
|
1853
1947
|
def argmax(self, axis=None, keepdim=False):
|
1854
1948
|
"""
|
@@ -1898,47 +1992,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1898
1992
|
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
1899
1993
|
```
|
1900
1994
|
"""
|
1901
|
-
return (
|
1902
|
-
|
1903
|
-
def rearrange(self, formula: str, **sizes) -> Tensor:
|
1904
|
-
"""
|
1905
|
-
Rearranges input according to formula
|
1906
|
-
|
1907
|
-
See: https://einops.rocks/api/rearrange/
|
1908
|
-
|
1909
|
-
```python exec="true" source="above" session="tensor" result="python"
|
1910
|
-
x = Tensor([[1, 2], [3, 4]])
|
1911
|
-
print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
|
1912
|
-
```
|
1913
|
-
"""
|
1914
|
-
def parse_formula(formula: str):
|
1915
|
-
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
1916
|
-
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
1917
|
-
pairs = list(zip(lparens, rparens))
|
1918
|
-
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
1919
|
-
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
1920
|
-
|
1921
|
-
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
1922
|
-
|
1923
|
-
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
1924
|
-
|
1925
|
-
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
1926
|
-
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
1927
|
-
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
1928
|
-
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
1929
|
-
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
1930
|
-
|
1931
|
-
# resolve ellipsis
|
1932
|
-
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
1933
|
-
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
1934
|
-
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
1935
|
-
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
1936
|
-
|
1937
|
-
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
1938
|
-
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
1939
|
-
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
1940
|
-
t = t.permute([lhs.index(name) for name in rhs])
|
1941
|
-
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1995
|
+
return self._inverse().argmax(axis=axis, keepdim=keepdim)
|
1942
1996
|
|
1943
1997
|
@staticmethod
|
1944
1998
|
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
@@ -1964,7 +2018,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1964
2018
|
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
1965
2019
|
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
1966
2020
|
|
1967
|
-
xs:
|
2021
|
+
xs:tuple[Tensor, ...] = argfix(*operands)
|
1968
2022
|
inputs_str, output = parse_formula(formula, *xs)
|
1969
2023
|
inputs = inputs_str.split(",")
|
1970
2024
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
@@ -1972,7 +2026,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1972
2026
|
# map the value of each letter in the formula
|
1973
2027
|
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
1974
2028
|
|
1975
|
-
xs_:
|
2029
|
+
xs_:list[Tensor] = []
|
1976
2030
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
1977
2031
|
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
|
1978
2032
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
@@ -1987,7 +2041,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1987
2041
|
|
1988
2042
|
# ***** processing ops *****
|
1989
2043
|
|
1990
|
-
def _pool(self, k_:
|
2044
|
+
def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
|
1991
2045
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
1992
2046
|
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
1993
2047
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
@@ -1995,10 +2049,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
1995
2049
|
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
1996
2050
|
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
1997
2051
|
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
1998
|
-
#
|
1999
|
-
|
2052
|
+
# input size scaling factor to make sure shrink for stride is possible
|
2053
|
+
f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
|
2054
|
+
# # repeats such that we don't need padding
|
2055
|
+
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
2000
2056
|
# handle dilation
|
2001
|
-
x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
|
2057
|
+
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
2002
2058
|
# handle stride
|
2003
2059
|
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
2004
2060
|
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
@@ -2010,14 +2066,44 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2010
2066
|
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
2011
2067
|
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
2012
2068
|
|
2013
|
-
def
|
2069
|
+
def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
2070
|
+
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
2071
|
+
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
2014
2072
|
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
2015
2073
|
|
2074
|
+
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
|
2075
|
+
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
|
2076
|
+
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
|
2077
|
+
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
2078
|
+
o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)]
|
2079
|
+
for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)):
|
2080
|
+
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
|
2081
|
+
# `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
|
2082
|
+
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
|
2083
|
+
# `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset)
|
2084
|
+
pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0)
|
2085
|
+
return pads
|
2086
|
+
|
2016
2087
|
# NOTE: these work for more than 2D
|
2017
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
2088
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
|
2018
2089
|
"""
|
2019
2090
|
Applies average pooling over a tensor.
|
2020
2091
|
|
2092
|
+
This function supports three different types of `padding`
|
2093
|
+
|
2094
|
+
1. `int` (single value):
|
2095
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2096
|
+
|
2097
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2098
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2099
|
+
|
2100
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2101
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2102
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2103
|
+
|
2104
|
+
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
2105
|
+
When `count_include_pad` is set to `False`, zero padding will not be included in the averaging calculation.
|
2106
|
+
|
2021
2107
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2022
2108
|
|
2023
2109
|
See: https://paperswithcode.com/method/average-pooling
|
@@ -2027,17 +2113,43 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2027
2113
|
print(t.avg_pool2d().numpy())
|
2028
2114
|
```
|
2029
2115
|
```python exec="true" source="above" session="tensor" result="python"
|
2116
|
+
print(t.avg_pool2d(ceil_mode=True).numpy())
|
2117
|
+
```
|
2118
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2030
2119
|
print(t.avg_pool2d(padding=1).numpy())
|
2031
2120
|
```
|
2121
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2122
|
+
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
|
2123
|
+
```
|
2032
2124
|
"""
|
2033
|
-
|
2034
|
-
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
2035
|
-
|
2125
|
+
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
2126
|
+
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
2127
|
+
reg_pads = self._resolve_pool_pads(padding, len(k_))
|
2128
|
+
ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation)
|
2129
|
+
if not count_include_pad:
|
2130
|
+
pads = ceil_pads if ceil_mode else reg_pads
|
2131
|
+
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
|
2132
|
+
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
2133
|
+
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
2036
2134
|
|
2037
|
-
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
2135
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
|
2038
2136
|
"""
|
2039
2137
|
Applies max pooling over a tensor.
|
2040
2138
|
|
2139
|
+
This function supports three different types of `padding`
|
2140
|
+
|
2141
|
+
1. `int` (single value):
|
2142
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2143
|
+
|
2144
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2145
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2146
|
+
|
2147
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2148
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2149
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2150
|
+
|
2151
|
+
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
2152
|
+
|
2041
2153
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2042
2154
|
|
2043
2155
|
See: https://paperswithcode.com/method/max-pooling
|
@@ -2047,17 +2159,33 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2047
2159
|
print(t.max_pool2d().numpy())
|
2048
2160
|
```
|
2049
2161
|
```python exec="true" source="above" session="tensor" result="python"
|
2162
|
+
print(t.max_pool2d(ceil_mode=True).numpy())
|
2163
|
+
```
|
2164
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2050
2165
|
print(t.max_pool2d(padding=1).numpy())
|
2051
2166
|
```
|
2052
2167
|
"""
|
2053
|
-
|
2054
|
-
|
2168
|
+
pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
|
2169
|
+
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
2170
|
+
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
2055
2171
|
|
2056
|
-
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|
|
2172
|
+
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
2057
2173
|
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
2058
2174
|
"""
|
2059
2175
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
2060
2176
|
|
2177
|
+
This function supports three different types of `padding`
|
2178
|
+
|
2179
|
+
1. `int` (single value):
|
2180
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2181
|
+
|
2182
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2183
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2184
|
+
|
2185
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2186
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2187
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2188
|
+
|
2061
2189
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
2062
2190
|
|
2063
2191
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
@@ -2070,9 +2198,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2070
2198
|
"""
|
2071
2199
|
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
|
2072
2200
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
2201
|
+
padding_ = self._resolve_pool_pads(padding, len(HW))
|
2073
2202
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
2074
|
-
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
2075
|
-
padding_ = self._padding2d(padding, len(HW))
|
2076
2203
|
|
2077
2204
|
# conv2d is a pooling op (with padding)
|
2078
2205
|
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
@@ -2120,6 +2247,18 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2120
2247
|
"""
|
2121
2248
|
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
2122
2249
|
|
2250
|
+
This function supports three different types of `padding`
|
2251
|
+
|
2252
|
+
1. `int` (single value):
|
2253
|
+
Applies the same padding value uniformly to all spatial dimensions.
|
2254
|
+
|
2255
|
+
2. `tuple[int, ...]` (length = number of spatial dimensions):
|
2256
|
+
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
|
2257
|
+
|
2258
|
+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
|
2259
|
+
Specifies explicit padding for each side of each spatial dimension in the form
|
2260
|
+
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2261
|
+
|
2123
2262
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
2124
2263
|
|
2125
2264
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
@@ -2132,14 +2271,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2132
2271
|
"""
|
2133
2272
|
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
2134
2273
|
HW = weight.shape[2:]
|
2135
|
-
|
2274
|
+
padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW)))
|
2275
|
+
stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)]
|
2136
2276
|
if any(s>1 for s in stride):
|
2137
2277
|
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
2138
2278
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
2139
2279
|
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
2140
2280
|
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
2141
2281
|
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
2142
|
-
padding = flatten((((k-1)*d-
|
2282
|
+
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
2143
2283
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
2144
2284
|
|
2145
2285
|
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
@@ -2185,15 +2325,28 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2185
2325
|
"""
|
2186
2326
|
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
2187
2327
|
|
2188
|
-
def
|
2189
|
-
assert self.shape[axis] != 0
|
2190
|
-
pl_sz = self.shape[axis] - int(not
|
2191
|
-
|
2328
|
+
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
2329
|
+
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
2330
|
+
pl_sz = self.shape[axis] - int(not _include_initial)
|
2331
|
+
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
2332
|
+
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
|
2333
|
+
|
2334
|
+
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
2335
|
+
axis = self._resolve_dim(axis)
|
2336
|
+
if self.ndim == 0 or 0 in self.shape: return self
|
2337
|
+
# TODO: someday the optimizer will find this on it's own
|
2338
|
+
# for now this is a two stage cumsum
|
2339
|
+
SPLIT = 256
|
2340
|
+
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
2341
|
+
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
2342
|
+
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
2343
|
+
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
2344
|
+
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
2345
|
+
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
|
2346
|
+
|
2192
2347
|
def cumsum(self, axis:int=0) -> Tensor:
|
2193
2348
|
"""
|
2194
|
-
Computes the cumulative sum of the tensor along the specified axis
|
2195
|
-
|
2196
|
-
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
2349
|
+
Computes the cumulative sum of the tensor along the specified `axis`.
|
2197
2350
|
|
2198
2351
|
```python exec="true" source="above" session="tensor" result="python"
|
2199
2352
|
t = Tensor.ones(2, 3)
|
@@ -2203,17 +2356,21 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2203
2356
|
print(t.cumsum(1).numpy())
|
2204
2357
|
```
|
2205
2358
|
"""
|
2206
|
-
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2210
|
-
|
2211
|
-
|
2212
|
-
|
2213
|
-
|
2214
|
-
|
2215
|
-
|
2216
|
-
|
2359
|
+
return self._split_cumalu(axis, Ops.ADD)
|
2360
|
+
|
2361
|
+
def cummax(self, axis:int=0) -> Tensor:
|
2362
|
+
"""
|
2363
|
+
Computes the cumulative max of the tensor along the specified `axis`.
|
2364
|
+
|
2365
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2366
|
+
t = Tensor([0, 1, -1, 2, -2, 3, -3])
|
2367
|
+
print(t.numpy())
|
2368
|
+
```
|
2369
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2370
|
+
print(t.cummax(0).numpy())
|
2371
|
+
```
|
2372
|
+
"""
|
2373
|
+
return self._split_cumalu(axis, Ops.MAX)
|
2217
2374
|
|
2218
2375
|
@staticmethod
|
2219
2376
|
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
@@ -2271,7 +2428,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2271
2428
|
"""
|
2272
2429
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
2273
2430
|
|
2274
|
-
def interpolate(self, size:
|
2431
|
+
def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
2275
2432
|
"""
|
2276
2433
|
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
|
2277
2434
|
|
@@ -2296,13 +2453,104 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2296
2453
|
reshape[i] = expand[i] = size[i]
|
2297
2454
|
if mode == "linear":
|
2298
2455
|
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
2299
|
-
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
2456
|
+
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
|
2300
2457
|
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
2301
2458
|
else:
|
2302
2459
|
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
|
2303
2460
|
x = x.gather(i, index)
|
2304
2461
|
return x.cast(self.dtype)
|
2305
2462
|
|
2463
|
+
def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
|
2464
|
+
index, dim = index.to(self.device), self._resolve_dim(dim)
|
2465
|
+
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
2466
|
+
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
2467
|
+
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
2468
|
+
if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
|
2469
|
+
# shrink src to index shape to shrink away the unused values
|
2470
|
+
src = src.shrink(tuple((0,s) for s in index.shape))
|
2471
|
+
# prepare src and mask for reduce with respect to dim
|
2472
|
+
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
2473
|
+
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
2474
|
+
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
2475
|
+
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
2476
|
+
return src, mask
|
2477
|
+
|
2478
|
+
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
2479
|
+
"""
|
2480
|
+
Scatters `src` values along an axis specified by `dim`.
|
2481
|
+
Apply `add` or `multiply` reduction operation with `reduce`.
|
2482
|
+
|
2483
|
+
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
|
2484
|
+
|
2485
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2486
|
+
src = Tensor.arange(1, 11).reshape(2, 5)
|
2487
|
+
print(src.numpy())
|
2488
|
+
```
|
2489
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2490
|
+
index = Tensor([[0, 1, 2, 0]])
|
2491
|
+
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
|
2492
|
+
```
|
2493
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2494
|
+
index = Tensor([[0, 1, 2], [0, 1, 4]])
|
2495
|
+
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
|
2496
|
+
```
|
2497
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2498
|
+
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
|
2499
|
+
```
|
2500
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2501
|
+
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
|
2502
|
+
```
|
2503
|
+
"""
|
2504
|
+
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
2505
|
+
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
|
2506
|
+
if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype)
|
2507
|
+
if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
|
2508
|
+
if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
|
2509
|
+
src, mask = self._pre_scatter(dim, index, src)
|
2510
|
+
return _masked_setitem(self, src, mask, (-1,))
|
2511
|
+
|
2512
|
+
def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
|
2513
|
+
include_self:bool=True) -> Tensor:
|
2514
|
+
"""
|
2515
|
+
Scatters `src` values along an axis specified by `dim`.
|
2516
|
+
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
|
2517
|
+
|
2518
|
+
Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
|
2519
|
+
|
2520
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2521
|
+
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
|
2522
|
+
print(src.numpy())
|
2523
|
+
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
2524
|
+
print(index.numpy())
|
2525
|
+
```
|
2526
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2527
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
|
2528
|
+
```
|
2529
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2530
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
|
2531
|
+
```
|
2532
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2533
|
+
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
|
2534
|
+
```
|
2535
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2536
|
+
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
|
2537
|
+
```
|
2538
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2539
|
+
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
|
2540
|
+
```
|
2541
|
+
"""
|
2542
|
+
src, mask = self._pre_scatter(dim, index, src)
|
2543
|
+
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
2544
|
+
# TODO: should not overwrite acc_dtype here?
|
2545
|
+
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
2546
|
+
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
2547
|
+
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
2548
|
+
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
2549
|
+
if reduce == "mean":
|
2550
|
+
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
2551
|
+
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
2552
|
+
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
2553
|
+
|
2306
2554
|
# ***** unary ops *****
|
2307
2555
|
|
2308
2556
|
def logical_not(self):
|
@@ -2313,7 +2561,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2313
2561
|
print(Tensor([False, True]).logical_not().numpy())
|
2314
2562
|
```
|
2315
2563
|
"""
|
2316
|
-
return
|
2564
|
+
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
2317
2565
|
def neg(self):
|
2318
2566
|
"""
|
2319
2567
|
Negates the tensor element-wise.
|
@@ -2327,12 +2575,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2327
2575
|
"""
|
2328
2576
|
Returns a contiguous tensor.
|
2329
2577
|
"""
|
2330
|
-
return
|
2578
|
+
return self._apply_uop(UOp.contiguous)
|
2331
2579
|
def contiguous_backward(self):
|
2332
2580
|
"""
|
2333
2581
|
Inserts a contiguous operation in the backward pass.
|
2334
2582
|
"""
|
2335
|
-
return
|
2583
|
+
return self._apply_uop(UOp.contiguous_backward)
|
2336
2584
|
def log(self):
|
2337
2585
|
"""
|
2338
2586
|
Computes the natural logarithm element-wise.
|
@@ -2343,7 +2591,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2343
2591
|
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
2344
2592
|
```
|
2345
2593
|
"""
|
2346
|
-
return
|
2594
|
+
return self.log2()*math.log(2)
|
2347
2595
|
def log2(self):
|
2348
2596
|
"""
|
2349
2597
|
Computes the base-2 logarithm element-wise.
|
@@ -2354,7 +2602,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2354
2602
|
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
2355
2603
|
```
|
2356
2604
|
"""
|
2357
|
-
return self.
|
2605
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
2358
2606
|
def exp(self):
|
2359
2607
|
"""
|
2360
2608
|
Computes the exponential function element-wise.
|
@@ -2365,7 +2613,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2365
2613
|
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
2366
2614
|
```
|
2367
2615
|
"""
|
2368
|
-
return
|
2616
|
+
return self.mul(1/math.log(2)).exp2()
|
2369
2617
|
def exp2(self):
|
2370
2618
|
"""
|
2371
2619
|
Computes the base-2 exponential function element-wise.
|
@@ -2376,7 +2624,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2376
2624
|
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
2377
2625
|
```
|
2378
2626
|
"""
|
2379
|
-
return
|
2627
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
2380
2628
|
def relu(self):
|
2381
2629
|
"""
|
2382
2630
|
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
@@ -2387,7 +2635,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2387
2635
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
2388
2636
|
```
|
2389
2637
|
"""
|
2390
|
-
return
|
2638
|
+
return (self>0).where(self, 0)
|
2639
|
+
|
2391
2640
|
def sigmoid(self):
|
2392
2641
|
"""
|
2393
2642
|
Applies the Sigmoid function element-wise.
|
@@ -2398,7 +2647,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2398
2647
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
2399
2648
|
```
|
2400
2649
|
"""
|
2401
|
-
return
|
2650
|
+
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
|
2651
|
+
|
2402
2652
|
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
|
2403
2653
|
"""
|
2404
2654
|
Applies the Hardsigmoid function element-wise.
|
@@ -2421,7 +2671,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2421
2671
|
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
2422
2672
|
```
|
2423
2673
|
"""
|
2424
|
-
return
|
2674
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
2425
2675
|
def rsqrt(self):
|
2426
2676
|
"""
|
2427
2677
|
Computes the reciprocal of the square root of the tensor element-wise.
|
@@ -2430,7 +2680,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2430
2680
|
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
2431
2681
|
```
|
2432
2682
|
"""
|
2433
|
-
return self.
|
2683
|
+
return self.sqrt().reciprocal()
|
2434
2684
|
def sin(self):
|
2435
2685
|
"""
|
2436
2686
|
Computes the sine of the tensor element-wise.
|
@@ -2439,7 +2689,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2439
2689
|
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
2440
2690
|
```
|
2441
2691
|
"""
|
2442
|
-
return
|
2692
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
2443
2693
|
def cos(self):
|
2444
2694
|
"""
|
2445
2695
|
Computes the cosine of the tensor element-wise.
|
@@ -2459,6 +2709,39 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2459
2709
|
"""
|
2460
2710
|
return self.sin() / self.cos()
|
2461
2711
|
|
2712
|
+
def asin(self):
|
2713
|
+
"""
|
2714
|
+
Computes the inverse sine (arcsine) of the tensor element-wise.
|
2715
|
+
|
2716
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2717
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
|
2718
|
+
```
|
2719
|
+
"""
|
2720
|
+
# https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
|
2721
|
+
coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
|
2722
|
+
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
|
2723
|
+
return self.sign() * x
|
2724
|
+
|
2725
|
+
def acos(self):
|
2726
|
+
"""
|
2727
|
+
Computes the inverse cosine (arccosine) of the tensor element-wise.
|
2728
|
+
|
2729
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2730
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
|
2731
|
+
```
|
2732
|
+
"""
|
2733
|
+
return math.pi / 2 - self.asin()
|
2734
|
+
|
2735
|
+
def atan(self):
|
2736
|
+
"""
|
2737
|
+
Computes the inverse tangent (arctan) of the tensor element-wise.
|
2738
|
+
|
2739
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2740
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
|
2741
|
+
```
|
2742
|
+
"""
|
2743
|
+
return (self / (1 + self * self).sqrt()).asin()
|
2744
|
+
|
2462
2745
|
# ***** math functions *****
|
2463
2746
|
|
2464
2747
|
def trunc(self: Tensor) -> Tensor:
|
@@ -2565,7 +2848,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2565
2848
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
2566
2849
|
```
|
2567
2850
|
"""
|
2568
|
-
return
|
2851
|
+
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
2569
2852
|
def abs(self):
|
2570
2853
|
"""
|
2571
2854
|
Computes the absolute value of the tensor element-wise.
|
@@ -2583,7 +2866,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2583
2866
|
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
2584
2867
|
```
|
2585
2868
|
"""
|
2586
|
-
return
|
2869
|
+
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
|
2587
2870
|
|
2588
2871
|
# ***** activation functions *****
|
2589
2872
|
|
@@ -2613,6 +2896,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2613
2896
|
"""
|
2614
2897
|
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
2615
2898
|
|
2899
|
+
def selu(self, alpha=1.67326, gamma=1.0507):
|
2900
|
+
"""
|
2901
|
+
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
2902
|
+
|
2903
|
+
- Described: https://paperswithcode.com/method/selu
|
2904
|
+
- Paper: https://arxiv.org/abs/1706.02515v5
|
2905
|
+
|
2906
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2907
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
|
2908
|
+
```
|
2909
|
+
"""
|
2910
|
+
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
2911
|
+
|
2616
2912
|
def swish(self):
|
2617
2913
|
"""
|
2618
2914
|
See `.silu()`
|
@@ -2840,17 +3136,17 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2840
3136
|
return self / (1 + self.abs())
|
2841
3137
|
|
2842
3138
|
# ***** broadcasted elementwise ops *****
|
2843
|
-
def _broadcast_to(self,
|
2844
|
-
if self.shape ==
|
2845
|
-
if self.ndim > len(
|
2846
|
-
# first
|
2847
|
-
|
2848
|
-
# for each dimension, check either
|
2849
|
-
if
|
2850
|
-
raise ValueError(f"cannot broadcast
|
2851
|
-
return
|
2852
|
-
|
2853
|
-
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) ->
|
3139
|
+
def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor:
|
3140
|
+
if self.shape == new_shape: return self
|
3141
|
+
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
3142
|
+
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
3143
|
+
shape, _ = _align_left(self.shape, new_shape)
|
3144
|
+
# for each dimension, check either dim is 1, or it does not change
|
3145
|
+
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
3146
|
+
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
3147
|
+
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
3148
|
+
|
3149
|
+
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
2854
3150
|
x: Tensor = self
|
2855
3151
|
if not isinstance(y, Tensor):
|
2856
3152
|
# make y a Tensor
|
@@ -2867,12 +3163,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2867
3163
|
if reverse: x, y = y, x
|
2868
3164
|
|
2869
3165
|
# broadcast
|
2870
|
-
out_shape
|
2871
|
-
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
2872
|
-
|
2873
|
-
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
2874
|
-
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
2875
|
-
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
3166
|
+
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
2876
3167
|
|
2877
3168
|
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2878
3169
|
"""
|
@@ -2892,7 +3183,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2892
3183
|
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
2893
3184
|
```
|
2894
3185
|
"""
|
2895
|
-
return
|
3186
|
+
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
2896
3187
|
|
2897
3188
|
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2898
3189
|
"""
|
@@ -2933,20 +3224,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2933
3224
|
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
2934
3225
|
```
|
2935
3226
|
"""
|
2936
|
-
return
|
3227
|
+
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
2937
3228
|
|
2938
3229
|
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2939
3230
|
"""
|
2940
3231
|
Divides `self` by `x`.
|
2941
3232
|
Equivalent to `self // x`.
|
2942
3233
|
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
2943
|
-
`idiv` performs integer division.
|
3234
|
+
`idiv` performs integer division (truncate towards zero).
|
2944
3235
|
|
2945
3236
|
```python exec="true" source="above" session="tensor" result="python"
|
2946
|
-
print(Tensor([
|
3237
|
+
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
2947
3238
|
```
|
2948
3239
|
"""
|
2949
|
-
return
|
3240
|
+
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
2950
3241
|
|
2951
3242
|
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2952
3243
|
"""
|
@@ -2970,6 +3261,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2970
3261
|
numerator, denominator = self._broadcasted(x, reverse)
|
2971
3262
|
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
2972
3263
|
|
3264
|
+
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3265
|
+
"""
|
3266
|
+
Mod `self` by `x`.
|
3267
|
+
Equivalent to `self % x`.
|
3268
|
+
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
3269
|
+
|
3270
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3271
|
+
print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
3272
|
+
```
|
3273
|
+
"""
|
3274
|
+
a, b = self._broadcasted(x, reverse)
|
3275
|
+
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
3276
|
+
|
2973
3277
|
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2974
3278
|
"""
|
2975
3279
|
Computes bitwise xor of `self` and `x`.
|
@@ -2984,7 +3288,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2984
3288
|
```
|
2985
3289
|
"""
|
2986
3290
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
2987
|
-
return
|
3291
|
+
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
|
2988
3292
|
|
2989
3293
|
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2990
3294
|
"""
|
@@ -2999,7 +3303,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
2999
3303
|
```
|
3000
3304
|
"""
|
3001
3305
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3002
|
-
return
|
3306
|
+
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
3003
3307
|
|
3004
3308
|
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3005
3309
|
"""
|
@@ -3014,7 +3318,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3014
3318
|
```
|
3015
3319
|
"""
|
3016
3320
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3017
|
-
return
|
3321
|
+
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
|
3018
3322
|
|
3019
3323
|
def bitwise_not(self) -> Tensor:
|
3020
3324
|
"""
|
@@ -3028,7 +3332,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3028
3332
|
```
|
3029
3333
|
"""
|
3030
3334
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3031
|
-
return self.logical_not() if self.dtype == dtypes.bool else self ^
|
3335
|
+
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
3032
3336
|
|
3033
3337
|
def lshift(self, x:int):
|
3034
3338
|
"""
|
@@ -3060,37 +3364,22 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3060
3364
|
Equivalent to `self ** x`.
|
3061
3365
|
|
3062
3366
|
```python exec="true" source="above" session="tensor" result="python"
|
3063
|
-
print(Tensor([-1, 2, 3]).pow(2).numpy())
|
3367
|
+
print(Tensor([-1, 2, 3]).pow(2.0).numpy())
|
3064
3368
|
```
|
3065
3369
|
```python exec="true" source="above" session="tensor" result="python"
|
3066
3370
|
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
|
3067
3371
|
```
|
3068
3372
|
```python exec="true" source="above" session="tensor" result="python"
|
3069
|
-
print((2 ** Tensor([-1, 2, 3])).numpy())
|
3373
|
+
print((2.0 ** Tensor([-1, 2, 3])).numpy())
|
3070
3374
|
```
|
3071
3375
|
"""
|
3072
|
-
x = self._to_const_val(x)
|
3073
|
-
if not isinstance(x, Tensor) and not reverse:
|
3074
|
-
# simple pow identities
|
3075
|
-
if x < 0: return self.reciprocal().pow(-x)
|
3076
|
-
if x == 0: return 1 + self * 0
|
3077
|
-
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
3078
|
-
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
3079
|
-
|
3080
|
-
# positive const ** self
|
3081
|
-
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
3082
|
-
|
3083
3376
|
base, exponent = self._broadcasted(x, reverse=reverse)
|
3084
|
-
#
|
3085
|
-
|
3086
|
-
|
3087
|
-
|
3088
|
-
|
3089
|
-
|
3090
|
-
# inject nan for negative base and non-integer exponent
|
3091
|
-
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
3092
|
-
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
3093
|
-
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
3377
|
+
# TODO: int pow
|
3378
|
+
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
3379
|
+
|
3380
|
+
# NOTE: pow(int, float) -> int
|
3381
|
+
ret = base._apply_uop(UOp.pow, exponent)
|
3382
|
+
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
3094
3383
|
|
3095
3384
|
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
3096
3385
|
"""
|
@@ -3103,7 +3392,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3103
3392
|
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
3104
3393
|
```
|
3105
3394
|
"""
|
3106
|
-
return
|
3395
|
+
return self._apply_broadcasted_uop(UOp.maximum, x)
|
3107
3396
|
|
3108
3397
|
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
3109
3398
|
"""
|
@@ -3116,9 +3405,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3116
3405
|
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
3117
3406
|
```
|
3118
3407
|
"""
|
3119
|
-
|
3408
|
+
t, x = self._broadcasted(x)
|
3409
|
+
return t._inverse().maximum(x._inverse())._inverse()
|
3120
3410
|
|
3121
|
-
def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
|
3411
|
+
def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
|
3122
3412
|
"""
|
3123
3413
|
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
3124
3414
|
`output_i = x_i if self_i else y_i`.
|
@@ -3140,7 +3430,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3140
3430
|
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
3141
3431
|
cond, x = self._broadcasted(x, match_dtype=False)
|
3142
3432
|
cond, y = cond._broadcasted(y, match_dtype=False)
|
3143
|
-
return
|
3433
|
+
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
3144
3434
|
|
3145
3435
|
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
3146
3436
|
|
@@ -3170,9 +3460,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3170
3460
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
3171
3461
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
3172
3462
|
|
3173
|
-
def
|
3174
|
-
def
|
3175
|
-
def ne(self, x) -> Tensor: return
|
3463
|
+
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
3464
|
+
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
3465
|
+
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
|
3176
3466
|
|
3177
3467
|
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
3178
3468
|
|
@@ -3194,7 +3484,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3194
3484
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
3195
3485
|
return x.add(bias) if bias is not None else x
|
3196
3486
|
|
3197
|
-
def sequential(self, ll:
|
3487
|
+
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
|
3198
3488
|
"""
|
3199
3489
|
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
3200
3490
|
|
@@ -3205,7 +3495,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3205
3495
|
"""
|
3206
3496
|
return functools.reduce(lambda x,f: f(x), ll, self)
|
3207
3497
|
|
3208
|
-
def layernorm(self, axis:Union[int,
|
3498
|
+
def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
|
3209
3499
|
"""
|
3210
3500
|
Applies Layer Normalization over a mini-batch of inputs.
|
3211
3501
|
|
@@ -3224,7 +3514,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3224
3514
|
y = (self - self.mean(axis, keepdim=True))
|
3225
3515
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
3226
3516
|
|
3227
|
-
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,
|
3517
|
+
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
|
3228
3518
|
"""
|
3229
3519
|
Applies Batch Normalization over a mini-batch of inputs.
|
3230
3520
|
|
@@ -3266,6 +3556,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3266
3556
|
if not Tensor.training or p == 0: return self
|
3267
3557
|
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
3268
3558
|
|
3559
|
+
# helper function commonly used for indexing
|
3560
|
+
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
3561
|
+
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
|
3562
|
+
offset = self.ndim - self._resolve_dim(dim) - 1
|
3563
|
+
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
3564
|
+
|
3269
3565
|
def one_hot(self, num_classes:int=-1) -> Tensor:
|
3270
3566
|
"""
|
3271
3567
|
Converts `self` to a one-hot tensor.
|
@@ -3277,11 +3573,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3277
3573
|
print(t.one_hot(5).numpy())
|
3278
3574
|
```
|
3279
3575
|
"""
|
3576
|
+
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
|
3280
3577
|
if num_classes == -1: num_classes = (self.max()+1).item()
|
3281
|
-
return
|
3578
|
+
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
3282
3579
|
|
3283
|
-
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:
|
3284
|
-
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
3580
|
+
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
3285
3581
|
"""
|
3286
3582
|
Computes scaled dot-product attention.
|
3287
3583
|
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
@@ -3298,14 +3594,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3298
3594
|
"""
|
3299
3595
|
# NOTE: it also works when `key` and `value` have symbolic shape.
|
3300
3596
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
3301
|
-
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
3302
|
-
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
3303
3597
|
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
3304
|
-
|
3598
|
+
# handle attention mask
|
3599
|
+
if is_causal:
|
3600
|
+
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
3601
|
+
attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
|
3602
|
+
if attn_mask is not None:
|
3603
|
+
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
3604
|
+
qk = qk + attn_mask
|
3605
|
+
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
3305
3606
|
|
3306
3607
|
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
3307
3608
|
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
3308
|
-
reductions:
|
3609
|
+
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
|
3309
3610
|
return reductions[reduction](self)
|
3310
3611
|
|
3311
3612
|
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
@@ -3354,8 +3655,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3354
3655
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
3355
3656
|
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
3356
3657
|
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
3357
|
-
|
3358
|
-
y = (
|
3658
|
+
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
|
3659
|
+
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
3359
3660
|
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
3360
3661
|
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
3361
3662
|
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
@@ -3469,7 +3770,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3469
3770
|
"""
|
3470
3771
|
return dtypes.is_float(self.dtype)
|
3471
3772
|
|
3472
|
-
def size(self, dim:Optional[int]=None) -> Union[sint,
|
3773
|
+
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
|
3473
3774
|
"""
|
3474
3775
|
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
3475
3776
|
|
@@ -3488,7 +3789,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3488
3789
|
def llvm_bf16_cast(self, dtype:DTypeLike):
|
3489
3790
|
# hack for devices that don't support bfloat16
|
3490
3791
|
assert self.dtype == dtypes.bfloat16
|
3491
|
-
return self.to("LLVM").
|
3792
|
+
return self.to("LLVM").cast(dtype)
|
3492
3793
|
|
3493
3794
|
def cast(self, dtype:DTypeLike) -> Tensor:
|
3494
3795
|
"""
|
@@ -3502,8 +3803,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3502
3803
|
t = t.cast(dtypes.int32)
|
3503
3804
|
print(t.dtype, t.numpy())
|
3504
3805
|
```
|
3806
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3807
|
+
t = t.cast(dtypes.uint8)
|
3808
|
+
print(t.dtype, t.numpy())
|
3809
|
+
```
|
3505
3810
|
"""
|
3506
|
-
|
3811
|
+
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
|
3812
|
+
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
|
3813
|
+
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
|
3814
|
+
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
|
3507
3815
|
|
3508
3816
|
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
3509
3817
|
"""
|
@@ -3522,13 +3830,13 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3522
3830
|
"""
|
3523
3831
|
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
3524
3832
|
dt = to_dtype(dtype)
|
3525
|
-
if (
|
3526
|
-
|
3833
|
+
if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
3834
|
+
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
|
3527
3835
|
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
3528
3836
|
tmp = self.bitcast(old_uint)
|
3529
3837
|
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
3530
3838
|
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
3531
|
-
return
|
3839
|
+
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
3532
3840
|
|
3533
3841
|
def float(self) -> Tensor:
|
3534
3842
|
"""
|
@@ -3650,7 +3958,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|
3650
3958
|
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
3651
3959
|
|
3652
3960
|
# prepare input
|
3653
|
-
x = x.permute(0,3,4,5,1,2).pad(self.
|
3961
|
+
x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
3654
3962
|
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
3655
3963
|
|
3656
3964
|
# prepare weights
|
@@ -3702,5 +4010,5 @@ def _metadata_wrapper(fn):
|
|
3702
4010
|
|
3703
4011
|
if TRACEMETA >= 1:
|
3704
4012
|
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
3705
|
-
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
4013
|
+
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
|
3706
4014
|
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|