tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools
|
4
|
-
from
|
3
|
+
import time, math, itertools, functools, struct
|
4
|
+
from contextlib import ContextDecorator
|
5
|
+
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
5
6
|
from collections import defaultdict
|
6
|
-
from functools import partialmethod, reduce
|
7
7
|
import numpy as np
|
8
8
|
|
9
|
-
from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
|
10
|
-
from tinygrad.helpers import argfix, make_pair,
|
11
|
-
from tinygrad.
|
12
|
-
from tinygrad.
|
13
|
-
from tinygrad.
|
14
|
-
from tinygrad.
|
15
|
-
from tinygrad.
|
16
|
-
from tinygrad.
|
9
|
+
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
10
|
+
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
11
|
+
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
12
|
+
from tinygrad.lazy import LazyBuffer
|
13
|
+
from tinygrad.multi import MultiLazyBuffer
|
14
|
+
from tinygrad.ops import LoadOps, truncate
|
15
|
+
from tinygrad.device import Device, Buffer, BufferOptions
|
16
|
+
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
17
|
+
from tinygrad.engine.realize import run_schedule
|
18
|
+
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
|
17
19
|
|
18
20
|
# **** start with two base classes, Tensor and Function ****
|
19
21
|
|
@@ -30,69 +32,145 @@ class Function:
|
|
30
32
|
@classmethod
|
31
33
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
32
34
|
ctx = fxn(x[0].device, *x)
|
33
|
-
ret = Tensor
|
34
|
-
|
35
|
+
ret = Tensor.__new__(Tensor)
|
36
|
+
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
37
|
+
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
35
38
|
return ret
|
36
39
|
|
37
|
-
import tinygrad.
|
40
|
+
import tinygrad.function as F
|
38
41
|
|
39
|
-
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:
|
42
|
+
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
40
43
|
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
41
44
|
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
42
45
|
|
43
|
-
|
46
|
+
def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
|
47
|
+
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
48
|
+
|
49
|
+
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
50
|
+
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
51
|
+
# fake realize
|
52
|
+
ret.buffer.allocate(x)
|
53
|
+
del ret.srcs
|
54
|
+
return ret
|
55
|
+
|
56
|
+
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
57
|
+
if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
|
58
|
+
else:
|
59
|
+
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
60
|
+
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
61
|
+
truncate_function = truncate[dtype]
|
62
|
+
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
63
|
+
# fake realize
|
64
|
+
ret.buffer.allocate(memoryview(data))
|
65
|
+
del ret.srcs
|
66
|
+
return ret
|
67
|
+
|
68
|
+
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
69
|
+
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
70
|
+
for k in range(len(mat[0]))] for dim in range(dims)]
|
71
|
+
|
72
|
+
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
73
|
+
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
74
|
+
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
|
75
|
+
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
76
|
+
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
77
|
+
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
78
|
+
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
79
|
+
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
80
|
+
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
81
|
+
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
82
|
+
return ret
|
83
|
+
|
84
|
+
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
85
|
+
max_dim = max(len(shape) for shape in shapes)
|
86
|
+
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
87
|
+
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
88
|
+
return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
44
89
|
|
45
90
|
class Tensor:
|
91
|
+
"""
|
92
|
+
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
93
|
+
|
94
|
+
```python exec="true" session="tensor"
|
95
|
+
from tinygrad import Tensor, dtypes, nn
|
96
|
+
import numpy as np
|
97
|
+
import math
|
98
|
+
np.set_printoptions(precision=4)
|
99
|
+
```
|
100
|
+
"""
|
46
101
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
47
102
|
__deletable__ = ('_ctx',)
|
48
103
|
training: ClassVar[bool] = False
|
49
|
-
class train:
|
50
|
-
def __init__(self, val=True): self.val = val
|
51
|
-
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
|
52
|
-
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
53
|
-
|
54
104
|
no_grad: ClassVar[bool] = False
|
55
|
-
|
105
|
+
|
106
|
+
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
56
107
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
57
108
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
58
109
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
59
|
-
|
110
|
+
|
111
|
+
# tensors can have gradients if you have called .backward
|
60
112
|
self.grad: Optional[Tensor] = None
|
61
113
|
|
62
114
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
63
115
|
# None (the default) will be updated to True if it's put in an optimizer
|
64
116
|
self.requires_grad: Optional[bool] = requires_grad
|
65
117
|
|
66
|
-
# internal
|
118
|
+
# internal variable used for autograd graph construction
|
67
119
|
self._ctx: Optional[Function] = None
|
120
|
+
|
121
|
+
# create a LazyBuffer from the different types of inputs
|
68
122
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
69
|
-
elif isinstance(data, get_args(
|
70
|
-
elif isinstance(data,
|
123
|
+
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
124
|
+
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
125
|
+
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
|
126
|
+
elif isinstance(data, (list, tuple)):
|
127
|
+
if dtype is None:
|
128
|
+
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
129
|
+
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
130
|
+
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
131
|
+
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
|
71
132
|
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
72
|
-
elif isinstance(data, list):
|
73
|
-
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
74
|
-
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
75
|
-
else: dtype = dtype or dtypes.default_float
|
76
|
-
# NOTE: cast at the end for the dtypes that do not have a numpy dtype
|
77
|
-
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
|
78
133
|
elif isinstance(data, np.ndarray):
|
79
|
-
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or
|
80
|
-
else: data =
|
134
|
+
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
135
|
+
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
136
|
+
|
137
|
+
# by this point, it has to be a LazyBuffer
|
138
|
+
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
139
|
+
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
81
140
|
|
82
141
|
# data is a LazyBuffer, but it might be on the wrong device
|
83
|
-
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
84
142
|
if isinstance(device, tuple):
|
85
|
-
#
|
86
|
-
|
143
|
+
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
144
|
+
if isinstance(data, MultiLazyBuffer):
|
145
|
+
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
146
|
+
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
147
|
+
else:
|
148
|
+
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
|
87
149
|
else:
|
88
150
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
89
151
|
|
152
|
+
class train(ContextDecorator):
|
153
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
154
|
+
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
155
|
+
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
156
|
+
|
157
|
+
class inference_mode(ContextDecorator):
|
158
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
159
|
+
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
160
|
+
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
161
|
+
|
90
162
|
def __repr__(self):
|
91
|
-
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
163
|
+
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
92
164
|
|
93
165
|
# Python has a non moving GC, so this should be okay
|
94
166
|
def __hash__(self): return id(self)
|
95
167
|
|
168
|
+
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
169
|
+
|
170
|
+
def __len__(self):
|
171
|
+
if not self.shape: raise TypeError("len() of a 0-d tensor")
|
172
|
+
return self.shape[0]
|
173
|
+
|
96
174
|
@property
|
97
175
|
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
98
176
|
|
@@ -104,148 +182,522 @@ class Tensor:
|
|
104
182
|
|
105
183
|
# ***** data handlers ****
|
106
184
|
|
107
|
-
|
108
|
-
|
109
|
-
|
185
|
+
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
186
|
+
"""Creates the schedule needed to realize these Tensor(s), with Variables."""
|
187
|
+
if getenv("FUZZ_SCHEDULE"):
|
188
|
+
from test.external.fuzz_schedule import fuzz_schedule
|
189
|
+
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
190
|
+
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
|
191
|
+
return memory_planner(schedule), var_vals
|
192
|
+
|
193
|
+
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
194
|
+
"""Creates the schedule needed to realize these Tensor(s)."""
|
195
|
+
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
|
196
|
+
assert len(var_vals) == 0
|
197
|
+
return schedule
|
198
|
+
|
199
|
+
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
200
|
+
"""Triggers the computation needed to create these Tensor(s)."""
|
201
|
+
run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
|
202
|
+
return self
|
110
203
|
|
111
|
-
def
|
112
|
-
|
204
|
+
def replace(self, x:Tensor) -> Tensor:
|
205
|
+
"""
|
206
|
+
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
207
|
+
"""
|
208
|
+
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
209
|
+
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
210
|
+
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
211
|
+
self.lazydata = x.lazydata
|
113
212
|
return self
|
114
213
|
|
115
214
|
def assign(self, x) -> Tensor:
|
116
215
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
117
216
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
118
|
-
if x.__class__ is not Tensor: x = Tensor(x, device="
|
217
|
+
if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
|
119
218
|
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
|
120
219
|
return self
|
121
220
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
221
|
+
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
222
|
+
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
122
223
|
# NOTE: we allow cross device assign
|
123
224
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
225
|
+
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
226
|
+
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
227
|
+
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
124
228
|
assert not x.requires_grad # self requires_grad is okay?
|
125
|
-
if
|
126
|
-
|
127
|
-
if isinstance(self.lazydata, MultiLazyBuffer):
|
128
|
-
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
129
|
-
else:
|
130
|
-
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
131
|
-
self.lazydata = x.lazydata
|
229
|
+
if not self.lazydata.is_realized(): return self.replace(x)
|
230
|
+
self.lazydata = self.lazydata.assign(x.lazydata)
|
132
231
|
return self
|
133
|
-
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
134
232
|
|
135
|
-
|
136
|
-
|
233
|
+
def detach(self) -> Tensor:
|
234
|
+
"""
|
235
|
+
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
236
|
+
"""
|
237
|
+
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
238
|
+
|
239
|
+
def _data(self) -> memoryview:
|
240
|
+
if 0 in self.shape: return memoryview(bytearray(0))
|
241
|
+
# NOTE: this realizes on the object from as_buffer being a Python object
|
242
|
+
cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
|
243
|
+
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
244
|
+
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
245
|
+
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
246
|
+
|
247
|
+
def data(self) -> memoryview:
|
248
|
+
"""
|
249
|
+
Returns the data of this tensor as a memoryview.
|
250
|
+
|
251
|
+
```python exec="true" source="above" session="tensor" result="python"
|
252
|
+
t = Tensor([1, 2, 3, 4])
|
253
|
+
print(np.frombuffer(t.data(), dtype=np.int32))
|
254
|
+
```
|
255
|
+
"""
|
256
|
+
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
257
|
+
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
258
|
+
return self._data().cast(self.dtype.fmt, self.shape)
|
259
|
+
|
260
|
+
def item(self) -> ConstType:
|
261
|
+
"""
|
262
|
+
Returns the value of this tensor as a standard Python number.
|
263
|
+
|
264
|
+
```python exec="true" source="above" session="tensor" result="python"
|
265
|
+
t = Tensor(42)
|
266
|
+
print(t.item())
|
267
|
+
```
|
268
|
+
"""
|
269
|
+
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
137
270
|
assert self.numel() == 1, "must have one element for item"
|
138
|
-
return
|
139
|
-
|
271
|
+
return self._data().cast(self.dtype.fmt)[0]
|
272
|
+
|
273
|
+
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
274
|
+
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
275
|
+
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
276
|
+
"""
|
277
|
+
Returns the value of this tensor as a nested list.
|
278
|
+
|
279
|
+
```python exec="true" source="above" session="tensor" result="python"
|
280
|
+
t = Tensor([1, 2, 3, 4])
|
281
|
+
print(t.tolist())
|
282
|
+
```
|
283
|
+
"""
|
284
|
+
return self.data().tolist()
|
140
285
|
|
141
|
-
# TODO: this should import numpy and use .data() to construct the array
|
142
286
|
def numpy(self) -> np.ndarray:
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
287
|
+
"""
|
288
|
+
Returns the value of this tensor as a `numpy.ndarray`.
|
289
|
+
|
290
|
+
```python exec="true" source="above" session="tensor" result="python"
|
291
|
+
t = Tensor([1, 2, 3, 4])
|
292
|
+
print(repr(t.numpy()))
|
293
|
+
```
|
294
|
+
"""
|
295
|
+
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
296
|
+
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
297
|
+
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
298
|
+
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
299
|
+
|
300
|
+
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
301
|
+
"""
|
302
|
+
Moves the tensor to the given device.
|
303
|
+
"""
|
304
|
+
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
305
|
+
if device == self.device: return self
|
306
|
+
if not isinstance(device, str): return self.shard(device)
|
307
|
+
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
308
|
+
if self.grad is not None: ret.grad = self.grad.to(device)
|
309
|
+
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
153
310
|
return ret
|
154
311
|
|
155
|
-
def to_(self, device:Optional[str]):
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
312
|
+
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
|
313
|
+
"""
|
314
|
+
Moves the tensor to the given device in place.
|
315
|
+
"""
|
316
|
+
real = self.to(device)
|
317
|
+
# TODO: is this assign?
|
318
|
+
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
319
|
+
self.lazydata = real.lazydata
|
160
320
|
|
161
321
|
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
322
|
+
"""
|
323
|
+
Shards the tensor across the given devices.
|
324
|
+
"""
|
162
325
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
163
326
|
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
|
164
|
-
|
327
|
+
if axis is not None and axis < 0: axis += len(self.shape)
|
328
|
+
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
|
165
329
|
|
166
330
|
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
|
331
|
+
"""
|
332
|
+
Shards the tensor across the given devices in place.
|
333
|
+
"""
|
167
334
|
self.lazydata = self.shard(devices, axis).lazydata
|
168
335
|
return self
|
169
336
|
|
337
|
+
@staticmethod
|
338
|
+
def from_node(y:Node, **kwargs) -> Tensor:
|
339
|
+
if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
|
340
|
+
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
|
341
|
+
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
|
342
|
+
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
|
343
|
+
raise RuntimeError(f"unhandled Node {y}")
|
344
|
+
|
170
345
|
# ***** creation llop entrypoint *****
|
171
346
|
|
172
347
|
@staticmethod
|
173
|
-
def _loadop(op, shape, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
174
|
-
|
348
|
+
def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
349
|
+
if isinstance(device, tuple):
|
350
|
+
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
351
|
+
for d in device], None), device, dtype, **kwargs)
|
352
|
+
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
175
353
|
|
176
354
|
@staticmethod
|
177
|
-
def empty(*shape, **kwargs):
|
355
|
+
def empty(*shape, **kwargs):
|
356
|
+
"""
|
357
|
+
Creates an empty tensor with the given shape.
|
358
|
+
|
359
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
360
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
361
|
+
|
362
|
+
```python exec="true" source="above" session="tensor" result="python"
|
363
|
+
t = Tensor.empty(2, 3)
|
364
|
+
print(t.shape)
|
365
|
+
```
|
366
|
+
"""
|
367
|
+
return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
178
368
|
|
179
369
|
_seed: int = int(time.time())
|
370
|
+
_rng_counter: Optional[Tensor] = None
|
180
371
|
@staticmethod
|
181
|
-
def manual_seed(seed=0):
|
372
|
+
def manual_seed(seed=0):
|
373
|
+
"""
|
374
|
+
Sets the seed for random operations.
|
375
|
+
|
376
|
+
```python exec="true" source="above" session="tensor" result="python"
|
377
|
+
Tensor.manual_seed(42)
|
378
|
+
print(Tensor.rand(5).numpy())
|
379
|
+
print(Tensor.rand(5).numpy())
|
380
|
+
```
|
381
|
+
```python exec="true" source="above" session="tensor" result="python"
|
382
|
+
Tensor.manual_seed(42) # reset to the same seed
|
383
|
+
print(Tensor.rand(5).numpy())
|
384
|
+
print(Tensor.rand(5).numpy())
|
385
|
+
```
|
386
|
+
"""
|
387
|
+
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
182
388
|
|
183
389
|
@staticmethod
|
184
|
-
def rand(*shape,
|
390
|
+
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
391
|
+
"""
|
392
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
393
|
+
|
394
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
395
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
396
|
+
|
397
|
+
```python exec="true" source="above" session="tensor" result="python"
|
398
|
+
Tensor.manual_seed(42)
|
399
|
+
t = Tensor.rand(2, 3)
|
400
|
+
print(t.numpy())
|
401
|
+
```
|
402
|
+
"""
|
403
|
+
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
404
|
+
if not THREEFRY.value:
|
405
|
+
# for bfloat16, numpy rand passes buffer in float
|
406
|
+
if (dtype or dtypes.default_float) == dtypes.bfloat16:
|
407
|
+
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
408
|
+
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
409
|
+
|
410
|
+
# threefry
|
411
|
+
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
412
|
+
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
|
413
|
+
counts2 = counts1 + math.ceil(num / 2)
|
414
|
+
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
415
|
+
|
416
|
+
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
417
|
+
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
|
418
|
+
|
419
|
+
x = [counts1 + ks[-1], counts2 + ks[0]]
|
420
|
+
for i in range(5):
|
421
|
+
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
|
422
|
+
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
423
|
+
out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
|
424
|
+
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
425
|
+
out.requires_grad = kwargs.get("requires_grad")
|
426
|
+
return out.contiguous()
|
185
427
|
|
186
428
|
# ***** creation helper functions *****
|
187
429
|
|
188
430
|
@staticmethod
|
189
|
-
def full(shape:Tuple[sint, ...], fill_value:
|
431
|
+
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
|
432
|
+
"""
|
433
|
+
Creates a tensor with the given shape, filled with the given value.
|
434
|
+
|
435
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
436
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
437
|
+
|
438
|
+
```python exec="true" source="above" session="tensor" result="python"
|
439
|
+
print(Tensor.full((2, 3), 42).numpy())
|
440
|
+
```
|
441
|
+
```python exec="true" source="above" session="tensor" result="python"
|
442
|
+
print(Tensor.full((2, 3), False).numpy())
|
443
|
+
```
|
444
|
+
"""
|
190
445
|
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
191
446
|
|
192
447
|
@staticmethod
|
193
|
-
def zeros(*shape, **kwargs):
|
448
|
+
def zeros(*shape, **kwargs):
|
449
|
+
"""
|
450
|
+
Creates a tensor with the given shape, filled with zeros.
|
451
|
+
|
452
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
453
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
454
|
+
|
455
|
+
```python exec="true" source="above" session="tensor" result="python"
|
456
|
+
print(Tensor.zeros(2, 3).numpy())
|
457
|
+
```
|
458
|
+
```python exec="true" source="above" session="tensor" result="python"
|
459
|
+
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
460
|
+
```
|
461
|
+
"""
|
462
|
+
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
194
463
|
|
195
464
|
@staticmethod
|
196
|
-
def ones(*shape, **kwargs):
|
465
|
+
def ones(*shape, **kwargs):
|
466
|
+
"""
|
467
|
+
Creates a tensor with the given shape, filled with ones.
|
468
|
+
|
469
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
470
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
471
|
+
|
472
|
+
```python exec="true" source="above" session="tensor" result="python"
|
473
|
+
print(Tensor.ones(2, 3).numpy())
|
474
|
+
```
|
475
|
+
```python exec="true" source="above" session="tensor" result="python"
|
476
|
+
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
477
|
+
```
|
478
|
+
"""
|
479
|
+
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
197
480
|
|
198
481
|
@staticmethod
|
199
482
|
def arange(start, stop=None, step=1, **kwargs):
|
483
|
+
"""
|
484
|
+
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
|
485
|
+
|
486
|
+
If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
|
487
|
+
|
488
|
+
If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
|
489
|
+
|
490
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
491
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
492
|
+
|
493
|
+
```python exec="true" source="above" session="tensor" result="python"
|
494
|
+
print(Tensor.arange(5).numpy())
|
495
|
+
```
|
496
|
+
```python exec="true" source="above" session="tensor" result="python"
|
497
|
+
print(Tensor.arange(5, 10).numpy())
|
498
|
+
```
|
499
|
+
```python exec="true" source="above" session="tensor" result="python"
|
500
|
+
print(Tensor.arange(5, 10, 2).numpy())
|
501
|
+
```
|
502
|
+
```python exec="true" source="above" session="tensor" result="python"
|
503
|
+
print(Tensor.arange(5.5, 10, 2).numpy())
|
504
|
+
```
|
505
|
+
"""
|
200
506
|
if stop is None: stop, start = start, 0
|
507
|
+
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
201
508
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
202
|
-
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).
|
509
|
+
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
203
510
|
|
204
511
|
@staticmethod
|
205
512
|
def eye(dim:int, **kwargs):
|
513
|
+
"""
|
514
|
+
Creates an identity matrix of the given dimension.
|
515
|
+
|
516
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
517
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
518
|
+
|
519
|
+
```python exec="true" source="above" session="tensor" result="python"
|
520
|
+
print(Tensor.eye(3).numpy())
|
521
|
+
```
|
522
|
+
"""
|
206
523
|
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
|
207
524
|
|
208
|
-
def full_like(self, fill_value:
|
525
|
+
def full_like(self, fill_value:ConstType, **kwargs):
|
526
|
+
"""
|
527
|
+
Creates a tensor with the same shape as `self`, filled with the given value.
|
528
|
+
If `dtype` is not specified, the dtype of `self` is used.
|
529
|
+
|
530
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
531
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
532
|
+
|
533
|
+
```python exec="true" source="above" session="tensor" result="python"
|
534
|
+
t = Tensor.ones(2, 3)
|
535
|
+
print(Tensor.full_like(t, 42).numpy())
|
536
|
+
```
|
537
|
+
"""
|
209
538
|
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
210
|
-
|
211
|
-
def
|
539
|
+
|
540
|
+
def zeros_like(self, **kwargs):
|
541
|
+
"""
|
542
|
+
Creates a tensor with the same shape as `self`, filled with zeros.
|
543
|
+
|
544
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
545
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
546
|
+
|
547
|
+
```python exec="true" source="above" session="tensor" result="python"
|
548
|
+
t = Tensor.ones(2, 3)
|
549
|
+
print(Tensor.zeros_like(t).numpy())
|
550
|
+
```
|
551
|
+
"""
|
552
|
+
return self.full_like(0, **kwargs)
|
553
|
+
|
554
|
+
def ones_like(self, **kwargs):
|
555
|
+
"""
|
556
|
+
Creates a tensor with the same shape as `self`, filled with ones.
|
557
|
+
|
558
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
559
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
560
|
+
|
561
|
+
```python exec="true" source="above" session="tensor" result="python"
|
562
|
+
t = Tensor.zeros(2, 3)
|
563
|
+
print(Tensor.ones_like(t).numpy())
|
564
|
+
```
|
565
|
+
"""
|
566
|
+
return self.full_like(1, **kwargs)
|
212
567
|
|
213
568
|
# ***** rng hlops *****
|
214
569
|
|
215
570
|
@staticmethod
|
216
571
|
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
572
|
+
"""
|
573
|
+
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
574
|
+
If `dtype` is not specified, the default type is used.
|
575
|
+
|
576
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
577
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
578
|
+
|
579
|
+
```python exec="true" source="above" session="tensor" result="python"
|
580
|
+
Tensor.manual_seed(42)
|
581
|
+
print(Tensor.randn(2, 3).numpy())
|
582
|
+
```
|
583
|
+
"""
|
217
584
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
218
|
-
src = Tensor.rand((2, *argfix(*shape)), **kwargs)
|
585
|
+
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
219
586
|
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
220
587
|
|
221
588
|
@staticmethod
|
222
|
-
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
589
|
+
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
590
|
+
"""
|
591
|
+
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
592
|
+
If `dtype` is not specified, the default type is used.
|
593
|
+
|
594
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
595
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
596
|
+
|
597
|
+
```python exec="true" source="above" session="tensor" result="python"
|
598
|
+
Tensor.manual_seed(42)
|
599
|
+
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
600
|
+
```
|
601
|
+
"""
|
602
|
+
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
603
|
+
dtype = kwargs.pop("dtype", dtypes.int32)
|
604
|
+
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
605
|
+
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
223
606
|
|
224
607
|
@staticmethod
|
225
|
-
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
608
|
+
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
609
|
+
"""
|
610
|
+
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
611
|
+
|
612
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
613
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
614
|
+
|
615
|
+
```python exec="true" source="above" session="tensor" result="python"
|
616
|
+
Tensor.manual_seed(42)
|
617
|
+
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
618
|
+
```
|
619
|
+
"""
|
620
|
+
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
226
621
|
|
227
622
|
@staticmethod
|
228
623
|
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
624
|
+
"""
|
625
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
626
|
+
|
627
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
628
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
629
|
+
|
630
|
+
```python exec="true" source="above" session="tensor" result="python"
|
631
|
+
Tensor.manual_seed(42)
|
632
|
+
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
633
|
+
```
|
634
|
+
"""
|
229
635
|
dtype = kwargs.pop("dtype", dtypes.default_float)
|
230
636
|
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
231
637
|
|
232
638
|
@staticmethod
|
233
|
-
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
639
|
+
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
640
|
+
"""
|
641
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
642
|
+
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
643
|
+
|
644
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
645
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
646
|
+
|
647
|
+
```python exec="true" source="above" session="tensor" result="python"
|
648
|
+
Tensor.manual_seed(42)
|
649
|
+
print(Tensor.scaled_uniform(2, 3).numpy())
|
650
|
+
```
|
651
|
+
"""
|
652
|
+
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
234
653
|
|
235
654
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
236
655
|
@staticmethod
|
237
656
|
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
657
|
+
"""
|
658
|
+
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
659
|
+
|
660
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
661
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
662
|
+
|
663
|
+
```python exec="true" source="above" session="tensor" result="python"
|
664
|
+
Tensor.manual_seed(42)
|
665
|
+
print(Tensor.glorot_uniform(2, 3).numpy())
|
666
|
+
```
|
667
|
+
"""
|
238
668
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
239
669
|
|
240
670
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
241
671
|
@staticmethod
|
242
672
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
673
|
+
"""
|
674
|
+
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
675
|
+
|
676
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
677
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
678
|
+
|
679
|
+
```python exec="true" source="above" session="tensor" result="python"
|
680
|
+
Tensor.manual_seed(42)
|
681
|
+
print(Tensor.kaiming_uniform(2, 3).numpy())
|
682
|
+
```
|
683
|
+
"""
|
243
684
|
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
244
685
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
245
686
|
|
246
687
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
247
688
|
@staticmethod
|
248
689
|
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
690
|
+
"""
|
691
|
+
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
692
|
+
|
693
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
694
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
695
|
+
|
696
|
+
```python exec="true" source="above" session="tensor" result="python"
|
697
|
+
Tensor.manual_seed(42)
|
698
|
+
print(Tensor.kaiming_normal(2, 3).numpy())
|
699
|
+
```
|
700
|
+
"""
|
249
701
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
250
702
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
251
703
|
|
@@ -254,31 +706,40 @@ class Tensor:
|
|
254
706
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
255
707
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
256
708
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
257
|
-
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
|
709
|
+
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
|
258
710
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
259
|
-
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.
|
711
|
+
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
260
712
|
|
261
713
|
# ***** toposort and backward pass *****
|
262
714
|
|
263
|
-
def
|
264
|
-
def
|
715
|
+
def _deepwalk(self):
|
716
|
+
def _walk(node, visited):
|
265
717
|
visited.add(node)
|
266
718
|
if getattr(node, "_ctx", None):
|
267
719
|
for i in node._ctx.parents:
|
268
|
-
if i not in visited:
|
269
|
-
|
270
|
-
|
271
|
-
return _deepwalk(self, set(), [])
|
720
|
+
if i not in visited: yield from _walk(i, visited)
|
721
|
+
yield node
|
722
|
+
return list(_walk(self, set()))
|
272
723
|
|
273
724
|
def backward(self) -> Tensor:
|
725
|
+
"""
|
726
|
+
Propagates the gradient of a tensor backwards through the computation graph.
|
727
|
+
Must be used on a scalar tensor.
|
728
|
+
|
729
|
+
```python exec="true" source="above" session="tensor" result="python"
|
730
|
+
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
731
|
+
t.sum().backward()
|
732
|
+
print(t.grad.numpy())
|
733
|
+
```
|
734
|
+
"""
|
274
735
|
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
275
736
|
|
276
737
|
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
277
738
|
# this is "implicit gradient creation"
|
278
|
-
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
|
739
|
+
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
279
740
|
|
280
|
-
for t0 in reversed(self.
|
281
|
-
|
741
|
+
for t0 in reversed(self._deepwalk()):
|
742
|
+
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
282
743
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
283
744
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
284
745
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
@@ -289,290 +750,825 @@ class Tensor:
|
|
289
750
|
del t0._ctx
|
290
751
|
return self
|
291
752
|
|
292
|
-
# ***** movement
|
753
|
+
# ***** movement low level ops *****
|
754
|
+
|
755
|
+
def view(self, *shape) -> Tensor:
|
756
|
+
"""`.view` is an alias for `.reshape`."""
|
757
|
+
return self.reshape(shape)
|
293
758
|
|
294
759
|
def reshape(self, shape, *args) -> Tensor:
|
295
|
-
|
296
|
-
|
760
|
+
"""
|
761
|
+
Returns a tensor with the same data as the original tensor but with a different shape.
|
762
|
+
`shape` can be passed as a tuple or as separate arguments.
|
763
|
+
|
764
|
+
```python exec="true" source="above" session="tensor" result="python"
|
765
|
+
t = Tensor.arange(6)
|
766
|
+
print(t.reshape(2, 3).numpy())
|
767
|
+
```
|
768
|
+
"""
|
769
|
+
# resolve None and args
|
770
|
+
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
771
|
+
# resolve -1
|
772
|
+
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
773
|
+
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
774
|
+
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
775
|
+
|
297
776
|
def expand(self, shape, *args) -> Tensor:
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
777
|
+
"""
|
778
|
+
Returns a tensor that is expanded to the shape that is specified.
|
779
|
+
Expand can also increase the number of dimensions that a tensor has.
|
780
|
+
|
781
|
+
Passing a `-1` or `None` to a dimension means that its size will not be changed.
|
782
|
+
|
783
|
+
```python exec="true" source="above" session="tensor" result="python"
|
784
|
+
t = Tensor([1, 2, 3])
|
785
|
+
print(t.expand(4, -1).numpy())
|
786
|
+
```
|
787
|
+
"""
|
788
|
+
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
|
789
|
+
|
790
|
+
def permute(self, order, *args) -> Tensor:
|
791
|
+
"""
|
792
|
+
Returns a tensor that is a permutation of the original tensor.
|
793
|
+
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
794
|
+
`order` can be passed as a tuple or as separate arguments.
|
795
|
+
|
796
|
+
```python exec="true" source="above" session="tensor" result="python"
|
797
|
+
t = Tensor.arange(6).reshape(2, 3)
|
798
|
+
print(t.numpy())
|
799
|
+
```
|
800
|
+
```python exec="true" source="above" session="tensor" result="python"
|
801
|
+
print(t.permute(1, 0).numpy())
|
802
|
+
```
|
803
|
+
"""
|
804
|
+
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
805
|
+
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
806
|
+
return F.Permute.apply(self, order=order_arg)
|
807
|
+
|
808
|
+
def flip(self, axis, *args) -> Tensor:
|
809
|
+
"""
|
810
|
+
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
811
|
+
`axis` can be passed as a tuple or as separate arguments.
|
812
|
+
|
813
|
+
```python exec="true" source="above" session="tensor" result="python"
|
814
|
+
t = Tensor.arange(6).reshape(2, 3)
|
815
|
+
print(t.numpy())
|
816
|
+
```
|
817
|
+
```python exec="true" source="above" session="tensor" result="python"
|
818
|
+
print(t.flip(0).numpy())
|
819
|
+
```
|
820
|
+
```python exec="true" source="above" session="tensor" result="python"
|
821
|
+
print(t.flip((0, 1)).numpy())
|
822
|
+
```
|
823
|
+
"""
|
824
|
+
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
825
|
+
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
|
826
|
+
return F.Flip.apply(self, axis=axis_arg)
|
827
|
+
|
302
828
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
303
|
-
|
304
|
-
|
829
|
+
"""
|
830
|
+
Returns a tensor that shrinks the each axis based on input arg.
|
831
|
+
`arg` must have the same length as `self.ndim`.
|
832
|
+
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
833
|
+
|
834
|
+
```python exec="true" source="above" session="tensor" result="python"
|
835
|
+
t = Tensor.arange(9).reshape(3, 3)
|
836
|
+
print(t.numpy())
|
837
|
+
```
|
838
|
+
```python exec="true" source="above" session="tensor" result="python"
|
839
|
+
print(t.shrink(((None, (1, 3)))).numpy())
|
840
|
+
```
|
841
|
+
```python exec="true" source="above" session="tensor" result="python"
|
842
|
+
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
843
|
+
```
|
844
|
+
"""
|
845
|
+
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
|
846
|
+
return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
847
|
+
|
305
848
|
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
849
|
+
"""
|
850
|
+
Returns a tensor that pads the each axis based on input arg.
|
851
|
+
`arg` must have the same length as `self.ndim`.
|
852
|
+
For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
|
853
|
+
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
854
|
+
|
855
|
+
```python exec="true" source="above" session="tensor" result="python"
|
856
|
+
t = Tensor.arange(6).reshape(2, 3)
|
857
|
+
print(t.numpy())
|
858
|
+
```
|
859
|
+
```python exec="true" source="above" session="tensor" result="python"
|
860
|
+
print(t.pad(((None, (1, 2)))).numpy())
|
861
|
+
```
|
862
|
+
```python exec="true" source="above" session="tensor" result="python"
|
863
|
+
print(t.pad(((None, (1, 2))), -2).numpy())
|
864
|
+
```
|
865
|
+
"""
|
306
866
|
if all(x is None or x == (0,0) for x in arg): return self
|
307
|
-
ret =
|
308
|
-
return ret if 0 == value else ret +
|
309
|
-
|
310
|
-
# ***** movement
|
311
|
-
|
312
|
-
#
|
313
|
-
#
|
314
|
-
#
|
315
|
-
#
|
316
|
-
#
|
317
|
-
#
|
318
|
-
#
|
319
|
-
#
|
320
|
-
#
|
321
|
-
#
|
322
|
-
#
|
323
|
-
#
|
324
|
-
#
|
325
|
-
#
|
326
|
-
#
|
327
|
-
#
|
328
|
-
#
|
329
|
-
#
|
330
|
-
#
|
331
|
-
#
|
332
|
-
#
|
333
|
-
#
|
334
|
-
#
|
335
|
-
#
|
336
|
-
#
|
337
|
-
#
|
338
|
-
#
|
339
|
-
|
867
|
+
ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
868
|
+
return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
869
|
+
|
870
|
+
# ***** movement high level ops *****
|
871
|
+
|
872
|
+
# Supported Indexing Implementations:
|
873
|
+
# 1. Int indexing (no copy)
|
874
|
+
# - for all dims where there's int, shrink -> reshape
|
875
|
+
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
876
|
+
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
877
|
+
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
878
|
+
# 2. Slice indexing (no copy)
|
879
|
+
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
880
|
+
# - first shrink the Tensor to X.shrink(((start, end),))
|
881
|
+
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
882
|
+
# - flip where dim value is negative
|
883
|
+
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
884
|
+
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
885
|
+
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
886
|
+
# 3. None indexing (no copy)
|
887
|
+
# - reshape (inject) a dim at the dim where there's None
|
888
|
+
# 4. Tensor indexing (copy)
|
889
|
+
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
890
|
+
# - combine masks together with mul
|
891
|
+
# - apply mask to self by mask * self
|
892
|
+
# - sum reduce away the extra dims added from creating masks
|
893
|
+
# Tiny Things:
|
894
|
+
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
895
|
+
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
896
|
+
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
897
|
+
# 2. Bool indexing is not supported
|
898
|
+
# 3. Out of bounds Tensor indexing results in 0
|
899
|
+
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
900
|
+
def __getitem__(self, indices) -> Tensor:
|
340
901
|
# 1. indices normalization and validation
|
341
902
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
342
|
-
if isinstance(indices, (
|
343
|
-
|
344
|
-
|
345
|
-
else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
|
903
|
+
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
904
|
+
elif isinstance(indices, (tuple, list)):
|
905
|
+
indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
346
906
|
else: indices = [indices]
|
347
907
|
|
908
|
+
# turn scalar Tensors into const val for int indexing if possible
|
909
|
+
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
910
|
+
# move Tensor indices to the same device as self
|
911
|
+
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
912
|
+
|
348
913
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
349
914
|
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
350
915
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
351
|
-
|
352
|
-
indices[fill_idx:fill_idx+1] = [slice(None)] * (
|
916
|
+
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
917
|
+
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
353
918
|
|
354
919
|
# use Dict[type, List[dimension]] to track elements in indices
|
355
920
|
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
356
921
|
|
357
922
|
# record None for dimension injection later and filter None and record rest of indices
|
358
923
|
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
359
|
-
|
924
|
+
tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
925
|
+
indices_filtered = [i for i in indices if i is not None]
|
360
926
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
361
927
|
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
if
|
366
|
-
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
|
367
|
-
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
928
|
+
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
929
|
+
for index_type in type_dim:
|
930
|
+
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
931
|
+
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
368
932
|
|
369
|
-
# 2. basic indexing (no copy)
|
370
|
-
# currently indices_filtered: Tuple[Union[
|
371
|
-
# turn indices in indices_filtered to Tuple[
|
933
|
+
# 2. basic indexing, uses only movement ops (no copy)
|
934
|
+
# currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
|
935
|
+
# turn indices in indices_filtered to Tuple[new_slice, strides]
|
372
936
|
for dim in type_dim[int]:
|
373
937
|
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
|
374
|
-
raise IndexError(f"{index=} is out of bounds
|
938
|
+
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
|
375
939
|
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
|
376
940
|
for dim in type_dim[slice]:
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
941
|
+
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
|
942
|
+
s, e, st = index.indices(self.shape[dim])
|
943
|
+
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
|
944
|
+
# record tensors and skip all Tensor dims for basic indexing
|
945
|
+
tensor_index: List[Tensor] = []
|
946
|
+
for dim in type_dim[Tensor]:
|
947
|
+
tensor_index.append(index := indices_filtered[dim])
|
948
|
+
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
949
|
+
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
950
|
+
|
951
|
+
new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
|
952
|
+
# flip negative strides
|
953
|
+
ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
|
954
|
+
# handle stride != 1 or -1
|
955
|
+
if any(abs(st) != 1 for st in strides):
|
385
956
|
strides = tuple(abs(s) for s in strides)
|
386
|
-
|
387
|
-
ret = ret.
|
388
|
-
ret = ret.
|
957
|
+
# pad shape to multiple of stride
|
958
|
+
ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
|
959
|
+
ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
|
960
|
+
ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
|
389
961
|
|
390
962
|
# inject 1 for dim where it's None and collapse dim for int
|
391
963
|
new_shape = list(ret.shape)
|
392
964
|
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
393
|
-
for dim in (dims_collapsed :=
|
394
|
-
assert all_int(new_shape), f"does not support symbolic shape {new_shape}"
|
965
|
+
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
|
395
966
|
|
396
967
|
ret = ret.reshape(new_shape)
|
397
968
|
|
398
969
|
# 3. advanced indexing (copy)
|
399
970
|
if type_dim[Tensor]:
|
971
|
+
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
972
|
+
def calc_dim(tensor_dim:int) -> int:
|
973
|
+
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
974
|
+
|
975
|
+
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
976
|
+
# track tensor_dim and tensor_index using a dict
|
977
|
+
# calc_dim to get dim and use that to normalize the negative tensor indices
|
978
|
+
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
|
979
|
+
|
980
|
+
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
981
|
+
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
400
982
|
|
401
|
-
#
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
#
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
|
416
|
-
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
417
|
-
reshaped_idx = first_idx + rest_idx
|
418
|
-
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
|
419
|
-
|
420
|
-
# iteratively eq -> mul -> sum fancy index
|
421
|
-
try:
|
422
|
-
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
423
|
-
except AssertionError as exc: raise IndexError(f"cannot broadcast with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
|
983
|
+
# create masks
|
984
|
+
for dim, i in idx.items():
|
985
|
+
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
986
|
+
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
987
|
+
a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
|
988
|
+
masks.append(i == a)
|
989
|
+
|
990
|
+
# reduce masks to 1 mask
|
991
|
+
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
992
|
+
|
993
|
+
# inject 1's for the extra dims added in create masks
|
994
|
+
reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
995
|
+
# sum reduce the extra dims introduced in create masks
|
996
|
+
ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
424
997
|
|
425
998
|
# special permute case
|
426
|
-
if
|
427
|
-
|
428
|
-
ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
|
999
|
+
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
1000
|
+
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
|
429
1001
|
return ret
|
430
1002
|
|
431
|
-
def __setitem__(self,indices,v
|
432
|
-
|
433
|
-
|
434
|
-
|
1003
|
+
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
1004
|
+
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
1005
|
+
self.__getitem__(indices).assign(v)
|
1006
|
+
return
|
1007
|
+
# NOTE: check that setitem target is valid first
|
1008
|
+
assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
|
1009
|
+
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1010
|
+
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1011
|
+
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
1012
|
+
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
|
1013
|
+
raise NotImplementedError("Advanced indexing setitem is not currently supported")
|
1014
|
+
|
1015
|
+
assign_to = self.realize().__getitem__(indices)
|
1016
|
+
# NOTE: contiguous to prevent const folding.
|
1017
|
+
v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
1018
|
+
assign_to.assign(v).realize()
|
1019
|
+
|
1020
|
+
# NOTE: using _slice is discouraged and things should migrate to pad and shrink
|
1021
|
+
def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
435
1022
|
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
|
436
1023
|
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
|
437
1024
|
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
|
438
1025
|
|
439
|
-
def gather(self:Tensor,
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
1026
|
+
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
|
1027
|
+
"""
|
1028
|
+
Gathers values along an axis specified by `dim`.
|
1029
|
+
|
1030
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1031
|
+
t = Tensor([[1, 2], [3, 4]])
|
1032
|
+
print(t.numpy())
|
1033
|
+
```
|
1034
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1035
|
+
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
1036
|
+
```
|
1037
|
+
"""
|
1038
|
+
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
1039
|
+
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1040
|
+
dim = self._resolve_dim(dim)
|
1041
|
+
index = index.to(self.device)
|
1042
|
+
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1043
|
+
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
447
1044
|
|
448
1045
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
449
|
-
|
1046
|
+
"""
|
1047
|
+
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
|
1048
|
+
All tensors must have the same shape except in the concatenating dimension.
|
1049
|
+
|
1050
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1051
|
+
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
|
1052
|
+
print(t0.cat(t1, t2, dim=0).numpy())
|
1053
|
+
```
|
1054
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1055
|
+
print(t0.cat(t1, t2, dim=1).numpy())
|
1056
|
+
```
|
1057
|
+
"""
|
1058
|
+
dim = self._resolve_dim(dim)
|
450
1059
|
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
451
1060
|
catargs = [self, *args]
|
452
|
-
|
453
|
-
|
454
|
-
shape_cumsum = [0, *itertools.accumulate(shapes)]
|
1061
|
+
cat_dims = [s.shape[dim] for s in catargs]
|
1062
|
+
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
|
455
1063
|
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
456
|
-
for
|
457
|
-
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
1064
|
+
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
1065
|
+
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
1066
|
+
|
1067
|
+
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1068
|
+
"""
|
1069
|
+
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
|
1070
|
+
|
1071
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1072
|
+
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
|
1073
|
+
print(t0.stack(t1, t2, dim=0).numpy())
|
1074
|
+
```
|
1075
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1076
|
+
print(t0.stack(t1, t2, dim=1).numpy())
|
1077
|
+
```
|
1078
|
+
"""
|
462
1079
|
# checks for shapes and number of dimensions delegated to cat
|
463
|
-
return
|
464
|
-
|
465
|
-
def repeat(self, repeats
|
1080
|
+
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
|
1081
|
+
|
1082
|
+
def repeat(self, repeats, *args) -> Tensor:
|
1083
|
+
"""
|
1084
|
+
Repeats tensor number of times along each dimension specified by `repeats`.
|
1085
|
+
`repeats` can be passed as a tuple or as separate arguments.
|
1086
|
+
|
1087
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1088
|
+
t = Tensor([1, 2, 3])
|
1089
|
+
print(t.repeat(4, 2).numpy())
|
1090
|
+
```
|
1091
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1092
|
+
print(t.repeat(4, 2, 1).shape)
|
1093
|
+
```
|
1094
|
+
"""
|
1095
|
+
repeats = argfix(repeats, *args)
|
466
1096
|
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
467
1097
|
new_shape = [x for b in base_shape for x in [1, b]]
|
468
1098
|
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
469
1099
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
470
1100
|
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
471
1101
|
|
1102
|
+
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
1103
|
+
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
1104
|
+
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
|
1105
|
+
return dim + self.ndim+outer if dim < 0 else dim
|
1106
|
+
|
472
1107
|
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
1108
|
+
"""
|
1109
|
+
Splits the tensor into chunks along the dimension specified by `dim`.
|
1110
|
+
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
1111
|
+
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
|
1112
|
+
|
1113
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1114
|
+
t = Tensor.arange(10).reshape(5, 2)
|
1115
|
+
print(t.numpy())
|
1116
|
+
```
|
1117
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1118
|
+
split = t.split(2)
|
1119
|
+
print("\\n".join([repr(x.numpy()) for x in split]))
|
1120
|
+
```
|
1121
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1122
|
+
split = t.split([1, 4])
|
1123
|
+
print("\\n".join([repr(x.numpy()) for x in split]))
|
1124
|
+
```
|
1125
|
+
"""
|
473
1126
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
474
|
-
dim =
|
475
|
-
if isinstance(sizes, int):
|
1127
|
+
dim = self._resolve_dim(dim)
|
1128
|
+
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
1129
|
+
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
476
1130
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
477
1131
|
|
478
|
-
def chunk(self,
|
1132
|
+
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
|
1133
|
+
"""
|
1134
|
+
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
1135
|
+
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
1136
|
+
The function may return fewer than the specified number of chunks.
|
1137
|
+
|
1138
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1139
|
+
chunked = Tensor.arange(11).chunk(6)
|
1140
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1141
|
+
```
|
1142
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1143
|
+
chunked = Tensor.arange(12).chunk(6)
|
1144
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1145
|
+
```
|
1146
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1147
|
+
chunked = Tensor.arange(13).chunk(6)
|
1148
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1149
|
+
```
|
1150
|
+
"""
|
479
1151
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
480
|
-
|
481
|
-
|
482
|
-
return
|
1152
|
+
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
1153
|
+
dim = self._resolve_dim(dim)
|
1154
|
+
return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
483
1155
|
|
484
1156
|
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
1157
|
+
"""
|
1158
|
+
Returns a tensor with specified dimensions of input of size 1 removed.
|
1159
|
+
If `dim` is not specified, all dimensions with size 1 are removed.
|
1160
|
+
|
1161
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1162
|
+
t = Tensor.zeros(2, 1, 2, 1, 2)
|
1163
|
+
print(t.squeeze().shape)
|
1164
|
+
```
|
1165
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1166
|
+
print(t.squeeze(0).shape)
|
1167
|
+
```
|
1168
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1169
|
+
print(t.squeeze(1).shape)
|
1170
|
+
```
|
1171
|
+
"""
|
485
1172
|
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
486
|
-
|
487
|
-
if not
|
488
|
-
if dim < 0: dim += self.ndim
|
489
|
-
return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
1173
|
+
dim = self._resolve_dim(dim)
|
1174
|
+
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
490
1175
|
|
491
1176
|
def unsqueeze(self, dim:int) -> Tensor:
|
492
|
-
|
1177
|
+
"""
|
1178
|
+
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
1179
|
+
|
1180
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1181
|
+
t = Tensor([1, 2, 3, 4])
|
1182
|
+
print(t.unsqueeze(0).numpy())
|
1183
|
+
```
|
1184
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1185
|
+
print(t.unsqueeze(1).numpy())
|
1186
|
+
```
|
1187
|
+
"""
|
1188
|
+
dim = self._resolve_dim(dim, outer=True)
|
493
1189
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
494
1190
|
|
495
|
-
|
496
|
-
|
1191
|
+
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
|
1192
|
+
"""
|
1193
|
+
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
|
1194
|
+
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
1195
|
+
|
1196
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1197
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1198
|
+
print(t.numpy())
|
1199
|
+
```
|
1200
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1201
|
+
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
1202
|
+
```
|
1203
|
+
"""
|
497
1204
|
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
498
|
-
return self.
|
1205
|
+
return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
499
1206
|
|
500
1207
|
@property
|
501
|
-
def T(self) -> Tensor:
|
502
|
-
|
1208
|
+
def T(self) -> Tensor:
|
1209
|
+
"""`.T` is an alias for `.transpose()`."""
|
1210
|
+
return self.transpose()
|
1211
|
+
|
1212
|
+
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
1213
|
+
"""
|
1214
|
+
Returns a tensor that is a transposed version of the original tensor.
|
1215
|
+
The given dimensions `dim0` and `dim1` are swapped.
|
1216
|
+
|
1217
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1218
|
+
t = Tensor.arange(6).reshape(2, 3)
|
1219
|
+
print(t.numpy())
|
1220
|
+
```
|
1221
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1222
|
+
print(t.transpose(0, 1).numpy())
|
1223
|
+
```
|
1224
|
+
"""
|
503
1225
|
order = list(range(self.ndim))
|
504
|
-
order[
|
1226
|
+
order[dim0], order[dim1] = order[dim1], order[dim0]
|
505
1227
|
return self.permute(order)
|
1228
|
+
|
506
1229
|
def flatten(self, start_dim=0, end_dim=-1):
|
507
|
-
|
1230
|
+
"""
|
1231
|
+
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
1232
|
+
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
1233
|
+
|
1234
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1235
|
+
t = Tensor.arange(8).reshape(2, 2, 2)
|
1236
|
+
print(t.flatten().numpy())
|
1237
|
+
```
|
1238
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1239
|
+
print(t.flatten(start_dim=1).numpy())
|
1240
|
+
```
|
1241
|
+
"""
|
1242
|
+
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
508
1243
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
1244
|
+
|
509
1245
|
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
510
|
-
|
1246
|
+
"""
|
1247
|
+
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
1248
|
+
|
1249
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1250
|
+
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
1251
|
+
```
|
1252
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1253
|
+
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
1254
|
+
```
|
1255
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1256
|
+
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
1257
|
+
```
|
1258
|
+
"""
|
1259
|
+
dim = self._resolve_dim(dim)
|
511
1260
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
512
1261
|
|
513
1262
|
# ***** reduce ops *****
|
514
1263
|
|
515
|
-
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int,
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
ret = fxn.apply(self,
|
522
|
-
return ret if keepdim else ret.reshape(shape
|
523
|
-
|
524
|
-
def sum(self, axis=None, keepdim=False):
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
1264
|
+
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
1265
|
+
if self.ndim == 0:
|
1266
|
+
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
|
1267
|
+
axis = ()
|
1268
|
+
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
1269
|
+
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
1270
|
+
ret = fxn.apply(self, axis=axis_)
|
1271
|
+
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
1272
|
+
|
1273
|
+
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
1274
|
+
"""
|
1275
|
+
Sums the elements of the tensor along the specified axis or axes.
|
1276
|
+
|
1277
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1278
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1279
|
+
|
1280
|
+
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
1281
|
+
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
1282
|
+
|
1283
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1284
|
+
t = Tensor.arange(6).reshape(2, 3)
|
1285
|
+
print(t.numpy())
|
1286
|
+
```
|
1287
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1288
|
+
print(t.sum().numpy())
|
1289
|
+
```
|
1290
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1291
|
+
print(t.sum(axis=0).numpy())
|
1292
|
+
```
|
1293
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1294
|
+
print(t.sum(axis=1).numpy())
|
1295
|
+
```
|
1296
|
+
"""
|
1297
|
+
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
1298
|
+
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1299
|
+
|
1300
|
+
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1301
|
+
"""
|
1302
|
+
Returns the maximum value of the tensor along the specified axis or axes.
|
1303
|
+
|
1304
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1305
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1306
|
+
|
1307
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1308
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1309
|
+
print(t.numpy())
|
1310
|
+
```
|
1311
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1312
|
+
print(t.max().numpy())
|
1313
|
+
```
|
1314
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1315
|
+
print(t.max(axis=0).numpy())
|
1316
|
+
```
|
1317
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1318
|
+
print(t.max(axis=1, keepdim=True).numpy())
|
1319
|
+
```
|
1320
|
+
"""
|
1321
|
+
return self._reduce(F.Max, axis, keepdim)
|
1322
|
+
|
1323
|
+
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1324
|
+
"""
|
1325
|
+
Returns the minimum value of the tensor along the specified axis or axes.
|
1326
|
+
|
1327
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1328
|
+
which the minimum is computed and whether the reduced dimensions are retained.
|
1329
|
+
|
1330
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1331
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1332
|
+
print(t.numpy())
|
1333
|
+
```
|
1334
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1335
|
+
print(t.min().numpy())
|
1336
|
+
```
|
1337
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1338
|
+
print(t.min(axis=0).numpy())
|
1339
|
+
```
|
1340
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1341
|
+
print(t.min(axis=1, keepdim=True).numpy())
|
1342
|
+
```
|
1343
|
+
"""
|
1344
|
+
return -((-self).max(axis=axis, keepdim=keepdim))
|
1345
|
+
|
1346
|
+
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1347
|
+
"""
|
1348
|
+
Returns the mean value of the tensor along the specified axis or axes.
|
1349
|
+
|
1350
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1351
|
+
which the mean is computed and whether the reduced dimensions are retained.
|
1352
|
+
|
1353
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1354
|
+
Tensor.manual_seed(42)
|
1355
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1356
|
+
print(t.numpy())
|
1357
|
+
```
|
1358
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1359
|
+
print(t.mean().numpy())
|
1360
|
+
```
|
1361
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1362
|
+
print(t.mean(axis=0).numpy())
|
1363
|
+
```
|
1364
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1365
|
+
print(t.mean(axis=1).numpy())
|
1366
|
+
```
|
1367
|
+
"""
|
1368
|
+
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
1369
|
+
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
1370
|
+
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
1371
|
+
|
1372
|
+
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1373
|
+
"""
|
1374
|
+
Returns the variance of the tensor along the specified axis or axes.
|
1375
|
+
|
1376
|
+
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
1377
|
+
which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
1378
|
+
|
1379
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1380
|
+
Tensor.manual_seed(42)
|
1381
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1382
|
+
print(t.numpy())
|
1383
|
+
```
|
1384
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1385
|
+
print(t.var().numpy())
|
1386
|
+
```
|
1387
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1388
|
+
print(t.var(axis=0).numpy())
|
1389
|
+
```
|
1390
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1391
|
+
print(t.var(axis=1).numpy())
|
1392
|
+
```
|
1393
|
+
"""
|
1394
|
+
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
1395
|
+
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
|
1396
|
+
return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
|
1397
|
+
|
1398
|
+
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1399
|
+
"""
|
1400
|
+
Returns the standard deviation of the tensor along the specified axis or axes.
|
1401
|
+
|
1402
|
+
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
1403
|
+
which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
1404
|
+
|
1405
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1406
|
+
Tensor.manual_seed(42)
|
1407
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1408
|
+
print(t.numpy())
|
1409
|
+
```
|
1410
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1411
|
+
print(t.std().numpy())
|
1412
|
+
```
|
1413
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1414
|
+
print(t.std(axis=0).numpy())
|
1415
|
+
```
|
1416
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1417
|
+
print(t.std(axis=1).numpy())
|
1418
|
+
```
|
1419
|
+
"""
|
1420
|
+
return self.var(axis, keepdim, correction).sqrt()
|
1421
|
+
|
543
1422
|
def _softmax(self, axis):
|
544
1423
|
m = self - self.max(axis=axis, keepdim=True)
|
545
1424
|
e = m.exp()
|
546
1425
|
return m, e, e.sum(axis=axis, keepdim=True)
|
547
1426
|
|
548
1427
|
def softmax(self, axis=-1):
|
1428
|
+
"""
|
1429
|
+
Applies the softmax function to the tensor along the specified axis.
|
1430
|
+
|
1431
|
+
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
|
1432
|
+
|
1433
|
+
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
|
1434
|
+
|
1435
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1436
|
+
Tensor.manual_seed(42)
|
1437
|
+
t = Tensor.randn(2, 3)
|
1438
|
+
print(t.numpy())
|
1439
|
+
```
|
1440
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1441
|
+
print(t.softmax().numpy())
|
1442
|
+
```
|
1443
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1444
|
+
print(t.softmax(axis=0).numpy())
|
1445
|
+
```
|
1446
|
+
"""
|
549
1447
|
_, e, ss = self._softmax(axis)
|
550
1448
|
return e.div(ss)
|
551
1449
|
|
552
1450
|
def log_softmax(self, axis=-1):
|
1451
|
+
"""
|
1452
|
+
Applies the log-softmax function to the tensor along the specified axis.
|
1453
|
+
|
1454
|
+
The log-softmax function is a numerically stable alternative to the softmax function in log space.
|
1455
|
+
|
1456
|
+
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
|
1457
|
+
|
1458
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1459
|
+
Tensor.manual_seed(42)
|
1460
|
+
t = Tensor.randn(2, 3)
|
1461
|
+
print(t.numpy())
|
1462
|
+
```
|
1463
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1464
|
+
print(t.log_softmax().numpy())
|
1465
|
+
```
|
1466
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1467
|
+
print(t.log_softmax(axis=0).numpy())
|
1468
|
+
```
|
1469
|
+
"""
|
553
1470
|
m, _, ss = self._softmax(axis)
|
554
1471
|
return m - ss.log()
|
555
1472
|
|
1473
|
+
def logsumexp(self, axis=None, keepdim=False):
|
1474
|
+
"""
|
1475
|
+
Computes the log-sum-exp of the tensor along the specified axis or axes.
|
1476
|
+
|
1477
|
+
The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
|
1478
|
+
|
1479
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1480
|
+
which the log-sum-exp is computed and whether the reduced dimensions are retained.
|
1481
|
+
|
1482
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1483
|
+
Tensor.manual_seed(42)
|
1484
|
+
t = Tensor.randn(2, 3)
|
1485
|
+
print(t.numpy())
|
1486
|
+
```
|
1487
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1488
|
+
print(t.logsumexp().numpy())
|
1489
|
+
```
|
1490
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1491
|
+
print(t.logsumexp(axis=0).numpy())
|
1492
|
+
```
|
1493
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1494
|
+
print(t.logsumexp(axis=1).numpy())
|
1495
|
+
```
|
1496
|
+
"""
|
1497
|
+
m = self.max(axis=axis, keepdim=True)
|
1498
|
+
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
1499
|
+
|
556
1500
|
def argmax(self, axis=None, keepdim=False):
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
1501
|
+
"""
|
1502
|
+
Returns the indices of the maximum value of the tensor along the specified axis.
|
1503
|
+
|
1504
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1505
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1506
|
+
|
1507
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1508
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1509
|
+
print(t.numpy())
|
1510
|
+
```
|
1511
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1512
|
+
print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
|
1513
|
+
```
|
1514
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1515
|
+
print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
|
1516
|
+
```
|
1517
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1518
|
+
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
1519
|
+
```
|
1520
|
+
"""
|
1521
|
+
if axis is None: return self.flatten().argmax(0)
|
1522
|
+
axis = self._resolve_dim(axis)
|
561
1523
|
m = self == self.max(axis=axis, keepdim=True)
|
562
1524
|
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
563
|
-
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.
|
564
|
-
|
1525
|
+
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
1526
|
+
|
1527
|
+
def argmin(self, axis=None, keepdim=False):
|
1528
|
+
"""
|
1529
|
+
Returns the indices of the minimum value of the tensor along the specified axis.
|
1530
|
+
|
1531
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1532
|
+
which the minimum is computed and whether the reduced dimensions are retained.
|
1533
|
+
|
1534
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1535
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1536
|
+
print(t.numpy())
|
1537
|
+
```
|
1538
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1539
|
+
print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
|
1540
|
+
```
|
1541
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1542
|
+
print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
|
1543
|
+
```
|
1544
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1545
|
+
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
1546
|
+
```
|
1547
|
+
"""
|
1548
|
+
return (-self).argmax(axis=axis, keepdim=keepdim)
|
565
1549
|
|
566
1550
|
@staticmethod
|
567
|
-
def einsum(formula:str, *raw_xs) -> Tensor:
|
1551
|
+
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
|
1552
|
+
"""
|
1553
|
+
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
1554
|
+
|
1555
|
+
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
1556
|
+
|
1557
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1558
|
+
x = Tensor([[1, 2], [3, 4]])
|
1559
|
+
y = Tensor([[5, 6], [7, 8]])
|
1560
|
+
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
1561
|
+
```
|
1562
|
+
"""
|
568
1563
|
xs:Tuple[Tensor] = argfix(*raw_xs)
|
569
1564
|
formula = formula.replace(" ", "")
|
570
|
-
inputs_str, output = formula.split("->") if "->" in formula else (formula,
|
571
|
-
|
1565
|
+
inputs_str, output = formula.split("->") if "->" in formula else (formula, \
|
1566
|
+
''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
1567
|
+
inputs = inputs_str.split(',')
|
572
1568
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
573
1569
|
|
574
1570
|
# map the value of each letter in the formula
|
575
|
-
letter_val = sorted(merge_dicts([
|
1571
|
+
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
576
1572
|
|
577
1573
|
xs_:List[Tensor] = []
|
578
1574
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
@@ -580,9 +1576,13 @@ class Tensor:
|
|
580
1576
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
581
1577
|
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
582
1578
|
|
583
|
-
|
1579
|
+
# determine the inverse permutation to revert back to original order
|
1580
|
+
rhs_letter_order = argsort(list(output))
|
1581
|
+
rhs_order = argsort(rhs_letter_order)
|
1582
|
+
|
584
1583
|
# sum over all axes that's not in the output, then permute to the output order
|
585
|
-
return reduce(lambda a,b:a*b, xs_)
|
1584
|
+
return functools.reduce(lambda a,b:a*b, xs_) \
|
1585
|
+
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
|
586
1586
|
|
587
1587
|
# ***** processing ops *****
|
588
1588
|
|
@@ -590,46 +1590,72 @@ class Tensor:
|
|
590
1590
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
591
1591
|
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
|
592
1592
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
593
|
-
assert len(k_) == len(s_)
|
594
|
-
|
1593
|
+
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
1594
|
+
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
|
1595
|
+
o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
|
595
1596
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
xup = xup.
|
602
|
-
xup = xup.
|
603
|
-
#
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
611
|
-
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
|
612
|
-
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
613
|
-
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
1597
|
+
# repeats such that we don't need padding
|
1598
|
+
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
|
1599
|
+
# handle dilation
|
1600
|
+
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
1601
|
+
# handle stride
|
1602
|
+
xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
1603
|
+
xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
1604
|
+
# permute to move reduce to the end
|
1605
|
+
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
1606
|
+
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
1607
|
+
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
|
1608
|
+
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
1609
|
+
xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
1610
|
+
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
|
614
1611
|
|
615
1612
|
# NOTE: these work for more than 2D
|
616
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
617
|
-
|
1613
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1614
|
+
"""
|
1615
|
+
Applies average pooling over a tensor.
|
618
1616
|
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
return
|
1617
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
1618
|
+
|
1619
|
+
See: https://paperswithcode.com/method/average-pooling
|
1620
|
+
|
1621
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1622
|
+
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1623
|
+
print(t.avg_pool2d().numpy())
|
1624
|
+
```
|
1625
|
+
"""
|
1626
|
+
kernel_size = make_pair(kernel_size)
|
1627
|
+
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
|
1628
|
+
|
1629
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1630
|
+
"""
|
1631
|
+
Applies max pooling over a tensor.
|
630
1632
|
|
631
|
-
|
632
|
-
|
1633
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
1634
|
+
|
1635
|
+
See: https://paperswithcode.com/method/max-pooling
|
1636
|
+
|
1637
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1638
|
+
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1639
|
+
print(t.max_pool2d().numpy())
|
1640
|
+
```
|
1641
|
+
"""
|
1642
|
+
kernel_size = make_pair(kernel_size)
|
1643
|
+
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
|
1644
|
+
|
1645
|
+
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
1646
|
+
"""
|
1647
|
+
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
1648
|
+
|
1649
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
1650
|
+
|
1651
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
1652
|
+
|
1653
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1654
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1655
|
+
w = Tensor.ones(1, 1, 2, 2)
|
1656
|
+
print(t.conv2d(w).numpy())
|
1657
|
+
```
|
1658
|
+
"""
|
633
1659
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
634
1660
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
635
1661
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
@@ -638,19 +1664,17 @@ class Tensor:
|
|
638
1664
|
# conv2d is a pooling op (with padding)
|
639
1665
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
640
1666
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
641
|
-
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not
|
1667
|
+
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
642
1668
|
# normal conv
|
643
1669
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
644
1670
|
|
645
1671
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
646
|
-
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
|
1672
|
+
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
647
1673
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
648
1674
|
|
649
|
-
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
650
|
-
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
|
651
1675
|
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
652
|
-
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
653
1676
|
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
1677
|
+
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
654
1678
|
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
655
1679
|
|
656
1680
|
# todo: stride == dilation
|
@@ -658,19 +1682,19 @@ class Tensor:
|
|
658
1682
|
# (bs, cin_, tyx, HWI)
|
659
1683
|
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
660
1684
|
# move HW to the front: # (HWI, bs, cin_, tyx)
|
661
|
-
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
1685
|
+
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
662
1686
|
tyx = d.shape[-len(HWI):] # dim of tiling
|
663
1687
|
|
664
1688
|
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
665
1689
|
|
666
1690
|
# compute 6x6 winograd tiles: GgGt, BtdB
|
667
1691
|
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
668
|
-
gfactors =
|
1692
|
+
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
669
1693
|
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
670
|
-
dfactors =
|
1694
|
+
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
671
1695
|
|
672
1696
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
673
|
-
ret =
|
1697
|
+
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
|
674
1698
|
|
675
1699
|
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
676
1700
|
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
@@ -679,163 +1703,880 @@ class Tensor:
|
|
679
1703
|
|
680
1704
|
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
681
1705
|
|
682
|
-
def
|
1706
|
+
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
1707
|
+
"""
|
1708
|
+
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
1709
|
+
|
1710
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
1711
|
+
|
1712
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
1713
|
+
|
1714
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1715
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1716
|
+
w = Tensor.ones(1, 1, 2, 2)
|
1717
|
+
print(t.conv_transpose2d(w).numpy())
|
1718
|
+
```
|
1719
|
+
"""
|
1720
|
+
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
1721
|
+
HW = weight.shape[2:]
|
1722
|
+
stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
|
1723
|
+
if any(s>1 for s in stride):
|
1724
|
+
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
1725
|
+
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
1726
|
+
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
1727
|
+
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
1728
|
+
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
1729
|
+
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
1730
|
+
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
1731
|
+
|
1732
|
+
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
1733
|
+
"""
|
1734
|
+
Performs dot product between two tensors.
|
1735
|
+
|
1736
|
+
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
1737
|
+
|
1738
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1739
|
+
a = Tensor([[1, 2], [3, 4]])
|
1740
|
+
b = Tensor([[5, 6], [7, 8]])
|
1741
|
+
print(a.dot(b).numpy())
|
1742
|
+
```
|
1743
|
+
"""
|
683
1744
|
n1, n2 = len(self.shape), len(w.shape)
|
684
1745
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
685
|
-
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({
|
1746
|
+
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
|
686
1747
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
687
1748
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
688
|
-
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
|
1749
|
+
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
1750
|
+
|
1751
|
+
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
1752
|
+
"""
|
1753
|
+
Performs matrix multiplication between two tensors.
|
689
1754
|
|
690
|
-
|
1755
|
+
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
|
1756
|
+
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
1757
|
+
|
1758
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1759
|
+
a = Tensor([[1, 2], [3, 4]])
|
1760
|
+
b = Tensor([[5, 6], [7, 8]])
|
1761
|
+
print(a.matmul(b).numpy())
|
1762
|
+
```
|
1763
|
+
"""
|
1764
|
+
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
691
1765
|
|
692
1766
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
693
|
-
|
1767
|
+
assert self.shape[axis] != 0
|
1768
|
+
pl_sz = self.shape[axis] - int(not _first_zero)
|
1769
|
+
return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
694
1770
|
def cumsum(self, axis:int=0) -> Tensor:
|
1771
|
+
"""
|
1772
|
+
Computes the cumulative sum of the tensor along the specified axis.
|
1773
|
+
|
1774
|
+
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
1775
|
+
|
1776
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1777
|
+
t = Tensor.ones(2, 3)
|
1778
|
+
print(t.numpy())
|
1779
|
+
```
|
1780
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1781
|
+
print(t.cumsum(1).numpy())
|
1782
|
+
```
|
1783
|
+
"""
|
1784
|
+
axis = self._resolve_dim(axis)
|
1785
|
+
if self.ndim == 0 or 0 in self.shape: return self
|
695
1786
|
# TODO: someday the optimizer will find this on it's own
|
696
1787
|
# for now this is a two stage cumsum
|
697
1788
|
SPLIT = 256
|
698
1789
|
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
699
1790
|
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
700
1791
|
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
701
|
-
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
1792
|
+
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
702
1793
|
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
703
1794
|
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
704
1795
|
return fix(ret) + fix(base_add)
|
705
1796
|
|
706
1797
|
@staticmethod
|
707
|
-
def _tri(r:sint, c:sint,
|
708
|
-
assert
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
def
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
def
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
1798
|
+
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
1799
|
+
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
|
1800
|
+
if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
|
1801
|
+
if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
|
1802
|
+
s = r+c-1
|
1803
|
+
# build a (s, s) upper triangle
|
1804
|
+
t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
|
1805
|
+
return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
|
1806
|
+
|
1807
|
+
def triu(self, diagonal:int=0) -> Tensor:
|
1808
|
+
"""
|
1809
|
+
Returns the upper triangular part of the tensor, the other elements are set to 0.
|
1810
|
+
|
1811
|
+
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
1812
|
+
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
1813
|
+
|
1814
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1815
|
+
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
1816
|
+
print(t.numpy())
|
1817
|
+
```
|
1818
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1819
|
+
print(t.triu(diagonal=0).numpy())
|
1820
|
+
```
|
1821
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1822
|
+
print(t.triu(diagonal=1).numpy())
|
1823
|
+
```
|
1824
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1825
|
+
print(t.triu(diagonal=-1).numpy())
|
1826
|
+
```
|
1827
|
+
"""
|
1828
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
|
1829
|
+
|
1830
|
+
def tril(self, diagonal:int=0) -> Tensor:
|
1831
|
+
"""
|
1832
|
+
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
1833
|
+
|
1834
|
+
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
1835
|
+
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
1836
|
+
|
1837
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1838
|
+
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
1839
|
+
print(t.numpy())
|
1840
|
+
```
|
1841
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1842
|
+
print(t.tril(diagonal=0).numpy())
|
1843
|
+
```
|
1844
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1845
|
+
print(t.tril(diagonal=1).numpy())
|
1846
|
+
```
|
1847
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1848
|
+
print(t.tril(diagonal=-1).numpy())
|
1849
|
+
```
|
1850
|
+
"""
|
1851
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
1852
|
+
|
1853
|
+
# ***** unary ops *****
|
1854
|
+
|
1855
|
+
def logical_not(self):
|
1856
|
+
"""
|
1857
|
+
Computes the logical NOT of the tensor element-wise.
|
1858
|
+
|
1859
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1860
|
+
print(Tensor([False, True]).logical_not().numpy())
|
1861
|
+
```
|
1862
|
+
"""
|
1863
|
+
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
|
1864
|
+
def neg(self):
|
1865
|
+
"""
|
1866
|
+
Negates the tensor element-wise.
|
1867
|
+
|
1868
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1869
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
|
1870
|
+
```
|
1871
|
+
"""
|
1872
|
+
return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
|
1873
|
+
def contiguous(self):
|
1874
|
+
"""
|
1875
|
+
Returns a contiguous tensor.
|
1876
|
+
"""
|
1877
|
+
return F.Contiguous.apply(self)
|
1878
|
+
def contiguous_backward(self):
|
1879
|
+
"""
|
1880
|
+
Inserts a contiguous operation in the backward pass.
|
1881
|
+
"""
|
1882
|
+
return F.ContiguousBackward.apply(self)
|
1883
|
+
def log(self):
|
1884
|
+
"""
|
1885
|
+
Computes the natural logarithm element-wise.
|
1886
|
+
|
1887
|
+
See: https://en.wikipedia.org/wiki/Logarithm
|
1888
|
+
|
1889
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1890
|
+
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
1891
|
+
```
|
1892
|
+
"""
|
1893
|
+
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
1894
|
+
def log2(self):
|
1895
|
+
"""
|
1896
|
+
Computes the base-2 logarithm element-wise.
|
1897
|
+
|
1898
|
+
See: https://en.wikipedia.org/wiki/Logarithm
|
1899
|
+
|
1900
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1901
|
+
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
1902
|
+
```
|
1903
|
+
"""
|
1904
|
+
return self.log()/math.log(2)
|
1905
|
+
def exp(self):
|
1906
|
+
"""
|
1907
|
+
Computes the exponential function element-wise.
|
1908
|
+
|
1909
|
+
See: https://en.wikipedia.org/wiki/Exponential_function
|
1910
|
+
|
1911
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1912
|
+
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
1913
|
+
```
|
1914
|
+
"""
|
1915
|
+
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
1916
|
+
def exp2(self):
|
1917
|
+
"""
|
1918
|
+
Computes the base-2 exponential function element-wise.
|
1919
|
+
|
1920
|
+
See: https://en.wikipedia.org/wiki/Exponential_function
|
1921
|
+
|
1922
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1923
|
+
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
1924
|
+
```
|
1925
|
+
"""
|
1926
|
+
return F.Exp.apply(self*math.log(2))
|
1927
|
+
def relu(self):
|
1928
|
+
"""
|
1929
|
+
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
1930
|
+
|
1931
|
+
- Described: https://paperswithcode.com/method/relu
|
1932
|
+
|
1933
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1934
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
1935
|
+
```
|
1936
|
+
"""
|
1937
|
+
return F.Relu.apply(self)
|
1938
|
+
def sigmoid(self):
|
1939
|
+
"""
|
1940
|
+
Applies the Sigmoid function element-wise.
|
1941
|
+
|
1942
|
+
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
|
1943
|
+
|
1944
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1945
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
1946
|
+
```
|
1947
|
+
"""
|
1948
|
+
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
1949
|
+
def sqrt(self):
|
1950
|
+
"""
|
1951
|
+
Computes the square root of the tensor element-wise.
|
1952
|
+
|
1953
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1954
|
+
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
1955
|
+
```
|
1956
|
+
"""
|
1957
|
+
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
1958
|
+
def rsqrt(self):
|
1959
|
+
"""
|
1960
|
+
Computes the reciprocal of the square root of the tensor element-wise.
|
1961
|
+
|
1962
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1963
|
+
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
1964
|
+
```
|
1965
|
+
"""
|
1966
|
+
return self.reciprocal().sqrt()
|
1967
|
+
def sin(self):
|
1968
|
+
"""
|
1969
|
+
Computes the sine of the tensor element-wise.
|
1970
|
+
|
1971
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1972
|
+
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
1973
|
+
```
|
1974
|
+
"""
|
1975
|
+
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
1976
|
+
def cos(self):
|
1977
|
+
"""
|
1978
|
+
Computes the cosine of the tensor element-wise.
|
1979
|
+
|
1980
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1981
|
+
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
|
1982
|
+
```
|
1983
|
+
"""
|
1984
|
+
return ((math.pi/2)-self).sin()
|
1985
|
+
def tan(self):
|
1986
|
+
"""
|
1987
|
+
Computes the tangent of the tensor element-wise.
|
1988
|
+
|
1989
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1990
|
+
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
|
1991
|
+
```
|
1992
|
+
"""
|
1993
|
+
return self.sin() / self.cos()
|
1994
|
+
|
1995
|
+
# ***** math functions *****
|
1996
|
+
|
1997
|
+
def trunc(self: Tensor) -> Tensor:
|
1998
|
+
"""
|
1999
|
+
Truncates the tensor element-wise.
|
2000
|
+
|
2001
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2002
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
|
2003
|
+
```
|
2004
|
+
"""
|
2005
|
+
return self.cast(dtypes.int32).cast(self.dtype)
|
2006
|
+
def ceil(self: Tensor) -> Tensor:
|
2007
|
+
"""
|
2008
|
+
Rounds the tensor element-wise towards positive infinity.
|
2009
|
+
|
2010
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2011
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
|
2012
|
+
```
|
2013
|
+
"""
|
2014
|
+
return (self > (b := self.trunc())).where(b+1, b)
|
2015
|
+
def floor(self: Tensor) -> Tensor:
|
2016
|
+
"""
|
2017
|
+
Rounds the tensor element-wise towards negative infinity.
|
2018
|
+
|
2019
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2020
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
|
2021
|
+
```
|
2022
|
+
"""
|
2023
|
+
return (self < (b := self.trunc())).where(b-1, b)
|
2024
|
+
def round(self: Tensor) -> Tensor:
|
2025
|
+
"""
|
2026
|
+
Rounds the tensor element-wise.
|
2027
|
+
|
2028
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2029
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
|
2030
|
+
```
|
2031
|
+
"""
|
2032
|
+
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
2033
|
+
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
2034
|
+
"""
|
2035
|
+
Linearly interpolates between `self` and `end` by `weight`.
|
2036
|
+
|
2037
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2038
|
+
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
|
2039
|
+
```
|
2040
|
+
"""
|
2041
|
+
return self + (end - self) * weight
|
2042
|
+
def square(self):
|
2043
|
+
"""
|
2044
|
+
Squares the tensor element-wise.
|
2045
|
+
Equivalent to `self*self`.
|
2046
|
+
|
2047
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2048
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
|
2049
|
+
```
|
2050
|
+
"""
|
2051
|
+
return self*self
|
2052
|
+
def clip(self, min_, max_):
|
2053
|
+
"""
|
2054
|
+
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
2055
|
+
|
2056
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2057
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
|
2058
|
+
```
|
2059
|
+
"""
|
2060
|
+
return self.maximum(min_).minimum(max_)
|
2061
|
+
def sign(self):
|
2062
|
+
"""
|
2063
|
+
Returns the sign of the tensor element-wise.
|
2064
|
+
|
2065
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2066
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
2067
|
+
```
|
2068
|
+
"""
|
2069
|
+
return F.Sign.apply(self)
|
2070
|
+
def abs(self):
|
2071
|
+
"""
|
2072
|
+
Computes the absolute value of the tensor element-wise.
|
2073
|
+
|
2074
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2075
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
|
2076
|
+
```
|
2077
|
+
"""
|
2078
|
+
return self * self.sign()
|
2079
|
+
def reciprocal(self):
|
2080
|
+
"""
|
2081
|
+
Compute `1/x` element-wise.
|
2082
|
+
|
2083
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2084
|
+
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
2085
|
+
```
|
2086
|
+
"""
|
2087
|
+
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
2088
|
+
|
2089
|
+
# ***** activation functions *****
|
2090
|
+
|
2091
|
+
def elu(self, alpha=1.0):
|
2092
|
+
"""
|
2093
|
+
Applies the Exponential Linear Unit (ELU) function element-wise.
|
2094
|
+
|
2095
|
+
- Described: https://paperswithcode.com/method/elu
|
2096
|
+
- Paper: https://arxiv.org/abs/1511.07289v5
|
2097
|
+
|
2098
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2099
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
|
2100
|
+
```
|
2101
|
+
"""
|
2102
|
+
return self.relu() - alpha*(1-self.exp()).relu()
|
2103
|
+
|
2104
|
+
def celu(self, alpha=1.0):
|
2105
|
+
"""
|
2106
|
+
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
|
2107
|
+
|
2108
|
+
- Described: https://paperswithcode.com/method/celu
|
2109
|
+
- Paper: https://arxiv.org/abs/1704.07483
|
2110
|
+
|
2111
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2112
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
|
2113
|
+
```
|
2114
|
+
"""
|
2115
|
+
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
2116
|
+
|
2117
|
+
def swish(self):
|
2118
|
+
"""
|
2119
|
+
See `.silu()`
|
2120
|
+
|
2121
|
+
- Paper: https://arxiv.org/abs/1710.05941v1
|
2122
|
+
|
2123
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2124
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
|
2125
|
+
```
|
2126
|
+
"""
|
2127
|
+
return self * self.sigmoid()
|
2128
|
+
|
2129
|
+
def silu(self):
|
2130
|
+
"""
|
2131
|
+
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
|
2132
|
+
|
2133
|
+
- Described: https://paperswithcode.com/method/silu
|
2134
|
+
- Paper: https://arxiv.org/abs/1606.08415
|
2135
|
+
|
2136
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2137
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
|
2138
|
+
```
|
2139
|
+
"""
|
2140
|
+
return self.swish() # The SiLU function is also known as the swish function.
|
2141
|
+
|
2142
|
+
def relu6(self):
|
2143
|
+
"""
|
2144
|
+
Applies the ReLU6 function element-wise.
|
2145
|
+
|
2146
|
+
- Described: https://paperswithcode.com/method/relu6
|
2147
|
+
- Paper: https://arxiv.org/abs/1704.04861v1
|
2148
|
+
|
2149
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2150
|
+
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
|
2151
|
+
```
|
2152
|
+
"""
|
2153
|
+
return self.relu() - (self-6).relu()
|
2154
|
+
|
2155
|
+
def hardswish(self):
|
2156
|
+
"""
|
2157
|
+
Applies the Hardswish function element-wise.
|
2158
|
+
|
2159
|
+
- Described: https://paperswithcode.com/method/hard-swish
|
2160
|
+
- Paper: https://arxiv.org/abs/1905.02244v5
|
2161
|
+
|
2162
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2163
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
|
2164
|
+
```
|
2165
|
+
"""
|
2166
|
+
return self * (self+3).relu6() * (1/6)
|
2167
|
+
|
2168
|
+
def tanh(self):
|
2169
|
+
"""
|
2170
|
+
Applies the Hyperbolic Tangent (tanh) function element-wise.
|
2171
|
+
|
2172
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
|
2173
|
+
|
2174
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2175
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
|
2176
|
+
```
|
2177
|
+
"""
|
2178
|
+
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
2179
|
+
|
2180
|
+
def sinh(self):
|
2181
|
+
"""
|
2182
|
+
Applies the Hyperbolic Sine (sinh) function element-wise.
|
2183
|
+
|
2184
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
|
2185
|
+
|
2186
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2187
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
|
2188
|
+
```
|
2189
|
+
"""
|
2190
|
+
return (self.exp() - self.neg().exp()) / 2
|
2191
|
+
|
2192
|
+
def cosh(self):
|
2193
|
+
"""
|
2194
|
+
Applies the Hyperbolic Cosine (cosh) function element-wise.
|
2195
|
+
|
2196
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
|
2197
|
+
|
2198
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2199
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
|
2200
|
+
```
|
2201
|
+
"""
|
2202
|
+
return (self.exp() + self.neg().exp()) / 2
|
2203
|
+
|
2204
|
+
def atanh(self):
|
2205
|
+
"""
|
2206
|
+
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
|
2207
|
+
|
2208
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
|
2209
|
+
|
2210
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2211
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
|
2212
|
+
```
|
2213
|
+
"""
|
2214
|
+
return ((1 + self)/(1 - self)).log() / 2
|
2215
|
+
|
2216
|
+
def asinh(self):
|
2217
|
+
"""
|
2218
|
+
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
|
2219
|
+
|
2220
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
|
2221
|
+
|
2222
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2223
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
|
2224
|
+
```
|
2225
|
+
"""
|
2226
|
+
return (self + (self.square() + 1).sqrt()).log()
|
2227
|
+
|
2228
|
+
def acosh(self):
|
2229
|
+
"""
|
2230
|
+
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
|
2231
|
+
|
2232
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
|
2233
|
+
|
2234
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2235
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
|
2236
|
+
```
|
2237
|
+
"""
|
2238
|
+
return (self + (self.square() - 1).sqrt()).log()
|
2239
|
+
|
2240
|
+
def hardtanh(self, min_val=-1, max_val=1):
|
2241
|
+
"""
|
2242
|
+
Applies the Hardtanh function element-wise.
|
2243
|
+
|
2244
|
+
- Described: https://paperswithcode.com/method/hardtanh-activation
|
2245
|
+
|
2246
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2247
|
+
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
|
2248
|
+
```
|
2249
|
+
"""
|
2250
|
+
return self.clip(min_val, max_val)
|
2251
|
+
|
2252
|
+
def gelu(self):
|
2253
|
+
"""
|
2254
|
+
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
2255
|
+
|
2256
|
+
- Described: https://paperswithcode.com/method/gelu
|
2257
|
+
- Paper: https://arxiv.org/abs/1606.08415v5
|
2258
|
+
|
2259
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2260
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
|
2261
|
+
```
|
2262
|
+
"""
|
2263
|
+
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
|
2264
|
+
|
2265
|
+
def quick_gelu(self):
|
2266
|
+
"""
|
2267
|
+
Applies the Sigmoid GELU approximation element-wise.
|
2268
|
+
|
2269
|
+
- Described: https://paperswithcode.com/method/gelu
|
2270
|
+
|
2271
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2272
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
|
2273
|
+
```
|
2274
|
+
"""
|
2275
|
+
return self * (self * 1.702).sigmoid()
|
2276
|
+
|
2277
|
+
def leakyrelu(self, neg_slope=0.01):
|
2278
|
+
"""
|
2279
|
+
Applies the Leaky ReLU function element-wise.
|
2280
|
+
|
2281
|
+
- Described: https://paperswithcode.com/method/leaky-relu
|
2282
|
+
|
2283
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2284
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
|
2285
|
+
```
|
2286
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2287
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
|
2288
|
+
```
|
2289
|
+
"""
|
2290
|
+
return self.relu() - (-neg_slope*self).relu()
|
2291
|
+
|
2292
|
+
def mish(self):
|
2293
|
+
"""
|
2294
|
+
Applies the Mish function element-wise.
|
2295
|
+
|
2296
|
+
- Described: https://paperswithcode.com/method/mish
|
2297
|
+
- Paper: https://arxiv.org/abs/1908.08681v3
|
2298
|
+
|
2299
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2300
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
|
2301
|
+
```
|
2302
|
+
"""
|
2303
|
+
return self * self.softplus().tanh()
|
2304
|
+
|
2305
|
+
def softplus(self, beta=1):
|
2306
|
+
"""
|
2307
|
+
Applies the Softplus function element-wise.
|
2308
|
+
|
2309
|
+
- Described: https://paperswithcode.com/method/softplus
|
2310
|
+
|
2311
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2312
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
|
2313
|
+
```
|
2314
|
+
"""
|
2315
|
+
return (1/beta) * (1 + (self*beta).exp()).log()
|
2316
|
+
|
2317
|
+
def softsign(self):
|
2318
|
+
"""
|
2319
|
+
Applies the Softsign function element-wise.
|
2320
|
+
|
2321
|
+
- Described: https://paperswithcode.com/method/softsign
|
2322
|
+
|
2323
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2324
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
|
2325
|
+
```
|
2326
|
+
"""
|
2327
|
+
return self / (1 + self.abs())
|
2328
|
+
|
2329
|
+
# ***** broadcasted elementwise ops *****
|
2330
|
+
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
|
2331
|
+
if self.shape == shape: return self
|
2332
|
+
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
|
2333
|
+
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
2334
|
+
padded, _ = _pad_left(self.shape, shape)
|
2335
|
+
# for each dimension, check either from_ is 1, or it does not change
|
2336
|
+
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
2337
|
+
return F.Expand.apply(self.reshape(padded), shape=shape)
|
2338
|
+
|
2339
|
+
def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
768
2340
|
x: Tensor = self
|
769
2341
|
if not isinstance(y, Tensor):
|
770
2342
|
# make y a Tensor
|
771
|
-
|
772
|
-
if isinstance(
|
773
|
-
|
774
|
-
y = Tensor(y,
|
2343
|
+
assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
|
2344
|
+
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
2345
|
+
elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
|
2346
|
+
if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
|
2347
|
+
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
|
775
2348
|
|
776
|
-
if match_dtype:
|
2349
|
+
if match_dtype and x.dtype != y.dtype:
|
777
2350
|
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
778
2351
|
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
779
2352
|
|
780
2353
|
if reverse: x, y = y, x
|
781
2354
|
|
782
|
-
#
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
787
|
-
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
2355
|
+
# broadcast
|
2356
|
+
out_shape = _broadcast_shape(x.shape, y.shape)
|
2357
|
+
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
788
2358
|
|
789
|
-
def _to_const_val(self, x:Union[Tensor,
|
790
|
-
|
791
|
-
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
|
2359
|
+
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
2360
|
+
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
792
2361
|
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
793
2362
|
|
794
|
-
def add(self, x:Union[Tensor,
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
2363
|
+
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2364
|
+
"""
|
2365
|
+
Adds `self` and `x`.
|
2366
|
+
Equivalent to `self + x`.
|
2367
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2368
|
+
|
2369
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2370
|
+
Tensor.manual_seed(42)
|
2371
|
+
t = Tensor.randn(4)
|
2372
|
+
print(t.numpy())
|
2373
|
+
```
|
2374
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2375
|
+
print(t.add(20).numpy())
|
2376
|
+
```
|
2377
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2378
|
+
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
2379
|
+
```
|
2380
|
+
"""
|
2381
|
+
return F.Add.apply(*self._broadcasted(x, reverse))
|
2382
|
+
|
2383
|
+
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2384
|
+
"""
|
2385
|
+
Subtracts `x` from `self`.
|
2386
|
+
Equivalent to `self - x`.
|
2387
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2388
|
+
|
2389
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2390
|
+
Tensor.manual_seed(42)
|
2391
|
+
t = Tensor.randn(4)
|
2392
|
+
print(t.numpy())
|
2393
|
+
```
|
2394
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2395
|
+
print(t.sub(20).numpy())
|
2396
|
+
```
|
2397
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2398
|
+
print(t.sub(Tensor([[2.0], [3.5]])).numpy())
|
2399
|
+
```
|
2400
|
+
"""
|
2401
|
+
a, b = self._broadcasted(x, reverse)
|
2402
|
+
return a + (-b)
|
2403
|
+
|
2404
|
+
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2405
|
+
"""
|
2406
|
+
Multiplies `self` and `x`.
|
2407
|
+
Equivalent to `self * x`.
|
2408
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2409
|
+
|
2410
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2411
|
+
Tensor.manual_seed(42)
|
2412
|
+
t = Tensor.randn(4)
|
2413
|
+
print(t.numpy())
|
2414
|
+
```
|
2415
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2416
|
+
print(t.mul(3).numpy())
|
2417
|
+
```
|
2418
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2419
|
+
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
2420
|
+
```
|
2421
|
+
"""
|
2422
|
+
return F.Mul.apply(*self._broadcasted(x, reverse))
|
2423
|
+
|
2424
|
+
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
|
2425
|
+
"""
|
2426
|
+
Divides `self` by `x`.
|
2427
|
+
Equivalent to `self / x`.
|
2428
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2429
|
+
By default, `div` performs true division. Set `upcast` to `False` for integer division.
|
2430
|
+
|
2431
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2432
|
+
Tensor.manual_seed(42)
|
2433
|
+
t = Tensor.randn(4)
|
2434
|
+
print(t.numpy())
|
2435
|
+
```
|
2436
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2437
|
+
print(t.div(3).numpy())
|
2438
|
+
```
|
2439
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2440
|
+
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
2441
|
+
```
|
2442
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2443
|
+
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
|
2444
|
+
```
|
2445
|
+
"""
|
2446
|
+
numerator, denominator = self._broadcasted(x, reverse)
|
2447
|
+
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
|
2448
|
+
return F.Div.apply(numerator, denominator)
|
2449
|
+
|
2450
|
+
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2451
|
+
"""
|
2452
|
+
Computes bitwise xor of `self` and `x`.
|
2453
|
+
Equivalent to `self ^ x`.
|
2454
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
2455
|
+
|
2456
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2457
|
+
print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
|
2458
|
+
```
|
2459
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2460
|
+
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
|
2461
|
+
```
|
2462
|
+
"""
|
2463
|
+
return F.Xor.apply(*self._broadcasted(x, reverse))
|
2464
|
+
|
2465
|
+
def lshift(self, x:int):
|
2466
|
+
"""
|
2467
|
+
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
2468
|
+
Equivalent to `self << x`.
|
2469
|
+
|
2470
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2471
|
+
print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
|
2472
|
+
```
|
2473
|
+
"""
|
2474
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
2475
|
+
return self.mul(2 ** x)
|
2476
|
+
|
2477
|
+
def rshift(self, x:int):
|
2478
|
+
"""
|
2479
|
+
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
2480
|
+
Equivalent to `self >> x`.
|
2481
|
+
|
2482
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2483
|
+
print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
|
2484
|
+
```
|
2485
|
+
"""
|
2486
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
2487
|
+
return self.div(2 ** x, upcast=False)
|
2488
|
+
|
2489
|
+
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2490
|
+
"""
|
2491
|
+
Computes power of `self` with `x`.
|
2492
|
+
Equivalent to `self ** x`.
|
2493
|
+
|
2494
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2495
|
+
print(Tensor([-1, 2, 3]).pow(2).numpy())
|
2496
|
+
```
|
2497
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2498
|
+
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
|
2499
|
+
```
|
2500
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2501
|
+
print((2 ** Tensor([-1, 2, 3])).numpy())
|
2502
|
+
```
|
2503
|
+
"""
|
809
2504
|
x = self._to_const_val(x)
|
810
2505
|
if not isinstance(x, Tensor) and not reverse:
|
811
2506
|
# simple pow identities
|
812
2507
|
if x < 0: return self.reciprocal().pow(-x)
|
813
|
-
if x
|
814
|
-
if x
|
2508
|
+
if x == 0: return 1 + self * 0
|
2509
|
+
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
2510
|
+
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
2511
|
+
|
2512
|
+
# positive const ** self
|
815
2513
|
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
inject_nan = (
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
2514
|
+
|
2515
|
+
base, exponent = self._broadcasted(x, reverse=reverse)
|
2516
|
+
# start with b ** e = exp(e * log(b))
|
2517
|
+
ret = base.abs().log().mul(exponent).exp()
|
2518
|
+
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
2519
|
+
negative_base = (base < 0).detach().where(1, 0)
|
2520
|
+
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
2521
|
+
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
2522
|
+
# inject nan for negative base and non-integer exponent
|
2523
|
+
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
2524
|
+
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
2525
|
+
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
2526
|
+
|
2527
|
+
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
2528
|
+
"""
|
2529
|
+
Computes element-wise maximum of `self` and `x`.
|
2530
|
+
|
2531
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2532
|
+
print(Tensor([-1, 2, 3]).maximum(1).numpy())
|
2533
|
+
```
|
2534
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2535
|
+
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
2536
|
+
```
|
2537
|
+
"""
|
2538
|
+
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
2539
|
+
|
2540
|
+
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
2541
|
+
"""
|
2542
|
+
Computes element-wise minimum of `self` and `x`.
|
2543
|
+
|
2544
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2545
|
+
print(Tensor([-1, 2, 3]).minimum(1).numpy())
|
2546
|
+
```
|
2547
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2548
|
+
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
2549
|
+
```
|
2550
|
+
"""
|
2551
|
+
return -((-self).maximum(-x))
|
2552
|
+
|
2553
|
+
def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
|
2554
|
+
"""
|
2555
|
+
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
2556
|
+
`output_i = x_i if self_i else y_i`.
|
2557
|
+
|
2558
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2559
|
+
cond = Tensor([[True, True, False], [True, False, False]])
|
2560
|
+
print(cond.where(1, 3).numpy())
|
2561
|
+
```
|
2562
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2563
|
+
Tensor.manual_seed(42)
|
2564
|
+
cond = Tensor.randn(2, 3)
|
2565
|
+
print(cond.numpy())
|
2566
|
+
```
|
2567
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2568
|
+
print((cond > 0).where(cond, -float("inf")).numpy())
|
2569
|
+
```
|
2570
|
+
"""
|
2571
|
+
if isinstance(x, Tensor): x, y = x._broadcasted(y)
|
2572
|
+
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
2573
|
+
cond, x = self._broadcasted(x, match_dtype=False)
|
2574
|
+
cond, y = cond._broadcasted(y, match_dtype=False)
|
2575
|
+
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
|
2576
|
+
|
2577
|
+
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
2578
|
+
|
2579
|
+
# ***** op wrappers *****
|
839
2580
|
|
840
2581
|
def __neg__(self) -> Tensor: return self.neg()
|
841
2582
|
|
@@ -846,6 +2587,8 @@ class Tensor:
|
|
846
2587
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
847
2588
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
848
2589
|
def __xor__(self, x) -> Tensor: return self.xor(x)
|
2590
|
+
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
2591
|
+
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
849
2592
|
|
850
2593
|
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
851
2594
|
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
@@ -862,69 +2605,254 @@ class Tensor:
|
|
862
2605
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
863
2606
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
864
2607
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
2608
|
+
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
2609
|
+
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
865
2610
|
|
866
|
-
def __lt__(self, x) -> Tensor: return
|
867
|
-
def __gt__(self, x) -> Tensor: return
|
2611
|
+
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
2612
|
+
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
868
2613
|
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
869
2614
|
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
870
|
-
def
|
871
|
-
def
|
2615
|
+
def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
|
2616
|
+
def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
|
872
2617
|
|
873
2618
|
# ***** functional nn ops *****
|
874
2619
|
|
875
2620
|
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
2621
|
+
"""
|
2622
|
+
Applies a linear transformation to `self` using `weight` and `bias`.
|
2623
|
+
|
2624
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
|
2625
|
+
|
2626
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2627
|
+
t = Tensor([[1, 2], [3, 4]])
|
2628
|
+
weight = Tensor([[1, 2], [3, 4]])
|
2629
|
+
bias = Tensor([1, 2])
|
2630
|
+
print(t.linear(weight, bias).numpy())
|
2631
|
+
```
|
2632
|
+
"""
|
876
2633
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
877
2634
|
return x.add(bias) if bias is not None else x
|
878
2635
|
|
879
|
-
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
|
2636
|
+
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
|
2637
|
+
"""
|
2638
|
+
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
2639
|
+
|
2640
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2641
|
+
t = Tensor([1, 2, 3])
|
2642
|
+
print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
|
2643
|
+
```
|
2644
|
+
"""
|
2645
|
+
return functools.reduce(lambda x,f: f(x), ll, self)
|
880
2646
|
|
881
2647
|
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
2648
|
+
"""
|
2649
|
+
Applies Layer Normalization over a mini-batch of inputs.
|
2650
|
+
|
2651
|
+
- Described: https://paperswithcode.com/method/layer-normalization
|
2652
|
+
- Paper: https://arxiv.org/abs/1607.06450v1
|
2653
|
+
|
2654
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2655
|
+
t = Tensor.randn(8, 10, 16) * 2 + 8
|
2656
|
+
print(t.mean().item(), t.std().item())
|
2657
|
+
```
|
2658
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2659
|
+
t = t.layernorm()
|
2660
|
+
print(t.mean().item(), t.std().item())
|
2661
|
+
```
|
2662
|
+
"""
|
882
2663
|
y = (self - self.mean(axis, keepdim=True))
|
883
2664
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
884
2665
|
|
885
|
-
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
2666
|
+
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
2667
|
+
"""
|
2668
|
+
Applies Batch Normalization over a mini-batch of inputs.
|
2669
|
+
|
2670
|
+
- Described: https://paperswithcode.com/method/batch-normalization
|
2671
|
+
- Paper: https://arxiv.org/abs/1502.03167
|
2672
|
+
|
2673
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2674
|
+
t = Tensor.randn(8, 4, 16, 16) * 2 + 8
|
2675
|
+
print(t.mean().item(), t.std().item())
|
2676
|
+
```
|
2677
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2678
|
+
t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
|
2679
|
+
print(t.mean().item(), t.std().item())
|
2680
|
+
```
|
2681
|
+
"""
|
2682
|
+
axis_ = argfix(axis)
|
2683
|
+
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
|
2684
|
+
x = self - mean.reshape(shape)
|
2685
|
+
if weight is not None: x = x * weight.reshape(shape)
|
2686
|
+
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
|
2687
|
+
return (ret + bias.reshape(shape)) if bias is not None else ret
|
890
2688
|
|
891
2689
|
def dropout(self, p=0.5) -> Tensor:
|
892
|
-
|
893
|
-
|
2690
|
+
"""
|
2691
|
+
Applies dropout to `self`.
|
2692
|
+
|
2693
|
+
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
2694
|
+
|
2695
|
+
- Described: https://paperswithcode.com/method/dropout
|
2696
|
+
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
894
2697
|
|
895
|
-
|
896
|
-
|
2698
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2699
|
+
Tensor.manual_seed(42)
|
2700
|
+
t = Tensor.randn(2, 2)
|
2701
|
+
with Tensor.train():
|
2702
|
+
print(t.dropout().numpy())
|
2703
|
+
```
|
2704
|
+
"""
|
2705
|
+
if not Tensor.training or p == 0: return self
|
2706
|
+
return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
|
2707
|
+
|
2708
|
+
def one_hot(self, num_classes:int) -> Tensor:
|
2709
|
+
"""
|
2710
|
+
Converts `self` to a one-hot tensor.
|
2711
|
+
|
2712
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2713
|
+
t = Tensor([0, 1, 3, 3, 4])
|
2714
|
+
print(t.one_hot(5).numpy())
|
2715
|
+
```
|
2716
|
+
"""
|
2717
|
+
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
2718
|
+
|
2719
|
+
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
2720
|
+
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
2721
|
+
"""
|
2722
|
+
Computes scaled dot-product attention.
|
2723
|
+
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
2724
|
+
|
2725
|
+
- Described: https://paperswithcode.com/method/scaled
|
2726
|
+
- Paper: https://arxiv.org/abs/1706.03762v7
|
2727
|
+
|
2728
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2729
|
+
q = Tensor.randn(2, 4, 8)
|
2730
|
+
k = Tensor.randn(2, 4, 8)
|
2731
|
+
v = Tensor.randn(2, 4, 8)
|
2732
|
+
print(q.scaled_dot_product_attention(k, v).numpy())
|
2733
|
+
```
|
2734
|
+
"""
|
2735
|
+
# NOTE: it also works when `key` and `value` have symbolic shape.
|
897
2736
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
898
2737
|
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
899
2738
|
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
900
|
-
qk = self
|
901
|
-
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
|
2739
|
+
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
2740
|
+
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
902
2741
|
|
903
2742
|
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
2743
|
+
"""
|
2744
|
+
Computes the binary cross-entropy loss between `self` and `y`.
|
2745
|
+
|
2746
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
|
2747
|
+
|
2748
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2749
|
+
t = Tensor([0.1, 0.9, 0.2])
|
2750
|
+
y = Tensor([0, 1, 0])
|
2751
|
+
print(t.binary_crossentropy(y).item())
|
2752
|
+
```
|
2753
|
+
"""
|
904
2754
|
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
905
2755
|
|
906
2756
|
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
2757
|
+
"""
|
2758
|
+
Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
|
2759
|
+
|
2760
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
|
2761
|
+
|
2762
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2763
|
+
t = Tensor([-1, 2, -3])
|
2764
|
+
y = Tensor([0, 1, 0])
|
2765
|
+
print(t.binary_crossentropy_logits(y).item())
|
2766
|
+
```
|
2767
|
+
"""
|
907
2768
|
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
|
908
2769
|
|
909
|
-
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
|
910
|
-
|
911
|
-
|
2770
|
+
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
|
2771
|
+
"""
|
2772
|
+
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
|
2773
|
+
|
2774
|
+
NOTE: `self` is logits and `Y` is the target labels.
|
2775
|
+
|
2776
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
|
2777
|
+
|
2778
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2779
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
2780
|
+
Y = Tensor([1, 2])
|
2781
|
+
print(t.sparse_categorical_crossentropy(Y).item())
|
2782
|
+
```
|
2783
|
+
"""
|
2784
|
+
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
2785
|
+
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
912
2786
|
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
913
|
-
y = ((y_counter == Y.flatten().reshape(-1, 1))
|
914
|
-
|
2787
|
+
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
2788
|
+
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
2789
|
+
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
915
2790
|
|
916
2791
|
# ***** cast ops *****
|
917
2792
|
|
918
|
-
def
|
919
|
-
if self.dtype == dtype: return self
|
2793
|
+
def llvm_bf16_cast(self, dtype:DType):
|
920
2794
|
# hack for devices that don't support bfloat16
|
921
|
-
|
922
|
-
return
|
2795
|
+
assert self.dtype == dtypes.bfloat16
|
2796
|
+
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
2797
|
+
def cast(self, dtype:DType) -> Tensor:
|
2798
|
+
"""
|
2799
|
+
Casts `self` to the given `dtype`.
|
2800
|
+
|
2801
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2802
|
+
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
2803
|
+
print(t.dtype, t.numpy())
|
2804
|
+
```
|
2805
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2806
|
+
t = t.cast(dtypes.int32)
|
2807
|
+
print(t.dtype, t.numpy())
|
2808
|
+
```
|
2809
|
+
"""
|
2810
|
+
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
|
923
2811
|
def bitcast(self, dtype:DType) -> Tensor:
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
2812
|
+
"""
|
2813
|
+
Bitcasts `self` to the given `dtype` of the same itemsize.
|
2814
|
+
|
2815
|
+
`self` must not require a gradient.
|
2816
|
+
|
2817
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2818
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2819
|
+
print(t.dtype, t.numpy())
|
2820
|
+
```
|
2821
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2822
|
+
t = t.bitcast(dtypes.uint32)
|
2823
|
+
print(t.dtype, t.numpy())
|
2824
|
+
```
|
2825
|
+
"""
|
2826
|
+
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
2827
|
+
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
2828
|
+
def float(self) -> Tensor:
|
2829
|
+
"""
|
2830
|
+
Convenience method to cast `self` to a `float32` Tensor.
|
2831
|
+
|
2832
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2833
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2834
|
+
print(t.dtype, t.numpy())
|
2835
|
+
```
|
2836
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2837
|
+
t = t.float()
|
2838
|
+
print(t.dtype, t.numpy())
|
2839
|
+
```
|
2840
|
+
"""
|
2841
|
+
return self.cast(dtypes.float32)
|
2842
|
+
def half(self) -> Tensor:
|
2843
|
+
"""
|
2844
|
+
Convenience method to cast `self` to a `float16` Tensor.
|
2845
|
+
|
2846
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2847
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2848
|
+
print(t.dtype, t.numpy())
|
2849
|
+
```
|
2850
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2851
|
+
t = t.half()
|
2852
|
+
print(t.dtype, t.numpy())
|
2853
|
+
```
|
2854
|
+
"""
|
2855
|
+
return self.cast(dtypes.float16)
|
928
2856
|
|
929
2857
|
# ***** convenience stuff *****
|
930
2858
|
|
@@ -934,20 +2862,101 @@ class Tensor:
|
|
934
2862
|
def element_size(self) -> int: return self.dtype.itemsize
|
935
2863
|
def nbytes(self) -> int: return self.numel() * self.element_size()
|
936
2864
|
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
2865
|
+
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
|
2866
|
+
|
2867
|
+
# *** image Tensor function replacements ***
|
2868
|
+
|
2869
|
+
def image_dot(self, w:Tensor, acc_dtype=None):
|
2870
|
+
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
2871
|
+
n1, n2 = len(self.shape), len(w.shape)
|
2872
|
+
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
2873
|
+
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
2874
|
+
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
|
2875
|
+
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
2876
|
+
|
2877
|
+
# NOTE: with NHWC we can remove the transposes
|
2878
|
+
# bs x groups*cin x H x W
|
2879
|
+
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
2880
|
+
# groups*cout x cin x H, W
|
2881
|
+
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
2882
|
+
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
2883
|
+
|
2884
|
+
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
|
2885
|
+
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
2886
|
+
|
2887
|
+
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
2888
|
+
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
2889
|
+
|
2890
|
+
# hack for non multiples of 4 on cin
|
2891
|
+
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
2892
|
+
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
2893
|
+
added_input_channels = 4 - (cin % 4)
|
2894
|
+
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
|
2895
|
+
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
|
2896
|
+
cin = cin + added_input_channels
|
2897
|
+
x = x.reshape(bs, groups*cin, iy, ix)
|
2898
|
+
|
2899
|
+
# hack for non multiples of 4 on rcout
|
2900
|
+
added_output_channels = 0
|
2901
|
+
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
2902
|
+
added_output_channels = 4 - (rcout % 4)
|
2903
|
+
rcout += added_output_channels
|
2904
|
+
cout = groups * rcout
|
2905
|
+
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
|
2906
|
+
|
2907
|
+
# packed (note: flipping bs and iy would make the auto-padding work)
|
2908
|
+
x = x.permute(0,2,3,1)
|
2909
|
+
cin_last = iy == 1 and ix == 1
|
2910
|
+
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
2911
|
+
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
2912
|
+
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
2913
|
+
|
2914
|
+
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
2915
|
+
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
2916
|
+
x, w = x.contiguous(), w.contiguous()
|
2917
|
+
|
2918
|
+
# expand out
|
2919
|
+
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
2920
|
+
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
2921
|
+
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
2922
|
+
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
2923
|
+
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
2924
|
+
|
2925
|
+
# padding
|
2926
|
+
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
2927
|
+
x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
2928
|
+
|
2929
|
+
# prepare input
|
2930
|
+
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
2931
|
+
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
2932
|
+
|
2933
|
+
# prepare weights
|
2934
|
+
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
2935
|
+
|
2936
|
+
# the conv!
|
2937
|
+
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
|
2938
|
+
|
2939
|
+
# undo hack for non multiples of 4 on C.rcout
|
2940
|
+
if added_output_channels != 0:
|
2941
|
+
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
2942
|
+
cout = groups * (rcout - added_output_channels)
|
2943
|
+
|
2944
|
+
# NCHW output
|
2945
|
+
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
2946
|
+
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
937
2947
|
|
938
2948
|
# register functions to move between devices
|
939
|
-
for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
2949
|
+
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
940
2950
|
|
941
2951
|
if IMAGE:
|
942
2952
|
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
943
|
-
|
944
|
-
setattr(Tensor, "
|
945
|
-
setattr(Tensor, "dot", image_dot)
|
2953
|
+
setattr(Tensor, "conv2d", Tensor.image_conv2d)
|
2954
|
+
setattr(Tensor, "dot", Tensor.image_dot)
|
946
2955
|
|
947
|
-
# TODO: remove
|
2956
|
+
# TODO: eventually remove this
|
948
2957
|
def custom_random(out:Buffer):
|
949
2958
|
Tensor._seed += 1
|
950
|
-
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
|
951
2959
|
rng = np.random.default_rng(Tensor._seed)
|
952
|
-
rng_np_buffer = rng.
|
2960
|
+
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
2961
|
+
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
|
953
2962
|
out.copyin(rng_np_buffer.data)
|