tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,19 +1,21 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools
4
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast, get_args
3
+ import time, math, itertools, functools, struct
4
+ from contextlib import ContextDecorator
5
+ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
5
6
  from collections import defaultdict
6
- from functools import partialmethod, reduce
7
7
  import numpy as np
8
8
 
9
- from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
10
- from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
11
- from tinygrad.lazy import LazyBuffer, create_schedule
12
- from tinygrad.features.multi import MultiLazyBuffer
13
- from tinygrad.ops import LoadOps
14
- from tinygrad.device import Device, Buffer
15
- from tinygrad.shape.symbolic import sint
16
- from tinygrad.realize import run_schedule
9
+ from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
+ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
11
+ from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
12
+ from tinygrad.lazy import LazyBuffer
13
+ from tinygrad.multi import MultiLazyBuffer
14
+ from tinygrad.ops import LoadOps, truncate
15
+ from tinygrad.device import Device, Buffer, BufferOptions
16
+ from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
17
+ from tinygrad.engine.realize import run_schedule
18
+ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
17
19
 
18
20
  # **** start with two base classes, Tensor and Function ****
19
21
 
@@ -30,69 +32,145 @@ class Function:
30
32
  @classmethod
31
33
  def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
32
34
  ctx = fxn(x[0].device, *x)
33
- ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
34
- if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
35
+ ret = Tensor.__new__(Tensor)
36
+ ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
+ ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
35
38
  return ret
36
39
 
37
- import tinygrad.mlops as mlops
40
+ import tinygrad.function as F
38
41
 
39
- def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Optional[LazyBuffer]=None):
42
+ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
40
43
  if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
41
44
  return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
42
45
 
43
- Scalar = Union[float, int, bool]
46
+ def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
47
+ def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
48
+
49
+ def _fromnp(x: np.ndarray) -> LazyBuffer:
50
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
51
+ # fake realize
52
+ ret.buffer.allocate(x)
53
+ del ret.srcs
54
+ return ret
55
+
56
+ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
57
+ if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
58
+ else:
59
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
60
+ assert dtype.fmt is not None, f"{dtype=} has None fmt"
61
+ truncate_function = truncate[dtype]
62
+ data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
63
+ # fake realize
64
+ ret.buffer.allocate(memoryview(data))
65
+ del ret.srcs
66
+ return ret
67
+
68
+ def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
69
+ return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
70
+ for k in range(len(mat[0]))] for dim in range(dims)]
71
+
72
+ # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
73
+ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
74
+ # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
75
+ # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
76
+ t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
77
+ # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
78
+ matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
79
+ # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
80
+ ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
81
+ assert isinstance(ret, Tensor), "sum didn't return a Tensor"
82
+ return ret
83
+
84
+ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
85
+ max_dim = max(len(shape) for shape in shapes)
86
+ return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
87
+ def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
88
+ return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
44
89
 
45
90
  class Tensor:
91
+ """
92
+ A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
93
+
94
+ ```python exec="true" session="tensor"
95
+ from tinygrad import Tensor, dtypes, nn
96
+ import numpy as np
97
+ import math
98
+ np.set_printoptions(precision=4)
99
+ ```
100
+ """
46
101
  __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
47
102
  __deletable__ = ('_ctx',)
48
103
  training: ClassVar[bool] = False
49
- class train:
50
- def __init__(self, val=True): self.val = val
51
- def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
52
- def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
53
-
54
104
  no_grad: ClassVar[bool] = False
55
- def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
105
+
106
+ def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
56
107
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
57
108
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
58
109
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
59
- # tensors have gradients, buffers do not
110
+
111
+ # tensors can have gradients if you have called .backward
60
112
  self.grad: Optional[Tensor] = None
61
113
 
62
114
  # NOTE: this can be in three states. False and None: no gradient, True: gradient
63
115
  # None (the default) will be updated to True if it's put in an optimizer
64
116
  self.requires_grad: Optional[bool] = requires_grad
65
117
 
66
- # internal variables used for autograd graph construction
118
+ # internal variable used for autograd graph construction
67
119
  self._ctx: Optional[Function] = None
120
+
121
+ # create a LazyBuffer from the different types of inputs
68
122
  if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
69
- elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
70
- elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
123
+ elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
124
+ elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
125
+ elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
126
+ elif isinstance(data, (list, tuple)):
127
+ if dtype is None:
128
+ if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
129
+ else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
130
+ if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
131
+ else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
71
132
  elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
72
- elif isinstance(data, list):
73
- if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
74
- elif d and all_int(d): dtype = dtype or dtypes.default_int
75
- else: dtype = dtype or dtypes.default_float
76
- # NOTE: cast at the end for the dtypes that do not have a numpy dtype
77
- data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
78
133
  elif isinstance(data, np.ndarray):
79
- if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
80
- else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
134
+ if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
135
+ else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
136
+
137
+ # by this point, it has to be a LazyBuffer
138
+ if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
139
+ raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
81
140
 
82
141
  # data is a LazyBuffer, but it might be on the wrong device
83
- if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
84
142
  if isinstance(device, tuple):
85
- # TODO: what if it's a MultiLazyBuffer on other devices?
86
- self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
143
+ # if device is a tuple, we should have/construct a MultiLazyBuffer
144
+ if isinstance(data, MultiLazyBuffer):
145
+ assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
146
+ self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
147
+ else:
148
+ self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
87
149
  else:
88
150
  self.lazydata = data if data.device == device else data.copy_to_device(device)
89
151
 
152
+ class train(ContextDecorator):
153
+ def __init__(self, mode:bool = True): self.mode = mode
154
+ def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
155
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
156
+
157
+ class inference_mode(ContextDecorator):
158
+ def __init__(self, mode:bool = True): self.mode = mode
159
+ def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
160
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
161
+
90
162
  def __repr__(self):
91
- return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
163
+ return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
92
164
 
93
165
  # Python has a non moving GC, so this should be okay
94
166
  def __hash__(self): return id(self)
95
167
 
168
+ def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
169
+
170
+ def __len__(self):
171
+ if not self.shape: raise TypeError("len() of a 0-d tensor")
172
+ return self.shape[0]
173
+
96
174
  @property
97
175
  def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
98
176
 
@@ -104,148 +182,522 @@ class Tensor:
104
182
 
105
183
  # ***** data handlers ****
106
184
 
107
- @staticmethod
108
- def corealize(lst:Iterable[Tensor]):
109
- return run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
185
+ def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
186
+ """Creates the schedule needed to realize these Tensor(s), with Variables."""
187
+ if getenv("FUZZ_SCHEDULE"):
188
+ from test.external.fuzz_schedule import fuzz_schedule
189
+ fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
190
+ schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
191
+ return memory_planner(schedule), var_vals
192
+
193
+ def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
194
+ """Creates the schedule needed to realize these Tensor(s)."""
195
+ schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
196
+ assert len(var_vals) == 0
197
+ return schedule
198
+
199
+ def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
200
+ """Triggers the computation needed to create these Tensor(s)."""
201
+ run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
202
+ return self
110
203
 
111
- def realize(self) -> Tensor:
112
- run_schedule(self.lazydata.schedule())
204
+ def replace(self, x:Tensor) -> Tensor:
205
+ """
206
+ Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
207
+ """
208
+ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
209
+ assert not x.requires_grad and getattr(self, '_ctx', None) is None
210
+ assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
211
+ self.lazydata = x.lazydata
113
212
  return self
114
213
 
115
214
  def assign(self, x) -> Tensor:
116
215
  # TODO: this is a hack for writing to DISK. remove with working assign
117
216
  if isinstance(self.device, str) and self.device.startswith("DISK"):
118
- if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
217
+ if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
119
218
  self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
120
219
  return self
121
220
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
221
+ if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
222
+ if self.lazydata is x.lazydata: return self # a self assign is a NOOP
122
223
  # NOTE: we allow cross device assign
123
224
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
225
+ assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
226
+ assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
227
+ assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
124
228
  assert not x.requires_grad # self requires_grad is okay?
125
- if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
126
- if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
127
- if isinstance(self.lazydata, MultiLazyBuffer):
128
- for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
129
- else:
130
- if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
131
- self.lazydata = x.lazydata
229
+ if not self.lazydata.is_realized(): return self.replace(x)
230
+ self.lazydata = self.lazydata.assign(x.lazydata)
132
231
  return self
133
- def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
134
232
 
135
- # TODO: these are good places to start removing numpy
136
- def item(self) -> Scalar:
233
+ def detach(self) -> Tensor:
234
+ """
235
+ Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
236
+ """
237
+ return Tensor(self.lazydata, device=self.device, requires_grad=False)
238
+
239
+ def _data(self) -> memoryview:
240
+ if 0 in self.shape: return memoryview(bytearray(0))
241
+ # NOTE: this realizes on the object from as_buffer being a Python object
242
+ cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
243
+ buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
244
+ if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
245
+ return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
246
+
247
+ def data(self) -> memoryview:
248
+ """
249
+ Returns the data of this tensor as a memoryview.
250
+
251
+ ```python exec="true" source="above" session="tensor" result="python"
252
+ t = Tensor([1, 2, 3, 4])
253
+ print(np.frombuffer(t.data(), dtype=np.int32))
254
+ ```
255
+ """
256
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
257
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
258
+ return self._data().cast(self.dtype.fmt, self.shape)
259
+
260
+ def item(self) -> ConstType:
261
+ """
262
+ Returns the value of this tensor as a standard Python number.
263
+
264
+ ```python exec="true" source="above" session="tensor" result="python"
265
+ t = Tensor(42)
266
+ print(t.item())
267
+ ```
268
+ """
269
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
137
270
  assert self.numel() == 1, "must have one element for item"
138
- return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
139
- def data(self) -> memoryview: return self.numpy().data
271
+ return self._data().cast(self.dtype.fmt)[0]
272
+
273
+ # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
274
+ # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
275
+ def tolist(self) -> Union[Sequence[ConstType], ConstType]:
276
+ """
277
+ Returns the value of this tensor as a nested list.
278
+
279
+ ```python exec="true" source="above" session="tensor" result="python"
280
+ t = Tensor([1, 2, 3, 4])
281
+ print(t.tolist())
282
+ ```
283
+ """
284
+ return self.data().tolist()
140
285
 
141
- # TODO: this should import numpy and use .data() to construct the array
142
286
  def numpy(self) -> np.ndarray:
143
- assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
144
- assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
145
- if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
146
- t = self if isinstance(self.device, str) else self.to("CPU")
147
- return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
148
-
149
- def to(self, device:Optional[str]) -> Tensor:
150
- if device is None or device == self.device: return self
151
- ret = Tensor(self.lazydata, device)
152
- if self.grad: ret.grad = self.grad.to(device)
287
+ """
288
+ Returns the value of this tensor as a `numpy.ndarray`.
289
+
290
+ ```python exec="true" source="above" session="tensor" result="python"
291
+ t = Tensor([1, 2, 3, 4])
292
+ print(repr(t.numpy()))
293
+ ```
294
+ """
295
+ if self.dtype == dtypes.bfloat16: return self.float().numpy()
296
+ assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
297
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
298
+ return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
299
+
300
+ def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
301
+ """
302
+ Moves the tensor to the given device.
303
+ """
304
+ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
305
+ if device == self.device: return self
306
+ if not isinstance(device, str): return self.shard(device)
307
+ ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
308
+ if self.grad is not None: ret.grad = self.grad.to(device)
309
+ if hasattr(self, '_ctx'): ret._ctx = self._ctx
153
310
  return ret
154
311
 
155
- def to_(self, device:Optional[str]):
156
- if device is None or device == self.device: return
157
- if self.grad: self.grad = self.grad.to_(device)
158
- _ret = Tensor(self.lazydata, device)
159
- self.lazydata = _ret.lazydata
312
+ def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
313
+ """
314
+ Moves the tensor to the given device in place.
315
+ """
316
+ real = self.to(device)
317
+ # TODO: is this assign?
318
+ if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
319
+ self.lazydata = real.lazydata
160
320
 
161
321
  def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
322
+ """
323
+ Shards the tensor across the given devices.
324
+ """
162
325
  assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
163
326
  canonical_devices = tuple(Device.canonicalize(x) for x in devices)
164
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices)
327
+ if axis is not None and axis < 0: axis += len(self.shape)
328
+ return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
165
329
 
166
330
  def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
331
+ """
332
+ Shards the tensor across the given devices in place.
333
+ """
167
334
  self.lazydata = self.shard(devices, axis).lazydata
168
335
  return self
169
336
 
337
+ @staticmethod
338
+ def from_node(y:Node, **kwargs) -> Tensor:
339
+ if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
340
+ if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
341
+ if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
342
+ if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
343
+ raise RuntimeError(f"unhandled Node {y}")
344
+
170
345
  # ***** creation llop entrypoint *****
171
346
 
172
347
  @staticmethod
173
- def _loadop(op, shape, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
174
- return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
348
+ def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
349
+ if isinstance(device, tuple):
350
+ return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
351
+ for d in device], None), device, dtype, **kwargs)
352
+ return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
175
353
 
176
354
  @staticmethod
177
- def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
355
+ def empty(*shape, **kwargs):
356
+ """
357
+ Creates an empty tensor with the given shape.
358
+
359
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
360
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
361
+
362
+ ```python exec="true" source="above" session="tensor" result="python"
363
+ t = Tensor.empty(2, 3)
364
+ print(t.shape)
365
+ ```
366
+ """
367
+ return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
178
368
 
179
369
  _seed: int = int(time.time())
370
+ _rng_counter: Optional[Tensor] = None
180
371
  @staticmethod
181
- def manual_seed(seed=0): Tensor._seed = seed
372
+ def manual_seed(seed=0):
373
+ """
374
+ Sets the seed for random operations.
375
+
376
+ ```python exec="true" source="above" session="tensor" result="python"
377
+ Tensor.manual_seed(42)
378
+ print(Tensor.rand(5).numpy())
379
+ print(Tensor.rand(5).numpy())
380
+ ```
381
+ ```python exec="true" source="above" session="tensor" result="python"
382
+ Tensor.manual_seed(42) # reset to the same seed
383
+ print(Tensor.rand(5).numpy())
384
+ print(Tensor.rand(5).numpy())
385
+ ```
386
+ """
387
+ Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
182
388
 
183
389
  @staticmethod
