tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import time, math, itertools, functools
|
3
|
+
import time, math, itertools, functools, struct
|
4
4
|
from contextlib import ContextDecorator
|
5
5
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
6
6
|
from collections import defaultdict
|
7
7
|
import numpy as np
|
8
8
|
|
9
9
|
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
10
|
-
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts,
|
10
|
+
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
11
11
|
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
12
12
|
from tinygrad.lazy import LazyBuffer
|
13
13
|
from tinygrad.multi import MultiLazyBuffer
|
14
|
-
from tinygrad.ops import LoadOps
|
14
|
+
from tinygrad.ops import LoadOps, truncate
|
15
15
|
from tinygrad.device import Device, Buffer, BufferOptions
|
16
|
-
from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
|
16
|
+
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
17
17
|
from tinygrad.engine.realize import run_schedule
|
18
18
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
|
19
19
|
|
@@ -43,13 +43,28 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
|
|
43
43
|
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
44
44
|
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
45
45
|
|
46
|
-
def
|
47
|
-
|
46
|
+
def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
|
47
|
+
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
48
|
+
|
49
|
+
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
50
|
+
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
48
51
|
# fake realize
|
49
52
|
ret.buffer.allocate(x)
|
50
53
|
del ret.srcs
|
51
54
|
return ret
|
52
55
|
|
56
|
+
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
57
|
+
if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
|
58
|
+
else:
|
59
|
+
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
60
|
+
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
61
|
+
truncate_function = truncate[dtype]
|
62
|
+
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
63
|
+
# fake realize
|
64
|
+
ret.buffer.allocate(memoryview(data))
|
65
|
+
del ret.srcs
|
66
|
+
return ret
|
67
|
+
|
53
68
|
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
|
54
69
|
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
|
55
70
|
for k in range(len(mat[0]))] for dim in range(dims)]
|
@@ -66,8 +81,11 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|
66
81
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
67
82
|
return ret
|
68
83
|
|
69
|
-
def _pad_left(*
|
70
|
-
|
84
|
+
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
85
|
+
max_dim = max(len(shape) for shape in shapes)
|
86
|
+
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
87
|
+
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
88
|
+
return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
71
89
|
|
72
90
|
class Tensor:
|
73
91
|
"""
|
@@ -83,60 +101,75 @@ class Tensor:
|
|
83
101
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
84
102
|
__deletable__ = ('_ctx',)
|
85
103
|
training: ClassVar[bool] = False
|
86
|
-
class train(ContextDecorator):
|
87
|
-
def __init__(self, mode:bool = True): self.mode = mode
|
88
|
-
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
89
|
-
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
90
|
-
|
91
104
|
no_grad: ClassVar[bool] = False
|
92
|
-
|
93
|
-
def __init__(self, mode:bool = True): self.mode = mode
|
94
|
-
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
95
|
-
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
105
|
+
|
96
106
|
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
97
107
|
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
98
108
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
99
109
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
100
|
-
|
110
|
+
|
111
|
+
# tensors can have gradients if you have called .backward
|
101
112
|
self.grad: Optional[Tensor] = None
|
102
113
|
|
103
114
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
104
115
|
# None (the default) will be updated to True if it's put in an optimizer
|
105
116
|
self.requires_grad: Optional[bool] = requires_grad
|
106
117
|
|
107
|
-
# internal
|
118
|
+
# internal variable used for autograd graph construction
|
108
119
|
self._ctx: Optional[Function] = None
|
120
|
+
|
121
|
+
# create a LazyBuffer from the different types of inputs
|
109
122
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
110
123
|
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
111
124
|
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
112
|
-
elif isinstance(data, bytes): data =
|
113
|
-
elif data
|
114
|
-
elif isinstance(data, list):
|
125
|
+
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
|
126
|
+
elif isinstance(data, (list, tuple)):
|
115
127
|
if dtype is None:
|
116
128
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
117
129
|
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
118
|
-
if dtype == dtypes.bfloat16: data = Tensor(
|
119
|
-
else: data =
|
130
|
+
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
131
|
+
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
|
132
|
+
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
120
133
|
elif isinstance(data, np.ndarray):
|
121
|
-
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or
|
122
|
-
else: data =
|
134
|
+
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
135
|
+
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
136
|
+
|
137
|
+
# by this point, it has to be a LazyBuffer
|
138
|
+
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
139
|
+
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
123
140
|
|
124
141
|
# data is a LazyBuffer, but it might be on the wrong device
|
125
|
-
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
126
142
|
if isinstance(device, tuple):
|
127
|
-
#
|
128
|
-
|
143
|
+
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
144
|
+
if isinstance(data, MultiLazyBuffer):
|
145
|
+
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
146
|
+
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
147
|
+
else:
|
148
|
+
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
|
129
149
|
else:
|
130
150
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
131
151
|
|
132
|
-
|
152
|
+
class train(ContextDecorator):
|
153
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
154
|
+
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
155
|
+
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
156
|
+
|
157
|
+
class inference_mode(ContextDecorator):
|
158
|
+
def __init__(self, mode:bool = True): self.mode = mode
|
159
|
+
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
160
|
+
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
161
|
+
|
162
|
+
def __repr__(self):
|
163
|
+
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
133
164
|
|
134
165
|
# Python has a non moving GC, so this should be okay
|
135
166
|
def __hash__(self): return id(self)
|
136
167
|
|
137
168
|
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
|
138
169
|
|
139
|
-
def __len__(self):
|
170
|
+
def __len__(self):
|
171
|
+
if not self.shape: raise TypeError("len() of a 0-d tensor")
|
172
|
+
return self.shape[0]
|
140
173
|
|
141
174
|
@property
|
142
175
|
def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
|
@@ -196,6 +229,7 @@ class Tensor:
|
|
196
229
|
if not self.lazydata.is_realized(): return self.replace(x)
|
197
230
|
self.lazydata = self.lazydata.assign(x.lazydata)
|
198
231
|
return self
|
232
|
+
|
199
233
|
def detach(self) -> Tensor:
|
200
234
|
"""
|
201
235
|
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
@@ -259,9 +293,9 @@ class Tensor:
|
|
259
293
|
```
|
260
294
|
"""
|
261
295
|
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
262
|
-
assert self.dtype
|
296
|
+
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
263
297
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
264
|
-
return np.frombuffer(self._data(), dtype=self.dtype
|
298
|
+
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
|
265
299
|
|
266
300
|
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
|
267
301
|
"""
|
@@ -302,8 +336,10 @@ class Tensor:
|
|
302
336
|
|
303
337
|
@staticmethod
|
304
338
|
def from_node(y:Node, **kwargs) -> Tensor:
|
305
|
-
if isinstance(y,
|
339
|
+
if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
|
306
340
|
if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
|
341
|
+
if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
|
342
|
+
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
|
307
343
|
raise RuntimeError(f"unhandled Node {y}")
|
308
344
|
|
309
345
|
# ***** creation llop entrypoint *****
|
@@ -339,7 +375,13 @@ class Tensor:
|
|
339
375
|
|
340
376
|
```python exec="true" source="above" session="tensor" result="python"
|
341
377
|
Tensor.manual_seed(42)
|
342
|
-
print(Tensor.
|
378
|
+
print(Tensor.rand(5).numpy())
|
379
|
+
print(Tensor.rand(5).numpy())
|
380
|
+
```
|
381
|
+
```python exec="true" source="above" session="tensor" result="python"
|
382
|
+
Tensor.manual_seed(42) # reset to the same seed
|
383
|
+
print(Tensor.rand(5).numpy())
|
384
|
+
print(Tensor.rand(5).numpy())
|
343
385
|
```
|
344
386
|
"""
|
345
387
|
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
@@ -724,8 +766,11 @@ class Tensor:
|
|
724
766
|
print(t.reshape(2, 3).numpy())
|
725
767
|
```
|
726
768
|
"""
|
727
|
-
|
728
|
-
new_shape = tuple([
|
769
|
+
# resolve None and args
|
770
|
+
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
771
|
+
# resolve -1
|
772
|
+
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
773
|
+
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
729
774
|
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
730
775
|
|
731
776
|
def expand(self, shape, *args) -> Tensor:
|
@@ -740,7 +785,7 @@ class Tensor:
|
|
740
785
|
print(t.expand(4, -1).numpy())
|
741
786
|
```
|
742
787
|
"""
|
743
|
-
return self._broadcast_to(tuple(
|
788
|
+
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
|
744
789
|
|
745
790
|
def permute(self, order, *args) -> Tensor:
|
746
791
|
"""
|
@@ -756,7 +801,9 @@ class Tensor:
|
|
756
801
|
print(t.permute(1, 0).numpy())
|
757
802
|
```
|
758
803
|
"""
|
759
|
-
|
804
|
+
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
805
|
+
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
806
|
+
return F.Permute.apply(self, order=order_arg)
|
760
807
|
|
761
808
|
def flip(self, axis, *args) -> Tensor:
|
762
809
|
"""
|
@@ -774,7 +821,9 @@ class Tensor:
|
|
774
821
|
print(t.flip((0, 1)).numpy())
|
775
822
|
```
|
776
823
|
"""
|
777
|
-
|
824
|
+
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
825
|
+
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
|
826
|
+
return F.Flip.apply(self, axis=axis_arg)
|
778
827
|
|
779
828
|
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
780
829
|
"""
|
@@ -831,7 +880,7 @@ class Tensor:
|
|
831
880
|
# - first shrink the Tensor to X.shrink(((start, end),))
|
832
881
|
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
833
882
|
# - flip where dim value is negative
|
834
|
-
# - pad
|
883
|
+
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
835
884
|
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
836
885
|
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
837
886
|
# 3. None indexing (no copy)
|
@@ -847,13 +896,13 @@ class Tensor:
|
|
847
896
|
# - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
|
848
897
|
# 2. Bool indexing is not supported
|
849
898
|
# 3. Out of bounds Tensor indexing results in 0
|
850
|
-
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are
|
899
|
+
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
851
900
|
def __getitem__(self, indices) -> Tensor:
|
852
901
|
# 1. indices normalization and validation
|
853
902
|
# treat internal tuples and lists as Tensors and standardize indices to list type
|
854
903
|
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
|
855
904
|
elif isinstance(indices, (tuple, list)):
|
856
|
-
indices = [Tensor(
|
905
|
+
indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
857
906
|
else: indices = [indices]
|
858
907
|
|
859
908
|
# turn scalar Tensors into const val for int indexing if possible
|
@@ -865,24 +914,25 @@ class Tensor:
|
|
865
914
|
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
866
915
|
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
867
916
|
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
868
|
-
indices[fill_idx:fill_idx+1] = [slice(None)] * (
|
917
|
+
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
869
918
|
|
870
919
|
# use Dict[type, List[dimension]] to track elements in indices
|
871
920
|
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
872
921
|
|
873
922
|
# record None for dimension injection later and filter None and record rest of indices
|
874
923
|
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
875
|
-
|
924
|
+
tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
925
|
+
indices_filtered = [i for i in indices if i is not None]
|
876
926
|
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
877
927
|
|
928
|
+
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
878
929
|
for index_type in type_dim:
|
879
930
|
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
880
|
-
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
881
931
|
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
882
932
|
|
883
933
|
# 2. basic indexing, uses only movement ops (no copy)
|
884
|
-
# currently indices_filtered: Tuple[Union[
|
885
|
-
# turn indices in indices_filtered to Tuple[
|
934
|
+
# currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
|
935
|
+
# turn indices in indices_filtered to Tuple[new_slice, strides]
|
886
936
|
for dim in type_dim[int]:
|
887
937
|
if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
|
888
938
|
raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
|
@@ -898,13 +948,16 @@ class Tensor:
|
|
898
948
|
if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
|
899
949
|
indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
900
950
|
|
901
|
-
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
|
902
|
-
|
903
|
-
|
951
|
+
new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
|
952
|
+
# flip negative strides
|
953
|
+
ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
|
954
|
+
# handle stride != 1 or -1
|
955
|
+
if any(abs(st) != 1 for st in strides):
|
904
956
|
strides = tuple(abs(s) for s in strides)
|
905
|
-
|
906
|
-
ret = ret.
|
907
|
-
ret = ret.
|
957
|
+
# pad shape to multiple of stride
|
958
|
+
ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
|
959
|
+
ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
|
960
|
+
ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
|
908
961
|
|
909
962
|
# inject 1 for dim where it's None and collapse dim for int
|
910
963
|
new_shape = list(ret.shape)
|
@@ -912,17 +965,17 @@ class Tensor:
|
|
912
965
|
for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
|
913
966
|
|
914
967
|
ret = ret.reshape(new_shape)
|
915
|
-
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
916
968
|
|
917
969
|
# 3. advanced indexing (copy)
|
918
970
|
if type_dim[Tensor]:
|
919
971
|
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
920
972
|
def calc_dim(tensor_dim:int) -> int:
|
921
|
-
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
973
|
+
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
922
974
|
|
975
|
+
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
923
976
|
# track tensor_dim and tensor_index using a dict
|
924
977
|
# calc_dim to get dim and use that to normalize the negative tensor indices
|
925
|
-
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(
|
978
|
+
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
|
926
979
|
|
927
980
|
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
928
981
|
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
@@ -938,9 +991,9 @@ class Tensor:
|
|
938
991
|
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
939
992
|
|
940
993
|
# inject 1's for the extra dims added in create masks
|
941
|
-
|
994
|
+
reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
942
995
|
# sum reduce the extra dims introduced in create masks
|
943
|
-
ret = (ret.reshape(
|
996
|
+
ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
|
944
997
|
|
945
998
|
# special permute case
|
946
999
|
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
@@ -983,13 +1036,11 @@ class Tensor:
|
|
983
1036
|
```
|
984
1037
|
"""
|
985
1038
|
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
986
|
-
assert all(s >= i for s,i in zip(self.shape, index.shape)), "
|
1039
|
+
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
987
1040
|
dim = self._resolve_dim(dim)
|
988
|
-
index = index.to(self.device)
|
989
|
-
|
990
|
-
|
991
|
-
return ((index == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
|
992
|
-
tuple([*[(0,sh) for sh in index.shape[1:-1]], None])).unsqueeze(0)).sum(-1, acc_dtype=self.dtype).transpose(0, dim)
|
1041
|
+
index = index.to(self.device)
|
1042
|
+
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1043
|
+
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
993
1044
|
|
994
1045
|
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
995
1046
|
"""
|
@@ -1193,7 +1244,7 @@ class Tensor:
|
|
1193
1244
|
|
1194
1245
|
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
1195
1246
|
"""
|
1196
|
-
|
1247
|
+
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
1197
1248
|
|
1198
1249
|
```python exec="true" source="above" session="tensor" result="python"
|
1199
1250
|
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
@@ -1210,16 +1261,16 @@ class Tensor:
|
|
1210
1261
|
|
1211
1262
|
# ***** reduce ops *****
|
1212
1263
|
|
1213
|
-
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int,
|
1264
|
+
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
1214
1265
|
if self.ndim == 0:
|
1215
|
-
if axis is not None and
|
1216
|
-
axis =
|
1266
|
+
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
|
1267
|
+
axis = ()
|
1217
1268
|
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
1218
1269
|
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
1219
1270
|
ret = fxn.apply(self, axis=axis_)
|
1220
1271
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
1221
1272
|
|
1222
|
-
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
1273
|
+
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
1223
1274
|
"""
|
1224
1275
|
Sums the elements of the tensor along the specified axis or axes.
|
1225
1276
|
|
@@ -1244,8 +1295,9 @@ class Tensor:
|
|
1244
1295
|
```
|
1245
1296
|
"""
|
1246
1297
|
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
1247
|
-
return ret.cast(self.dtype) if self.dtype in
|
1248
|
-
|
1298
|
+
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
1299
|
+
|
1300
|
+
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1249
1301
|
"""
|
1250
1302
|
Returns the maximum value of the tensor along the specified axis or axes.
|
1251
1303
|
|
@@ -1267,7 +1319,8 @@ class Tensor:
|
|
1267
1319
|
```
|
1268
1320
|
"""
|
1269
1321
|
return self._reduce(F.Max, axis, keepdim)
|
1270
|
-
|
1322
|
+
|
1323
|
+
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1271
1324
|
"""
|
1272
1325
|
Returns the minimum value of the tensor along the specified axis or axes.
|
1273
1326
|
|
@@ -1290,7 +1343,7 @@ class Tensor:
|
|
1290
1343
|
"""
|
1291
1344
|
return -((-self).max(axis=axis, keepdim=keepdim))
|
1292
1345
|
|
1293
|
-
def mean(self, axis=None, keepdim=False):
|
1346
|
+
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1294
1347
|
"""
|
1295
1348
|
Returns the mean value of the tensor along the specified axis or axes.
|
1296
1349
|
|
@@ -1316,7 +1369,7 @@ class Tensor:
|
|
1316
1369
|
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
1317
1370
|
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
1318
1371
|
|
1319
|
-
def var(self, axis=None, keepdim=False, correction=1):
|
1372
|
+
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1320
1373
|
"""
|
1321
1374
|
Returns the variance of the tensor along the specified axis or axes.
|
1322
1375
|
|
@@ -1338,11 +1391,11 @@ class Tensor:
|
|
1338
1391
|
print(t.var(axis=1).numpy())
|
1339
1392
|
```
|
1340
1393
|
"""
|
1341
|
-
|
1342
|
-
|
1343
|
-
return
|
1394
|
+
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
1395
|
+
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
|
1396
|
+
return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
|
1344
1397
|
|
1345
|
-
def std(self, axis=None, keepdim=False, correction=1):
|
1398
|
+
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
1346
1399
|
"""
|
1347
1400
|
Returns the standard deviation of the tensor along the specified axis or axes.
|
1348
1401
|
|
@@ -1465,9 +1518,7 @@ class Tensor:
|
|
1465
1518
|
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
1466
1519
|
```
|
1467
1520
|
"""
|
1468
|
-
if axis is None:
|
1469
|
-
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
1470
|
-
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
1521
|
+
if axis is None: return self.flatten().argmax(0)
|
1471
1522
|
axis = self._resolve_dim(axis)
|
1472
1523
|
m = self == self.max(axis=axis, keepdim=True)
|
1473
1524
|
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
@@ -1511,12 +1562,13 @@ class Tensor:
|
|
1511
1562
|
"""
|
1512
1563
|
xs:Tuple[Tensor] = argfix(*raw_xs)
|
1513
1564
|
formula = formula.replace(" ", "")
|
1514
|
-
inputs_str, output = formula.split("->") if "->" in formula else (formula,
|
1515
|
-
|
1565
|
+
inputs_str, output = formula.split("->") if "->" in formula else (formula, \
|
1566
|
+
''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
1567
|
+
inputs = inputs_str.split(',')
|
1516
1568
|
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
1517
1569
|
|
1518
1570
|
# map the value of each letter in the formula
|
1519
|
-
letter_val = sorted(merge_dicts([
|
1571
|
+
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
1520
1572
|
|
1521
1573
|
xs_:List[Tensor] = []
|
1522
1574
|
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
@@ -1540,19 +1592,18 @@ class Tensor:
|
|
1540
1592
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
1541
1593
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
1542
1594
|
noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
|
1595
|
+
o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
|
1543
1596
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
1544
|
-
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
1545
1597
|
# repeats such that we don't need padding
|
1546
1598
|
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
|
1547
|
-
#
|
1599
|
+
# handle dilation
|
1548
1600
|
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
1549
1601
|
# handle stride
|
1550
1602
|
xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
1551
1603
|
xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
1552
1604
|
# permute to move reduce to the end
|
1553
1605
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
1554
|
-
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
1555
|
-
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
1606
|
+
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
1556
1607
|
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
|
1557
1608
|
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
1558
1609
|
xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
@@ -1572,8 +1623,9 @@ class Tensor:
|
|
1572
1623
|
print(t.avg_pool2d().numpy())
|
1573
1624
|
```
|
1574
1625
|
"""
|
1575
|
-
|
1576
|
-
|
1626
|
+
kernel_size = make_pair(kernel_size)
|
1627
|
+
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
|
1628
|
+
|
1577
1629
|
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1578
1630
|
"""
|
1579
1631
|
Applies max pooling over a tensor.
|
@@ -1587,8 +1639,8 @@ class Tensor:
|
|
1587
1639
|
print(t.max_pool2d().numpy())
|
1588
1640
|
```
|
1589
1641
|
"""
|
1590
|
-
|
1591
|
-
|
1642
|
+
kernel_size = make_pair(kernel_size)
|
1643
|
+
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
|
1592
1644
|
|
1593
1645
|
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
1594
1646
|
"""
|
@@ -1665,22 +1717,24 @@ class Tensor:
|
|
1665
1717
|
print(t.conv_transpose2d(w).numpy())
|
1666
1718
|
```
|
1667
1719
|
"""
|
1668
|
-
|
1669
|
-
|
1670
|
-
stride = make_pair(
|
1720
|
+
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
|
1721
|
+
HW = weight.shape[2:]
|
1722
|
+
stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
|
1671
1723
|
if any(s>1 for s in stride):
|
1724
|
+
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
|
1672
1725
|
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
1673
1726
|
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
1674
1727
|
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
1675
1728
|
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
1676
|
-
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
|
1677
|
-
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
1729
|
+
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
1678
1730
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
1679
1731
|
|
1680
1732
|
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
|
1681
1733
|
"""
|
1682
1734
|
Performs dot product between two tensors.
|
1683
1735
|
|
1736
|
+
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
|
1737
|
+
|
1684
1738
|
```python exec="true" source="above" session="tensor" result="python"
|
1685
1739
|
a = Tensor([[1, 2], [3, 4]])
|
1686
1740
|
b = Tensor([[5, 6], [7, 8]])
|
@@ -1692,7 +1746,7 @@ class Tensor:
|
|
1692
1746
|
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
|
1693
1747
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
1694
1748
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
1695
|
-
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
|
1749
|
+
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
1696
1750
|
|
1697
1751
|
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
1698
1752
|
"""
|
@@ -1710,8 +1764,9 @@ class Tensor:
|
|
1710
1764
|
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
1711
1765
|
|
1712
1766
|
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
1713
|
-
|
1714
|
-
|
1767
|
+
assert self.shape[axis] != 0
|
1768
|
+
pl_sz = self.shape[axis] - int(not _first_zero)
|
1769
|
+
return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
1715
1770
|
def cumsum(self, axis:int=0) -> Tensor:
|
1716
1771
|
"""
|
1717
1772
|
Computes the cumulative sum of the tensor along the specified axis.
|
@@ -1726,48 +1781,74 @@ class Tensor:
|
|
1726
1781
|
print(t.cumsum(1).numpy())
|
1727
1782
|
```
|
1728
1783
|
"""
|
1784
|
+
axis = self._resolve_dim(axis)
|
1785
|
+
if self.ndim == 0 or 0 in self.shape: return self
|
1729
1786
|
# TODO: someday the optimizer will find this on it's own
|
1730
1787
|
# for now this is a two stage cumsum
|
1731
1788
|
SPLIT = 256
|
1732
1789
|
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
1733
1790
|
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
1734
1791
|
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
1735
|
-
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
1792
|
+
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
1736
1793
|
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
1737
1794
|
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
1738
1795
|
return fix(ret) + fix(base_add)
|
1739
1796
|
|
1740
1797
|
@staticmethod
|
1741
|
-
def _tri(r:sint, c:sint,
|
1742
|
-
assert
|
1743
|
-
if r == 0: return Tensor.zeros(
|
1744
|
-
|
1745
|
-
|
1798
|
+
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
1799
|
+
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
|
1800
|
+
if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
|
1801
|
+
if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
|
1802
|
+
s = r+c-1
|
1803
|
+
# build a (s, s) upper triangle
|
1804
|
+
t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
|
1805
|
+
return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
|
1806
|
+
|
1807
|
+
def triu(self, diagonal:int=0) -> Tensor:
|
1746
1808
|
"""
|
1747
1809
|
Returns the upper triangular part of the tensor, the other elements are set to 0.
|
1748
1810
|
|
1811
|
+
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
1812
|
+
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
1813
|
+
|
1749
1814
|
```python exec="true" source="above" session="tensor" result="python"
|
1750
|
-
t = Tensor([[1, 2, 3
|
1815
|
+
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
1751
1816
|
print(t.numpy())
|
1752
1817
|
```
|
1753
1818
|
```python exec="true" source="above" session="tensor" result="python"
|
1754
|
-
print(t.triu(
|
1819
|
+
print(t.triu(diagonal=0).numpy())
|
1820
|
+
```
|
1821
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1822
|
+
print(t.triu(diagonal=1).numpy())
|
1823
|
+
```
|
1824
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1825
|
+
print(t.triu(diagonal=-1).numpy())
|
1755
1826
|
```
|
1756
1827
|
"""
|
1757
|
-
return Tensor._tri(self.shape[-2], self.shape[-1],
|
1758
|
-
|
1828
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
|
1829
|
+
|
1830
|
+
def tril(self, diagonal:int=0) -> Tensor:
|
1759
1831
|
"""
|
1760
1832
|
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
1761
1833
|
|
1834
|
+
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
1835
|
+
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
1836
|
+
|
1762
1837
|
```python exec="true" source="above" session="tensor" result="python"
|
1763
|
-
t = Tensor([[1, 2, 3
|
1838
|
+
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
1764
1839
|
print(t.numpy())
|
1765
1840
|
```
|
1766
1841
|
```python exec="true" source="above" session="tensor" result="python"
|
1767
|
-
print(t.tril().numpy())
|
1842
|
+
print(t.tril(diagonal=0).numpy())
|
1843
|
+
```
|
1844
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1845
|
+
print(t.tril(diagonal=1).numpy())
|
1846
|
+
```
|
1847
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1848
|
+
print(t.tril(diagonal=-1).numpy())
|
1768
1849
|
```
|
1769
1850
|
"""
|
1770
|
-
return Tensor._tri(self.shape[-2], self.shape[-1],
|
1851
|
+
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
1771
1852
|
|
1772
1853
|
# ***** unary ops *****
|
1773
1854
|
|
@@ -1779,7 +1860,7 @@ class Tensor:
|
|
1779
1860
|
print(Tensor([False, True]).logical_not().numpy())
|
1780
1861
|
```
|
1781
1862
|
"""
|
1782
|
-
return F.
|
1863
|
+
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
|
1783
1864
|
def neg(self):
|
1784
1865
|
"""
|
1785
1866
|
Negates the tensor element-wise.
|
@@ -2179,7 +2260,7 @@ class Tensor:
|
|
2179
2260
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
|
2180
2261
|
```
|
2181
2262
|
"""
|
2182
|
-
return 0.5 * self * (1 + (
|
2263
|
+
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
|
2183
2264
|
|
2184
2265
|
def quick_gelu(self):
|
2185
2266
|
"""
|
@@ -2246,23 +2327,26 @@ class Tensor:
|
|
2246
2327
|
return self / (1 + self.abs())
|
2247
2328
|
|
2248
2329
|
# ***** broadcasted elementwise ops *****
|
2249
|
-
def _broadcast_to(self, shape:Tuple[sint, ...]):
|
2250
|
-
|
2251
|
-
if self.ndim > len(shape)
|
2252
|
-
|
2253
|
-
|
2254
|
-
|
2255
|
-
|
2330
|
+
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
|
2331
|
+
if self.shape == shape: return self
|
2332
|
+
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
|
2333
|
+
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
2334
|
+
padded, _ = _pad_left(self.shape, shape)
|
2335
|
+
# for each dimension, check either from_ is 1, or it does not change
|
2336
|
+
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
2337
|
+
return F.Expand.apply(self.reshape(padded), shape=shape)
|
2338
|
+
|
2339
|
+
def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
2256
2340
|
x: Tensor = self
|
2257
2341
|
if not isinstance(y, Tensor):
|
2258
2342
|
# make y a Tensor
|
2259
2343
|
assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
|
2260
|
-
if isinstance(
|
2261
|
-
|
2262
|
-
if isinstance(y, Node): y = Tensor.from_node(y, device=
|
2263
|
-
else: y = Tensor(dtypes.as_const(y, y_dtype),
|
2344
|
+
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
2345
|
+
elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
|
2346
|
+
if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
|
2347
|
+
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
|
2264
2348
|
|
2265
|
-
if match_dtype:
|
2349
|
+
if match_dtype and x.dtype != y.dtype:
|
2266
2350
|
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
2267
2351
|
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
2268
2352
|
|
@@ -2273,7 +2357,6 @@ class Tensor:
|
|
2273
2357
|
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
2274
2358
|
|
2275
2359
|
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
2276
|
-
# TODO: update with multi
|
2277
2360
|
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
2278
2361
|
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
2279
2362
|
|
@@ -2315,7 +2398,8 @@ class Tensor:
|
|
2315
2398
|
print(t.sub(Tensor([[2.0], [3.5]])).numpy())
|
2316
2399
|
```
|
2317
2400
|
"""
|
2318
|
-
|
2401
|
+
a, b = self._broadcasted(x, reverse)
|
2402
|
+
return a + (-b)
|
2319
2403
|
|
2320
2404
|
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2321
2405
|
"""
|
@@ -2528,8 +2612,8 @@ class Tensor:
|
|
2528
2612
|
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
2529
2613
|
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
2530
2614
|
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
2531
|
-
def
|
2532
|
-
def
|
2615
|
+
def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
|
2616
|
+
def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
|
2533
2617
|
|
2534
2618
|
# ***** functional nn ops *****
|
2535
2619
|
|
@@ -2652,8 +2736,8 @@ class Tensor:
|
|
2652
2736
|
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
2653
2737
|
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
2654
2738
|
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
2655
|
-
qk = self
|
2656
|
-
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
|
2739
|
+
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
2740
|
+
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
2657
2741
|
|
2658
2742
|
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
2659
2743
|
"""
|
@@ -2874,5 +2958,5 @@ def custom_random(out:Buffer):
|
|
2874
2958
|
Tensor._seed += 1
|
2875
2959
|
rng = np.random.default_rng(Tensor._seed)
|
2876
2960
|
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
2877
|
-
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype
|
2961
|
+
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
|
2878
2962
|
out.copyin(rng_np_buffer.data)
|