tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,37 +1,37 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools, functools, struct
|
3
|
+
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
|
4
4
|
from contextlib import ContextDecorator
|
5
|
-
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args,
|
5
|
+
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
6
6
|
from collections import defaultdict
|
7
|
-
import numpy as np
|
8
7
|
|
9
|
-
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
10
|
-
from tinygrad.helpers import argfix,
|
11
|
-
from tinygrad.helpers import IMAGE, DEBUG, WINO,
|
12
|
-
from tinygrad.lazy import LazyBuffer
|
8
|
+
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
9
|
+
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
10
|
+
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
13
11
|
from tinygrad.multi import MultiLazyBuffer
|
14
|
-
from tinygrad.ops import
|
12
|
+
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
|
15
13
|
from tinygrad.device import Device, Buffer, BufferOptions
|
16
|
-
from tinygrad.
|
14
|
+
from tinygrad.engine.lazy import LazyBuffer
|
17
15
|
from tinygrad.engine.realize import run_schedule
|
18
|
-
from tinygrad.engine.
|
16
|
+
from tinygrad.engine.memory import memory_planner
|
17
|
+
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
19
18
|
|
20
19
|
# **** start with two base classes, Tensor and Function ****
|
21
20
|
|
22
21
|
class Function:
|
23
|
-
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
|
22
|
+
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
24
23
|
self.device = device
|
25
24
|
self.needs_input_grad = [t.requires_grad for t in tensors]
|
26
25
|
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
27
26
|
if self.requires_grad: self.parents = tensors
|
27
|
+
self.metadata = metadata
|
28
28
|
|
29
29
|
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
30
30
|
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
31
31
|
|
32
32
|
@classmethod
|
33
33
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
34
|
-
ctx = fxn(x[0].device, *x)
|
34
|
+
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
|
35
35
|
ret = Tensor.__new__(Tensor)
|
36
36
|
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
37
37
|
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
@@ -39,29 +39,39 @@ class Function:
|
|
39
39
|
|
40
40
|
import tinygrad.function as F
|
41
41
|
|
42
|
-
def
|
43
|
-
if isinstance(device, str): return LazyBuffer.
|
44
|
-
return MultiLazyBuffer([LazyBuffer.
|
42
|
+
def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
43
|
+
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
|
44
|
+
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
45
45
|
|
46
|
-
def _from_np_dtype(npdtype:
|
47
|
-
|
46
|
+
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
47
|
+
import numpy as np
|
48
|
+
return dtypes.fields()[np.dtype(npdtype).name]
|
49
|
+
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
50
|
+
import numpy as np
|
51
|
+
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
48
52
|
|
49
|
-
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
50
|
-
ret = LazyBuffer.
|
53
|
+
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
|
54
|
+
ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
51
55
|
# fake realize
|
52
56
|
ret.buffer.allocate(x)
|
53
57
|
del ret.srcs
|
54
58
|
return ret
|
55
59
|
|
60
|
+
def get_shape(x) -> Tuple[int, ...]:
|
61
|
+
# NOTE: str is special because __getitem__ on a str is still a str
|
62
|
+
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
|
63
|
+
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
64
|
+
return (len(subs),) + (subs[0] if subs else ())
|
65
|
+
|
56
66
|
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
57
|
-
if isinstance(x, bytes): ret, data = LazyBuffer.
|
67
|
+
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
58
68
|
else:
|
59
|
-
ret = LazyBuffer.
|
69
|
+
ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
|
60
70
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
61
71
|
truncate_function = truncate[dtype]
|
62
72
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
63
73
|
# fake realize
|
64
|
-
ret.buffer.allocate(memoryview(data))
|
74
|
+
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
65
75
|
del ret.srcs
|
66
76
|
return ret
|
67
77
|
|
@@ -85,9 +95,11 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
|
85
95
|
max_dim = max(len(shape) for shape in shapes)
|
86
96
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
87
97
|
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
88
|
-
return tuple(0 if
|
98
|
+
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
99
|
+
|
100
|
+
ReductionStr = Literal["mean", "sum", "none"]
|
89
101
|
|
90
|
-
class Tensor:
|
102
|
+
class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
91
103
|
"""
|
92
104
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
93
105
|
|
@@ -103,9 +115,11 @@ class Tensor:
|
|
103
115
|
training: ClassVar[bool] = False
|
104
116
|
no_grad: ClassVar[bool] = False
|
105
117
|
|
106
|
-
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray,
|
107
|
-
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[
|
118
|
+
def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, LazyBuffer, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
119
|
+
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
120
|
+
if dtype is not None: dtype = to_dtype(dtype)
|
108
121
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
122
|
+
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
109
123
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
110
124
|
|
111
125
|
# tensors can have gradients if you have called .backward
|
@@ -119,42 +133,45 @@ class Tensor:
|
|
119
133
|
self._ctx: Optional[Function] = None
|
120
134
|
|
121
135
|
# create a LazyBuffer from the different types of inputs
|
122
|
-
if isinstance(data, LazyBuffer): assert dtype is None or dtype
|
123
|
-
elif
|
124
|
-
elif isinstance(data,
|
125
|
-
elif isinstance(data,
|
136
|
+
if isinstance(data, (LazyBuffer, MultiLazyBuffer)): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
137
|
+
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
|
138
|
+
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
139
|
+
elif isinstance(data, UOp):
|
140
|
+
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
|
141
|
+
data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
|
142
|
+
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
126
143
|
elif isinstance(data, (list, tuple)):
|
127
144
|
if dtype is None:
|
128
145
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
129
146
|
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
130
|
-
if dtype == dtypes.bfloat16: data = Tensor(
|
131
|
-
else: data =
|
132
|
-
elif
|
133
|
-
|
134
|
-
|
135
|
-
|
147
|
+
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
148
|
+
else: data = _frompy(data, dtype)
|
149
|
+
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
150
|
+
import numpy as np
|
151
|
+
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
152
|
+
if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
153
|
+
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
154
|
+
elif isinstance(data, pathlib.Path):
|
155
|
+
dtype = dtype or dtypes.uint8
|
156
|
+
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
136
157
|
|
137
158
|
# by this point, it has to be a LazyBuffer
|
138
|
-
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
139
|
-
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
159
|
+
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
140
160
|
|
141
|
-
# data
|
142
|
-
if isinstance(device,
|
143
|
-
|
144
|
-
|
145
|
-
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
146
|
-
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
147
|
-
else:
|
148
|
-
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
|
161
|
+
# data might be on a different device
|
162
|
+
if isinstance(device, str): self.lazydata:Union[LazyBuffer, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
|
163
|
+
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
164
|
+
elif isinstance(data, LazyBuffer): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
|
149
165
|
else:
|
150
|
-
|
166
|
+
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
167
|
+
self.lazydata = data
|
151
168
|
|
152
169
|
class train(ContextDecorator):
|
153
170
|
def __init__(self, mode:bool = True): self.mode = mode
|
154
171
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
155
172
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
156
173
|
|
157
|
-
class
|
174
|
+
class test(ContextDecorator):
|
158
175
|
def __init__(self, mode:bool = True): self.mode = mode
|
159
176
|
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
160
177
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
@@ -182,17 +199,18 @@ class Tensor:
|
|
182
199
|
|
183
200
|
# ***** data handlers ****
|
184
201
|
|
185
|
-
def schedule_with_vars(self, *lst:Tensor
|
186
|
-
"""
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
202
|
+
def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
203
|
+
"""
|
204
|
+
Creates the schedule needed to realize these Tensor(s), with Variables.
|
205
|
+
|
206
|
+
NOTE: A Tensor can only be scheduled once.
|
207
|
+
"""
|
208
|
+
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
191
209
|
return memory_planner(schedule), var_vals
|
192
210
|
|
193
|
-
def schedule(self, *lst:Tensor
|
211
|
+
def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
|
194
212
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
195
|
-
schedule, var_vals = self.schedule_with_vars(*lst
|
213
|
+
schedule, var_vals = self.schedule_with_vars(*lst)
|
196
214
|
assert len(var_vals) == 0
|
197
215
|
return schedule
|
198
216
|
|
@@ -226,7 +244,7 @@ class Tensor:
|
|
226
244
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
227
245
|
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
228
246
|
assert not x.requires_grad # self requires_grad is okay?
|
229
|
-
if not self.lazydata.is_realized
|
247
|
+
if not self.lazydata.is_realized: return self.replace(x)
|
230
248
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
231
249
|
return self
|
232
250
|
|
@@ -239,7 +257,7 @@ class Tensor:
|
|
239
257
|
def _data(self) -> memoryview:
|
240
258
|
if 0 in self.shape: return memoryview(bytearray(0))
|
241
259
|
# NOTE: this realizes on the object from as_buffer being a Python object
|
242
|
-
cpu = self.cast(self.dtype.
|
260
|
+
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
243
261
|
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
244
262
|
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
245
263
|
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
@@ -283,7 +301,7 @@ class Tensor:
|
|
283
301
|
"""
|
284
302
|
return self.data().tolist()
|
285
303
|
|
286
|
-
def numpy(self) -> np.ndarray:
|
304
|
+
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
287
305
|
"""
|
288
306
|
Returns the value of this tensor as a `numpy.ndarray`.
|
289
307
|
|
@@ -292,11 +310,21 @@ class Tensor:
|
|
292
310
|
print(repr(t.numpy()))
|
293
311
|
```
|
294
312
|
"""
|
313
|
+
import numpy as np
|
295
314
|
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
296
315
|
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
297
316
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
298
317
|
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
299
318
|
|
319
|
+
def clone(self) -> Tensor:
|
320
|
+
"""
|
321
|
+
Creates a clone of this tensor allocating a seperate buffer for the data.
|
322
|
+
"""
|
323
|
+
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
|
324
|
+
if self.grad is not None: ret.grad = self.grad.clone()
|
325
|
+
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
326
|
+
return ret
|
327
|
+
|
300
328
|
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
301
329
|
"""
|
302
330
|
Moves the tensor to the given device.
|
@@ -318,38 +346,54 @@ class Tensor:
|
|
318
346
|
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
319
347
|
self.lazydata = real.lazydata
|
320
348
|
|
321
|
-
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
349
|
+
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
|
322
350
|
"""
|
323
|
-
Shards the tensor across the given devices.
|
351
|
+
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
|
352
|
+
|
353
|
+
```python exec="true" source="above" session="tensor" result="python"
|
354
|
+
t = Tensor.empty(2, 3)
|
355
|
+
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
|
356
|
+
```
|
357
|
+
|
324
358
|
"""
|
325
359
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
326
|
-
|
327
|
-
if axis is not None
|
328
|
-
|
329
|
-
|
330
|
-
|
360
|
+
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
361
|
+
if axis is not None:
|
362
|
+
if axis < 0: axis += len(self.shape)
|
363
|
+
if splits is None:
|
364
|
+
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
365
|
+
sz = ceildiv(total, len(devices))
|
366
|
+
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
367
|
+
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
368
|
+
boundaries = tuple(itertools.accumulate(splits))
|
369
|
+
bounds = tuple(zip((0,) + boundaries, boundaries))
|
370
|
+
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
371
|
+
|
372
|
+
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
|
331
373
|
"""
|
332
374
|
Shards the tensor across the given devices in place.
|
333
375
|
"""
|
334
|
-
self.lazydata = self.shard(devices, axis).lazydata
|
376
|
+
self.lazydata = self.shard(devices, axis, splits).lazydata
|
335
377
|
return self
|
336
378
|
|
337
379
|
@staticmethod
|
338
|
-
def
|
339
|
-
if
|
340
|
-
if
|
341
|
-
if
|
342
|
-
if
|
343
|
-
|
380
|
+
def from_uop(y:UOp, **kwargs) -> Tensor:
|
381
|
+
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
382
|
+
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
383
|
+
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
384
|
+
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
385
|
+
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
386
|
+
raise RuntimeError(f"unhandled UOp {y}")
|
344
387
|
|
345
|
-
# ***** creation
|
388
|
+
# ***** creation entrypoint *****
|
346
389
|
|
347
390
|
@staticmethod
|
348
|
-
def
|
391
|
+
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
392
|
+
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
349
393
|
if isinstance(device, tuple):
|
350
|
-
return Tensor(MultiLazyBuffer([LazyBuffer.
|
351
|
-
|
352
|
-
return Tensor(LazyBuffer.
|
394
|
+
return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
|
395
|
+
device, dtype, **kwargs)
|
396
|
+
return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
353
397
|
|
354
398
|
@staticmethod
|
355
399
|
def empty(*shape, **kwargs):
|
@@ -364,10 +408,39 @@ class Tensor:
|
|
364
408
|
print(t.shape)
|
365
409
|
```
|
366
410
|
"""
|
367
|
-
return Tensor.
|
411
|
+
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
|
412
|
+
|
413
|
+
@staticmethod
|
414
|
+
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
|
415
|
+
"""
|
416
|
+
Exposes the pointer as a Tensor without taking ownership of the original data.
|
417
|
+
The pointer must remain valid for the entire lifetime of the created Tensor.
|
418
|
+
|
419
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
420
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
421
|
+
"""
|
422
|
+
|
423
|
+
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
|
424
|
+
r.lazydata.buffer.allocate(external_ptr=ptr)
|
425
|
+
del r.lazydata.srcs # fake realize
|
426
|
+
return r
|
427
|
+
|
428
|
+
@staticmethod
|
429
|
+
def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
|
430
|
+
"""
|
431
|
+
Create a Tensor from a URL.
|
432
|
+
|
433
|
+
This is the preferred way to access Internet resources.
|
434
|
+
It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
|
435
|
+
This also will soon become lazy (when possible) and not print progress without DEBUG.
|
436
|
+
|
437
|
+
THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
|
438
|
+
"""
|
439
|
+
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
|
368
440
|
|
369
441
|
_seed: int = int(time.time())
|
370
|
-
|
442
|
+
_device_seeds: Dict[str, Tensor] = {}
|
443
|
+
_device_rng_counters: Dict[str, Tensor] = {}
|
371
444
|
@staticmethod
|
372
445
|
def manual_seed(seed=0):
|
373
446
|
"""
|
@@ -384,10 +457,17 @@ class Tensor:
|
|
384
457
|
print(Tensor.rand(5).numpy())
|
385
458
|
```
|
386
459
|
"""
|
387
|
-
Tensor._seed, Tensor.
|
460
|
+
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
388
461
|
|
389
462
|
@staticmethod
|
390
|
-
def
|
463
|
+
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
464
|
+
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
465
|
+
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
466
|
+
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
467
|
+
return counts0.cat(counts1)
|
468
|
+
|
469
|
+
@staticmethod
|
470
|
+
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
|
391
471
|
"""
|
392
472
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
393
473
|
|
@@ -400,35 +480,55 @@ class Tensor:
|
|
400
480
|
print(t.numpy())
|
401
481
|
```
|
402
482
|
"""
|
403
|
-
if
|
404
|
-
if not
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
483
|
+
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
484
|
+
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
485
|
+
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
486
|
+
_device = device = Device.canonicalize(device)
|
487
|
+
|
488
|
+
# when using MOCKGPU and NV generate rand on CLANG
|
489
|
+
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
|
490
|
+
|
491
|
+
# generate per device seeds and rng counter if we haven't seen this device yet
|
492
|
+
if device not in Tensor._device_seeds:
|
493
|
+
Tensor._device_seeds[device] = Tensor(
|
494
|
+
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
495
|
+
device=device, dtype=dtypes.uint32, requires_grad=False)
|
496
|
+
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
497
|
+
had_counter = False
|
498
|
+
else: had_counter = True
|
499
|
+
|
500
|
+
# if shape has 0, return zero tensor
|
501
|
+
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
502
|
+
num = ceildiv(numel * dtype.itemsize, 4)
|
503
|
+
|
504
|
+
# increment rng counter for devices
|
505
|
+
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
506
|
+
|
507
|
+
# threefry random bits
|
508
|
+
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
509
|
+
counts1 = counts0 + ceildiv(num, 2)
|
510
|
+
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
511
|
+
|
512
|
+
# bitcast to uint with same number of bits
|
513
|
+
_, nmant = dtypes.finfo(dtype)
|
514
|
+
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
|
515
|
+
bits = bits.bitcast(uint_dtype)
|
516
|
+
# only randomize the mantissa bits and set the exponent to 1
|
517
|
+
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
518
|
+
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
519
|
+
# bitcast back to the original dtype and reshape
|
520
|
+
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
521
|
+
|
522
|
+
# move back to the original device if we were using MOCKGPU
|
523
|
+
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
524
|
+
|
425
525
|
out.requires_grad = kwargs.get("requires_grad")
|
426
|
-
return out.contiguous()
|
526
|
+
return out.contiguous() if contiguous else out
|
427
527
|
|
428
528
|
# ***** creation helper functions *****
|
429
529
|
|
430
530
|
@staticmethod
|
431
|
-
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
|
531
|
+
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
432
532
|
"""
|
433
533
|
Creates a tensor with the given shape, filled with the given value.
|
434
534
|
|
@@ -445,7 +545,7 @@ class Tensor:
|
|
445
545
|
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
446
546
|
|
447
547
|
@staticmethod
|
448
|
-
def zeros(*shape, **kwargs):
|
548
|
+
def zeros(*shape, **kwargs) -> Tensor:
|
449
549
|
"""
|
450
550
|
Creates a tensor with the given shape, filled with zeros.
|
451
551
|
|
@@ -462,7 +562,7 @@ class Tensor:
|
|
462
562
|
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
463
563
|
|
464
564
|
@staticmethod
|
465
|
-
def ones(*shape, **kwargs):
|
565
|
+
def ones(*shape, **kwargs) -> Tensor:
|
466
566
|
"""
|
467
567
|
Creates a tensor with the given shape, filled with ones.
|
468
568
|
|
@@ -479,7 +579,7 @@ class Tensor:
|
|
479
579
|
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
480
580
|
|
481
581
|
@staticmethod
|
482
|
-
def arange(start, stop=None, step=1, **kwargs):
|
582
|
+
def arange(start, stop=None, step=1, **kwargs) -> Tensor:
|
483
583
|
"""
|
484
584
|
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
|
485
585
|
|
@@ -504,14 +604,35 @@ class Tensor:
|
|
504
604
|
```
|
505
605
|
"""
|
506
606
|
if stop is None: stop, start = start, 0
|
507
|
-
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
508
607
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
509
|
-
|
608
|
+
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
609
|
+
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
610
|
+
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
611
|
+
|
612
|
+
@staticmethod
|
613
|
+
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
614
|
+
"""
|
615
|
+
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
|
616
|
+
|
617
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
618
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
619
|
+
|
620
|
+
```python exec="true" source="above" session="tensor" result="python"
|
621
|
+
print(Tensor.linspace(0, 10, 5).numpy())
|
622
|
+
```
|
623
|
+
```python exec="true" source="above" session="tensor" result="python"
|
624
|
+
print(Tensor.linspace(-1, 1, 5).numpy())
|
625
|
+
```
|
626
|
+
"""
|
627
|
+
if steps < 0: raise ValueError("number of steps must be non-negative")
|
628
|
+
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
|
629
|
+
if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
|
630
|
+
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
510
631
|
|
511
632
|
@staticmethod
|
512
|
-
def eye(
|
633
|
+
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
|
513
634
|
"""
|
514
|
-
|
635
|
+
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
515
636
|
|
516
637
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
517
638
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
@@ -519,10 +640,16 @@ class Tensor:
|
|
519
640
|
```python exec="true" source="above" session="tensor" result="python"
|
520
641
|
print(Tensor.eye(3).numpy())
|
521
642
|
```
|
643
|
+
|
644
|
+
```python exec="true" source="above" session="tensor" result="python"
|
645
|
+
print(Tensor.eye(2, 4).numpy())
|
646
|
+
```
|
522
647
|
"""
|
523
|
-
|
648
|
+
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
649
|
+
x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
650
|
+
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
|
524
651
|
|
525
|
-
def full_like(self, fill_value:ConstType, **kwargs):
|
652
|
+
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
526
653
|
"""
|
527
654
|
Creates a tensor with the same shape as `self`, filled with the given value.
|
528
655
|
If `dtype` is not specified, the dtype of `self` is used.
|
@@ -537,7 +664,7 @@ class Tensor:
|
|
537
664
|
"""
|
538
665
|
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
539
666
|
|
540
|
-
def zeros_like(self, **kwargs):
|
667
|
+
def zeros_like(self, **kwargs) -> Tensor:
|
541
668
|
"""
|
542
669
|
Creates a tensor with the same shape as `self`, filled with zeros.
|
543
670
|
|
@@ -551,7 +678,7 @@ class Tensor:
|
|
551
678
|
"""
|
552
679
|
return self.full_like(0, **kwargs)
|
553
680
|
|
554
|
-
def ones_like(self, **kwargs):
|
681
|
+
def ones_like(self, **kwargs) -> Tensor:
|
555
682
|
"""
|
556
683
|
Creates a tensor with the same shape as `self`, filled with ones.
|
557
684
|
|
@@ -565,10 +692,31 @@ class Tensor:
|
|
565
692
|
"""
|
566
693
|
return self.full_like(1, **kwargs)
|
567
694
|
|
695
|
+
def rand_like(self, **kwargs) -> Tensor:
|
696
|
+
"""
|
697
|
+
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
698
|
+
|
699
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
700
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
701
|
+
|
702
|
+
```python exec="true" source="above" session="tensor" result="python"
|
703
|
+
t = Tensor.ones(2, 3)
|
704
|
+
print(Tensor.rand_like(t).numpy())
|
705
|
+
```
|
706
|
+
"""
|
707
|
+
dtype = kwargs.pop("dtype", self.dtype)
|
708
|
+
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
|
709
|
+
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
710
|
+
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
711
|
+
contiguous = kwargs.pop("contiguous", True)
|
712
|
+
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
|
713
|
+
return Tensor(MultiLazyBuffer(cast(List[LazyBuffer], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
714
|
+
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
715
|
+
|
568
716
|
# ***** rng hlops *****
|
569
717
|
|
570
718
|
@staticmethod
|
571
|
-
def randn(*shape, dtype:Optional[
|
719
|
+
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
572
720
|
"""
|
573
721
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
574
722
|
If `dtype` is not specified, the default type is used.
|
@@ -600,7 +748,7 @@ class Tensor:
|
|
600
748
|
```
|
601
749
|
"""
|
602
750
|
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
603
|
-
dtype = kwargs.pop("dtype", dtypes.int32)
|
751
|
+
dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
|
604
752
|
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
605
753
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
606
754
|
|
@@ -706,7 +854,7 @@ class Tensor:
|
|
706
854
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
707
855
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
708
856
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
709
|
-
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1
|
857
|
+
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
710
858
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
711
859
|
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
712
860
|
|
@@ -715,39 +863,46 @@ class Tensor:
|
|
715
863
|
def _deepwalk(self):
|
716
864
|
def _walk(node, visited):
|
717
865
|
visited.add(node)
|
718
|
-
if
|
866
|
+
# if tensor is not leaf, reset grad
|
867
|
+
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
868
|
+
if ctx:
|
719
869
|
for i in node._ctx.parents:
|
720
870
|
if i not in visited: yield from _walk(i, visited)
|
721
871
|
yield node
|
722
872
|
return list(_walk(self, set()))
|
723
873
|
|
724
|
-
def backward(self) -> Tensor:
|
874
|
+
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
725
875
|
"""
|
726
876
|
Propagates the gradient of a tensor backwards through the computation graph.
|
727
|
-
|
728
|
-
|
877
|
+
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
878
|
+
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
729
879
|
```python exec="true" source="above" session="tensor" result="python"
|
730
880
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
731
881
|
t.sum().backward()
|
732
882
|
print(t.grad.numpy())
|
733
883
|
```
|
734
884
|
"""
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
885
|
+
toposorted = self._deepwalk()
|
886
|
+
if gradient is None:
|
887
|
+
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
888
|
+
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
889
|
+
# this is "implicit gradient creation"
|
890
|
+
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
740
891
|
|
741
|
-
|
892
|
+
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
893
|
+
self.grad = gradient
|
894
|
+
for t0 in reversed(toposorted):
|
742
895
|
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
896
|
+
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
743
897
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
898
|
+
_METADATA.reset(token)
|
744
899
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
745
900
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
746
901
|
for t, g in zip(t0._ctx.parents, grads):
|
747
902
|
if g is not None and t.requires_grad:
|
748
903
|
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
749
904
|
t.grad = g if t.grad is None else (t.grad + g)
|
750
|
-
del t0._ctx
|
905
|
+
if not retain_graph: del t0._ctx
|
751
906
|
return self
|
752
907
|
|
753
908
|
# ***** movement low level ops *****
|
@@ -822,7 +977,7 @@ class Tensor:
|
|
822
977
|
```
|
823
978
|
"""
|
824
979
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
825
|
-
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at
|
980
|
+
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
826
981
|
return F.Flip.apply(self, axis=axis_arg)
|
827
982
|
|
828
983
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
@@ -842,30 +997,58 @@ class Tensor:
|
|
842
997
|
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
843
998
|
```
|
844
999
|
"""
|
845
|
-
if
|
846
|
-
return F.Shrink.apply(self, arg=tuple(
|
1000
|
+
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
1001
|
+
return F.Shrink.apply(self, arg=tuple(shrink_arg))
|
847
1002
|
|
848
|
-
def pad(self,
|
1003
|
+
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
849
1004
|
"""
|
850
|
-
Returns a tensor
|
851
|
-
`
|
852
|
-
|
853
|
-
|
1005
|
+
Returns a tensor with padding applied based on the input `padding`.
|
1006
|
+
`padding` supports two padding structures:
|
1007
|
+
|
1008
|
+
1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
|
1009
|
+
- This structure matches PyTorch's pad.
|
1010
|
+
- `padding` length must be even.
|
1011
|
+
|
1012
|
+
2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
|
1013
|
+
- This structure matches pad for jax, numpy, tensorflow and others.
|
1014
|
+
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
|
1015
|
+
- `padding` must have the same length as `self.ndim`.
|
1016
|
+
|
1017
|
+
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
|
1018
|
+
Padding modes is selected with `mode` which supports `constant` and `reflect`.
|
854
1019
|
|
855
1020
|
```python exec="true" source="above" session="tensor" result="python"
|
856
|
-
t = Tensor.arange(
|
1021
|
+
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
857
1022
|
print(t.numpy())
|
858
1023
|
```
|
859
1024
|
```python exec="true" source="above" session="tensor" result="python"
|
860
|
-
print(t.pad((
|
1025
|
+
print(t.pad((1, 2, 0, -1)).numpy())
|
1026
|
+
```
|
1027
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1028
|
+
print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
|
861
1029
|
```
|
862
1030
|
```python exec="true" source="above" session="tensor" result="python"
|
863
|
-
print(t.pad((
|
1031
|
+
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
|
864
1032
|
```
|
865
1033
|
"""
|
866
|
-
if
|
867
|
-
|
868
|
-
|
1034
|
+
if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
|
1035
|
+
if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
1036
|
+
# turn flat padding into group padding
|
1037
|
+
pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
|
1038
|
+
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
1039
|
+
X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
|
1040
|
+
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
|
1041
|
+
# early return for symbolic with positive pads (no need to max)
|
1042
|
+
if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
|
1043
|
+
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
|
1044
|
+
if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
|
1045
|
+
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
1046
|
+
for d,(pB,pA) in enumerate(pads):
|
1047
|
+
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
1048
|
+
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
1049
|
+
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
1050
|
+
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
1051
|
+
return X.shrink(shrinks(X.shape))
|
869
1052
|
|
870
1053
|
# ***** movement high level ops *****
|
871
1054
|
|
@@ -897,7 +1080,7 @@ class Tensor:
|
|
897
1080
|
# 2. Bool indexing is not supported
|
898
1081
|
# 3. Out of bounds Tensor indexing results in 0
|
899
1082
|
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
900
|
-
def
|
1083
|
+
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
901
1084
|
# 1. indices normalization and validation
|
902
1085
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
903
1086
|
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
@@ -921,7 +1104,6 @@ class Tensor:
|
|
921
1104
|
|
922
1105
|
# record None for dimension injection later and filter None and record rest of indices
|
923
1106
|
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
924
|
-
tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
925
1107
|
indices_filtered = [i for i in indices if i is not None]
|
926
1108
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
927
1109
|
|
@@ -939,13 +1121,15 @@ class Tensor:
|
|
939
1121
|
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
|
940
1122
|
for dim in type_dim[slice]:
|
941
1123
|
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
|
1124
|
+
if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)):
|
1125
|
+
raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice("
|
1126
|
+
f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).")
|
942
1127
|
s, e, st = index.indices(self.shape[dim])
|
943
1128
|
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
|
944
|
-
#
|
945
|
-
tensor_index: List[Tensor] = []
|
1129
|
+
# skip all Tensor dims for basic indexing
|
946
1130
|
for dim in type_dim[Tensor]:
|
947
|
-
|
948
|
-
if not dtypes.is_int(
|
1131
|
+
dtype = indices_filtered[dim].dtype
|
1132
|
+
if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
949
1133
|
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
950
1134
|
|
951
1135
|
new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
|
@@ -955,6 +1139,7 @@ class Tensor:
|
|
955
1139
|
if any(abs(st) != 1 for st in strides):
|
956
1140
|
strides = tuple(abs(s) for s in strides)
|
957
1141
|
# pad shape to multiple of stride
|
1142
|
+
if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted")
|
958
1143
|
ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
|
959
1144
|
ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
|
960
1145
|
ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
|
@@ -968,6 +1153,7 @@ class Tensor:
|
|
968
1153
|
|
969
1154
|
# 3. advanced indexing (copy)
|
970
1155
|
if type_dim[Tensor]:
|
1156
|
+
dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
971
1157
|
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
972
1158
|
def calc_dim(tensor_dim:int) -> int:
|
973
1159
|
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
@@ -975,7 +1161,7 @@ class Tensor:
|
|
975
1161
|
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
976
1162
|
# track tensor_dim and tensor_index using a dict
|
977
1163
|
# calc_dim to get dim and use that to normalize the negative tensor indices
|
978
|
-
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in
|
1164
|
+
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors}
|
979
1165
|
|
980
1166
|
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
981
1167
|
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
@@ -993,35 +1179,47 @@ class Tensor:
|
|
993
1179
|
# inject 1's for the extra dims added in create masks
|
994
1180
|
reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
995
1181
|
# sum reduce the extra dims introduced in create masks
|
996
|
-
ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
1182
|
+
ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
997
1183
|
|
998
1184
|
# special permute case
|
999
1185
|
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
1000
1186
|
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
|
1187
|
+
|
1188
|
+
# for advanced setitem, returns whole tensor with indices replaced
|
1189
|
+
if v is not None:
|
1190
|
+
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
|
1191
|
+
# add back reduced dims from sum
|
1192
|
+
for dim in sum_axis: vb = vb.unsqueeze(dim)
|
1193
|
+
# axis to be reduced to match self.shape
|
1194
|
+
axis = tuple(range(first_dim, first_dim + len(big_shape)))
|
1195
|
+
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
|
1196
|
+
vb = vb * mask
|
1197
|
+
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
|
1198
|
+
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
|
1199
|
+
ret = mask.any(axis).where(vb.squeeze(), self)
|
1200
|
+
|
1001
1201
|
return ret
|
1002
1202
|
|
1203
|
+
def __getitem__(self, indices) -> Tensor:
|
1204
|
+
return self._getitem(indices)
|
1205
|
+
|
1003
1206
|
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
1004
1207
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
1005
|
-
self.
|
1208
|
+
self._getitem(indices).assign(v)
|
1006
1209
|
return
|
1007
1210
|
# NOTE: check that setitem target is valid first
|
1008
|
-
|
1211
|
+
if not all(lb.st.contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
|
1009
1212
|
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1010
1213
|
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1011
1214
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
1012
|
-
if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
|
1013
|
-
raise NotImplementedError("Advanced indexing setitem is not currently supported")
|
1014
|
-
|
1015
|
-
assign_to = self.realize().__getitem__(indices)
|
1016
|
-
# NOTE: contiguous to prevent const folding.
|
1017
|
-
v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
1018
|
-
assign_to.assign(v).realize()
|
1019
1215
|
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1216
|
+
res = self.realize()._getitem(indices, v)
|
1217
|
+
# if shapes match and data is not shared it's a copy and we assign to self
|
1218
|
+
if res.shape == self.shape and res.lazydata is not self.lazydata:
|
1219
|
+
self.assign(res).realize()
|
1220
|
+
else: # no copy, basic setitem
|
1221
|
+
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
|
1222
|
+
res.assign(v).realize()
|
1025
1223
|
|
1026
1224
|
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
|
1027
1225
|
"""
|
@@ -1036,8 +1234,8 @@ class Tensor:
|
|
1036
1234
|
```
|
1037
1235
|
"""
|
1038
1236
|
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
1039
|
-
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1040
1237
|
dim = self._resolve_dim(dim)
|
1238
|
+
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1041
1239
|
index = index.to(self.device)
|
1042
1240
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1043
1241
|
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
@@ -1056,13 +1254,11 @@ class Tensor:
|
|
1056
1254
|
```
|
1057
1255
|
"""
|
1058
1256
|
dim = self._resolve_dim(dim)
|
1059
|
-
assert
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
|
1065
|
-
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
1257
|
+
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
1258
|
+
tensors = [self, *args]
|
1259
|
+
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
1260
|
+
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
1261
|
+
return functools.reduce(Tensor.add, tensors)
|
1066
1262
|
|
1067
1263
|
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1068
1264
|
"""
|
@@ -1077,7 +1273,20 @@ class Tensor:
|
|
1077
1273
|
```
|
1078
1274
|
"""
|
1079
1275
|
# checks for shapes and number of dimensions delegated to cat
|
1080
|
-
return
|
1276
|
+
return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
|
1277
|
+
|
1278
|
+
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
|
1279
|
+
"""
|
1280
|
+
Repeat elements of a tensor.
|
1281
|
+
|
1282
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1283
|
+
t = Tensor([1, 2, 3])
|
1284
|
+
print(t.repeat_interleave(2).numpy())
|
1285
|
+
```
|
1286
|
+
"""
|
1287
|
+
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
|
1288
|
+
shp = x.shape
|
1289
|
+
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
|
1081
1290
|
|
1082
1291
|
def repeat(self, repeats, *args) -> Tensor:
|
1083
1292
|
"""
|
@@ -1093,16 +1302,16 @@ class Tensor:
|
|
1093
1302
|
```
|
1094
1303
|
"""
|
1095
1304
|
repeats = argfix(repeats, *args)
|
1096
|
-
base_shape = (
|
1097
|
-
|
1098
|
-
|
1305
|
+
base_shape = _pad_left(self.shape, repeats)[0]
|
1306
|
+
unsqueezed_shape = flatten([[1, s] for s in base_shape])
|
1307
|
+
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
|
1099
1308
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
1100
|
-
return self.reshape(
|
1309
|
+
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
1101
1310
|
|
1102
|
-
def _resolve_dim(self, dim:int, *,
|
1103
|
-
|
1104
|
-
|
1105
|
-
return dim +
|
1311
|
+
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
|
1312
|
+
total = self.ndim + int(extra)
|
1313
|
+
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
1314
|
+
return dim + total if dim < 0 else dim
|
1106
1315
|
|
1107
1316
|
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
1108
1317
|
"""
|
@@ -1151,7 +1360,34 @@ class Tensor:
|
|
1151
1360
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
1152
1361
|
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
1153
1362
|
dim = self._resolve_dim(dim)
|
1154
|
-
return list(self.split(
|
1363
|
+
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
1364
|
+
|
1365
|
+
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]:
|
1366
|
+
"""
|
1367
|
+
Generates coordinate matrices from coordinate vectors.
|
1368
|
+
Input tensors can be scalars or 1D tensors.
|
1369
|
+
|
1370
|
+
`indexing` determines how the output grids are aligned.
|
1371
|
+
`ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing.
|
1372
|
+
|
1373
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1374
|
+
x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
|
1375
|
+
grid_x, grid_y = x.meshgrid(y)
|
1376
|
+
print(grid_x.numpy())
|
1377
|
+
print(grid_y.numpy())
|
1378
|
+
```
|
1379
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1380
|
+
grid_x, grid_y = x.meshgrid(y, indexing="xy")
|
1381
|
+
print(grid_x.numpy())
|
1382
|
+
print(grid_y.numpy())
|
1383
|
+
```
|
1384
|
+
"""
|
1385
|
+
if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
|
1386
|
+
if len(tensors:=(self, *args)) == 1: return tensors
|
1387
|
+
basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors)))
|
1388
|
+
tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors))
|
1389
|
+
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
1390
|
+
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
1155
1391
|
|
1156
1392
|
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
1157
1393
|
"""
|
@@ -1185,25 +1421,9 @@ class Tensor:
|
|
1185
1421
|
print(t.unsqueeze(1).numpy())
|
1186
1422
|
```
|
1187
1423
|
"""
|
1188
|
-
dim = self._resolve_dim(dim,
|
1424
|
+
dim = self._resolve_dim(dim, extra=True)
|
1189
1425
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
1190
1426
|
|
1191
|
-
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
|
1192
|
-
"""
|
1193
|
-
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
|
1194
|
-
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
1195
|
-
|
1196
|
-
```python exec="true" source="above" session="tensor" result="python"
|
1197
|
-
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
1198
|
-
print(t.numpy())
|
1199
|
-
```
|
1200
|
-
```python exec="true" source="above" session="tensor" result="python"
|
1201
|
-
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
1202
|
-
```
|
1203
|
-
"""
|
1204
|
-
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
1205
|
-
return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
1206
|
-
|
1207
1427
|
@property
|
1208
1428
|
def T(self) -> Tensor:
|
1209
1429
|
"""`.T` is an alias for `.transpose()`."""
|
@@ -1259,20 +1479,37 @@ class Tensor:
|
|
1259
1479
|
dim = self._resolve_dim(dim)
|
1260
1480
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
1261
1481
|
|
1482
|
+
def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
|
1483
|
+
"""
|
1484
|
+
Rolls the tensor along specified dimension(s).
|
1485
|
+
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
1486
|
+
|
1487
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1488
|
+
t = Tensor.arange(4)
|
1489
|
+
print(t.roll(shifts=1, dims=0).numpy())
|
1490
|
+
```
|
1491
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1492
|
+
print(t.roll(shifts=-1, dims=0).numpy())
|
1493
|
+
```
|
1494
|
+
"""
|
1495
|
+
dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
|
1496
|
+
for dim, shift in zip(dims, make_tuple(shifts, 1)):
|
1497
|
+
shift = shift % self.shape[dim]
|
1498
|
+
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
|
1499
|
+
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
|
1500
|
+
return rolled
|
1501
|
+
|
1262
1502
|
# ***** reduce ops *****
|
1263
1503
|
|
1264
1504
|
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
1270
|
-
ret = fxn.apply(self, axis=axis_)
|
1271
|
-
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
1505
|
+
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
1506
|
+
if self.ndim == 0: axis = ()
|
1507
|
+
ret = fxn.apply(self, axis=axis)
|
1508
|
+
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
1272
1509
|
|
1273
|
-
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[
|
1510
|
+
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
1274
1511
|
"""
|
1275
|
-
|
1512
|
+
Returns the sum of the elements of the tensor along the specified axis or axes.
|
1276
1513
|
|
1277
1514
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1278
1515
|
which the maximum is computed and whether the reduced dimensions are retained.
|
@@ -1294,9 +1531,35 @@ class Tensor:
|
|
1294
1531
|
print(t.sum(axis=1).numpy())
|
1295
1532
|
```
|
1296
1533
|
"""
|
1297
|
-
ret = self.cast(
|
1534
|
+
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
|
1298
1535
|
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1299
1536
|
|
1537
|
+
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
1538
|
+
"""
|
1539
|
+
Returns the product of the elements of the tensor along the specified axis or axes.
|
1540
|
+
|
1541
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1542
|
+
which the maximum is computed and whether the reduced dimensions are retained.
|
1543
|
+
|
1544
|
+
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
|
1545
|
+
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
1546
|
+
|
1547
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1548
|
+
t = Tensor([-1, -2, -3, 1, 2, 3]).reshape(2, 3)
|
1549
|
+
print(t.numpy())
|
1550
|
+
```
|
1551
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1552
|
+
print(t.prod().numpy())
|
1553
|
+
```
|
1554
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1555
|
+
print(t.prod(axis=0).numpy())
|
1556
|
+
```
|
1557
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1558
|
+
print(t.prod(axis=1).numpy())
|
1559
|
+
```
|
1560
|
+
"""
|
1561
|
+
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
|
1562
|
+
|
1300
1563
|
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1301
1564
|
"""
|
1302
1565
|
Returns the maximum value of the tensor along the specified axis or axes.
|
@@ -1341,8 +1604,53 @@ class Tensor:
|
|
1341
1604
|
print(t.min(axis=1, keepdim=True).numpy())
|
1342
1605
|
```
|
1343
1606
|
"""
|
1607
|
+
if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
|
1344
1608
|
return -((-self).max(axis=axis, keepdim=keepdim))
|
1345
1609
|
|
1610
|
+
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1611
|
+
"""
|
1612
|
+
Tests if any element evaluates to `True` along the specified axis or axes.
|
1613
|
+
|
1614
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
1615
|
+
|
1616
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1617
|
+
t = Tensor([[True, True], [True, False], [False, False]])
|
1618
|
+
print(t.numpy())
|
1619
|
+
```
|
1620
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1621
|
+
print(t.any().numpy())
|
1622
|
+
```
|
1623
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1624
|
+
print(t.any(axis=0).numpy())
|
1625
|
+
```
|
1626
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1627
|
+
print(t.any(axis=1, keepdim=True).numpy())
|
1628
|
+
```
|
1629
|
+
"""
|
1630
|
+
return self.bool().max(axis, keepdim)
|
1631
|
+
|
1632
|
+
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1633
|
+
"""
|
1634
|
+
Tests if all element evaluates to `True` along the specified axis or axes.
|
1635
|
+
|
1636
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
1637
|
+
|
1638
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1639
|
+
t = Tensor([[True, True], [True, False], [False, False]])
|
1640
|
+
print(t.numpy())
|
1641
|
+
```
|
1642
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1643
|
+
print(t.all().numpy())
|
1644
|
+
```
|
1645
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1646
|
+
print(t.all(axis=0).numpy())
|
1647
|
+
```
|
1648
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1649
|
+
print(t.all(axis=1, keepdim=True).numpy())
|
1650
|
+
```
|
1651
|
+
"""
|
1652
|
+
return self.logical_not().any(axis, keepdim).logical_not()
|
1653
|
+
|
1346
1654
|
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1347
1655
|
"""
|
1348
1656
|
Returns the mean value of the tensor along the specified axis or axes.
|
@@ -1367,7 +1675,7 @@ class Tensor:
|
|
1367
1675
|
"""
|
1368
1676
|
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
1369
1677
|
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
1370
|
-
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
1678
|
+
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
|
1371
1679
|
|
1372
1680
|
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1373
1681
|
"""
|
@@ -1392,8 +1700,8 @@ class Tensor:
|
|
1392
1700
|
```
|
1393
1701
|
"""
|
1394
1702
|
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
1395
|
-
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
|
1396
|
-
return squares.sum(axis=axis, keepdim=keepdim).div(
|
1703
|
+
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
1704
|
+
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
1397
1705
|
|
1398
1706
|
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1399
1707
|
"""
|
@@ -1419,12 +1727,30 @@ class Tensor:
|
|
1419
1727
|
"""
|
1420
1728
|
return self.var(axis, keepdim, correction).sqrt()
|
1421
1729
|
|
1422
|
-
def
|
1423
|
-
|
1730
|
+
def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1731
|
+
"""
|
1732
|
+
Calculates the standard deviation and mean over the dimensions specified by dim.
|
1733
|
+
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
|
1734
|
+
|
1735
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1736
|
+
Tensor.manual_seed(42)
|
1737
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1738
|
+
print(t.numpy())
|
1739
|
+
```
|
1740
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1741
|
+
std, mean = t.std_mean()
|
1742
|
+
print(std.numpy(), mean.numpy())
|
1743
|
+
```
|
1744
|
+
"""
|
1745
|
+
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
1746
|
+
|
1747
|
+
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
1748
|
+
x = self.cast(dtype) if dtype is not None else self
|
1749
|
+
m = x - x.max(axis=axis, keepdim=True).detach()
|
1424
1750
|
e = m.exp()
|
1425
1751
|
return m, e, e.sum(axis=axis, keepdim=True)
|
1426
1752
|
|
1427
|
-
def softmax(self, axis=-1):
|
1753
|
+
def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
|
1428
1754
|
"""
|
1429
1755
|
Applies the softmax function to the tensor along the specified axis.
|
1430
1756
|
|
@@ -1444,10 +1770,10 @@ class Tensor:
|
|
1444
1770
|
print(t.softmax(axis=0).numpy())
|
1445
1771
|
```
|
1446
1772
|
"""
|
1447
|
-
_, e, ss = self._softmax(axis)
|
1773
|
+
_, e, ss = self._softmax(axis, dtype)
|
1448
1774
|
return e.div(ss)
|
1449
1775
|
|
1450
|
-
def log_softmax(self, axis=-1):
|
1776
|
+
def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
|
1451
1777
|
"""
|
1452
1778
|
Applies the log-softmax function to the tensor along the specified axis.
|
1453
1779
|
|
@@ -1467,7 +1793,7 @@ class Tensor:
|
|
1467
1793
|
print(t.log_softmax(axis=0).numpy())
|
1468
1794
|
```
|
1469
1795
|
"""
|
1470
|
-
m, _, ss = self._softmax(axis)
|
1796
|
+
m, _, ss = self._softmax(axis, dtype)
|
1471
1797
|
return m - ss.log()
|
1472
1798
|
|
1473
1799
|
def logsumexp(self, axis=None, keepdim=False):
|
@@ -1497,6 +1823,33 @@ class Tensor:
|
|
1497
1823
|
m = self.max(axis=axis, keepdim=True)
|
1498
1824
|
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
1499
1825
|
|
1826
|
+
def logcumsumexp(self, axis=0):
|
1827
|
+
"""
|
1828
|
+
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
|
1829
|
+
|
1830
|
+
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
|
1831
|
+
|
1832
|
+
You can pass in the `axis` keyword argument to control the axis along which
|
1833
|
+
the log-cum-sum-exp is computed.
|
1834
|
+
|
1835
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1836
|
+
Tensor.manual_seed(42)
|
1837
|
+
t = Tensor.randn(2, 3)
|
1838
|
+
print(t.numpy())
|
1839
|
+
```
|
1840
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1841
|
+
print(t.logcumsumexp().numpy())
|
1842
|
+
```
|
1843
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1844
|
+
print(t.logcumsumexp(axis=0).numpy())
|
1845
|
+
```
|
1846
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1847
|
+
print(t.logcumsumexp(axis=1).numpy())
|
1848
|
+
```
|
1849
|
+
"""
|
1850
|
+
m = self.max(axis=axis, keepdim=True)
|
1851
|
+
return (self - m).exp().cumsum(axis=axis).log() + m
|
1852
|
+
|
1500
1853
|
def argmax(self, axis=None, keepdim=False):
|
1501
1854
|
"""
|
1502
1855
|
Returns the indices of the maximum value of the tensor along the specified axis.
|
@@ -1521,8 +1874,8 @@ class Tensor:
|
|
1521
1874
|
if axis is None: return self.flatten().argmax(0)
|
1522
1875
|
axis = self._resolve_dim(axis)
|
1523
1876
|
m = self == self.max(axis=axis, keepdim=True)
|
1524
|
-
idx = m * Tensor.arange(self.shape[axis]
|
1525
|
-
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)
|
1877
|
+
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
1878
|
+
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
|
1526
1879
|
|
1527
1880
|
def argmin(self, axis=None, keepdim=False):
|
1528
1881
|
"""
|
@@ -1547,8 +1900,48 @@ class Tensor:
|
|
1547
1900
|
"""
|
1548
1901
|
return (-self).argmax(axis=axis, keepdim=keepdim)
|
1549
1902
|
|
1903
|
+
def rearrange(self, formula: str, **sizes) -> Tensor:
|
1904
|
+
"""
|
1905
|
+
Rearranges input according to formula
|
1906
|
+
|
1907
|
+
See: https://einops.rocks/api/rearrange/
|
1908
|
+
|
1909
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1910
|
+
x = Tensor([[1, 2], [3, 4]])
|
1911
|
+
print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
|
1912
|
+
```
|
1913
|
+
"""
|
1914
|
+
def parse_formula(formula: str):
|
1915
|
+
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
1916
|
+
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
1917
|
+
pairs = list(zip(lparens, rparens))
|
1918
|
+
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
1919
|
+
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
1920
|
+
|
1921
|
+
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
1922
|
+
|
1923
|
+
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
1924
|
+
|
1925
|
+
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
1926
|
+
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
1927
|
+
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
1928
|
+
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
1929
|
+
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
1930
|
+
|
1931
|
+
# resolve ellipsis
|
1932
|
+
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
1933
|
+
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
1934
|
+
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
1935
|
+
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
1936
|
+
|
1937
|
+
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
1938
|
+
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
1939
|
+
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
1940
|
+
t = t.permute([lhs.index(name) for name in rhs])
|
1941
|
+
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1942
|
+
|
1550
1943
|
@staticmethod
|
1551
|
-
def einsum(formula:str, *
|
1944
|
+
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1552
1945
|
"""
|
1553
1946
|
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
1554
1947
|
|
@@ -1560,11 +1953,20 @@ class Tensor:
|
|
1560
1953
|
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
1561
1954
|
```
|
1562
1955
|
"""
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1956
|
+
def parse_formula(formula:str, *operands:Tensor):
|
1957
|
+
if "..." in (formula := formula.replace(" ", "")):
|
1958
|
+
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
|
1959
|
+
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
1960
|
+
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
1961
|
+
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
1962
|
+
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
|
1963
|
+
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
|
1964
|
+
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
1965
|
+
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
1966
|
+
|
1967
|
+
xs:Tuple[Tensor, ...] = argfix(*operands)
|
1968
|
+
inputs_str, output = parse_formula(formula, *xs)
|
1969
|
+
inputs = inputs_str.split(",")
|
1568
1970
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
1569
1971
|
|
1570
1972
|
# map the value of each letter in the formula
|
@@ -1576,41 +1978,43 @@ class Tensor:
|
|
1576
1978
|
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
1577
1979
|
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
1578
1980
|
|
1579
|
-
#
|
1580
|
-
|
1581
|
-
rhs_order = argsort(rhs_letter_order)
|
1981
|
+
# ordinal encode the output alphabet
|
1982
|
+
rhs_order = argsort(argsort(list(output)))
|
1582
1983
|
|
1583
1984
|
# sum over all axes that's not in the output, then permute to the output order
|
1584
1985
|
return functools.reduce(lambda a,b:a*b, xs_) \
|
1585
|
-
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
|
1986
|
+
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
|
1586
1987
|
|
1587
1988
|
# ***** processing ops *****
|
1588
1989
|
|
1589
1990
|
def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
1590
1991
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
1591
|
-
|
1592
|
-
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
1992
|
+
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
1593
1993
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1994
|
+
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
1995
|
+
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
1996
|
+
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
1997
|
+
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
1597
1998
|
# repeats such that we don't need padding
|
1598
|
-
|
1999
|
+
x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)])
|
1599
2000
|
# handle dilation
|
1600
|
-
|
2001
|
+
x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
|
1601
2002
|
# handle stride
|
1602
|
-
|
1603
|
-
|
2003
|
+
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
2004
|
+
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
1604
2005
|
# permute to move reduce to the end
|
1605
|
-
return
|
2006
|
+
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
1606
2007
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
return
|
2008
|
+
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
|
2009
|
+
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
|
2010
|
+
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
2011
|
+
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
2012
|
+
|
2013
|
+
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
2014
|
+
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
1611
2015
|
|
1612
2016
|
# NOTE: these work for more than 2D
|
1613
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
2017
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
1614
2018
|
"""
|
1615
2019
|
Applies average pooling over a tensor.
|
1616
2020
|
|
@@ -1622,11 +2026,15 @@ class Tensor:
|
|
1622
2026
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1623
2027
|
print(t.avg_pool2d().numpy())
|
1624
2028
|
```
|
2029
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2030
|
+
print(t.avg_pool2d(padding=1).numpy())
|
2031
|
+
```
|
1625
2032
|
"""
|
1626
|
-
|
1627
|
-
return
|
2033
|
+
padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
|
2034
|
+
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
2035
|
+
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
|
1628
2036
|
|
1629
|
-
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
2037
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
1630
2038
|
"""
|
1631
2039
|
Applies max pooling over a tensor.
|
1632
2040
|
|
@@ -1638,11 +2046,15 @@ class Tensor:
|
|
1638
2046
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1639
2047
|
print(t.max_pool2d().numpy())
|
1640
2048
|
```
|
2049
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2050
|
+
print(t.max_pool2d(padding=1).numpy())
|
2051
|
+
```
|
1641
2052
|
"""
|
1642
|
-
|
1643
|
-
return self._pool(
|
2053
|
+
padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
|
2054
|
+
return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
1644
2055
|
|
1645
|
-
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding
|
2056
|
+
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
|
2057
|
+
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1646
2058
|
"""
|
1647
2059
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
1648
2060
|
|
@@ -1656,13 +2068,14 @@ class Tensor:
|
|
1656
2068
|
print(t.conv2d(w).numpy())
|
1657
2069
|
```
|
1658
2070
|
"""
|
2071
|
+
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
|
1659
2072
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
1660
2073
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
1661
2074
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
1662
|
-
padding_ =
|
2075
|
+
padding_ = self._padding2d(padding, len(HW))
|
1663
2076
|
|
1664
2077
|
# conv2d is a pooling op (with padding)
|
1665
|
-
x = self.
|
2078
|
+
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
1666
2079
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
1667
2080
|
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
1668
2081
|
# normal conv
|
@@ -1680,7 +2093,7 @@ class Tensor:
|
|
1680
2093
|
# todo: stride == dilation
|
1681
2094
|
# use padding to round up to 4x4 output tiles
|
1682
2095
|
# (bs, cin_, tyx, HWI)
|
1683
|
-
d = self.
|
2096
|
+
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
1684
2097
|
# move HW to the front: # (HWI, bs, cin_, tyx)
|
1685
2098
|
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
1686
2099
|
tyx = d.shape[-len(HWI):] # dim of tiling
|
@@ -1719,7 +2132,7 @@ class Tensor:
|
|
1719
2132
|
"""
|
1720
2133
|
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
1721
2134
|
HW = weight.shape[2:]
|
1722
|
-
stride, dilation, padding, output_padding = [
|
2135
|
+
stride, dilation, padding, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
|
1723
2136
|
if any(s>1 for s in stride):
|
1724
2137
|
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
1725
2138
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
@@ -1729,26 +2142,35 @@ class Tensor:
|
|
1729
2142
|
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
1730
2143
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
1731
2144
|
|
1732
|
-
def dot(self, w:Tensor, acc_dtype:Optional[
|
2145
|
+
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
2146
|
+
|
1733
2147
|
"""
|
1734
2148
|
Performs dot product between two tensors.
|
2149
|
+
If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
|
2150
|
+
If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
|
1735
2151
|
|
1736
2152
|
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
1737
2153
|
|
2154
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2155
|
+
a = Tensor([1, 2, 3])
|
2156
|
+
b = Tensor([1, 1, 0])
|
2157
|
+
print(a.dot(b).numpy())
|
2158
|
+
```
|
1738
2159
|
```python exec="true" source="above" session="tensor" result="python"
|
1739
2160
|
a = Tensor([[1, 2], [3, 4]])
|
1740
2161
|
b = Tensor([[5, 6], [7, 8]])
|
1741
2162
|
print(a.dot(b).numpy())
|
1742
2163
|
```
|
1743
2164
|
"""
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
x
|
1748
|
-
|
2165
|
+
if IMAGE: return self.image_dot(w, acc_dtype)
|
2166
|
+
x, dx, dw = self, self.ndim, w.ndim
|
2167
|
+
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
2168
|
+
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
|
2169
|
+
x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
|
2170
|
+
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
|
1749
2171
|
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
1750
2172
|
|
1751
|
-
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[
|
2173
|
+
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1752
2174
|
"""
|
1753
2175
|
Performs matrix multiplication between two tensors.
|
1754
2176
|
|
@@ -1766,7 +2188,7 @@ class Tensor:
|
|
1766
2188
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
1767
2189
|
assert self.shape[axis] != 0
|
1768
2190
|
pl_sz = self.shape[axis] - int(not _first_zero)
|
1769
|
-
return self.transpose(axis,-1).
|
2191
|
+
return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
1770
2192
|
def cumsum(self, axis:int=0) -> Tensor:
|
1771
2193
|
"""
|
1772
2194
|
Computes the cumulative sum of the tensor along the specified axis.
|
@@ -1786,12 +2208,11 @@ class Tensor:
|
|
1786
2208
|
# TODO: someday the optimizer will find this on it's own
|
1787
2209
|
# for now this is a two stage cumsum
|
1788
2210
|
SPLIT = 256
|
1789
|
-
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
1790
|
-
ret = self.transpose(axis,-1).
|
1791
|
-
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
2211
|
+
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
|
2212
|
+
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
1792
2213
|
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
1793
2214
|
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
1794
|
-
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -
|
2215
|
+
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
1795
2216
|
return fix(ret) + fix(base_add)
|
1796
2217
|
|
1797
2218
|
@staticmethod
|
@@ -1850,6 +2271,38 @@ class Tensor:
|
|
1850
2271
|
"""
|
1851
2272
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
1852
2273
|
|
2274
|
+
def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
2275
|
+
"""
|
2276
|
+
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
|
2277
|
+
|
2278
|
+
The interpolation algorithm is selected with `mode` which currently only supports `linear`, `nearest` and `nearest-exact`.
|
2279
|
+
To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
|
2280
|
+
|
2281
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2282
|
+
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
|
2283
|
+
print(t.numpy())
|
2284
|
+
```
|
2285
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2286
|
+
print(t.interpolate(size=(2,3), mode="linear").numpy())
|
2287
|
+
```
|
2288
|
+
"""
|
2289
|
+
assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
|
2290
|
+
assert mode in ("linear", "nearest", "nearest-exact"), "only supports linear, nearest or nearest-exact interpolate"
|
2291
|
+
assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear"
|
2292
|
+
x, expand = self, list(self.shape)
|
2293
|
+
for i in range(-1,-len(size)-1,-1):
|
2294
|
+
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
|
2295
|
+
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
|
2296
|
+
reshape[i] = expand[i] = size[i]
|
2297
|
+
if mode == "linear":
|
2298
|
+
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
2299
|
+
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
2300
|
+
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
2301
|
+
else:
|
2302
|
+
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
|
2303
|
+
x = x.gather(i, index)
|
2304
|
+
return x.cast(self.dtype)
|
2305
|
+
|
1853
2306
|
# ***** unary ops *****
|
1854
2307
|
|
1855
2308
|
def logical_not(self):
|
@@ -1869,7 +2322,7 @@ class Tensor:
|
|
1869
2322
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
|
1870
2323
|
```
|
1871
2324
|
"""
|
1872
|
-
return
|
2325
|
+
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
1873
2326
|
def contiguous(self):
|
1874
2327
|
"""
|
1875
2328
|
Returns a contiguous tensor.
|
@@ -1946,6 +2399,20 @@ class Tensor:
|
|
1946
2399
|
```
|
1947
2400
|
"""
|
1948
2401
|
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
2402
|
+
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
|
2403
|
+
"""
|
2404
|
+
Applies the Hardsigmoid function element-wise.
|
2405
|
+
NOTE: default `alpha` and `beta` values is taken from torch
|
2406
|
+
|
2407
|
+
- Described: https://paperswithcode.com/method/hard-sigmoid
|
2408
|
+
- See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
|
2409
|
+
|
2410
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2411
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy())
|
2412
|
+
```
|
2413
|
+
"""
|
2414
|
+
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
|
2415
|
+
|
1949
2416
|
def sqrt(self):
|
1950
2417
|
"""
|
1951
2418
|
Computes the square root of the tensor element-wise.
|
@@ -1999,7 +2466,7 @@ class Tensor:
|
|
1999
2466
|
Truncates the tensor element-wise.
|
2000
2467
|
|
2001
2468
|
```python exec="true" source="above" session="tensor" result="python"
|
2002
|
-
print(Tensor([-3.
|
2469
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
|
2003
2470
|
```
|
2004
2471
|
"""
|
2005
2472
|
return self.cast(dtypes.int32).cast(self.dtype)
|
@@ -2008,7 +2475,7 @@ class Tensor:
|
|
2008
2475
|
Rounds the tensor element-wise towards positive infinity.
|
2009
2476
|
|
2010
2477
|
```python exec="true" source="above" session="tensor" result="python"
|
2011
|
-
print(Tensor([-3.
|
2478
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
|
2012
2479
|
```
|
2013
2480
|
"""
|
2014
2481
|
return (self > (b := self.trunc())).where(b+1, b)
|
@@ -2017,19 +2484,39 @@ class Tensor:
|
|
2017
2484
|
Rounds the tensor element-wise towards negative infinity.
|
2018
2485
|
|
2019
2486
|
```python exec="true" source="above" session="tensor" result="python"
|
2020
|
-
print(Tensor([-3.
|
2487
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
|
2021
2488
|
```
|
2022
2489
|
"""
|
2023
2490
|
return (self < (b := self.trunc())).where(b-1, b)
|
2024
2491
|
def round(self: Tensor) -> Tensor:
|
2025
2492
|
"""
|
2026
|
-
Rounds the tensor element-wise.
|
2493
|
+
Rounds the tensor element-wise with rounding half to even.
|
2027
2494
|
|
2028
2495
|
```python exec="true" source="above" session="tensor" result="python"
|
2029
|
-
print(Tensor([-3.
|
2496
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
|
2030
2497
|
```
|
2031
2498
|
"""
|
2032
2499
|
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
2500
|
+
|
2501
|
+
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
|
2502
|
+
"""
|
2503
|
+
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
|
2504
|
+
|
2505
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2506
|
+
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy())
|
2507
|
+
```
|
2508
|
+
"""
|
2509
|
+
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
|
2510
|
+
def isnan(self:Tensor):
|
2511
|
+
"""
|
2512
|
+
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
|
2513
|
+
|
2514
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2515
|
+
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy())
|
2516
|
+
```
|
2517
|
+
"""
|
2518
|
+
return self != self
|
2519
|
+
|
2033
2520
|
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
2034
2521
|
"""
|
2035
2522
|
Linearly interpolates between `self` and `end` by `weight`.
|
@@ -2038,7 +2525,11 @@ class Tensor:
|
|
2038
2525
|
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
|
2039
2526
|
```
|
2040
2527
|
"""
|
2528
|
+
if self.dtype == dtypes.uint8 and isinstance(weight, Tensor):
|
2529
|
+
w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
|
2530
|
+
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
|
2041
2531
|
return self + (end - self) * weight
|
2532
|
+
|
2042
2533
|
def square(self):
|
2043
2534
|
"""
|
2044
2535
|
Squares the tensor element-wise.
|
@@ -2049,15 +2540,23 @@ class Tensor:
|
|
2049
2540
|
```
|
2050
2541
|
"""
|
2051
2542
|
return self*self
|
2052
|
-
def
|
2543
|
+
def clamp(self, min_=None, max_=None):
|
2053
2544
|
"""
|
2054
2545
|
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
2546
|
+
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
2055
2547
|
|
2056
2548
|
```python exec="true" source="above" session="tensor" result="python"
|
2057
2549
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
|
2058
2550
|
```
|
2059
2551
|
"""
|
2060
|
-
|
2552
|
+
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
2553
|
+
ret = self.maximum(min_) if min_ is not None else self
|
2554
|
+
return ret.minimum(max_) if max_ is not None else ret
|
2555
|
+
def clip(self, min_=None, max_=None):
|
2556
|
+
"""
|
2557
|
+
Alias for `Tensor.clamp`.
|
2558
|
+
"""
|
2559
|
+
return self.clamp(min_, max_)
|
2061
2560
|
def sign(self):
|
2062
2561
|
"""
|
2063
2562
|
Returns the sign of the tensor element-wise.
|
@@ -2249,6 +2748,20 @@ class Tensor:
|
|
2249
2748
|
"""
|
2250
2749
|
return self.clip(min_val, max_val)
|
2251
2750
|
|
2751
|
+
def erf(self):
|
2752
|
+
"""
|
2753
|
+
Applies error function element-wise.
|
2754
|
+
|
2755
|
+
- Described: https://en.wikipedia.org/wiki/Error_function
|
2756
|
+
|
2757
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2758
|
+
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
|
2759
|
+
```
|
2760
|
+
"""
|
2761
|
+
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
|
2762
|
+
t = 1.0 / (1.0 + 0.3275911 * self.abs())
|
2763
|
+
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
|
2764
|
+
|
2252
2765
|
def gelu(self):
|
2253
2766
|
"""
|
2254
2767
|
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
@@ -2333,17 +2846,18 @@ class Tensor:
|
|
2333
2846
|
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
2334
2847
|
padded, _ = _pad_left(self.shape, shape)
|
2335
2848
|
# for each dimension, check either from_ is 1, or it does not change
|
2336
|
-
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)):
|
2849
|
+
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
|
2850
|
+
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
2337
2851
|
return F.Expand.apply(self.reshape(padded), shape=shape)
|
2338
2852
|
|
2339
|
-
def _broadcasted(self, y:Union[Tensor,
|
2853
|
+
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
2340
2854
|
x: Tensor = self
|
2341
2855
|
if not isinstance(y, Tensor):
|
2342
2856
|
# make y a Tensor
|
2343
|
-
assert isinstance(y, (
|
2857
|
+
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
2344
2858
|
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
2345
|
-
elif not isinstance(y,
|
2346
|
-
if isinstance(y,
|
2859
|
+
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
|
2860
|
+
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
|
2347
2861
|
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
|
2348
2862
|
|
2349
2863
|
if match_dtype and x.dtype != y.dtype:
|
@@ -2421,12 +2935,25 @@ class Tensor:
|
|
2421
2935
|
"""
|
2422
2936
|
return F.Mul.apply(*self._broadcasted(x, reverse))
|
2423
2937
|
|
2424
|
-
def
|
2938
|
+
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2939
|
+
"""
|
2940
|
+
Divides `self` by `x`.
|
2941
|
+
Equivalent to `self // x`.
|
2942
|
+
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
2943
|
+
`idiv` performs integer division.
|
2944
|
+
|
2945
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2946
|
+
print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy())
|
2947
|
+
```
|
2948
|
+
"""
|
2949
|
+
return F.IDiv.apply(*self._broadcasted(x, reverse))
|
2950
|
+
|
2951
|
+
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2425
2952
|
"""
|
2426
2953
|
Divides `self` by `x`.
|
2427
2954
|
Equivalent to `self / x`.
|
2428
2955
|
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
2429
|
-
|
2956
|
+
`div` performs true division.
|
2430
2957
|
|
2431
2958
|
```python exec="true" source="above" session="tensor" result="python"
|
2432
2959
|
Tensor.manual_seed(42)
|
@@ -2439,13 +2966,9 @@ class Tensor:
|
|
2439
2966
|
```python exec="true" source="above" session="tensor" result="python"
|
2440
2967
|
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
|
2441
2968
|
```
|
2442
|
-
```python exec="true" source="above" session="tensor" result="python"
|
2443
|
-
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
|
2444
|
-
```
|
2445
2969
|
"""
|
2446
2970
|
numerator, denominator = self._broadcasted(x, reverse)
|
2447
|
-
|
2448
|
-
return F.Div.apply(numerator, denominator)
|
2971
|
+
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
2449
2972
|
|
2450
2973
|
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2451
2974
|
"""
|
@@ -2460,8 +2983,53 @@ class Tensor:
|
|
2460
2983
|
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
|
2461
2984
|
```
|
2462
2985
|
"""
|
2986
|
+
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
2463
2987
|
return F.Xor.apply(*self._broadcasted(x, reverse))
|
2464
2988
|
|
2989
|
+
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2990
|
+
"""
|
2991
|
+
Compute the bit-wise AND of `self` and `x`.
|
2992
|
+
Equivalent to `self & x`.
|
2993
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
2994
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2995
|
+
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
2996
|
+
```
|
2997
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2998
|
+
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
2999
|
+
```
|
3000
|
+
"""
|
3001
|
+
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3002
|
+
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
3003
|
+
|
3004
|
+
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3005
|
+
"""
|
3006
|
+
Compute the bit-wise OR of `self` and `x`.
|
3007
|
+
Equivalent to `self | x`.
|
3008
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
3009
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3010
|
+
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
3011
|
+
```
|
3012
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3013
|
+
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
3014
|
+
```
|
3015
|
+
"""
|
3016
|
+
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3017
|
+
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
3018
|
+
|
3019
|
+
def bitwise_not(self) -> Tensor:
|
3020
|
+
"""
|
3021
|
+
Compute the bit-wise NOT of `self`.
|
3022
|
+
Equivalent to `~self`.
|
3023
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3024
|
+
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
|
3025
|
+
```
|
3026
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3027
|
+
print(Tensor([True, False]).bitwise_not().numpy())
|
3028
|
+
```
|
3029
|
+
"""
|
3030
|
+
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3031
|
+
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
|
3032
|
+
|
2465
3033
|
def lshift(self, x:int):
|
2466
3034
|
"""
|
2467
3035
|
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
@@ -2484,7 +3052,7 @@ class Tensor:
|
|
2484
3052
|
```
|
2485
3053
|
"""
|
2486
3054
|
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
2487
|
-
return self.
|
3055
|
+
return self.idiv(2 ** x)
|
2488
3056
|
|
2489
3057
|
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2490
3058
|
"""
|
@@ -2578,42 +3146,35 @@ class Tensor:
|
|
2578
3146
|
|
2579
3147
|
# ***** op wrappers *****
|
2580
3148
|
|
2581
|
-
def
|
3149
|
+
def __invert__(self) -> Tensor: return self.bitwise_not()
|
2582
3150
|
|
2583
|
-
def __add__(self, x) -> Tensor: return self.add(x)
|
2584
|
-
def __sub__(self, x) -> Tensor: return self.sub(x)
|
2585
|
-
def __mul__(self, x) -> Tensor: return self.mul(x)
|
2586
|
-
def __pow__(self, x) -> Tensor: return self.pow(x)
|
2587
|
-
def __truediv__(self, x) -> Tensor: return self.div(x)
|
2588
|
-
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
2589
|
-
def __xor__(self, x) -> Tensor: return self.xor(x)
|
2590
3151
|
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
2591
3152
|
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
2592
3153
|
|
2593
|
-
def
|
2594
|
-
def
|
2595
|
-
|
3154
|
+
def __pow__(self, x) -> Tensor: return self.pow(x)
|
3155
|
+
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
3156
|
+
|
2596
3157
|
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
2597
|
-
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
2598
3158
|
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
2599
|
-
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
2600
3159
|
|
2601
3160
|
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
2602
3161
|
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
2603
3162
|
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
2604
3163
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
2605
3164
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
3165
|
+
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.idiv(x))
|
2606
3166
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
3167
|
+
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
3168
|
+
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
2607
3169
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
2608
3170
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
2609
3171
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
2610
3172
|
|
2611
|
-
def
|
2612
|
-
def
|
2613
|
-
def
|
2614
|
-
|
2615
|
-
def
|
2616
|
-
def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
|
3173
|
+
def lt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
3174
|
+
def gt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
3175
|
+
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
|
3176
|
+
|
3177
|
+
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
2617
3178
|
|
2618
3179
|
# ***** functional nn ops *****
|
2619
3180
|
|
@@ -2644,7 +3205,7 @@ class Tensor:
|
|
2644
3205
|
"""
|
2645
3206
|
return functools.reduce(lambda x,f: f(x), ll, self)
|
2646
3207
|
|
2647
|
-
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
3208
|
+
def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
|
2648
3209
|
"""
|
2649
3210
|
Applies Layer Normalization over a mini-batch of inputs.
|
2650
3211
|
|
@@ -2703,17 +3264,20 @@ class Tensor:
|
|
2703
3264
|
```
|
2704
3265
|
"""
|
2705
3266
|
if not Tensor.training or p == 0: return self
|
2706
|
-
return
|
3267
|
+
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
2707
3268
|
|
2708
|
-
def one_hot(self, num_classes:int) -> Tensor:
|
3269
|
+
def one_hot(self, num_classes:int=-1) -> Tensor:
|
2709
3270
|
"""
|
2710
3271
|
Converts `self` to a one-hot tensor.
|
2711
3272
|
|
3273
|
+
`num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1.
|
3274
|
+
|
2712
3275
|
```python exec="true" source="above" session="tensor" result="python"
|
2713
3276
|
t = Tensor([0, 1, 3, 3, 4])
|
2714
3277
|
print(t.one_hot(5).numpy())
|
2715
3278
|
```
|
2716
3279
|
"""
|
3280
|
+
if num_classes == -1: num_classes = (self.max()+1).item()
|
2717
3281
|
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
2718
3282
|
|
2719
3283
|
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
@@ -2739,39 +3303,45 @@ class Tensor:
|
|
2739
3303
|
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
2740
3304
|
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
2741
3305
|
|
2742
|
-
def
|
3306
|
+
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
3307
|
+
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
3308
|
+
reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
|
3309
|
+
return reductions[reduction](self)
|
3310
|
+
|
3311
|
+
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
2743
3312
|
"""
|
2744
|
-
Computes the binary cross-entropy loss between `self` and `
|
3313
|
+
Computes the binary cross-entropy loss between `self` and `Y`.
|
2745
3314
|
|
2746
3315
|
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
|
2747
3316
|
|
2748
3317
|
```python exec="true" source="above" session="tensor" result="python"
|
2749
3318
|
t = Tensor([0.1, 0.9, 0.2])
|
2750
|
-
|
2751
|
-
print(t.binary_crossentropy(
|
3319
|
+
Y = Tensor([0, 1, 0])
|
3320
|
+
print(t.binary_crossentropy(Y).item())
|
2752
3321
|
```
|
2753
3322
|
"""
|
2754
|
-
return (-
|
3323
|
+
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
|
2755
3324
|
|
2756
|
-
def binary_crossentropy_logits(self,
|
3325
|
+
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
2757
3326
|
"""
|
2758
|
-
Computes the binary cross-entropy loss between `self` and `
|
3327
|
+
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
|
2759
3328
|
|
2760
3329
|
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
|
2761
3330
|
|
2762
3331
|
```python exec="true" source="above" session="tensor" result="python"
|
2763
3332
|
t = Tensor([-1, 2, -3])
|
2764
|
-
|
2765
|
-
print(t.binary_crossentropy_logits(
|
3333
|
+
Y = Tensor([0, 1, 0])
|
3334
|
+
print(t.binary_crossentropy_logits(Y).item())
|
2766
3335
|
```
|
2767
3336
|
"""
|
2768
|
-
return (self.maximum(0) -
|
3337
|
+
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
|
2769
3338
|
|
2770
|
-
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
|
3339
|
+
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
2771
3340
|
"""
|
2772
3341
|
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
|
2773
3342
|
|
2774
3343
|
NOTE: `self` is logits and `Y` is the target labels.
|
3344
|
+
NOTE: unlike PyTorch, this function expects the class axis to be -1
|
2775
3345
|
|
2776
3346
|
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
|
2777
3347
|
|
@@ -2782,19 +3352,145 @@ class Tensor:
|
|
2782
3352
|
```
|
2783
3353
|
"""
|
2784
3354
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
2785
|
-
|
3355
|
+
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
3356
|
+
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
2786
3357
|
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
2787
3358
|
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
2788
|
-
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
2789
|
-
|
3359
|
+
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
3360
|
+
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
3361
|
+
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
3362
|
+
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
|
3363
|
+
|
3364
|
+
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
|
3365
|
+
"""
|
3366
|
+
Compute the cross entropy loss between input logits and target.
|
3367
|
+
|
3368
|
+
NOTE: `self` are logits and `Y` are the target labels or class probabilities.
|
3369
|
+
|
3370
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
3371
|
+
|
3372
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3373
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
3374
|
+
Y = Tensor([1, 2])
|
3375
|
+
print(t.cross_entropy(Y).item())
|
3376
|
+
```
|
3377
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3378
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
3379
|
+
Y = Tensor([1, 2])
|
3380
|
+
print(t.cross_entropy(Y, reduction='none').numpy())
|
3381
|
+
```
|
3382
|
+
"""
|
3383
|
+
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
3384
|
+
Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
|
3385
|
+
Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
|
3386
|
+
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
|
3387
|
+
return ret._do_reduction(reduction)
|
3388
|
+
|
3389
|
+
def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
|
3390
|
+
"""
|
3391
|
+
Compute the negative log likelihood loss between log-probabilities and target labels.
|
3392
|
+
|
3393
|
+
NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
|
3394
|
+
|
3395
|
+
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html
|
3396
|
+
|
3397
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3398
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
3399
|
+
Y = Tensor([1, 2])
|
3400
|
+
print(t.log_softmax().nll_loss(Y).item())
|
3401
|
+
```
|
3402
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3403
|
+
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
3404
|
+
Y = Tensor([1, 2])
|
3405
|
+
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
|
3406
|
+
```
|
3407
|
+
"""
|
3408
|
+
weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
|
3409
|
+
masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
|
3410
|
+
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
3411
|
+
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
3412
|
+
|
3413
|
+
# ***** Tensor Properties *****
|
3414
|
+
|
3415
|
+
@property
|
3416
|
+
def ndim(self) -> int:
|
3417
|
+
"""
|
3418
|
+
Returns the number of dimensions in the tensor.
|
3419
|
+
|
3420
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3421
|
+
t = Tensor([[1, 2], [3, 4]])
|
3422
|
+
print(t.ndim)
|
3423
|
+
```
|
3424
|
+
"""
|
3425
|
+
return len(self.shape)
|
3426
|
+
|
3427
|
+
def numel(self) -> sint:
|
3428
|
+
"""
|
3429
|
+
Returns the total number of elements in the tensor.
|
3430
|
+
|
3431
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3432
|
+
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
3433
|
+
print(t.numel())
|
3434
|
+
```
|
3435
|
+
"""
|
3436
|
+
return prod(self.shape)
|
3437
|
+
|
3438
|
+
def element_size(self) -> int:
|
3439
|
+
"""
|
3440
|
+
Returns the size in bytes of an individual element in the tensor.
|
3441
|
+
|
3442
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3443
|
+
t = Tensor([5], dtype=dtypes.int16)
|
3444
|
+
print(t.element_size())
|
3445
|
+
```
|
3446
|
+
"""
|
3447
|
+
return self.dtype.itemsize
|
3448
|
+
|
3449
|
+
def nbytes(self) -> int:
|
3450
|
+
"""
|
3451
|
+
Returns the total number of bytes of all elements in the tensor.
|
3452
|
+
|
3453
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3454
|
+
t = Tensor([8, 9], dtype=dtypes.float)
|
3455
|
+
print(t.nbytes())
|
3456
|
+
```
|
3457
|
+
"""
|
3458
|
+
return self.numel() * self.element_size()
|
3459
|
+
|
3460
|
+
def is_floating_point(self) -> bool:
|
3461
|
+
"""
|
3462
|
+
Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
|
3463
|
+
`dtype.float16`, `dtype.bfloat16`.
|
3464
|
+
|
3465
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3466
|
+
t = Tensor([8, 9], dtype=dtypes.float32)
|
3467
|
+
print(t.is_floating_point())
|
3468
|
+
```
|
3469
|
+
"""
|
3470
|
+
return dtypes.is_float(self.dtype)
|
3471
|
+
|
3472
|
+
def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
|
3473
|
+
"""
|
3474
|
+
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
3475
|
+
|
3476
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3477
|
+
t = Tensor([[4, 5, 6], [7, 8, 9]])
|
3478
|
+
print(t.size())
|
3479
|
+
```
|
3480
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3481
|
+
print(t.size(dim=1))
|
3482
|
+
```
|
3483
|
+
"""
|
3484
|
+
return self.shape if dim is None else self.shape[dim]
|
2790
3485
|
|
2791
3486
|
# ***** cast ops *****
|
2792
3487
|
|
2793
|
-
def llvm_bf16_cast(self, dtype:
|
3488
|
+
def llvm_bf16_cast(self, dtype:DTypeLike):
|
2794
3489
|
# hack for devices that don't support bfloat16
|
2795
3490
|
assert self.dtype == dtypes.bfloat16
|
2796
3491
|
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
2797
|
-
|
3492
|
+
|
3493
|
+
def cast(self, dtype:DTypeLike) -> Tensor:
|
2798
3494
|
"""
|
2799
3495
|
Casts `self` to the given `dtype`.
|
2800
3496
|
|
@@ -2807,8 +3503,9 @@ class Tensor:
|
|
2807
3503
|
print(t.dtype, t.numpy())
|
2808
3504
|
```
|
2809
3505
|
"""
|
2810
|
-
return self if self.dtype == dtype else F.Cast.apply(self, dtype=
|
2811
|
-
|
3506
|
+
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
|
3507
|
+
|
3508
|
+
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
2812
3509
|
"""
|
2813
3510
|
Bitcasts `self` to the given `dtype` of the same itemsize.
|
2814
3511
|
|
@@ -2824,7 +3521,15 @@ class Tensor:
|
|
2824
3521
|
```
|
2825
3522
|
"""
|
2826
3523
|
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
2827
|
-
|
3524
|
+
dt = to_dtype(dtype)
|
3525
|
+
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
|
3526
|
+
if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
3527
|
+
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
3528
|
+
tmp = self.bitcast(old_uint)
|
3529
|
+
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
3530
|
+
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
3531
|
+
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
|
3532
|
+
|
2828
3533
|
def float(self) -> Tensor:
|
2829
3534
|
"""
|
2830
3535
|
Convenience method to cast `self` to a `float32` Tensor.
|
@@ -2839,6 +3544,7 @@ class Tensor:
|
|
2839
3544
|
```
|
2840
3545
|
"""
|
2841
3546
|
return self.cast(dtypes.float32)
|
3547
|
+
|
2842
3548
|
def half(self) -> Tensor:
|
2843
3549
|
"""
|
2844
3550
|
Convenience method to cast `self` to a `float16` Tensor.
|
@@ -2854,23 +3560,44 @@ class Tensor:
|
|
2854
3560
|
"""
|
2855
3561
|
return self.cast(dtypes.float16)
|
2856
3562
|
|
2857
|
-
|
3563
|
+
def int(self) -> Tensor:
|
3564
|
+
"""
|
3565
|
+
Convenience method to cast `self` to a `int32` Tensor.
|
2858
3566
|
|
2859
|
-
|
2860
|
-
|
2861
|
-
|
2862
|
-
|
2863
|
-
|
2864
|
-
|
2865
|
-
|
3567
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3568
|
+
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
3569
|
+
print(t.dtype, t.numpy())
|
3570
|
+
```
|
3571
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3572
|
+
t = t.int()
|
3573
|
+
print(t.dtype, t.numpy())
|
3574
|
+
```
|
3575
|
+
"""
|
3576
|
+
return self.cast(dtypes.int32)
|
3577
|
+
|
3578
|
+
def bool(self) -> Tensor:
|
3579
|
+
"""
|
3580
|
+
Convenience method to cast `self` to a `bool` Tensor.
|
3581
|
+
|
3582
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3583
|
+
t = Tensor([-1, 0, 1])
|
3584
|
+
print(t.dtype, t.numpy())
|
3585
|
+
```
|
3586
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3587
|
+
t = t.bool()
|
3588
|
+
print(t.dtype, t.numpy())
|
3589
|
+
```
|
3590
|
+
"""
|
3591
|
+
return self.cast(dtypes.bool)
|
2866
3592
|
|
2867
3593
|
# *** image Tensor function replacements ***
|
2868
3594
|
|
2869
|
-
def image_dot(self, w:Tensor, acc_dtype=None):
|
3595
|
+
def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
2870
3596
|
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
2871
|
-
|
2872
|
-
|
2873
|
-
|
3597
|
+
x, dx, dw = self, self.ndim, w.ndim
|
3598
|
+
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
3599
|
+
if x.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {x.shape} and {w.shape}")
|
3600
|
+
|
2874
3601
|
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
|
2875
3602
|
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
2876
3603
|
|
@@ -2881,7 +3608,7 @@ class Tensor:
|
|
2881
3608
|
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
2882
3609
|
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
2883
3610
|
|
2884
|
-
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
|
3611
|
+
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
|
2885
3612
|
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
2886
3613
|
|
2887
3614
|
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
@@ -2922,12 +3649,8 @@ class Tensor:
|
|
2922
3649
|
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
2923
3650
|
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
2924
3651
|
|
2925
|
-
# padding
|
2926
|
-
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
2927
|
-
x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
2928
|
-
|
2929
3652
|
# prepare input
|
2930
|
-
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
3653
|
+
x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
2931
3654
|
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
2932
3655
|
|
2933
3656
|
# prepare weights
|
@@ -2945,18 +3668,39 @@ class Tensor:
|
|
2945
3668
|
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
2946
3669
|
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
2947
3670
|
|
2948
|
-
|
2949
|
-
|
2950
|
-
|
2951
|
-
|
2952
|
-
|
2953
|
-
|
2954
|
-
|
2955
|
-
|
2956
|
-
|
2957
|
-
|
2958
|
-
|
2959
|
-
|
2960
|
-
|
2961
|
-
|
2962
|
-
|
3671
|
+
def _metadata_wrapper(fn):
|
3672
|
+
def _wrapper(*args, **kwargs):
|
3673
|
+
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
3674
|
+
|
3675
|
+
if TRACEMETA >= 2:
|
3676
|
+
caller_frame = sys._getframe(frame := 1)
|
3677
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3678
|
+
caller_func = caller_frame.f_code.co_name
|
3679
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3680
|
+
|
3681
|
+
# if its called from nn we want to step up frames until we are out of nn
|
3682
|
+
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
|
3683
|
+
caller_frame = sys._getframe(frame := frame + 1)
|
3684
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3685
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3686
|
+
|
3687
|
+
# if its called from a lambda in tinygrad we want to look two more frames up
|
3688
|
+
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
|
3689
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3690
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3691
|
+
caller_func = caller_frame.f_code.co_name
|
3692
|
+
caller_lineno = caller_frame.f_lineno
|
3693
|
+
|
3694
|
+
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
|
3695
|
+
else: caller = ""
|
3696
|
+
|
3697
|
+
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
3698
|
+
ret = fn(*args, **kwargs)
|
3699
|
+
_METADATA.reset(token)
|
3700
|
+
return ret
|
3701
|
+
return _wrapper
|
3702
|
+
|
3703
|
+
if TRACEMETA >= 1:
|
3704
|
+
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
3705
|
+
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
3706
|
+
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|