184
- def rand(*shape, **kwargs): return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs)
390
+ def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
391
+ """
392
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
393
+
394
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
395
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
396
+
397
+ ```python exec="true" source="above" session="tensor" result="python"
398
+ Tensor.manual_seed(42)
399
+ t = Tensor.rand(2, 3)
400
+ print(t.numpy())
401
+ ```
402
+ """
403
+ if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
404
+ if not THREEFRY.value:
405
+ # for bfloat16, numpy rand passes buffer in float
406
+ if (dtype or dtypes.default_float) == dtypes.bfloat16:
407
+ return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
408
+ return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
409
+
410
+ # threefry
411
+ if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
412
+ counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
413
+ counts2 = counts1 + math.ceil(num / 2)
414
+ Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
415
+
416
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
417
+ ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
418
+
419
+ x = [counts1 + ks[-1], counts2 + ks[0]]
420
+ for i in range(5):
421
+ for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
422
+ x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
423
+ out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
424
+ out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
425
+ out.requires_grad = kwargs.get("requires_grad")
426
+ return out.contiguous()
185
427
 
186
428
  # ***** creation helper functions *****
187
429
 
188
430
  @staticmethod
189
- def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
431
+ def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
432
+ """
433
+ Creates a tensor with the given shape, filled with the given value.
434
+
435
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
436
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
437
+
438
+ ```python exec="true" source="above" session="tensor" result="python"
439
+ print(Tensor.full((2, 3), 42).numpy())
440
+ ```
441
+ ```python exec="true" source="above" session="tensor" result="python"
442
+ print(Tensor.full((2, 3), False).numpy())
443
+ ```
444
+ """
190
445
  return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
191
446
 
192
447
  @staticmethod
193
- def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
448
+ def zeros(*shape, **kwargs):
449
+ """
450
+ Creates a tensor with the given shape, filled with zeros.
451
+
452
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
453
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
454
+
455
+ ```python exec="true" source="above" session="tensor" result="python"
456
+ print(Tensor.zeros(2, 3).numpy())
457
+ ```
458
+ ```python exec="true" source="above" session="tensor" result="python"
459
+ print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
460
+ ```
461
+ """
462
+ return Tensor.full(argfix(*shape), 0.0, **kwargs)
194
463
 
195
464
  @staticmethod
196
- def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
465
+ def ones(*shape, **kwargs):
466
+ """
467
+ Creates a tensor with the given shape, filled with ones.
468
+
469
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
470
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
471
+
472
+ ```python exec="true" source="above" session="tensor" result="python"
473
+ print(Tensor.ones(2, 3).numpy())
474
+ ```
475
+ ```python exec="true" source="above" session="tensor" result="python"
476
+ print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
477
+ ```
478
+ """
479
+ return Tensor.full(argfix(*shape), 1.0, **kwargs)
197
480
 
198
481
  @staticmethod
199
482
  def arange(start, stop=None, step=1, **kwargs):
483
+ """
484
+ Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
485
+
486
+ If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
487
+
488
+ If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
489
+
490
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
491
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
492
+
493
+ ```python exec="true" source="above" session="tensor" result="python"
494
+ print(Tensor.arange(5).numpy())
495
+ ```
496
+ ```python exec="true" source="above" session="tensor" result="python"
497
+ print(Tensor.arange(5, 10).numpy())
498
+ ```
499
+ ```python exec="true" source="above" session="tensor" result="python"
500
+ print(Tensor.arange(5, 10, 2).numpy())
501
+ ```
502
+ ```python exec="true" source="above" session="tensor" result="python"
503
+ print(Tensor.arange(5.5, 10, 2).numpy())
504
+ ```
505
+ """
200
506
  if stop is None: stop, start = start, 0
507
+ assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
201
508
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
202
- return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
509
+ return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
203
510
 
204
511
  @staticmethod
205
512
  def eye(dim:int, **kwargs):
513
+ """
514
+ Creates an identity matrix of the given dimension.
515
+
516
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
517
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
518
+
519
+ ```python exec="true" source="above" session="tensor" result="python"
520
+ print(Tensor.eye(3).numpy())
521
+ ```
522
+ """
206
523
  return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
207
524
 
208
- def full_like(self, fill_value:Scalar, **kwargs):
525
+ def full_like(self, fill_value:ConstType, **kwargs):
526
+ """
527
+ Creates a tensor with the same shape as `self`, filled with the given value.
528
+ If `dtype` is not specified, the dtype of `self` is used.
529
+
530
+ You can pass in the `device` keyword argument to control device of the tensor.
531
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
532
+
533
+ ```python exec="true" source="above" session="tensor" result="python"
534
+ t = Tensor.ones(2, 3)
535
+ print(Tensor.full_like(t, 42).numpy())
536
+ ```
537
+ """
209
538
  return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
210
- def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
211
- def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
539
+
540
+ def zeros_like(self, **kwargs):
541
+ """
542
+ Creates a tensor with the same shape as `self`, filled with zeros.
543
+
544
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
545
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
546
+
547
+ ```python exec="true" source="above" session="tensor" result="python"
548
+ t = Tensor.ones(2, 3)
549
+ print(Tensor.zeros_like(t).numpy())
550
+ ```
551
+ """
552
+ return self.full_like(0, **kwargs)
553
+
554
+ def ones_like(self, **kwargs):
555
+ """
556
+ Creates a tensor with the same shape as `self`, filled with ones.
557
+
558
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
559
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
560
+
561
+ ```python exec="true" source="above" session="tensor" result="python"
562
+ t = Tensor.zeros(2, 3)
563
+ print(Tensor.ones_like(t).numpy())
564
+ ```
565
+ """
566
+ return self.full_like(1, **kwargs)
212
567
 
213
568
  # ***** rng hlops *****
214
569
 
215
570
  @staticmethod
216
571
  def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
572
+ """
573
+ Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
574
+ If `dtype` is not specified, the default type is used.
575
+
576
+ You can pass in the `device` keyword argument to control device of the tensor.
577
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
578
+
579
+ ```python exec="true" source="above" session="tensor" result="python"
580
+ Tensor.manual_seed(42)
581
+ print(Tensor.randn(2, 3).numpy())
582
+ ```
583
+ """
217
584
  # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
218
- src = Tensor.rand((2, *argfix(*shape)), **kwargs)
585
+ src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
219
586
  return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
220
587
 
221
588
  @staticmethod
222
- def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32)
589
+ def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
590
+ """
591
+ Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
592
+ If `dtype` is not specified, the default type is used.
593
+
594
+ You can pass in the `device` keyword argument to control device of the tensor.
595
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
596
+
597
+ ```python exec="true" source="above" session="tensor" result="python"
598
+ Tensor.manual_seed(42)
599
+ print(Tensor.randint(2, 3, low=5, high=10).numpy())
600
+ ```
601
+ """
602
+ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
603
+ dtype = kwargs.pop("dtype", dtypes.int32)
604
+ if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
605
+ return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
223
606
 
224
607
  @staticmethod
225
- def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
608
+ def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
609
+ """
610
+ Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
611
+
612
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
613
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
614
+
615
+ ```python exec="true" source="above" session="tensor" result="python"
616
+ Tensor.manual_seed(42)
617
+ print(Tensor.normal(2, 3, mean=10, std=2).numpy())
618
+ ```
619
+ """
620
+ return (std * Tensor.randn(*shape, **kwargs)) + mean
226
621
 
227
622
  @staticmethod
228
623
  def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
624
+ """
625
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
626
+
627
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
628
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
629
+
630
+ ```python exec="true" source="above" session="tensor" result="python"
631
+ Tensor.manual_seed(42)
632
+ print(Tensor.uniform(2, 3, low=2, high=10).numpy())
633
+ ```
634
+ """
229
635
  dtype = kwargs.pop("dtype", dtypes.default_float)
230
636
  return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
231
637
 
232
638
  @staticmethod
233
- def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
639
+ def scaled_uniform(*shape, **kwargs) -> Tensor:
640
+ """
641
+ Creates a tensor with the given shape, filled with random values from a uniform distribution
642
+ over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
643
+
644
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
645
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
646
+
647
+ ```python exec="true" source="above" session="tensor" result="python"
648
+ Tensor.manual_seed(42)
649
+ print(Tensor.scaled_uniform(2, 3).numpy())
650
+ ```
651
+ """
652
+ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
234
653
 
235
654
  # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
236
655
  @staticmethod
237
656
  def glorot_uniform(*shape, **kwargs) -> Tensor:
657
+ """
658
+ <https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
659
+
660
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
661
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
662
+
663
+ ```python exec="true" source="above" session="tensor" result="python"
664
+ Tensor.manual_seed(42)
665
+ print(Tensor.glorot_uniform(2, 3).numpy())
666
+ ```
667
+ """
238
668
  return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
239
669
 
240
670
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
241
671
  @staticmethod
242
672
  def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
673
+ """
674
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
675
+
676
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
677
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
678
+
679
+ ```python exec="true" source="above" session="tensor" result="python"
680
+ Tensor.manual_seed(42)
681
+ print(Tensor.kaiming_uniform(2, 3).numpy())
682
+ ```
683
+ """
243
684
  bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
244
685
  return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
245
686
 
246
687
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
247
688
  @staticmethod
248
689
  def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
690
+ """
691
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
692
+
693
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
694
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
695
+
696
+ ```python exec="true" source="above" session="tensor" result="python"
697
+ Tensor.manual_seed(42)
698
+ print(Tensor.kaiming_normal(2, 3).numpy())
699
+ ```
700
+ """
249
701
  std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
250
702
  return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
251
703
 
@@ -254,31 +706,40 @@ class Tensor:
254
706
  assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
255
707
  weight = self.unsqueeze(0) if self.ndim == 1 else self
256
708
  cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
257
- unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
709
+ unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
258
710
  indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
259
- return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
711
+ return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
260
712
 
261
713
  # ***** toposort and backward pass *****
262
714
 
263
- def deepwalk(self):
264
- def _deepwalk(node, visited, nodes):
715
+ def _deepwalk(self):
716
+ def _walk(node, visited):
265
717
  visited.add(node)
266
718
  if getattr(node, "_ctx", None):
267
719
  for i in node._ctx.parents:
268
- if i not in visited: _deepwalk(i, visited, nodes)
269
- nodes.append(node)
270
- return nodes
271
- return _deepwalk(self, set(), [])
720
+ if i not in visited: yield from _walk(i, visited)
721
+ yield node
722
+ return list(_walk(self, set()))
272
723
 
273
724
  def backward(self) -> Tensor:
725
+ """
726
+ Propagates the gradient of a tensor backwards through the computation graph.
727
+ Must be used on a scalar tensor.
728
+
729
+ ```python exec="true" source="above" session="tensor" result="python"
730
+ t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
731
+ t.sum().backward()
732
+ print(t.grad.numpy())
733
+ ```
734
+ """
274
735
  assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
275
736
 
276
737
  # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
277
738
  # this is "implicit gradient creation"
278
- self.grad = Tensor(1.0, device=self.device, requires_grad=False)
739
+ self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
279
740
 
280
- for t0 in reversed(self.deepwalk()):
281
- assert (t0.grad is not None)
741
+ for t0 in reversed(self._deepwalk()):
742
+ if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
282
743
  grads = t0._ctx.backward(t0.grad.lazydata)
283
744
  grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
284
745
  for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
@@ -289,290 +750,825 @@ class Tensor:
289
750
  del t0._ctx
290
751
  return self
291
752
 
292
- # ***** movement mlops *****
753
+ # ***** movement low level ops *****
754
+
755
+ def view(self, *shape) -> Tensor:
756
+ """`.view` is an alias for `.reshape`."""
757
+ return self.reshape(shape)
293
758
 
294
759
  def reshape(self, shape, *args) -> Tensor:
295
- new_shape = argfix(shape, *args)
296
- return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) # noqa: E501
760
+ """
761
+ Returns a tensor with the same data as the original tensor but with a different shape.
762
+ `shape` can be passed as a tuple or as separate arguments.
763
+
764
+ ```python exec="true" source="above" session="tensor" result="python"
765
+ t = Tensor.arange(6)
766
+ print(t.reshape(2, 3).numpy())
767
+ ```
768
+ """
769
+ # resolve None and args
770
+ new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
771
+ # resolve -1
772
+ if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
773
+ if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
774
+ return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
775
+
297
776
  def expand(self, shape, *args) -> Tensor:
298
- if shape == self.shape: return self
299
- return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
300
- def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
301
- def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
777
+ """
778
+ Returns a tensor that is expanded to the shape that is specified.
779
+ Expand can also increase the number of dimensions that a tensor has.
780
+
781
+ Passing a `-1` or `None` to a dimension means that its size will not be changed.
782
+
783
+ ```python exec="true" source="above" session="tensor" result="python"
784
+ t = Tensor([1, 2, 3])
785
+ print(t.expand(4, -1).numpy())
786
+ ```
787
+ """
788
+ return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
789
+
790
+ def permute(self, order, *args) -> Tensor:
791
+ """
792
+ Returns a tensor that is a permutation of the original tensor.
793
+ The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
794
+ `order` can be passed as a tuple or as separate arguments.
795
+
796
+ ```python exec="true" source="above" session="tensor" result="python"
797
+ t = Tensor.arange(6).reshape(2, 3)
798
+ print(t.numpy())
799
+ ```
800
+ ```python exec="true" source="above" session="tensor" result="python"
801
+ print(t.permute(1, 0).numpy())
802
+ ```
803
+ """
804
+ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
805
+ if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
806
+ return F.Permute.apply(self, order=order_arg)
807
+
808
+ def flip(self, axis, *args) -> Tensor:
809
+ """
810
+ Returns a tensor that reverses the order of the original tensor along given `axis`.
811
+ `axis` can be passed as a tuple or as separate arguments.
812
+
813
+ ```python exec="true" source="above" session="tensor" result="python"
814
+ t = Tensor.arange(6).reshape(2, 3)
815
+ print(t.numpy())
816
+ ```
817
+ ```python exec="true" source="above" session="tensor" result="python"
818
+ print(t.flip(0).numpy())
819
+ ```
820
+ ```python exec="true" source="above" session="tensor" result="python"
821
+ print(t.flip((0, 1)).numpy())
822
+ ```
823
+ """
824
+ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
825
+ if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
826
+ return F.Flip.apply(self, axis=axis_arg)
827
+
302
828
  def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
