tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools
|
4
|
-
from
|
3
|
+
import time, math, itertools, functools
|
4
|
+
from contextlib import ContextDecorator
|
5
|
+
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
5
6
|
from collections import defaultdict
|
6
|
-
from functools import partialmethod, reduce
|
7
7
|
import numpy as np
|
8
8
|
|
9
|
-
from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
|
10
|
-
from tinygrad.helpers import argfix, make_pair,
|
11
|
-
from tinygrad.
|
12
|
-
from tinygrad.
|
9
|
+
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
10
|
+
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv
|
11
|
+
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
12
|
+
from tinygrad.lazy import LazyBuffer
|
13
|
+
from tinygrad.multi import MultiLazyBuffer
|
13
14
|
from tinygrad.ops import LoadOps
|
14
|
-
from tinygrad.device import Device, Buffer
|
15
|
-
from tinygrad.shape.symbolic import sint
|
16
|
-
from tinygrad.realize import run_schedule
|
15
|
+
from tinygrad.device import Device, Buffer, BufferOptions
|
16
|
+
from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
|
17
|
+
from tinygrad.engine.realize import run_schedule
|
18
|
+
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
|
17
19
|
|
18
20
|
# **** start with two base classes, Tensor and Function ****
|
19
21
|
|
@@ -30,29 +32,68 @@ class Function:
|
|
30
32
|
@classmethod
|
31
33
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
32
34
|
ctx = fxn(x[0].device, *x)
|
33
|
-
ret = Tensor
|
34
|
-
|
35
|
+
ret = Tensor.__new__(Tensor)
|
36
|
+
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
37
|
+
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
35
38
|
return ret
|
36
39
|
|
37
|
-
import tinygrad.
|
40
|
+
import tinygrad.function as F
|
38
41
|
|
39
|
-
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:
|
42
|
+
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
40
43
|
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
41
44
|
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
42
45
|
|
43
|
-
|
46
|
+
def _fromcpu(x: np.ndarray) -> LazyBuffer:
|
47
|
+
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
|
48
|
+
# fake realize
|
49
|
+
ret.buffer.allocate(x)
|
50
|
+
del ret.srcs
|
51
|
+
return ret
|
52
|
+
|
53
|
+
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
54
|
+
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
55
|
+
for k in range(len(mat[0]))] for dim in range(dims)]
|
56
|
+
|
57
|
+
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
58
|
+
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
59
|
+
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
|
60
|
+
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
61
|
+
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
62
|
+
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
63
|
+
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
|
64
|
+
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
65
|
+
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
66
|
+
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
67
|
+
return ret
|
68
|
+
|
69
|
+
def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
|
70
|
+
def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
|
44
71
|
|
45
72
|
class Tensor:
|
73
|
+
"""
|
74
|
+
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
75
|
+
|
76
|
+
```python exec="true" session="tensor"
|
77
|
+
from tinygrad import Tensor, dtypes, nn
|
78
|
+
import numpy as np
|
79
|
+
import math
|
80
|
+
np.set_printoptions(precision=4)
|
81
|
+
```
|
82
|
+
"""
|
46
83
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
47
84
|
__deletable__ = ('_ctx',)
|
48
85
|
training: ClassVar[bool] = False
|
49
|
-
class train:
|
50
|
-
def __init__(self,
|
51
|
-
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.
|
86
|
+
class train(ContextDecorator):
|
87
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
88
|
+
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
52
89
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
53
90
|
|
54
91
|
no_grad: ClassVar[bool] = False
|
55
|
-
|
92
|
+
class inference_mode(ContextDecorator):
|
93
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
94
|
+
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
95
|
+
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
96
|
+
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
56
97
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
57
98
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
58
99
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
@@ -66,18 +107,19 @@ class Tensor:
|
|
66
107
|
# internal variables used for autograd graph construction
|
67
108
|
self._ctx: Optional[Function] = None
|
68
109
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
69
|
-
elif isinstance(data, get_args(
|
70
|
-
elif isinstance(data,
|
110
|
+
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
111
|
+
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
112
|
+
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
|
71
113
|
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
72
114
|
elif isinstance(data, list):
|
73
|
-
if
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
data =
|
115
|
+
if dtype is None:
|
116
|
+
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
117
|
+
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
118
|
+
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
119
|
+
else: data = _fromcpu(np.array(data, dtype.np))
|
78
120
|
elif isinstance(data, np.ndarray):
|
79
121
|
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
80
|
-
else: data =
|
122
|
+
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
81
123
|
|
82
124
|
# data is a LazyBuffer, but it might be on the wrong device
|
83
125
|
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
@@ -87,12 +129,15 @@ class Tensor:
|
|
87
129
|
else:
|
88
130
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
89
131
|
|
90
|
-
def __repr__(self):
|
91
|
-
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
132
|
+
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
92
133
|
|
93
134
|
# Python has a non moving GC, so this should be okay
|
94
135
|
def __hash__(self): return id(self)
|
95
136
|
|
137
|
+
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
138
|
+
|
139
|
+
def __len__(self): return self.shape[0] if len(self.shape) else 1
|
140
|
+
|
96
141
|
@property
|
97
142
|
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
98
143
|
|
@@ -104,148 +149,513 @@ class Tensor:
|
|
104
149
|
|
105
150
|
# ***** data handlers ****
|
106
151
|
|
107
|
-
|
108
|
-
|
109
|
-
|
152
|
+
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
153
|
+
"""Creates the schedule needed to realize these Tensor(s), with Variables."""
|
154
|
+
if getenv("FUZZ_SCHEDULE"):
|
155
|
+
from test.external.fuzz_schedule import fuzz_schedule
|
156
|
+
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
157
|
+
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
|
158
|
+
return memory_planner(schedule), var_vals
|
159
|
+
|
160
|
+
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
161
|
+
"""Creates the schedule needed to realize these Tensor(s)."""
|
162
|
+
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
|
163
|
+
assert len(var_vals) == 0
|
164
|
+
return schedule
|
165
|
+
|
166
|
+
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
167
|
+
"""Triggers the computation needed to create these Tensor(s)."""
|
168
|
+
run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
|
169
|
+
return self
|
110
170
|
|
111
|
-
def
|
112
|
-
|
171
|
+
def replace(self, x:Tensor) -> Tensor:
|
172
|
+
"""
|
173
|
+
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
174
|
+
"""
|
175
|
+
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
176
|
+
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
177
|
+
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
178
|
+
self.lazydata = x.lazydata
|
113
179
|
return self
|
114
180
|
|
115
181
|
def assign(self, x) -> Tensor:
|
116
182
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
117
183
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
118
|
-
if x.__class__ is not Tensor: x = Tensor(x, device="
|
184
|
+
if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
|
119
185
|
self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
|
120
186
|
return self
|
121
187
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
188
|
+
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
189
|
+
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
122
190
|
# NOTE: we allow cross device assign
|
123
191
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
192
|
+
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
193
|
+
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
194
|
+
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
124
195
|
assert not x.requires_grad # self requires_grad is okay?
|
125
|
-
if
|
126
|
-
|
127
|
-
if isinstance(self.lazydata, MultiLazyBuffer):
|
128
|
-
for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
|
129
|
-
else:
|
130
|
-
if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
|
131
|
-
self.lazydata = x.lazydata
|
196
|
+
if not self.lazydata.is_realized(): return self.replace(x)
|
197
|
+
self.lazydata = self.lazydata.assign(x.lazydata)
|
132
198
|
return self
|
133
|
-
def detach(self) -> Tensor:
|
134
|
-
|
135
|
-
|
136
|
-
|
199
|
+
def detach(self) -> Tensor:
|
200
|
+
"""
|
201
|
+
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
202
|
+
"""
|
203
|
+
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
204
|
+
|
205
|
+
def _data(self) -> memoryview:
|
206
|
+
if 0 in self.shape: return memoryview(bytearray(0))
|
207
|
+
# NOTE: this realizes on the object from as_buffer being a Python object
|
208
|
+
cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
|
209
|
+
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
210
|
+
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
211
|
+
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
212
|
+
|
213
|
+
def data(self) -> memoryview:
|
214
|
+
"""
|
215
|
+
Returns the data of this tensor as a memoryview.
|
216
|
+
|
217
|
+
```python exec="true" source="above" session="tensor" result="python"
|
218
|
+
t = Tensor([1, 2, 3, 4])
|
219
|
+
print(np.frombuffer(t.data(), dtype=np.int32))
|
220
|
+
```
|
221
|
+
"""
|
222
|
+
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
223
|
+
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
224
|
+
return self._data().cast(self.dtype.fmt, self.shape)
|
225
|
+
|
226
|
+
def item(self) -> ConstType:
|
227
|
+
"""
|
228
|
+
Returns the value of this tensor as a standard Python number.
|
229
|
+
|
230
|
+
```python exec="true" source="above" session="tensor" result="python"
|
231
|
+
t = Tensor(42)
|
232
|
+
print(t.item())
|
233
|
+
```
|
234
|
+
"""
|
235
|
+
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
137
236
|
assert self.numel() == 1, "must have one element for item"
|
138
|
-
return
|
139
|
-
|
237
|
+
return self._data().cast(self.dtype.fmt)[0]
|
238
|
+
|
239
|
+
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
240
|
+
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
241
|
+
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
242
|
+
"""
|
243
|
+
Returns the value of this tensor as a nested list.
|
244
|
+
|
245
|
+
```python exec="true" source="above" session="tensor" result="python"
|
246
|
+
t = Tensor([1, 2, 3, 4])
|
247
|
+
print(t.tolist())
|
248
|
+
```
|
249
|
+
"""
|
250
|
+
return self.data().tolist()
|
140
251
|
|
141
|
-
# TODO: this should import numpy and use .data() to construct the array
|
142
252
|
def numpy(self) -> np.ndarray:
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
253
|
+
"""
|
254
|
+
Returns the value of this tensor as a `numpy.ndarray`.
|
255
|
+
|
256
|
+
```python exec="true" source="above" session="tensor" result="python"
|
257
|
+
t = Tensor([1, 2, 3, 4])
|
258
|
+
print(repr(t.numpy()))
|
259
|
+
```
|
260
|
+
"""
|
261
|
+
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
262
|
+
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
263
|
+
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
264
|
+
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
|
265
|
+
|
266
|
+
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
267
|
+
"""
|
268
|
+
Moves the tensor to the given device.
|
269
|
+
"""
|
270
|
+
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
271
|
+
if device == self.device: return self
|
272
|
+
if not isinstance(device, str): return self.shard(device)
|
273
|
+
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
274
|
+
if self.grad is not None: ret.grad = self.grad.to(device)
|
275
|
+
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
153
276
|
return ret
|
154
277
|
|
155
|
-
def to_(self, device:Optional[str]):
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
278
|
+
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
|
279
|
+
"""
|
280
|
+
Moves the tensor to the given device in place.
|
281
|
+
"""
|
282
|
+
real = self.to(device)
|
283
|
+
# TODO: is this assign?
|
284
|
+
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
285
|
+
self.lazydata = real.lazydata
|
160
286
|
|
161
287
|
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
288
|
+
"""
|
289
|
+
Shards the tensor across the given devices.
|
290
|
+
"""
|
162
291
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
163
292
|
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
|
164
|
-
|
293
|
+
if axis is not None and axis < 0: axis += len(self.shape)
|
294
|
+
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
|
165
295
|
|
166
296
|
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
|
297
|
+
"""
|
298
|
+
Shards the tensor across the given devices in place.
|
299
|
+
"""
|
167
300
|
self.lazydata = self.shard(devices, axis).lazydata
|
168
301
|
return self
|
169
302
|
|
303
|
+
@staticmethod
|
304
|
+
def from_node(y:Node, **kwargs) -> Tensor:
|
305
|
+
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
|
306
|
+
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
|
307
|
+
raise RuntimeError(f"unhandled Node {y}")
|
308
|
+
|
170
309
|
# ***** creation llop entrypoint *****
|
171
310
|
|
172
311
|
@staticmethod
|
173
|
-
def _loadop(op, shape, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
174
|
-
|
312
|
+
def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
313
|
+
if isinstance(device, tuple):
|
314
|
+
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
315
|
+
for d in device], None), device, dtype, **kwargs)
|
316
|
+
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
175
317
|
|
176
318
|
@staticmethod
|
177
|
-
def empty(*shape, **kwargs):
|
319
|
+
def empty(*shape, **kwargs):
|
320
|
+
"""
|
321
|
+
Creates an empty tensor with the given shape.
|
322
|
+
|
323
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
324
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
325
|
+
|
326
|
+
```python exec="true" source="above" session="tensor" result="python"
|
327
|
+
t = Tensor.empty(2, 3)
|
328
|
+
print(t.shape)
|
329
|
+
```
|
330
|
+
"""
|
331
|
+
return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
178
332
|
|
179
333
|
_seed: int = int(time.time())
|
334
|
+
_rng_counter: Optional[Tensor] = None
|
180
335
|
@staticmethod
|
181
|
-
def manual_seed(seed=0):
|
336
|
+
def manual_seed(seed=0):
|
337
|
+
"""
|
338
|
+
Sets the seed for random operations.
|
339
|
+
|
340
|
+
```python exec="true" source="above" session="tensor" result="python"
|
341
|
+
Tensor.manual_seed(42)
|
342
|
+
print(Tensor._seed)
|
343
|
+
```
|
344
|
+
"""
|
345
|
+
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
182
346
|
|
183
347
|
@staticmethod
|
184
|
-
def rand(*shape,
|
348
|
+
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
349
|
+
"""
|
350
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
351
|
+
|
352
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
353
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
354
|
+
|
355
|
+
```python exec="true" source="above" session="tensor" result="python"
|
356
|
+
Tensor.manual_seed(42)
|
357
|
+
t = Tensor.rand(2, 3)
|
358
|
+
print(t.numpy())
|
359
|
+
```
|
360
|
+
"""
|
361
|
+
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
362
|
+
if not THREEFRY.value:
|
363
|
+
# for bfloat16, numpy rand passes buffer in float
|
364
|
+
if (dtype or dtypes.default_float) == dtypes.bfloat16:
|
365
|
+
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
366
|
+
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
367
|
+
|
368
|
+
# threefry
|
369
|
+
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
370
|
+
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
|
371
|
+
counts2 = counts1 + math.ceil(num / 2)
|
372
|
+
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
373
|
+
|
374
|
+
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
375
|
+
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
|
376
|
+
|
377
|
+
x = [counts1 + ks[-1], counts2 + ks[0]]
|
378
|
+
for i in range(5):
|
379
|
+
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
|
380
|
+
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
381
|
+
out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
|
382
|
+
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
383
|
+
out.requires_grad = kwargs.get("requires_grad")
|
384
|
+
return out.contiguous()
|
185
385
|
|
186
386
|
# ***** creation helper functions *****
|
187
387
|
|
188
388
|
@staticmethod
|
189
|
-
def full(shape:Tuple[sint, ...], fill_value:
|
389
|
+
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
|
390
|
+
"""
|
391
|
+
Creates a tensor with the given shape, filled with the given value.
|
392
|
+
|
393
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
394
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
395
|
+
|
396
|
+
```python exec="true" source="above" session="tensor" result="python"
|
397
|
+
print(Tensor.full((2, 3), 42).numpy())
|
398
|
+
```
|
399
|
+
```python exec="true" source="above" session="tensor" result="python"
|
400
|
+
print(Tensor.full((2, 3), False).numpy())
|
401
|
+
```
|
402
|
+
"""
|
190
403
|
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
191
404
|
|
192
405
|
@staticmethod
|
193
|
-
def zeros(*shape, **kwargs):
|
406
|
+
def zeros(*shape, **kwargs):
|
407
|
+
"""
|
408
|
+
Creates a tensor with the given shape, filled with zeros.
|
409
|
+
|
410
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
411
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
412
|
+
|
413
|
+
```python exec="true" source="above" session="tensor" result="python"
|
414
|
+
print(Tensor.zeros(2, 3).numpy())
|
415
|
+
```
|
416
|
+
```python exec="true" source="above" session="tensor" result="python"
|
417
|
+
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
418
|
+
```
|
419
|
+
"""
|
420
|
+
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
194
421
|
|
195
422
|
@staticmethod
|
196
|
-
def ones(*shape, **kwargs):
|
423
|
+
def ones(*shape, **kwargs):
|
424
|
+
"""
|
425
|
+
Creates a tensor with the given shape, filled with ones.
|
426
|
+
|
427
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
428
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
429
|
+
|
430
|
+
```python exec="true" source="above" session="tensor" result="python"
|
431
|
+
print(Tensor.ones(2, 3).numpy())
|
432
|
+
```
|
433
|
+
```python exec="true" source="above" session="tensor" result="python"
|
434
|
+
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
435
|
+
```
|
436
|
+
"""
|
437
|
+
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
197
438
|
|
198
439
|
@staticmethod
|
199
440
|
def arange(start, stop=None, step=1, **kwargs):
|
441
|
+
"""
|
442
|
+
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
|
443
|
+
|
444
|
+
If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
|
445
|
+
|
446
|
+
If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
|
447
|
+
|
448
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
449
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
450
|
+
|
451
|
+
```python exec="true" source="above" session="tensor" result="python"
|
452
|
+
print(Tensor.arange(5).numpy())
|
453
|
+
```
|
454
|
+
```python exec="true" source="above" session="tensor" result="python"
|
455
|
+
print(Tensor.arange(5, 10).numpy())
|
456
|
+
```
|
457
|
+
```python exec="true" source="above" session="tensor" result="python"
|
458
|
+
print(Tensor.arange(5, 10, 2).numpy())
|
459
|
+
```
|
460
|
+
```python exec="true" source="above" session="tensor" result="python"
|
461
|
+
print(Tensor.arange(5.5, 10, 2).numpy())
|
462
|
+
```
|
463
|
+
"""
|
200
464
|
if stop is None: stop, start = start, 0
|
465
|
+
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
201
466
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
202
|
-
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).
|
467
|
+
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
203
468
|
|
204
469
|
@staticmethod
|
205
470
|
def eye(dim:int, **kwargs):
|
471
|
+
"""
|
472
|
+
Creates an identity matrix of the given dimension.
|
473
|
+
|
474
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
475
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
476
|
+
|
477
|
+
```python exec="true" source="above" session="tensor" result="python"
|
478
|
+
print(Tensor.eye(3).numpy())
|
479
|
+
```
|
480
|
+
"""
|
206
481
|
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
|
207
482
|
|
208
|
-
def full_like(self, fill_value:
|
483
|
+
def full_like(self, fill_value:ConstType, **kwargs):
|
484
|
+
"""
|
485
|
+
Creates a tensor with the same shape as `self`, filled with the given value.
|
486
|
+
If `dtype` is not specified, the dtype of `self` is used.
|
487
|
+
|
488
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
489
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
490
|
+
|
491
|
+
```python exec="true" source="above" session="tensor" result="python"
|
492
|
+
t = Tensor.ones(2, 3)
|
493
|
+
print(Tensor.full_like(t, 42).numpy())
|
494
|
+
```
|
495
|
+
"""
|
209
496
|
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
210
|
-
|
211
|
-
def
|
497
|
+
|
498
|
+
def zeros_like(self, **kwargs):
|
499
|
+
"""
|
500
|
+
Creates a tensor with the same shape as `self`, filled with zeros.
|
501
|
+
|
502
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
503
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
504
|
+
|
505
|
+
```python exec="true" source="above" session="tensor" result="python"
|
506
|
+
t = Tensor.ones(2, 3)
|
507
|
+
print(Tensor.zeros_like(t).numpy())
|
508
|
+
```
|
509
|
+
"""
|
510
|
+
return self.full_like(0, **kwargs)
|
511
|
+
|
512
|
+
def ones_like(self, **kwargs):
|
513
|
+
"""
|
514
|
+
Creates a tensor with the same shape as `self`, filled with ones.
|
515
|
+
|
516
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
517
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
518
|
+
|
519
|
+
```python exec="true" source="above" session="tensor" result="python"
|
520
|
+
t = Tensor.zeros(2, 3)
|
521
|
+
print(Tensor.ones_like(t).numpy())
|
522
|
+
```
|
523
|
+
"""
|
524
|
+
return self.full_like(1, **kwargs)
|
212
525
|
|
213
526
|
# ***** rng hlops *****
|
214
527
|
|
215
528
|
@staticmethod
|
216
529
|
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
530
|
+
"""
|
531
|
+
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
532
|
+
If `dtype` is not specified, the default type is used.
|
533
|
+
|
534
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
535
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
536
|
+
|
537
|
+
```python exec="true" source="above" session="tensor" result="python"
|
538
|
+
Tensor.manual_seed(42)
|
539
|
+
print(Tensor.randn(2, 3).numpy())
|
540
|
+
```
|
541
|
+
"""
|
217
542
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
218
|
-
src = Tensor.rand((2, *argfix(*shape)), **kwargs)
|
543
|
+
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
219
544
|
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
220
545
|
|
221
546
|
@staticmethod
|
222
|
-
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
547
|
+
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
548
|
+
"""
|
549
|
+
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
550
|
+
If `dtype` is not specified, the default type is used.
|
551
|
+
|
552
|
+
You can pass in the `device` keyword argument to control device of the tensor.
|
553
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
554
|
+
|
555
|
+
```python exec="true" source="above" session="tensor" result="python"
|
556
|
+
Tensor.manual_seed(42)
|
557
|
+
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
558
|
+
```
|
559
|
+
"""
|
560
|
+
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
561
|
+
dtype = kwargs.pop("dtype", dtypes.int32)
|
562
|
+
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
563
|
+
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
223
564
|
|
224
565
|
@staticmethod
|
225
|
-
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
566
|
+
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
567
|
+
"""
|
568
|
+
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
569
|
+
|
570
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
571
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
572
|
+
|
573
|
+
```python exec="true" source="above" session="tensor" result="python"
|
574
|
+
Tensor.manual_seed(42)
|
575
|
+
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
576
|
+
```
|
577
|
+
"""
|
578
|
+
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
226
579
|
|
227
580
|
@staticmethod
|
228
581
|
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
582
|
+
"""
|
583
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
584
|
+
|
585
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
586
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
587
|
+
|
588
|
+
```python exec="true" source="above" session="tensor" result="python"
|
589
|
+
Tensor.manual_seed(42)
|
590
|
+
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
591
|
+
```
|
592
|
+
"""
|
229
593
|
dtype = kwargs.pop("dtype", dtypes.default_float)
|
230
594
|
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
231
595
|
|
232
596
|
@staticmethod
|
233
|
-
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
597
|
+
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
598
|
+
"""
|
599
|
+
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
600
|
+
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
601
|
+
|
602
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
603
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
604
|
+
|
605
|
+
```python exec="true" source="above" session="tensor" result="python"
|
606
|
+
Tensor.manual_seed(42)
|
607
|
+
print(Tensor.scaled_uniform(2, 3).numpy())
|
608
|
+
```
|
609
|
+
"""
|
610
|
+
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
234
611
|
|
235
612
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
236
613
|
@staticmethod
|
237
614
|
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
615
|
+
"""
|
616
|
+
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
617
|
+
|
618
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
619
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
620
|
+
|
621
|
+
```python exec="true" source="above" session="tensor" result="python"
|
622
|
+
Tensor.manual_seed(42)
|
623
|
+
print(Tensor.glorot_uniform(2, 3).numpy())
|
624
|
+
```
|
625
|
+
"""
|
238
626
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
239
627
|
|
240
628
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
241
629
|
@staticmethod
|
242
630
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
631
|
+
"""
|
632
|
+
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
633
|
+
|
634
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
635
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
636
|
+
|
637
|
+
```python exec="true" source="above" session="tensor" result="python"
|
638
|
+
Tensor.manual_seed(42)
|
639
|
+
print(Tensor.kaiming_uniform(2, 3).numpy())
|
640
|
+
```
|
641
|
+
"""
|
243
642
|
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
244
643
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
245
644
|
|
246
645
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
247
646
|
@staticmethod
|
248
647
|
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
648
|
+
"""
|
649
|
+
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
650
|
+
|
651
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
652
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
653
|
+
|
654
|
+
```python exec="true" source="above" session="tensor" result="python"
|
655
|
+
Tensor.manual_seed(42)
|
656
|
+
print(Tensor.kaiming_normal(2, 3).numpy())
|
657
|
+
```
|
658
|
+
"""
|
249
659
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
250
660
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
251
661
|
|
@@ -254,31 +664,40 @@ class Tensor:
|
|
254
664
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
255
665
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
256
666
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
257
|
-
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
|
667
|
+
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
|
258
668
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
259
|
-
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.
|
669
|
+
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
260
670
|
|
261
671
|
# ***** toposort and backward pass *****
|
262
672
|
|
263
|
-
def
|
264
|
-
def
|
673
|
+
def _deepwalk(self):
|
674
|
+
def _walk(node, visited):
|
265
675
|
visited.add(node)
|
266
676
|
if getattr(node, "_ctx", None):
|
267
677
|
for i in node._ctx.parents:
|
268
|
-
if i not in visited:
|
269
|
-
|
270
|
-
|
271
|
-
return _deepwalk(self, set(), [])
|
678
|
+
if i not in visited: yield from _walk(i, visited)
|
679
|
+
yield node
|
680
|
+
return list(_walk(self, set()))
|
272
681
|
|
273
682
|
def backward(self) -> Tensor:
|
683
|
+
"""
|
684
|
+
Propagates the gradient of a tensor backwards through the computation graph.
|
685
|
+
Must be used on a scalar tensor.
|
686
|
+
|
687
|
+
```python exec="true" source="above" session="tensor" result="python"
|
688
|
+
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
689
|
+
t.sum().backward()
|
690
|
+
print(t.grad.numpy())
|
691
|
+
```
|
692
|
+
"""
|
274
693
|
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
275
694
|
|
276
695
|
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
277
696
|
# this is "implicit gradient creation"
|
278
|
-
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
|
697
|
+
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
279
698
|
|
280
|
-
for t0 in reversed(self.
|
281
|
-
|
699
|
+
for t0 in reversed(self._deepwalk()):
|
700
|
+
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
282
701
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
283
702
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
284
703
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
@@ -289,67 +708,164 @@ class Tensor:
|
|
289
708
|
del t0._ctx
|
290
709
|
return self
|
291
710
|
|
292
|
-
# ***** movement
|
711
|
+
# ***** movement low level ops *****
|
712
|
+
|
713
|
+
def view(self, *shape) -> Tensor:
|
714
|
+
"""`.view` is an alias for `.reshape`."""
|
715
|
+
return self.reshape(shape)
|
293
716
|
|
294
717
|
def reshape(self, shape, *args) -> Tensor:
|
718
|
+
"""
|
719
|
+
Returns a tensor with the same data as the original tensor but with a different shape.
|
720
|
+
`shape` can be passed as a tuple or as separate arguments.
|
721
|
+
|
722
|
+
```python exec="true" source="above" session="tensor" result="python"
|
723
|
+
t = Tensor.arange(6)
|
724
|
+
print(t.reshape(2, 3).numpy())
|
725
|
+
```
|
726
|
+
"""
|
295
727
|
new_shape = argfix(shape, *args)
|
296
|
-
|
728
|
+
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
729
|
+
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
730
|
+
|
297
731
|
def expand(self, shape, *args) -> Tensor:
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
732
|
+
"""
|
733
|
+
Returns a tensor that is expanded to the shape that is specified.
|
734
|
+
Expand can also increase the number of dimensions that a tensor has.
|
735
|
+
|
736
|
+
Passing a `-1` or `None` to a dimension means that its size will not be changed.
|
737
|
+
|
738
|
+
```python exec="true" source="above" session="tensor" result="python"
|
739
|
+
t = Tensor([1, 2, 3])
|
740
|
+
print(t.expand(4, -1).numpy())
|
741
|
+
```
|
742
|
+
"""
|
743
|
+
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
|
744
|
+
|
745
|
+
def permute(self, order, *args) -> Tensor:
|
746
|
+
"""
|
747
|
+
Returns a tensor that is a permutation of the original tensor.
|
748
|
+
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
749
|
+
`order` can be passed as a tuple or as separate arguments.
|
750
|
+
|
751
|
+
```python exec="true" source="above" session="tensor" result="python"
|
752
|
+
t = Tensor.arange(6).reshape(2, 3)
|
753
|
+
print(t.numpy())
|
754
|
+
```
|
755
|
+
```python exec="true" source="above" session="tensor" result="python"
|
756
|
+
print(t.permute(1, 0).numpy())
|
757
|
+
```
|
758
|
+
"""
|
759
|
+
return F.Permute.apply(self, order=argfix(order, *args))
|
760
|
+
|
761
|
+
def flip(self, axis, *args) -> Tensor:
|
762
|
+
"""
|
763
|
+
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
764
|
+
`axis` can be passed as a tuple or as separate arguments.
|
765
|
+
|
766
|
+
```python exec="true" source="above" session="tensor" result="python"
|
767
|
+
t = Tensor.arange(6).reshape(2, 3)
|
768
|
+
print(t.numpy())
|
769
|
+
```
|
770
|
+
```python exec="true" source="above" session="tensor" result="python"
|
771
|
+
print(t.flip(0).numpy())
|
772
|
+
```
|
773
|
+
```python exec="true" source="above" session="tensor" result="python"
|
774
|
+
print(t.flip((0, 1)).numpy())
|
775
|
+
```
|
776
|
+
"""
|
777
|
+
return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
778
|
+
|
302
779
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
303
|
-
|
304
|
-
|
780
|
+
"""
|
781
|
+
Returns a tensor that shrinks the each axis based on input arg.
|
782
|
+
`arg` must have the same length as `self.ndim`.
|
783
|
+
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
784
|
+
|
785
|
+
```python exec="true" source="above" session="tensor" result="python"
|
786
|
+
t = Tensor.arange(9).reshape(3, 3)
|
787
|
+
print(t.numpy())
|
788
|
+
```
|
789
|
+
```python exec="true" source="above" session="tensor" result="python"
|
790
|
+
print(t.shrink(((None, (1, 3)))).numpy())
|
791
|
+
```
|
792
|
+
```python exec="true" source="above" session="tensor" result="python"
|
793
|
+
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
794
|
+
```
|
795
|
+
"""
|
796
|
+
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
|
797
|
+
return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
798
|
+
|
305
799
|
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
800
|
+
"""
|
801
|
+
Returns a tensor that pads the each axis based on input arg.
|
802
|
+
`arg` must have the same length as `self.ndim`.
|
803
|
+
For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
|
804
|
+
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
805
|
+
|
806
|
+
```python exec="true" source="above" session="tensor" result="python"
|
807
|
+
t = Tensor.arange(6).reshape(2, 3)
|
808
|
+
print(t.numpy())
|
809
|
+
```
|
810
|
+
```python exec="true" source="above" session="tensor" result="python"
|
811
|
+
print(t.pad(((None, (1, 2)))).numpy())
|
812
|
+
```
|
813
|
+
```python exec="true" source="above" session="tensor" result="python"
|
814
|
+
print(t.pad(((None, (1, 2))), -2).numpy())
|
815
|
+
```
|
816
|
+
"""
|
306
817
|
if all(x is None or x == (0,0) for x in arg): return self
|
307
|
-
ret =
|
308
|
-
return ret if 0 == value else ret +
|
309
|
-
|
310
|
-
# ***** movement
|
311
|
-
|
312
|
-
#
|
313
|
-
#
|
314
|
-
#
|
315
|
-
#
|
316
|
-
#
|
317
|
-
#
|
318
|
-
#
|
319
|
-
#
|
320
|
-
#
|
321
|
-
#
|
322
|
-
#
|
323
|
-
#
|
324
|
-
#
|
325
|
-
#
|
326
|
-
#
|
327
|
-
#
|
328
|
-
#
|
329
|
-
#
|
330
|
-
#
|
331
|
-
#
|
332
|
-
#
|
333
|
-
#
|
334
|
-
#
|
335
|
-
#
|
336
|
-
#
|
337
|
-
#
|
338
|
-
#
|
339
|
-
|
818
|
+
ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
819
|
+
return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
820
|
+
|
821
|
+
# ***** movement high level ops *****
|
822
|
+
|
823
|
+
# Supported Indexing Implementations:
|
824
|
+
# 1. Int indexing (no copy)
|
825
|
+
# - for all dims where there's int, shrink -> reshape
|
826
|
+
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
827
|
+
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
828
|
+
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
829
|
+
# 2. Slice indexing (no copy)
|
830
|
+
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
831
|
+
# - first shrink the Tensor to X.shrink(((start, end),))
|
832
|
+
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
833
|
+
# - flip where dim value is negative
|
834
|
+
# - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
835
|
+
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
836
|
+
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
837
|
+
# 3. None indexing (no copy)
|
838
|
+
# - reshape (inject) a dim at the dim where there's None
|
839
|
+
# 4. Tensor indexing (copy)
|
840
|
+
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
841
|
+
# - combine masks together with mul
|
842
|
+
# - apply mask to self by mask * self
|
843
|
+
# - sum reduce away the extra dims added from creating masks
|
844
|
+
# Tiny Things:
|
845
|
+
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
846
|
+
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
847
|
+
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
848
|
+
# 2. Bool indexing is not supported
|
849
|
+
# 3. Out of bounds Tensor indexing results in 0
|
850
|
+
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
|
851
|
+
def __getitem__(self, indices) -> Tensor:
|
340
852
|
# 1. indices normalization and validation
|
341
853
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
342
|
-
if isinstance(indices, (
|
343
|
-
|
344
|
-
|
345
|
-
else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
|
854
|
+
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
855
|
+
elif isinstance(indices, (tuple, list)):
|
856
|
+
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
346
857
|
else: indices = [indices]
|
347
858
|
|
859
|
+
# turn scalar Tensors into const val for int indexing if possible
|
860
|
+
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
861
|
+
# move Tensor indices to the same device as self
|
862
|
+
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
863
|
+
|
348
864
|
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
349
865
|
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
350
866
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
351
|
-
|
352
|
-
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) -
|
867
|
+
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
868
|
+
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
|
353
869
|
|
354
870
|
# use Dict[type, List[dimension]] to track elements in indices
|
355
871
|
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
@@ -359,212 +875,640 @@ class Tensor:
|
|
359
875
|
indices_filtered = [v for v in indices if v is not None]
|
360
876
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
361
877
|
|
362
|
-
|
363
|
-
|
364
|
-
if len(ellipsis_idx) > 1: raise IndexError("
|
365
|
-
if
|
366
|
-
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
|
367
|
-
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
878
|
+
for index_type in type_dim:
|
879
|
+
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
880
|
+
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
881
|
+
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
368
882
|
|
369
|
-
# 2. basic indexing (no copy)
|
883
|
+
# 2. basic indexing, uses only movement ops (no copy)
|
370
884
|
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
|
371
885
|
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
|
372
886
|
for dim in type_dim[int]:
|
373
887
|
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
|
374
|
-
raise IndexError(f"{index=} is out of bounds
|
888
|
+
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
|
375
889
|
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
|
376
890
|
for dim in type_dim[slice]:
|
377
|
-
|
378
|
-
|
379
|
-
|
891
|
+
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
|
892
|
+
s, e, st = index.indices(self.shape[dim])
|
893
|
+
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
|
894
|
+
# record tensors and skip all Tensor dims for basic indexing
|
895
|
+
tensor_index: List[Tensor] = []
|
896
|
+
for dim in type_dim[Tensor]:
|
897
|
+
tensor_index.append(index := indices_filtered[dim])
|
898
|
+
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
899
|
+
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
380
900
|
|
381
901
|
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
|
382
|
-
ret = self.shrink(new_slice).flip(
|
383
|
-
# add strides by pad -> reshape -> shrink
|
902
|
+
ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
|
384
903
|
if any(abs(s) != 1 for s in strides):
|
385
904
|
strides = tuple(abs(s) for s in strides)
|
386
905
|
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
|
387
|
-
ret = ret.reshape(flatten(
|
906
|
+
ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
|
388
907
|
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
|
389
908
|
|
390
909
|
# inject 1 for dim where it's None and collapse dim for int
|
391
910
|
new_shape = list(ret.shape)
|
392
911
|
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
393
|
-
for dim in (dims_collapsed :=
|
394
|
-
assert all_int(new_shape), f"does not support symbolic shape {new_shape}"
|
912
|
+
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
|
395
913
|
|
396
914
|
ret = ret.reshape(new_shape)
|
915
|
+
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
397
916
|
|
398
917
|
# 3. advanced indexing (copy)
|
399
918
|
if type_dim[Tensor]:
|
919
|
+
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
920
|
+
def calc_dim(tensor_dim:int) -> int:
|
921
|
+
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
|
922
|
+
|
923
|
+
# track tensor_dim and tensor_index using a dict
|
924
|
+
# calc_dim to get dim and use that to normalize the negative tensor indices
|
925
|
+
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
|
400
926
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
#
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
|
419
|
-
|
420
|
-
# iteratively eq -> mul -> sum fancy index
|
421
|
-
try:
|
422
|
-
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
423
|
-
except AssertionError as exc: raise IndexError(f"cannot broadcast with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
|
927
|
+
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
928
|
+
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
929
|
+
|
930
|
+
# create masks
|
931
|
+
for dim, i in idx.items():
|
932
|
+
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
933
|
+
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
934
|
+
a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
|
935
|
+
masks.append(i == a)
|
936
|
+
|
937
|
+
# reduce masks to 1 mask
|
938
|
+
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
939
|
+
|
940
|
+
# inject 1's for the extra dims added in create masks
|
941
|
+
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
942
|
+
# sum reduce the extra dims introduced in create masks
|
943
|
+
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
424
944
|
|
425
945
|
# special permute case
|
426
|
-
if
|
427
|
-
|
428
|
-
ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
|
946
|
+
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
947
|
+
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
|
429
948
|
return ret
|
430
949
|
|
431
|
-
def __setitem__(self,indices,v
|
432
|
-
|
433
|
-
|
434
|
-
|
950
|
+
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
951
|
+
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
952
|
+
self.__getitem__(indices).assign(v)
|
953
|
+
return
|
954
|
+
# NOTE: check that setitem target is valid first
|
955
|
+
assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
|
956
|
+
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
957
|
+
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
958
|
+
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
959
|
+
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
|
960
|
+
raise NotImplementedError("Advanced indexing setitem is not currently supported")
|
961
|
+
|
962
|
+
assign_to = self.realize().__getitem__(indices)
|
963
|
+
# NOTE: contiguous to prevent const folding.
|
964
|
+
v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
965
|
+
assign_to.assign(v).realize()
|
966
|
+
|
967
|
+
# NOTE: using _slice is discouraged and things should migrate to pad and shrink
|
968
|
+
def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
435
969
|
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
|
436
970
|
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
|
437
971
|
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
|
438
972
|
|
439
|
-
def gather(self:Tensor,
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
973
|
+
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
|
974
|
+
"""
|
975
|
+
Gathers values along an axis specified by `dim`.
|
976
|
+
|
977
|
+
```python exec="true" source="above" session="tensor" result="python"
|
978
|
+
t = Tensor([[1, 2], [3, 4]])
|
979
|
+
print(t.numpy())
|
980
|
+
```
|
981
|
+
```python exec="true" source="above" session="tensor" result="python"
|
982
|
+
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
983
|
+
```
|
984
|
+
"""
|
985
|
+
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
986
|
+
assert all(s >= i for s,i in zip(self.shape, index.shape)), "all dim of index.shape must be smaller than self.shape"
|
987
|
+
dim = self._resolve_dim(dim)
|
988
|
+
index = index.to(self.device).transpose(0, dim).unsqueeze(-1)
|
444
989
|
permarg = list(range(self.ndim))
|
445
990
|
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
446
|
-
return ((
|
991
|
+
return ((index == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
|
992
|
+
tuple([*[(0,sh) for sh in index.shape[1:-1]], None])).unsqueeze(0)).sum(-1, acc_dtype=self.dtype).transpose(0, dim)
|
447
993
|
|
448
994
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
449
|
-
|
995
|
+
"""
|
996
|
+
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
|
997
|
+
All tensors must have the same shape except in the concatenating dimension.
|
998
|
+
|
999
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1000
|
+
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
|
1001
|
+
print(t0.cat(t1, t2, dim=0).numpy())
|
1002
|
+
```
|
1003
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1004
|
+
print(t0.cat(t1, t2, dim=1).numpy())
|
1005
|
+
```
|
1006
|
+
"""
|
1007
|
+
dim = self._resolve_dim(dim)
|
450
1008
|
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
451
1009
|
catargs = [self, *args]
|
452
|
-
|
453
|
-
|
454
|
-
shape_cumsum = [0, *itertools.accumulate(shapes)]
|
1010
|
+
cat_dims = [s.shape[dim] for s in catargs]
|
1011
|
+
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
|
455
1012
|
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
456
|
-
for
|
457
|
-
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
1013
|
+
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
1014
|
+
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
1015
|
+
|
1016
|
+
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1017
|
+
"""
|
1018
|
+
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
|
1019
|
+
|
1020
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1021
|
+
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
|
1022
|
+
print(t0.stack(t1, t2, dim=0).numpy())
|
1023
|
+
```
|
1024
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1025
|
+
print(t0.stack(t1, t2, dim=1).numpy())
|
1026
|
+
```
|
1027
|
+
"""
|
462
1028
|
# checks for shapes and number of dimensions delegated to cat
|
463
|
-
return
|
464
|
-
|
465
|
-
def repeat(self, repeats
|
1029
|
+
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
|
1030
|
+
|
1031
|
+
def repeat(self, repeats, *args) -> Tensor:
|
1032
|
+
"""
|
1033
|
+
Repeats tensor number of times along each dimension specified by `repeats`.
|
1034
|
+
`repeats` can be passed as a tuple or as separate arguments.
|
1035
|
+
|
1036
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1037
|
+
t = Tensor([1, 2, 3])
|
1038
|
+
print(t.repeat(4, 2).numpy())
|
1039
|
+
```
|
1040
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1041
|
+
print(t.repeat(4, 2, 1).shape)
|
1042
|
+
```
|
1043
|
+
"""
|
1044
|
+
repeats = argfix(repeats, *args)
|
466
1045
|
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
|
467
1046
|
new_shape = [x for b in base_shape for x in [1, b]]
|
468
1047
|
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
|
469
1048
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
470
1049
|
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
471
1050
|
|
1051
|
+
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
1052
|
+
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
1053
|
+
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
|
1054
|
+
return dim + self.ndim+outer if dim < 0 else dim
|
1055
|
+
|
472
1056
|
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
1057
|
+
"""
|
1058
|
+
Splits the tensor into chunks along the dimension specified by `dim`.
|
1059
|
+
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
1060
|
+
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
|
1061
|
+
|
1062
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1063
|
+
t = Tensor.arange(10).reshape(5, 2)
|
1064
|
+
print(t.numpy())
|
1065
|
+
```
|
1066
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1067
|
+
split = t.split(2)
|
1068
|
+
print("\\n".join([repr(x.numpy()) for x in split]))
|
1069
|
+
```
|
1070
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1071
|
+
split = t.split([1, 4])
|
1072
|
+
print("\\n".join([repr(x.numpy()) for x in split]))
|
1073
|
+
```
|
1074
|
+
"""
|
473
1075
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
474
|
-
dim =
|
475
|
-
if isinstance(sizes, int):
|
1076
|
+
dim = self._resolve_dim(dim)
|
1077
|
+
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
1078
|
+
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
476
1079
|
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
477
1080
|
|
478
|
-
def chunk(self,
|
1081
|
+
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
|
1082
|
+
"""
|
1083
|
+
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
1084
|
+
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
1085
|
+
The function may return fewer than the specified number of chunks.
|
1086
|
+
|
1087
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1088
|
+
chunked = Tensor.arange(11).chunk(6)
|
1089
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1090
|
+
```
|
1091
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1092
|
+
chunked = Tensor.arange(12).chunk(6)
|
1093
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1094
|
+
```
|
1095
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1096
|
+
chunked = Tensor.arange(13).chunk(6)
|
1097
|
+
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
1098
|
+
```
|
1099
|
+
"""
|
479
1100
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
480
|
-
|
481
|
-
|
482
|
-
return
|
1101
|
+
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
1102
|
+
dim = self._resolve_dim(dim)
|
1103
|
+
return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
483
1104
|
|
484
1105
|
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
1106
|
+
"""
|
1107
|
+
Returns a tensor with specified dimensions of input of size 1 removed.
|
1108
|
+
If `dim` is not specified, all dimensions with size 1 are removed.
|
1109
|
+
|
1110
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1111
|
+
t = Tensor.zeros(2, 1, 2, 1, 2)
|
1112
|
+
print(t.squeeze().shape)
|
1113
|
+
```
|
1114
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1115
|
+
print(t.squeeze(0).shape)
|
1116
|
+
```
|
1117
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1118
|
+
print(t.squeeze(1).shape)
|
1119
|
+
```
|
1120
|
+
"""
|
485
1121
|
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
486
|
-
|
487
|
-
if not
|
488
|
-
if dim < 0: dim += self.ndim
|
489
|
-
return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
1122
|
+
dim = self._resolve_dim(dim)
|
1123
|
+
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
490
1124
|
|
491
1125
|
def unsqueeze(self, dim:int) -> Tensor:
|
492
|
-
|
1126
|
+
"""
|
1127
|
+
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
1128
|
+
|
1129
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1130
|
+
t = Tensor([1, 2, 3, 4])
|
1131
|
+
print(t.unsqueeze(0).numpy())
|
1132
|
+
```
|
1133
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1134
|
+
print(t.unsqueeze(1).numpy())
|
1135
|
+
```
|
1136
|
+
"""
|
1137
|
+
dim = self._resolve_dim(dim, outer=True)
|
493
1138
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
494
1139
|
|
495
|
-
|
496
|
-
|
1140
|
+
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
|
1141
|
+
"""
|
1142
|
+
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
|
1143
|
+
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
1144
|
+
|
1145
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1146
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1147
|
+
print(t.numpy())
|
1148
|
+
```
|
1149
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1150
|
+
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
1151
|
+
```
|
1152
|
+
"""
|
497
1153
|
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
498
|
-
return self.
|
1154
|
+
return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
499
1155
|
|
500
1156
|
@property
|
501
|
-
def T(self) -> Tensor:
|
502
|
-
|
1157
|
+
def T(self) -> Tensor:
|
1158
|
+
"""`.T` is an alias for `.transpose()`."""
|
1159
|
+
return self.transpose()
|
1160
|
+
|
1161
|
+
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
1162
|
+
"""
|
1163
|
+
Returns a tensor that is a transposed version of the original tensor.
|
1164
|
+
The given dimensions `dim0` and `dim1` are swapped.
|
1165
|
+
|
1166
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1167
|
+
t = Tensor.arange(6).reshape(2, 3)
|
1168
|
+
print(t.numpy())
|
1169
|
+
```
|
1170
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1171
|
+
print(t.transpose(0, 1).numpy())
|
1172
|
+
```
|
1173
|
+
"""
|
503
1174
|
order = list(range(self.ndim))
|
504
|
-
order[
|
1175
|
+
order[dim0], order[dim1] = order[dim1], order[dim0]
|
505
1176
|
return self.permute(order)
|
1177
|
+
|
506
1178
|
def flatten(self, start_dim=0, end_dim=-1):
|
507
|
-
|
1179
|
+
"""
|
1180
|
+
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
1181
|
+
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
1182
|
+
|
1183
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1184
|
+
t = Tensor.arange(8).reshape(2, 2, 2)
|
1185
|
+
print(t.flatten().numpy())
|
1186
|
+
```
|
1187
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1188
|
+
print(t.flatten(start_dim=1).numpy())
|
1189
|
+
```
|
1190
|
+
"""
|
1191
|
+
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
508
1192
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
1193
|
+
|
509
1194
|
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
510
|
-
|
1195
|
+
"""
|
1196
|
+
Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
|
1197
|
+
|
1198
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1199
|
+
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
1200
|
+
```
|
1201
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1202
|
+
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
1203
|
+
```
|
1204
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1205
|
+
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
1206
|
+
```
|
1207
|
+
"""
|
1208
|
+
dim = self._resolve_dim(dim)
|
511
1209
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
512
1210
|
|
513
1211
|
# ***** reduce ops *****
|
514
1212
|
|
515
1213
|
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
ret = fxn.apply(self,
|
522
|
-
return ret if keepdim else ret.reshape(shape
|
523
|
-
|
524
|
-
def sum(self, axis=None, keepdim=False):
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
1214
|
+
if self.ndim == 0:
|
1215
|
+
if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
|
1216
|
+
axis = None
|
1217
|
+
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
1218
|
+
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
1219
|
+
ret = fxn.apply(self, axis=axis_)
|
1220
|
+
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
1221
|
+
|
1222
|
+
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
1223
|
+
"""
|
1224
|
+
Sums the elements of the tensor along the specified axis or axes.
|
1225
|
+
|
1226
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1227
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1228
|
+
|
1229
|
+
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
1230
|
+
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
1231
|
+
|
1232
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1233
|
+
t = Tensor.arange(6).reshape(2, 3)
|
1234
|
+
print(t.numpy())
|
1235
|
+
```
|
1236
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1237
|
+
print(t.sum().numpy())
|
1238
|
+
```
|
1239
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1240
|
+
print(t.sum(axis=0).numpy())
|
1241
|
+
```
|
1242
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1243
|
+
print(t.sum(axis=1).numpy())
|
1244
|
+
```
|
1245
|
+
"""
|
1246
|
+
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
1247
|
+
return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
|
1248
|
+
def max(self, axis=None, keepdim=False):
|
1249
|
+
"""
|
1250
|
+
Returns the maximum value of the tensor along the specified axis or axes.
|
1251
|
+
|
1252
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1253
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1254
|
+
|
1255
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1256
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1257
|
+
print(t.numpy())
|
1258
|
+
```
|
1259
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1260
|
+
print(t.max().numpy())
|
1261
|
+
```
|
1262
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1263
|
+
print(t.max(axis=0).numpy())
|
1264
|
+
```
|
1265
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1266
|
+
print(t.max(axis=1, keepdim=True).numpy())
|
1267
|
+
```
|
1268
|
+
"""
|
1269
|
+
return self._reduce(F.Max, axis, keepdim)
|
1270
|
+
def min(self, axis=None, keepdim=False):
|
1271
|
+
"""
|
1272
|
+
Returns the minimum value of the tensor along the specified axis or axes.
|
1273
|
+
|
1274
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1275
|
+
which the minimum is computed and whether the reduced dimensions are retained.
|
1276
|
+
|
1277
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1278
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1279
|
+
print(t.numpy())
|
1280
|
+
```
|
1281
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1282
|
+
print(t.min().numpy())
|
1283
|
+
```
|
1284
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1285
|
+
print(t.min(axis=0).numpy())
|
1286
|
+
```
|
1287
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1288
|
+
print(t.min(axis=1, keepdim=True).numpy())
|
1289
|
+
```
|
1290
|
+
"""
|
1291
|
+
return -((-self).max(axis=axis, keepdim=keepdim))
|
534
1292
|
|
535
1293
|
def mean(self, axis=None, keepdim=False):
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
1294
|
+
"""
|
1295
|
+
Returns the mean value of the tensor along the specified axis or axes.
|
1296
|
+
|
1297
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1298
|
+
which the mean is computed and whether the reduced dimensions are retained.
|
1299
|
+
|
1300
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1301
|
+
Tensor.manual_seed(42)
|
1302
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1303
|
+
print(t.numpy())
|
1304
|
+
```
|
1305
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1306
|
+
print(t.mean().numpy())
|
1307
|
+
```
|
1308
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1309
|
+
print(t.mean(axis=0).numpy())
|
1310
|
+
```
|
1311
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1312
|
+
print(t.mean(axis=1).numpy())
|
1313
|
+
```
|
1314
|
+
"""
|
1315
|
+
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
1316
|
+
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
1317
|
+
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
1318
|
+
|
1319
|
+
def var(self, axis=None, keepdim=False, correction=1):
|
1320
|
+
"""
|
1321
|
+
Returns the variance of the tensor along the specified axis or axes.
|
1322
|
+
|
1323
|
+
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
1324
|
+
which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
1325
|
+
|
1326
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1327
|
+
Tensor.manual_seed(42)
|
1328
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1329
|
+
print(t.numpy())
|
1330
|
+
```
|
1331
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1332
|
+
print(t.var().numpy())
|
1333
|
+
```
|
1334
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1335
|
+
print(t.var(axis=0).numpy())
|
1336
|
+
```
|
1337
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1338
|
+
print(t.var(axis=1).numpy())
|
1339
|
+
```
|
1340
|
+
"""
|
540
1341
|
assert all_int(self.shape), "does not support symbolic shape"
|
541
1342
|
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
542
|
-
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction)
|
1343
|
+
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
|
1344
|
+
|
1345
|
+
def std(self, axis=None, keepdim=False, correction=1):
|
1346
|
+
"""
|
1347
|
+
Returns the standard deviation of the tensor along the specified axis or axes.
|
1348
|
+
|
1349
|
+
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
|
1350
|
+
which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
|
1351
|
+
|
1352
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1353
|
+
Tensor.manual_seed(42)
|
1354
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1355
|
+
print(t.numpy())
|
1356
|
+
```
|
1357
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1358
|
+
print(t.std().numpy())
|
1359
|
+
```
|
1360
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1361
|
+
print(t.std(axis=0).numpy())
|
1362
|
+
```
|
1363
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1364
|
+
print(t.std(axis=1).numpy())
|
1365
|
+
```
|
1366
|
+
"""
|
1367
|
+
return self.var(axis, keepdim, correction).sqrt()
|
1368
|
+
|
543
1369
|
def _softmax(self, axis):
|
544
1370
|
m = self - self.max(axis=axis, keepdim=True)
|
545
1371
|
e = m.exp()
|
546
1372
|
return m, e, e.sum(axis=axis, keepdim=True)
|
547
1373
|
|
548
1374
|
def softmax(self, axis=-1):
|
1375
|
+
"""
|
1376
|
+
Applies the softmax function to the tensor along the specified axis.
|
1377
|
+
|
1378
|
+
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
|
1379
|
+
|
1380
|
+
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
|
1381
|
+
|
1382
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1383
|
+
Tensor.manual_seed(42)
|
1384
|
+
t = Tensor.randn(2, 3)
|
1385
|
+
print(t.numpy())
|
1386
|
+
```
|
1387
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1388
|
+
print(t.softmax().numpy())
|
1389
|
+
```
|
1390
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1391
|
+
print(t.softmax(axis=0).numpy())
|
1392
|
+
```
|
1393
|
+
"""
|
549
1394
|
_, e, ss = self._softmax(axis)
|
550
1395
|
return e.div(ss)
|
551
1396
|
|
552
1397
|
def log_softmax(self, axis=-1):
|
1398
|
+
"""
|
1399
|
+
Applies the log-softmax function to the tensor along the specified axis.
|
1400
|
+
|
1401
|
+
The log-softmax function is a numerically stable alternative to the softmax function in log space.
|
1402
|
+
|
1403
|
+
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
|
1404
|
+
|
1405
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1406
|
+
Tensor.manual_seed(42)
|
1407
|
+
t = Tensor.randn(2, 3)
|
1408
|
+
print(t.numpy())
|
1409
|
+
```
|
1410
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1411
|
+
print(t.log_softmax().numpy())
|
1412
|
+
```
|
1413
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1414
|
+
print(t.log_softmax(axis=0).numpy())
|
1415
|
+
```
|
1416
|
+
"""
|
553
1417
|
m, _, ss = self._softmax(axis)
|
554
1418
|
return m - ss.log()
|
555
1419
|
|
1420
|
+
def logsumexp(self, axis=None, keepdim=False):
|
1421
|
+
"""
|
1422
|
+
Computes the log-sum-exp of the tensor along the specified axis or axes.
|
1423
|
+
|
1424
|
+
The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
|
1425
|
+
|
1426
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1427
|
+
which the log-sum-exp is computed and whether the reduced dimensions are retained.
|
1428
|
+
|
1429
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1430
|
+
Tensor.manual_seed(42)
|
1431
|
+
t = Tensor.randn(2, 3)
|
1432
|
+
print(t.numpy())
|
1433
|
+
```
|
1434
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1435
|
+
print(t.logsumexp().numpy())
|
1436
|
+
```
|
1437
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1438
|
+
print(t.logsumexp(axis=0).numpy())
|
1439
|
+
```
|
1440
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1441
|
+
print(t.logsumexp(axis=1).numpy())
|
1442
|
+
```
|
1443
|
+
"""
|
1444
|
+
m = self.max(axis=axis, keepdim=True)
|
1445
|
+
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
1446
|
+
|
556
1447
|
def argmax(self, axis=None, keepdim=False):
|
1448
|
+
"""
|
1449
|
+
Returns the indices of the maximum value of the tensor along the specified axis.
|
1450
|
+
|
1451
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1452
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1453
|
+
|
1454
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1455
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1456
|
+
print(t.numpy())
|
1457
|
+
```
|
1458
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1459
|
+
print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
|
1460
|
+
```
|
1461
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1462
|
+
print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
|
1463
|
+
```
|
1464
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1465
|
+
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
1466
|
+
```
|
1467
|
+
"""
|
557
1468
|
if axis is None:
|
558
1469
|
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
559
|
-
return (prod(self.shape) - idx.max() - 1).cast(dtypes.
|
560
|
-
axis =
|
1470
|
+
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
1471
|
+
axis = self._resolve_dim(axis)
|
561
1472
|
m = self == self.max(axis=axis, keepdim=True)
|
562
1473
|
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
563
|
-
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.
|
564
|
-
|
1474
|
+
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
1475
|
+
|
1476
|
+
def argmin(self, axis=None, keepdim=False):
|
1477
|
+
"""
|
1478
|
+
Returns the indices of the minimum value of the tensor along the specified axis.
|
1479
|
+
|
1480
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1481
|
+
which the minimum is computed and whether the reduced dimensions are retained.
|
1482
|
+
|
1483
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1484
|
+
t = Tensor([[1, 0, 2], [5, 4, 3]])
|
1485
|
+
print(t.numpy())
|
1486
|
+
```
|
1487
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1488
|
+
print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
|
1489
|
+
```
|
1490
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1491
|
+
print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
|
1492
|
+
```
|
1493
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1494
|
+
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
1495
|
+
```
|
1496
|
+
"""
|
1497
|
+
return (-self).argmax(axis=axis, keepdim=keepdim)
|
565
1498
|
|
566
1499
|
@staticmethod
|
567
|
-
def einsum(formula:str, *raw_xs) -> Tensor:
|
1500
|
+
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
|
1501
|
+
"""
|
1502
|
+
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
1503
|
+
|
1504
|
+
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
1505
|
+
|
1506
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1507
|
+
x = Tensor([[1, 2], [3, 4]])
|
1508
|
+
y = Tensor([[5, 6], [7, 8]])
|
1509
|
+
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
1510
|
+
```
|
1511
|
+
"""
|
568
1512
|
xs:Tuple[Tensor] = argfix(*raw_xs)
|
569
1513
|
formula = formula.replace(" ", "")
|
570
1514
|
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
|
@@ -580,9 +1524,13 @@ class Tensor:
|
|
580
1524
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
581
1525
|
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
582
1526
|
|
583
|
-
|
1527
|
+
# determine the inverse permutation to revert back to original order
|
1528
|
+
rhs_letter_order = argsort(list(output))
|
1529
|
+
rhs_order = argsort(rhs_letter_order)
|
1530
|
+
|
584
1531
|
# sum over all axes that's not in the output, then permute to the output order
|
585
|
-
return reduce(lambda a,b:a*b, xs_)
|
1532
|
+
return functools.reduce(lambda a,b:a*b, xs_) \
|
1533
|
+
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
|
586
1534
|
|
587
1535
|
# ***** processing ops *****
|
588
1536
|
|
@@ -590,46 +1538,72 @@ class Tensor:
|
|
590
1538
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
591
1539
|
assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
|
592
1540
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
593
|
-
assert len(k_) == len(s_)
|
594
|
-
|
1541
|
+
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
1542
|
+
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
|
595
1543
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
596
1544
|
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
597
|
-
|
598
|
-
xup = self.
|
599
|
-
#
|
600
|
-
xup = xup.
|
601
|
-
|
602
|
-
xup = xup.
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
607
|
-
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
1545
|
+
# repeats such that we don't need padding
|
1546
|
+
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
|
1547
|
+
# slice by dilation
|
1548
|
+
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
1549
|
+
# handle stride
|
1550
|
+
xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
1551
|
+
xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
1552
|
+
# permute to move reduce to the end
|
1553
|
+
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
608
1554
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
609
1555
|
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
610
|
-
xup = self.
|
611
|
-
xup = xup.reshape(
|
612
|
-
xup = xup.
|
613
|
-
return xup.permute(*range(len(
|
1556
|
+
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
|
1557
|
+
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
1558
|
+
xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
1559
|
+
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
|
614
1560
|
|
615
1561
|
# NOTE: these work for more than 2D
|
616
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
return
|
630
|
-
|
631
|
-
|
632
|
-
|
1562
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1563
|
+
"""
|
1564
|
+
Applies average pooling over a tensor.
|
1565
|
+
|
1566
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
1567
|
+
|
1568
|
+
See: https://paperswithcode.com/method/average-pooling
|
1569
|
+
|
1570
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1571
|
+
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1572
|
+
print(t.avg_pool2d().numpy())
|
1573
|
+
```
|
1574
|
+
"""
|
1575
|
+
return self._pool(
|
1576
|
+
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
1577
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1578
|
+
"""
|
1579
|
+
Applies max pooling over a tensor.
|
1580
|
+
|
1581
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
1582
|
+
|
1583
|
+
See: https://paperswithcode.com/method/max-pooling
|
1584
|
+
|
1585
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1586
|
+
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1587
|
+
print(t.max_pool2d().numpy())
|
1588
|
+
```
|
1589
|
+
"""
|
1590
|
+
return self._pool(
|
1591
|
+
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
1592
|
+
|
1593
|
+
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
1594
|
+
"""
|
1595
|
+
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
1596
|
+
|
1597
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
|
1598
|
+
|
1599
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
1600
|
+
|
1601
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1602
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1603
|
+
w = Tensor.ones(1, 1, 2, 2)
|
1604
|
+
print(t.conv2d(w).numpy())
|
1605
|
+
```
|
1606
|
+
"""
|
633
1607
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
634
1608
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
635
1609
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
@@ -638,19 +1612,17 @@ class Tensor:
|
|
638
1612
|
# conv2d is a pooling op (with padding)
|
639
1613
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
640
1614
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
641
|
-
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not
|
1615
|
+
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
642
1616
|
# normal conv
|
643
1617
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
644
1618
|
|
645
1619
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
646
|
-
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
|
1620
|
+
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
|
647
1621
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
648
1622
|
|
649
|
-
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
650
|
-
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
|
651
1623
|
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
652
|
-
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
653
1624
|
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
1625
|
+
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
654
1626
|
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
655
1627
|
|
656
1628
|
# todo: stride == dilation
|
@@ -658,19 +1630,19 @@ class Tensor:
|
|
658
1630
|
# (bs, cin_, tyx, HWI)
|
659
1631
|
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
660
1632
|
# move HW to the front: # (HWI, bs, cin_, tyx)
|
661
|
-
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
1633
|
+
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
662
1634
|
tyx = d.shape[-len(HWI):] # dim of tiling
|
663
1635
|
|
664
1636
|
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
665
1637
|
|
666
1638
|
# compute 6x6 winograd tiles: GgGt, BtdB
|
667
1639
|
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
668
|
-
gfactors =
|
1640
|
+
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
669
1641
|
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
670
|
-
dfactors =
|
1642
|
+
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
671
1643
|
|
672
1644
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
673
|
-
ret =
|
1645
|
+
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
|
674
1646
|
|
675
1647
|
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
676
1648
|
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
@@ -679,19 +1651,81 @@ class Tensor:
|
|
679
1651
|
|
680
1652
|
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
681
1653
|
|
682
|
-
def
|
1654
|
+
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
1655
|
+
"""
|
1656
|
+
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
1657
|
+
|
1658
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
|
1659
|
+
|
1660
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
1661
|
+
|
1662
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1663
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1664
|
+
w = Tensor.ones(1, 1, 2, 2)
|
1665
|
+
print(t.conv_transpose2d(w).numpy())
|
1666
|
+
```
|
1667
|
+
"""
|
1668
|
+
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
1669
|
+
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
|
1670
|
+
stride = make_pair(stride, len(HW))
|
1671
|
+
if any(s>1 for s in stride):
|
1672
|
+
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
1673
|
+
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
1674
|
+
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
1675
|
+
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
1676
|
+
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
|
1677
|
+
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
1678
|
+
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
1679
|
+
|
1680
|
+
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
1681
|
+
"""
|
1682
|
+
Performs dot product between two tensors.
|
1683
|
+
|
1684
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1685
|
+
a = Tensor([[1, 2], [3, 4]])
|
1686
|
+
b = Tensor([[5, 6], [7, 8]])
|
1687
|
+
print(a.dot(b).numpy())
|
1688
|
+
```
|
1689
|
+
"""
|
683
1690
|
n1, n2 = len(self.shape), len(w.shape)
|
684
1691
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
685
|
-
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({
|
1692
|
+
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
|
686
1693
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
687
1694
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
688
|
-
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
|
1695
|
+
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
|
1696
|
+
|
1697
|
+
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
1698
|
+
"""
|
1699
|
+
Performs matrix multiplication between two tensors.
|
1700
|
+
|
1701
|
+
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
|
1702
|
+
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
689
1703
|
|
690
|
-
|
1704
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1705
|
+
a = Tensor([[1, 2], [3, 4]])
|
1706
|
+
b = Tensor([[5, 6], [7, 8]])
|
1707
|
+
print(a.matmul(b).numpy())
|
1708
|
+
```
|
1709
|
+
"""
|
1710
|
+
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
691
1711
|
|
692
1712
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
693
|
-
|
1713
|
+
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
|
1714
|
+
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
|
694
1715
|
def cumsum(self, axis:int=0) -> Tensor:
|
1716
|
+
"""
|
1717
|
+
Computes the cumulative sum of the tensor along the specified axis.
|
1718
|
+
|
1719
|
+
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
|
1720
|
+
|
1721
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1722
|
+
t = Tensor.ones(2, 3)
|
1723
|
+
print(t.numpy())
|
1724
|
+
```
|
1725
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1726
|
+
print(t.cumsum(1).numpy())
|
1727
|
+
```
|
1728
|
+
"""
|
695
1729
|
# TODO: someday the optimizer will find this on it's own
|
696
1730
|
# for now this is a two stage cumsum
|
697
1731
|
SPLIT = 256
|
@@ -706,72 +1740,527 @@ class Tensor:
|
|
706
1740
|
@staticmethod
|
707
1741
|
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
|
708
1742
|
assert all_int((r,c)), "does not support symbolic"
|
1743
|
+
if r == 0: return Tensor.zeros((r, c), **kwargs)
|
709
1744
|
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
710
|
-
def triu(self, k:int=0) -> Tensor:
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
def
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
def
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
def
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
def
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
def
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
def
|
1745
|
+
def triu(self, k:int=0) -> Tensor:
|
1746
|
+
"""
|
1747
|
+
Returns the upper triangular part of the tensor, the other elements are set to 0.
|
1748
|
+
|
1749
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1750
|
+
t = Tensor([[1, 2, 3], [4, 5, 6]])
|
1751
|
+
print(t.numpy())
|
1752
|
+
```
|
1753
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1754
|
+
print(t.triu(k=1).numpy())
|
1755
|
+
```
|
1756
|
+
"""
|
1757
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
|
1758
|
+
def tril(self, k:int=0) -> Tensor:
|
1759
|
+
"""
|
1760
|
+
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
1761
|
+
|
1762
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1763
|
+
t = Tensor([[1, 2, 3], [4, 5, 6]])
|
1764
|
+
print(t.numpy())
|
1765
|
+
```
|
1766
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1767
|
+
print(t.tril().numpy())
|
1768
|
+
```
|
1769
|
+
"""
|
1770
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
|
1771
|
+
|
1772
|
+
# ***** unary ops *****
|
1773
|
+
|
1774
|
+
def logical_not(self):
|
1775
|
+
"""
|
1776
|
+
Computes the logical NOT of the tensor element-wise.
|
1777
|
+
|
1778
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1779
|
+
print(Tensor([False, True]).logical_not().numpy())
|
1780
|
+
```
|
1781
|
+
"""
|
1782
|
+
return F.Eq.apply(*self._broadcasted(False))
|
1783
|
+
def neg(self):
|
1784
|
+
"""
|
1785
|
+
Negates the tensor element-wise.
|
1786
|
+
|
1787
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1788
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
|
1789
|
+
```
|
1790
|
+
"""
|
1791
|
+
return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
|
1792
|
+
def contiguous(self):
|
1793
|
+
"""
|
1794
|
+
Returns a contiguous tensor.
|
1795
|
+
"""
|
1796
|
+
return F.Contiguous.apply(self)
|
1797
|
+
def contiguous_backward(self):
|
1798
|
+
"""
|
1799
|
+
Inserts a contiguous operation in the backward pass.
|
1800
|
+
"""
|
1801
|
+
return F.ContiguousBackward.apply(self)
|
1802
|
+
def log(self):
|
1803
|
+
"""
|
1804
|
+
Computes the natural logarithm element-wise.
|
1805
|
+
|
1806
|
+
See: https://en.wikipedia.org/wiki/Logarithm
|
1807
|
+
|
1808
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1809
|
+
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
1810
|
+
```
|
1811
|
+
"""
|
1812
|
+
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
1813
|
+
def log2(self):
|
1814
|
+
"""
|
1815
|
+
Computes the base-2 logarithm element-wise.
|
1816
|
+
|
1817
|
+
See: https://en.wikipedia.org/wiki/Logarithm
|
1818
|
+
|
1819
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1820
|
+
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
1821
|
+
```
|
1822
|
+
"""
|
1823
|
+
return self.log()/math.log(2)
|
1824
|
+
def exp(self):
|
1825
|
+
"""
|
1826
|
+
Computes the exponential function element-wise.
|
1827
|
+
|
1828
|
+
See: https://en.wikipedia.org/wiki/Exponential_function
|
1829
|
+
|
1830
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1831
|
+
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
1832
|
+
```
|
1833
|
+
"""
|
1834
|
+
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
1835
|
+
def exp2(self):
|
1836
|
+
"""
|
1837
|
+
Computes the base-2 exponential function element-wise.
|
1838
|
+
|
1839
|
+
See: https://en.wikipedia.org/wiki/Exponential_function
|
1840
|
+
|
1841
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1842
|
+
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
1843
|
+
```
|
1844
|
+
"""
|
1845
|
+
return F.Exp.apply(self*math.log(2))
|
1846
|
+
def relu(self):
|
1847
|
+
"""
|
1848
|
+
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
1849
|
+
|
1850
|
+
- Described: https://paperswithcode.com/method/relu
|
1851
|
+
|
1852
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1853
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
1854
|
+
```
|
1855
|
+
"""
|
1856
|
+
return F.Relu.apply(self)
|
1857
|
+
def sigmoid(self):
|
1858
|
+
"""
|
1859
|
+
Applies the Sigmoid function element-wise.
|
1860
|
+
|
1861
|
+
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
|
1862
|
+
|
1863
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1864
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
1865
|
+
```
|
1866
|
+
"""
|
1867
|
+
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
1868
|
+
def sqrt(self):
|
1869
|
+
"""
|
1870
|
+
Computes the square root of the tensor element-wise.
|
1871
|
+
|
1872
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1873
|
+
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
1874
|
+
```
|
1875
|
+
"""
|
1876
|
+
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
1877
|
+
def rsqrt(self):
|
1878
|
+
"""
|
1879
|
+
Computes the reciprocal of the square root of the tensor element-wise.
|
1880
|
+
|
1881
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1882
|
+
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
1883
|
+
```
|
1884
|
+
"""
|
1885
|
+
return self.reciprocal().sqrt()
|
1886
|
+
def sin(self):
|
1887
|
+
"""
|
1888
|
+
Computes the sine of the tensor element-wise.
|
1889
|
+
|
1890
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1891
|
+
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
1892
|
+
```
|
1893
|
+
"""
|
1894
|
+
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
1895
|
+
def cos(self):
|
1896
|
+
"""
|
1897
|
+
Computes the cosine of the tensor element-wise.
|
1898
|
+
|
1899
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1900
|
+
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
|
1901
|
+
```
|
1902
|
+
"""
|
1903
|
+
return ((math.pi/2)-self).sin()
|
1904
|
+
def tan(self):
|
1905
|
+
"""
|
1906
|
+
Computes the tangent of the tensor element-wise.
|
1907
|
+
|
1908
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1909
|
+
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
|
1910
|
+
```
|
1911
|
+
"""
|
1912
|
+
return self.sin() / self.cos()
|
1913
|
+
|
1914
|
+
# ***** math functions *****
|
1915
|
+
|
1916
|
+
def trunc(self: Tensor) -> Tensor:
|
1917
|
+
"""
|
1918
|
+
Truncates the tensor element-wise.
|
1919
|
+
|
1920
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1921
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
|
1922
|
+
```
|
1923
|
+
"""
|
1924
|
+
return self.cast(dtypes.int32).cast(self.dtype)
|
1925
|
+
def ceil(self: Tensor) -> Tensor:
|
1926
|
+
"""
|
1927
|
+
Rounds the tensor element-wise towards positive infinity.
|
1928
|
+
|
1929
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1930
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
|
1931
|
+
```
|
1932
|
+
"""
|
1933
|
+
return (self > (b := self.trunc())).where(b+1, b)
|
1934
|
+
def floor(self: Tensor) -> Tensor:
|
1935
|
+
"""
|
1936
|
+
Rounds the tensor element-wise towards negative infinity.
|
1937
|
+
|
1938
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1939
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
|
1940
|
+
```
|
1941
|
+
"""
|
1942
|
+
return (self < (b := self.trunc())).where(b-1, b)
|
1943
|
+
def round(self: Tensor) -> Tensor:
|
1944
|
+
"""
|
1945
|
+
Rounds the tensor element-wise.
|
1946
|
+
|
1947
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1948
|
+
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
|
1949
|
+
```
|
1950
|
+
"""
|
1951
|
+
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
1952
|
+
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
1953
|
+
"""
|
1954
|
+
Linearly interpolates between `self` and `end` by `weight`.
|
1955
|
+
|
1956
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1957
|
+
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
|
1958
|
+
```
|
1959
|
+
"""
|
1960
|
+
return self + (end - self) * weight
|
1961
|
+
def square(self):
|
1962
|
+
"""
|
1963
|
+
Squares the tensor element-wise.
|
1964
|
+
Equivalent to `self*self`.
|
1965
|
+
|
1966
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1967
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
|
1968
|
+
```
|
1969
|
+
"""
|
1970
|
+
return self*self
|
1971
|
+
def clip(self, min_, max_):
|
1972
|
+
"""
|
1973
|
+
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
1974
|
+
|
1975
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1976
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
|
1977
|
+
```
|
1978
|
+
"""
|
1979
|
+
return self.maximum(min_).minimum(max_)
|
1980
|
+
def sign(self):
|
1981
|
+
"""
|
1982
|
+
Returns the sign of the tensor element-wise.
|
1983
|
+
|
1984
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1985
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
1986
|
+
```
|
1987
|
+
"""
|
1988
|
+
return F.Sign.apply(self)
|
1989
|
+
def abs(self):
|
1990
|
+
"""
|
1991
|
+
Computes the absolute value of the tensor element-wise.
|
1992
|
+
|
1993
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1994
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
|
1995
|
+
```
|
1996
|
+
"""
|
1997
|
+
return self * self.sign()
|
1998
|
+
def reciprocal(self):
|
1999
|
+
"""
|
2000
|
+
Compute `1/x` element-wise.
|
2001
|
+
|
2002
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2003
|
+
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
2004
|
+
```
|
2005
|
+
"""
|
2006
|
+
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
2007
|
+
|
2008
|
+
# ***** activation functions *****
|
2009
|
+
|
2010
|
+
def elu(self, alpha=1.0):
|
2011
|
+
"""
|
2012
|
+
Applies the Exponential Linear Unit (ELU) function element-wise.
|
2013
|
+
|
2014
|
+
- Described: https://paperswithcode.com/method/elu
|
2015
|
+
- Paper: https://arxiv.org/abs/1511.07289v5
|
2016
|
+
|
2017
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2018
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
|
2019
|
+
```
|
2020
|
+
"""
|
2021
|
+
return self.relu() - alpha*(1-self.exp()).relu()
|
2022
|
+
|
2023
|
+
def celu(self, alpha=1.0):
|
2024
|
+
"""
|
2025
|
+
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
|
2026
|
+
|
2027
|
+
- Described: https://paperswithcode.com/method/celu
|
2028
|
+
- Paper: https://arxiv.org/abs/1704.07483
|
2029
|
+
|
2030
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2031
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
|
2032
|
+
```
|
2033
|
+
"""
|
2034
|
+
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
2035
|
+
|
2036
|
+
def swish(self):
|
2037
|
+
"""
|
2038
|
+
See `.silu()`
|
2039
|
+
|
2040
|
+
- Paper: https://arxiv.org/abs/1710.05941v1
|
2041
|
+
|
2042
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2043
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
|
2044
|
+
```
|
2045
|
+
"""
|
2046
|
+
return self * self.sigmoid()
|
2047
|
+
|
2048
|
+
def silu(self):
|
2049
|
+
"""
|
2050
|
+
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
|
2051
|
+
|
2052
|
+
- Described: https://paperswithcode.com/method/silu
|
2053
|
+
- Paper: https://arxiv.org/abs/1606.08415
|
2054
|
+
|
2055
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2056
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
|
2057
|
+
```
|
2058
|
+
"""
|
2059
|
+
return self.swish() # The SiLU function is also known as the swish function.
|
2060
|
+
|
2061
|
+
def relu6(self):
|
2062
|
+
"""
|
2063
|
+
Applies the ReLU6 function element-wise.
|
2064
|
+
|
2065
|
+
- Described: https://paperswithcode.com/method/relu6
|
2066
|
+
- Paper: https://arxiv.org/abs/1704.04861v1
|
2067
|
+
|
2068
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2069
|
+
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
|
2070
|
+
```
|
2071
|
+
"""
|
2072
|
+
return self.relu() - (self-6).relu()
|
2073
|
+
|
2074
|
+
def hardswish(self):
|
2075
|
+
"""
|
2076
|
+
Applies the Hardswish function element-wise.
|
2077
|
+
|
2078
|
+
- Described: https://paperswithcode.com/method/hard-swish
|
2079
|
+
- Paper: https://arxiv.org/abs/1905.02244v5
|
2080
|
+
|
2081
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2082
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
|
2083
|
+
```
|
2084
|
+
"""
|
2085
|
+
return self * (self+3).relu6() * (1/6)
|
2086
|
+
|
2087
|
+
def tanh(self):
|
2088
|
+
"""
|
2089
|
+
Applies the Hyperbolic Tangent (tanh) function element-wise.
|
2090
|
+
|
2091
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
|
2092
|
+
|
2093
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2094
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
|
2095
|
+
```
|
2096
|
+
"""
|
2097
|
+
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
2098
|
+
|
2099
|
+
def sinh(self):
|
2100
|
+
"""
|
2101
|
+
Applies the Hyperbolic Sine (sinh) function element-wise.
|
2102
|
+
|
2103
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
|
2104
|
+
|
2105
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2106
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
|
2107
|
+
```
|
2108
|
+
"""
|
2109
|
+
return (self.exp() - self.neg().exp()) / 2
|
2110
|
+
|
2111
|
+
def cosh(self):
|
2112
|
+
"""
|
2113
|
+
Applies the Hyperbolic Cosine (cosh) function element-wise.
|
2114
|
+
|
2115
|
+
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
|
2116
|
+
|
2117
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2118
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
|
2119
|
+
```
|
2120
|
+
"""
|
2121
|
+
return (self.exp() + self.neg().exp()) / 2
|
2122
|
+
|
2123
|
+
def atanh(self):
|
2124
|
+
"""
|
2125
|
+
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
|
2126
|
+
|
2127
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
|
2128
|
+
|
2129
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2130
|
+
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
|
2131
|
+
```
|
2132
|
+
"""
|
2133
|
+
return ((1 + self)/(1 - self)).log() / 2
|
2134
|
+
|
2135
|
+
def asinh(self):
|
2136
|
+
"""
|
2137
|
+
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
|
2138
|
+
|
2139
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
|
2140
|
+
|
2141
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2142
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
|
2143
|
+
```
|
2144
|
+
"""
|
2145
|
+
return (self + (self.square() + 1).sqrt()).log()
|
2146
|
+
|
2147
|
+
def acosh(self):
|
2148
|
+
"""
|
2149
|
+
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
|
2150
|
+
|
2151
|
+
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
|
2152
|
+
|
2153
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2154
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
|
2155
|
+
```
|
2156
|
+
"""
|
2157
|
+
return (self + (self.square() - 1).sqrt()).log()
|
2158
|
+
|
2159
|
+
def hardtanh(self, min_val=-1, max_val=1):
|
2160
|
+
"""
|
2161
|
+
Applies the Hardtanh function element-wise.
|
2162
|
+
|
2163
|
+
- Described: https://paperswithcode.com/method/hardtanh-activation
|
2164
|
+
|
2165
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2166
|
+
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
|
2167
|
+
```
|
2168
|
+
"""
|
2169
|
+
return self.clip(min_val, max_val)
|
2170
|
+
|
2171
|
+
def gelu(self):
|
2172
|
+
"""
|
2173
|
+
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
2174
|
+
|
2175
|
+
- Described: https://paperswithcode.com/method/gelu
|
2176
|
+
- Paper: https://arxiv.org/abs/1606.08415v5
|
2177
|
+
|
2178
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2179
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
|
2180
|
+
```
|
2181
|
+
"""
|
2182
|
+
return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
2183
|
+
|
2184
|
+
def quick_gelu(self):
|
2185
|
+
"""
|
2186
|
+
Applies the Sigmoid GELU approximation element-wise.
|
2187
|
+
|
2188
|
+
- Described: https://paperswithcode.com/method/gelu
|
2189
|
+
|
2190
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2191
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
|
2192
|
+
```
|
2193
|
+
"""
|
2194
|
+
return self * (self * 1.702).sigmoid()
|
2195
|
+
|
2196
|
+
def leakyrelu(self, neg_slope=0.01):
|
2197
|
+
"""
|
2198
|
+
Applies the Leaky ReLU function element-wise.
|
2199
|
+
|
2200
|
+
- Described: https://paperswithcode.com/method/leaky-relu
|
2201
|
+
|
2202
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2203
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
|
2204
|
+
```
|
2205
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2206
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
|
2207
|
+
```
|
2208
|
+
"""
|
2209
|
+
return self.relu() - (-neg_slope*self).relu()
|
2210
|
+
|
2211
|
+
def mish(self):
|
2212
|
+
"""
|
2213
|
+
Applies the Mish function element-wise.
|
2214
|
+
|
2215
|
+
- Described: https://paperswithcode.com/method/mish
|
2216
|
+
- Paper: https://arxiv.org/abs/1908.08681v3
|
2217
|
+
|
2218
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2219
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
|
2220
|
+
```
|
2221
|
+
"""
|
2222
|
+
return self * self.softplus().tanh()
|
2223
|
+
|
2224
|
+
def softplus(self, beta=1):
|
2225
|
+
"""
|
2226
|
+
Applies the Softplus function element-wise.
|
2227
|
+
|
2228
|
+
- Described: https://paperswithcode.com/method/softplus
|
2229
|
+
|
2230
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2231
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
|
2232
|
+
```
|
2233
|
+
"""
|
2234
|
+
return (1/beta) * (1 + (self*beta).exp()).log()
|
2235
|
+
|
2236
|
+
def softsign(self):
|
2237
|
+
"""
|
2238
|
+
Applies the Softsign function element-wise.
|
2239
|
+
|
2240
|
+
- Described: https://paperswithcode.com/method/softsign
|
2241
|
+
|
2242
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2243
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
|
2244
|
+
```
|
2245
|
+
"""
|
2246
|
+
return self / (1 + self.abs())
|
2247
|
+
|
2248
|
+
# ***** broadcasted elementwise ops *****
|
2249
|
+
def _broadcast_to(self, shape:Tuple[sint, ...]):
|
2250
|
+
reshape_arg, _ = _pad_left(self.shape, shape)
|
2251
|
+
if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
|
2252
|
+
raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
|
2253
|
+
return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
|
2254
|
+
|
2255
|
+
def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
768
2256
|
x: Tensor = self
|
769
2257
|
if not isinstance(y, Tensor):
|
770
2258
|
# make y a Tensor
|
771
|
-
|
2259
|
+
assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
|
772
2260
|
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
773
2261
|
else: y_dtype = dtypes.from_py(y)
|
774
|
-
y = Tensor(y, self.device
|
2262
|
+
if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
|
2263
|
+
else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
|
775
2264
|
|
776
2265
|
if match_dtype:
|
777
2266
|
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
@@ -779,63 +2268,231 @@ class Tensor:
|
|
779
2268
|
|
780
2269
|
if reverse: x, y = y, x
|
781
2270
|
|
782
|
-
#
|
783
|
-
|
784
|
-
|
2271
|
+
# broadcast
|
2272
|
+
out_shape = _broadcast_shape(x.shape, y.shape)
|
2273
|
+
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
785
2274
|
|
786
|
-
|
787
|
-
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
788
|
-
|
789
|
-
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
|
2275
|
+
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
790
2276
|
# TODO: update with multi
|
791
|
-
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.
|
2277
|
+
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
792
2278
|
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
793
2279
|
|
794
|
-
def add(self, x:Union[Tensor,
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
2280
|
+
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2281
|
+
"""
|
2282
|
+
Adds `self` and `x`.
|
2283
|
+
Equivalent to `self + x`.
|
2284
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2285
|
+
|
2286
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2287
|
+
Tensor.manual_seed(42)
|
2288
|
+
t = Tensor.randn(4)
|
2289
|
+
print(t.numpy())
|
2290
|
+
```
|
2291
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2292
|
+
print(t.add(20).numpy())
|
2293
|
+
```
|
2294
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2295
|
+
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
2296
|
+
```
|
2297
|
+
"""
|
2298
|
+
return F.Add.apply(*self._broadcasted(x, reverse))
|
2299
|
+
|
2300
|
+
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2301
|
+
"""
|
2302
|
+
Subtracts `x` from `self`.
|
2303
|
+
Equivalent to `self - x`.
|
2304
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2305
|
+
|
2306
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2307
|
+
Tensor.manual_seed(42)
|
2308
|
+
t = Tensor.randn(4)
|
2309
|
+
print(t.numpy())
|
2310
|
+
```
|
2311
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2312
|
+
print(t.sub(20).numpy())
|
2313
|
+
```
|
2314
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2315
|
+
print(t.sub(Tensor([[2.0], [3.5]])).numpy())
|
2316
|
+
```
|
2317
|
+
"""
|
2318
|
+
return F.Sub.apply(*self._broadcasted(x, reverse))
|
2319
|
+
|
2320
|
+
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2321
|
+
"""
|
2322
|
+
Multiplies `self` and `x`.
|
2323
|
+
Equivalent to `self * x`.
|
2324
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2325
|
+
|
2326
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2327
|
+
Tensor.manual_seed(42)
|
2328
|
+
t = Tensor.randn(4)
|
2329
|
+
print(t.numpy())
|
2330
|
+
```
|
2331
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2332
|
+
print(t.mul(3).numpy())
|
2333
|
+
```
|
2334
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2335
|
+
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
2336
|
+
```
|
2337
|
+
"""
|
2338
|
+
return F.Mul.apply(*self._broadcasted(x, reverse))
|
2339
|
+
|
2340
|
+
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
|
2341
|
+
"""
|
2342
|
+
Divides `self` by `x`.
|
2343
|
+
Equivalent to `self / x`.
|
2344
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2345
|
+
By default, `div` performs true division. Set `upcast` to `False` for integer division.
|
2346
|
+
|
2347
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2348
|
+
Tensor.manual_seed(42)
|
2349
|
+
t = Tensor.randn(4)
|
2350
|
+
print(t.numpy())
|
2351
|
+
```
|
2352
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2353
|
+
print(t.div(3).numpy())
|
2354
|
+
```
|
2355
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2356
|
+
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
2357
|
+
```
|
2358
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2359
|
+
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
|
2360
|
+
```
|
2361
|
+
"""
|
2362
|
+
numerator, denominator = self._broadcasted(x, reverse)
|
2363
|
+
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
|
2364
|
+
return F.Div.apply(numerator, denominator)
|
2365
|
+
|
2366
|
+
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2367
|
+
"""
|
2368
|
+
Computes bitwise xor of `self` and `x`.
|
2369
|
+
Equivalent to `self ^ x`.
|
2370
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
2371
|
+
|
2372
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2373
|
+
print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
|
2374
|
+
```
|
2375
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2376
|
+
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
|
2377
|
+
```
|
2378
|
+
"""
|
2379
|
+
return F.Xor.apply(*self._broadcasted(x, reverse))
|
2380
|
+
|
2381
|
+
def lshift(self, x:int):
|
2382
|
+
"""
|
2383
|
+
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
2384
|
+
Equivalent to `self << x`.
|
2385
|
+
|
2386
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2387
|
+
print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
|
2388
|
+
```
|
2389
|
+
"""
|
2390
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
2391
|
+
return self.mul(2 ** x)
|
2392
|
+
|
2393
|
+
def rshift(self, x:int):
|
2394
|
+
"""
|
2395
|
+
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
2396
|
+
Equivalent to `self >> x`.
|
2397
|
+
|
2398
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2399
|
+
print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
|
2400
|
+
```
|
2401
|
+
"""
|
2402
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
2403
|
+
return self.div(2 ** x, upcast=False)
|
2404
|
+
|
2405
|
+
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2406
|
+
"""
|
2407
|
+
Computes power of `self` with `x`.
|
2408
|
+
Equivalent to `self ** x`.
|
2409
|
+
|
2410
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2411
|
+
print(Tensor([-1, 2, 3]).pow(2).numpy())
|
2412
|
+
```
|
2413
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2414
|
+
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
|
2415
|
+
```
|
2416
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2417
|
+
print((2 ** Tensor([-1, 2, 3])).numpy())
|
2418
|
+
```
|
2419
|
+
"""
|
809
2420
|
x = self._to_const_val(x)
|
810
2421
|
if not isinstance(x, Tensor) and not reverse:
|
811
2422
|
# simple pow identities
|
812
2423
|
if x < 0: return self.reciprocal().pow(-x)
|
813
|
-
if x
|
814
|
-
if x
|
2424
|
+
if x == 0: return 1 + self * 0
|
2425
|
+
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
2426
|
+
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
2427
|
+
|
2428
|
+
# positive const ** self
|
815
2429
|
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
inject_nan = (
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
2430
|
+
|
2431
|
+
base, exponent = self._broadcasted(x, reverse=reverse)
|
2432
|
+
# start with b ** e = exp(e * log(b))
|
2433
|
+
ret = base.abs().log().mul(exponent).exp()
|
2434
|
+
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
2435
|
+
negative_base = (base < 0).detach().where(1, 0)
|
2436
|
+
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
2437
|
+
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
2438
|
+
# inject nan for negative base and non-integer exponent
|
2439
|
+
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
2440
|
+
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
2441
|
+
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
2442
|
+
|
2443
|
+
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
2444
|
+
"""
|
2445
|
+
Computes element-wise maximum of `self` and `x`.
|
2446
|
+
|
2447
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2448
|
+
print(Tensor([-1, 2, 3]).maximum(1).numpy())
|
2449
|
+
```
|
2450
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2451
|
+
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
2452
|
+
```
|
2453
|
+
"""
|
2454
|
+
return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
|
2455
|
+
|
2456
|
+
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
2457
|
+
"""
|
2458
|
+
Computes element-wise minimum of `self` and `x`.
|
2459
|
+
|
2460
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2461
|
+
print(Tensor([-1, 2, 3]).minimum(1).numpy())
|
2462
|
+
```
|
2463
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2464
|
+
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
2465
|
+
```
|
2466
|
+
"""
|
2467
|
+
return -((-self).maximum(-x))
|
2468
|
+
|
2469
|
+
def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
|
2470
|
+
"""
|
2471
|
+
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
2472
|
+
`output_i = x_i if self_i else y_i`.
|
2473
|
+
|
2474
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2475
|
+
cond = Tensor([[True, True, False], [True, False, False]])
|
2476
|
+
print(cond.where(1, 3).numpy())
|
2477
|
+
```
|
2478
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2479
|
+
Tensor.manual_seed(42)
|
2480
|
+
cond = Tensor.randn(2, 3)
|
2481
|
+
print(cond.numpy())
|
2482
|
+
```
|
2483
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2484
|
+
print((cond > 0).where(cond, -float("inf")).numpy())
|
2485
|
+
```
|
2486
|
+
"""
|
2487
|
+
if isinstance(x, Tensor): x, y = x._broadcasted(y)
|
2488
|
+
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
2489
|
+
cond, x = self._broadcasted(x, match_dtype=False)
|
2490
|
+
cond, y = cond._broadcasted(y, match_dtype=False)
|
2491
|
+
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
|
2492
|
+
|
2493
|
+
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
2494
|
+
|
2495
|
+
# ***** op wrappers *****
|
839
2496
|
|
840
2497
|
def __neg__(self) -> Tensor: return self.neg()
|
841
2498
|
|
@@ -846,6 +2503,8 @@ class Tensor:
|
|
846
2503
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
847
2504
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
848
2505
|
def __xor__(self, x) -> Tensor: return self.xor(x)
|
2506
|
+
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
2507
|
+
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
849
2508
|
|
850
2509
|
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
851
2510
|
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
@@ -862,38 +2521,134 @@ class Tensor:
|
|
862
2521
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
863
2522
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
864
2523
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
2524
|
+
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
2525
|
+
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
865
2526
|
|
866
|
-
def __lt__(self, x) -> Tensor: return
|
867
|
-
def __gt__(self, x) -> Tensor: return
|
2527
|
+
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
2528
|
+
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
868
2529
|
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
869
2530
|
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
870
|
-
def __eq__(self, x) -> Tensor: return
|
2531
|
+
def __eq__(self, x) -> Tensor: return F.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
|
871
2532
|
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
|
872
2533
|
|
873
2534
|
# ***** functional nn ops *****
|
874
2535
|
|
875
2536
|
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
2537
|
+
"""
|
2538
|
+
Applies a linear transformation to `self` using `weight` and `bias`.
|
2539
|
+
|
2540
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
|
2541
|
+
|
2542
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2543
|
+
t = Tensor([[1, 2], [3, 4]])
|
2544
|
+
weight = Tensor([[1, 2], [3, 4]])
|
2545
|
+
bias = Tensor([1, 2])
|
2546
|
+
print(t.linear(weight, bias).numpy())
|
2547
|
+
```
|
2548
|
+
"""
|
876
2549
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
877
2550
|
return x.add(bias) if bias is not None else x
|
878
2551
|
|
879
|
-
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
|
2552
|
+
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
|
2553
|
+
"""
|
2554
|
+
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
2555
|
+
|
2556
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2557
|
+
t = Tensor([1, 2, 3])
|
2558
|
+
print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
|
2559
|
+
```
|
2560
|
+
"""
|
2561
|
+
return functools.reduce(lambda x,f: f(x), ll, self)
|
880
2562
|
|
881
2563
|
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
2564
|
+
"""
|
2565
|
+
Applies Layer Normalization over a mini-batch of inputs.
|
2566
|
+
|
2567
|
+
- Described: https://paperswithcode.com/method/layer-normalization
|
2568
|
+
- Paper: https://arxiv.org/abs/1607.06450v1
|
2569
|
+
|
2570
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2571
|
+
t = Tensor.randn(8, 10, 16) * 2 + 8
|
2572
|
+
print(t.mean().item(), t.std().item())
|
2573
|
+
```
|
2574
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2575
|
+
t = t.layernorm()
|
2576
|
+
print(t.mean().item(), t.std().item())
|
2577
|
+
```
|
2578
|
+
"""
|
882
2579
|
y = (self - self.mean(axis, keepdim=True))
|
883
2580
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
884
2581
|
|
885
|
-
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
2582
|
+
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
2583
|
+
"""
|
2584
|
+
Applies Batch Normalization over a mini-batch of inputs.
|
2585
|
+
|
2586
|
+
- Described: https://paperswithcode.com/method/batch-normalization
|
2587
|
+
- Paper: https://arxiv.org/abs/1502.03167
|
2588
|
+
|
2589
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2590
|
+
t = Tensor.randn(8, 4, 16, 16) * 2 + 8
|
2591
|
+
print(t.mean().item(), t.std().item())
|
2592
|
+
```
|
2593
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2594
|
+
t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
|
2595
|
+
print(t.mean().item(), t.std().item())
|
2596
|
+
```
|
2597
|
+
"""
|
2598
|
+
axis_ = argfix(axis)
|
2599
|
+
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
|
2600
|
+
x = self - mean.reshape(shape)
|
2601
|
+
if weight is not None: x = x * weight.reshape(shape)
|
2602
|
+
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
|
2603
|
+
return (ret + bias.reshape(shape)) if bias is not None else ret
|
890
2604
|
|
891
2605
|
def dropout(self, p=0.5) -> Tensor:
|
892
|
-
|
893
|
-
|
2606
|
+
"""
|
2607
|
+
Applies dropout to `self`.
|
2608
|
+
|
2609
|
+
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
2610
|
+
|
2611
|
+
- Described: https://paperswithcode.com/method/dropout
|
2612
|
+
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
894
2613
|
|
895
|
-
|
896
|
-
|
2614
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2615
|
+
Tensor.manual_seed(42)
|
2616
|
+
t = Tensor.randn(2, 2)
|
2617
|
+
with Tensor.train():
|
2618
|
+
print(t.dropout().numpy())
|
2619
|
+
```
|
2620
|
+
"""
|
2621
|
+
if not Tensor.training or p == 0: return self
|
2622
|
+
return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
|
2623
|
+
|
2624
|
+
def one_hot(self, num_classes:int) -> Tensor:
|
2625
|
+
"""
|
2626
|
+
Converts `self` to a one-hot tensor.
|
2627
|
+
|
2628
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2629
|
+
t = Tensor([0, 1, 3, 3, 4])
|
2630
|
+
print(t.one_hot(5).numpy())
|
2631
|
+
```
|
2632
|
+
"""
|
2633
|
+
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
2634
|
+
|
2635
|
+
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
2636
|
+
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
2637
|
+
"""
|
2638
|
+
Computes scaled dot-product attention.
|
2639
|
+
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
2640
|
+
|
2641
|
+
- Described: https://paperswithcode.com/method/scaled
|
2642
|
+
- Paper: https://arxiv.org/abs/1706.03762v7
|
2643
|
+
|
2644
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2645
|
+
q = Tensor.randn(2, 4, 8)
|
2646
|
+
k = Tensor.randn(2, 4, 8)
|
2647
|
+
v = Tensor.randn(2, 4, 8)
|
2648
|
+
print(q.scaled_dot_product_attention(k, v).numpy())
|
2649
|
+
```
|
2650
|
+
"""
|
2651
|
+
# NOTE: it also works when `key` and `value` have symbolic shape.
|
897
2652
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
898
2653
|
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
899
2654
|
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
@@ -901,30 +2656,119 @@ class Tensor:
|
|
901
2656
|
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
|
902
2657
|
|
903
2658
|
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
2659
|
+
"""
|
2660
|
+
Computes the binary cross-entropy loss between `self` and `y`.
|
2661
|
+
|
2662
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
|
2663
|
+
|
2664
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2665
|
+
t = Tensor([0.1, 0.9, 0.2])
|
2666
|
+
y = Tensor([0, 1, 0])
|
2667
|
+
print(t.binary_crossentropy(y).item())
|
2668
|
+
```
|
2669
|
+
"""
|
904
2670
|
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
905
2671
|
|
906
2672
|
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
2673
|
+
"""
|
2674
|
+
Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
|
2675
|
+
|
2676
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
|
2677
|
+
|
2678
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2679
|
+
t = Tensor([-1, 2, -3])
|
2680
|
+
y = Tensor([0, 1, 0])
|
2681
|
+
print(t.binary_crossentropy_logits(y).item())
|
2682
|
+
```
|
2683
|
+
"""
|
907
2684
|
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
|
908
2685
|
|
909
|
-
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
|
910
|
-
|
911
|
-
|
2686
|
+
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
|
2687
|
+
"""
|
2688
|
+
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
|
2689
|
+
|
2690
|
+
NOTE: `self` is logits and `Y` is the target labels.
|
2691
|
+
|
2692
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
|
2693
|
+
|
2694
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2695
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
2696
|
+
Y = Tensor([1, 2])
|
2697
|
+
print(t.sparse_categorical_crossentropy(Y).item())
|
2698
|
+
```
|
2699
|
+
"""
|
2700
|
+
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
2701
|
+
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
912
2702
|
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
913
|
-
y = ((y_counter == Y.flatten().reshape(-1, 1))
|
914
|
-
|
2703
|
+
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
2704
|
+
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
2705
|
+
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
915
2706
|
|
916
2707
|
# ***** cast ops *****
|
917
2708
|
|
918
|
-
def
|
919
|
-
if self.dtype == dtype: return self
|
2709
|
+
def llvm_bf16_cast(self, dtype:DType):
|
920
2710
|
# hack for devices that don't support bfloat16
|
921
|
-
|
922
|
-
return
|
2711
|
+
assert self.dtype == dtypes.bfloat16
|
2712
|
+
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
2713
|
+
def cast(self, dtype:DType) -> Tensor:
|
2714
|
+
"""
|
2715
|
+
Casts `self` to the given `dtype`.
|
2716
|
+
|
2717
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2718
|
+
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
2719
|
+
print(t.dtype, t.numpy())
|
2720
|
+
```
|
2721
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2722
|
+
t = t.cast(dtypes.int32)
|
2723
|
+
print(t.dtype, t.numpy())
|
2724
|
+
```
|
2725
|
+
"""
|
2726
|
+
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
|
923
2727
|
def bitcast(self, dtype:DType) -> Tensor:
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
2728
|
+
"""
|
2729
|
+
Bitcasts `self` to the given `dtype` of the same itemsize.
|
2730
|
+
|
2731
|
+
`self` must not require a gradient.
|
2732
|
+
|
2733
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2734
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2735
|
+
print(t.dtype, t.numpy())
|
2736
|
+
```
|
2737
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2738
|
+
t = t.bitcast(dtypes.uint32)
|
2739
|
+
print(t.dtype, t.numpy())
|
2740
|
+
```
|
2741
|
+
"""
|
2742
|
+
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
2743
|
+
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
2744
|
+
def float(self) -> Tensor:
|
2745
|
+
"""
|
2746
|
+
Convenience method to cast `self` to a `float32` Tensor.
|
2747
|
+
|
2748
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2749
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2750
|
+
print(t.dtype, t.numpy())
|
2751
|
+
```
|
2752
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2753
|
+
t = t.float()
|
2754
|
+
print(t.dtype, t.numpy())
|
2755
|
+
```
|
2756
|
+
"""
|
2757
|
+
return self.cast(dtypes.float32)
|
2758
|
+
def half(self) -> Tensor:
|
2759
|
+
"""
|
2760
|
+
Convenience method to cast `self` to a `float16` Tensor.
|
2761
|
+
|
2762
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2763
|
+
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
|
2764
|
+
print(t.dtype, t.numpy())
|
2765
|
+
```
|
2766
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2767
|
+
t = t.half()
|
2768
|
+
print(t.dtype, t.numpy())
|
2769
|
+
```
|
2770
|
+
"""
|
2771
|
+
return self.cast(dtypes.float16)
|
928
2772
|
|
929
2773
|
# ***** convenience stuff *****
|
930
2774
|
|
@@ -934,20 +2778,101 @@ class Tensor:
|
|
934
2778
|
def element_size(self) -> int: return self.dtype.itemsize
|
935
2779
|
def nbytes(self) -> int: return self.numel() * self.element_size()
|
936
2780
|
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
2781
|
+
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
|
2782
|
+
|
2783
|
+
# *** image Tensor function replacements ***
|
2784
|
+
|
2785
|
+
def image_dot(self, w:Tensor, acc_dtype=None):
|
2786
|
+
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
2787
|
+
n1, n2 = len(self.shape), len(w.shape)
|
2788
|
+
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
2789
|
+
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
2790
|
+
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
|
2791
|
+
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
2792
|
+
|
2793
|
+
# NOTE: with NHWC we can remove the transposes
|
2794
|
+
# bs x groups*cin x H x W
|
2795
|
+
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
2796
|
+
# groups*cout x cin x H, W
|
2797
|
+
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
2798
|
+
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
2799
|
+
|
2800
|
+
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
|
2801
|
+
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
2802
|
+
|
2803
|
+
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
2804
|
+
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
2805
|
+
|
2806
|
+
# hack for non multiples of 4 on cin
|
2807
|
+
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
2808
|
+
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
2809
|
+
added_input_channels = 4 - (cin % 4)
|
2810
|
+
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
|
2811
|
+
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
|
2812
|
+
cin = cin + added_input_channels
|
2813
|
+
x = x.reshape(bs, groups*cin, iy, ix)
|
2814
|
+
|
2815
|
+
# hack for non multiples of 4 on rcout
|
2816
|
+
added_output_channels = 0
|
2817
|
+
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
2818
|
+
added_output_channels = 4 - (rcout % 4)
|
2819
|
+
rcout += added_output_channels
|
2820
|
+
cout = groups * rcout
|
2821
|
+
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
|
2822
|
+
|
2823
|
+
# packed (note: flipping bs and iy would make the auto-padding work)
|
2824
|
+
x = x.permute(0,2,3,1)
|
2825
|
+
cin_last = iy == 1 and ix == 1
|
2826
|
+
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
|
2827
|
+
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
|
2828
|
+
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
|
2829
|
+
|
2830
|
+
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
2831
|
+
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
2832
|
+
x, w = x.contiguous(), w.contiguous()
|
2833
|
+
|
2834
|
+
# expand out
|
2835
|
+
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
2836
|
+
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
2837
|
+
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
2838
|
+
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
2839
|
+
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
2840
|
+
|
2841
|
+
# padding
|
2842
|
+
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
2843
|
+
x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
2844
|
+
|
2845
|
+
# prepare input
|
2846
|
+
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
2847
|
+
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
2848
|
+
|
2849
|
+
# prepare weights
|
2850
|
+
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
2851
|
+
|
2852
|
+
# the conv!
|
2853
|
+
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
|
2854
|
+
|
2855
|
+
# undo hack for non multiples of 4 on C.rcout
|
2856
|
+
if added_output_channels != 0:
|
2857
|
+
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
2858
|
+
cout = groups * (rcout - added_output_channels)
|
2859
|
+
|
2860
|
+
# NCHW output
|
2861
|
+
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
2862
|
+
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
937
2863
|
|
938
2864
|
# register functions to move between devices
|
939
|
-
for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
2865
|
+
for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
|
940
2866
|
|
941
2867
|
if IMAGE:
|
942
2868
|
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
943
|
-
|
944
|
-
setattr(Tensor, "
|
945
|
-
setattr(Tensor, "dot", image_dot)
|
2869
|
+
setattr(Tensor, "conv2d", Tensor.image_conv2d)
|
2870
|
+
setattr(Tensor, "dot", Tensor.image_dot)
|
946
2871
|
|
947
|
-
# TODO: remove
|
2872
|
+
# TODO: eventually remove this
|
948
2873
|
def custom_random(out:Buffer):
|
949
2874
|
Tensor._seed += 1
|
950
|
-
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
|
951
2875
|
rng = np.random.default_rng(Tensor._seed)
|
952
|
-
rng_np_buffer = rng.
|
2876
|
+
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
2877
|
+
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
|
953
2878
|
out.copyin(rng_np_buffer.data)
|