tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -2,65 +2,55 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
4
4
|
from contextlib import ContextDecorator
|
5
|
-
from typing import Callable,
|
5
|
+
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
|
6
6
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
7
|
+
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
7
8
|
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
8
|
-
from tinygrad.helpers import IMAGE, WINO,
|
9
|
-
from tinygrad.engine.multi import get_multi_map
|
9
|
+
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray
|
10
10
|
from tinygrad.gradient import compute_gradient
|
11
|
-
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable,
|
12
|
-
from tinygrad.spec import tensor_uop_spec, type_verify
|
13
|
-
from tinygrad.device import Device,
|
11
|
+
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
|
12
|
+
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
13
|
+
from tinygrad.device import Device, Buffer
|
14
14
|
from tinygrad.engine.realize import run_schedule
|
15
15
|
from tinygrad.engine.memory import memory_planner
|
16
16
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
17
|
+
from tinygrad.schedule.kernelize import get_kernelize_map
|
17
18
|
|
18
19
|
# *** all in scope Tensors are here. this gets relevant UOps ***
|
19
20
|
|
20
|
-
all_tensors:
|
21
|
+
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
22
|
+
def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
|
23
|
+
return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
|
21
24
|
|
22
|
-
def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
25
|
+
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
|
23
26
|
# get all children of keys in applied_map
|
24
27
|
all_uops: set[UOp] = set()
|
25
28
|
search_uops = list(applied_map)
|
26
29
|
while len(search_uops):
|
27
|
-
x = search_uops.pop(
|
30
|
+
x = search_uops.pop()
|
28
31
|
if x in all_uops: continue
|
29
32
|
all_uops.add(x)
|
30
33
|
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
31
34
|
|
32
35
|
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
33
36
|
# NOTE: this uses all_tensors, but it's fast
|
34
|
-
|
35
|
-
|
36
|
-
if len(fixed_tensors):
|
37
|
+
if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
|
37
38
|
# potentially rewrite all the discovered Tensors
|
38
|
-
sink = UOp.sink(*[t.
|
39
|
-
new_sink = sink.substitute(applied_map)
|
39
|
+
sink = UOp.sink(*[t.uop for t in fixed_tensors])
|
40
|
+
new_sink = sink.substitute(applied_map, name=name)
|
40
41
|
|
41
|
-
# set the relevant
|
42
|
+
# set the relevant uop to the realized UOps
|
42
43
|
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
43
44
|
if s is ns: continue
|
44
|
-
t.
|
45
|
+
t.uop = ns
|
45
46
|
|
46
47
|
# **** Tensor helper functions ****
|
47
48
|
|
48
|
-
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
49
|
-
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
50
|
-
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
51
|
-
|
52
|
-
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
53
|
-
import numpy as np
|
54
|
-
return dtypes.fields()[np.dtype(npdtype).name]
|
55
|
-
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
56
|
-
import numpy as np
|
57
|
-
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
58
|
-
|
59
49
|
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
60
|
-
ret = UOp.
|
50
|
+
ret = UOp.new_buffer("NPY", x.size, _from_np_dtype(x.dtype))
|
61
51
|
# fake realize
|
62
52
|
ret.buffer.allocate(x)
|
63
|
-
return ret
|
53
|
+
return ret.reshape(x.shape)
|
64
54
|
|
65
55
|
def get_shape(x) -> tuple[int, ...]:
|
66
56
|
# NOTE: str is special because __getitem__ on a str is still a str
|
@@ -68,10 +58,10 @@ def get_shape(x) -> tuple[int, ...]:
|
|
68
58
|
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
|
69
59
|
return (len(subs),) + (subs[0] if subs else ())
|
70
60
|
|
71
|
-
def _frompy(x:
|
72
|
-
if isinstance(x, bytes): ret, data = UOp.
|
61
|
+
def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
|
62
|
+
if isinstance(x, bytes): ret, data = UOp.new_buffer("PYTHON", len(x)//dtype.itemsize, dtype), x
|
73
63
|
else:
|
74
|
-
ret = UOp.
|
64
|
+
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
|
75
65
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
76
66
|
truncate_function = truncate[dtype]
|
77
67
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
@@ -79,7 +69,7 @@ def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
|
|
79
69
|
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
80
70
|
return ret
|
81
71
|
|
82
|
-
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:
|
72
|
+
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]:
|
83
73
|
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
|
84
74
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
85
75
|
|
@@ -102,13 +92,12 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
|
102
92
|
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
103
93
|
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
104
94
|
|
105
|
-
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
|
106
|
-
#
|
107
|
-
values = values * mask
|
95
|
+
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
|
96
|
+
# reduce such that if mask contains repeated indices the last one remains
|
108
97
|
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
109
98
|
# remove extra dims from reduce
|
110
99
|
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
111
|
-
# select from values for each True element in mask else select from
|
100
|
+
# select from values for each True element in mask else select from target
|
112
101
|
return mask.where(values, target)
|
113
102
|
|
114
103
|
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
@@ -116,7 +105,7 @@ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: r
|
|
116
105
|
|
117
106
|
ReductionStr = Literal["mean", "sum", "none"]
|
118
107
|
|
119
|
-
class Tensor(
|
108
|
+
class Tensor(MathTrait):
|
120
109
|
"""
|
121
110
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
122
111
|
|
@@ -127,70 +116,76 @@ class Tensor(SimpleMathTrait):
|
|
127
116
|
np.set_printoptions(precision=4)
|
128
117
|
```
|
129
118
|
"""
|
130
|
-
__slots__ = "
|
119
|
+
__slots__ = "uop", "requires_grad", "grad"
|
131
120
|
training: ClassVar[bool] = False
|
132
|
-
no_grad: ClassVar[bool] = False
|
133
121
|
|
134
|
-
def __init__(self, data:
|
135
|
-
device:
|
122
|
+
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
123
|
+
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
|
136
124
|
if dtype is not None: dtype = to_dtype(dtype)
|
137
125
|
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
138
126
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
139
127
|
|
140
128
|
# tensors can have gradients if you have called .backward
|
141
|
-
self.grad:
|
129
|
+
self.grad:Tensor|None = None
|
142
130
|
|
143
131
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
144
132
|
# None (the default) will be updated to True if it's put in an optimizer
|
145
|
-
self.requires_grad:
|
133
|
+
self.requires_grad:bool|None = requires_grad
|
146
134
|
|
147
|
-
# create a
|
135
|
+
# create a UOp from the different types of inputs
|
148
136
|
if isinstance(data, UOp):
|
149
137
|
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
138
|
+
if data.op is Ops.BIND:
|
139
|
+
var, val = data.unbind()
|
140
|
+
# give the bound constant a device
|
141
|
+
const = UOp.const(var.dtype, val, device, ())
|
142
|
+
data = data.replace(src=(var.replace(src=const.src), const))
|
143
|
+
elif data is None: data = UOp.const(dtype or dtypes.default_float, 0, device, ())
|
144
|
+
elif isinstance(data, get_args(ConstType)): data = UOp.const(dtype or dtypes.from_py(data), data, device, ())
|
154
145
|
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
155
146
|
elif isinstance(data, (list, tuple)):
|
156
147
|
if dtype is None:
|
157
148
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
158
149
|
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
159
|
-
if dtype
|
150
|
+
if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).uop
|
160
151
|
else: data = _frompy(data, dtype)
|
161
|
-
elif
|
152
|
+
elif is_numpy_ndarray(data):
|
162
153
|
import numpy as np
|
163
154
|
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
164
|
-
if data.shape == (): data =
|
155
|
+
if data.shape == (): data = UOp.const(dtype or _from_np_dtype(data.dtype), data.item(), device, ())
|
165
156
|
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
166
157
|
elif isinstance(data, pathlib.Path):
|
167
158
|
dtype = dtype or dtypes.uint8
|
168
|
-
data =
|
159
|
+
data = UOp.new_buffer(f"DISK:{data.resolve()}", data.stat().st_size // dtype.itemsize, dtype)
|
169
160
|
|
170
161
|
# by this point, it has to be a UOp
|
171
162
|
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
172
163
|
|
173
164
|
# data might be on a different device
|
174
|
-
if isinstance(device, str): self.
|
165
|
+
if isinstance(device, str): self.uop:UOp = data if data.device == device else data.copy_to_device(device)
|
175
166
|
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
176
|
-
elif isinstance(data
|
167
|
+
elif isinstance(data.device, str): self.uop = Tensor(data).shard(device).uop
|
177
168
|
else:
|
178
169
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
179
|
-
self.
|
170
|
+
self.uop = data
|
180
171
|
|
181
172
|
# add to all_tensors after construction succeeds
|
182
|
-
all_tensors
|
183
|
-
def __del__(self): all_tensors.
|
173
|
+
all_tensors[weakref.ref(self)] = None
|
174
|
+
def __del__(self): all_tensors.pop(weakref.ref(self), None)
|
184
175
|
|
185
176
|
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
186
|
-
new_uop: UOp = fxn(*[t.
|
177
|
+
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], **kwargs)
|
178
|
+
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
187
179
|
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
188
180
|
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
189
181
|
|
190
|
-
def _apply_broadcasted_uop(self, fxn:Callable, x:
|
182
|
+
def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
|
191
183
|
lhs,rhs = self._broadcasted(x, reverse)
|
192
184
|
return lhs._apply_uop(fxn, rhs)
|
193
185
|
|
186
|
+
# _binop is used by MathTrait
|
187
|
+
def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse)
|
188
|
+
|
194
189
|
def requires_grad_(self, requires_grad=True) -> Tensor:
|
195
190
|
self.requires_grad = requires_grad
|
196
191
|
return self
|
@@ -200,15 +195,10 @@ class Tensor(SimpleMathTrait):
|
|
200
195
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
201
196
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
202
197
|
|
203
|
-
class test(ContextDecorator):
|
204
|
-
def __init__(self, mode:bool = True): self.mode = mode
|
205
|
-
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
206
|
-
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
207
|
-
|
208
198
|
def __repr__(self):
|
209
|
-
ld = self.
|
199
|
+
ld = self.uop
|
210
200
|
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
211
|
-
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.
|
201
|
+
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.uop if self.grad is not None else None)!r}>"
|
212
202
|
|
213
203
|
# Python has a non moving GC, so this should be okay
|
214
204
|
def __hash__(self): return id(self)
|
@@ -220,36 +210,50 @@ class Tensor(SimpleMathTrait):
|
|
220
210
|
return self.shape[0]
|
221
211
|
|
222
212
|
@property
|
223
|
-
def device(self) ->
|
213
|
+
def device(self) -> str|tuple[str, ...]: return self.uop.device
|
224
214
|
|
225
215
|
@property
|
226
|
-
def shape(self) -> tuple[sint, ...]: return self.
|
216
|
+
def shape(self) -> tuple[sint, ...]: return self.uop.shape
|
227
217
|
|
228
218
|
@property
|
229
|
-
def dtype(self) -> DType: return self.
|
219
|
+
def dtype(self) -> DType: return self.uop.dtype
|
230
220
|
|
231
221
|
# ***** data handlers ****
|
232
222
|
|
223
|
+
def kernelize(self, *lst:Tensor) -> Tensor:
|
224
|
+
"""
|
225
|
+
Creates the kernels and buffers needed to realize these Tensor(s).
|
226
|
+
|
227
|
+
NOTE: Kernelize can be called multiple times on a Tensor
|
228
|
+
"""
|
229
|
+
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
230
|
+
|
231
|
+
# verify Tensors match the spec
|
232
|
+
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
|
233
|
+
|
234
|
+
becomes_map = get_kernelize_map(big_sink)
|
235
|
+
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
236
|
+
return self
|
237
|
+
|
233
238
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
234
239
|
"""
|
235
240
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
236
241
|
|
237
242
|
NOTE: A Tensor can only be scheduled once.
|
238
243
|
"""
|
239
|
-
|
244
|
+
st = time.perf_counter()
|
245
|
+
self.kernelize(*lst)
|
246
|
+
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
240
247
|
|
241
|
-
#
|
242
|
-
|
243
|
-
|
244
|
-
_apply_map_to_tensors(get_multi_map(big_sink))
|
245
|
-
big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
|
246
|
-
|
247
|
-
# verify Tensors match the spec
|
248
|
-
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
248
|
+
# remove all ASSIGNs, after scheduling, the tensors are just buffers
|
249
|
+
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
|
250
|
+
_apply_map_to_tensors(remove_assign_map, name="Remove Assigns")
|
249
251
|
|
250
|
-
|
251
|
-
|
252
|
-
|
252
|
+
# create the schedule
|
253
|
+
schedule, var_vals = create_schedule_with_vars(sink)
|
254
|
+
schedule = memory_planner(schedule)
|
255
|
+
if DEBUG >= 1 and len(schedule) > 1: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
256
|
+
return schedule, var_vals
|
253
257
|
|
254
258
|
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
255
259
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
@@ -262,46 +266,43 @@ class Tensor(SimpleMathTrait):
|
|
262
266
|
run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
|
263
267
|
return self
|
264
268
|
|
265
|
-
def replace(self, x:Tensor) -> Tensor:
|
269
|
+
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
266
270
|
"""
|
267
271
|
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
268
272
|
"""
|
269
273
|
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
270
|
-
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
271
|
-
self.
|
274
|
+
assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
|
275
|
+
self.uop = x.uop
|
272
276
|
return self
|
273
277
|
|
274
278
|
def assign(self, x) -> Tensor:
|
275
279
|
# TODO: this is a hack for writing to DISK. remove with working assign
|
276
280
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
277
281
|
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
278
|
-
self.
|
282
|
+
self._buffer().copyin(x._data())
|
279
283
|
return self
|
280
284
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
281
|
-
if self.
|
285
|
+
if self.uop is x.uop: return self # a self assign is a NOOP
|
282
286
|
# NOTE: we allow cross device assign
|
287
|
+
# broadcast x
|
288
|
+
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
|
283
289
|
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
284
290
|
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
285
291
|
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
286
|
-
|
287
|
-
if not self.lazydata.is_realized: return self.replace(x)
|
288
|
-
self.lazydata = self.lazydata.assign(x.lazydata)
|
292
|
+
self.uop = self.uop.assign(x.uop)
|
289
293
|
return self
|
290
294
|
|
291
295
|
def detach(self) -> Tensor:
|
292
296
|
"""
|
293
297
|
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
294
298
|
"""
|
295
|
-
return Tensor(self.
|
299
|
+
return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
|
296
300
|
|
297
|
-
def
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
303
|
-
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
304
|
-
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
301
|
+
def _buffer(self) -> Buffer:
|
302
|
+
x = self.cast(self.dtype.base).contiguous()
|
303
|
+
if isinstance(self.device, tuple): x = x.to("CPU")
|
304
|
+
return cast(Buffer, x.realize().uop.base.buffer).ensure_allocated()
|
305
|
+
def _data(self) -> memoryview: return self._buffer().as_buffer()
|
305
306
|
|
306
307
|
def data(self) -> memoryview:
|
307
308
|
"""
|
@@ -312,10 +313,9 @@ class Tensor(SimpleMathTrait):
|
|
312
313
|
print(np.frombuffer(t.data(), dtype=np.int32))
|
313
314
|
```
|
314
315
|
"""
|
315
|
-
|
316
|
+
if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt)
|
316
317
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
317
|
-
|
318
|
-
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
318
|
+
return self._buffer().as_typed_buffer(self.shape)
|
319
319
|
|
320
320
|
def item(self) -> ConstType:
|
321
321
|
"""
|
@@ -331,7 +331,7 @@ class Tensor(SimpleMathTrait):
|
|
331
331
|
|
332
332
|
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
|
333
333
|
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
334
|
-
def tolist(self) ->
|
334
|
+
def tolist(self) -> Sequence[ConstType]|ConstType:
|
335
335
|
"""
|
336
336
|
Returns the value of this tensor as a nested list.
|
337
337
|
Returns single value for const tensor.
|
@@ -345,6 +345,7 @@ class Tensor(SimpleMathTrait):
|
|
345
345
|
print(t.tolist())
|
346
346
|
```
|
347
347
|
"""
|
348
|
+
if self.dtype in (dtypes.bfloat16, *dtypes.fp8s): return self.cast(dtypes.float32).tolist()
|
348
349
|
return self.data().tolist()
|
349
350
|
|
350
351
|
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
@@ -356,32 +357,32 @@ class Tensor(SimpleMathTrait):
|
|
356
357
|
print(repr(t.numpy()))
|
357
358
|
```
|
358
359
|
"""
|
360
|
+
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
359
361
|
import numpy as np
|
360
362
|
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
|
361
|
-
|
362
|
-
|
363
|
-
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
|
363
|
+
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
|
364
|
+
return self._buffer().numpy().reshape(self.shape)
|
364
365
|
|
365
366
|
def clone(self) -> Tensor:
|
366
367
|
"""
|
367
368
|
Creates a clone of this tensor allocating a separate buffer for the data.
|
368
369
|
"""
|
369
|
-
ret = Tensor(self.
|
370
|
+
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
|
370
371
|
if self.grad is not None: ret.grad = self.grad.clone()
|
371
|
-
return ret
|
372
|
+
return ret.assign(self)
|
372
373
|
|
373
|
-
def to(self, device:
|
374
|
+
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
|
374
375
|
"""
|
375
376
|
Moves the tensor to the given device.
|
376
377
|
"""
|
377
378
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
378
379
|
if device == self.device: return self
|
379
380
|
if not isinstance(device, str): return self.shard(device)
|
380
|
-
ret = Tensor(self.
|
381
|
+
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
|
381
382
|
if self.grad is not None: ret.grad = self.grad.to(device)
|
382
383
|
return ret
|
383
384
|
|
384
|
-
def to_(self, device:
|
385
|
+
def to_(self, device:str|tuple[str, ...]|None) -> Tensor:
|
385
386
|
"""
|
386
387
|
Moves the tensor to the given device in place.
|
387
388
|
"""
|
@@ -389,21 +390,21 @@ class Tensor(SimpleMathTrait):
|
|
389
390
|
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
390
391
|
return self.replace(real)
|
391
392
|
|
392
|
-
def shard(self, devices:tuple[str, ...], axis:
|
393
|
+
def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
393
394
|
"""
|
394
395
|
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
395
396
|
|
396
397
|
```python exec="true" source="above" session="tensor" result="python"
|
397
398
|
t = Tensor.empty(2, 4)
|
398
|
-
print(t.shard((t.device, t.device), axis=1).
|
399
|
+
print(t.shard((t.device, t.device), axis=1).uop)
|
399
400
|
```
|
400
401
|
"""
|
401
402
|
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
402
403
|
devices = tuple(Device.canonicalize(x) for x in devices)
|
403
|
-
mlb = self.
|
404
|
+
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
404
405
|
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
405
406
|
|
406
|
-
def shard_(self, devices:tuple[str, ...], axis:
|
407
|
+
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
407
408
|
"""
|
408
409
|
Shards the tensor across the given devices in place.
|
409
410
|
"""
|
@@ -411,7 +412,7 @@ class Tensor(SimpleMathTrait):
|
|
411
412
|
|
412
413
|
@staticmethod
|
413
414
|
def from_uop(y:UOp, **kwargs) -> Tensor:
|
414
|
-
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
|
415
|
+
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
|
415
416
|
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
416
417
|
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
417
418
|
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
@@ -420,15 +421,7 @@ class Tensor(SimpleMathTrait):
|
|
420
421
|
# ***** creation entrypoint *****
|
421
422
|
|
422
423
|
@staticmethod
|
423
|
-
def
|
424
|
-
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
425
|
-
if isinstance(device, tuple):
|
426
|
-
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
427
|
-
device, dtype, **kwargs)
|
428
|
-
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
429
|
-
|
430
|
-
@staticmethod
|
431
|
-
def empty(*shape, **kwargs):
|
424
|
+
def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
432
425
|
"""
|
433
426
|
Creates an empty tensor with the given shape.
|
434
427
|
|
@@ -440,7 +433,11 @@ class Tensor(SimpleMathTrait):
|
|
440
433
|
print(t.shape)
|
441
434
|
```
|
442
435
|
"""
|
443
|
-
|
436
|
+
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
|
437
|
+
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
438
|
+
# TODO: add test for multidevice tensor
|
439
|
+
device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device)
|
440
|
+
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape)
|
444
441
|
|
445
442
|
@staticmethod
|
446
443
|
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
@@ -451,21 +448,21 @@ class Tensor(SimpleMathTrait):
|
|
451
448
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
452
449
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
453
450
|
"""
|
454
|
-
|
455
|
-
|
456
|
-
r.
|
451
|
+
r = Tensor.empty(*shape, **kwargs)
|
452
|
+
assert isinstance(r.device, str)
|
453
|
+
cast(Buffer, r.uop.buffer).allocate(external_ptr=ptr)
|
457
454
|
return r
|
458
455
|
|
459
456
|
@staticmethod
|
460
457
|
def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
|
461
458
|
"""
|
462
|
-
|
459
|
+
Creates a Tensor from a URL.
|
463
460
|
|
464
461
|
This is the preferred way to access Internet resources.
|
465
462
|
It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
|
466
463
|
This also will soon become lazy (when possible) and not print progress without DEBUG.
|
467
464
|
|
468
|
-
|
465
|
+
The `gunzip` flag will gzip extract the resource and return an extracted Tensor.
|
469
466
|
"""
|
470
467
|
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
|
471
468
|
|
@@ -473,7 +470,7 @@ class Tensor(SimpleMathTrait):
|
|
473
470
|
_device_seeds: dict[str, Tensor] = {}
|
474
471
|
_device_rng_counters: dict[str, Tensor] = {}
|
475
472
|
@staticmethod
|
476
|
-
def manual_seed(seed=0):
|
473
|
+
def manual_seed(seed=0) -> None:
|
477
474
|
"""
|
478
475
|
Sets the seed for random operations.
|
479
476
|
|
@@ -491,14 +488,14 @@ class Tensor(SimpleMathTrait):
|
|
491
488
|
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
492
489
|
|
493
490
|
@staticmethod
|
494
|
-
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
491
|
+
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor:
|
495
492
|
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
496
493
|
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
497
494
|
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
498
495
|
return counts0.cat(counts1)
|
499
496
|
|
500
497
|
@staticmethod
|
501
|
-
def rand(*shape, device:
|
498
|
+
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor:
|
502
499
|
"""
|
503
500
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
504
501
|
|
@@ -514,26 +511,24 @@ class Tensor(SimpleMathTrait):
|
|
514
511
|
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
515
512
|
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
516
513
|
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
517
|
-
|
514
|
+
device = Device.canonicalize(device)
|
518
515
|
|
519
516
|
# if shape has 0, return zero tensor
|
520
|
-
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=
|
517
|
+
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
521
518
|
num = ceildiv(numel * dtype.itemsize, 4)
|
522
519
|
|
523
|
-
# when using MOCKGPU and NV generate rand on CPU
|
524
|
-
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
525
|
-
|
526
520
|
# generate per device seeds and rng counter if we haven't seen this device yet
|
527
521
|
if device not in Tensor._device_seeds:
|
528
522
|
Tensor._device_seeds[device] = Tensor(
|
529
523
|
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
530
524
|
device=device, dtype=dtypes.uint32, requires_grad=False)
|
531
|
-
Tensor._device_rng_counters[device] = Tensor([
|
525
|
+
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
|
532
526
|
# increment rng counter for devices
|
533
527
|
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
534
528
|
|
535
529
|
# threefry random bits
|
536
|
-
|
530
|
+
bits_count = Tensor._device_rng_counters[device] - num
|
531
|
+
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
|
537
532
|
counts1 = counts0 + ceildiv(num, 2)
|
538
533
|
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
539
534
|
|
@@ -545,12 +540,7 @@ class Tensor(SimpleMathTrait):
|
|
545
540
|
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
546
541
|
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
547
542
|
# bitcast back to the original dtype and reshape
|
548
|
-
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
549
|
-
|
550
|
-
# move back to the original device if we were using MOCKGPU
|
551
|
-
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
552
|
-
|
553
|
-
out.requires_grad = kwargs.get("requires_grad")
|
543
|
+
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
|
554
544
|
return out.contiguous() if contiguous else out
|
555
545
|
|
556
546
|
# ***** creation helper functions *****
|
@@ -638,7 +628,7 @@ class Tensor(SimpleMathTrait):
|
|
638
628
|
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
639
629
|
|
640
630
|
@staticmethod
|
641
|
-
def linspace(start:
|
631
|
+
def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor:
|
642
632
|
"""
|
643
633
|
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
|
644
634
|
|
@@ -658,7 +648,7 @@ class Tensor(SimpleMathTrait):
|
|
658
648
|
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
659
649
|
|
660
650
|
@staticmethod
|
661
|
-
def eye(n:int, m:
|
651
|
+
def eye(n:int, m:int|None=None, **kwargs) -> Tensor:
|
662
652
|
"""
|
663
653
|
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
664
654
|
|
@@ -674,7 +664,7 @@ class Tensor(SimpleMathTrait):
|
|
674
664
|
```
|
675
665
|
"""
|
676
666
|
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
677
|
-
x = Tensor.ones(
|
667
|
+
x = Tensor.ones(n, **kwargs).diag()
|
678
668
|
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
|
679
669
|
|
680
670
|
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
@@ -735,17 +725,34 @@ class Tensor(SimpleMathTrait):
|
|
735
725
|
dtype = kwargs.pop("dtype", self.dtype)
|
736
726
|
if isinstance(self.device, tuple):
|
737
727
|
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
738
|
-
if self.
|
728
|
+
if self.uop.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
739
729
|
contiguous = kwargs.pop("contiguous", True)
|
740
|
-
sharded_shape = tuple(s//len(self.device) if a==self.
|
741
|
-
rands =
|
742
|
-
|
730
|
+
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
|
731
|
+
rands = UOp(Ops.MSTACK, dtype=dtype,
|
732
|
+
src=tuple([Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).uop for d in self.device]))
|
733
|
+
return Tensor(UOp.multi(rands, axis=self.uop.axis), device=self.device, dtype=dtype, **kwargs)
|
743
734
|
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
744
735
|
|
745
736
|
# ***** rng hlops *****
|
746
737
|
|
738
|
+
def randn_like(self, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
739
|
+
"""
|
740
|
+
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
|
741
|
+
|
742
|
+
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
743
|
+
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
744
|
+
|
745
|
+
```python exec="true" source="above" session="tensor" result="python"
|
746
|
+
t = Tensor.ones(2, 3)
|
747
|
+
print(Tensor.randn_like(t).numpy())
|
748
|
+
```
|
749
|
+
"""
|
750
|
+
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
|
751
|
+
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
752
|
+
return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)).requires_grad_(requires_grad)
|
753
|
+
|
747
754
|
@staticmethod
|
748
|
-
def randn(*shape, dtype:
|
755
|
+
def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
749
756
|
"""
|
750
757
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
751
758
|
If `dtype` is not specified, the default type is used.
|
@@ -758,9 +765,7 @@ class Tensor(SimpleMathTrait):
|
|
758
765
|
print(Tensor.randn(2, 3).numpy())
|
759
766
|
```
|
760
767
|
"""
|
761
|
-
|
762
|
-
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
763
|
-
return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
|
768
|
+
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype, requires_grad=requires_grad)
|
764
769
|
|
765
770
|
@staticmethod
|
766
771
|
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
@@ -782,7 +787,7 @@ class Tensor(SimpleMathTrait):
|
|
782
787
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
783
788
|
|
784
789
|
@staticmethod
|
785
|
-
def normal(*shape, mean=0.0, std=1.0, requires_grad:
|
790
|
+
def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
786
791
|
"""
|
787
792
|
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
788
793
|
|
@@ -797,7 +802,7 @@ class Tensor(SimpleMathTrait):
|
|
797
802
|
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
|
798
803
|
|
799
804
|
@staticmethod
|
800
|
-
def uniform(*shape, low=0.0, high=1.0, dtype:
|
805
|
+
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor:
|
801
806
|
"""
|
802
807
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
803
808
|
|
@@ -877,7 +882,29 @@ class Tensor(SimpleMathTrait):
|
|
877
882
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
878
883
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
879
884
|
|
885
|
+
@staticmethod
|
886
|
+
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
|
887
|
+
"""
|
888
|
+
Returns a tensor with a random permutation of integers from `0` to `n-1`.
|
889
|
+
|
890
|
+
```python exec="true" source="above" session="tensor" result="python"
|
891
|
+
Tensor.manual_seed(42)
|
892
|
+
print(Tensor.randperm(6).numpy())
|
893
|
+
```
|
894
|
+
"""
|
895
|
+
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
|
896
|
+
|
880
897
|
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
898
|
+
"""
|
899
|
+
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
900
|
+
|
901
|
+
NOTE: `replacement=False` for `num_samples > 1` is not supported yet.
|
902
|
+
```python exec="true" source="above" session="tensor" result="python"
|
903
|
+
Tensor.manual_seed(42)
|
904
|
+
t = Tensor([1, 2, 3, 4])
|
905
|
+
print(t.multinomial(20, replacement=True).numpy())
|
906
|
+
```
|
907
|
+
"""
|
881
908
|
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
882
909
|
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
883
910
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
@@ -888,9 +915,9 @@ class Tensor(SimpleMathTrait):
|
|
888
915
|
|
889
916
|
# ***** toposort and backward pass *****
|
890
917
|
|
891
|
-
def gradient(self, *targets:Tensor, gradient:
|
918
|
+
def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]:
|
892
919
|
"""
|
893
|
-
|
920
|
+
Computes the gradient of the targets with respect to self.
|
894
921
|
|
895
922
|
```python exec="true" source="above" session="tensor" result="python"
|
896
923
|
x = Tensor.eye(3)
|
@@ -903,21 +930,20 @@ class Tensor(SimpleMathTrait):
|
|
903
930
|
```
|
904
931
|
"""
|
905
932
|
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
933
|
+
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
|
906
934
|
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
907
|
-
|
908
|
-
|
909
|
-
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
935
|
+
target_uops = [x.uop for x in targets]
|
936
|
+
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
|
910
937
|
ret = []
|
911
938
|
for x in target_uops:
|
912
939
|
if (y:=grads.get(x)) is None:
|
913
940
|
if materialize_grads: y = x.const_like(0)
|
914
|
-
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.
|
941
|
+
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
|
915
942
|
ret.append(y)
|
916
|
-
rets.append(ret)
|
917
943
|
# create returned Tensors
|
918
|
-
return [Tensor(u, device=t.device) for t,u in zip(targets,
|
944
|
+
return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
|
919
945
|
|
920
|
-
def backward(self, gradient:
|
946
|
+
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
921
947
|
"""
|
922
948
|
Propagates the gradient of a tensor backwards through the computation graph.
|
923
949
|
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
@@ -927,9 +953,9 @@ class Tensor(SimpleMathTrait):
|
|
927
953
|
print(t.grad.numpy())
|
928
954
|
```
|
929
955
|
"""
|
930
|
-
all_uops = self.
|
956
|
+
all_uops = self.uop.toposort()
|
931
957
|
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
932
|
-
t.
|
958
|
+
t.uop in all_uops and t.requires_grad]
|
933
959
|
# clear contexts
|
934
960
|
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
935
961
|
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
@@ -938,9 +964,9 @@ class Tensor(SimpleMathTrait):
|
|
938
964
|
|
939
965
|
# ***** movement low level ops *****
|
940
966
|
|
941
|
-
def view(self, *
|
967
|
+
def view(self, shape:tuple[sint, ...], *args) -> Tensor:
|
942
968
|
"""`.view` is an alias for `.reshape`."""
|
943
|
-
return self.reshape(shape)
|
969
|
+
return self.reshape(shape, *args)
|
944
970
|
|
945
971
|
def reshape(self, shape, *args) -> Tensor:
|
946
972
|
"""
|
@@ -981,11 +1007,11 @@ class Tensor(SimpleMathTrait):
|
|
981
1007
|
`order` can be passed as a tuple or as separate arguments.
|
982
1008
|
|
983
1009
|
```python exec="true" source="above" session="tensor" result="python"
|
984
|
-
t = Tensor.
|
985
|
-
print(t.
|
1010
|
+
t = Tensor.empty(2, 3, 5)
|
1011
|
+
print(t.shape)
|
986
1012
|
```
|
987
1013
|
```python exec="true" source="above" session="tensor" result="python"
|
988
|
-
print(t.permute(
|
1014
|
+
print(t.permute(2, 0, 1).shape)
|
989
1015
|
```
|
990
1016
|
"""
|
991
1017
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
@@ -1012,7 +1038,7 @@ class Tensor(SimpleMathTrait):
|
|
1012
1038
|
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
1013
1039
|
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
1014
1040
|
|
1015
|
-
def shrink(self, arg:tuple[
|
1041
|
+
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
|
1016
1042
|
"""
|
1017
1043
|
Returns a tensor that shrinks the each axis based on input arg.
|
1018
1044
|
`arg` must have the same length as `self.ndim`.
|
@@ -1032,7 +1058,7 @@ class Tensor(SimpleMathTrait):
|
|
1032
1058
|
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
1033
1059
|
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
1034
1060
|
|
1035
|
-
def pad(self, padding:
|
1061
|
+
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
1036
1062
|
"""
|
1037
1063
|
Returns a tensor with padding applied based on the input `padding`.
|
1038
1064
|
|
@@ -1070,11 +1096,11 @@ class Tensor(SimpleMathTrait):
|
|
1070
1096
|
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
1071
1097
|
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
|
1072
1098
|
# group padding
|
1073
|
-
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[
|
1099
|
+
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
|
1074
1100
|
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
1075
1101
|
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
1076
1102
|
if mode == "constant":
|
1077
|
-
def _constant(x:Tensor,px,v):
|
1103
|
+
def _constant(x:Tensor,px,v) -> Tensor:
|
1078
1104
|
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
1079
1105
|
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
1080
1106
|
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
@@ -1097,16 +1123,17 @@ class Tensor(SimpleMathTrait):
|
|
1097
1123
|
|
1098
1124
|
# ***** movement high level ops *****
|
1099
1125
|
|
1100
|
-
def _getitem(self, indices, v:
|
1126
|
+
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
1101
1127
|
# wrap single index into a list
|
1102
1128
|
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
1103
1129
|
x, indices = self, list(indices)
|
1104
1130
|
|
1105
|
-
#
|
1131
|
+
# fill ellipsis or rest of indices with slice(None)
|
1106
1132
|
if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
|
1107
|
-
|
1133
|
+
# NOTE: None adds a dim later
|
1108
1134
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
1109
1135
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
1136
|
+
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
1110
1137
|
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
1111
1138
|
|
1112
1139
|
indices_parsed, dim = [], 0
|
@@ -1114,22 +1141,32 @@ class Tensor(SimpleMathTrait):
|
|
1114
1141
|
size = 1 if index is None else self.shape[dim]
|
1115
1142
|
boundary, stride = [0, size], 1 # defaults
|
1116
1143
|
match index:
|
1117
|
-
case
|
1118
|
-
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
1144
|
+
case Tensor():
|
1119
1145
|
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
1120
|
-
index = (index
|
1146
|
+
index = (index < 0).where(index+size, index).to(self.device) # treat negative index values
|
1147
|
+
case list() | tuple():
|
1148
|
+
if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element")
|
1149
|
+
index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device, requires_grad=False).reshape(ti.shape)
|
1121
1150
|
case int() | UOp(): # sint
|
1122
1151
|
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
1152
|
+
# TODO: is this right for (negative) symbolic?
|
1123
1153
|
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
1124
1154
|
case slice():
|
1125
1155
|
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
if
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1156
|
+
start, stop = 0 if index.start is None else index.start, size if index.stop is None else index.stop
|
1157
|
+
step = 1 if index.step is None else index.step
|
1158
|
+
boundary, stride = [start, stop], step
|
1159
|
+
if all(isinstance(s, int) for s in (start,stop,step)):
|
1160
|
+
# handle int slicing
|
1161
|
+
*boundary, stride = index.indices(cast(SupportsIndex, size))
|
1162
|
+
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
|
1163
|
+
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
|
1164
|
+
# update size for slice
|
1165
|
+
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
|
1166
|
+
elif resolve(step == 1, False) and all(isinstance(s,sint) for s in (start, stop)) and resolve((stop-start) > 0, False):
|
1167
|
+
# simple symbolic slice
|
1168
|
+
size = cast(sint, cast(UOp, (stop - start)).ssimplify())
|
1169
|
+
else: raise TypeError(f"slice {index=} is not supported")
|
1133
1170
|
case None: pass # do nothing
|
1134
1171
|
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
|
1135
1172
|
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
|
@@ -1140,9 +1177,9 @@ class Tensor(SimpleMathTrait):
|
|
1140
1177
|
# flip negative strides
|
1141
1178
|
shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
|
1142
1179
|
x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1180
|
+
strides = tuple(map(abs, strides))
|
1181
|
+
# apply stride
|
1182
|
+
if any(st != 1 for st in strides):
|
1146
1183
|
# pad shape to multiple of stride
|
1147
1184
|
if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
|
1148
1185
|
x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
|
@@ -1150,7 +1187,7 @@ class Tensor(SimpleMathTrait):
|
|
1150
1187
|
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
|
1151
1188
|
|
1152
1189
|
# dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
|
1153
|
-
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
|
1190
|
+
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], (int, UOp))))
|
1154
1191
|
|
1155
1192
|
# tensor indexing
|
1156
1193
|
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
|
@@ -1170,7 +1207,7 @@ class Tensor(SimpleMathTrait):
|
|
1170
1207
|
# inject 1's for the extra dims added in create masks
|
1171
1208
|
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
1172
1209
|
# sum reduce the extra dims introduced in create masks
|
1173
|
-
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims),
|
1210
|
+
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
|
1174
1211
|
|
1175
1212
|
# special permute case
|
1176
1213
|
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
|
@@ -1188,7 +1225,7 @@ class Tensor(SimpleMathTrait):
|
|
1188
1225
|
|
1189
1226
|
def __getitem__(self, indices) -> Tensor:
|
1190
1227
|
"""
|
1191
|
-
|
1228
|
+
Retrieves a sub-tensor using indexing.
|
1192
1229
|
|
1193
1230
|
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
1194
1231
|
|
@@ -1226,19 +1263,19 @@ class Tensor(SimpleMathTrait):
|
|
1226
1263
|
"""
|
1227
1264
|
return self._getitem(indices)
|
1228
1265
|
|
1229
|
-
def __setitem__(self, indices, v:
|
1266
|
+
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
1230
1267
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
1231
|
-
self._getitem(indices).assign(v)
|
1268
|
+
self.realize()._getitem(indices).assign(v)
|
1232
1269
|
return
|
1233
1270
|
# NOTE: check that setitem target is valid first
|
1234
|
-
if not unwrap(self.
|
1271
|
+
if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
1235
1272
|
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
1236
1273
|
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
1237
1274
|
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
1238
1275
|
|
1239
1276
|
res = self.realize()._getitem(indices, v)
|
1240
1277
|
# if shapes match and data is not shared it's a copy and we assign to self
|
1241
|
-
if res.shape == self.shape and res.
|
1278
|
+
if res.shape == self.shape and res.uop is not self.uop:
|
1242
1279
|
self.assign(res).realize()
|
1243
1280
|
else: # no copy, basic setitem
|
1244
1281
|
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
|
@@ -1261,7 +1298,7 @@ class Tensor(SimpleMathTrait):
|
|
1261
1298
|
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1262
1299
|
index = index.to(self.device)
|
1263
1300
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1264
|
-
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1,
|
1301
|
+
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype)
|
1265
1302
|
|
1266
1303
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
1267
1304
|
"""
|
@@ -1296,11 +1333,11 @@ class Tensor(SimpleMathTrait):
|
|
1296
1333
|
```
|
1297
1334
|
"""
|
1298
1335
|
# checks for shapes and number of dimensions delegated to cat
|
1299
|
-
return Tensor.cat(*[t.unsqueeze(dim) for t in
|
1336
|
+
return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim)
|
1300
1337
|
|
1301
|
-
def repeat_interleave(self, repeats:int, dim:
|
1338
|
+
def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor:
|
1302
1339
|
"""
|
1303
|
-
|
1340
|
+
Repeats elements of a tensor.
|
1304
1341
|
|
1305
1342
|
```python exec="true" source="above" session="tensor" result="python"
|
1306
1343
|
t = Tensor([1, 2, 3])
|
@@ -1336,7 +1373,7 @@ class Tensor(SimpleMathTrait):
|
|
1336
1373
|
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
1337
1374
|
return dim + total if dim < 0 else dim
|
1338
1375
|
|
1339
|
-
def split(self, sizes:
|
1376
|
+
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
|
1340
1377
|
"""
|
1341
1378
|
Splits the tensor into chunks along the dimension specified by `dim`.
|
1342
1379
|
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
@@ -1385,7 +1422,31 @@ class Tensor(SimpleMathTrait):
|
|
1385
1422
|
dim = self._resolve_dim(dim)
|
1386
1423
|
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
1387
1424
|
|
1388
|
-
def
|
1425
|
+
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
|
1426
|
+
"""
|
1427
|
+
Unfolds the tensor along dimension `dim` into overlapping windows.
|
1428
|
+
|
1429
|
+
Each window has length `size` and begins every `step` elements of `self`.
|
1430
|
+
Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)`
|
1431
|
+
where `n_windows = (self.shape[dim] - size) // step + 1`.
|
1432
|
+
|
1433
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1434
|
+
unfolded = Tensor.arange(8).unfold(0,2,2)
|
1435
|
+
print("\\n".join([repr(x.numpy()) for x in unfolded]))
|
1436
|
+
```
|
1437
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1438
|
+
unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
|
1439
|
+
print("\\n".join([repr(x.numpy()) for x in unfolded]))
|
1440
|
+
```
|
1441
|
+
"""
|
1442
|
+
if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}')
|
1443
|
+
if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}')
|
1444
|
+
if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}')
|
1445
|
+
dim = self._resolve_dim(dim)
|
1446
|
+
perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,)
|
1447
|
+
return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,))
|
1448
|
+
|
1449
|
+
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
|
1389
1450
|
"""
|
1390
1451
|
Generates coordinate matrices from coordinate vectors.
|
1391
1452
|
Input tensors can be scalars or 1D tensors.
|
@@ -1412,7 +1473,7 @@ class Tensor(SimpleMathTrait):
|
|
1412
1473
|
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
1413
1474
|
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
1414
1475
|
|
1415
|
-
def squeeze(self, dim:
|
1476
|
+
def squeeze(self, dim:int|None=None) -> Tensor:
|
1416
1477
|
"""
|
1417
1478
|
Returns a tensor with specified dimensions of input of size 1 removed.
|
1418
1479
|
If `dim` is not specified, all dimensions with size 1 are removed.
|
@@ -1469,7 +1530,7 @@ class Tensor(SimpleMathTrait):
|
|
1469
1530
|
order[dim0], order[dim1] = order[dim1], order[dim0]
|
1470
1531
|
return self.permute(order)
|
1471
1532
|
|
1472
|
-
def flatten(self, start_dim=0, end_dim=-1):
|
1533
|
+
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
|
1473
1534
|
"""
|
1474
1535
|
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
1475
1536
|
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
@@ -1485,7 +1546,7 @@ class Tensor(SimpleMathTrait):
|
|
1485
1546
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
1486
1547
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
1487
1548
|
|
1488
|
-
def unflatten(self, dim:int, sizes:tuple[int,...]):
|
1549
|
+
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
|
1489
1550
|
"""
|
1490
1551
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
1491
1552
|
|
@@ -1502,7 +1563,33 @@ class Tensor(SimpleMathTrait):
|
|
1502
1563
|
dim = self._resolve_dim(dim)
|
1503
1564
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
1504
1565
|
|
1505
|
-
def
|
1566
|
+
def diag(self) -> Tensor:
|
1567
|
+
"""
|
1568
|
+
Returns a 2-D square tensor with the elements of input as the main diagonal.
|
1569
|
+
|
1570
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1571
|
+
print(Tensor([1, 2, 3]).diag().numpy())
|
1572
|
+
```
|
1573
|
+
"""
|
1574
|
+
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
|
1575
|
+
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
1576
|
+
|
1577
|
+
def diagonal(self) -> Tensor:
|
1578
|
+
"""
|
1579
|
+
Returns a view of input tensor with its main diagonal elements.
|
1580
|
+
|
1581
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1582
|
+
t = Tensor.arange(9).reshape(3, 3)
|
1583
|
+
print(t.numpy())
|
1584
|
+
```
|
1585
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1586
|
+
print(t.diagonal().numpy())
|
1587
|
+
```
|
1588
|
+
"""
|
1589
|
+
if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
|
1590
|
+
return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
|
1591
|
+
|
1592
|
+
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor:
|
1506
1593
|
"""
|
1507
1594
|
Rolls the tensor along specified dimension(s).
|
1508
1595
|
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
|
@@ -1515,12 +1602,11 @@ class Tensor(SimpleMathTrait):
|
|
1515
1602
|
print(t.roll(shifts=-1, dims=0).numpy())
|
1516
1603
|
```
|
1517
1604
|
"""
|
1518
|
-
dims
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
return rolled
|
1605
|
+
if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape)
|
1606
|
+
dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim
|
1607
|
+
if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}")
|
1608
|
+
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
|
1609
|
+
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
|
1524
1610
|
|
1525
1611
|
def rearrange(self, formula:str, **sizes) -> Tensor:
|
1526
1612
|
"""
|
@@ -1562,22 +1648,61 @@ class Tensor(SimpleMathTrait):
|
|
1562
1648
|
t = t.permute([lhs.index(name) for name in rhs])
|
1563
1649
|
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
1564
1650
|
|
1651
|
+
def masked_select(self, mask):
|
1652
|
+
"""
|
1653
|
+
Selects elements from `self` based on the boolean `mask`.
|
1654
|
+
|
1655
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1656
|
+
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
1657
|
+
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
1658
|
+
print(t.numpy())
|
1659
|
+
print(mask.numpy())
|
1660
|
+
```
|
1661
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1662
|
+
print(t.masked_select(mask).numpy())
|
1663
|
+
```
|
1664
|
+
"""
|
1665
|
+
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
1666
|
+
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
1667
|
+
mask_cumsum = mask.cumsum()
|
1668
|
+
counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32)
|
1669
|
+
idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()
|
1670
|
+
return x[idxs]
|
1671
|
+
|
1672
|
+
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor:
|
1673
|
+
"""
|
1674
|
+
Replaces `self` with `value` wherever the elements of `mask` are True.
|
1675
|
+
|
1676
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1677
|
+
t = Tensor([1, 2, 3, 4, 5])
|
1678
|
+
mask = Tensor([True, False, True, False, False])
|
1679
|
+
print(t.masked_fill(mask, -12).numpy())
|
1680
|
+
```
|
1681
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1682
|
+
t = Tensor([1, 2, 3, 4, 5])
|
1683
|
+
mask = Tensor([True, False, True, False, False])
|
1684
|
+
value = Tensor([-1, -2, -3, -4, -5])
|
1685
|
+
print(t.masked_fill(mask, value).numpy())
|
1686
|
+
```
|
1687
|
+
"""
|
1688
|
+
return mask.where(value, self)
|
1689
|
+
|
1565
1690
|
# ***** reduce ops *****
|
1566
1691
|
|
1567
|
-
def _reduce(self, op:Ops, axis:
|
1692
|
+
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1568
1693
|
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
1569
1694
|
if self.ndim == 0: axis = ()
|
1570
1695
|
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
1571
1696
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
1572
1697
|
|
1573
|
-
def sum(self, axis:
|
1698
|
+
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
|
1574
1699
|
"""
|
1575
1700
|
Returns the sum of the elements of the tensor along the specified axis or axes.
|
1576
1701
|
|
1577
1702
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1578
1703
|
which the maximum is computed and whether the reduced dimensions are retained.
|
1579
1704
|
|
1580
|
-
You can pass in `
|
1705
|
+
You can pass in `dtype` keyword argument to control the data type of the accumulation.
|
1581
1706
|
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
1582
1707
|
|
1583
1708
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -1594,17 +1719,17 @@ class Tensor(SimpleMathTrait):
|
|
1594
1719
|
print(t.sum(axis=1).numpy())
|
1595
1720
|
```
|
1596
1721
|
"""
|
1597
|
-
ret = self.cast(sum_acc_dtype(self.dtype) if
|
1598
|
-
return ret.cast(self.dtype) if
|
1722
|
+
ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim)
|
1723
|
+
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1599
1724
|
|
1600
|
-
def prod(self, axis:
|
1725
|
+
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
|
1601
1726
|
"""
|
1602
1727
|
Returns the product of the elements of the tensor along the specified axis or axes.
|
1603
1728
|
|
1604
1729
|
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
|
1605
1730
|
which the maximum is computed and whether the reduced dimensions are retained.
|
1606
1731
|
|
1607
|
-
You can pass in `
|
1732
|
+
You can pass in `dtype` keyword argument to control the data type of the accumulation.
|
1608
1733
|
If not specified, the accumulation data type is chosen based on the input tensor's data type.
|
1609
1734
|
|
1610
1735
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -1621,9 +1746,9 @@ class Tensor(SimpleMathTrait):
|
|
1621
1746
|
print(t.prod(axis=1).numpy())
|
1622
1747
|
```
|
1623
1748
|
"""
|
1624
|
-
return self.cast(
|
1749
|
+
return self.cast(dtype if dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
1625
1750
|
|
1626
|
-
def max(self, axis:
|
1751
|
+
def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1627
1752
|
"""
|
1628
1753
|
Returns the maximum value of the tensor along the specified axis or axes.
|
1629
1754
|
|
@@ -1646,9 +1771,9 @@ class Tensor(SimpleMathTrait):
|
|
1646
1771
|
"""
|
1647
1772
|
return self._reduce(Ops.MAX, axis, keepdim)
|
1648
1773
|
|
1649
|
-
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
1774
|
+
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
1650
1775
|
|
1651
|
-
def min(self, axis:
|
1776
|
+
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1652
1777
|
"""
|
1653
1778
|
Returns the minimum value of the tensor along the specified axis or axes.
|
1654
1779
|
|
@@ -1671,7 +1796,7 @@ class Tensor(SimpleMathTrait):
|
|
1671
1796
|
"""
|
1672
1797
|
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
|
1673
1798
|
|
1674
|
-
def any(self, axis:
|
1799
|
+
def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1675
1800
|
"""
|
1676
1801
|
Tests if any element evaluates to `True` along the specified axis or axes.
|
1677
1802
|
|
@@ -1693,7 +1818,7 @@ class Tensor(SimpleMathTrait):
|
|
1693
1818
|
"""
|
1694
1819
|
return self.bool().max(axis, keepdim)
|
1695
1820
|
|
1696
|
-
def all(self, axis:
|
1821
|
+
def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1697
1822
|
"""
|
1698
1823
|
Tests if all element evaluates to `True` along the specified axis or axes.
|
1699
1824
|
|
@@ -1730,14 +1855,12 @@ class Tensor(SimpleMathTrait):
|
|
1730
1855
|
print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy())
|
1731
1856
|
```
|
1732
1857
|
"""
|
1733
|
-
|
1734
|
-
def isfinite(t): return (t.isinf()|t.isnan()).logical_not()
|
1735
|
-
is_finite_close = isfinite(self) & isfinite(other) & ((self - other).abs() <= atol + rtol * other.abs())
|
1858
|
+
is_finite_close = self.isfinite() & other.isfinite() & ((self - other).abs() <= atol + rtol * other.abs())
|
1736
1859
|
is_infinite_close = (self.isinf() | other.isinf()) & (self == other)
|
1737
1860
|
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
1738
1861
|
return is_finite_close | is_infinite_close | is_nan_close
|
1739
1862
|
|
1740
|
-
def mean(self, axis:
|
1863
|
+
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
1741
1864
|
"""
|
1742
1865
|
Returns the mean value of the tensor along the specified axis or axes.
|
1743
1866
|
|
@@ -1761,9 +1884,10 @@ class Tensor(SimpleMathTrait):
|
|
1761
1884
|
"""
|
1762
1885
|
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
1763
1886
|
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
1764
|
-
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]))
|
1887
|
+
return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \
|
1888
|
+
.cast(output_dtype)
|
1765
1889
|
|
1766
|
-
def var(self, axis:
|
1890
|
+
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
1767
1891
|
"""
|
1768
1892
|
Returns the variance of the tensor along the specified axis or axes.
|
1769
1893
|
|
@@ -1789,7 +1913,24 @@ class Tensor(SimpleMathTrait):
|
|
1789
1913
|
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
1790
1914
|
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
1791
1915
|
|
1792
|
-
def
|
1916
|
+
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
1917
|
+
"""
|
1918
|
+
Calculates the variance and mean over the dimensions specified by dim.
|
1919
|
+
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
|
1920
|
+
|
1921
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1922
|
+
Tensor.manual_seed(42)
|
1923
|
+
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
|
1924
|
+
print(t.numpy())
|
1925
|
+
```
|
1926
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1927
|
+
var, mean = t.var_mean()
|
1928
|
+
print(var.numpy(), mean.numpy())
|
1929
|
+
```
|
1930
|
+
"""
|
1931
|
+
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
|
1932
|
+
|
1933
|
+
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
1793
1934
|
"""
|
1794
1935
|
Returns the standard deviation of the tensor along the specified axis or axes.
|
1795
1936
|
|
@@ -1813,7 +1954,7 @@ class Tensor(SimpleMathTrait):
|
|
1813
1954
|
"""
|
1814
1955
|
return self.var(axis, keepdim, correction).sqrt()
|
1815
1956
|
|
1816
|
-
def std_mean(self, axis:
|
1957
|
+
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
1817
1958
|
"""
|
1818
1959
|
Calculates the standard deviation and mean over the dimensions specified by dim.
|
1819
1960
|
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
|
@@ -1830,13 +1971,100 @@ class Tensor(SimpleMathTrait):
|
|
1830
1971
|
"""
|
1831
1972
|
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
1832
1973
|
|
1833
|
-
def
|
1974
|
+
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
|
1975
|
+
"""
|
1976
|
+
Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default.
|
1977
|
+
|
1978
|
+
```python exec="false" source="above" session="tensor" result="python"
|
1979
|
+
t = Tensor(b"Hello World!").keccak()
|
1980
|
+
print(t.data().hex())
|
1981
|
+
```
|
1982
|
+
"""
|
1983
|
+
|
1984
|
+
# https://keccak.team/keccak_specs_summary.html
|
1985
|
+
|
1986
|
+
def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64):
|
1987
|
+
# TODO: contiguous is here for compile speed
|
1988
|
+
return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous()
|
1989
|
+
rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
|
1990
|
+
rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets])
|
1991
|
+
|
1992
|
+
# calculated from π step
|
1993
|
+
reorder_indexes = ctensor([0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21], dtype=dtypes.int32)
|
1994
|
+
rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081,
|
1995
|
+
0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
|
1996
|
+
0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)]
|
1997
|
+
|
1998
|
+
rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg
|
1999
|
+
data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]), rate - (self.shape[-1] * self.dtype.itemsize % rate)
|
2000
|
+
# pad batches then pad blocks
|
2001
|
+
data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad((None, None, (0, 200 - rate)))
|
2002
|
+
|
2003
|
+
# create pad mask
|
2004
|
+
lbe = prod(data.shape[1:]) + rate - data_pad - 200
|
2005
|
+
if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)]
|
2006
|
+
else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)]
|
2007
|
+
pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0)
|
2008
|
+
|
2009
|
+
data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
|
2010
|
+
|
2011
|
+
state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64)
|
2012
|
+
for k in range(int(data.shape[1])):
|
2013
|
+
state = state.bitwise_xor(data[:,k].reshape(bs, 25))
|
2014
|
+
for i in range(24): # f1600
|
2015
|
+
# θ step
|
2016
|
+
p = state.reshape(bs, 5, 5).transpose(2, 1)
|
2017
|
+
t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
|
2018
|
+
state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
|
2019
|
+
# ρ and π steps
|
2020
|
+
state = state[:, reorder_indexes]
|
2021
|
+
state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
|
2022
|
+
# χ and ι step
|
2023
|
+
state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2))
|
2024
|
+
state = state.flatten(1) ^ rnd_const_masks[i]
|
2025
|
+
# NOTE: kernelize here to prevent internal stack from growing propotional to data size
|
2026
|
+
state = state.kernelize()
|
2027
|
+
return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes)
|
2028
|
+
|
2029
|
+
def _hash_1mb(self) -> Tensor:
|
2030
|
+
assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing"
|
2031
|
+
assert self.ndim == 2, "only support batched 1d tensors"
|
2032
|
+
assert self.shape[1] == 1024 * 1024, "only support messages of 1mb"
|
2033
|
+
|
2034
|
+
blocks = self.shape[0] * self.shape[1] // 4096
|
2035
|
+
data = self.reshape(blocks, 4096)
|
2036
|
+
block_hashes = data.keccak("shake_128").reshape(self.shape[0], 4096)
|
2037
|
+
return block_hashes.keccak("shake_128").reshape(self.shape[0], 16)
|
2038
|
+
|
2039
|
+
def hash(self) -> Tensor:
|
2040
|
+
"""
|
2041
|
+
Calculates a 16-byte hash of the tensor.
|
2042
|
+
```python exec="false source="above" session="tensor" result="python"
|
2043
|
+
t = Tensor(b"Hello World!").hash()
|
2044
|
+
print(t.data().hex())
|
2045
|
+
```
|
2046
|
+
"""
|
2047
|
+
|
2048
|
+
data = self.flatten().bitcast(dtypes.uint8)
|
2049
|
+
if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
|
2050
|
+
base_chunks = ceildiv(data.shape[0], 2**20)
|
2051
|
+
tree_depth = math.ceil(math.log(base_chunks, 65536)) if base_chunks > 1 else 0
|
2052
|
+
|
2053
|
+
level_chunks = base_chunks
|
2054
|
+
for _ in range(tree_depth + 1):
|
2055
|
+
data = data.reshape(level_chunks, 2**20)._hash_1mb().flatten()
|
2056
|
+
if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
|
2057
|
+
level_chunks = ceildiv(data.shape[0], 2**20)
|
2058
|
+
|
2059
|
+
return data[:16]
|
2060
|
+
|
2061
|
+
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
|
1834
2062
|
m = self - self.max(axis=axis, keepdim=True).detach()
|
1835
2063
|
if dtype is not None: m = m.cast(dtype)
|
1836
2064
|
e = m.exp()
|
1837
2065
|
return m, e, e.sum(axis=axis, keepdim=True)
|
1838
2066
|
|
1839
|
-
def softmax(self, axis=-1, dtype:
|
2067
|
+
def softmax(self, axis=-1, dtype:DTypeLike|None=None, _single_kernel=getenv("SINGLE_KERNEL_SOFTMAX")) -> Tensor:
|
1840
2068
|
"""
|
1841
2069
|
Applies the softmax function to the tensor along the specified axis.
|
1842
2070
|
|
@@ -1856,10 +2084,13 @@ class Tensor(SimpleMathTrait):
|
|
1856
2084
|
print(t.softmax(axis=0).numpy())
|
1857
2085
|
```
|
1858
2086
|
"""
|
2087
|
+
if _single_kernel:
|
2088
|
+
_, e, ss = self.contiguous()._softmax(axis, dtype)
|
2089
|
+
return e.div(ss).fuse()
|
1859
2090
|
_, e, ss = self._softmax(axis, dtype)
|
1860
2091
|
return e.div(ss)
|
1861
2092
|
|
1862
|
-
def log_softmax(self, axis=-1, dtype:
|
2093
|
+
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
1863
2094
|
"""
|
1864
2095
|
Applies the log-softmax function to the tensor along the specified axis.
|
1865
2096
|
|
@@ -1882,7 +2113,7 @@ class Tensor(SimpleMathTrait):
|
|
1882
2113
|
m, _, ss = self._softmax(axis, dtype)
|
1883
2114
|
return m - ss.log()
|
1884
2115
|
|
1885
|
-
def logsumexp(self, axis=None, keepdim=False):
|
2116
|
+
def logsumexp(self, axis=None, keepdim=False) -> Tensor:
|
1886
2117
|
"""
|
1887
2118
|
Computes the log-sum-exp of the tensor along the specified axis or axes.
|
1888
2119
|
|
@@ -1909,14 +2140,14 @@ class Tensor(SimpleMathTrait):
|
|
1909
2140
|
m = self.max(axis=axis, keepdim=True)
|
1910
2141
|
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
1911
2142
|
|
1912
|
-
def logcumsumexp(self, axis=0):
|
2143
|
+
def logcumsumexp(self, axis=0) -> Tensor:
|
1913
2144
|
"""
|
1914
2145
|
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
|
1915
2146
|
|
1916
2147
|
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
|
1917
2148
|
|
1918
2149
|
You can pass in the `axis` keyword argument to control the axis along which
|
1919
|
-
the log-
|
2150
|
+
the log-cumsum-exp is computed.
|
1920
2151
|
|
1921
2152
|
```python exec="true" source="above" session="tensor" result="python"
|
1922
2153
|
Tensor.manual_seed(42)
|
@@ -1934,17 +2165,15 @@ class Tensor(SimpleMathTrait):
|
|
1934
2165
|
```
|
1935
2166
|
"""
|
1936
2167
|
if self.ndim == 0: return self
|
1937
|
-
axis = self._resolve_dim(axis)
|
1938
2168
|
x = self.transpose(axis, -1)
|
1939
2169
|
last_dim_size = x.shape[-1]
|
1940
|
-
|
1941
|
-
x_cummax =
|
1942
|
-
|
1943
|
-
|
1944
|
-
ret
|
1945
|
-
return ret.reshape(*x.shape).transpose(-1, axis)
|
2170
|
+
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
|
2171
|
+
x_cummax = x.cummax(-1)
|
2172
|
+
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
|
2173
|
+
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
|
2174
|
+
return ret.transpose(-1, axis)
|
1946
2175
|
|
1947
|
-
def argmax(self, axis=None, keepdim=False):
|
2176
|
+
def argmax(self, axis=None, keepdim=False) -> Tensor:
|
1948
2177
|
"""
|
1949
2178
|
Returns the indices of the maximum value of the tensor along the specified axis.
|
1950
2179
|
|
@@ -1971,7 +2200,7 @@ class Tensor(SimpleMathTrait):
|
|
1971
2200
|
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
1972
2201
|
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
|
1973
2202
|
|
1974
|
-
def argmin(self, axis=None, keepdim=False):
|
2203
|
+
def argmin(self, axis=None, keepdim=False) -> Tensor:
|
1975
2204
|
"""
|
1976
2205
|
Returns the indices of the minimum value of the tensor along the specified axis.
|
1977
2206
|
|
@@ -1995,7 +2224,7 @@ class Tensor(SimpleMathTrait):
|
|
1995
2224
|
return self._inverse().argmax(axis=axis, keepdim=keepdim)
|
1996
2225
|
|
1997
2226
|
@staticmethod
|
1998
|
-
def einsum(formula:str, *operands:Tensor|Sequence[Tensor],
|
2227
|
+
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor:
|
1999
2228
|
"""
|
2000
2229
|
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
2001
2230
|
|
@@ -2009,7 +2238,7 @@ class Tensor(SimpleMathTrait):
|
|
2009
2238
|
"""
|
2010
2239
|
def parse_formula(formula:str, *operands:Tensor):
|
2011
2240
|
if "..." in (formula := formula.replace(" ", "")):
|
2012
|
-
ell_chars, ell_longest = "".join(
|
2241
|
+
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
|
2013
2242
|
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
2014
2243
|
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
2015
2244
|
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
@@ -2037,11 +2266,11 @@ class Tensor(SimpleMathTrait):
|
|
2037
2266
|
|
2038
2267
|
# sum over all axes that's not in the output, then permute to the output order
|
2039
2268
|
return functools.reduce(lambda a,b:a*b, xs_) \
|
2040
|
-
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],
|
2269
|
+
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
|
2041
2270
|
|
2042
2271
|
# ***** processing ops *****
|
2043
2272
|
|
2044
|
-
def _pool(self, k_:tuple[sint, ...], stride:
|
2273
|
+
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
|
2045
2274
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
2046
2275
|
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
2047
2276
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
@@ -2066,12 +2295,12 @@ class Tensor(SimpleMathTrait):
|
|
2066
2295
|
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
2067
2296
|
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
2068
2297
|
|
2069
|
-
def _resolve_pool_pads(self, padding:
|
2298
|
+
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
2070
2299
|
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
2071
2300
|
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
2072
2301
|
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
2073
2302
|
|
2074
|
-
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:
|
2303
|
+
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]:
|
2075
2304
|
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
|
2076
2305
|
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
|
2077
2306
|
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
@@ -2085,7 +2314,8 @@ class Tensor(SimpleMathTrait):
|
|
2085
2314
|
return pads
|
2086
2315
|
|
2087
2316
|
# NOTE: these work for more than 2D
|
2088
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding
|
2317
|
+
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
2318
|
+
ceil_mode=False, count_include_pad=True) -> Tensor:
|
2089
2319
|
"""
|
2090
2320
|
Applies average pooling over a tensor.
|
2091
2321
|
|
@@ -2106,8 +2336,6 @@ class Tensor(SimpleMathTrait):
|
|
2106
2336
|
|
2107
2337
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2108
2338
|
|
2109
|
-
See: https://paperswithcode.com/method/average-pooling
|
2110
|
-
|
2111
2339
|
```python exec="true" source="above" session="tensor" result="python"
|
2112
2340
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
2113
2341
|
print(t.avg_pool2d().numpy())
|
@@ -2132,7 +2360,8 @@ class Tensor(SimpleMathTrait):
|
|
2132
2360
|
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
2133
2361
|
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
2134
2362
|
|
2135
|
-
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0,
|
2363
|
+
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
2364
|
+
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
|
2136
2365
|
"""
|
2137
2366
|
Applies max pooling over a tensor.
|
2138
2367
|
|
@@ -2149,11 +2378,10 @@ class Tensor(SimpleMathTrait):
|
|
2149
2378
|
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
2150
2379
|
|
2151
2380
|
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
2381
|
+
When `return_indices` is set to `True`, the argmax will be returned along with the max values.
|
2152
2382
|
|
2153
2383
|
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2154
2384
|
|
2155
|
-
See: https://paperswithcode.com/method/max-pooling
|
2156
|
-
|
2157
2385
|
```python exec="true" source="above" session="tensor" result="python"
|
2158
2386
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
2159
2387
|
print(t.max_pool2d().numpy())
|
@@ -2165,12 +2393,50 @@ class Tensor(SimpleMathTrait):
|
|
2165
2393
|
print(t.max_pool2d(padding=1).numpy())
|
2166
2394
|
```
|
2167
2395
|
"""
|
2168
|
-
|
2396
|
+
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
2397
|
+
pads = self._resolve_pool_pads(padding, len(k_))
|
2169
2398
|
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
2170
|
-
|
2399
|
+
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
2400
|
+
if not return_indices: return pooled.max(axis)
|
2401
|
+
spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
|
2402
|
+
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
|
2403
|
+
m = pooled == pooled.max(axis, keepdim=True)
|
2404
|
+
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
2405
|
+
return pooled.max(axis), spatial_sz - idx.max(axis)
|
2406
|
+
|
2407
|
+
def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
|
2408
|
+
"""
|
2409
|
+
Performs a partial inverse of `max_pool2d` using the indices from the argmax.
|
2410
|
+
|
2411
|
+
When `output_size` is provided, the output shape disambiguates to the provided shape.
|
2412
|
+
|
2413
|
+
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
2414
|
+
|
2415
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2416
|
+
t = Tensor.arange(1, 17).reshape(1, 1, 4, 4)
|
2417
|
+
print(t.numpy())
|
2418
|
+
```
|
2419
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2420
|
+
output, indices = Tensor.max_pool2d(t, return_indices=True)
|
2421
|
+
print(output.numpy())
|
2422
|
+
print(indices.numpy())
|
2423
|
+
```
|
2424
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2425
|
+
print(Tensor.max_unpool2d(output, indices).numpy())
|
2426
|
+
```
|
2427
|
+
"""
|
2428
|
+
bs,c,*spatial_shape = self.shape
|
2429
|
+
if output_size is None:
|
2430
|
+
k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size))
|
2431
|
+
p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape)))
|
2432
|
+
# https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1.
|
2433
|
+
output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_))
|
2434
|
+
else: output_size = output_size[-len(spatial_shape):]
|
2435
|
+
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3)
|
2436
|
+
return ret.reshape(bs,c,*output_size)
|
2171
2437
|
|
2172
|
-
def conv2d(self, weight:Tensor, bias:
|
2173
|
-
|
2438
|
+
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
2439
|
+
dtype:DTypeLike|None=None) -> Tensor:
|
2174
2440
|
"""
|
2175
2441
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
2176
2442
|
|
@@ -2196,7 +2462,7 @@ class Tensor(SimpleMathTrait):
|
|
2196
2462
|
print(t.conv2d(w).numpy())
|
2197
2463
|
```
|
2198
2464
|
"""
|
2199
|
-
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding,
|
2465
|
+
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
2200
2466
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
2201
2467
|
padding_ = self._resolve_pool_pads(padding, len(HW))
|
2202
2468
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
@@ -2209,7 +2475,7 @@ class Tensor(SimpleMathTrait):
|
|
2209
2475
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
2210
2476
|
|
2211
2477
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
2212
|
-
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True,
|
2478
|
+
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
|
2213
2479
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
2214
2480
|
|
2215
2481
|
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
@@ -2217,7 +2483,7 @@ class Tensor(SimpleMathTrait):
|
|
2217
2483
|
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
2218
2484
|
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
|
2219
2485
|
|
2220
|
-
#
|
2486
|
+
# TODO: stride == dilation
|
2221
2487
|
# use padding to round up to 4x4 output tiles
|
2222
2488
|
# (bs, cin_, tyx, HWI)
|
2223
2489
|
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
@@ -2234,7 +2500,7 @@ class Tensor(SimpleMathTrait):
|
|
2234
2500
|
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
2235
2501
|
|
2236
2502
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
2237
|
-
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW),
|
2503
|
+
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
2238
2504
|
|
2239
2505
|
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
2240
2506
|
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
@@ -2243,7 +2509,7 @@ class Tensor(SimpleMathTrait):
|
|
2243
2509
|
|
2244
2510
|
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
2245
2511
|
|
2246
|
-
def conv_transpose2d(self, weight:Tensor, bias:
|
2512
|
+
def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
2247
2513
|
"""
|
2248
2514
|
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
|
2249
2515
|
|
@@ -2282,14 +2548,14 @@ class Tensor(SimpleMathTrait):
|
|
2282
2548
|
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
2283
2549
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
2284
2550
|
|
2285
|
-
def dot(self, w:Tensor,
|
2551
|
+
def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
2286
2552
|
|
2287
2553
|
"""
|
2288
2554
|
Performs dot product between two tensors.
|
2289
2555
|
If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
|
2290
2556
|
If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
|
2291
2557
|
|
2292
|
-
You can pass in the optional `
|
2558
|
+
You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
|
2293
2559
|
|
2294
2560
|
```python exec="true" source="above" session="tensor" result="python"
|
2295
2561
|
a = Tensor([1, 2, 3])
|
@@ -2302,20 +2568,20 @@ class Tensor(SimpleMathTrait):
|
|
2302
2568
|
print(a.dot(b).numpy())
|
2303
2569
|
```
|
2304
2570
|
"""
|
2305
|
-
if IMAGE: return self.image_dot(w,
|
2571
|
+
if IMAGE: return self.image_dot(w, dtype)
|
2306
2572
|
x, dx, dw = self, self.ndim, w.ndim
|
2307
2573
|
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
2308
2574
|
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
|
2309
2575
|
x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
|
2310
2576
|
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
|
2311
|
-
return (x*w).sum(-1,
|
2577
|
+
return (x*w).sum(-1, dtype=dtype).cast(least_upper_dtype(x.dtype, w.dtype) if dtype is None else dtype)
|
2312
2578
|
|
2313
|
-
def matmul(self, x:Tensor, reverse=False,
|
2579
|
+
def matmul(self, x:Tensor, reverse=False, dtype:DTypeLike|None=None) -> Tensor:
|
2314
2580
|
"""
|
2315
2581
|
Performs matrix multiplication between two tensors.
|
2316
2582
|
|
2317
2583
|
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
|
2318
|
-
You can pass in the optional `
|
2584
|
+
You can pass in the optional `dtype` keyword argument to control the data type of the accumulation.
|
2319
2585
|
|
2320
2586
|
```python exec="true" source="above" session="tensor" result="python"
|
2321
2587
|
a = Tensor([[1, 2], [3, 4]])
|
@@ -2323,26 +2589,26 @@ class Tensor(SimpleMathTrait):
|
|
2323
2589
|
print(a.matmul(b).numpy())
|
2324
2590
|
```
|
2325
2591
|
"""
|
2326
|
-
return x.dot(self,
|
2592
|
+
return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype)
|
2327
2593
|
|
2328
2594
|
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
2329
|
-
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
2595
|
+
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
|
2330
2596
|
pl_sz = self.shape[axis] - int(not _include_initial)
|
2331
2597
|
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
2332
|
-
return
|
2598
|
+
return {Ops.ADD: pooled.sum(-1), Ops.MAX: pooled.max(-1), Ops.MUL: pooled.prod(-1)}[op].transpose(axis, -1)
|
2333
2599
|
|
2334
2600
|
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
2335
2601
|
axis = self._resolve_dim(axis)
|
2336
2602
|
if self.ndim == 0 or 0 in self.shape: return self
|
2337
|
-
# TODO: someday the optimizer will find this on
|
2603
|
+
# TODO: someday the optimizer will find this on its own
|
2338
2604
|
# for now this is a two stage cumsum
|
2339
2605
|
SPLIT = 256
|
2340
2606
|
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
|
2341
2607
|
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
2342
2608
|
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
2343
2609
|
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
2344
|
-
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
2345
|
-
return
|
2610
|
+
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
2611
|
+
return {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}[op](fix(ret), fix(base))
|
2346
2612
|
|
2347
2613
|
def cumsum(self, axis:int=0) -> Tensor:
|
2348
2614
|
"""
|
@@ -2358,6 +2624,20 @@ class Tensor(SimpleMathTrait):
|
|
2358
2624
|
"""
|
2359
2625
|
return self._split_cumalu(axis, Ops.ADD)
|
2360
2626
|
|
2627
|
+
def cumprod(self, axis:int) -> Tensor:
|
2628
|
+
"""
|
2629
|
+
Computes the cumulative product of the elements of the tensor along the specified `axis`.
|
2630
|
+
|
2631
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2632
|
+
t = Tensor.arange(1, 7).reshape(2, 3)
|
2633
|
+
print(t.numpy())
|
2634
|
+
```
|
2635
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2636
|
+
print(t.cumprod(axis=0).numpy())
|
2637
|
+
```
|
2638
|
+
"""
|
2639
|
+
return self._split_cumalu(axis, Ops.MUL)
|
2640
|
+
|
2361
2641
|
def cummax(self, axis:int=0) -> Tensor:
|
2362
2642
|
"""
|
2363
2643
|
Computes the cumulative max of the tensor along the specified `axis`.
|
@@ -2403,7 +2683,7 @@ class Tensor(SimpleMathTrait):
|
|
2403
2683
|
print(t.triu(diagonal=-1).numpy())
|
2404
2684
|
```
|
2405
2685
|
"""
|
2406
|
-
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self,
|
2686
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, self.zeros_like())
|
2407
2687
|
|
2408
2688
|
def tril(self, diagonal:int=0) -> Tensor:
|
2409
2689
|
"""
|
@@ -2426,7 +2706,7 @@ class Tensor(SimpleMathTrait):
|
|
2426
2706
|
print(t.tril(diagonal=-1).numpy())
|
2427
2707
|
```
|
2428
2708
|
"""
|
2429
|
-
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(
|
2709
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(self.zeros_like(), self)
|
2430
2710
|
|
2431
2711
|
def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
2432
2712
|
"""
|
@@ -2462,7 +2742,7 @@ class Tensor(SimpleMathTrait):
|
|
2462
2742
|
|
2463
2743
|
def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
|
2464
2744
|
index, dim = index.to(self.device), self._resolve_dim(dim)
|
2465
|
-
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.
|
2745
|
+
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.ndim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
2466
2746
|
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
2467
2747
|
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
2468
2748
|
if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
|
@@ -2475,7 +2755,7 @@ class Tensor(SimpleMathTrait):
|
|
2475
2755
|
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
2476
2756
|
return src, mask
|
2477
2757
|
|
2478
|
-
def scatter(self, dim:int, index:Tensor, src:
|
2758
|
+
def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
2479
2759
|
"""
|
2480
2760
|
Scatters `src` values along an axis specified by `dim`.
|
2481
2761
|
Apply `add` or `multiply` reduction operation with `reduce`.
|
@@ -2540,20 +2820,103 @@ class Tensor(SimpleMathTrait):
|
|
2540
2820
|
```
|
2541
2821
|
"""
|
2542
2822
|
src, mask = self._pre_scatter(dim, index, src)
|
2543
|
-
def _inv_mask(a:
|
2544
|
-
|
2545
|
-
if reduce == "
|
2546
|
-
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
2823
|
+
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
2824
|
+
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
|
2825
|
+
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
|
2547
2826
|
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
2548
2827
|
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
2549
2828
|
if reduce == "mean":
|
2550
|
-
count = mask.where(1, 0).sum(-1
|
2551
|
-
return mask.where(src, 0).sum(-1
|
2829
|
+
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
|
2830
|
+
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
|
2552
2831
|
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
2553
2832
|
|
2833
|
+
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
|
2834
|
+
"""
|
2835
|
+
Performs a bitonic sort on the tensor along the specified dimension.
|
2836
|
+
|
2837
|
+
Order of indices for equivalent elements is always preserved.
|
2838
|
+
|
2839
|
+
See: https://en.wikipedia.org/wiki/Bitonic_sorter
|
2840
|
+
|
2841
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2842
|
+
t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]])
|
2843
|
+
print(t.numpy())
|
2844
|
+
```
|
2845
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2846
|
+
sorted_values, indices = t.sort(dim=1, descending=True)
|
2847
|
+
print(sorted_values.numpy())
|
2848
|
+
print(indices.numpy())
|
2849
|
+
```
|
2850
|
+
"""
|
2851
|
+
x, dim = self, self._resolve_dim(dim)
|
2852
|
+
if (orig_len:= x.shape[dim]) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
|
2853
|
+
# pad to power of 2
|
2854
|
+
n_stages = (orig_len-1).bit_length()
|
2855
|
+
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
|
2856
|
+
x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages)
|
2857
|
+
# https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg
|
2858
|
+
for stage in range(1, n_stages+1):
|
2859
|
+
if stage != n_stages:
|
2860
|
+
# flip so arrows of green boxes point the same way as blue boxes
|
2861
|
+
crossover_dim = dim + n_stages - stage - 1
|
2862
|
+
blue_box, green_box = x.split(1, crossover_dim)
|
2863
|
+
flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim)))
|
2864
|
+
x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous()
|
2865
|
+
for substage in range(stage-1, -1, -1):
|
2866
|
+
partner_dim = dim + n_stages - substage - 1
|
2867
|
+
x_top, x_bottom = x.split(1, partner_dim)
|
2868
|
+
x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom)
|
2869
|
+
x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous()
|
2870
|
+
if stage != n_stages:
|
2871
|
+
# flip wires back to undo the crossover
|
2872
|
+
blue_box, flipped_green_box = x.split(1, crossover_dim)
|
2873
|
+
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
|
2874
|
+
x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, s) for s in self.shape))
|
2875
|
+
# compute indices for sorted values
|
2876
|
+
mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1))
|
2877
|
+
def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
|
2878
|
+
count_orig, count_sorted = compute_counts(self), compute_counts(x)
|
2879
|
+
cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
|
2880
|
+
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
|
2881
|
+
idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
|
2882
|
+
return x, idx
|
2883
|
+
|
2884
|
+
def argsort(self, dim:int=-1, descending:bool=False) -> Tensor:
|
2885
|
+
"""
|
2886
|
+
Returns the indices that sort input tensor along given `dimension` in given `descending` order by value.
|
2887
|
+
|
2888
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2889
|
+
t = Tensor([[2, 3, 4, 1], [1, 4, 3, 2]])
|
2890
|
+
print(t.argsort().numpy())
|
2891
|
+
```
|
2892
|
+
"""
|
2893
|
+
return self.sort(dim, descending)[1]
|
2894
|
+
|
2895
|
+
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]:
|
2896
|
+
"""
|
2897
|
+
Computes the top-k elements of the tensor along the specified `dim`.
|
2898
|
+
|
2899
|
+
Order of indices for equivalent elements is always preserved.
|
2900
|
+
|
2901
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2902
|
+
t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]])
|
2903
|
+
print(t.numpy())
|
2904
|
+
```
|
2905
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2906
|
+
topk_values, topk_indices = t.topk(2, dim=1)
|
2907
|
+
print(topk_values.numpy())
|
2908
|
+
print(topk_indices.numpy())
|
2909
|
+
```
|
2910
|
+
"""
|
2911
|
+
if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported")
|
2912
|
+
if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range")
|
2913
|
+
x, idx = self.sort(dim, descending=largest)
|
2914
|
+
shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim))
|
2915
|
+
return x.shrink(shrink_to_k), idx.shrink(shrink_to_k)
|
2916
|
+
|
2554
2917
|
# ***** unary ops *****
|
2555
2918
|
|
2556
|
-
def logical_not(self):
|
2919
|
+
def logical_not(self) -> Tensor:
|
2557
2920
|
"""
|
2558
2921
|
Computes the logical NOT of the tensor element-wise.
|
2559
2922
|
|
@@ -2562,7 +2925,8 @@ class Tensor(SimpleMathTrait):
|
|
2562
2925
|
```
|
2563
2926
|
"""
|
2564
2927
|
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
2565
|
-
|
2928
|
+
|
2929
|
+
def neg(self) -> Tensor:
|
2566
2930
|
"""
|
2567
2931
|
Negates the tensor element-wise.
|
2568
2932
|
|
@@ -2571,17 +2935,29 @@ class Tensor(SimpleMathTrait):
|
|
2571
2935
|
```
|
2572
2936
|
"""
|
2573
2937
|
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
2574
|
-
|
2938
|
+
|
2939
|
+
def contiguous(self, **kwargs) -> Tensor:
|
2575
2940
|
"""
|
2576
2941
|
Returns a contiguous tensor.
|
2577
2942
|
"""
|
2578
|
-
return self._apply_uop(UOp.contiguous)
|
2579
|
-
|
2943
|
+
return self._apply_uop(UOp.contiguous, **kwargs)
|
2944
|
+
|
2945
|
+
def fuse(self) -> Tensor:
|
2946
|
+
"""
|
2947
|
+
Makes this a single kernel back to Ops.CONTIGUOUS on the inputs.
|
2948
|
+
|
2949
|
+
Useful for single kernel softmax and flash attention.
|
2950
|
+
Careful, this can break codegen or make kernels really slow.
|
2951
|
+
"""
|
2952
|
+
return self._apply_uop(UOp.fuse)
|
2953
|
+
|
2954
|
+
def contiguous_backward(self) -> Tensor:
|
2580
2955
|
"""
|
2581
2956
|
Inserts a contiguous operation in the backward pass.
|
2582
2957
|
"""
|
2583
2958
|
return self._apply_uop(UOp.contiguous_backward)
|
2584
|
-
|
2959
|
+
|
2960
|
+
def log(self) -> Tensor:
|
2585
2961
|
"""
|
2586
2962
|
Computes the natural logarithm element-wise.
|
2587
2963
|
|
@@ -2592,7 +2968,8 @@ class Tensor(SimpleMathTrait):
|
|
2592
2968
|
```
|
2593
2969
|
"""
|
2594
2970
|
return self.log2()*math.log(2)
|
2595
|
-
|
2971
|
+
|
2972
|
+
def log2(self) -> Tensor:
|
2596
2973
|
"""
|
2597
2974
|
Computes the base-2 logarithm element-wise.
|
2598
2975
|
|
@@ -2603,7 +2980,8 @@ class Tensor(SimpleMathTrait):
|
|
2603
2980
|
```
|
2604
2981
|
"""
|
2605
2982
|
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
2606
|
-
|
2983
|
+
|
2984
|
+
def exp(self) -> Tensor:
|
2607
2985
|
"""
|
2608
2986
|
Computes the exponential function element-wise.
|
2609
2987
|
|
@@ -2614,7 +2992,8 @@ class Tensor(SimpleMathTrait):
|
|
2614
2992
|
```
|
2615
2993
|
"""
|
2616
2994
|
return self.mul(1/math.log(2)).exp2()
|
2617
|
-
|
2995
|
+
|
2996
|
+
def exp2(self) -> Tensor:
|
2618
2997
|
"""
|
2619
2998
|
Computes the base-2 exponential function element-wise.
|
2620
2999
|
|
@@ -2625,19 +3004,19 @@ class Tensor(SimpleMathTrait):
|
|
2625
3004
|
```
|
2626
3005
|
"""
|
2627
3006
|
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
2628
|
-
|
3007
|
+
|
3008
|
+
def relu(self) -> Tensor:
|
2629
3009
|
"""
|
2630
3010
|
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
2631
3011
|
|
2632
|
-
- Described: https://paperswithcode.com/method/relu
|
2633
|
-
|
2634
3012
|
```python exec="true" source="above" session="tensor" result="python"
|
2635
3013
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
2636
3014
|
```
|
2637
3015
|
"""
|
3016
|
+
# NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0
|
2638
3017
|
return (self>0).where(self, 0)
|
2639
3018
|
|
2640
|
-
def sigmoid(self):
|
3019
|
+
def sigmoid(self) -> Tensor:
|
2641
3020
|
"""
|
2642
3021
|
Applies the Sigmoid function element-wise.
|
2643
3022
|
|
@@ -2649,12 +3028,23 @@ class Tensor(SimpleMathTrait):
|
|
2649
3028
|
"""
|
2650
3029
|
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
|
2651
3030
|
|
2652
|
-
def
|
3031
|
+
def logsigmoid(self) -> Tensor:
|
3032
|
+
"""
|
3033
|
+
Applies the LogSigmoid function element-wise.
|
3034
|
+
|
3035
|
+
- See: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.logsigmoid.html
|
3036
|
+
|
3037
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3038
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).logsigmoid().numpy())
|
3039
|
+
```
|
3040
|
+
"""
|
3041
|
+
return -(-self).softplus()
|
3042
|
+
|
3043
|
+
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
|
2653
3044
|
"""
|
2654
3045
|
Applies the Hardsigmoid function element-wise.
|
2655
|
-
NOTE: default `alpha` and `beta` values
|
3046
|
+
NOTE: default `alpha` and `beta` values are taken from torch
|
2656
3047
|
|
2657
|
-
- Described: https://paperswithcode.com/method/hard-sigmoid
|
2658
3048
|
- See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
|
2659
3049
|
|
2660
3050
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2663,7 +3053,7 @@ class Tensor(SimpleMathTrait):
|
|
2663
3053
|
"""
|
2664
3054
|
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
|
2665
3055
|
|
2666
|
-
def sqrt(self):
|
3056
|
+
def sqrt(self) -> Tensor:
|
2667
3057
|
"""
|
2668
3058
|
Computes the square root of the tensor element-wise.
|
2669
3059
|
|
@@ -2672,7 +3062,8 @@ class Tensor(SimpleMathTrait):
|
|
2672
3062
|
```
|
2673
3063
|
"""
|
2674
3064
|
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
2675
|
-
|
3065
|
+
|
3066
|
+
def rsqrt(self) -> Tensor:
|
2676
3067
|
"""
|
2677
3068
|
Computes the reciprocal of the square root of the tensor element-wise.
|
2678
3069
|
|
@@ -2681,7 +3072,8 @@ class Tensor(SimpleMathTrait):
|
|
2681
3072
|
```
|
2682
3073
|
"""
|
2683
3074
|
return self.sqrt().reciprocal()
|
2684
|
-
|
3075
|
+
|
3076
|
+
def sin(self) -> Tensor:
|
2685
3077
|
"""
|
2686
3078
|
Computes the sine of the tensor element-wise.
|
2687
3079
|
|
@@ -2690,7 +3082,8 @@ class Tensor(SimpleMathTrait):
|
|
2690
3082
|
```
|
2691
3083
|
"""
|
2692
3084
|
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
2693
|
-
|
3085
|
+
|
3086
|
+
def cos(self) -> Tensor:
|
2694
3087
|
"""
|
2695
3088
|
Computes the cosine of the tensor element-wise.
|
2696
3089
|
|
@@ -2699,7 +3092,8 @@ class Tensor(SimpleMathTrait):
|
|
2699
3092
|
```
|
2700
3093
|
"""
|
2701
3094
|
return ((math.pi/2)-self).sin()
|
2702
|
-
|
3095
|
+
|
3096
|
+
def tan(self) -> Tensor:
|
2703
3097
|
"""
|
2704
3098
|
Computes the tangent of the tensor element-wise.
|
2705
3099
|
|
@@ -2709,7 +3103,7 @@ class Tensor(SimpleMathTrait):
|
|
2709
3103
|
"""
|
2710
3104
|
return self.sin() / self.cos()
|
2711
3105
|
|
2712
|
-
def asin(self):
|
3106
|
+
def asin(self) -> Tensor:
|
2713
3107
|
"""
|
2714
3108
|
Computes the inverse sine (arcsine) of the tensor element-wise.
|
2715
3109
|
|
@@ -2722,7 +3116,7 @@ class Tensor(SimpleMathTrait):
|
|
2722
3116
|
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
|
2723
3117
|
return self.sign() * x
|
2724
3118
|
|
2725
|
-
def acos(self):
|
3119
|
+
def acos(self) -> Tensor:
|
2726
3120
|
"""
|
2727
3121
|
Computes the inverse cosine (arccosine) of the tensor element-wise.
|
2728
3122
|
|
@@ -2732,7 +3126,7 @@ class Tensor(SimpleMathTrait):
|
|
2732
3126
|
"""
|
2733
3127
|
return math.pi / 2 - self.asin()
|
2734
3128
|
|
2735
|
-
def atan(self):
|
3129
|
+
def atan(self) -> Tensor:
|
2736
3130
|
"""
|
2737
3131
|
Computes the inverse tangent (arctan) of the tensor element-wise.
|
2738
3132
|
|
@@ -2752,7 +3146,8 @@ class Tensor(SimpleMathTrait):
|
|
2752
3146
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
|
2753
3147
|
```
|
2754
3148
|
"""
|
2755
|
-
return self.
|
3149
|
+
return self._apply_uop(UOp.trunc)
|
3150
|
+
|
2756
3151
|
def ceil(self: Tensor) -> Tensor:
|
2757
3152
|
"""
|
2758
3153
|
Rounds the tensor element-wise towards positive infinity.
|
@@ -2762,6 +3157,7 @@ class Tensor(SimpleMathTrait):
|
|
2762
3157
|
```
|
2763
3158
|
"""
|
2764
3159
|
return (self > (b := self.trunc())).where(b+1, b)
|
3160
|
+
|
2765
3161
|
def floor(self: Tensor) -> Tensor:
|
2766
3162
|
"""
|
2767
3163
|
Rounds the tensor element-wise towards negative infinity.
|
@@ -2771,6 +3167,7 @@ class Tensor(SimpleMathTrait):
|
|
2771
3167
|
```
|
2772
3168
|
"""
|
2773
3169
|
return (self < (b := self.trunc())).where(b-1, b)
|
3170
|
+
|
2774
3171
|
def round(self: Tensor) -> Tensor:
|
2775
3172
|
"""
|
2776
3173
|
Rounds the tensor element-wise with rounding half to even.
|
@@ -2779,9 +3176,9 @@ class Tensor(SimpleMathTrait):
|
|
2779
3176
|
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
|
2780
3177
|
```
|
2781
3178
|
"""
|
2782
|
-
return ((self > 0) == ((b := self.
|
3179
|
+
return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
2783
3180
|
|
2784
|
-
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
|
3181
|
+
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor:
|
2785
3182
|
"""
|
2786
3183
|
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
|
2787
3184
|
|
@@ -2790,7 +3187,8 @@ class Tensor(SimpleMathTrait):
|
|
2790
3187
|
```
|
2791
3188
|
"""
|
2792
3189
|
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
|
2793
|
-
|
3190
|
+
|
3191
|
+
def isnan(self:Tensor) -> Tensor:
|
2794
3192
|
"""
|
2795
3193
|
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
|
2796
3194
|
|
@@ -2800,7 +3198,17 @@ class Tensor(SimpleMathTrait):
|
|
2800
3198
|
"""
|
2801
3199
|
return self != self
|
2802
3200
|
|
2803
|
-
def
|
3201
|
+
def isfinite(self:Tensor) -> Tensor:
|
3202
|
+
"""
|
3203
|
+
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
|
3204
|
+
|
3205
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3206
|
+
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy())
|
3207
|
+
```
|
3208
|
+
"""
|
3209
|
+
return (self.isinf()|self.isnan()).logical_not()
|
3210
|
+
|
3211
|
+
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
|
2804
3212
|
"""
|
2805
3213
|
Linearly interpolates between `self` and `end` by `weight`.
|
2806
3214
|
|
@@ -2813,7 +3221,7 @@ class Tensor(SimpleMathTrait):
|
|
2813
3221
|
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
|
2814
3222
|
return self + (end - self) * weight
|
2815
3223
|
|
2816
|
-
def square(self):
|
3224
|
+
def square(self) -> Tensor:
|
2817
3225
|
"""
|
2818
3226
|
Squares the tensor element-wise.
|
2819
3227
|
Equivalent to `self*self`.
|
@@ -2823,7 +3231,8 @@ class Tensor(SimpleMathTrait):
|
|
2823
3231
|
```
|
2824
3232
|
"""
|
2825
3233
|
return self*self
|
2826
|
-
|
3234
|
+
|
3235
|
+
def clamp(self, min_=None, max_=None) -> Tensor:
|
2827
3236
|
"""
|
2828
3237
|
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
2829
3238
|
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
@@ -2835,12 +3244,14 @@ class Tensor(SimpleMathTrait):
|
|
2835
3244
|
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
2836
3245
|
ret = self.maximum(min_) if min_ is not None else self
|
2837
3246
|
return ret.minimum(max_) if max_ is not None else ret
|
2838
|
-
|
3247
|
+
|
3248
|
+
def clip(self, min_=None, max_=None) -> Tensor:
|
2839
3249
|
"""
|
2840
3250
|
Alias for `Tensor.clamp`.
|
2841
3251
|
"""
|
2842
3252
|
return self.clamp(min_, max_)
|
2843
|
-
|
3253
|
+
|
3254
|
+
def sign(self) -> Tensor:
|
2844
3255
|
"""
|
2845
3256
|
Returns the sign of the tensor element-wise.
|
2846
3257
|
|
@@ -2849,7 +3260,8 @@ class Tensor(SimpleMathTrait):
|
|
2849
3260
|
```
|
2850
3261
|
"""
|
2851
3262
|
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
2852
|
-
|
3263
|
+
|
3264
|
+
def abs(self) -> Tensor:
|
2853
3265
|
"""
|
2854
3266
|
Computes the absolute value of the tensor element-wise.
|
2855
3267
|
|
@@ -2858,9 +3270,10 @@ class Tensor(SimpleMathTrait):
|
|
2858
3270
|
```
|
2859
3271
|
"""
|
2860
3272
|
return self * self.sign()
|
2861
|
-
|
3273
|
+
|
3274
|
+
def reciprocal(self) -> Tensor:
|
2862
3275
|
"""
|
2863
|
-
|
3276
|
+
Computes `1/x` element-wise.
|
2864
3277
|
|
2865
3278
|
```python exec="true" source="above" session="tensor" result="python"
|
2866
3279
|
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
@@ -2870,11 +3283,10 @@ class Tensor(SimpleMathTrait):
|
|
2870
3283
|
|
2871
3284
|
# ***** activation functions *****
|
2872
3285
|
|
2873
|
-
def elu(self, alpha=1.0):
|
3286
|
+
def elu(self, alpha=1.0) -> Tensor:
|
2874
3287
|
"""
|
2875
3288
|
Applies the Exponential Linear Unit (ELU) function element-wise.
|
2876
3289
|
|
2877
|
-
- Described: https://paperswithcode.com/method/elu
|
2878
3290
|
- Paper: https://arxiv.org/abs/1511.07289v5
|
2879
3291
|
|
2880
3292
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2883,11 +3295,10 @@ class Tensor(SimpleMathTrait):
|
|
2883
3295
|
"""
|
2884
3296
|
return self.relu() - alpha*(1-self.exp()).relu()
|
2885
3297
|
|
2886
|
-
def celu(self, alpha=1.0):
|
3298
|
+
def celu(self, alpha=1.0) -> Tensor:
|
2887
3299
|
"""
|
2888
3300
|
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
|
2889
3301
|
|
2890
|
-
- Described: https://paperswithcode.com/method/celu
|
2891
3302
|
- Paper: https://arxiv.org/abs/1704.07483
|
2892
3303
|
|
2893
3304
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2896,11 +3307,10 @@ class Tensor(SimpleMathTrait):
|
|
2896
3307
|
"""
|
2897
3308
|
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
2898
3309
|
|
2899
|
-
def selu(self, alpha=1.67326, gamma=1.0507):
|
3310
|
+
def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
|
2900
3311
|
"""
|
2901
3312
|
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
2902
3313
|
|
2903
|
-
- Described: https://paperswithcode.com/method/selu
|
2904
3314
|
- Paper: https://arxiv.org/abs/1706.02515v5
|
2905
3315
|
|
2906
3316
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2909,7 +3319,7 @@ class Tensor(SimpleMathTrait):
|
|
2909
3319
|
"""
|
2910
3320
|
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
2911
3321
|
|
2912
|
-
def swish(self):
|
3322
|
+
def swish(self) -> Tensor:
|
2913
3323
|
"""
|
2914
3324
|
See `.silu()`
|
2915
3325
|
|
@@ -2921,11 +3331,10 @@ class Tensor(SimpleMathTrait):
|
|
2921
3331
|
"""
|
2922
3332
|
return self * self.sigmoid()
|
2923
3333
|
|
2924
|
-
def silu(self):
|
3334
|
+
def silu(self) -> Tensor:
|
2925
3335
|
"""
|
2926
3336
|
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
|
2927
3337
|
|
2928
|
-
- Described: https://paperswithcode.com/method/silu
|
2929
3338
|
- Paper: https://arxiv.org/abs/1606.08415
|
2930
3339
|
|
2931
3340
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2934,11 +3343,10 @@ class Tensor(SimpleMathTrait):
|
|
2934
3343
|
"""
|
2935
3344
|
return self.swish() # The SiLU function is also known as the swish function.
|
2936
3345
|
|
2937
|
-
def relu6(self):
|
3346
|
+
def relu6(self) -> Tensor:
|
2938
3347
|
"""
|
2939
3348
|
Applies the ReLU6 function element-wise.
|
2940
3349
|
|
2941
|
-
- Described: https://paperswithcode.com/method/relu6
|
2942
3350
|
- Paper: https://arxiv.org/abs/1704.04861v1
|
2943
3351
|
|
2944
3352
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2947,11 +3355,10 @@ class Tensor(SimpleMathTrait):
|
|
2947
3355
|
"""
|
2948
3356
|
return self.relu() - (self-6).relu()
|
2949
3357
|
|
2950
|
-
def hardswish(self):
|
3358
|
+
def hardswish(self) -> Tensor:
|
2951
3359
|
"""
|
2952
3360
|
Applies the Hardswish function element-wise.
|
2953
3361
|
|
2954
|
-
- Described: https://paperswithcode.com/method/hard-swish
|
2955
3362
|
- Paper: https://arxiv.org/abs/1905.02244v5
|
2956
3363
|
|
2957
3364
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -2960,7 +3367,7 @@ class Tensor(SimpleMathTrait):
|
|
2960
3367
|
"""
|
2961
3368
|
return self * (self+3).relu6() * (1/6)
|
2962
3369
|
|
2963
|
-
def tanh(self):
|
3370
|
+
def tanh(self) -> Tensor:
|
2964
3371
|
"""
|
2965
3372
|
Applies the Hyperbolic Tangent (tanh) function element-wise.
|
2966
3373
|
|
@@ -2972,7 +3379,7 @@ class Tensor(SimpleMathTrait):
|
|
2972
3379
|
"""
|
2973
3380
|
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
2974
3381
|
|
2975
|
-
def sinh(self):
|
3382
|
+
def sinh(self) -> Tensor:
|
2976
3383
|
"""
|
2977
3384
|
Applies the Hyperbolic Sine (sinh) function element-wise.
|
2978
3385
|
|
@@ -2984,7 +3391,7 @@ class Tensor(SimpleMathTrait):
|
|
2984
3391
|
"""
|
2985
3392
|
return (self.exp() - self.neg().exp()) / 2
|
2986
3393
|
|
2987
|
-
def cosh(self):
|
3394
|
+
def cosh(self) -> Tensor:
|
2988
3395
|
"""
|
2989
3396
|
Applies the Hyperbolic Cosine (cosh) function element-wise.
|
2990
3397
|
|
@@ -2996,7 +3403,7 @@ class Tensor(SimpleMathTrait):
|
|
2996
3403
|
"""
|
2997
3404
|
return (self.exp() + self.neg().exp()) / 2
|
2998
3405
|
|
2999
|
-
def atanh(self):
|
3406
|
+
def atanh(self) -> Tensor:
|
3000
3407
|
"""
|
3001
3408
|
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
|
3002
3409
|
|
@@ -3008,7 +3415,7 @@ class Tensor(SimpleMathTrait):
|
|
3008
3415
|
"""
|
3009
3416
|
return ((1 + self)/(1 - self)).log() / 2
|
3010
3417
|
|
3011
|
-
def asinh(self):
|
3418
|
+
def asinh(self) -> Tensor:
|
3012
3419
|
"""
|
3013
3420
|
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
|
3014
3421
|
|
@@ -3020,7 +3427,7 @@ class Tensor(SimpleMathTrait):
|
|
3020
3427
|
"""
|
3021
3428
|
return (self + (self.square() + 1).sqrt()).log()
|
3022
3429
|
|
3023
|
-
def acosh(self):
|
3430
|
+
def acosh(self) -> Tensor:
|
3024
3431
|
"""
|
3025
3432
|
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
|
3026
3433
|
|
@@ -3032,19 +3439,17 @@ class Tensor(SimpleMathTrait):
|
|
3032
3439
|
"""
|
3033
3440
|
return (self + (self.square() - 1).sqrt()).log()
|
3034
3441
|
|
3035
|
-
def hardtanh(self, min_val=-1, max_val=1):
|
3442
|
+
def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
|
3036
3443
|
"""
|
3037
3444
|
Applies the Hardtanh function element-wise.
|
3038
3445
|
|
3039
|
-
- Described: https://paperswithcode.com/method/hardtanh-activation
|
3040
|
-
|
3041
3446
|
```python exec="true" source="above" session="tensor" result="python"
|
3042
3447
|
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
|
3043
3448
|
```
|
3044
3449
|
"""
|
3045
3450
|
return self.clip(min_val, max_val)
|
3046
3451
|
|
3047
|
-
def erf(self):
|
3452
|
+
def erf(self) -> Tensor:
|
3048
3453
|
"""
|
3049
3454
|
Applies error function element-wise.
|
3050
3455
|
|
@@ -3058,11 +3463,10 @@ class Tensor(SimpleMathTrait):
|
|
3058
3463
|
t = 1.0 / (1.0 + 0.3275911 * self.abs())
|
3059
3464
|
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
|
3060
3465
|
|
3061
|
-
def gelu(self):
|
3466
|
+
def gelu(self) -> Tensor:
|
3062
3467
|
"""
|
3063
3468
|
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
3064
3469
|
|
3065
|
-
- Described: https://paperswithcode.com/method/gelu
|
3066
3470
|
- Paper: https://arxiv.org/abs/1606.08415v5
|
3067
3471
|
|
3068
3472
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3071,38 +3475,33 @@ class Tensor(SimpleMathTrait):
|
|
3071
3475
|
"""
|
3072
3476
|
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
|
3073
3477
|
|
3074
|
-
def quick_gelu(self):
|
3478
|
+
def quick_gelu(self) -> Tensor:
|
3075
3479
|
"""
|
3076
3480
|
Applies the Sigmoid GELU approximation element-wise.
|
3077
3481
|
|
3078
|
-
- Described: https://paperswithcode.com/method/gelu
|
3079
|
-
|
3080
3482
|
```python exec="true" source="above" session="tensor" result="python"
|
3081
3483
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
|
3082
3484
|
```
|
3083
3485
|
"""
|
3084
3486
|
return self * (self * 1.702).sigmoid()
|
3085
3487
|
|
3086
|
-
def
|
3488
|
+
def leaky_relu(self, neg_slope=0.01) -> Tensor:
|
3087
3489
|
"""
|
3088
3490
|
Applies the Leaky ReLU function element-wise.
|
3089
3491
|
|
3090
|
-
- Described: https://paperswithcode.com/method/leaky-relu
|
3091
|
-
|
3092
3492
|
```python exec="true" source="above" session="tensor" result="python"
|
3093
|
-
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).
|
3493
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
|
3094
3494
|
```
|
3095
3495
|
```python exec="true" source="above" session="tensor" result="python"
|
3096
|
-
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).
|
3496
|
+
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
|
3097
3497
|
```
|
3098
3498
|
"""
|
3099
|
-
return self.
|
3499
|
+
return (self<0).where(neg_slope*self, self)
|
3100
3500
|
|
3101
|
-
def mish(self):
|
3501
|
+
def mish(self) -> Tensor:
|
3102
3502
|
"""
|
3103
3503
|
Applies the Mish function element-wise.
|
3104
3504
|
|
3105
|
-
- Described: https://paperswithcode.com/method/mish
|
3106
3505
|
- Paper: https://arxiv.org/abs/1908.08681v3
|
3107
3506
|
|
3108
3507
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3111,24 +3510,21 @@ class Tensor(SimpleMathTrait):
|
|
3111
3510
|
"""
|
3112
3511
|
return self * self.softplus().tanh()
|
3113
3512
|
|
3114
|
-
def softplus(self, beta=1):
|
3513
|
+
def softplus(self, beta=1.0, threshold=20.0) -> Tensor:
|
3115
3514
|
"""
|
3116
3515
|
Applies the Softplus function element-wise.
|
3117
|
-
|
3118
|
-
- Described: https://paperswithcode.com/method/softplus
|
3516
|
+
For numerical stability, the implementation folds into identity function when `self * beta > threshold`.
|
3119
3517
|
|
3120
3518
|
```python exec="true" source="above" session="tensor" result="python"
|
3121
3519
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
|
3122
3520
|
```
|
3123
3521
|
"""
|
3124
|
-
return (1/beta) * (1 + (self*beta).exp()).log()
|
3522
|
+
return (self * beta > threshold).where(self, (1/beta) * (1 + (self*beta).exp()).log())
|
3125
3523
|
|
3126
|
-
def softsign(self):
|
3524
|
+
def softsign(self) -> Tensor:
|
3127
3525
|
"""
|
3128
3526
|
Applies the Softsign function element-wise.
|
3129
3527
|
|
3130
|
-
- Described: https://paperswithcode.com/method/softsign
|
3131
|
-
|
3132
3528
|
```python exec="true" source="above" session="tensor" result="python"
|
3133
3529
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
|
3134
3530
|
```
|
@@ -3144,9 +3540,10 @@ class Tensor(SimpleMathTrait):
|
|
3144
3540
|
# for each dimension, check either dim is 1, or it does not change
|
3145
3541
|
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
3146
3542
|
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
3147
|
-
|
3543
|
+
# NOTE: this cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
3544
|
+
return self.reshape(shape).cast(sum_acc_dtype(self.dtype))._apply_uop(UOp.expand, arg=new_shape).cast(self.dtype)
|
3148
3545
|
|
3149
|
-
def _broadcasted(self, y:
|
3546
|
+
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
3150
3547
|
x: Tensor = self
|
3151
3548
|
if not isinstance(y, Tensor):
|
3152
3549
|
# make y a Tensor
|
@@ -3165,27 +3562,7 @@ class Tensor(SimpleMathTrait):
|
|
3165
3562
|
# broadcast
|
3166
3563
|
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
3167
3564
|
|
3168
|
-
def
|
3169
|
-
"""
|
3170
|
-
Adds `self` and `x`.
|
3171
|
-
Equivalent to `self + x`.
|
3172
|
-
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
3173
|
-
|
3174
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3175
|
-
Tensor.manual_seed(42)
|
3176
|
-
t = Tensor.randn(4)
|
3177
|
-
print(t.numpy())
|
3178
|
-
```
|
3179
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3180
|
-
print(t.add(20).numpy())
|
3181
|
-
```
|
3182
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3183
|
-
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
3184
|
-
```
|
3185
|
-
"""
|
3186
|
-
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
3187
|
-
|
3188
|
-
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3565
|
+
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
3189
3566
|
"""
|
3190
3567
|
Subtracts `x` from `self`.
|
3191
3568
|
Equivalent to `self - x`.
|
@@ -3206,40 +3583,7 @@ class Tensor(SimpleMathTrait):
|
|
3206
3583
|
a, b = self._broadcasted(x, reverse)
|
3207
3584
|
return a + (-b)
|
3208
3585
|
|
3209
|
-
def
|
3210
|
-
"""
|
3211
|
-
Multiplies `self` and `x`.
|
3212
|
-
Equivalent to `self * x`.
|
3213
|
-
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
3214
|
-
|
3215
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3216
|
-
Tensor.manual_seed(42)
|
3217
|
-
t = Tensor.randn(4)
|
3218
|
-
print(t.numpy())
|
3219
|
-
```
|
3220
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3221
|
-
print(t.mul(3).numpy())
|
3222
|
-
```
|
3223
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3224
|
-
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
3225
|
-
```
|
3226
|
-
"""
|
3227
|
-
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
3228
|
-
|
3229
|
-
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3230
|
-
"""
|
3231
|
-
Divides `self` by `x`.
|
3232
|
-
Equivalent to `self // x`.
|
3233
|
-
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
3234
|
-
`idiv` performs integer division (truncate towards zero).
|
3235
|
-
|
3236
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3237
|
-
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
3238
|
-
```
|
3239
|
-
"""
|
3240
|
-
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
3241
|
-
|
3242
|
-
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3586
|
+
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
3243
3587
|
"""
|
3244
3588
|
Divides `self` by `x`.
|
3245
3589
|
Equivalent to `self / x`.
|
@@ -3259,9 +3603,21 @@ class Tensor(SimpleMathTrait):
|
|
3259
3603
|
```
|
3260
3604
|
"""
|
3261
3605
|
numerator, denominator = self._broadcasted(x, reverse)
|
3262
|
-
|
3263
|
-
|
3264
|
-
|
3606
|
+
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
3607
|
+
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
|
3608
|
+
if dtypes.is_int(dt:=least_upper_dtype(numerator.dtype, denominator.dtype)) and rounding_mode is not None:
|
3609
|
+
numerator, denominator = numerator.cast(dt), denominator.cast(dt)
|
3610
|
+
if rounding_mode == "trunc": return numerator.idiv(denominator)
|
3611
|
+
if rounding_mode == "floor":
|
3612
|
+
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._apply_broadcasted_uop(UOp.mod, denominator)
|
3613
|
+
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
|
3614
|
+
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
|
3615
|
+
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
|
3616
|
+
if rounding_mode == "floor": return d.floor().cast(output_dtype)
|
3617
|
+
if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported")
|
3618
|
+
return d
|
3619
|
+
|
3620
|
+
def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
3265
3621
|
"""
|
3266
3622
|
Mod `self` by `x`.
|
3267
3623
|
Equivalent to `self % x`.
|
@@ -3272,57 +3628,11 @@ class Tensor(SimpleMathTrait):
|
|
3272
3628
|
```
|
3273
3629
|
"""
|
3274
3630
|
a, b = self._broadcasted(x, reverse)
|
3275
|
-
return
|
3276
|
-
|
3277
|
-
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3278
|
-
"""
|
3279
|
-
Computes bitwise xor of `self` and `x`.
|
3280
|
-
Equivalent to `self ^ x`.
|
3281
|
-
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
3282
|
-
|
3283
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3284
|
-
print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
|
3285
|
-
```
|
3286
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3287
|
-
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
|
3288
|
-
```
|
3289
|
-
"""
|
3290
|
-
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3291
|
-
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
|
3292
|
-
|
3293
|
-
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3294
|
-
"""
|
3295
|
-
Compute the bit-wise AND of `self` and `x`.
|
3296
|
-
Equivalent to `self & x`.
|
3297
|
-
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
3298
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3299
|
-
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
3300
|
-
```
|
3301
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3302
|
-
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
3303
|
-
```
|
3304
|
-
"""
|
3305
|
-
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3306
|
-
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
3307
|
-
|
3308
|
-
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
3309
|
-
"""
|
3310
|
-
Compute the bit-wise OR of `self` and `x`.
|
3311
|
-
Equivalent to `self | x`.
|
3312
|
-
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
3313
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3314
|
-
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
3315
|
-
```
|
3316
|
-
```python exec="true" source="above" session="tensor" result="python"
|
3317
|
-
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
3318
|
-
```
|
3319
|
-
"""
|
3320
|
-
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3321
|
-
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
|
3631
|
+
return a - a.div(b, rounding_mode="floor") * b
|
3322
3632
|
|
3323
3633
|
def bitwise_not(self) -> Tensor:
|
3324
3634
|
"""
|
3325
|
-
|
3635
|
+
Computes the bitwise NOT of `self`.
|
3326
3636
|
Equivalent to `~self`.
|
3327
3637
|
```python exec="true" source="above" session="tensor" result="python"
|
3328
3638
|
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
|
@@ -3334,7 +3644,7 @@ class Tensor(SimpleMathTrait):
|
|
3334
3644
|
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
3335
3645
|
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
3336
3646
|
|
3337
|
-
def lshift(self, x:int):
|
3647
|
+
def lshift(self, x:int, reverse=False) -> Tensor:
|
3338
3648
|
"""
|
3339
3649
|
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
3340
3650
|
Equivalent to `self << x`.
|
@@ -3343,10 +3653,10 @@ class Tensor(SimpleMathTrait):
|
|
3343
3653
|
print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
|
3344
3654
|
```
|
3345
3655
|
"""
|
3346
|
-
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
3347
|
-
return self.mul(2 ** x)
|
3656
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
|
3657
|
+
return self.mul(2 ** x, reverse)
|
3348
3658
|
|
3349
|
-
def rshift(self, x:int):
|
3659
|
+
def rshift(self, x:int, reverse=False) -> Tensor:
|
3350
3660
|
"""
|
3351
3661
|
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
3352
3662
|
Equivalent to `self >> x`.
|
@@ -3355,10 +3665,10 @@ class Tensor(SimpleMathTrait):
|
|
3355
3665
|
print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
|
3356
3666
|
```
|
3357
3667
|
"""
|
3358
|
-
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
3359
|
-
return self.idiv(2 ** x)
|
3668
|
+
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
|
3669
|
+
return self.idiv(2 ** x, reverse)
|
3360
3670
|
|
3361
|
-
def pow(self, x:
|
3671
|
+
def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
3362
3672
|
"""
|
3363
3673
|
Computes power of `self` with `x`.
|
3364
3674
|
Equivalent to `self ** x`.
|
@@ -3375,13 +3685,13 @@ class Tensor(SimpleMathTrait):
|
|
3375
3685
|
"""
|
3376
3686
|
base, exponent = self._broadcasted(x, reverse=reverse)
|
3377
3687
|
# TODO: int pow
|
3378
|
-
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
3688
|
+
if not base.is_floating_point() and not (isinstance(x, int) and x >= 0): raise RuntimeError("base needs to be float")
|
3379
3689
|
|
3380
|
-
# NOTE: pow(int, float) -> int
|
3381
3690
|
ret = base._apply_uop(UOp.pow, exponent)
|
3382
|
-
|
3691
|
+
# NOTE: pow(int, float) -> int
|
3692
|
+
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
|
3383
3693
|
|
3384
|
-
def maximum(self, x:
|
3694
|
+
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
3385
3695
|
"""
|
3386
3696
|
Computes element-wise maximum of `self` and `x`.
|
3387
3697
|
|
@@ -3394,7 +3704,7 @@ class Tensor(SimpleMathTrait):
|
|
3394
3704
|
"""
|
3395
3705
|
return self._apply_broadcasted_uop(UOp.maximum, x)
|
3396
3706
|
|
3397
|
-
def minimum(self, x:
|
3707
|
+
def minimum(self, x:Tensor|ConstType) -> Tensor:
|
3398
3708
|
"""
|
3399
3709
|
Computes element-wise minimum of `self` and `x`.
|
3400
3710
|
|
@@ -3408,9 +3718,9 @@ class Tensor(SimpleMathTrait):
|
|
3408
3718
|
t, x = self._broadcasted(x)
|
3409
3719
|
return t._inverse().maximum(x._inverse())._inverse()
|
3410
3720
|
|
3411
|
-
def where(self:Tensor, x:
|
3721
|
+
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
3412
3722
|
"""
|
3413
|
-
|
3723
|
+
Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
|
3414
3724
|
`output_i = x_i if self_i else y_i`.
|
3415
3725
|
|
3416
3726
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3432,14 +3742,22 @@ class Tensor(SimpleMathTrait):
|
|
3432
3742
|
cond, y = cond._broadcasted(y, match_dtype=False)
|
3433
3743
|
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
3434
3744
|
|
3435
|
-
def
|
3745
|
+
def copysign(self, other) -> Tensor:
|
3746
|
+
"""
|
3747
|
+
Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
|
3748
|
+
"""
|
3749
|
+
# NOTE: torch always return in float, we return based on the broadcasting rule.
|
3750
|
+
other = self._broadcasted(other)[1]
|
3751
|
+
# TODO: remove other*0?
|
3752
|
+
return (other < 0).where(-self.abs(), self.abs()) + other*0
|
3436
3753
|
|
3437
3754
|
# ***** op wrappers *****
|
3438
3755
|
|
3439
3756
|
def __invert__(self) -> Tensor: return self.bitwise_not()
|
3440
3757
|
|
3441
|
-
|
3442
|
-
def
|
3758
|
+
# TODO: combine with UOps __floordiv__
|
3759
|
+
def __floordiv__(self, x): return self.div(x, rounding_mode="floor")
|
3760
|
+
def __rfloordiv__(self, x): return self.div(x, rounding_mode="floor", reverse=True)
|
3443
3761
|
|
3444
3762
|
def __pow__(self, x) -> Tensor: return self.pow(x)
|
3445
3763
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
@@ -3452,11 +3770,11 @@ class Tensor(SimpleMathTrait):
|
|
3452
3770
|
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
3453
3771
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
3454
3772
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
3455
|
-
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.
|
3773
|
+
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.__floordiv__(x))
|
3456
3774
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
3457
3775
|
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
3458
3776
|
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
3459
|
-
def __ixor__(self, x) -> Tensor: return self.assign(self.
|
3777
|
+
def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x))
|
3460
3778
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
3461
3779
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
3462
3780
|
|
@@ -3468,7 +3786,7 @@ class Tensor(SimpleMathTrait):
|
|
3468
3786
|
|
3469
3787
|
# ***** functional nn ops *****
|
3470
3788
|
|
3471
|
-
def linear(self, weight:Tensor, bias:
|
3789
|
+
def linear(self, weight:Tensor, bias:Tensor|None=None, dtype:DTypeLike|None=None) -> Tensor:
|
3472
3790
|
"""
|
3473
3791
|
Applies a linear transformation to `self` using `weight` and `bias`.
|
3474
3792
|
|
@@ -3481,10 +3799,11 @@ class Tensor(SimpleMathTrait):
|
|
3481
3799
|
print(t.linear(weight, bias).numpy())
|
3482
3800
|
```
|
3483
3801
|
"""
|
3802
|
+
if dtype is not None: return self.cast(dtype).linear(weight.cast(dtype), bias.cast(dtype) if bias is not None else bias)
|
3484
3803
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
3485
3804
|
return x.add(bias) if bias is not None else x
|
3486
3805
|
|
3487
|
-
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
|
3806
|
+
def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor:
|
3488
3807
|
"""
|
3489
3808
|
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
3490
3809
|
|
@@ -3495,11 +3814,10 @@ class Tensor(SimpleMathTrait):
|
|
3495
3814
|
"""
|
3496
3815
|
return functools.reduce(lambda x,f: f(x), ll, self)
|
3497
3816
|
|
3498
|
-
def layernorm(self, axis:
|
3817
|
+
def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor:
|
3499
3818
|
"""
|
3500
3819
|
Applies Layer Normalization over a mini-batch of inputs.
|
3501
3820
|
|
3502
|
-
- Described: https://paperswithcode.com/method/layer-normalization
|
3503
3821
|
- Paper: https://arxiv.org/abs/1607.06450v1
|
3504
3822
|
|
3505
3823
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3514,11 +3832,10 @@ class Tensor(SimpleMathTrait):
|
|
3514
3832
|
y = (self - self.mean(axis, keepdim=True))
|
3515
3833
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
3516
3834
|
|
3517
|
-
def batchnorm(self, weight:
|
3835
|
+
def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor:
|
3518
3836
|
"""
|
3519
3837
|
Applies Batch Normalization over a mini-batch of inputs.
|
3520
3838
|
|
3521
|
-
- Described: https://paperswithcode.com/method/batch-normalization
|
3522
3839
|
- Paper: https://arxiv.org/abs/1502.03167
|
3523
3840
|
|
3524
3841
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3543,7 +3860,6 @@ class Tensor(SimpleMathTrait):
|
|
3543
3860
|
|
3544
3861
|
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
3545
3862
|
|
3546
|
-
- Described: https://paperswithcode.com/method/dropout
|
3547
3863
|
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
3548
3864
|
|
3549
3865
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3553,11 +3869,13 @@ class Tensor(SimpleMathTrait):
|
|
3553
3869
|
print(t.dropout().numpy())
|
3554
3870
|
```
|
3555
3871
|
"""
|
3872
|
+
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
3556
3873
|
if not Tensor.training or p == 0: return self
|
3874
|
+
if p == 1: return self.zeros_like()
|
3557
3875
|
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
3558
3876
|
|
3559
3877
|
# helper function commonly used for indexing
|
3560
|
-
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
3878
|
+
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor:
|
3561
3879
|
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
|
3562
3880
|
offset = self.ndim - self._resolve_dim(dim) - 1
|
3563
3881
|
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
@@ -3577,12 +3895,12 @@ class Tensor(SimpleMathTrait):
|
|
3577
3895
|
if num_classes == -1: num_classes = (self.max()+1).item()
|
3578
3896
|
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
3579
3897
|
|
3580
|
-
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
3898
|
+
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
3899
|
+
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
|
3581
3900
|
"""
|
3582
3901
|
Computes scaled dot-product attention.
|
3583
3902
|
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
3584
3903
|
|
3585
|
-
- Described: https://paperswithcode.com/method/scaled
|
3586
3904
|
- Paper: https://arxiv.org/abs/1706.03762v7
|
3587
3905
|
|
3588
3906
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -3594,7 +3912,11 @@ class Tensor(SimpleMathTrait):
|
|
3594
3912
|
"""
|
3595
3913
|
# NOTE: it also works when `key` and `value` have symbolic shape.
|
3596
3914
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
3597
|
-
|
3915
|
+
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
3916
|
+
if enable_gqa:
|
3917
|
+
key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)
|
3918
|
+
value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3)
|
3919
|
+
qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
3598
3920
|
# handle attention mask
|
3599
3921
|
if is_causal:
|
3600
3922
|
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
@@ -3602,7 +3924,7 @@ class Tensor(SimpleMathTrait):
|
|
3602
3924
|
if attn_mask is not None:
|
3603
3925
|
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
3604
3926
|
qk = qk + attn_mask
|
3605
|
-
return qk.
|
3927
|
+
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
3606
3928
|
|
3607
3929
|
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
3608
3930
|
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
@@ -3623,7 +3945,7 @@ class Tensor(SimpleMathTrait):
|
|
3623
3945
|
"""
|
3624
3946
|
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
|
3625
3947
|
|
3626
|
-
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
3948
|
+
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean", pos_weight:Tensor|None=None) -> Tensor:
|
3627
3949
|
"""
|
3628
3950
|
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
|
3629
3951
|
|
@@ -3635,7 +3957,8 @@ class Tensor(SimpleMathTrait):
|
|
3635
3957
|
print(t.binary_crossentropy_logits(Y).item())
|
3636
3958
|
```
|
3637
3959
|
"""
|
3638
|
-
|
3960
|
+
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
|
3961
|
+
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
|
3639
3962
|
|
3640
3963
|
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
3641
3964
|
"""
|
@@ -3653,10 +3976,10 @@ class Tensor(SimpleMathTrait):
|
|
3653
3976
|
```
|
3654
3977
|
"""
|
3655
3978
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
3656
|
-
assert reduction in (
|
3657
|
-
log_probs
|
3658
|
-
|
3659
|
-
y = (
|
3979
|
+
assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
|
3980
|
+
log_probs = self.log_softmax()
|
3981
|
+
loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
3982
|
+
y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
|
3660
3983
|
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
3661
3984
|
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
3662
3985
|
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
@@ -3664,7 +3987,7 @@ class Tensor(SimpleMathTrait):
|
|
3664
3987
|
|
3665
3988
|
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
|
3666
3989
|
"""
|
3667
|
-
|
3990
|
+
Computes the cross entropy loss between input logits and target.
|
3668
3991
|
|
3669
3992
|
NOTE: `self` are logits and `Y` are the target labels or class probabilities.
|
3670
3993
|
|
@@ -3682,14 +4005,16 @@ class Tensor(SimpleMathTrait):
|
|
3682
4005
|
```
|
3683
4006
|
"""
|
3684
4007
|
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
3685
|
-
|
3686
|
-
|
3687
|
-
|
3688
|
-
|
4008
|
+
classes_dim = 0 if self.ndim == 1 else 1
|
4009
|
+
if self.shape != Y.shape:
|
4010
|
+
if self.max(classes_dim).shape != Y.shape: raise RuntimeError(f"shape mismatch: {self.shape=}, {Y.shape=}")
|
4011
|
+
Y = Y.unsqueeze(classes_dim)._one_hot_along_dim(num_classes=self.shape[classes_dim], dim=classes_dim)
|
4012
|
+
Y = (1 - label_smoothing)*Y + label_smoothing / int(Y.shape[classes_dim])
|
4013
|
+
return -self.log_softmax(classes_dim).mul(Y).sum(classes_dim)._do_reduction(reduction)
|
3689
4014
|
|
3690
|
-
def nll_loss(self, Y:Tensor, weight:
|
4015
|
+
def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor:
|
3691
4016
|
"""
|
3692
|
-
|
4017
|
+
Computes the negative log likelihood loss between log-probabilities and target labels.
|
3693
4018
|
|
3694
4019
|
NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
|
3695
4020
|
|
@@ -3711,6 +4036,87 @@ class Tensor(SimpleMathTrait):
|
|
3711
4036
|
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
3712
4037
|
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
3713
4038
|
|
4039
|
+
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
|
4040
|
+
"""
|
4041
|
+
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
|
4042
|
+
|
4043
|
+
```python exec="true" source="above" session="tensor" result="python"
|
4044
|
+
t = Tensor.randn(4, 4)
|
4045
|
+
print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy())
|
4046
|
+
```
|
4047
|
+
"""
|
4048
|
+
assert self.ndim > 1, "NS only works for two or more dims"
|
4049
|
+
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
|
4050
|
+
G = G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
4051
|
+
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
|
4052
|
+
return G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
4053
|
+
|
4054
|
+
def qr(self) -> tuple[Tensor, Tensor]:
|
4055
|
+
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
4056
|
+
R = self.clone()
|
4057
|
+
b_shape, m, n = self.shape[0:self.ndim - 2], int(R.shape[-2]), int(R.shape[-1])
|
4058
|
+
Q = Tensor.eye(m, dtype = self.dtype).reshape((1,) * (len(self.shape) - 2) + 2 * (m,)).expand(b_shape + 2 * (m,)).contiguous()
|
4059
|
+
for i in range(int(min(m, n))):
|
4060
|
+
x = R[..., i:m, i]
|
4061
|
+
s = -x[..., 0].sign()
|
4062
|
+
u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
|
4063
|
+
w = x.unsqueeze(-1) / u1.reshape(b_shape + 2 * (1,))
|
4064
|
+
w[..., 0, 0] = 1
|
4065
|
+
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + 2 * (1,)).expand(w.shape)
|
4066
|
+
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
|
4067
|
+
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau.transpose(-2, -1) * w.transpose(-2, -1))
|
4068
|
+
return Q,R
|
4069
|
+
|
4070
|
+
def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]:
|
4071
|
+
#partial implementation of https://www.netlib.org/lapack/lawnspdf/lawn169.pdf , pg 26
|
4072
|
+
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
4073
|
+
b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
|
4074
|
+
#preprocess the matrix
|
4075
|
+
Q, R = (Tensor.qr(self) if m >= n else Tensor.qr(self.transpose(-2, -1)))
|
4076
|
+
num, q_num = int(min(m, n)), int(max(m, n))
|
4077
|
+
U = R.shrink(tuple([(0, self.shape[i]) for i in range(self.ndim - 2)] + [(0, num), (0, num)])).contiguous()
|
4078
|
+
V = Tensor.eye(num, dtype = self.dtype).reshape((1,) * (self.ndim - 2) + (num, num)).expand(b_shape + 2 * (num,)).contiguous()
|
4079
|
+
#prepare round robin pairing
|
4080
|
+
permute, inverse_permute = Tensor.arange(0, num, dtype = dtypes.int), Tensor.zeros(num, dtype = dtypes.int).contiguous()
|
4081
|
+
permute[num//2:num] = permute[num//2:num].flip(0)
|
4082
|
+
inverse_permute[permute] = Tensor.arange(num, dtype = dtypes.int)
|
4083
|
+
def one_round_jacobi(U, V,permute,inverse_permute):
|
4084
|
+
#pair all the columns
|
4085
|
+
V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None)
|
4086
|
+
V_left, V_right = V_permuted.split(num//2, -1)
|
4087
|
+
U_permuted, runoff_U = (U[..., permute].split(num - 1, -1)) if num % 2 == 1 else (U[..., permute], None)
|
4088
|
+
U_left, U_right = U_permuted.split(num//2, -1)
|
4089
|
+
#compute the jacobi rotations for each pairing
|
4090
|
+
gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
|
4091
|
+
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
|
4092
|
+
tau = (beta - alpha) / (2 * gamma)
|
4093
|
+
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
|
4094
|
+
c = 1 / (1 + t.square()).sqrt()
|
4095
|
+
s = c * t
|
4096
|
+
#apply the rotations
|
4097
|
+
U_left, U_right = c * U_left - s * U_right, s * U_left + c * U_right
|
4098
|
+
U = U_left.cat(U_right.cat(runoff_U, dim = -1) if num % 2 == 1 else U_right, dim = -1)[..., inverse_permute]
|
4099
|
+
V_left, V_right = c * V_left - s * V_right, s * V_left + c * V_right
|
4100
|
+
V = V_left.cat(V_right.cat(runoff_V, dim = -1) if num % 2 == 1 else V_right, dim = -1)[..., inverse_permute]
|
4101
|
+
#prepare the next round robin pairings
|
4102
|
+
if num % 2 == 1: permute = ((permute - 1) % num)
|
4103
|
+
else: permute = permute[0].reshape(1).cat(((permute[1:num] - 2) % (num - 1)) + 1)
|
4104
|
+
inverse_permute = inverse_permute.scatter(0,permute,Tensor.arange(num,dtype=dtypes.int32))
|
4105
|
+
return U, V, permute, inverse_permute
|
4106
|
+
max_iterations, iterations_per_round = 1, int((num) * math.log2(num) * 2 + 2)#sorta heuristic, most use num*log2(num)
|
4107
|
+
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
|
4108
|
+
#extract singular values and sort. construct U from Q
|
4109
|
+
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
4110
|
+
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + 2 * (num,)).contiguous()
|
4111
|
+
new_indices[..., :num] = indices.reshape(b_shape + (1,) + (num,)).expand(b_shape + 2 * (num,))
|
4112
|
+
U,V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
4113
|
+
|
4114
|
+
padded_u = Tensor.eye(q_num, dtype = U.dtype).reshape((1,) * (self.ndim - 2) + 2 * (q_num,)).expand(b_shape + 2 * (q_num,)).contiguous()
|
4115
|
+
padded_u[..., 0:num, 0:num] = U
|
4116
|
+
U = Q @ padded_u
|
4117
|
+
if not full_matrices: U, V = U[..., 0:num], V[..., 0:num]
|
4118
|
+
return (U, S, V.transpose(-2,-1)) if m >= n else (V, S, U.transpose(-2, -1))
|
4119
|
+
|
3714
4120
|
# ***** Tensor Properties *****
|
3715
4121
|
|
3716
4122
|
@property
|
@@ -3760,8 +4166,8 @@ class Tensor(SimpleMathTrait):
|
|
3760
4166
|
|
3761
4167
|
def is_floating_point(self) -> bool:
|
3762
4168
|
"""
|
3763
|
-
Returns `True` if the tensor contains floating point types, i.e. is one of `
|
3764
|
-
`
|
4169
|
+
Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
|
4170
|
+
`dtypes.float16`, `dtypes.bfloat16`.
|
3765
4171
|
|
3766
4172
|
```python exec="true" source="above" session="tensor" result="python"
|
3767
4173
|
t = Tensor([8, 9], dtype=dtypes.float32)
|
@@ -3770,9 +4176,9 @@ class Tensor(SimpleMathTrait):
|
|
3770
4176
|
"""
|
3771
4177
|
return dtypes.is_float(self.dtype)
|
3772
4178
|
|
3773
|
-
def size(self, dim:
|
4179
|
+
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
|
3774
4180
|
"""
|
3775
|
-
|
4181
|
+
Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
3776
4182
|
|
3777
4183
|
```python exec="true" source="above" session="tensor" result="python"
|
3778
4184
|
t = Tensor([[4, 5, 6], [7, 8, 9]])
|
@@ -3786,7 +4192,7 @@ class Tensor(SimpleMathTrait):
|
|
3786
4192
|
|
3787
4193
|
# ***** cast ops *****
|
3788
4194
|
|
3789
|
-
def llvm_bf16_cast(self, dtype:DTypeLike):
|
4195
|
+
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
|
3790
4196
|
# hack for devices that don't support bfloat16
|
3791
4197
|
assert self.dtype == dtypes.bfloat16
|
3792
4198
|
return self.to("LLVM").cast(dtype)
|
@@ -3834,7 +4240,10 @@ class Tensor(SimpleMathTrait):
|
|
3834
4240
|
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
|
3835
4241
|
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
3836
4242
|
tmp = self.bitcast(old_uint)
|
3837
|
-
if ns > os:
|
4243
|
+
if ns > os:
|
4244
|
+
tmp = tmp.reshape(self.shape[:-1] + (self.shape[-1]//(rate := ns//os), rate))
|
4245
|
+
nones = (None,) * (tmp.ndim - 1)
|
4246
|
+
return functools.reduce(Tensor.add, (tmp.shrink(nones + ((i, i+1),)).cast(new_uint)<<8*i*os for i in range(rate))).squeeze(-1).bitcast(dtype)
|
3838
4247
|
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
3839
4248
|
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
3840
4249
|
|
@@ -3898,9 +4307,14 @@ class Tensor(SimpleMathTrait):
|
|
3898
4307
|
"""
|
3899
4308
|
return self.cast(dtypes.bool)
|
3900
4309
|
|
4310
|
+
def bfloat16(self) -> Tensor: return self.cast(dtypes.bfloat16)
|
4311
|
+
def double(self) -> Tensor: return self.cast(dtypes.double)
|
4312
|
+
def long(self) -> Tensor: return self.cast(dtypes.long)
|
4313
|
+
def short(self) -> Tensor: return self.cast(dtypes.short)
|
4314
|
+
|
3901
4315
|
# *** image Tensor function replacements ***
|
3902
4316
|
|
3903
|
-
def image_dot(self, w:Tensor,
|
4317
|
+
def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
|
3904
4318
|
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
3905
4319
|
x, dx, dw = self, self.ndim, w.ndim
|
3906
4320
|
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
@@ -3914,9 +4328,9 @@ class Tensor(SimpleMathTrait):
|
|
3914
4328
|
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
3915
4329
|
# groups*cout x cin x H, W
|
3916
4330
|
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
3917
|
-
return cx.image_conv2d(cw, groups=groups,
|
4331
|
+
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
3918
4332
|
|
3919
|
-
def image_conv2d(self, weight:Tensor, bias:
|
4333
|
+
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
|
3920
4334
|
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
3921
4335
|
|
3922
4336
|
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
@@ -3965,7 +4379,7 @@ class Tensor(SimpleMathTrait):
|
|
3965
4379
|
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
3966
4380
|
|
3967
4381
|
# the conv!
|
3968
|
-
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1),
|
4382
|
+
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
|
3969
4383
|
|
3970
4384
|
# undo hack for non multiples of 4 on C.rcout
|
3971
4385
|
if added_output_channels != 0:
|
@@ -3976,8 +4390,20 @@ class Tensor(SimpleMathTrait):
|
|
3976
4390
|
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
3977
4391
|
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
3978
4392
|
|
3979
|
-
|
3980
|
-
|
4393
|
+
P = ParamSpec("P")
|
4394
|
+
T = TypeVar("T")
|
4395
|
+
|
4396
|
+
# this tracks the tensor.py METADATA, contextvars.ContextVar was switched to this due to thread safety issues
|
4397
|
+
class _ContextVar(Generic[T]):
|
4398
|
+
def __init__(self, default:T): self.state:T = default
|
4399
|
+
def get(self) -> T: return self.state
|
4400
|
+
def set(self, x:T) -> T:
|
4401
|
+
ret, self.state = self.state, x
|
4402
|
+
return ret
|
4403
|
+
_METADATA: _ContextVar[Metadata|None] = _ContextVar(default=None)
|
4404
|
+
|
4405
|
+
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
4406
|
+
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
3981
4407
|
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
3982
4408
|
|
3983
4409
|
if TRACEMETA >= 2:
|
@@ -4004,7 +4430,7 @@ def _metadata_wrapper(fn):
|
|
4004
4430
|
|
4005
4431
|
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
4006
4432
|
ret = fn(*args, **kwargs)
|
4007
|
-
_METADATA.
|
4433
|
+
_METADATA.set(token)
|
4008
4434
|
return ret
|
4009
4435
|
return _wrapper
|
4010
4436
|
|