303
- if not any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)): return self
304
- return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
829
+ """
830
+ Returns a tensor that shrinks the each axis based on input arg.
831
+ `arg` must have the same length as `self.ndim`.
832
+ For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
833
+
834
+ ```python exec="true" source="above" session="tensor" result="python"
835
+ t = Tensor.arange(9).reshape(3, 3)
836
+ print(t.numpy())
837
+ ```
838
+ ```python exec="true" source="above" session="tensor" result="python"
839
+ print(t.shrink(((None, (1, 3)))).numpy())
840
+ ```
841
+ ```python exec="true" source="above" session="tensor" result="python"
842
+ print(t.shrink((((0, 2), (0, 2)))).numpy())
843
+ ```
844
+ """
845
+ if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
846
+ return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
847
+
305
848
  def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
849
+ """
850
+ Returns a tensor that pads the each axis based on input arg.
851
+ `arg` must have the same length as `self.ndim`.
852
+ For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
853
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
854
+
855
+ ```python exec="true" source="above" session="tensor" result="python"
856
+ t = Tensor.arange(6).reshape(2, 3)
857
+ print(t.numpy())
858
+ ```
859
+ ```python exec="true" source="above" session="tensor" result="python"
860
+ print(t.pad(((None, (1, 2)))).numpy())
861
+ ```
862
+ ```python exec="true" source="above" session="tensor" result="python"
863
+ print(t.pad(((None, (1, 2))), -2).numpy())
864
+ ```
865
+ """
306
866
  if all(x is None or x == (0,0) for x in arg): return self
307
- ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
308
- return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
309
-
310
- # ***** movement hlops *****
311
-
312
- # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
313
- # - A slice i:j returns the elements with indices in [i, j)
314
- # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
315
- # - Negative values for i and j are taken relative to the end of the sequence
316
- # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
317
- # - Indexing with None on a given axis will add a new dimension of size one before that axis
318
- # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
319
- # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
320
- # - Strides > 1 and < 0 are now allowed!:
321
- # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
322
- # - Idea of stride < 0 support:
323
- # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
324
- # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
325
- # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
326
- # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
327
- # is possible.
328
- # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
329
- # - Fancy indexing and combined indexing is supported
330
- # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
331
- # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
332
- # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
333
- # - There's a special case where a permute is needed at the end:
334
- # - if first Tensor passed in (expand dims) is not at dim 0
335
- # - and following Tensors does not follow consecutively to the end of fancy indexing's dims
336
- # TODO: boolean indices
337
- # TODO: figure out the exact acceptable types for indices, especially for internal list/tuple types
338
- # TODO: update docs
339
- def __getitem__(self, indices: Union[int, slice, Tensor, None, List, Tuple]) -> Tensor: # no ellipsis type...
867
+ ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
868
+ return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
869
+
870
+ # ***** movement high level ops *****
871
+
872
+ # Supported Indexing Implementations:
873
+ # 1. Int indexing (no copy)
874
+ # - for all dims where there's int, shrink -> reshape
875
+ # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
876
+ # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
877
+ # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
878
+ # 2. Slice indexing (no copy)
879
+ # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
880
+ # - first shrink the Tensor to X.shrink(((start, end),))
881
+ # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
882
+ # - flip where dim value is negative
883
+ # - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
884
+ # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
885
+ # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
886
+ # 3. None indexing (no copy)
887
+ # - reshape (inject) a dim at the dim where there's None
888
+ # 4. Tensor indexing (copy)
889
+ # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
890
+ # - combine masks together with mul
891
+ # - apply mask to self by mask * self
892
+ # - sum reduce away the extra dims added from creating masks
893
+ # Tiny Things:
894
+ # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
895
+ # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
896
+ # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
897
+ # 2. Bool indexing is not supported
898
+ # 3. Out of bounds Tensor indexing results in 0
899
+ # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
900
+ def __getitem__(self, indices) -> Tensor:
340
901
  # 1. indices normalization and validation
341
902
  # treat internal tuples and lists as Tensors and standardize indices to list type
342
- if isinstance(indices, (tuple, list)):
343
- # special case <indices: List[int]>, a lil ugly
344
- if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, requires_grad=False, device=self.device)]
345
- else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
903
+ if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
904
+ elif isinstance(indices, (tuple, list)):
905
+ indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
346
906
  else: indices = [indices]
347
907
 
908
+ # turn scalar Tensors into const val for int indexing if possible
909
+ indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
910
+ # move Tensor indices to the same device as self
911
+ indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
912
+
348
913
  # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
349
914
  ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
350
915
  fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
351
- num_slices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
352
- indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
916
+ num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
917
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
353
918
 
354
919
  # use Dict[type, List[dimension]] to track elements in indices
355
920
  type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
356
921
 
357
922
  # record None for dimension injection later and filter None and record rest of indices
358
923
  type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
359
- indices_filtered = [v for v in indices if v is not None]
924
+ tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
925
+ indices_filtered = [i for i in indices if i is not None]
360
926
  for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
361
927
 
362
- # validation! raise Errors
363
- if slice in type_dim and self.ndim == 0: raise IndexError("slice cannot be applied to a 0-dim tensor.")
364
- if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
365
- if float in type_dim: raise IndexError("float type is not valid index")
366
- if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
367
- if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
928
+ if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
929
+ for index_type in type_dim:
930
+ if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
931
+ if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
368
932
 
369
- # 2. basic indexing (no copy)
370
- # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
371
- # turn indices in indices_filtered to Tuple[shrink_arg, strides]
933
+ # 2. basic indexing, uses only movement ops (no copy)
934
+ # currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
935
+ # turn indices in indices_filtered to Tuple[new_slice, strides]
372
936
  for dim in type_dim[int]:
373
937
  if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
374
- raise IndexError(f"{index=} is out of bounds for dimension {dim} with {size=}")
938
+ raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
375
939
  indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
376
940
  for dim in type_dim[slice]:
377
- s, e, st = indices_filtered[dim].indices(self.shape[dim])
378
- indices_filtered[dim] = ((0, 0) if (st > 0 and e < s) or (st <= 0 and e > s) else (s, e) if st > 0 else (e+1, s+1), st)
379
- for dim in type_dim[Tensor]: indices_filtered[dim] = ((0, self.shape[dim]), 1)
380
-
381
- new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
382
- ret = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
383
- # add strides by pad -> reshape -> shrink
384
- if any(abs(s) != 1 for s in strides):
941
+ if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
942
+ s, e, st = index.indices(self.shape[dim])
943
+ indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
944
+ # record tensors and skip all Tensor dims for basic indexing
945
+ tensor_index: List[Tensor] = []
946
+ for dim in type_dim[Tensor]:
947
+ tensor_index.append(index := indices_filtered[dim])
948
+ if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
949
+ indices_filtered[dim] = ((0, self.shape[dim]), 1)
950
+
951
+ new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
952
+ # flip negative strides
953
+ ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
954
+ # handle stride != 1 or -1
955
+ if any(abs(st) != 1 for st in strides):
385
956
  strides = tuple(abs(s) for s in strides)
386
- ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
387
- ret = ret.reshape(flatten([sh // s, s] for s, sh in zip(strides, ret.shape)))
388
- ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
957
+ # pad shape to multiple of stride
958
+ ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
959
+ ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
960
+ ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
389
961
 
390
962
  # inject 1 for dim where it's None and collapse dim for int
391
963
  new_shape = list(ret.shape)
392
964
  for dim in type_dim[None]: new_shape.insert(dim, 1)
393
- for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
394
- assert all_int(new_shape), f"does not support symbolic shape {new_shape}"
965
+ for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
395
966
 
396
967
  ret = ret.reshape(new_shape)
397
968
 
398
969
  # 3. advanced indexing (copy)
399
970
  if type_dim[Tensor]:
971
+ # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
972
+ def calc_dim(tensor_dim:int) -> int:
973
+ return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
974
+
975
+ assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
976
+ # track tensor_dim and tensor_index using a dict
977
+ # calc_dim to get dim and use that to normalize the negative tensor indices
978
+ idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
979
+
980
+ masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
981
+ pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
400
982
 
401
- # extract tensors and tensor dimensions
402
- idx, tdim = [], []
403
- for tensor_dim in type_dim[Tensor]:
404
- dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
405
- tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
406
- # normalize the negative tensor indices
407
- idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
408
- # TODO uint8 and bool tensor indexing
409
- if not (dtypes.is_int(t.dtype) or t.dtype == dtypes.bool): raise IndexError("tensors used as indices must be int or bool tensors")
410
-
411
- # compute sum_dim, arange, and idx
412
- max_dim = max(i.ndim for i in idx)
413
- sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
414
- arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
415
- first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
416
- rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
417
- reshaped_idx = first_idx + rest_idx
418
- ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
419
-
420
- # iteratively eq -> mul -> sum fancy index
421
- try:
422
- for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
423
- except AssertionError as exc: raise IndexError(f"cannot broadcast with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
983
+ # create masks
984
+ for dim, i in idx.items():
985
+ try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
986
+ except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
987
+ a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
988
+ masks.append(i == a)
989
+
990
+ # reduce masks to 1 mask
991
+ mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
992
+
993
+ # inject 1's for the extra dims added in create masks
994
+ reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
995
+ # sum reduce the extra dims introduced in create masks
996
+ ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
424
997
 
425
998
  # special permute case
426
- if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):
427
- ret_dims = list(range(ret.ndim))
428
- ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
999
+ if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
1000
+ ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
429
1001
  return ret
430
1002
 
431
- def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
432
-
433
- # NOTE: using slice is discouraged and things should migrate to pad and shrink
434
- def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
1003
+ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
1004
+ if isinstance(self.device, str) and self.device.startswith("DISK"):
1005
+ self.__getitem__(indices).assign(v)
1006
+ return
1007
+ # NOTE: check that setitem target is valid first
1008
+ assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
1009
+ if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1010
+ if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
1011
+ if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
1012
+ if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
1013
+ raise NotImplementedError("Advanced indexing setitem is not currently supported")
1014
+
1015
+ assign_to = self.realize().__getitem__(indices)
1016
+ # NOTE: contiguous to prevent const folding.
1017
+ v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
1018
+ assign_to.assign(v).realize()
1019
+
1020
+ # NOTE: using _slice is discouraged and things should migrate to pad and shrink
1021
+ def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
435
1022
  arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
436
1023
  padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
437
1024
  return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
438
1025
 
439
- def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
440
- assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
441
- assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
442
- if dim < 0: dim += self.ndim
443
- idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
444
- permarg = list(range(self.ndim))
445
- permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
446
- return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501
1026
+ def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
1027
+ """
1028
+ Gathers values along an axis specified by `dim`.
1029
+
1030
+ ```python exec="true" source="above" session="tensor" result="python"
1031
+ t = Tensor([[1, 2], [3, 4]])
1032
+ print(t.numpy())
1033
+ ```
1034
+ ```python exec="true" source="above" session="tensor" result="python"
1035
+ print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
1036
+ ```
1037
+ """
1038
+ assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
1039
+ assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1040
+ dim = self._resolve_dim(dim)
1041
+ index = index.to(self.device)
1042
+ x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1043
+ return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
447
1044
 
448
1045
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
449
- if dim < 0: dim += self.ndim
1046
+ """
1047
+ Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
1048
+ All tensors must have the same shape except in the concatenating dimension.
1049
+
1050
+ ```python exec="true" source="above" session="tensor" result="python"
1051
+ t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
1052
+ print(t0.cat(t1, t2, dim=0).numpy())
1053
+ ```
1054
+ ```python exec="true" source="above" session="tensor" result="python"
1055
+ print(t0.cat(t1, t2, dim=1).numpy())
1056
+ ```
1057
+ """
1058
+ dim = self._resolve_dim(dim)
450
1059
  assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
451
1060
  catargs = [self, *args]
452
- assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
453
- shapes = [s.shape[dim] for s in catargs]
454
- shape_cumsum = [0, *itertools.accumulate(shapes)]
1061
+ cat_dims = [s.shape[dim] for s in catargs]
1062
+ cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
455
1063
  slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
456
- for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp)
457
- return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
458
-
459
- @staticmethod
460
- def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
461
- unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
1064
+ for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
1065
+ return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
1066
+
1067
+ def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1068
+ """
1069
+ Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
1070
+
1071
+ ```python exec="true" source="above" session="tensor" result="python"
1072
+ t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
1073
+ print(t0.stack(t1, t2, dim=0).numpy())
1074
+ ```
1075
+ ```python exec="true" source="above" session="tensor" result="python"
1076
+ print(t0.stack(t1, t2, dim=1).numpy())
1077
+ ```
1078
+ """
462
1079
  # checks for shapes and number of dimensions delegated to cat
463
- return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
464
-
465
- def repeat(self, repeats:Sequence[int]) -> Tensor:
1080
+ return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
1081
+
1082
+ def repeat(self, repeats, *args) -> Tensor:
1083
+ """
1084
+ Repeats tensor number of times along each dimension specified by `repeats`.
1085
+ `repeats` can be passed as a tuple or as separate arguments.
1086
+
1087
+ ```python exec="true" source="above" session="tensor" result="python"
1088
+ t = Tensor([1, 2, 3])
1089
+ print(t.repeat(4, 2).numpy())
1090
+ ```
1091
+ ```python exec="true" source="above" session="tensor" result="python"
1092
+ print(t.repeat(4, 2, 1).shape)
1093
+ ```
1094
+ """
1095
+ repeats = argfix(repeats, *args)
466
1096
  base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
467
1097
  new_shape = [x for b in base_shape for x in [1, b]]
468
1098
  expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
469
1099
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
470
1100
  return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
471
1101
 
1102
+ def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
1103
+ if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
1104
+ raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
1105
+ return dim + self.ndim+outer if dim < 0 else dim
1106
+
472
1107
  def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1108
+ """
1109
+ Splits the tensor into chunks along the dimension specified by `dim`.
1110
+ If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
1111
+ If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
1112
+
1113
+ ```python exec="true" source="above" session="tensor" result="python"
1114
+ t = Tensor.arange(10).reshape(5, 2)
1115
+ print(t.numpy())
1116
+ ```
1117
+ ```python exec="true" source="above" session="tensor" result="python"
1118
+ split = t.split(2)
1119
+ print("\\n".join([repr(x.numpy()) for x in split]))
1120
+ ```
1121
+ ```python exec="true" source="above" session="tensor" result="python"
1122
+ split = t.split([1, 4])
1123
+ print("\\n".join([repr(x.numpy()) for x in split]))
1124
+ ```
1125
+ """
473
1126
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
474
- dim = dim + self.ndim if dim < 0 else dim
475
- if isinstance(sizes, int): return tuple(self.chunk(math.ceil(self.shape[dim]/sizes)))
1127
+ dim = self._resolve_dim(dim)
1128
+ if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
1129
+ assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
476
1130
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
477
1131
 
478
- def chunk(self, num:int, dim:int=0) -> List[Tensor]:
1132
+ def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
1133
+ """
1134
+ Splits the tensor into `chunks` number of chunks along the dimension `dim`.
1135
+ If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
1136
+ The function may return fewer than the specified number of chunks.
1137
+
1138
+ ```python exec="true" source="above" session="tensor" result="python"
1139
+ chunked = Tensor.arange(11).chunk(6)
1140
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1141
+ ```
1142
+ ```python exec="true" source="above" session="tensor" result="python"
1143
+ chunked = Tensor.arange(12).chunk(6)
1144
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1145
+ ```
1146
+ ```python exec="true" source="above" session="tensor" result="python"
1147
+ chunked = Tensor.arange(13).chunk(6)
1148
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1149
+ ```
1150
+ """
479
1151
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
480
- dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
481
- slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
482
- return [self[tuple(sl)] for sl in slice_params]
1152
+ assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
1153
+ dim = self._resolve_dim(dim)
1154
+ return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
483
1155
 
484
1156
  def squeeze(self, dim:Optional[int]=None) -> Tensor:
1157
+ """
1158
+ Returns a tensor with specified dimensions of input of size 1 removed.
1159
+ If `dim` is not specified, all dimensions with size 1 are removed.
1160
+
1161
+ ```python exec="true" source="above" session="tensor" result="python"
1162
+ t = Tensor.zeros(2, 1, 2, 1, 2)
1163
+ print(t.squeeze().shape)
1164
+ ```
1165
+ ```python exec="true" source="above" session="tensor" result="python"
1166
+ print(t.squeeze(0).shape)
1167
+ ```
1168
+ ```python exec="true" source="above" session="tensor" result="python"
1169
+ print(t.squeeze(1).shape)
1170
+ ```
1171
+ """
485
1172
  if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
486
- if self.ndim == 0 and dim in [-1, 0]: return self # this is to match torch behavior
487
- if not -self.ndim <= dim <= self.ndim-1: raise IndexError(f"{dim=} out of range {[-self.ndim, self.ndim-1] if self.ndim else [-1, 0]}")
488
- if dim < 0: dim += self.ndim
489
- return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
1173
+ dim = self._resolve_dim(dim)
1174
+ return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
490
1175
 
491
1176
  def unsqueeze(self, dim:int) -> Tensor:
492
- if dim < 0: dim = self.ndim + dim + 1
1177
+ """
1178
+ Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
1179
+
1180
+ ```python exec="true" source="above" session="tensor" result="python"
1181
+ t = Tensor([1, 2, 3, 4])
1182
+ print(t.unsqueeze(0).numpy())
1183
+ ```
1184
+ ```python exec="true" source="above" session="tensor" result="python"
1185
+ print(t.unsqueeze(1).numpy())
1186
+ ```
1187
+ """
1188
+ dim = self._resolve_dim(dim, outer=True)
493
1189
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
494
1190
 
495
- # (padding_left, padding_right, padding_top, padding_bottom)
496
- def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
1191
+ def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
1192
+ """
1193
+ Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
1194
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
1195
+
1196
+ ```python exec="true" source="above" session="tensor" result="python"
1197
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1198
+ print(t.numpy())
1199
+ ```
1200
+ ```python exec="true" source="above" session="tensor" result="python"
1201
+ print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
1202
+ ```
1203
+ """
497
1204
  slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
498
- return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
1205
+ return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
499
1206
 
500
1207
  @property
501
- def T(self) -> Tensor: return self.transpose()
502
- def transpose(self, ax1=1, ax2=0) -> Tensor:
1208
+ def T(self) -> Tensor:
1209
+ """`.T` is an alias for `.transpose()`."""
1210
+ return self.transpose()
1211
+
1212
+ def transpose(self, dim0=1, dim1=0) -> Tensor:
1213
+ """
1214
+ Returns a tensor that is a transposed version of the original tensor.
1215
+ The given dimensions `dim0` and `dim1` are swapped.
1216
+
1217
+ ```python exec="true" source="above" session="tensor" result="python"
1218
+ t = Tensor.arange(6).reshape(2, 3)
1219
+ print(t.numpy())
1220
+ ```
1221
+ ```python exec="true" source="above" session="tensor" result="python"
1222
+ print(t.transpose(0, 1).numpy())
1223
+ ```
1224
+ """
503
1225
  order = list(range(self.ndim))
504
- order[ax1], order[ax2] = order[ax2], order[ax1]
1226
+ order[dim0], order[dim1] = order[dim1], order[dim0]
505
1227
  return self.permute(order)
1228
+
506
1229
  def flatten(self, start_dim=0, end_dim=-1):
507
- start_dim, end_dim = start_dim + self.ndim if start_dim < 0 else start_dim, end_dim + self.ndim if end_dim < 0 else end_dim
1230
+ """
1231
+ Flattens the tensor by reshaping it into a one-dimensional tensor.
1232
+ If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
1233
+
1234
+ ```python exec="true" source="above" session="tensor" result="python"
1235
+ t = Tensor.arange(8).reshape(2, 2, 2)
1236
+ print(t.flatten().numpy())
1237
+ ```
1238
+ ```python exec="true" source="above" session="tensor" result="python"
1239
+ print(t.flatten(start_dim=1).numpy())
1240
+ ```
1241
+ """
1242
+ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
508
1243
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1244
+
509
1245
  def unflatten(self, dim:int, sizes:Tuple[int,...]):
510
- if dim < 0: dim += self.ndim
1246
+ """
1247
+ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
1248
+
1249
+ ```python exec="true" source="above" session="tensor" result="python"
1250
+ print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
1251
+ ```
1252
+ ```python exec="true" source="above" session="tensor" result="python"
1253
+ print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
1254
+ ```
1255
+ ```python exec="true" source="above" session="tensor" result="python"
1256
+ print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
1257
+ ```
1258
+ """
1259
+ dim = self._resolve_dim(dim)
511
1260
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
512
1261
 
513
1262
  # ***** reduce ops *****
514
1263
 
515
- def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
516
- axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
517
- axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
518
- shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
519
- if 0 in self.shape and 0 not in shape:
520
- return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
521
- ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
522
- return ret if keepdim else ret.reshape(shape=shape)
523
-
524
- def sum(self, axis=None, keepdim=False):
525
- acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
526
- least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
527
- least_upper_dtype(self.dtype, dtypes.float)
528
- # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
529
- output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
530
- return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
531
-
532
- def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
533
- def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
534
-
535
- def mean(self, axis=None, keepdim=False):
536
- assert all_int(self.shape), "does not support symbolic shape"
537
- out = self.sum(axis=axis, keepdim=keepdim)
538
- return out.mul(prod(out.shape)/prod(self.shape)) if 0 not in self.shape else out
539
- def std(self, axis=None, keepdim=False, correction=1):
540
- assert all_int(self.shape), "does not support symbolic shape"
541
- square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
542
- return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt()
1264
+ def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1265
+ if self.ndim == 0:
1266
+ if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
1267
+ axis = ()
1268
+ axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
1269
+ axis_ = tuple(self._resolve_dim(x) for x in axis_)
1270
+ ret = fxn.apply(self, axis=axis_)
1271
+ return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1272
+
1273
+ def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
1274
+ """
1275
+ Sums the elements of the tensor along the specified axis or axes.
1276
+
1277
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1278
+ which the maximum is computed and whether the reduced dimensions are retained.
1279
+
1280
+ You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1281
+ If not specified, the accumulation data type is chosen based on the input tensor's data type.
1282
+
1283
+ ```python exec="true" source="above" session="tensor" result="python"
1284
+ t = Tensor.arange(6).reshape(2, 3)
1285
+ print(t.numpy())
1286
+ ```
1287
+ ```python exec="true" source="above" session="tensor" result="python"
1288
+ print(t.sum().numpy())
1289
+ ```
1290
+ ```python exec="true" source="above" session="tensor" result="python"
1291
+ print(t.sum(axis=0).numpy())
1292
+ ```
1293
+ ```python exec="true" source="above" session="tensor" result="python"
1294
+ print(t.sum(axis=1).numpy())
1295
+ ```
1296
+ """
1297
+ ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
1298
+ return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1299
+
1300
+ def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1301
+ """
1302
+ Returns the maximum value of the tensor along the specified axis or axes.
1303
+
1304
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1305
+ which the maximum is computed and whether the reduced dimensions are retained.
1306
+
1307
+ ```python exec="true" source="above" session="tensor" result="python"
1308
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1309
+ print(t.numpy())
1310
+ ```
1311
+ ```python exec="true" source="above" session="tensor" result="python"
1312
+ print(t.max().numpy())
1313
+ ```
1314
+ ```python exec="true" source="above" session="tensor" result="python"
1315
+ print(t.max(axis=0).numpy())
1316
+ ```
1317
+ ```python exec="true" source="above" session="tensor" result="python"
1318
+ print(t.max(axis=1, keepdim=True).numpy())
1319
+ ```
1320
+ """
1321
+ return self._reduce(F.Max, axis, keepdim)
1322
+
1323
+ def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1324
+ """
1325
+ Returns the minimum value of the tensor along the specified axis or axes.
1326
+
1327
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1328
+ which the minimum is computed and whether the reduced dimensions are retained.
1329
+
1330
+ ```python exec="true" source="above" session="tensor" result="python"
1331
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1332
+ print(t.numpy())
1333
+ ```
1334
+ ```python exec="true" source="above" session="tensor" result="python"
1335
+ print(t.min().numpy())
1336
+ ```
1337
+ ```python exec="true" source="above" session="tensor" result="python"
1338
+ print(t.min(axis=0).numpy())
1339
+ ```
1340
+ ```python exec="true" source="above" session="tensor" result="python"
1341
+ print(t.min(axis=1, keepdim=True).numpy())
1342
+ ```
1343
+ """
1344
+ return -((-self).max(axis=axis, keepdim=keepdim))
1345
+
1346
+ def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1347
+ """
1348
+ Returns the mean value of the tensor along the specified axis or axes.
1349
+
1350
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1351
+ which the mean is computed and whether the reduced dimensions are retained.
1352
+
1353
+ ```python exec="true" source="above" session="tensor" result="python"
1354
+ Tensor.manual_seed(42)
1355
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1356
+ print(t.numpy())
1357
+ ```
1358
+ ```python exec="true" source="above" session="tensor" result="python"
1359
+ print(t.mean().numpy())
1360
+ ```
1361
+ ```python exec="true" source="above" session="tensor" result="python"
1362
+ print(t.mean(axis=0).numpy())
1363
+ ```
1364
+ ```python exec="true" source="above" session="tensor" result="python"
1365
+ print(t.mean(axis=1).numpy())
1366
+ ```
1367
+ """
1368
+ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
1369
+ numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1370
+ return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
1371
+
1372
+ def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1373
+ """
1374
+ Returns the variance of the tensor along the specified axis or axes.
1375
+
1376
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1377
+ which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1378
+
1379
+ ```python exec="true" source="above" session="tensor" result="python"
1380
+ Tensor.manual_seed(42)
1381
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1382
+ print(t.numpy())
1383
+ ```
1384
+ ```python exec="true" source="above" session="tensor" result="python"
1385
+ print(t.var().numpy())
1386
+ ```
1387
+ ```python exec="true" source="above" session="tensor" result="python"
1388
+ print(t.var(axis=0).numpy())
1389
+ ```
1390
+ ```python exec="true" source="above" session="tensor" result="python"
1391
+ print(t.var(axis=1).numpy())
1392
+ ```
1393
+ """
1394
+ squares = (self - self.mean(axis=axis, keepdim=True)).square()
1395
+ n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
1396
+ return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
1397
+
1398
+ def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1399
+ """
1400
+ Returns the standard deviation of the tensor along the specified axis or axes.
1401
+
1402
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1403
+ which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1404
+
1405
+ ```python exec="true" source="above" session="tensor" result="python"
1406
+ Tensor.manual_seed(42)
1407
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1408
+ print(t.numpy())
1409
+ ```
1410
+ ```python exec="true" source="above" session="tensor" result="python"
1411
+ print(t.std().numpy())
1412
+ ```
1413
+ ```python exec="true" source="above" session="tensor" result="python"
1414
+ print(t.std(axis=0).numpy())
1415
+ ```
1416
+ ```python exec="true" source="above" session="tensor" result="python"
1417
+ print(t.std(axis=1).numpy())
1418
+ ```
1419
+ """
1420
+ return self.var(axis, keepdim, correction).sqrt()
1421
+
543
1422
  def _softmax(self, axis):
544
1423
  m = self - self.max(axis=axis, keepdim=True)
545
1424
  e = m.exp()
546
1425
  return m, e, e.sum(axis=axis, keepdim=True)
547
1426
 
548
1427
  def softmax(self, axis=-1):
1428
+ """
1429
+ Applies the softmax function to the tensor along the specified axis.
1430
+
1431
+ Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
1432
+
1433
+ You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
1434
+
1435
+ ```python exec="true" source="above" session="tensor" result="python"
1436
+ Tensor.manual_seed(42)
1437
+ t = Tensor.randn(2, 3)
1438
+ print(t.numpy())
1439
+ ```
1440
+ ```python exec="true" source="above" session="tensor" result="python"
1441
+ print(t.softmax().numpy())
1442
+ ```
1443
+ ```python exec="true" source="above" session="tensor" result="python"
1444
+ print(t.softmax(axis=0).numpy())
1445
+ ```
1446
+ """
549
1447
  _, e, ss = self._softmax(axis)
550
1448
  return e.div(ss)
551
1449
 
552
1450
  def log_softmax(self, axis=-1):
1451
+ """
1452
+ Applies the log-softmax function to the tensor along the specified axis.
1453
+
1454
+ The log-softmax function is a numerically stable alternative to the softmax function in log space.
1455
+
1456
+ You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
1457
+
1458
+ ```python exec="true" source="above" session="tensor" result="python"
1459
+ Tensor.manual_seed(42)
1460
+ t = Tensor.randn(2, 3)
1461
+ print(t.numpy())
1462
+ ```
1463
+ ```python exec="true" source="above" session="tensor" result="python"
1464
+ print(t.log_softmax().numpy())
1465
+ ```
1466
+ ```python exec="true" source="above" session="tensor" result="python"
1467
+ print(t.log_softmax(axis=0).numpy())
1468
+ ```
1469
+ """
553
1470
  m, _, ss = self._softmax(axis)
554
1471
  return m - ss.log()
555
1472
 
1473
+ def logsumexp(self, axis=None, keepdim=False):
1474
+ """
1475
+ Computes the log-sum-exp of the tensor along the specified axis or axes.
1476
+
1477
+ The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
1478
+
1479
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1480
+ which the log-sum-exp is computed and whether the reduced dimensions are retained.
1481
+
1482
+ ```python exec="true" source="above" session="tensor" result="python"
1483
+ Tensor.manual_seed(42)
1484
+ t = Tensor.randn(2, 3)
1485
+ print(t.numpy())
1486
+ ```
1487
+ ```python exec="true" source="above" session="tensor" result="python"
1488
+ print(t.logsumexp().numpy())
1489
+ ```
1490
+ ```python exec="true" source="above" session="tensor" result="python"
1491
+ print(t.logsumexp(axis=0).numpy())
1492
+ ```
1493
+ ```python exec="true" source="above" session="tensor" result="python"
1494
+ print(t.logsumexp(axis=1).numpy())
1495
+ ```
1496
+ """
1497
+ m = self.max(axis=axis, keepdim=True)
1498
+ return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
1499
+
556
1500
  def argmax(self, axis=None, keepdim=False):
557
- if axis is None:
558
- idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
559
- return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
560
- axis = axis + len(self.shape) if axis < 0 else axis
1501
+ """
1502
+ Returns the indices of the maximum value of the tensor along the specified axis.
1503
+
1504
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1505
+ which the maximum is computed and whether the reduced dimensions are retained.
1506
+
1507
+ ```python exec="true" source="above" session="tensor" result="python"
1508
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1509
+ print(t.numpy())
1510
+ ```
1511
+ ```python exec="true" source="above" session="tensor" result="python"
1512
+ print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
1513
+ ```
1514
+ ```python exec="true" source="above" session="tensor" result="python"
1515
+ print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
1516
+ ```
1517
+ ```python exec="true" source="above" session="tensor" result="python"
1518
+ print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
1519
+ ```
1520
+ """
1521
+ if axis is None: return self.flatten().argmax(0)
1522
+ axis = self._resolve_dim(axis)
561
1523
  m = self == self.max(axis=axis, keepdim=True)
562
1524
  idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
563
- return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
564
- def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
1525
+ return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
1526
+
1527
+ def argmin(self, axis=None, keepdim=False):
1528
+ """
1529
+ Returns the indices of the minimum value of the tensor along the specified axis.
1530
+
1531
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1532
+ which the minimum is computed and whether the reduced dimensions are retained.
1533
+
1534
+ ```python exec="true" source="above" session="tensor" result="python"
1535
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1536
+ print(t.numpy())
1537
+ ```
1538
+ ```python exec="true" source="above" session="tensor" result="python"
1539
+ print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
1540
+ ```
1541
+ ```python exec="true" source="above" session="tensor" result="python"
1542
+ print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
1543
+ ```
1544
+ ```python exec="true" source="above" session="tensor" result="python"
1545
+ print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
1546
+ ```
1547
+ """
1548
+ return (-self).argmax(axis=axis, keepdim=keepdim)
565
1549
 
566
1550
  @staticmethod
567
- def einsum(formula:str, *raw_xs) -> Tensor:
1551
+ def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
1552
+ """
1553
+ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
1554
+
1555
+ See: https://pytorch.org/docs/stable/generated/torch.einsum.html
1556
+
1557
+ ```python exec="true" source="above" session="tensor" result="python"
1558
+ x = Tensor([[1, 2], [3, 4]])
1559
+ y = Tensor([[5, 6], [7, 8]])
1560
+ print(Tensor.einsum("ij,ij->", x, y).numpy())
1561
+ ```
1562
+ """
568
1563
  xs:Tuple[Tensor] = argfix(*raw_xs)
569
1564
  formula = formula.replace(" ", "")
570
- inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
571
- inputs = [x for x in cast(str,inputs_str).split(',')]
1565
+ inputs_str, output = formula.split("->") if "->" in formula else (formula, \
1566
+ ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1567
+ inputs = inputs_str.split(',')
572
1568
  assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
573
1569
 
574
1570
  # map the value of each letter in the formula
575
- letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
1571
+ letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
576
1572
 
577
1573
  xs_:List[Tensor] = []
578
1574
  lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
@@ -580,9 +1576,13 @@ class Tensor:
580
1576
  # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
581
1577
  xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
582
1578
 
583
- rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
1579
+ # determine the inverse permutation to revert back to original order
1580
+ rhs_letter_order = argsort(list(output))
1581
+ rhs_order = argsort(rhs_letter_order)
1582
+
584
1583
  # sum over all axes that's not in the output, then permute to the output order
585
- return reduce(lambda a,b:a*b, xs_).sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
1584
+ return functools.reduce(lambda a,b:a*b, xs_) \
1585
+ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
586
1586
 
587
1587
  # ***** processing ops *****
588
1588
 
@@ -590,46 +1590,72 @@ class Tensor:
590
1590
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
591
1591
  assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
592
1592
  s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
593
- assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
594
- slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
1593
+ assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
1594
+ noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
1595
+ o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
595
1596
  if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
596
- o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
597
- e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
598
- xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # noqa: E501
599
- # slide by dilation
600
- xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
601
- xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
602
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
603
- # handle stride, and permute to move reduce to the end
604
- xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
605
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
606
- xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
607
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
608
- # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
609
- o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
610
- xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
611
- xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
612
- xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
613
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
1597
+ # repeats such that we don't need padding
1598
+ xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
1599
+ # handle dilation
1600
+ xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
1601
+ # handle stride
1602
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1603
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1604
+ # permute to move reduce to the end
1605
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
1606
+ # TODO: once the shapetracker can optimize well, remove this alternative implementation
1607
+ xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1608
+ xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1609
+ xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
1610
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
614
1611
 
615
1612
  # NOTE: these work for more than 2D
616
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
617
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
1613
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1614
+ """
1615
+ Applies average pooling over a tensor.
618
1616
 
619
- def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
620
- HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
621
- x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
622
- stride = make_pair(stride, len(HW))
623
- if any(s>1 for s in stride):
624
- x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
625
- x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
626
- x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
627
- x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
628
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501
629
- return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1617
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1618
+
1619
+ See: https://paperswithcode.com/method/average-pooling
1620
+
1621
+ ```python exec="true" source="above" session="tensor" result="python"
1622
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1623
+ print(t.avg_pool2d().numpy())
1624
+ ```
1625
+ """
1626
+ kernel_size = make_pair(kernel_size)
1627
+ return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
1628
+
1629
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1630
+ """
1631
+ Applies max pooling over a tensor.
630
1632
 
631
- wino = getenv("WINO", 0)
632
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
1633
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1634
+
1635
+ See: https://paperswithcode.com/method/max-pooling
1636
+
1637
+ ```python exec="true" source="above" session="tensor" result="python"
1638
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1639
+ print(t.max_pool2d().numpy())
1640
+ ```
1641
+ """
1642
+ kernel_size = make_pair(kernel_size)
1643
+ return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
1644
+
1645
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
1646
+ """
1647
+ Applies a convolution over a tensor with a given `weight` and optional `bias`.
1648
+
1649
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
1650
+
1651
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
1652
+
1653
+ ```python exec="true" source="above" session="tensor" result="python"
1654
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1655
+ w = Tensor.ones(1, 1, 2, 2)
1656
+ print(t.conv2d(w).numpy())
1657
+ ```
1658
+ """
633
1659
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
634
1660
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
635
1661
  if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
@@ -638,19 +1664,17 @@ class Tensor:
638
1664
  # conv2d is a pooling op (with padding)
639
1665
  x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
640
1666
  rcout, oyx = cout//groups, x.shape[2:-len(HW)]
641
- if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
1667
+ if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
642
1668
  # normal conv
643
1669
  x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
644
1670
 
645
1671
  # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
646
- ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
1672
+ ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
647
1673
  return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
648
1674
 
649
- # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
650
- def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
651
1675
  HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
652
- winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
653
1676
  winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
1677
+ winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
654
1678
  winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
655
1679
 
656
1680
  # todo: stride == dilation
@@ -658,19 +1682,19 @@ class Tensor:
658
1682
  # (bs, cin_, tyx, HWI)
659
1683
  d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
660
1684
  # move HW to the front: # (HWI, bs, cin_, tyx)
661
- d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward()
1685
+ d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
662
1686
  tyx = d.shape[-len(HWI):] # dim of tiling
663
1687
 
664
1688
  g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
665
1689
 
666
1690
  # compute 6x6 winograd tiles: GgGt, BtdB
667
1691
  # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
668
- gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
1692
+ gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
669
1693
  # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
670
- dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
1694
+ dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
671
1695
 
672
1696
  # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
673
- ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW)))
1697
+ ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
674
1698
 
675
1699
  # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
676
1700
  ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
@@ -679,163 +1703,880 @@ class Tensor:
679
1703
 
680
1704
  return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
681
1705
 
682
- def dot(self, w:Tensor) -> Tensor:
1706
+ def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
1707
+ """
1708
+ Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
1709
+
1710
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
1711
+
1712
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
1713
+
1714
+ ```python exec="true" source="above" session="tensor" result="python"
1715
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1716
+ w = Tensor.ones(1, 1, 2, 2)
1717
+ print(t.conv_transpose2d(w).numpy())
1718
+ ```
1719
+ """
1720
+ x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
1721
+ HW = weight.shape[2:]
1722
+ stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
1723
+ if any(s>1 for s in stride):
1724
+ # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
1725
+ x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
1726
+ x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
1727
+ x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
1728
+ x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
1729
+ padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
1730
+ return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1731
+
1732
+ def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
1733
+ """
1734
+ Performs dot product between two tensors.
1735
+
1736
+ You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
1737
+
1738
+ ```python exec="true" source="above" session="tensor" result="python"
1739
+ a = Tensor([[1, 2], [3, 4]])
1740
+ b = Tensor([[5, 6], [7, 8]])
1741
+ print(a.dot(b).numpy())
1742
+ ```
1743
+ """
683
1744
  n1, n2 = len(self.shape), len(w.shape)
684
1745
  assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
685
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
1746
+ assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
686
1747
  x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
687
1748
  w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
688
- return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
1749
+ return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
1750
+
1751
+ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
1752
+ """
1753
+ Performs matrix multiplication between two tensors.
689
1754
 
690
- def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
1755
+ You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
1756
+ You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
1757
+
1758
+ ```python exec="true" source="above" session="tensor" result="python"
1759
+ a = Tensor([[1, 2], [3, 4]])
1760
+ b = Tensor([[5, 6], [7, 8]])
1761
+ print(a.matmul(b).numpy())
1762
+ ```
1763
+ """
1764
+ return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
691
1765
 
692
1766
  def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
693
- return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
1767
+ assert self.shape[axis] != 0
1768
+ pl_sz = self.shape[axis] - int(not _first_zero)
1769
+ return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
694
1770
  def cumsum(self, axis:int=0) -> Tensor:
1771
+ """
1772
+ Computes the cumulative sum of the tensor along the specified axis.
1773
+
1774
+ You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
1775
+
1776
+ ```python exec="true" source="above" session="tensor" result="python"
1777
+ t = Tensor.ones(2, 3)
1778
+ print(t.numpy())
1779
+ ```
1780
+ ```python exec="true" source="above" session="tensor" result="python"
1781
+ print(t.cumsum(1).numpy())
1782
+ ```
1783
+ """
1784
+ axis = self._resolve_dim(axis)
1785
+ if self.ndim == 0 or 0 in self.shape: return self
695
1786
  # TODO: someday the optimizer will find this on it's own
696
1787
  # for now this is a two stage cumsum
697
1788
  SPLIT = 256
698
1789
  if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
699
1790
  ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
700
1791
  ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
701
- base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
1792
+ base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
702
1793
  base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
703
1794
  def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
704
1795
  return fix(ret) + fix(base_add)
705
1796
 
706
1797
  @staticmethod
707
- def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
708
- assert all_int((r,c)), "does not support symbolic"
709
- return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
710
- def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
711
- def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
712
-
713
- # ***** mlops (unary) *****
714
-
715
- def neg(self): return mlops.Neg.apply(self)
716
- def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self)
717
- def contiguous(self): return mlops.Contiguous.apply(self)
718
- def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
719
- def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
720
- def log2(self): return self.log()/math.log(2)
721
- def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
722
- def exp2(self): return mlops.Exp.apply(self*math.log(2))
723
- def relu(self): return mlops.Relu.apply(self)
724
- def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
725
- def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
726
- def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
727
- def rsqrt(self): return self.reciprocal().sqrt()
728
- def cos(self): return ((math.pi/2)-self).sin()
729
- def tan(self): return self.sin() / self.cos()
730
-
731
- # ***** math functions (unary) *****
732
-
733
- def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
734
- def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
735
- def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
736
-
737
- def square(self): return self*self
738
- def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
739
- def abs(self): return self.relu() + (-self).relu()
740
- def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
741
- def reciprocal(self): return 1.0/self
742
-
743
- # ***** activation functions (unary) *****
744
-
745
- def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
746
- def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
747
- def swish(self): return self * self.sigmoid()
748
- def silu(self): return self.swish() # The SiLU function is also known as the swish function.
749
- def relu6(self): return self.relu() - (self-6).relu()
750
- def hardswish(self): return self * (self+3).relu6() * (1/6)
751
- def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
752
- def sinh(self): return (self.exp() - self.neg().exp()) / 2
753
- def cosh(self): return (self.exp() + self.neg().exp()) / 2
754
- def atanh(self): return ((1 + self)/(1 - self)).log() / 2
755
- def asinh(self): return (self + (self.square() + 1).sqrt()).log()
756
- def acosh(self): return (self + (self.square() - 1).sqrt()).log()
757
- def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
758
- def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
759
- def quick_gelu(self): return self * (self * 1.702).sigmoid()
760
- def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
761
- def mish(self): return self * self.softplus().tanh()
762
- def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
763
- def softsign(self): return self / (1 + self.abs())
764
-
765
- # ***** broadcasted elementwise mlops *****
766
-
767
- def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
1798
+ def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
1799
+ assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
1800
+ if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
1801
+ if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
1802
+ s = r+c-1
1803
+ # build a (s, s) upper triangle
1804
+ t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
1805
+ return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
1806
+
1807
+ def triu(self, diagonal:int=0) -> Tensor:
1808
+ """
1809
+ Returns the upper triangular part of the tensor, the other elements are set to 0.
1810
+
1811
+ The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
1812
+ Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
1813
+
1814
+ ```python exec="true" source="above" session="tensor" result="python"
1815
+ t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
1816
+ print(t.numpy())
1817
+ ```
1818
+ ```python exec="true" source="above" session="tensor" result="python"
1819
+ print(t.triu(diagonal=0).numpy())
1820
+ ```
1821
+ ```python exec="true" source="above" session="tensor" result="python"
1822
+ print(t.triu(diagonal=1).numpy())
1823
+ ```
1824
+ ```python exec="true" source="above" session="tensor" result="python"
1825
+ print(t.triu(diagonal=-1).numpy())
1826
+ ```
1827
+ """
1828
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
1829
+
1830
+ def tril(self, diagonal:int=0) -> Tensor:
1831
+ """
1832
+ Returns the lower triangular part of the tensor, the other elements are set to 0.
1833
+
1834
+ The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
1835
+ Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
1836
+
1837
+ ```python exec="true" source="above" session="tensor" result="python"
1838
+ t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
1839
+ print(t.numpy())
1840
+ ```
1841
+ ```python exec="true" source="above" session="tensor" result="python"
1842
+ print(t.tril(diagonal=0).numpy())
1843
+ ```
1844
+ ```python exec="true" source="above" session="tensor" result="python"
1845
+ print(t.tril(diagonal=1).numpy())
1846
+ ```
1847
+ ```python exec="true" source="above" session="tensor" result="python"
1848
+ print(t.tril(diagonal=-1).numpy())
1849
+ ```
1850
+ """
1851
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
1852
+
1853
+ # ***** unary ops *****
1854
+
1855
+ def logical_not(self):
1856
+ """
1857
+ Computes the logical NOT of the tensor element-wise.
1858
+
1859
+ ```python exec="true" source="above" session="tensor" result="python"
1860
+ print(Tensor([False, True]).logical_not().numpy())
1861
+ ```
1862
+ """
1863
+ return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
1864
+ def neg(self):
1865
+ """
1866
+ Negates the tensor element-wise.
1867
+
1868
+ ```python exec="true" source="above" session="tensor" result="python"
1869
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
1870
+ ```
1871
+ """
1872
+ return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
1873
+ def contiguous(self):
1874
+ """
1875
+ Returns a contiguous tensor.
1876
+ """
1877
+ return F.Contiguous.apply(self)
1878
+ def contiguous_backward(self):
1879
+ """
1880
+ Inserts a contiguous operation in the backward pass.
1881
+ """
1882
+ return F.ContiguousBackward.apply(self)
1883
+ def log(self):
1884
+ """
1885
+ Computes the natural logarithm element-wise.
1886
+
1887
+ See: https://en.wikipedia.org/wiki/Logarithm
1888
+
1889
+ ```python exec="true" source="above" session="tensor" result="python"
1890
+ print(Tensor([1., 2., 4., 8.]).log().numpy())
1891
+ ```
1892
+ """
1893
+ return F.Log.apply(self.cast(least_upper_float(self.dtype)))
1894
+ def log2(self):
1895
+ """
1896
+ Computes the base-2 logarithm element-wise.
1897
+
1898
+ See: https://en.wikipedia.org/wiki/Logarithm
1899
+
1900
+ ```python exec="true" source="above" session="tensor" result="python"
1901
+ print(Tensor([1., 2., 4., 8.]).log2().numpy())
1902
+ ```
1903
+ """
1904
+ return self.log()/math.log(2)
1905
+ def exp(self):
1906
+ """
1907
+ Computes the exponential function element-wise.
1908
+
1909
+ See: https://en.wikipedia.org/wiki/Exponential_function
1910
+
1911
+ ```python exec="true" source="above" session="tensor" result="python"
1912
+ print(Tensor([0., 1., 2., 3.]).exp().numpy())
1913
+ ```
1914
+ """
1915
+ return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
1916
+ def exp2(self):
1917
+ """
1918
+ Computes the base-2 exponential function element-wise.
1919
+
1920
+ See: https://en.wikipedia.org/wiki/Exponential_function
1921
+
1922
+ ```python exec="true" source="above" session="tensor" result="python"
1923
+ print(Tensor([0., 1., 2., 3.]).exp2().numpy())
1924
+ ```
1925
+ """
1926
+ return F.Exp.apply(self*math.log(2))
1927
+ def relu(self):
1928
+ """
1929
+ Applies the Rectified Linear Unit (ReLU) function element-wise.
1930
+
1931
+ - Described: https://paperswithcode.com/method/relu
1932
+
1933
+ ```python exec="true" source="above" session="tensor" result="python"
1934
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
1935
+ ```
1936
+ """
1937
+ return F.Relu.apply(self)
1938
+ def sigmoid(self):
1939
+ """
1940
+ Applies the Sigmoid function element-wise.
1941
+
1942
+ - Described: https://en.wikipedia.org/wiki/Sigmoid_function
1943
+
1944
+ ```python exec="true" source="above" session="tensor" result="python"
1945
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
1946
+ ```
1947
+ """
1948
+ return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
1949
+ def sqrt(self):
1950
+ """
1951
+ Computes the square root of the tensor element-wise.
1952
+
1953
+ ```python exec="true" source="above" session="tensor" result="python"
1954
+ print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
1955
+ ```
1956
+ """
1957
+ return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
1958
+ def rsqrt(self):
1959
+ """
1960
+ Computes the reciprocal of the square root of the tensor element-wise.
1961
+
1962
+ ```python exec="true" source="above" session="tensor" result="python"
1963
+ print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
1964
+ ```
1965
+ """
1966
+ return self.reciprocal().sqrt()
1967
+ def sin(self):
1968
+ """
1969
+ Computes the sine of the tensor element-wise.
1970
+
1971
+ ```python exec="true" source="above" session="tensor" result="python"
1972
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
1973
+ ```
1974
+ """
1975
+ return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
1976
+ def cos(self):
1977
+ """
1978
+ Computes the cosine of the tensor element-wise.
1979
+
1980
+ ```python exec="true" source="above" session="tensor" result="python"
1981
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
1982
+ ```
1983
+ """
1984
+ return ((math.pi/2)-self).sin()
1985
+ def tan(self):
1986
+ """
1987
+ Computes the tangent of the tensor element-wise.
1988
+
1989
+ ```python exec="true" source="above" session="tensor" result="python"
1990
+ print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
1991
+ ```
1992
+ """
1993
+ return self.sin() / self.cos()
1994
+
1995
+ # ***** math functions *****
1996
+
1997
+ def trunc(self: Tensor) -> Tensor:
1998
+ """
1999
+ Truncates the tensor element-wise.
2000
+
2001
+ ```python exec="true" source="above" session="tensor" result="python"
2002
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
2003
+ ```
2004
+ """
2005
+ return self.cast(dtypes.int32).cast(self.dtype)
2006
+ def ceil(self: Tensor) -> Tensor:
2007
+ """
2008
+ Rounds the tensor element-wise towards positive infinity.
2009
+
2010
+ ```python exec="true" source="above" session="tensor" result="python"
2011
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
2012
+ ```
2013
+ """
2014
+ return (self > (b := self.trunc())).where(b+1, b)
2015
+ def floor(self: Tensor) -> Tensor:
2016
+ """
2017
+ Rounds the tensor element-wise towards negative infinity.
2018
+
2019
+ ```python exec="true" source="above" session="tensor" result="python"
2020
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
2021
+ ```
2022
+ """
2023
+ return (self < (b := self.trunc())).where(b-1, b)
2024
+ def round(self: Tensor) -> Tensor:
2025
+ """
2026
+ Rounds the tensor element-wise.
2027
+
2028
+ ```python exec="true" source="above" session="tensor" result="python"
2029
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
2030
+ ```
2031
+ """
2032
+ return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
2033
+ def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
2034
+ """
2035
+ Linearly interpolates between `self` and `end` by `weight`.
2036
+
2037
+ ```python exec="true" source="above" session="tensor" result="python"
2038
+ print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
2039
+ ```
2040
+ """
2041
+ return self + (end - self) * weight
2042
+ def square(self):
2043
+ """
2044
+ Squares the tensor element-wise.
2045
+ Equivalent to `self*self`.
2046
+
2047
+ ```python exec="true" source="above" session="tensor" result="python"
2048
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
2049
+ ```
2050
+ """
2051
+ return self*self
2052
+ def clip(self, min_, max_):
2053
+ """
2054
+ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
2055
+
2056
+ ```python exec="true" source="above" session="tensor" result="python"
2057
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
2058
+ ```
2059
+ """
2060
+ return self.maximum(min_).minimum(max_)
2061
+ def sign(self):
2062
+ """
2063
+ Returns the sign of the tensor element-wise.
2064
+
2065
+ ```python exec="true" source="above" session="tensor" result="python"
2066
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
2067
+ ```
2068
+ """
2069
+ return F.Sign.apply(self)
2070
+ def abs(self):
2071
+ """
2072
+ Computes the absolute value of the tensor element-wise.
2073
+
2074
+ ```python exec="true" source="above" session="tensor" result="python"
2075
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
2076
+ ```
2077
+ """
2078
+ return self * self.sign()
2079
+ def reciprocal(self):
2080
+ """
2081
+ Compute `1/x` element-wise.
2082
+
2083
+ ```python exec="true" source="above" session="tensor" result="python"
2084
+ print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
2085
+ ```
2086
+ """
2087
+ return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
2088
+
2089
+ # ***** activation functions *****
2090
+
2091
+ def elu(self, alpha=1.0):
2092
+ """
2093
+ Applies the Exponential Linear Unit (ELU) function element-wise.
2094
+
2095
+ - Described: https://paperswithcode.com/method/elu
2096
+ - Paper: https://arxiv.org/abs/1511.07289v5
2097
+
2098
+ ```python exec="true" source="above" session="tensor" result="python"
2099
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
2100
+ ```
2101
+ """
2102
+ return self.relu() - alpha*(1-self.exp()).relu()
2103
+
2104
+ def celu(self, alpha=1.0):
2105
+ """
2106
+ Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
2107
+
2108
+ - Described: https://paperswithcode.com/method/celu
2109
+ - Paper: https://arxiv.org/abs/1704.07483
2110
+
2111
+ ```python exec="true" source="above" session="tensor" result="python"
2112
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
2113
+ ```
2114
+ """
2115
+ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2116
+
2117
+ def swish(self):
2118
+ """
2119
+ See `.silu()`
2120
+
2121
+ - Paper: https://arxiv.org/abs/1710.05941v1
2122
+
2123
+ ```python exec="true" source="above" session="tensor" result="python"
2124
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
2125
+ ```
2126
+ """
2127
+ return self * self.sigmoid()
2128
+
2129
+ def silu(self):
2130
+ """
2131
+ Applies the Sigmoid Linear Unit (SiLU) function element-wise.
2132
+
2133
+ - Described: https://paperswithcode.com/method/silu
2134
+ - Paper: https://arxiv.org/abs/1606.08415
2135
+
2136
+ ```python exec="true" source="above" session="tensor" result="python"
2137
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
2138
+ ```
2139
+ """
2140
+ return self.swish() # The SiLU function is also known as the swish function.
2141
+
2142
+ def relu6(self):
2143
+ """
2144
+ Applies the ReLU6 function element-wise.
2145
+
2146
+ - Described: https://paperswithcode.com/method/relu6
2147
+ - Paper: https://arxiv.org/abs/1704.04861v1
2148
+
2149
+ ```python exec="true" source="above" session="tensor" result="python"
2150
+ print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
2151
+ ```
2152
+ """
2153
+ return self.relu() - (self-6).relu()
2154
+
2155
+ def hardswish(self):
2156
+ """
2157
+ Applies the Hardswish function element-wise.
2158
+
2159
+ - Described: https://paperswithcode.com/method/hard-swish
2160
+ - Paper: https://arxiv.org/abs/1905.02244v5
2161
+
2162
+ ```python exec="true" source="above" session="tensor" result="python"
2163
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
2164
+ ```
2165
+ """
2166
+ return self * (self+3).relu6() * (1/6)
2167
+
2168
+ def tanh(self):
2169
+ """
2170
+ Applies the Hyperbolic Tangent (tanh) function element-wise.
2171
+
2172
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
2173
+
2174
+ ```python exec="true" source="above" session="tensor" result="python"
2175
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
2176
+ ```
2177
+ """
2178
+ return 2.0 * ((2.0 * self).sigmoid()) - 1.0
2179
+
2180
+ def sinh(self):
2181
+ """
2182
+ Applies the Hyperbolic Sine (sinh) function element-wise.
2183
+
2184
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
2185
+
2186
+ ```python exec="true" source="above" session="tensor" result="python"
2187
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
2188
+ ```
2189
+ """
2190
+ return (self.exp() - self.neg().exp()) / 2
2191
+
2192
+ def cosh(self):
2193
+ """
2194
+ Applies the Hyperbolic Cosine (cosh) function element-wise.
2195
+
2196
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
2197
+
2198
+ ```python exec="true" source="above" session="tensor" result="python"
2199
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
2200
+ ```
2201
+ """
2202
+ return (self.exp() + self.neg().exp()) / 2
2203
+
2204
+ def atanh(self):
2205
+ """
2206
+ Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
2207
+
2208
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
2209
+
2210
+ ```python exec="true" source="above" session="tensor" result="python"
2211
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
2212
+ ```
2213
+ """
2214
+ return ((1 + self)/(1 - self)).log() / 2
2215
+
2216
+ def asinh(self):
2217
+ """
2218
+ Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
2219
+
2220
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
2221
+
2222
+ ```python exec="true" source="above" session="tensor" result="python"
2223
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
2224
+ ```
2225
+ """
2226
+ return (self + (self.square() + 1).sqrt()).log()
2227
+
2228
+ def acosh(self):
2229
+ """
2230
+ Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
2231
+
2232
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
2233
+
2234
+ ```python exec="true" source="above" session="tensor" result="python"
2235
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
2236
+ ```
2237
+ """
2238
+ return (self + (self.square() - 1).sqrt()).log()
2239
+
2240
+ def hardtanh(self, min_val=-1, max_val=1):
2241
+ """
2242
+ Applies the Hardtanh function element-wise.
2243
+
2244
+ - Described: https://paperswithcode.com/method/hardtanh-activation
2245
+
2246
+ ```python exec="true" source="above" session="tensor" result="python"
2247
+ print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
2248
+ ```
2249
+ """
2250
+ return self.clip(min_val, max_val)
2251
+
2252
+ def gelu(self):
2253
+ """
2254
+ Applies the Gaussian Error Linear Unit (GELU) function element-wise.
2255
+
2256
+ - Described: https://paperswithcode.com/method/gelu
2257
+ - Paper: https://arxiv.org/abs/1606.08415v5
2258
+
2259
+ ```python exec="true" source="above" session="tensor" result="python"
2260
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
2261
+ ```
2262
+ """
2263
+ return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
2264
+
2265
+ def quick_gelu(self):
2266
+ """
2267
+ Applies the Sigmoid GELU approximation element-wise.
2268
+
2269
+ - Described: https://paperswithcode.com/method/gelu
2270
+
2271
+ ```python exec="true" source="above" session="tensor" result="python"
2272
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
2273
+ ```
2274
+ """
2275
+ return self * (self * 1.702).sigmoid()
2276
+
2277
+ def leakyrelu(self, neg_slope=0.01):
2278
+ """
2279
+ Applies the Leaky ReLU function element-wise.
2280
+
2281
+ - Described: https://paperswithcode.com/method/leaky-relu
2282
+
2283
+ ```python exec="true" source="above" session="tensor" result="python"
2284
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
2285
+ ```
2286
+ ```python exec="true" source="above" session="tensor" result="python"
2287
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
2288
+ ```
2289
+ """
2290
+ return self.relu() - (-neg_slope*self).relu()
2291
+
2292
+ def mish(self):
2293
+ """
2294
+ Applies the Mish function element-wise.
2295
+
2296
+ - Described: https://paperswithcode.com/method/mish
2297
+ - Paper: https://arxiv.org/abs/1908.08681v3
2298
+
2299
+ ```python exec="true" source="above" session="tensor" result="python"
2300
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
2301
+ ```
2302
+ """
2303
+ return self * self.softplus().tanh()
2304
+
2305
+ def softplus(self, beta=1):
2306
+ """
2307
+ Applies the Softplus function element-wise.
2308
+
2309
+ - Described: https://paperswithcode.com/method/softplus
2310
+
2311
+ ```python exec="true" source="above" session="tensor" result="python"
2312
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
2313
+ ```
2314
+ """
2315
+ return (1/beta) * (1 + (self*beta).exp()).log()
2316
+
2317
+ def softsign(self):
2318
+ """
2319
+ Applies the Softsign function element-wise.
2320
+
2321
+ - Described: https://paperswithcode.com/method/softsign
2322
+
2323
+ ```python exec="true" source="above" session="tensor" result="python"
2324
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
2325
+ ```
2326
+ """
2327
+ return self / (1 + self.abs())
2328
+
2329
+ # ***** broadcasted elementwise ops *****
2330
+ def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
2331
+ if self.shape == shape: return self
2332
+ if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
2333
+ # first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
2334
+ padded, _ = _pad_left(self.shape, shape)
2335
+ # for each dimension, check either from_ is 1, or it does not change
2336
+ if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2337
+ return F.Expand.apply(self.reshape(padded), shape=shape)
2338
+
2339
+ def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
768
2340
  x: Tensor = self
769
2341
  if not isinstance(y, Tensor):
770
2342
  # make y a Tensor
771
- if 0 in self.shape: return self, self.full_like(y)
772
- if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
773
- else: y_dtype = dtypes.from_py(y)
774
- y = Tensor(y, self.device, y_dtype, requires_grad=False)
2343
+ assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
2344
+ if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2345
+ elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
2346
+ if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
2347
+ else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
775
2348
 
776
- if match_dtype:
2349
+ if match_dtype and x.dtype != y.dtype:
777
2350
  output_dtype = least_upper_dtype(x.dtype, y.dtype)
778
2351
  x, y = x.cast(output_dtype), y.cast(output_dtype)
779
2352
 
780
2353
  if reverse: x, y = y, x
781
2354
 
782
- # left pad shape with 1s
783
- if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
784
- elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
785
-
786
- broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
787
- return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
2355
+ # broadcast
2356
+ out_shape = _broadcast_shape(x.shape, y.shape)
2357
+ return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
788
2358
 
789
- def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
790
- # TODO: update with multi
791
- return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
2359
+ def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
2360
+ return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
792
2361
  and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
793
2362
 
794
- def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
795
- x = self._to_const_val(x)
796
- return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
797
- def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
798
- x = self._to_const_val(x)
799
- return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
800
- def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
801
- x = self._to_const_val(x)
802
- if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
803
- if x.__class__ is not Tensor and x == -1.0: return -self
804
- return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
805
- def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
806
- x = self._to_const_val(x)
807
- return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
808
- def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
2363
+ def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2364
+ """
2365
+ Adds `self` and `x`.
2366
+ Equivalent to `self + x`.
2367
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2368
+
2369
+ ```python exec="true" source="above" session="tensor" result="python"
2370
+ Tensor.manual_seed(42)
2371
+ t = Tensor.randn(4)
2372
+ print(t.numpy())
2373
+ ```
2374
+ ```python exec="true" source="above" session="tensor" result="python"
2375
+ print(t.add(20).numpy())
2376
+ ```
2377
+ ```python exec="true" source="above" session="tensor" result="python"
2378
+ print(t.add(Tensor([[2.0], [3.5]])).numpy())
2379
+ ```
2380
+ """
2381
+ return F.Add.apply(*self._broadcasted(x, reverse))
2382
+
2383
+ def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2384
+ """
2385
+ Subtracts `x` from `self`.
2386
+ Equivalent to `self - x`.
2387
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2388
+
2389
+ ```python exec="true" source="above" session="tensor" result="python"
2390
+ Tensor.manual_seed(42)
2391
+ t = Tensor.randn(4)
2392
+ print(t.numpy())
2393
+ ```
2394
+ ```python exec="true" source="above" session="tensor" result="python"
2395
+ print(t.sub(20).numpy())
2396
+ ```
2397
+ ```python exec="true" source="above" session="tensor" result="python"
2398
+ print(t.sub(Tensor([[2.0], [3.5]])).numpy())
2399
+ ```
2400
+ """
2401
+ a, b = self._broadcasted(x, reverse)
2402
+ return a + (-b)
2403
+
2404
+ def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2405
+ """
2406
+ Multiplies `self` and `x`.
2407
+ Equivalent to `self * x`.
2408
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2409
+
2410
+ ```python exec="true" source="above" session="tensor" result="python"
2411
+ Tensor.manual_seed(42)
2412
+ t = Tensor.randn(4)
2413
+ print(t.numpy())
2414
+ ```
2415
+ ```python exec="true" source="above" session="tensor" result="python"
2416
+ print(t.mul(3).numpy())
2417
+ ```
2418
+ ```python exec="true" source="above" session="tensor" result="python"
2419
+ print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
2420
+ ```
2421
+ """
2422
+ return F.Mul.apply(*self._broadcasted(x, reverse))
2423
+
2424
+ def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
2425
+ """
2426
+ Divides `self` by `x`.
2427
+ Equivalent to `self / x`.
2428
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2429
+ By default, `div` performs true division. Set `upcast` to `False` for integer division.
2430
+
2431
+ ```python exec="true" source="above" session="tensor" result="python"
2432
+ Tensor.manual_seed(42)
2433
+ t = Tensor.randn(4)
2434
+ print(t.numpy())
2435
+ ```
2436
+ ```python exec="true" source="above" session="tensor" result="python"
2437
+ print(t.div(3).numpy())
2438
+ ```
2439
+ ```python exec="true" source="above" session="tensor" result="python"
2440
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
2441
+ ```
2442
+ ```python exec="true" source="above" session="tensor" result="python"
2443
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
2444
+ ```
2445
+ """
2446
+ numerator, denominator = self._broadcasted(x, reverse)
2447
+ if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
2448
+ return F.Div.apply(numerator, denominator)
2449
+
2450
+ def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2451
+ """
2452
+ Computes bitwise xor of `self` and `x`.
2453
+ Equivalent to `self ^ x`.
2454
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2455
+
2456
+ ```python exec="true" source="above" session="tensor" result="python"
2457
+ print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
2458
+ ```
2459
+ ```python exec="true" source="above" session="tensor" result="python"
2460
+ print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
2461
+ ```
2462
+ """
2463
+ return F.Xor.apply(*self._broadcasted(x, reverse))
2464
+
2465
+ def lshift(self, x:int):
2466
+ """
2467
+ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2468
+ Equivalent to `self << x`.
2469
+
2470
+ ```python exec="true" source="above" session="tensor" result="python"
2471
+ print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
2472
+ ```
2473
+ """
2474
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2475
+ return self.mul(2 ** x)
2476
+
2477
+ def rshift(self, x:int):
2478
+ """
2479
+ Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2480
+ Equivalent to `self >> x`.
2481
+
2482
+ ```python exec="true" source="above" session="tensor" result="python"
2483
+ print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
2484
+ ```
2485
+ """
2486
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2487
+ return self.div(2 ** x, upcast=False)
2488
+
2489
+ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2490
+ """
2491
+ Computes power of `self` with `x`.
2492
+ Equivalent to `self ** x`.
2493
+
2494
+ ```python exec="true" source="above" session="tensor" result="python"
2495
+ print(Tensor([-1, 2, 3]).pow(2).numpy())
2496
+ ```
2497
+ ```python exec="true" source="above" session="tensor" result="python"
2498
+ print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
2499
+ ```
2500
+ ```python exec="true" source="above" session="tensor" result="python"
2501
+ print((2 ** Tensor([-1, 2, 3])).numpy())
2502
+ ```
2503
+ """
809
2504
  x = self._to_const_val(x)
810
2505
  if not isinstance(x, Tensor) and not reverse:
811
2506
  # simple pow identities
812
2507
  if x < 0: return self.reciprocal().pow(-x)
813
- if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
814
- if x == 0.5: return self.sqrt()
2508
+ if x == 0: return 1 + self * 0
2509
+ if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
2510
+ if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
2511
+
2512
+ # positive const ** self
815
2513
  if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
816
- ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
817
- # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
818
- sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
819
- # we only need to correct the sign if the base is negative
820
- base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
821
- # we need 0 to be positive so we need to correct base_sign when the base is 0
822
- base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
823
- # inject nan if the base is negative and the power is not an integer
824
- to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501
825
- inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
826
- return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
827
- def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
828
-
829
- # TODO: this implicitly changes dtype with /2
830
- def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
831
- def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
832
-
833
- def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
834
- x_,y = self._broadcasted(input_, match_dtype=False)
835
- x,z = x_._broadcasted(other, match_dtype=False)
836
- return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
837
-
838
- # ***** op wrappers (wasted lines to make the typechecker happy) *****
2514
+
2515
+ base, exponent = self._broadcasted(x, reverse=reverse)
2516
+ # start with b ** e = exp(e * log(b))
2517
+ ret = base.abs().log().mul(exponent).exp()
2518
+ # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
2519
+ negative_base = (base < 0).detach().where(1, 0)
2520
+ # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
2521
+ correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
2522
+ # inject nan for negative base and non-integer exponent
2523
+ inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
2524
+ # apply correct_sign inject_nan, and fix 0 ** 0 = 1
2525
+ return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
2526
+
2527
+ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
2528
+ """
2529
+ Computes element-wise maximum of `self` and `x`.
2530
+
2531
+ ```python exec="true" source="above" session="tensor" result="python"
2532
+ print(Tensor([-1, 2, 3]).maximum(1).numpy())
2533
+ ```
2534
+ ```python exec="true" source="above" session="tensor" result="python"
2535
+ print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
2536
+ ```
2537
+ """
2538
+ return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
2539
+
2540
+ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
2541
+ """
2542
+ Computes element-wise minimum of `self` and `x`.
2543
+
2544
+ ```python exec="true" source="above" session="tensor" result="python"
2545
+ print(Tensor([-1, 2, 3]).minimum(1).numpy())
2546
+ ```
2547
+ ```python exec="true" source="above" session="tensor" result="python"
2548
+ print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
2549
+ ```
2550
+ """
2551
+ return -((-self).maximum(-x))
2552
+
2553
+ def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
2554
+ """
2555
+ Return a tensor of elements selected from either `x` or `y`, depending on `self`.
2556
+ `output_i = x_i if self_i else y_i`.
2557
+
2558
+ ```python exec="true" source="above" session="tensor" result="python"
2559
+ cond = Tensor([[True, True, False], [True, False, False]])
2560
+ print(cond.where(1, 3).numpy())
2561
+ ```
2562
+ ```python exec="true" source="above" session="tensor" result="python"
2563
+ Tensor.manual_seed(42)
2564
+ cond = Tensor.randn(2, 3)
2565
+ print(cond.numpy())
2566
+ ```
2567
+ ```python exec="true" source="above" session="tensor" result="python"
2568
+ print((cond > 0).where(cond, -float("inf")).numpy())
2569
+ ```
2570
+ """
2571
+ if isinstance(x, Tensor): x, y = x._broadcasted(y)
2572
+ elif isinstance(y, Tensor): y, x = y._broadcasted(x)
2573
+ cond, x = self._broadcasted(x, match_dtype=False)
2574
+ cond, y = cond._broadcasted(y, match_dtype=False)
2575
+ return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
2576
+
2577
+ def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
2578
+
2579
+ # ***** op wrappers *****
839
2580
 
840
2581
  def __neg__(self) -> Tensor: return self.neg()
841
2582
 
@@ -846,6 +2587,8 @@ class Tensor:
846
2587
  def __truediv__(self, x) -> Tensor: return self.div(x)
847
2588
  def __matmul__(self, x) -> Tensor: return self.matmul(x)
848
2589
  def __xor__(self, x) -> Tensor: return self.xor(x)
2590
+ def __lshift__(self, x) -> Tensor: return self.lshift(x)
2591
+ def __rshift__(self, x) -> Tensor: return self.rshift(x)
849
2592
 
850
2593
  def __radd__(self, x) -> Tensor: return self.add(x, True)
851
2594
  def __rsub__(self, x) -> Tensor: return self.sub(x, True)
@@ -862,69 +2605,254 @@ class Tensor:
862
2605
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
863
2606
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
864
2607
  def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
2608
+ def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
2609
+ def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
865
2610
 
866
- def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
867
- def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
2611
+ def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
2612
+ def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
868
2613
  def __ge__(self, x) -> Tensor: return (self<x).logical_not()
869
2614
  def __le__(self, x) -> Tensor: return (self>x).logical_not()
870
- def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
871
- def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
2615
+ def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
2616
+ def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
872
2617
 
873
2618
  # ***** functional nn ops *****
874
2619
 
875
2620
  def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
2621
+ """
2622
+ Applies a linear transformation to `self` using `weight` and `bias`.
2623
+
2624
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
2625
+
2626
+ ```python exec="true" source="above" session="tensor" result="python"
2627
+ t = Tensor([[1, 2], [3, 4]])
2628
+ weight = Tensor([[1, 2], [3, 4]])
2629
+ bias = Tensor([1, 2])
2630
+ print(t.linear(weight, bias).numpy())
2631
+ ```
2632
+ """
876
2633
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
877
2634
  return x.add(bias) if bias is not None else x
878
2635
 
879
- def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
2636
+ def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
2637
+ """
2638
+ Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
2639
+
2640
+ ```python exec="true" source="above" session="tensor" result="python"
2641
+ t = Tensor([1, 2, 3])
2642
+ print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
2643
+ ```
2644
+ """
2645
+ return functools.reduce(lambda x,f: f(x), ll, self)
880
2646
 
881
2647
  def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
2648
+ """
2649
+ Applies Layer Normalization over a mini-batch of inputs.
2650
+
2651
+ - Described: https://paperswithcode.com/method/layer-normalization
2652
+ - Paper: https://arxiv.org/abs/1607.06450v1
2653
+
2654
+ ```python exec="true" source="above" session="tensor" result="python"
2655
+ t = Tensor.randn(8, 10, 16) * 2 + 8
2656
+ print(t.mean().item(), t.std().item())
2657
+ ```
2658
+ ```python exec="true" source="above" session="tensor" result="python"
2659
+ t = t.layernorm()
2660
+ print(t.mean().item(), t.std().item())
2661
+ ```
2662
+ """
882
2663
  y = (self - self.mean(axis, keepdim=True))
883
2664
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
884
2665
 
885
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
886
- x = (self - mean.reshape(shape=[1, -1, 1, 1]))
887
- if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
888
- ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
889
- return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
2666
+ def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
2667
+ """
2668
+ Applies Batch Normalization over a mini-batch of inputs.
2669
+
2670
+ - Described: https://paperswithcode.com/method/batch-normalization
2671
+ - Paper: https://arxiv.org/abs/1502.03167
2672
+
2673
+ ```python exec="true" source="above" session="tensor" result="python"
2674
+ t = Tensor.randn(8, 4, 16, 16) * 2 + 8
2675
+ print(t.mean().item(), t.std().item())
2676
+ ```
2677
+ ```python exec="true" source="above" session="tensor" result="python"
2678
+ t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
2679
+ print(t.mean().item(), t.std().item())
2680
+ ```
2681
+ """
2682
+ axis_ = argfix(axis)
2683
+ shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
2684
+ x = self - mean.reshape(shape)
2685
+ if weight is not None: x = x * weight.reshape(shape)
2686
+ ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
2687
+ return (ret + bias.reshape(shape)) if bias is not None else ret
890
2688
 
891
2689
  def dropout(self, p=0.5) -> Tensor:
892
- if not Tensor.training or p == 0: return self
893
- return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
2690
+ """
2691
+ Applies dropout to `self`.
2692
+
2693
+ NOTE: dropout is only applied when `Tensor.training` is `True`.
2694
+
2695
+ - Described: https://paperswithcode.com/method/dropout
2696
+ - Paper: https://jmlr.org/papers/v15/srivastava14a.html
894
2697
 
895
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
896
- # NOTE: it works if key, value have symbolic shape
2698
+ ```python exec="true" source="above" session="tensor" result="python"
2699
+ Tensor.manual_seed(42)
2700
+ t = Tensor.randn(2, 2)
2701
+ with Tensor.train():
2702
+ print(t.dropout().numpy())
2703
+ ```
2704
+ """
2705
+ if not Tensor.training or p == 0: return self
2706
+ return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
2707
+
2708
+ def one_hot(self, num_classes:int) -> Tensor:
2709
+ """
2710
+ Converts `self` to a one-hot tensor.
2711
+
2712
+ ```python exec="true" source="above" session="tensor" result="python"
2713
+ t = Tensor([0, 1, 3, 3, 4])
2714
+ print(t.one_hot(5).numpy())
2715
+ ```
2716
+ """
2717
+ return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
2718
+
2719
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
2720
+ dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
2721
+ """
2722
+ Computes scaled dot-product attention.
2723
+ `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
2724
+
2725
+ - Described: https://paperswithcode.com/method/scaled
2726
+ - Paper: https://arxiv.org/abs/1706.03762v7
2727
+
2728
+ ```python exec="true" source="above" session="tensor" result="python"
2729
+ q = Tensor.randn(2, 4, 8)
2730
+ k = Tensor.randn(2, 4, 8)
2731
+ v = Tensor.randn(2, 4, 8)
2732
+ print(q.scaled_dot_product_attention(k, v).numpy())
2733
+ ```
2734
+ """
2735
+ # NOTE: it also works when `key` and `value` have symbolic shape.
897
2736
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
898
2737
  if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
899
2738
  if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
900
- qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
901
- return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
2739
+ qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
2740
+ return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
902
2741
 
903
2742
  def binary_crossentropy(self, y:Tensor) -> Tensor:
2743
+ """
2744
+ Computes the binary cross-entropy loss between `self` and `y`.
2745
+
2746
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
2747
+
2748
+ ```python exec="true" source="above" session="tensor" result="python"
2749
+ t = Tensor([0.1, 0.9, 0.2])
2750
+ y = Tensor([0, 1, 0])
2751
+ print(t.binary_crossentropy(y).item())
2752
+ ```
2753
+ """
904
2754
  return (-y*self.log() - (1-y)*(1-self).log()).mean()
905
2755
 
906
2756
  def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
2757
+ """
2758
+ Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
2759
+
2760
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
2761
+
2762
+ ```python exec="true" source="above" session="tensor" result="python"
2763
+ t = Tensor([-1, 2, -3])
2764
+ y = Tensor([0, 1, 0])
2765
+ print(t.binary_crossentropy_logits(y).item())
2766
+ ```
2767
+ """
907
2768
  return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
908
2769
 
909
- def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
910
- # NOTE: self is a logits input
911
- loss_mask = (Y != ignore_index)
2770
+ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
2771
+ """
2772
+ Computes the sparse categorical cross-entropy loss between `self` and `Y`.
2773
+
2774
+ NOTE: `self` is logits and `Y` is the target labels.
2775
+
2776
+ See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
2777
+
2778
+ ```python exec="true" source="above" session="tensor" result="python"
2779
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
2780
+ Y = Tensor([1, 2])
2781
+ print(t.sparse_categorical_crossentropy(Y).item())
2782
+ ```
2783
+ """
2784
+ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
2785
+ log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
912
2786
  y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
913
- y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
914
- return self.log_softmax().mul(y).sum() / loss_mask.sum()
2787
+ y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
2788
+ smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
2789
+ return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
915
2790
 
916
2791
  # ***** cast ops *****
917
2792
 
918
- def cast(self, dtype:DType) -> Tensor:
919
- if self.dtype == dtype: return self
2793
+ def llvm_bf16_cast(self, dtype:DType):
920
2794
  # hack for devices that don't support bfloat16
921
- if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
922
- return mlops.Cast.apply(self, dtype=dtype)
2795
+ assert self.dtype == dtypes.bfloat16
2796
+ return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
2797
+ def cast(self, dtype:DType) -> Tensor:
2798
+ """
2799
+ Casts `self` to the given `dtype`.
2800
+
2801
+ ```python exec="true" source="above" session="tensor" result="python"
2802
+ t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
2803
+ print(t.dtype, t.numpy())
2804
+ ```
2805
+ ```python exec="true" source="above" session="tensor" result="python"
2806
+ t = t.cast(dtypes.int32)
2807
+ print(t.dtype, t.numpy())
2808
+ ```
2809
+ """
2810
+ return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
923
2811
  def bitcast(self, dtype:DType) -> Tensor:
924
- assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
925
- return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
926
- def float(self) -> Tensor: return self.cast(dtypes.float32)
927
- def half(self) -> Tensor: return self.cast(dtypes.float16)
2812
+ """
2813
+ Bitcasts `self` to the given `dtype` of the same itemsize.
2814
+
2815
+ `self` must not require a gradient.
2816
+
2817
+ ```python exec="true" source="above" session="tensor" result="python"
2818
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2819
+ print(t.dtype, t.numpy())
2820
+ ```
2821
+ ```python exec="true" source="above" session="tensor" result="python"
2822
+ t = t.bitcast(dtypes.uint32)
2823
+ print(t.dtype, t.numpy())
2824
+ ```
2825
+ """
2826
+ if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
2827
+ return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
2828
+ def float(self) -> Tensor:
2829
+ """
2830
+ Convenience method to cast `self` to a `float32` Tensor.
2831
+
2832
+ ```python exec="true" source="above" session="tensor" result="python"
2833
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2834
+ print(t.dtype, t.numpy())
2835
+ ```
2836
+ ```python exec="true" source="above" session="tensor" result="python"
2837
+ t = t.float()
2838
+ print(t.dtype, t.numpy())
2839
+ ```
2840
+ """
2841
+ return self.cast(dtypes.float32)
2842
+ def half(self) -> Tensor:
2843
+ """
2844
+ Convenience method to cast `self` to a `float16` Tensor.
2845
+
2846
+ ```python exec="true" source="above" session="tensor" result="python"
2847
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2848
+ print(t.dtype, t.numpy())
2849
+ ```
2850
+ ```python exec="true" source="above" session="tensor" result="python"
2851
+ t = t.half()
2852
+ print(t.dtype, t.numpy())
2853
+ ```
2854
+ """
2855
+ return self.cast(dtypes.float16)
928
2856
 
929
2857
  # ***** convenience stuff *****
930
2858
 
@@ -934,20 +2862,101 @@ class Tensor:
934
2862
  def element_size(self) -> int: return self.dtype.itemsize
935
2863
  def nbytes(self) -> int: return self.numel() * self.element_size()
936
2864
  def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
2865
+ def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
2866
+
2867
+ # *** image Tensor function replacements ***
2868
+
2869
+ def image_dot(self, w:Tensor, acc_dtype=None):
2870
+ # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
2871
+ n1, n2 = len(self.shape), len(w.shape)
2872
+ assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
2873
+ assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
2874
+ bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
2875
+ out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
2876
+
2877
+ # NOTE: with NHWC we can remove the transposes
2878
+ # bs x groups*cin x H x W
2879
+ cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
2880
+ # groups*cout x cin x H, W
2881
+ cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
2882
+ return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
2883
+
2884
+ def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
2885
+ base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
2886
+
2887
+ (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
2888
+ x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
2889
+
2890
+ # hack for non multiples of 4 on cin
2891
+ if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
2892
+ x = x.reshape(bs, groups, cin, iy, ix) # do this always?
2893
+ added_input_channels = 4 - (cin % 4)
2894
+ w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
2895
+ x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
2896
+ cin = cin + added_input_channels
2897
+ x = x.reshape(bs, groups*cin, iy, ix)
2898
+
2899
+ # hack for non multiples of 4 on rcout
2900
+ added_output_channels = 0
2901
+ if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
2902
+ added_output_channels = 4 - (rcout % 4)
2903
+ rcout += added_output_channels
2904
+ cout = groups * rcout
2905
+ w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
2906
+
2907
+ # packed (note: flipping bs and iy would make the auto-padding work)
2908
+ x = x.permute(0,2,3,1)
2909
+ cin_last = iy == 1 and ix == 1
2910
+ if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
2911
+ elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
2912
+ else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
2913
+
2914
+ # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
2915
+ if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
2916
+ x, w = x.contiguous(), w.contiguous()
2917
+
2918
+ # expand out
2919
+ rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
2920
+ cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
2921
+ x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
2922
+ if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
2923
+ else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
2924
+
2925
+ # padding
2926
+ padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
2927
+ x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
2928
+
2929
+ # prepare input
2930
+ x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
2931
+ x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
2932
+
2933
+ # prepare weights
2934
+ w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
2935
+
2936
+ # the conv!
2937
+ ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
2938
+
2939
+ # undo hack for non multiples of 4 on C.rcout
2940
+ if added_output_channels != 0:
2941
+ ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
2942
+ cout = groups * (rcout - added_output_channels)
2943
+
2944
+ # NCHW output
2945
+ ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
2946
+ return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
937
2947
 
938
2948
  # register functions to move between devices
939
- for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
2949
+ for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
940
2950
 
941
2951
  if IMAGE:
942
2952
  # if IMAGE>0 we install these replacement functions in Tensor (hack!)
943
- from tinygrad.features.image import image_conv2d, image_dot
944
- setattr(Tensor, "conv2d", image_conv2d)
945
- setattr(Tensor, "dot", image_dot)
2953
+ setattr(Tensor, "conv2d", Tensor.image_conv2d)
2954
+ setattr(Tensor, "dot", Tensor.image_dot)
946
2955
 
947
- # TODO: remove the custom op and replace with threefry
2956
+ # TODO: eventually remove this
948
2957
  def custom_random(out:Buffer):
949
2958
  Tensor._seed += 1
950
- if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
951
2959
  rng = np.random.default_rng(Tensor._seed)
952
- rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
2960
+ if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2961
+ else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
953
2962
  out.copyin(rng_np_buffer.data)