tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,19 +1,21 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools
4
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast, get_args
3
+ import time, math, itertools, functools
4
+ from contextlib import ContextDecorator
5
+ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
5
6
  from collections import defaultdict
6
- from functools import partialmethod, reduce
7
7
  import numpy as np
8
8
 
9
- from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
10
- from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
11
- from tinygrad.lazy import LazyBuffer, create_schedule
12
- from tinygrad.features.multi import MultiLazyBuffer
9
+ from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
+ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv
11
+ from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
12
+ from tinygrad.lazy import LazyBuffer
13
+ from tinygrad.multi import MultiLazyBuffer
13
14
  from tinygrad.ops import LoadOps
14
- from tinygrad.device import Device, Buffer
15
- from tinygrad.shape.symbolic import sint
16
- from tinygrad.realize import run_schedule
15
+ from tinygrad.device import Device, Buffer, BufferOptions
16
+ from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
17
+ from tinygrad.engine.realize import run_schedule
18
+ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
17
19
 
18
20
  # **** start with two base classes, Tensor and Function ****
19
21
 
@@ -30,29 +32,68 @@ class Function:
30
32
  @classmethod
31
33
  def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
32
34
  ctx = fxn(x[0].device, *x)
33
- ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
34
- if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
35
+ ret = Tensor.__new__(Tensor)
36
+ ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
+ ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
35
38
  return ret
36
39
 
37
- import tinygrad.mlops as mlops
40
+ import tinygrad.function as F
38
41
 
39
- def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Optional[LazyBuffer]=None):
42
+ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
40
43
  if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
41
44
  return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
42
45
 
43
- Scalar = Union[float, int, bool]
46
+ def _fromcpu(x: np.ndarray) -> LazyBuffer:
47
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
48
+ # fake realize
49
+ ret.buffer.allocate(x)
50
+ del ret.srcs
51
+ return ret
52
+
53
+ def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
54
+ return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
55
+ for k in range(len(mat[0]))] for dim in range(dims)]
56
+
57
+ # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
58
+ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
59
+ # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
60
+ # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
61
+ t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
62
+ # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
63
+ matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
64
+ # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
65
+ ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
66
+ assert isinstance(ret, Tensor), "sum didn't return a Tensor"
67
+ return ret
68
+
69
+ def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
70
+ def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
44
71
 
45
72
  class Tensor:
73
+ """
74
+ A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
75
+
76
+ ```python exec="true" session="tensor"
77
+ from tinygrad import Tensor, dtypes, nn
78
+ import numpy as np
79
+ import math
80
+ np.set_printoptions(precision=4)
81
+ ```
82
+ """
46
83
  __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
47
84
  __deletable__ = ('_ctx',)
48
85
  training: ClassVar[bool] = False
49
- class train:
50
- def __init__(self, val=True): self.val = val
51
- def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
86
+ class train(ContextDecorator):
87
+ def __init__(self, mode:bool = True): self.mode = mode
88
+ def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
52
89
  def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
53
90
 
54
91
  no_grad: ClassVar[bool] = False
55
- def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
92
+ class inference_mode(ContextDecorator):
93
+ def __init__(self, mode:bool = True): self.mode = mode
94
+ def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
95
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
96
+ def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
56
97
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
57
98
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
58
99
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
@@ -66,18 +107,19 @@ class Tensor:
66
107
  # internal variables used for autograd graph construction
67
108
  self._ctx: Optional[Function] = None
68
109
  if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
69
- elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
70
- elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
110
+ elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
111
+ elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
112
+ elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
71
113
  elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
72
114
  elif isinstance(data, list):
73
- if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
74
- elif d and all_int(d): dtype = dtype or dtypes.default_int
75
- else: dtype = dtype or dtypes.default_float
76
- # NOTE: cast at the end for the dtypes that do not have a numpy dtype
77
- data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
115
+ if dtype is None:
116
+ if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
117
+ else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
118
+ if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
119
+ else: data = _fromcpu(np.array(data, dtype.np))
78
120
  elif isinstance(data, np.ndarray):
79
121
  if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
80
- else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
122
+ else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
81
123
 
82
124
  # data is a LazyBuffer, but it might be on the wrong device
83
125
  if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
@@ -87,12 +129,15 @@ class Tensor:
87
129
  else:
88
130
  self.lazydata = data if data.device == device else data.copy_to_device(device)
89
131
 
90
- def __repr__(self):
91
- return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
132
+ def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
92
133
 
93
134
  # Python has a non moving GC, so this should be okay
94
135
  def __hash__(self): return id(self)
95
136
 
137
+ def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
138
+
139
+ def __len__(self): return self.shape[0] if len(self.shape) else 1
140
+
96
141
  @property
97
142
  def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
98
143
 
@@ -104,148 +149,513 @@ class Tensor:
104
149
 
105
150
  # ***** data handlers ****
106
151
 
107
- @staticmethod
108
- def corealize(lst:Iterable[Tensor]):
109
- return run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
152
+ def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
153
+ """Creates the schedule needed to realize these Tensor(s), with Variables."""
154
+ if getenv("FUZZ_SCHEDULE"):
155
+ from test.external.fuzz_schedule import fuzz_schedule
156
+ fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
157
+ schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
158
+ return memory_planner(schedule), var_vals
159
+
160
+ def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
161
+ """Creates the schedule needed to realize these Tensor(s)."""
162
+ schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
163
+ assert len(var_vals) == 0
164
+ return schedule
165
+
166
+ def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
167
+ """Triggers the computation needed to create these Tensor(s)."""
168
+ run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
169
+ return self
110
170
 
111
- def realize(self) -> Tensor:
112
- run_schedule(self.lazydata.schedule())
171
+ def replace(self, x:Tensor) -> Tensor:
172
+ """
173
+ Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
174
+ """
175
+ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
176
+ assert not x.requires_grad and getattr(self, '_ctx', None) is None
177
+ assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
178
+ self.lazydata = x.lazydata
113
179
  return self
114
180
 
115
181
  def assign(self, x) -> Tensor:
116
182
  # TODO: this is a hack for writing to DISK. remove with working assign
117
183
  if isinstance(self.device, str) and self.device.startswith("DISK"):
118
- if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
184
+ if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
119
185
  self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
120
186
  return self
121
187
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
188
+ if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
189
+ if self.lazydata is x.lazydata: return self # a self assign is a NOOP
122
190
  # NOTE: we allow cross device assign
123
191
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
192
+ assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
193
+ assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
194
+ assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
124
195
  assert not x.requires_grad # self requires_grad is okay?
125
- if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
126
- if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"):
127
- if isinstance(self.lazydata, MultiLazyBuffer):
128
- for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized
129
- else:
130
- if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized
131
- self.lazydata = x.lazydata
196
+ if not self.lazydata.is_realized(): return self.replace(x)
197
+ self.lazydata = self.lazydata.assign(x.lazydata)
132
198
  return self
133
- def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
134
-
135
- # TODO: these are good places to start removing numpy
136
- def item(self) -> Scalar:
199
+ def detach(self) -> Tensor:
200
+ """
201
+ Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
202
+ """
203
+ return Tensor(self.lazydata, device=self.device, requires_grad=False)
204
+
205
+ def _data(self) -> memoryview:
206
+ if 0 in self.shape: return memoryview(bytearray(0))
207
+ # NOTE: this realizes on the object from as_buffer being a Python object
208
+ cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
209
+ buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
210
+ if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
211
+ return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
212
+
213
+ def data(self) -> memoryview:
214
+ """
215
+ Returns the data of this tensor as a memoryview.
216
+
217
+ ```python exec="true" source="above" session="tensor" result="python"
218
+ t = Tensor([1, 2, 3, 4])
219
+ print(np.frombuffer(t.data(), dtype=np.int32))
220
+ ```
221
+ """
222
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
223
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
224
+ return self._data().cast(self.dtype.fmt, self.shape)
225
+
226
+ def item(self) -> ConstType:
227
+ """
228
+ Returns the value of this tensor as a standard Python number.
229
+
230
+ ```python exec="true" source="above" session="tensor" result="python"
231
+ t = Tensor(42)
232
+ print(t.item())
233
+ ```
234
+ """
235
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
137
236
  assert self.numel() == 1, "must have one element for item"
138
- return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
139
- def data(self) -> memoryview: return self.numpy().data
237
+ return self._data().cast(self.dtype.fmt)[0]
238
+
239
+ # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
240
+ # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
241
+ def tolist(self) -> Union[Sequence[ConstType], ConstType]:
242
+ """
243
+ Returns the value of this tensor as a nested list.
244
+
245
+ ```python exec="true" source="above" session="tensor" result="python"
246
+ t = Tensor([1, 2, 3, 4])
247
+ print(t.tolist())
248
+ ```
249
+ """
250
+ return self.data().tolist()
140
251
 
141
- # TODO: this should import numpy and use .data() to construct the array
142
252
  def numpy(self) -> np.ndarray:
143
- assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
144
- assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
145
- if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np)
146
- t = self if isinstance(self.device, str) else self.to("CPU")
147
- return t.cast(self.dtype.scalar()).contiguous().realize().lazydata.base.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
148
-
149
- def to(self, device:Optional[str]) -> Tensor:
150
- if device is None or device == self.device: return self
151
- ret = Tensor(self.lazydata, device)
152
- if self.grad: ret.grad = self.grad.to(device)
253
+ """
254
+ Returns the value of this tensor as a `numpy.ndarray`.
255
+
256
+ ```python exec="true" source="above" session="tensor" result="python"
257
+ t = Tensor([1, 2, 3, 4])
258
+ print(repr(t.numpy()))
259
+ ```
260
+ """
261
+ if self.dtype == dtypes.bfloat16: return self.float().numpy()
262
+ assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
263
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
264
+ return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
265
+
266
+ def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
267
+ """
268
+ Moves the tensor to the given device.
269
+ """
270
+ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
271
+ if device == self.device: return self
272
+ if not isinstance(device, str): return self.shard(device)
273
+ ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
274
+ if self.grad is not None: ret.grad = self.grad.to(device)
275
+ if hasattr(self, '_ctx'): ret._ctx = self._ctx
153
276
  return ret
154
277
 
155
- def to_(self, device:Optional[str]):
156
- if device is None or device == self.device: return
157
- if self.grad: self.grad = self.grad.to_(device)
158
- _ret = Tensor(self.lazydata, device)
159
- self.lazydata = _ret.lazydata
278
+ def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
279
+ """
280
+ Moves the tensor to the given device in place.
281
+ """
282
+ real = self.to(device)
283
+ # TODO: is this assign?
284
+ if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
285
+ self.lazydata = real.lazydata
160
286
 
161
287
  def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
288
+ """
289
+ Shards the tensor across the given devices.
290
+ """
162
291
  assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
163
292
  canonical_devices = tuple(Device.canonicalize(x) for x in devices)
164
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices)
293
+ if axis is not None and axis < 0: axis += len(self.shape)
294
+ return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
165
295
 
166
296
  def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
297
+ """
298
+ Shards the tensor across the given devices in place.
299
+ """
167
300
  self.lazydata = self.shard(devices, axis).lazydata
168
301
  return self
169
302
 
303
+ @staticmethod
304
+ def from_node(y:Node, **kwargs) -> Tensor:
305
+ if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
306
+ if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
307
+ raise RuntimeError(f"unhandled Node {y}")
308
+
170
309
  # ***** creation llop entrypoint *****
171
310
 
172
311
  @staticmethod
173
- def _loadop(op, shape, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
174
- return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
312
+ def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
313
+ if isinstance(device, tuple):
314
+ return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
315
+ for d in device], None), device, dtype, **kwargs)
316
+ return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
175
317
 
176
318
  @staticmethod
177
- def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
319
+ def empty(*shape, **kwargs):
320
+ """
321
+ Creates an empty tensor with the given shape.
322
+
323
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
324
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
325
+
326
+ ```python exec="true" source="above" session="tensor" result="python"
327
+ t = Tensor.empty(2, 3)
328
+ print(t.shape)
329
+ ```
330
+ """
331
+ return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
178
332
 
179
333
  _seed: int = int(time.time())
334
+ _rng_counter: Optional[Tensor] = None
180
335
  @staticmethod
181
- def manual_seed(seed=0): Tensor._seed = seed
336
+ def manual_seed(seed=0):
337
+ """
338
+ Sets the seed for random operations.
339
+
340
+ ```python exec="true" source="above" session="tensor" result="python"
341
+ Tensor.manual_seed(42)
342
+ print(Tensor._seed)
343
+ ```
344
+ """
345
+ Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
182
346
 
183
347
  @staticmethod
184
- def rand(*shape, **kwargs): return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs)
348
+ def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
349
+ """
350
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
351
+
352
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
353
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
354
+
355
+ ```python exec="true" source="above" session="tensor" result="python"
356
+ Tensor.manual_seed(42)
357
+ t = Tensor.rand(2, 3)
358
+ print(t.numpy())
359
+ ```
360
+ """
361
+ if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
362
+ if not THREEFRY.value:
363
+ # for bfloat16, numpy rand passes buffer in float
364
+ if (dtype or dtypes.default_float) == dtypes.bfloat16:
365
+ return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
366
+ return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
367
+
368
+ # threefry
369
+ if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
370
+ counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
371
+ counts2 = counts1 + math.ceil(num / 2)
372
+ Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
373
+
374
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
375
+ ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
376
+
377
+ x = [counts1 + ks[-1], counts2 + ks[0]]
378
+ for i in range(5):
379
+ for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
380
+ x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
381
+ out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
382
+ out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
383
+ out.requires_grad = kwargs.get("requires_grad")
384
+ return out.contiguous()
185
385
 
186
386
  # ***** creation helper functions *****
187
387
 
188
388
  @staticmethod
189
- def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
389
+ def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
390
+ """
391
+ Creates a tensor with the given shape, filled with the given value.
392
+
393
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
394
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
395
+
396
+ ```python exec="true" source="above" session="tensor" result="python"
397
+ print(Tensor.full((2, 3), 42).numpy())
398
+ ```
399
+ ```python exec="true" source="above" session="tensor" result="python"
400
+ print(Tensor.full((2, 3), False).numpy())
401
+ ```
402
+ """
190
403
  return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
191
404
 
192
405
  @staticmethod
193
- def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
406
+ def zeros(*shape, **kwargs):
407
+ """
408
+ Creates a tensor with the given shape, filled with zeros.
409
+
410
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
411
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
412
+
413
+ ```python exec="true" source="above" session="tensor" result="python"
414
+ print(Tensor.zeros(2, 3).numpy())
415
+ ```
416
+ ```python exec="true" source="above" session="tensor" result="python"
417
+ print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
418
+ ```
419
+ """
420
+ return Tensor.full(argfix(*shape), 0.0, **kwargs)
194
421
 
195
422
  @staticmethod
196
- def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
423
+ def ones(*shape, **kwargs):
424
+ """
425
+ Creates a tensor with the given shape, filled with ones.
426
+
427
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
428
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
429
+
430
+ ```python exec="true" source="above" session="tensor" result="python"
431
+ print(Tensor.ones(2, 3).numpy())
432
+ ```
433
+ ```python exec="true" source="above" session="tensor" result="python"
434
+ print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
435
+ ```
436
+ """
437
+ return Tensor.full(argfix(*shape), 1.0, **kwargs)
197
438
 
198
439
  @staticmethod
199
440
  def arange(start, stop=None, step=1, **kwargs):
441
+ """
442
+ Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
443
+
444
+ If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
445
+
446
+ If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
447
+
448
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
449
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
450
+
451
+ ```python exec="true" source="above" session="tensor" result="python"
452
+ print(Tensor.arange(5).numpy())
453
+ ```
454
+ ```python exec="true" source="above" session="tensor" result="python"
455
+ print(Tensor.arange(5, 10).numpy())
456
+ ```
457
+ ```python exec="true" source="above" session="tensor" result="python"
458
+ print(Tensor.arange(5, 10, 2).numpy())
459
+ ```
460
+ ```python exec="true" source="above" session="tensor" result="python"
461
+ print(Tensor.arange(5.5, 10, 2).numpy())
462
+ ```
463
+ """
200
464
  if stop is None: stop, start = start, 0
465
+ assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
201
466
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
202
- return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
467
+ return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
203
468
 
204
469
  @staticmethod
205
470
  def eye(dim:int, **kwargs):
471
+ """
472
+ Creates an identity matrix of the given dimension.
473
+
474
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
475
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
476
+
477
+ ```python exec="true" source="above" session="tensor" result="python"
478
+ print(Tensor.eye(3).numpy())
479
+ ```
480
+ """
206
481
  return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
207
482
 
208
- def full_like(self, fill_value:Scalar, **kwargs):
483
+ def full_like(self, fill_value:ConstType, **kwargs):
484
+ """
485
+ Creates a tensor with the same shape as `self`, filled with the given value.
486
+ If `dtype` is not specified, the dtype of `self` is used.
487
+
488
+ You can pass in the `device` keyword argument to control device of the tensor.
489
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
490
+
491
+ ```python exec="true" source="above" session="tensor" result="python"
492
+ t = Tensor.ones(2, 3)
493
+ print(Tensor.full_like(t, 42).numpy())
494
+ ```
495
+ """
209
496
  return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
210
- def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
211
- def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
497
+
498
+ def zeros_like(self, **kwargs):
499
+ """
500
+ Creates a tensor with the same shape as `self`, filled with zeros.
501
+
502
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
503
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
504
+
505
+ ```python exec="true" source="above" session="tensor" result="python"
506
+ t = Tensor.ones(2, 3)
507
+ print(Tensor.zeros_like(t).numpy())
508
+ ```
509
+ """
510
+ return self.full_like(0, **kwargs)
511
+
512
+ def ones_like(self, **kwargs):
513
+ """
514
+ Creates a tensor with the same shape as `self`, filled with ones.
515
+
516
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
517
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
518
+
519
+ ```python exec="true" source="above" session="tensor" result="python"
520
+ t = Tensor.zeros(2, 3)
521
+ print(Tensor.ones_like(t).numpy())
522
+ ```
523
+ """
524
+ return self.full_like(1, **kwargs)
212
525
 
213
526
  # ***** rng hlops *****
214
527
 
215
528
  @staticmethod
216
529
  def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
530
+ """
531
+ Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
532
+ If `dtype` is not specified, the default type is used.
533
+
534
+ You can pass in the `device` keyword argument to control device of the tensor.
535
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
536
+
537
+ ```python exec="true" source="above" session="tensor" result="python"
538
+ Tensor.manual_seed(42)
539
+ print(Tensor.randn(2, 3).numpy())
540
+ ```
541
+ """
217
542
  # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
218
- src = Tensor.rand((2, *argfix(*shape)), **kwargs)
543
+ src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
219
544
  return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
220
545
 
221
546
  @staticmethod
222
- def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32)
547
+ def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
548
+ """
549
+ Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
550
+ If `dtype` is not specified, the default type is used.
551
+
552
+ You can pass in the `device` keyword argument to control device of the tensor.
553
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
554
+
555
+ ```python exec="true" source="above" session="tensor" result="python"
556
+ Tensor.manual_seed(42)
557
+ print(Tensor.randint(2, 3, low=5, high=10).numpy())
558
+ ```
559
+ """
560
+ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
561
+ dtype = kwargs.pop("dtype", dtypes.int32)
562
+ if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
563
+ return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
223
564
 
224
565
  @staticmethod
225
- def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
566
+ def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
567
+ """
568
+ Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
569
+
570
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
571
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
572
+
573
+ ```python exec="true" source="above" session="tensor" result="python"
574
+ Tensor.manual_seed(42)
575
+ print(Tensor.normal(2, 3, mean=10, std=2).numpy())
576
+ ```
577
+ """
578
+ return (std * Tensor.randn(*shape, **kwargs)) + mean
226
579
 
227
580
  @staticmethod
228
581
  def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
582
+ """
583
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
584
+
585
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
586
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
587
+
588
+ ```python exec="true" source="above" session="tensor" result="python"
589
+ Tensor.manual_seed(42)
590
+ print(Tensor.uniform(2, 3, low=2, high=10).numpy())
591
+ ```
592
+ """
229
593
  dtype = kwargs.pop("dtype", dtypes.default_float)
230
594
  return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
231
595
 
232
596
  @staticmethod
233
- def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
597
+ def scaled_uniform(*shape, **kwargs) -> Tensor:
598
+ """
599
+ Creates a tensor with the given shape, filled with random values from a uniform distribution
600
+ over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
601
+
602
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
603
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
604
+
605
+ ```python exec="true" source="above" session="tensor" result="python"
606
+ Tensor.manual_seed(42)
607
+ print(Tensor.scaled_uniform(2, 3).numpy())
608
+ ```
609
+ """
610
+ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
234
611
 
235
612
  # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
236
613
  @staticmethod
237
614
  def glorot_uniform(*shape, **kwargs) -> Tensor:
615
+ """
616
+ <https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
617
+
618
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
619
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
620
+
621
+ ```python exec="true" source="above" session="tensor" result="python"
622
+ Tensor.manual_seed(42)
623
+ print(Tensor.glorot_uniform(2, 3).numpy())
624
+ ```
625
+ """
238
626
  return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
239
627
 
240
628
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
241
629
  @staticmethod
242
630
  def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
631
+ """
632
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
633
+
634
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
635
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
636
+
637
+ ```python exec="true" source="above" session="tensor" result="python"
638
+ Tensor.manual_seed(42)
639
+ print(Tensor.kaiming_uniform(2, 3).numpy())
640
+ ```
641
+ """
243
642
  bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
244
643
  return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
245
644
 
246
645
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
247
646
  @staticmethod
248
647
  def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
648
+ """
649
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
650
+
651
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
652
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
653
+
654
+ ```python exec="true" source="above" session="tensor" result="python"
655
+ Tensor.manual_seed(42)
656
+ print(Tensor.kaiming_normal(2, 3).numpy())
657
+ ```
658
+ """
249
659
  std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
250
660
  return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
251
661
 
@@ -254,31 +664,40 @@ class Tensor:
254
664
  assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
255
665
  weight = self.unsqueeze(0) if self.ndim == 1 else self
256
666
  cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
257
- unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
667
+ unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
258
668
  indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
259
- return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
669
+ return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
260
670
 
261
671
  # ***** toposort and backward pass *****
262
672
 
263
- def deepwalk(self):
264
- def _deepwalk(node, visited, nodes):
673
+ def _deepwalk(self):
674
+ def _walk(node, visited):
265
675
  visited.add(node)
266
676
  if getattr(node, "_ctx", None):
267
677
  for i in node._ctx.parents:
268
- if i not in visited: _deepwalk(i, visited, nodes)
269
- nodes.append(node)
270
- return nodes
271
- return _deepwalk(self, set(), [])
678
+ if i not in visited: yield from _walk(i, visited)
679
+ yield node
680
+ return list(_walk(self, set()))
272
681
 
273
682
  def backward(self) -> Tensor:
683
+ """
684
+ Propagates the gradient of a tensor backwards through the computation graph.
685
+ Must be used on a scalar tensor.
686
+
687
+ ```python exec="true" source="above" session="tensor" result="python"
688
+ t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
689
+ t.sum().backward()
690
+ print(t.grad.numpy())
691
+ ```
692
+ """
274
693
  assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
275
694
 
276
695
  # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
277
696
  # this is "implicit gradient creation"
278
- self.grad = Tensor(1.0, device=self.device, requires_grad=False)
697
+ self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
279
698
 
280
- for t0 in reversed(self.deepwalk()):
281
- assert (t0.grad is not None)
699
+ for t0 in reversed(self._deepwalk()):
700
+ if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
282
701
  grads = t0._ctx.backward(t0.grad.lazydata)
283
702
  grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
284
703
  for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
@@ -289,67 +708,164 @@ class Tensor:
289
708
  del t0._ctx
290
709
  return self
291
710
 
292
- # ***** movement mlops *****
711
+ # ***** movement low level ops *****
712
+
713
+ def view(self, *shape) -> Tensor:
714
+ """`.view` is an alias for `.reshape`."""
715
+ return self.reshape(shape)
293
716
 
294
717
  def reshape(self, shape, *args) -> Tensor:
718
+ """
719
+ Returns a tensor with the same data as the original tensor but with a different shape.
720
+ `shape` can be passed as a tuple or as separate arguments.
721
+
722
+ ```python exec="true" source="above" session="tensor" result="python"
723
+ t = Tensor.arange(6)
724
+ print(t.reshape(2, 3).numpy())
725
+ ```
726
+ """
295
727
  new_shape = argfix(shape, *args)
296
- return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) # noqa: E501
728
+ new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
729
+ return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
730
+
297
731
  def expand(self, shape, *args) -> Tensor:
298
- if shape == self.shape: return self
299
- return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
300
- def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
301
- def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
732
+ """
733
+ Returns a tensor that is expanded to the shape that is specified.
734
+ Expand can also increase the number of dimensions that a tensor has.
735
+
736
+ Passing a `-1` or `None` to a dimension means that its size will not be changed.
737
+
738
+ ```python exec="true" source="above" session="tensor" result="python"
739
+ t = Tensor([1, 2, 3])
740
+ print(t.expand(4, -1).numpy())
741
+ ```
742
+ """
743
+ return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
744
+
745
+ def permute(self, order, *args) -> Tensor:
746
+ """
747
+ Returns a tensor that is a permutation of the original tensor.
748
+ The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
749
+ `order` can be passed as a tuple or as separate arguments.
750
+
751
+ ```python exec="true" source="above" session="tensor" result="python"
752
+ t = Tensor.arange(6).reshape(2, 3)
753
+ print(t.numpy())
754
+ ```
755
+ ```python exec="true" source="above" session="tensor" result="python"
756
+ print(t.permute(1, 0).numpy())
757
+ ```
758
+ """
759
+ return F.Permute.apply(self, order=argfix(order, *args))
760
+
761
+ def flip(self, axis, *args) -> Tensor:
762
+ """
763
+ Returns a tensor that reverses the order of the original tensor along given `axis`.
764
+ `axis` can be passed as a tuple or as separate arguments.
765
+
766
+ ```python exec="true" source="above" session="tensor" result="python"
767
+ t = Tensor.arange(6).reshape(2, 3)
768
+ print(t.numpy())
769
+ ```
770
+ ```python exec="true" source="above" session="tensor" result="python"
771
+ print(t.flip(0).numpy())
772
+ ```
773
+ ```python exec="true" source="above" session="tensor" result="python"
774
+ print(t.flip((0, 1)).numpy())
775
+ ```
776
+ """
777
+ return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
778
+
302
779
  def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
303
- if not any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)): return self
304
- return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
780
+ """
781
+ Returns a tensor that shrinks the each axis based on input arg.
782
+ `arg` must have the same length as `self.ndim`.
783
+ For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
784
+
785
+ ```python exec="true" source="above" session="tensor" result="python"
786
+ t = Tensor.arange(9).reshape(3, 3)
787
+ print(t.numpy())
788
+ ```
789
+ ```python exec="true" source="above" session="tensor" result="python"
790
+ print(t.shrink(((None, (1, 3)))).numpy())
791
+ ```
792
+ ```python exec="true" source="above" session="tensor" result="python"
793
+ print(t.shrink((((0, 2), (0, 2)))).numpy())
794
+ ```
795
+ """
796
+ if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
797
+ return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
798
+
305
799
  def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
800
+ """
801
+ Returns a tensor that pads the each axis based on input arg.
802
+ `arg` must have the same length as `self.ndim`.
803
+ For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
804
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
805
+
806
+ ```python exec="true" source="above" session="tensor" result="python"
807
+ t = Tensor.arange(6).reshape(2, 3)
808
+ print(t.numpy())
809
+ ```
810
+ ```python exec="true" source="above" session="tensor" result="python"
811
+ print(t.pad(((None, (1, 2)))).numpy())
812
+ ```
813
+ ```python exec="true" source="above" session="tensor" result="python"
814
+ print(t.pad(((None, (1, 2))), -2).numpy())
815
+ ```
816
+ """
306
817
  if all(x is None or x == (0,0) for x in arg): return self
307
- ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
308
- return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
309
-
310
- # ***** movement hlops *****
311
-
312
- # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
313
- # - A slice i:j returns the elements with indices in [i, j)
314
- # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
315
- # - Negative values for i and j are taken relative to the end of the sequence
316
- # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
317
- # - Indexing with None on a given axis will add a new dimension of size one before that axis
318
- # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
319
- # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
320
- # - Strides > 1 and < 0 are now allowed!:
321
- # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
322
- # - Idea of stride < 0 support:
323
- # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
324
- # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
325
- # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
326
- # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
327
- # is possible.
328
- # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
329
- # - Fancy indexing and combined indexing is supported
330
- # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
331
- # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
332
- # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
333
- # - There's a special case where a permute is needed at the end:
334
- # - if first Tensor passed in (expand dims) is not at dim 0
335
- # - and following Tensors does not follow consecutively to the end of fancy indexing's dims
336
- # TODO: boolean indices
337
- # TODO: figure out the exact acceptable types for indices, especially for internal list/tuple types
338
- # TODO: update docs
339
- def __getitem__(self, indices: Union[int, slice, Tensor, None, List, Tuple]) -> Tensor: # no ellipsis type...
818
+ ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
819
+ return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
820
+
821
+ # ***** movement high level ops *****
822
+
823
+ # Supported Indexing Implementations:
824
+ # 1. Int indexing (no copy)
825
+ # - for all dims where there's int, shrink -> reshape
826
+ # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
827
+ # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
828
+ # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
829
+ # 2. Slice indexing (no copy)
830
+ # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
831
+ # - first shrink the Tensor to X.shrink(((start, end),))
832
+ # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
833
+ # - flip where dim value is negative
834
+ # - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
835
+ # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
836
+ # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
837
+ # 3. None indexing (no copy)
838
+ # - reshape (inject) a dim at the dim where there's None
839
+ # 4. Tensor indexing (copy)
840
+ # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
841
+ # - combine masks together with mul
842
+ # - apply mask to self by mask * self
843
+ # - sum reduce away the extra dims added from creating masks
844
+ # Tiny Things:
845
+ # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
846
+ # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
847
+ # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
848
+ # 2. Bool indexing is not supported
849
+ # 3. Out of bounds Tensor indexing results in 0
850
+ # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
851
+ def __getitem__(self, indices) -> Tensor:
340
852
  # 1. indices normalization and validation
341
853
  # treat internal tuples and lists as Tensors and standardize indices to list type
342
- if isinstance(indices, (tuple, list)):
343
- # special case <indices: List[int]>, a lil ugly
344
- if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, requires_grad=False, device=self.device)]
345
- else: indices = [Tensor(list(i), requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices]
854
+ if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
855
+ elif isinstance(indices, (tuple, list)):
856
+ indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
346
857
  else: indices = [indices]
347
858
 
859
+ # turn scalar Tensors into const val for int indexing if possible
860
+ indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
861
+ # move Tensor indices to the same device as self
862
+ indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
863
+
348
864
  # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
349
865
  ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
350
866
  fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
351
- num_slices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
352
- indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
867
+ num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
868
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
353
869
 
354
870
  # use Dict[type, List[dimension]] to track elements in indices
355
871
  type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
@@ -359,212 +875,640 @@ class Tensor:
359
875
  indices_filtered = [v for v in indices if v is not None]
360
876
  for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
361
877
 
362
- # validation! raise Errors
363
- if slice in type_dim and self.ndim == 0: raise IndexError("slice cannot be applied to a 0-dim tensor.")
364
- if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
365
- if float in type_dim: raise IndexError("float type is not valid index")
366
- if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
367
- if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
878
+ for index_type in type_dim:
879
+ if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
880
+ if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
881
+ if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
368
882
 
369
- # 2. basic indexing (no copy)
883
+ # 2. basic indexing, uses only movement ops (no copy)
370
884
  # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
371
885
  # turn indices in indices_filtered to Tuple[shrink_arg, strides]
372
886
  for dim in type_dim[int]:
373
887
  if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
374
- raise IndexError(f"{index=} is out of bounds for dimension {dim} with {size=}")
888
+ raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
375
889
  indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
376
890
  for dim in type_dim[slice]:
377
- s, e, st = indices_filtered[dim].indices(self.shape[dim])
378
- indices_filtered[dim] = ((0, 0) if (st > 0 and e < s) or (st <= 0 and e > s) else (s, e) if st > 0 else (e+1, s+1), st)
379
- for dim in type_dim[Tensor]: indices_filtered[dim] = ((0, self.shape[dim]), 1)
891
+ if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
892
+ s, e, st = index.indices(self.shape[dim])
893
+ indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
894
+ # record tensors and skip all Tensor dims for basic indexing
895
+ tensor_index: List[Tensor] = []
896
+ for dim in type_dim[Tensor]:
897
+ tensor_index.append(index := indices_filtered[dim])
898
+ if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
899
+ indices_filtered[dim] = ((0, self.shape[dim]), 1)
380
900
 
381
901
  new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
382
- ret = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
383
- # add strides by pad -> reshape -> shrink
902
+ ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
384
903
  if any(abs(s) != 1 for s in strides):
385
904
  strides = tuple(abs(s) for s in strides)
386
905
  ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
387
- ret = ret.reshape(flatten([sh // s, s] for s, sh in zip(strides, ret.shape)))
906
+ ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
388
907
  ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
389
908
 
390
909
  # inject 1 for dim where it's None and collapse dim for int
391
910
  new_shape = list(ret.shape)
392
911
  for dim in type_dim[None]: new_shape.insert(dim, 1)
393
- for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
394
- assert all_int(new_shape), f"does not support symbolic shape {new_shape}"
912
+ for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
395
913
 
396
914
  ret = ret.reshape(new_shape)
915
+ assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
397
916
 
398
917
  # 3. advanced indexing (copy)
399
918
  if type_dim[Tensor]:
919
+ # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
920
+ def calc_dim(tensor_dim:int) -> int:
921
+ return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
922
+
923
+ # track tensor_dim and tensor_index using a dict
924
+ # calc_dim to get dim and use that to normalize the negative tensor indices
925
+ idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
400
926
 
401
- # extract tensors and tensor dimensions
402
- idx, tdim = [], []
403
- for tensor_dim in type_dim[Tensor]:
404
- dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
405
- tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
406
- # normalize the negative tensor indices
407
- idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
408
- # TODO uint8 and bool tensor indexing
409
- if not (dtypes.is_int(t.dtype) or t.dtype == dtypes.bool): raise IndexError("tensors used as indices must be int or bool tensors")
410
-
411
- # compute sum_dim, arange, and idx
412
- max_dim = max(i.ndim for i in idx)
413
- sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
414
- arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
415
- first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
416
- rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
417
- reshaped_idx = first_idx + rest_idx
418
- ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
419
-
420
- # iteratively eq -> mul -> sum fancy index
421
- try:
422
- for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
423
- except AssertionError as exc: raise IndexError(f"cannot broadcast with index shapes {', '.join(str(i.shape) for i in idx)}") from exc
927
+ masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
928
+ pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
929
+
930
+ # create masks
931
+ for dim, i in idx.items():
932
+ try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
933
+ except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
934
+ a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
935
+ masks.append(i == a)
936
+
937
+ # reduce masks to 1 mask
938
+ mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
939
+
940
+ # inject 1's for the extra dims added in create masks
941
+ sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
942
+ # sum reduce the extra dims introduced in create masks
943
+ ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
424
944
 
425
945
  # special permute case
426
- if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):
427
- ret_dims = list(range(ret.ndim))
428
- ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
946
+ if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
947
+ ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
429
948
  return ret
430
949
 
431
- def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
432
-
433
- # NOTE: using slice is discouraged and things should migrate to pad and shrink
434
- def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
950
+ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
951
+ if isinstance(self.device, str) and self.device.startswith("DISK"):
952
+ self.__getitem__(indices).assign(v)
953
+ return
954
+ # NOTE: check that setitem target is valid first
955
+ assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
956
+ if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
957
+ if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
958
+ if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
959
+ if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
960
+ raise NotImplementedError("Advanced indexing setitem is not currently supported")
961
+
962
+ assign_to = self.realize().__getitem__(indices)
963
+ # NOTE: contiguous to prevent const folding.
964
+ v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
965
+ assign_to.assign(v).realize()
966
+
967
+ # NOTE: using _slice is discouraged and things should migrate to pad and shrink
968
+ def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
435
969
  arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
436
970
  padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
437
971
  return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
438
972
 
439
- def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
440
- assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
441
- assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
442
- if dim < 0: dim += self.ndim
443
- idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
973
+ def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
974
+ """
975
+ Gathers values along an axis specified by `dim`.
976
+
977
+ ```python exec="true" source="above" session="tensor" result="python"
978
+ t = Tensor([[1, 2], [3, 4]])
979
+ print(t.numpy())
980
+ ```
981
+ ```python exec="true" source="above" session="tensor" result="python"
982
+ print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
983
+ ```
984
+ """
985
+ assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
986
+ assert all(s >= i for s,i in zip(self.shape, index.shape)), "all dim of index.shape must be smaller than self.shape"
987
+ dim = self._resolve_dim(dim)
988
+ index = index.to(self.device).transpose(0, dim).unsqueeze(-1)
444
989
  permarg = list(range(self.ndim))
445
990
  permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
446
- return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501
991
+ return ((index == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
992
+ tuple([*[(0,sh) for sh in index.shape[1:-1]], None])).unsqueeze(0)).sum(-1, acc_dtype=self.dtype).transpose(0, dim)
447
993
 
448
994
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
449
- if dim < 0: dim += self.ndim
995
+ """
996
+ Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
997
+ All tensors must have the same shape except in the concatenating dimension.
998
+
999
+ ```python exec="true" source="above" session="tensor" result="python"
1000
+ t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
1001
+ print(t0.cat(t1, t2, dim=0).numpy())
1002
+ ```
1003
+ ```python exec="true" source="above" session="tensor" result="python"
1004
+ print(t0.cat(t1, t2, dim=1).numpy())
1005
+ ```
1006
+ """
1007
+ dim = self._resolve_dim(dim)
450
1008
  assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
451
1009
  catargs = [self, *args]
452
- assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
453
- shapes = [s.shape[dim] for s in catargs]
454
- shape_cumsum = [0, *itertools.accumulate(shapes)]
1010
+ cat_dims = [s.shape[dim] for s in catargs]
1011
+ cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
455
1012
  slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
456
- for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp)
457
- return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
458
-
459
- @staticmethod
460
- def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
461
- unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
1013
+ for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
1014
+ return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
1015
+
1016
+ def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1017
+ """
1018
+ Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
1019
+
1020
+ ```python exec="true" source="above" session="tensor" result="python"
1021
+ t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
1022
+ print(t0.stack(t1, t2, dim=0).numpy())
1023
+ ```
1024
+ ```python exec="true" source="above" session="tensor" result="python"
1025
+ print(t0.stack(t1, t2, dim=1).numpy())
1026
+ ```
1027
+ """
462
1028
  # checks for shapes and number of dimensions delegated to cat
463
- return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
464
-
465
- def repeat(self, repeats:Sequence[int]) -> Tensor:
1029
+ return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
1030
+
1031
+ def repeat(self, repeats, *args) -> Tensor:
1032
+ """
1033
+ Repeats tensor number of times along each dimension specified by `repeats`.
1034
+ `repeats` can be passed as a tuple or as separate arguments.
1035
+
1036
+ ```python exec="true" source="above" session="tensor" result="python"
1037
+ t = Tensor([1, 2, 3])
1038
+ print(t.repeat(4, 2).numpy())
1039
+ ```
1040
+ ```python exec="true" source="above" session="tensor" result="python"
1041
+ print(t.repeat(4, 2, 1).shape)
1042
+ ```
1043
+ """
1044
+ repeats = argfix(repeats, *args)
466
1045
  base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
467
1046
  new_shape = [x for b in base_shape for x in [1, b]]
468
1047
  expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
469
1048
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
470
1049
  return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
471
1050
 
1051
+ def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
1052
+ if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
1053
+ raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
1054
+ return dim + self.ndim+outer if dim < 0 else dim
1055
+
472
1056
  def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1057
+ """
1058
+ Splits the tensor into chunks along the dimension specified by `dim`.
1059
+ If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
1060
+ If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
1061
+
1062
+ ```python exec="true" source="above" session="tensor" result="python"
1063
+ t = Tensor.arange(10).reshape(5, 2)
1064
+ print(t.numpy())
1065
+ ```
1066
+ ```python exec="true" source="above" session="tensor" result="python"
1067
+ split = t.split(2)
1068
+ print("\\n".join([repr(x.numpy()) for x in split]))
1069
+ ```
1070
+ ```python exec="true" source="above" session="tensor" result="python"
1071
+ split = t.split([1, 4])
1072
+ print("\\n".join([repr(x.numpy()) for x in split]))
1073
+ ```
1074
+ """
473
1075
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
474
- dim = dim + self.ndim if dim < 0 else dim
475
- if isinstance(sizes, int): return tuple(self.chunk(math.ceil(self.shape[dim]/sizes)))
1076
+ dim = self._resolve_dim(dim)
1077
+ if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
1078
+ assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
476
1079
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
477
1080
 
478
- def chunk(self, num:int, dim:int=0) -> List[Tensor]:
1081
+ def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
1082
+ """
1083
+ Splits the tensor into `chunks` number of chunks along the dimension `dim`.
1084
+ If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
1085
+ The function may return fewer than the specified number of chunks.
1086
+
1087
+ ```python exec="true" source="above" session="tensor" result="python"
1088
+ chunked = Tensor.arange(11).chunk(6)
1089
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1090
+ ```
1091
+ ```python exec="true" source="above" session="tensor" result="python"
1092
+ chunked = Tensor.arange(12).chunk(6)
1093
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1094
+ ```
1095
+ ```python exec="true" source="above" session="tensor" result="python"
1096
+ chunked = Tensor.arange(13).chunk(6)
1097
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1098
+ ```
1099
+ """
479
1100
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
480
- dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
481
- slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
482
- return [self[tuple(sl)] for sl in slice_params]
1101
+ assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
1102
+ dim = self._resolve_dim(dim)
1103
+ return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
483
1104
 
484
1105
  def squeeze(self, dim:Optional[int]=None) -> Tensor:
1106
+ """
1107
+ Returns a tensor with specified dimensions of input of size 1 removed.
1108
+ If `dim` is not specified, all dimensions with size 1 are removed.
1109
+
1110
+ ```python exec="true" source="above" session="tensor" result="python"
1111
+ t = Tensor.zeros(2, 1, 2, 1, 2)
1112
+ print(t.squeeze().shape)
1113
+ ```
1114
+ ```python exec="true" source="above" session="tensor" result="python"
1115
+ print(t.squeeze(0).shape)
1116
+ ```
1117
+ ```python exec="true" source="above" session="tensor" result="python"
1118
+ print(t.squeeze(1).shape)
1119
+ ```
1120
+ """
485
1121
  if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
486
- if self.ndim == 0 and dim in [-1, 0]: return self # this is to match torch behavior
487
- if not -self.ndim <= dim <= self.ndim-1: raise IndexError(f"{dim=} out of range {[-self.ndim, self.ndim-1] if self.ndim else [-1, 0]}")
488
- if dim < 0: dim += self.ndim
489
- return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
1122
+ dim = self._resolve_dim(dim)
1123
+ return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
490
1124
 
491
1125
  def unsqueeze(self, dim:int) -> Tensor:
492
- if dim < 0: dim = self.ndim + dim + 1
1126
+ """
1127
+ Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
1128
+
1129
+ ```python exec="true" source="above" session="tensor" result="python"
1130
+ t = Tensor([1, 2, 3, 4])
1131
+ print(t.unsqueeze(0).numpy())
1132
+ ```
1133
+ ```python exec="true" source="above" session="tensor" result="python"
1134
+ print(t.unsqueeze(1).numpy())
1135
+ ```
1136
+ """
1137
+ dim = self._resolve_dim(dim, outer=True)
493
1138
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
494
1139
 
495
- # (padding_left, padding_right, padding_top, padding_bottom)
496
- def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
1140
+ def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
1141
+ """
1142
+ Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
1143
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
1144
+
1145
+ ```python exec="true" source="above" session="tensor" result="python"
1146
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1147
+ print(t.numpy())
1148
+ ```
1149
+ ```python exec="true" source="above" session="tensor" result="python"
1150
+ print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
1151
+ ```
1152
+ """
497
1153
  slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
498
- return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
1154
+ return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
499
1155
 
500
1156
  @property
501
- def T(self) -> Tensor: return self.transpose()
502
- def transpose(self, ax1=1, ax2=0) -> Tensor:
1157
+ def T(self) -> Tensor:
1158
+ """`.T` is an alias for `.transpose()`."""
1159
+ return self.transpose()
1160
+
1161
+ def transpose(self, dim0=1, dim1=0) -> Tensor:
1162
+ """
1163
+ Returns a tensor that is a transposed version of the original tensor.
1164
+ The given dimensions `dim0` and `dim1` are swapped.
1165
+
1166
+ ```python exec="true" source="above" session="tensor" result="python"
1167
+ t = Tensor.arange(6).reshape(2, 3)
1168
+ print(t.numpy())
1169
+ ```
1170
+ ```python exec="true" source="above" session="tensor" result="python"
1171
+ print(t.transpose(0, 1).numpy())
1172
+ ```
1173
+ """
503
1174
  order = list(range(self.ndim))
504
- order[ax1], order[ax2] = order[ax2], order[ax1]
1175
+ order[dim0], order[dim1] = order[dim1], order[dim0]
505
1176
  return self.permute(order)
1177
+
506
1178
  def flatten(self, start_dim=0, end_dim=-1):
507
- start_dim, end_dim = start_dim + self.ndim if start_dim < 0 else start_dim, end_dim + self.ndim if end_dim < 0 else end_dim
1179
+ """
1180
+ Flattens the tensor by reshaping it into a one-dimensional tensor.
1181
+ If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
1182
+
1183
+ ```python exec="true" source="above" session="tensor" result="python"
1184
+ t = Tensor.arange(8).reshape(2, 2, 2)
1185
+ print(t.flatten().numpy())
1186
+ ```
1187
+ ```python exec="true" source="above" session="tensor" result="python"
1188
+ print(t.flatten(start_dim=1).numpy())
1189
+ ```
1190
+ """
1191
+ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
508
1192
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1193
+
509
1194
  def unflatten(self, dim:int, sizes:Tuple[int,...]):
510
- if dim < 0: dim += self.ndim
1195
+ """
1196
+ Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
1197
+
1198
+ ```python exec="true" source="above" session="tensor" result="python"
1199
+ print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
1200
+ ```
1201
+ ```python exec="true" source="above" session="tensor" result="python"
1202
+ print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
1203
+ ```
1204
+ ```python exec="true" source="above" session="tensor" result="python"
1205
+ print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
1206
+ ```
1207
+ """
1208
+ dim = self._resolve_dim(dim)
511
1209
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
512
1210
 
513
1211
  # ***** reduce ops *****
514
1212
 
515
1213
  def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
516
- axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
517
- axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
518
- shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
519
- if 0 in self.shape and 0 not in shape:
520
- return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
521
- ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
522
- return ret if keepdim else ret.reshape(shape=shape)
523
-
524
- def sum(self, axis=None, keepdim=False):
525
- acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
526
- least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
527
- least_upper_dtype(self.dtype, dtypes.float)
528
- # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
529
- output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
530
- return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
531
-
532
- def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
533
- def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
1214
+ if self.ndim == 0:
1215
+ if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
1216
+ axis = None
1217
+ axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
1218
+ axis_ = tuple(self._resolve_dim(x) for x in axis_)
1219
+ ret = fxn.apply(self, axis=axis_)
1220
+ return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1221
+
1222
+ def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
1223
+ """
1224
+ Sums the elements of the tensor along the specified axis or axes.
1225
+
1226
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1227
+ which the maximum is computed and whether the reduced dimensions are retained.
1228
+
1229
+ You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1230
+ If not specified, the accumulation data type is chosen based on the input tensor's data type.
1231
+
1232
+ ```python exec="true" source="above" session="tensor" result="python"
1233
+ t = Tensor.arange(6).reshape(2, 3)
1234
+ print(t.numpy())
1235
+ ```
1236
+ ```python exec="true" source="above" session="tensor" result="python"
1237
+ print(t.sum().numpy())
1238
+ ```
1239
+ ```python exec="true" source="above" session="tensor" result="python"
1240
+ print(t.sum(axis=0).numpy())
1241
+ ```
1242
+ ```python exec="true" source="above" session="tensor" result="python"
1243
+ print(t.sum(axis=1).numpy())
1244
+ ```
1245
+ """
1246
+ ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
1247
+ return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
1248
+ def max(self, axis=None, keepdim=False):
1249
+ """
1250
+ Returns the maximum value of the tensor along the specified axis or axes.
1251
+
1252
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1253
+ which the maximum is computed and whether the reduced dimensions are retained.
1254
+
1255
+ ```python exec="true" source="above" session="tensor" result="python"
1256
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1257
+ print(t.numpy())
1258
+ ```
1259
+ ```python exec="true" source="above" session="tensor" result="python"
1260
+ print(t.max().numpy())
1261
+ ```
1262
+ ```python exec="true" source="above" session="tensor" result="python"
1263
+ print(t.max(axis=0).numpy())
1264
+ ```
1265
+ ```python exec="true" source="above" session="tensor" result="python"
1266
+ print(t.max(axis=1, keepdim=True).numpy())
1267
+ ```
1268
+ """
1269
+ return self._reduce(F.Max, axis, keepdim)
1270
+ def min(self, axis=None, keepdim=False):
1271
+ """
1272
+ Returns the minimum value of the tensor along the specified axis or axes.
1273
+
1274
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1275
+ which the minimum is computed and whether the reduced dimensions are retained.
1276
+
1277
+ ```python exec="true" source="above" session="tensor" result="python"
1278
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1279
+ print(t.numpy())
1280
+ ```
1281
+ ```python exec="true" source="above" session="tensor" result="python"
1282
+ print(t.min().numpy())
1283
+ ```
1284
+ ```python exec="true" source="above" session="tensor" result="python"
1285
+ print(t.min(axis=0).numpy())
1286
+ ```
1287
+ ```python exec="true" source="above" session="tensor" result="python"
1288
+ print(t.min(axis=1, keepdim=True).numpy())
1289
+ ```
1290
+ """
1291
+ return -((-self).max(axis=axis, keepdim=keepdim))
534
1292
 
535
1293
  def mean(self, axis=None, keepdim=False):
536
- assert all_int(self.shape), "does not support symbolic shape"
537
- out = self.sum(axis=axis, keepdim=keepdim)
538
- return out.mul(prod(out.shape)/prod(self.shape)) if 0 not in self.shape else out
539
- def std(self, axis=None, keepdim=False, correction=1):
1294
+ """
1295
+ Returns the mean value of the tensor along the specified axis or axes.
1296
+
1297
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1298
+ which the mean is computed and whether the reduced dimensions are retained.
1299
+
1300
+ ```python exec="true" source="above" session="tensor" result="python"
1301
+ Tensor.manual_seed(42)
1302
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1303
+ print(t.numpy())
1304
+ ```
1305
+ ```python exec="true" source="above" session="tensor" result="python"
1306
+ print(t.mean().numpy())
1307
+ ```
1308
+ ```python exec="true" source="above" session="tensor" result="python"
1309
+ print(t.mean(axis=0).numpy())
1310
+ ```
1311
+ ```python exec="true" source="above" session="tensor" result="python"
1312
+ print(t.mean(axis=1).numpy())
1313
+ ```
1314
+ """
1315
+ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
1316
+ numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1317
+ return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
1318
+
1319
+ def var(self, axis=None, keepdim=False, correction=1):
1320
+ """
1321
+ Returns the variance of the tensor along the specified axis or axes.
1322
+
1323
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1324
+ which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1325
+
1326
+ ```python exec="true" source="above" session="tensor" result="python"
1327
+ Tensor.manual_seed(42)
1328
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1329
+ print(t.numpy())
1330
+ ```
1331
+ ```python exec="true" source="above" session="tensor" result="python"
1332
+ print(t.var().numpy())
1333
+ ```
1334
+ ```python exec="true" source="above" session="tensor" result="python"
1335
+ print(t.var(axis=0).numpy())
1336
+ ```
1337
+ ```python exec="true" source="above" session="tensor" result="python"
1338
+ print(t.var(axis=1).numpy())
1339
+ ```
1340
+ """
540
1341
  assert all_int(self.shape), "does not support symbolic shape"
541
1342
  square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
542
- return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt()
1343
+ return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
1344
+
1345
+ def std(self, axis=None, keepdim=False, correction=1):
1346
+ """
1347
+ Returns the standard deviation of the tensor along the specified axis or axes.
1348
+
1349
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1350
+ which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1351
+
1352
+ ```python exec="true" source="above" session="tensor" result="python"
1353
+ Tensor.manual_seed(42)
1354
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1355
+ print(t.numpy())
1356
+ ```
1357
+ ```python exec="true" source="above" session="tensor" result="python"
1358
+ print(t.std().numpy())
1359
+ ```
1360
+ ```python exec="true" source="above" session="tensor" result="python"
1361
+ print(t.std(axis=0).numpy())
1362
+ ```
1363
+ ```python exec="true" source="above" session="tensor" result="python"
1364
+ print(t.std(axis=1).numpy())
1365
+ ```
1366
+ """
1367
+ return self.var(axis, keepdim, correction).sqrt()
1368
+
543
1369
  def _softmax(self, axis):
544
1370
  m = self - self.max(axis=axis, keepdim=True)
545
1371
  e = m.exp()
546
1372
  return m, e, e.sum(axis=axis, keepdim=True)
547
1373
 
548
1374
  def softmax(self, axis=-1):
1375
+ """
1376
+ Applies the softmax function to the tensor along the specified axis.
1377
+
1378
+ Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
1379
+
1380
+ You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
1381
+
1382
+ ```python exec="true" source="above" session="tensor" result="python"
1383
+ Tensor.manual_seed(42)
1384
+ t = Tensor.randn(2, 3)
1385
+ print(t.numpy())
1386
+ ```
1387
+ ```python exec="true" source="above" session="tensor" result="python"
1388
+ print(t.softmax().numpy())
1389
+ ```
1390
+ ```python exec="true" source="above" session="tensor" result="python"
1391
+ print(t.softmax(axis=0).numpy())
1392
+ ```
1393
+ """
549
1394
  _, e, ss = self._softmax(axis)
550
1395
  return e.div(ss)
551
1396
 
552
1397
  def log_softmax(self, axis=-1):
1398
+ """
1399
+ Applies the log-softmax function to the tensor along the specified axis.
1400
+
1401
+ The log-softmax function is a numerically stable alternative to the softmax function in log space.
1402
+
1403
+ You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
1404
+
1405
+ ```python exec="true" source="above" session="tensor" result="python"
1406
+ Tensor.manual_seed(42)
1407
+ t = Tensor.randn(2, 3)
1408
+ print(t.numpy())
1409
+ ```
1410
+ ```python exec="true" source="above" session="tensor" result="python"
1411
+ print(t.log_softmax().numpy())
1412
+ ```
1413
+ ```python exec="true" source="above" session="tensor" result="python"
1414
+ print(t.log_softmax(axis=0).numpy())
1415
+ ```
1416
+ """
553
1417
  m, _, ss = self._softmax(axis)
554
1418
  return m - ss.log()
555
1419
 
1420
+ def logsumexp(self, axis=None, keepdim=False):
1421
+ """
1422
+ Computes the log-sum-exp of the tensor along the specified axis or axes.
1423
+
1424
+ The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
1425
+
1426
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1427
+ which the log-sum-exp is computed and whether the reduced dimensions are retained.
1428
+
1429
+ ```python exec="true" source="above" session="tensor" result="python"
1430
+ Tensor.manual_seed(42)
1431
+ t = Tensor.randn(2, 3)
1432
+ print(t.numpy())
1433
+ ```
1434
+ ```python exec="true" source="above" session="tensor" result="python"
1435
+ print(t.logsumexp().numpy())
1436
+ ```
1437
+ ```python exec="true" source="above" session="tensor" result="python"
1438
+ print(t.logsumexp(axis=0).numpy())
1439
+ ```
1440
+ ```python exec="true" source="above" session="tensor" result="python"
1441
+ print(t.logsumexp(axis=1).numpy())
1442
+ ```
1443
+ """
1444
+ m = self.max(axis=axis, keepdim=True)
1445
+ return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
1446
+
556
1447
  def argmax(self, axis=None, keepdim=False):
1448
+ """
1449
+ Returns the indices of the maximum value of the tensor along the specified axis.
1450
+
1451
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1452
+ which the maximum is computed and whether the reduced dimensions are retained.
1453
+
1454
+ ```python exec="true" source="above" session="tensor" result="python"
1455
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1456
+ print(t.numpy())
1457
+ ```
1458
+ ```python exec="true" source="above" session="tensor" result="python"
1459
+ print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
1460
+ ```
1461
+ ```python exec="true" source="above" session="tensor" result="python"
1462
+ print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
1463
+ ```
1464
+ ```python exec="true" source="above" session="tensor" result="python"
1465
+ print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
1466
+ ```
1467
+ """
557
1468
  if axis is None:
558
1469
  idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
559
- return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
560
- axis = axis + len(self.shape) if axis < 0 else axis
1470
+ return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
1471
+ axis = self._resolve_dim(axis)
561
1472
  m = self == self.max(axis=axis, keepdim=True)
562
1473
  idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
563
- return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
564
- def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
1474
+ return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
1475
+
1476
+ def argmin(self, axis=None, keepdim=False):
1477
+ """
1478
+ Returns the indices of the minimum value of the tensor along the specified axis.
1479
+
1480
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1481
+ which the minimum is computed and whether the reduced dimensions are retained.
1482
+
1483
+ ```python exec="true" source="above" session="tensor" result="python"
1484
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1485
+ print(t.numpy())
1486
+ ```
1487
+ ```python exec="true" source="above" session="tensor" result="python"
1488
+ print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
1489
+ ```
1490
+ ```python exec="true" source="above" session="tensor" result="python"
1491
+ print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
1492
+ ```
1493
+ ```python exec="true" source="above" session="tensor" result="python"
1494
+ print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
1495
+ ```
1496
+ """
1497
+ return (-self).argmax(axis=axis, keepdim=keepdim)
565
1498
 
566
1499
  @staticmethod
567
- def einsum(formula:str, *raw_xs) -> Tensor:
1500
+ def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
1501
+ """
1502
+ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
1503
+
1504
+ See: https://pytorch.org/docs/stable/generated/torch.einsum.html
1505
+
1506
+ ```python exec="true" source="above" session="tensor" result="python"
1507
+ x = Tensor([[1, 2], [3, 4]])
1508
+ y = Tensor([[5, 6], [7, 8]])
1509
+ print(Tensor.einsum("ij,ij->", x, y).numpy())
1510
+ ```
1511
+ """
568
1512
  xs:Tuple[Tensor] = argfix(*raw_xs)
569
1513
  formula = formula.replace(" ", "")
570
1514
  inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
@@ -580,9 +1524,13 @@ class Tensor:
580
1524
  # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
581
1525
  xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
582
1526
 
583
- rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
1527
+ # determine the inverse permutation to revert back to original order
1528
+ rhs_letter_order = argsort(list(output))
1529
+ rhs_order = argsort(rhs_letter_order)
1530
+
584
1531
  # sum over all axes that's not in the output, then permute to the output order
585
- return reduce(lambda a,b:a*b, xs_).sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
1532
+ return functools.reduce(lambda a,b:a*b, xs_) \
1533
+ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
586
1534
 
587
1535
  # ***** processing ops *****
588
1536
 
@@ -590,46 +1538,72 @@ class Tensor:
590
1538
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
591
1539
  assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
592
1540
  s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
593
- assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
594
- slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
1541
+ assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
1542
+ noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
595
1543
  if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
596
1544
  o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
597
- e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
598
- xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # noqa: E501
599
- # slide by dilation
600
- xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
601
- xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
602
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
603
- # handle stride, and permute to move reduce to the end
604
- xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
605
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
606
- xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
607
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
1545
+ # repeats such that we don't need padding
1546
+ xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
1547
+ # slice by dilation
1548
+ xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
1549
+ # handle stride
1550
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1551
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1552
+ # permute to move reduce to the end
1553
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
608
1554
  # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
609
1555
  o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
610
- xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
611
- xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
612
- xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
613
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
1556
+ xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1557
+ xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1558
+ xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
1559
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
614
1560
 
615
1561
  # NOTE: these work for more than 2D
616
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
617
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) # noqa: E501
618
-
619
- def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
620
- HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
621
- x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
622
- stride = make_pair(stride, len(HW))
623
- if any(s>1 for s in stride):
624
- x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
625
- x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
626
- x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
627
- x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
628
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501
629
- return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
630
-
631
- wino = getenv("WINO", 0)
632
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
1562
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1563
+ """
1564
+ Applies average pooling over a tensor.
1565
+
1566
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1567
+
1568
+ See: https://paperswithcode.com/method/average-pooling
1569
+
1570
+ ```python exec="true" source="above" session="tensor" result="python"
1571
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1572
+ print(t.avg_pool2d().numpy())
1573
+ ```
1574
+ """
1575
+ return self._pool(
1576
+ make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1577
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1578
+ """
1579
+ Applies max pooling over a tensor.
1580
+
1581
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1582
+
1583
+ See: https://paperswithcode.com/method/max-pooling
1584
+
1585
+ ```python exec="true" source="above" session="tensor" result="python"
1586
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1587
+ print(t.max_pool2d().numpy())
1588
+ ```
1589
+ """
1590
+ return self._pool(
1591
+ make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1592
+
1593
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
1594
+ """
1595
+ Applies a convolution over a tensor with a given `weight` and optional `bias`.
1596
+
1597
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
1598
+
1599
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
1600
+
1601
+ ```python exec="true" source="above" session="tensor" result="python"
1602
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1603
+ w = Tensor.ones(1, 1, 2, 2)
1604
+ print(t.conv2d(w).numpy())
1605
+ ```
1606
+ """
633
1607
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
634
1608
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
635
1609
  if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
@@ -638,19 +1612,17 @@ class Tensor:
638
1612
  # conv2d is a pooling op (with padding)
639
1613
  x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
640
1614
  rcout, oyx = cout//groups, x.shape[2:-len(HW)]
641
- if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
1615
+ if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
642
1616
  # normal conv
643
1617
  x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
644
1618
 
645
1619
  # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
646
- ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) # noqa: E501
1620
+ ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
647
1621
  return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
648
1622
 
649
- # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
650
- def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
651
1623
  HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
652
- winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
653
1624
  winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
1625
+ winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
654
1626
  winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
655
1627
 
656
1628
  # todo: stride == dilation
@@ -658,19 +1630,19 @@ class Tensor:
658
1630
  # (bs, cin_, tyx, HWI)
659
1631
  d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
660
1632
  # move HW to the front: # (HWI, bs, cin_, tyx)
661
- d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward()
1633
+ d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
662
1634
  tyx = d.shape[-len(HWI):] # dim of tiling
663
1635
 
664
1636
  g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
665
1637
 
666
1638
  # compute 6x6 winograd tiles: GgGt, BtdB
667
1639
  # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
668
- gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
1640
+ gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
669
1641
  # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
670
- dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
1642
+ dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
671
1643
 
672
1644
  # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
673
- ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW)))
1645
+ ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
674
1646
 
675
1647
  # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
676
1648
  ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
@@ -679,19 +1651,81 @@ class Tensor:
679
1651
 
680
1652
  return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
681
1653
 
682
- def dot(self, w:Tensor) -> Tensor:
1654
+ def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
1655
+ """
1656
+ Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
1657
+
1658
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
1659
+
1660
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
1661
+
1662
+ ```python exec="true" source="above" session="tensor" result="python"
1663
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1664
+ w = Tensor.ones(1, 1, 2, 2)
1665
+ print(t.conv_transpose2d(w).numpy())
1666
+ ```
1667
+ """
1668
+ HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
1669
+ x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
1670
+ stride = make_pair(stride, len(HW))
1671
+ if any(s>1 for s in stride):
1672
+ x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
1673
+ x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
1674
+ x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
1675
+ x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
1676
+ padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
1677
+ zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
1678
+ return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1679
+
1680
+ def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
1681
+ """
1682
+ Performs dot product between two tensors.
1683
+
1684
+ ```python exec="true" source="above" session="tensor" result="python"
1685
+ a = Tensor([[1, 2], [3, 4]])
1686
+ b = Tensor([[5, 6], [7, 8]])
1687
+ print(a.dot(b).numpy())
1688
+ ```
1689
+ """
683
1690
  n1, n2 = len(self.shape), len(w.shape)
684
1691
  assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
685
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
1692
+ assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
686
1693
  x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
687
1694
  w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
688
- return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
1695
+ return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
1696
+
1697
+ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
1698
+ """
1699
+ Performs matrix multiplication between two tensors.
1700
+
1701
+ You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
1702
+ You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
689
1703
 
690
- def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
1704
+ ```python exec="true" source="above" session="tensor" result="python"
1705
+ a = Tensor([[1, 2], [3, 4]])
1706
+ b = Tensor([[5, 6], [7, 8]])
1707
+ print(a.matmul(b).numpy())
1708
+ ```
1709
+ """
1710
+ return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
691
1711
 
692
1712
  def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
693
- return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
1713
+ pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
1714
+ return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
694
1715
  def cumsum(self, axis:int=0) -> Tensor:
1716
+ """
1717
+ Computes the cumulative sum of the tensor along the specified axis.
1718
+
1719
+ You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
1720
+
1721
+ ```python exec="true" source="above" session="tensor" result="python"
1722
+ t = Tensor.ones(2, 3)
1723
+ print(t.numpy())
1724
+ ```
1725
+ ```python exec="true" source="above" session="tensor" result="python"
1726
+ print(t.cumsum(1).numpy())
1727
+ ```
1728
+ """
695
1729
  # TODO: someday the optimizer will find this on it's own
696
1730
  # for now this is a two stage cumsum
697
1731
  SPLIT = 256
@@ -706,72 +1740,527 @@ class Tensor:
706
1740
  @staticmethod
707
1741
  def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
708
1742
  assert all_int((r,c)), "does not support symbolic"
1743
+ if r == 0: return Tensor.zeros((r, c), **kwargs)
709
1744
  return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
710
- def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
711
- def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
712
-
713
- # ***** mlops (unary) *****
714
-
715
- def neg(self): return mlops.Neg.apply(self)
716
- def logical_not(self): return self.neg() if self.dtype == dtypes.bool else (1.0-self)
717
- def contiguous(self): return mlops.Contiguous.apply(self)
718
- def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
719
- def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
720
- def log2(self): return self.log()/math.log(2)
721
- def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
722
- def exp2(self): return mlops.Exp.apply(self*math.log(2))
723
- def relu(self): return mlops.Relu.apply(self)
724
- def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
725
- def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
726
- def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
727
- def rsqrt(self): return self.reciprocal().sqrt()
728
- def cos(self): return ((math.pi/2)-self).sin()
729
- def tan(self): return self.sin() / self.cos()
730
-
731
- # ***** math functions (unary) *****
732
-
733
- def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
734
- def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
735
- def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
736
-
737
- def square(self): return self*self
738
- def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
739
- def abs(self): return self.relu() + (-self).relu()
740
- def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
741
- def reciprocal(self): return 1.0/self
742
-
743
- # ***** activation functions (unary) *****
744
-
745
- def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
746
- def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
747
- def swish(self): return self * self.sigmoid()
748
- def silu(self): return self.swish() # The SiLU function is also known as the swish function.
749
- def relu6(self): return self.relu() - (self-6).relu()
750
- def hardswish(self): return self * (self+3).relu6() * (1/6)
751
- def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
752
- def sinh(self): return (self.exp() - self.neg().exp()) / 2
753
- def cosh(self): return (self.exp() + self.neg().exp()) / 2
754
- def atanh(self): return ((1 + self)/(1 - self)).log() / 2
755
- def asinh(self): return (self + (self.square() + 1).sqrt()).log()
756
- def acosh(self): return (self + (self.square() - 1).sqrt()).log()
757
- def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
758
- def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
759
- def quick_gelu(self): return self * (self * 1.702).sigmoid()
760
- def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
761
- def mish(self): return self * self.softplus().tanh()
762
- def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
763
- def softsign(self): return self / (1 + self.abs())
764
-
765
- # ***** broadcasted elementwise mlops *****
766
-
767
- def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
1745
+ def triu(self, k:int=0) -> Tensor:
1746
+ """
1747
+ Returns the upper triangular part of the tensor, the other elements are set to 0.
1748
+
1749
+ ```python exec="true" source="above" session="tensor" result="python"
1750
+ t = Tensor([[1, 2, 3], [4, 5, 6]])
1751
+ print(t.numpy())
1752
+ ```
1753
+ ```python exec="true" source="above" session="tensor" result="python"
1754
+ print(t.triu(k=1).numpy())
1755
+ ```
1756
+ """
1757
+ return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
1758
+ def tril(self, k:int=0) -> Tensor:
1759
+ """
1760
+ Returns the lower triangular part of the tensor, the other elements are set to 0.
1761
+
1762
+ ```python exec="true" source="above" session="tensor" result="python"
1763
+ t = Tensor([[1, 2, 3], [4, 5, 6]])
1764
+ print(t.numpy())
1765
+ ```
1766
+ ```python exec="true" source="above" session="tensor" result="python"
1767
+ print(t.tril().numpy())
1768
+ ```
1769
+ """
1770
+ return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
1771
+
1772
+ # ***** unary ops *****
1773
+
1774
+ def logical_not(self):
1775
+ """
1776
+ Computes the logical NOT of the tensor element-wise.
1777
+
1778
+ ```python exec="true" source="above" session="tensor" result="python"
1779
+ print(Tensor([False, True]).logical_not().numpy())
1780
+ ```
1781
+ """
1782
+ return F.Eq.apply(*self._broadcasted(False))
1783
+ def neg(self):
1784
+ """
1785
+ Negates the tensor element-wise.
1786
+
1787
+ ```python exec="true" source="above" session="tensor" result="python"
1788
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
1789
+ ```
1790
+ """
1791
+ return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
1792
+ def contiguous(self):
1793
+ """
1794
+ Returns a contiguous tensor.
1795
+ """
1796
+ return F.Contiguous.apply(self)
1797
+ def contiguous_backward(self):
1798
+ """
1799
+ Inserts a contiguous operation in the backward pass.
1800
+ """
1801
+ return F.ContiguousBackward.apply(self)
1802
+ def log(self):
1803
+ """
1804
+ Computes the natural logarithm element-wise.
1805
+
1806
+ See: https://en.wikipedia.org/wiki/Logarithm
1807
+
1808
+ ```python exec="true" source="above" session="tensor" result="python"
1809
+ print(Tensor([1., 2., 4., 8.]).log().numpy())
1810
+ ```
1811
+ """
1812
+ return F.Log.apply(self.cast(least_upper_float(self.dtype)))
1813
+ def log2(self):
1814
+ """
1815
+ Computes the base-2 logarithm element-wise.
1816
+
1817
+ See: https://en.wikipedia.org/wiki/Logarithm
1818
+
1819
+ ```python exec="true" source="above" session="tensor" result="python"
1820
+ print(Tensor([1., 2., 4., 8.]).log2().numpy())
1821
+ ```
1822
+ """
1823
+ return self.log()/math.log(2)
1824
+ def exp(self):
1825
+ """
1826
+ Computes the exponential function element-wise.
1827
+
1828
+ See: https://en.wikipedia.org/wiki/Exponential_function
1829
+
1830
+ ```python exec="true" source="above" session="tensor" result="python"
1831
+ print(Tensor([0., 1., 2., 3.]).exp().numpy())
1832
+ ```
1833
+ """
1834
+ return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
1835
+ def exp2(self):
1836
+ """
1837
+ Computes the base-2 exponential function element-wise.
1838
+
1839
+ See: https://en.wikipedia.org/wiki/Exponential_function
1840
+
1841
+ ```python exec="true" source="above" session="tensor" result="python"
1842
+ print(Tensor([0., 1., 2., 3.]).exp2().numpy())
1843
+ ```
1844
+ """
1845
+ return F.Exp.apply(self*math.log(2))
1846
+ def relu(self):
1847
+ """
1848
+ Applies the Rectified Linear Unit (ReLU) function element-wise.
1849
+
1850
+ - Described: https://paperswithcode.com/method/relu
1851
+
1852
+ ```python exec="true" source="above" session="tensor" result="python"
1853
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
1854
+ ```
1855
+ """
1856
+ return F.Relu.apply(self)
1857
+ def sigmoid(self):
1858
+ """
1859
+ Applies the Sigmoid function element-wise.
1860
+
1861
+ - Described: https://en.wikipedia.org/wiki/Sigmoid_function
1862
+
1863
+ ```python exec="true" source="above" session="tensor" result="python"
1864
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
1865
+ ```
1866
+ """
1867
+ return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
1868
+ def sqrt(self):
1869
+ """
1870
+ Computes the square root of the tensor element-wise.
1871
+
1872
+ ```python exec="true" source="above" session="tensor" result="python"
1873
+ print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
1874
+ ```
1875
+ """
1876
+ return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
1877
+ def rsqrt(self):
1878
+ """
1879
+ Computes the reciprocal of the square root of the tensor element-wise.
1880
+
1881
+ ```python exec="true" source="above" session="tensor" result="python"
1882
+ print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
1883
+ ```
1884
+ """
1885
+ return self.reciprocal().sqrt()
1886
+ def sin(self):
1887
+ """
1888
+ Computes the sine of the tensor element-wise.
1889
+
1890
+ ```python exec="true" source="above" session="tensor" result="python"
1891
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
1892
+ ```
1893
+ """
1894
+ return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
1895
+ def cos(self):
1896
+ """
1897
+ Computes the cosine of the tensor element-wise.
1898
+
1899
+ ```python exec="true" source="above" session="tensor" result="python"
1900
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
1901
+ ```
1902
+ """
1903
+ return ((math.pi/2)-self).sin()
1904
+ def tan(self):
1905
+ """
1906
+ Computes the tangent of the tensor element-wise.
1907
+
1908
+ ```python exec="true" source="above" session="tensor" result="python"
1909
+ print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
1910
+ ```
1911
+ """
1912
+ return self.sin() / self.cos()
1913
+
1914
+ # ***** math functions *****
1915
+
1916
+ def trunc(self: Tensor) -> Tensor:
1917
+ """
1918
+ Truncates the tensor element-wise.
1919
+
1920
+ ```python exec="true" source="above" session="tensor" result="python"
1921
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
1922
+ ```
1923
+ """
1924
+ return self.cast(dtypes.int32).cast(self.dtype)
1925
+ def ceil(self: Tensor) -> Tensor:
1926
+ """
1927
+ Rounds the tensor element-wise towards positive infinity.
1928
+
1929
+ ```python exec="true" source="above" session="tensor" result="python"
1930
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
1931
+ ```
1932
+ """
1933
+ return (self > (b := self.trunc())).where(b+1, b)
1934
+ def floor(self: Tensor) -> Tensor:
1935
+ """
1936
+ Rounds the tensor element-wise towards negative infinity.
1937
+
1938
+ ```python exec="true" source="above" session="tensor" result="python"
1939
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
1940
+ ```
1941
+ """
1942
+ return (self < (b := self.trunc())).where(b-1, b)
1943
+ def round(self: Tensor) -> Tensor:
1944
+ """
1945
+ Rounds the tensor element-wise.
1946
+
1947
+ ```python exec="true" source="above" session="tensor" result="python"
1948
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
1949
+ ```
1950
+ """
1951
+ return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
1952
+ def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
1953
+ """
1954
+ Linearly interpolates between `self` and `end` by `weight`.
1955
+
1956
+ ```python exec="true" source="above" session="tensor" result="python"
1957
+ print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
1958
+ ```
1959
+ """
1960
+ return self + (end - self) * weight
1961
+ def square(self):
1962
+ """
1963
+ Squares the tensor element-wise.
1964
+ Equivalent to `self*self`.
1965
+
1966
+ ```python exec="true" source="above" session="tensor" result="python"
1967
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
1968
+ ```
1969
+ """
1970
+ return self*self
1971
+ def clip(self, min_, max_):
1972
+ """
1973
+ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
1974
+
1975
+ ```python exec="true" source="above" session="tensor" result="python"
1976
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
1977
+ ```
1978
+ """
1979
+ return self.maximum(min_).minimum(max_)
1980
+ def sign(self):
1981
+ """
1982
+ Returns the sign of the tensor element-wise.
1983
+
1984
+ ```python exec="true" source="above" session="tensor" result="python"
1985
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
1986
+ ```
1987
+ """
1988
+ return F.Sign.apply(self)
1989
+ def abs(self):
1990
+ """
1991
+ Computes the absolute value of the tensor element-wise.
1992
+
1993
+ ```python exec="true" source="above" session="tensor" result="python"
1994
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
1995
+ ```
1996
+ """
1997
+ return self * self.sign()
1998
+ def reciprocal(self):
1999
+ """
2000
+ Compute `1/x` element-wise.
2001
+
2002
+ ```python exec="true" source="above" session="tensor" result="python"
2003
+ print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
2004
+ ```
2005
+ """
2006
+ return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
2007
+
2008
+ # ***** activation functions *****
2009
+
2010
+ def elu(self, alpha=1.0):
2011
+ """
2012
+ Applies the Exponential Linear Unit (ELU) function element-wise.
2013
+
2014
+ - Described: https://paperswithcode.com/method/elu
2015
+ - Paper: https://arxiv.org/abs/1511.07289v5
2016
+
2017
+ ```python exec="true" source="above" session="tensor" result="python"
2018
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
2019
+ ```
2020
+ """
2021
+ return self.relu() - alpha*(1-self.exp()).relu()
2022
+
2023
+ def celu(self, alpha=1.0):
2024
+ """
2025
+ Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
2026
+
2027
+ - Described: https://paperswithcode.com/method/celu
2028
+ - Paper: https://arxiv.org/abs/1704.07483
2029
+
2030
+ ```python exec="true" source="above" session="tensor" result="python"
2031
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
2032
+ ```
2033
+ """
2034
+ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2035
+
2036
+ def swish(self):
2037
+ """
2038
+ See `.silu()`
2039
+
2040
+ - Paper: https://arxiv.org/abs/1710.05941v1
2041
+
2042
+ ```python exec="true" source="above" session="tensor" result="python"
2043
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
2044
+ ```
2045
+ """
2046
+ return self * self.sigmoid()
2047
+
2048
+ def silu(self):
2049
+ """
2050
+ Applies the Sigmoid Linear Unit (SiLU) function element-wise.
2051
+
2052
+ - Described: https://paperswithcode.com/method/silu
2053
+ - Paper: https://arxiv.org/abs/1606.08415
2054
+
2055
+ ```python exec="true" source="above" session="tensor" result="python"
2056
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
2057
+ ```
2058
+ """
2059
+ return self.swish() # The SiLU function is also known as the swish function.
2060
+
2061
+ def relu6(self):
2062
+ """
2063
+ Applies the ReLU6 function element-wise.
2064
+
2065
+ - Described: https://paperswithcode.com/method/relu6
2066
+ - Paper: https://arxiv.org/abs/1704.04861v1
2067
+
2068
+ ```python exec="true" source="above" session="tensor" result="python"
2069
+ print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
2070
+ ```
2071
+ """
2072
+ return self.relu() - (self-6).relu()
2073
+
2074
+ def hardswish(self):
2075
+ """
2076
+ Applies the Hardswish function element-wise.
2077
+
2078
+ - Described: https://paperswithcode.com/method/hard-swish
2079
+ - Paper: https://arxiv.org/abs/1905.02244v5
2080
+
2081
+ ```python exec="true" source="above" session="tensor" result="python"
2082
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
2083
+ ```
2084
+ """
2085
+ return self * (self+3).relu6() * (1/6)
2086
+
2087
+ def tanh(self):
2088
+ """
2089
+ Applies the Hyperbolic Tangent (tanh) function element-wise.
2090
+
2091
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
2092
+
2093
+ ```python exec="true" source="above" session="tensor" result="python"
2094
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
2095
+ ```
2096
+ """
2097
+ return 2.0 * ((2.0 * self).sigmoid()) - 1.0
2098
+
2099
+ def sinh(self):
2100
+ """
2101
+ Applies the Hyperbolic Sine (sinh) function element-wise.
2102
+
2103
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
2104
+
2105
+ ```python exec="true" source="above" session="tensor" result="python"
2106
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
2107
+ ```
2108
+ """
2109
+ return (self.exp() - self.neg().exp()) / 2
2110
+
2111
+ def cosh(self):
2112
+ """
2113
+ Applies the Hyperbolic Cosine (cosh) function element-wise.
2114
+
2115
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
2116
+
2117
+ ```python exec="true" source="above" session="tensor" result="python"
2118
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
2119
+ ```
2120
+ """
2121
+ return (self.exp() + self.neg().exp()) / 2
2122
+
2123
+ def atanh(self):
2124
+ """
2125
+ Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
2126
+
2127
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
2128
+
2129
+ ```python exec="true" source="above" session="tensor" result="python"
2130
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
2131
+ ```
2132
+ """
2133
+ return ((1 + self)/(1 - self)).log() / 2
2134
+
2135
+ def asinh(self):
2136
+ """
2137
+ Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
2138
+
2139
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
2140
+
2141
+ ```python exec="true" source="above" session="tensor" result="python"
2142
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
2143
+ ```
2144
+ """
2145
+ return (self + (self.square() + 1).sqrt()).log()
2146
+
2147
+ def acosh(self):
2148
+ """
2149
+ Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
2150
+
2151
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
2152
+
2153
+ ```python exec="true" source="above" session="tensor" result="python"
2154
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
2155
+ ```
2156
+ """
2157
+ return (self + (self.square() - 1).sqrt()).log()
2158
+
2159
+ def hardtanh(self, min_val=-1, max_val=1):
2160
+ """
2161
+ Applies the Hardtanh function element-wise.
2162
+
2163
+ - Described: https://paperswithcode.com/method/hardtanh-activation
2164
+
2165
+ ```python exec="true" source="above" session="tensor" result="python"
2166
+ print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
2167
+ ```
2168
+ """
2169
+ return self.clip(min_val, max_val)
2170
+
2171
+ def gelu(self):
2172
+ """
2173
+ Applies the Gaussian Error Linear Unit (GELU) function element-wise.
2174
+
2175
+ - Described: https://paperswithcode.com/method/gelu
2176
+ - Paper: https://arxiv.org/abs/1606.08415v5
2177
+
2178
+ ```python exec="true" source="above" session="tensor" result="python"
2179
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
2180
+ ```
2181
+ """
2182
+ return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
2183
+
2184
+ def quick_gelu(self):
2185
+ """
2186
+ Applies the Sigmoid GELU approximation element-wise.
2187
+
2188
+ - Described: https://paperswithcode.com/method/gelu
2189
+
2190
+ ```python exec="true" source="above" session="tensor" result="python"
2191
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
2192
+ ```
2193
+ """
2194
+ return self * (self * 1.702).sigmoid()
2195
+
2196
+ def leakyrelu(self, neg_slope=0.01):
2197
+ """
2198
+ Applies the Leaky ReLU function element-wise.
2199
+
2200
+ - Described: https://paperswithcode.com/method/leaky-relu
2201
+
2202
+ ```python exec="true" source="above" session="tensor" result="python"
2203
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
2204
+ ```
2205
+ ```python exec="true" source="above" session="tensor" result="python"
2206
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
2207
+ ```
2208
+ """
2209
+ return self.relu() - (-neg_slope*self).relu()
2210
+
2211
+ def mish(self):
2212
+ """
2213
+ Applies the Mish function element-wise.
2214
+
2215
+ - Described: https://paperswithcode.com/method/mish
2216
+ - Paper: https://arxiv.org/abs/1908.08681v3
2217
+
2218
+ ```python exec="true" source="above" session="tensor" result="python"
2219
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
2220
+ ```
2221
+ """
2222
+ return self * self.softplus().tanh()
2223
+
2224
+ def softplus(self, beta=1):
2225
+ """
2226
+ Applies the Softplus function element-wise.
2227
+
2228
+ - Described: https://paperswithcode.com/method/softplus
2229
+
2230
+ ```python exec="true" source="above" session="tensor" result="python"
2231
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
2232
+ ```
2233
+ """
2234
+ return (1/beta) * (1 + (self*beta).exp()).log()
2235
+
2236
+ def softsign(self):
2237
+ """
2238
+ Applies the Softsign function element-wise.
2239
+
2240
+ - Described: https://paperswithcode.com/method/softsign
2241
+
2242
+ ```python exec="true" source="above" session="tensor" result="python"
2243
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
2244
+ ```
2245
+ """
2246
+ return self / (1 + self.abs())
2247
+
2248
+ # ***** broadcasted elementwise ops *****
2249
+ def _broadcast_to(self, shape:Tuple[sint, ...]):
2250
+ reshape_arg, _ = _pad_left(self.shape, shape)
2251
+ if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
2252
+ raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
2253
+ return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
2254
+
2255
+ def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
768
2256
  x: Tensor = self
769
2257
  if not isinstance(y, Tensor):
770
2258
  # make y a Tensor
771
- if 0 in self.shape: return self, self.full_like(y)
2259
+ assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
772
2260
  if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
773
2261
  else: y_dtype = dtypes.from_py(y)
774
- y = Tensor(y, self.device, y_dtype, requires_grad=False)
2262
+ if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
2263
+ else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
775
2264
 
776
2265
  if match_dtype:
777
2266
  output_dtype = least_upper_dtype(x.dtype, y.dtype)
@@ -779,63 +2268,231 @@ class Tensor:
779
2268
 
780
2269
  if reverse: x, y = y, x
781
2270
 
782
- # left pad shape with 1s
783
- if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
784
- elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
2271
+ # broadcast
2272
+ out_shape = _broadcast_shape(x.shape, y.shape)
2273
+ return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
785
2274
 
786
- broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
787
- return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
788
-
789
- def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
2275
+ def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
790
2276
  # TODO: update with multi
791
- return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
2277
+ return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
792
2278
  and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
793
2279
 
794
- def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
795
- x = self._to_const_val(x)
796
- return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
797
- def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
798
- x = self._to_const_val(x)
799
- return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
800
- def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
801
- x = self._to_const_val(x)
802
- if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
803
- if x.__class__ is not Tensor and x == -1.0: return -self
804
- return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
805
- def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
806
- x = self._to_const_val(x)
807
- return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
808
- def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
2280
+ def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2281
+ """
2282
+ Adds `self` and `x`.
2283
+ Equivalent to `self + x`.
2284
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2285
+
2286
+ ```python exec="true" source="above" session="tensor" result="python"
2287
+ Tensor.manual_seed(42)
2288
+ t = Tensor.randn(4)
2289
+ print(t.numpy())
2290
+ ```
2291
+ ```python exec="true" source="above" session="tensor" result="python"
2292
+ print(t.add(20).numpy())
2293
+ ```
2294
+ ```python exec="true" source="above" session="tensor" result="python"
2295
+ print(t.add(Tensor([[2.0], [3.5]])).numpy())
2296
+ ```
2297
+ """
2298
+ return F.Add.apply(*self._broadcasted(x, reverse))
2299
+
2300
+ def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2301
+ """
2302
+ Subtracts `x` from `self`.
2303
+ Equivalent to `self - x`.
2304
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2305
+
2306
+ ```python exec="true" source="above" session="tensor" result="python"
2307
+ Tensor.manual_seed(42)
2308
+ t = Tensor.randn(4)
2309
+ print(t.numpy())
2310
+ ```
2311
+ ```python exec="true" source="above" session="tensor" result="python"
2312
+ print(t.sub(20).numpy())
2313
+ ```
2314
+ ```python exec="true" source="above" session="tensor" result="python"
2315
+ print(t.sub(Tensor([[2.0], [3.5]])).numpy())
2316
+ ```
2317
+ """
2318
+ return F.Sub.apply(*self._broadcasted(x, reverse))
2319
+
2320
+ def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2321
+ """
2322
+ Multiplies `self` and `x`.
2323
+ Equivalent to `self * x`.
2324
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2325
+
2326
+ ```python exec="true" source="above" session="tensor" result="python"
2327
+ Tensor.manual_seed(42)
2328
+ t = Tensor.randn(4)
2329
+ print(t.numpy())
2330
+ ```
2331
+ ```python exec="true" source="above" session="tensor" result="python"
2332
+ print(t.mul(3).numpy())
2333
+ ```
2334
+ ```python exec="true" source="above" session="tensor" result="python"
2335
+ print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
2336
+ ```
2337
+ """
2338
+ return F.Mul.apply(*self._broadcasted(x, reverse))
2339
+
2340
+ def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
2341
+ """
2342
+ Divides `self` by `x`.
2343
+ Equivalent to `self / x`.
2344
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2345
+ By default, `div` performs true division. Set `upcast` to `False` for integer division.
2346
+
2347
+ ```python exec="true" source="above" session="tensor" result="python"
2348
+ Tensor.manual_seed(42)
2349
+ t = Tensor.randn(4)
2350
+ print(t.numpy())
2351
+ ```
2352
+ ```python exec="true" source="above" session="tensor" result="python"
2353
+ print(t.div(3).numpy())
2354
+ ```
2355
+ ```python exec="true" source="above" session="tensor" result="python"
2356
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
2357
+ ```
2358
+ ```python exec="true" source="above" session="tensor" result="python"
2359
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
2360
+ ```
2361
+ """
2362
+ numerator, denominator = self._broadcasted(x, reverse)
2363
+ if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
2364
+ return F.Div.apply(numerator, denominator)
2365
+
2366
+ def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2367
+ """
2368
+ Computes bitwise xor of `self` and `x`.
2369
+ Equivalent to `self ^ x`.
2370
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2371
+
2372
+ ```python exec="true" source="above" session="tensor" result="python"
2373
+ print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
2374
+ ```
2375
+ ```python exec="true" source="above" session="tensor" result="python"
2376
+ print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
2377
+ ```
2378
+ """
2379
+ return F.Xor.apply(*self._broadcasted(x, reverse))
2380
+
2381
+ def lshift(self, x:int):
2382
+ """
2383
+ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2384
+ Equivalent to `self << x`.
2385
+
2386
+ ```python exec="true" source="above" session="tensor" result="python"
2387
+ print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
2388
+ ```
2389
+ """
2390
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2391
+ return self.mul(2 ** x)
2392
+
2393
+ def rshift(self, x:int):
2394
+ """
2395
+ Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2396
+ Equivalent to `self >> x`.
2397
+
2398
+ ```python exec="true" source="above" session="tensor" result="python"
2399
+ print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
2400
+ ```
2401
+ """
2402
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2403
+ return self.div(2 ** x, upcast=False)
2404
+
2405
+ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2406
+ """
2407
+ Computes power of `self` with `x`.
2408
+ Equivalent to `self ** x`.
2409
+
2410
+ ```python exec="true" source="above" session="tensor" result="python"
2411
+ print(Tensor([-1, 2, 3]).pow(2).numpy())
2412
+ ```
2413
+ ```python exec="true" source="above" session="tensor" result="python"
2414
+ print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
2415
+ ```
2416
+ ```python exec="true" source="above" session="tensor" result="python"
2417
+ print((2 ** Tensor([-1, 2, 3])).numpy())
2418
+ ```
2419
+ """
809
2420
  x = self._to_const_val(x)
810
2421
  if not isinstance(x, Tensor) and not reverse:
811
2422
  # simple pow identities
812
2423
  if x < 0: return self.reciprocal().pow(-x)
813
- if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
814
- if x == 0.5: return self.sqrt()
2424
+ if x == 0: return 1 + self * 0
2425
+ if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
2426
+ if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
2427
+
2428
+ # positive const ** self
815
2429
  if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
816
- ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
817
- # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
818
- sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
819
- # we only need to correct the sign if the base is negative
820
- base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
821
- # we need 0 to be positive so we need to correct base_sign when the base is 0
822
- base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
823
- # inject nan if the base is negative and the power is not an integer
824
- to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501
825
- inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
826
- return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
827
- def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
828
-
829
- # TODO: this implicitly changes dtype with /2
830
- def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
831
- def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
832
-
833
- def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
834
- x_,y = self._broadcasted(input_, match_dtype=False)
835
- x,z = x_._broadcasted(other, match_dtype=False)
836
- return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
837
-
838
- # ***** op wrappers (wasted lines to make the typechecker happy) *****
2430
+
2431
+ base, exponent = self._broadcasted(x, reverse=reverse)
2432
+ # start with b ** e = exp(e * log(b))
2433
+ ret = base.abs().log().mul(exponent).exp()
2434
+ # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
2435
+ negative_base = (base < 0).detach().where(1, 0)
2436
+ # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
2437
+ correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
2438
+ # inject nan for negative base and non-integer exponent
2439
+ inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
2440
+ # apply correct_sign inject_nan, and fix 0 ** 0 = 1
2441
+ return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
2442
+
2443
+ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
2444
+ """
2445
+ Computes element-wise maximum of `self` and `x`.
2446
+
2447
+ ```python exec="true" source="above" session="tensor" result="python"
2448
+ print(Tensor([-1, 2, 3]).maximum(1).numpy())
2449
+ ```
2450
+ ```python exec="true" source="above" session="tensor" result="python"
2451
+ print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
2452
+ ```
2453
+ """
2454
+ return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
2455
+
2456
+ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
2457
+ """
2458
+ Computes element-wise minimum of `self` and `x`.
2459
+
2460
+ ```python exec="true" source="above" session="tensor" result="python"
2461
+ print(Tensor([-1, 2, 3]).minimum(1).numpy())
2462
+ ```
2463
+ ```python exec="true" source="above" session="tensor" result="python"
2464
+ print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
2465
+ ```
2466
+ """
2467
+ return -((-self).maximum(-x))
2468
+
2469
+ def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
2470
+ """
2471
+ Return a tensor of elements selected from either `x` or `y`, depending on `self`.
2472
+ `output_i = x_i if self_i else y_i`.
2473
+
2474
+ ```python exec="true" source="above" session="tensor" result="python"
2475
+ cond = Tensor([[True, True, False], [True, False, False]])
2476
+ print(cond.where(1, 3).numpy())
2477
+ ```
2478
+ ```python exec="true" source="above" session="tensor" result="python"
2479
+ Tensor.manual_seed(42)
2480
+ cond = Tensor.randn(2, 3)
2481
+ print(cond.numpy())
2482
+ ```
2483
+ ```python exec="true" source="above" session="tensor" result="python"
2484
+ print((cond > 0).where(cond, -float("inf")).numpy())
2485
+ ```
2486
+ """
2487
+ if isinstance(x, Tensor): x, y = x._broadcasted(y)
2488
+ elif isinstance(y, Tensor): y, x = y._broadcasted(x)
2489
+ cond, x = self._broadcasted(x, match_dtype=False)
2490
+ cond, y = cond._broadcasted(y, match_dtype=False)
2491
+ return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
2492
+
2493
+ def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
2494
+
2495
+ # ***** op wrappers *****
839
2496
 
840
2497
  def __neg__(self) -> Tensor: return self.neg()
841
2498
 
@@ -846,6 +2503,8 @@ class Tensor:
846
2503
  def __truediv__(self, x) -> Tensor: return self.div(x)
847
2504
  def __matmul__(self, x) -> Tensor: return self.matmul(x)
848
2505
  def __xor__(self, x) -> Tensor: return self.xor(x)
2506
+ def __lshift__(self, x) -> Tensor: return self.lshift(x)
2507
+ def __rshift__(self, x) -> Tensor: return self.rshift(x)
849
2508
 
850
2509
  def __radd__(self, x) -> Tensor: return self.add(x, True)
851
2510
  def __rsub__(self, x) -> Tensor: return self.sub(x, True)
@@ -862,38 +2521,134 @@ class Tensor:
862
2521
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
863
2522
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
864
2523
  def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
2524
+ def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
2525
+ def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
865
2526
 
866
- def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
867
- def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
2527
+ def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
2528
+ def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
868
2529
  def __ge__(self, x) -> Tensor: return (self<x).logical_not()
869
2530
  def __le__(self, x) -> Tensor: return (self>x).logical_not()
870
- def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
2531
+ def __eq__(self, x) -> Tensor: return F.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
871
2532
  def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
872
2533
 
873
2534
  # ***** functional nn ops *****
874
2535
 
875
2536
  def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
2537
+ """
2538
+ Applies a linear transformation to `self` using `weight` and `bias`.
2539
+
2540
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
2541
+
2542
+ ```python exec="true" source="above" session="tensor" result="python"
2543
+ t = Tensor([[1, 2], [3, 4]])
2544
+ weight = Tensor([[1, 2], [3, 4]])
2545
+ bias = Tensor([1, 2])
2546
+ print(t.linear(weight, bias).numpy())
2547
+ ```
2548
+ """
876
2549
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
877
2550
  return x.add(bias) if bias is not None else x
878
2551
 
879
- def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
2552
+ def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
2553
+ """
2554
+ Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
2555
+
2556
+ ```python exec="true" source="above" session="tensor" result="python"
2557
+ t = Tensor([1, 2, 3])
2558
+ print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
2559
+ ```
2560
+ """
2561
+ return functools.reduce(lambda x,f: f(x), ll, self)
880
2562
 
881
2563
  def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
2564
+ """
2565
+ Applies Layer Normalization over a mini-batch of inputs.
2566
+
2567
+ - Described: https://paperswithcode.com/method/layer-normalization
2568
+ - Paper: https://arxiv.org/abs/1607.06450v1
2569
+
2570
+ ```python exec="true" source="above" session="tensor" result="python"
2571
+ t = Tensor.randn(8, 10, 16) * 2 + 8
2572
+ print(t.mean().item(), t.std().item())
2573
+ ```
2574
+ ```python exec="true" source="above" session="tensor" result="python"
2575
+ t = t.layernorm()
2576
+ print(t.mean().item(), t.std().item())
2577
+ ```
2578
+ """
882
2579
  y = (self - self.mean(axis, keepdim=True))
883
2580
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
884
2581
 
885
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
886
- x = (self - mean.reshape(shape=[1, -1, 1, 1]))
887
- if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
888
- ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
889
- return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
2582
+ def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
2583
+ """
2584
+ Applies Batch Normalization over a mini-batch of inputs.
2585
+
2586
+ - Described: https://paperswithcode.com/method/batch-normalization
2587
+ - Paper: https://arxiv.org/abs/1502.03167
2588
+
2589
+ ```python exec="true" source="above" session="tensor" result="python"
2590
+ t = Tensor.randn(8, 4, 16, 16) * 2 + 8
2591
+ print(t.mean().item(), t.std().item())
2592
+ ```
2593
+ ```python exec="true" source="above" session="tensor" result="python"
2594
+ t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
2595
+ print(t.mean().item(), t.std().item())
2596
+ ```
2597
+ """
2598
+ axis_ = argfix(axis)
2599
+ shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
2600
+ x = self - mean.reshape(shape)
2601
+ if weight is not None: x = x * weight.reshape(shape)
2602
+ ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
2603
+ return (ret + bias.reshape(shape)) if bias is not None else ret
890
2604
 
891
2605
  def dropout(self, p=0.5) -> Tensor:
892
- if not Tensor.training or p == 0: return self
893
- return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
2606
+ """
2607
+ Applies dropout to `self`.
2608
+
2609
+ NOTE: dropout is only applied when `Tensor.training` is `True`.
2610
+
2611
+ - Described: https://paperswithcode.com/method/dropout
2612
+ - Paper: https://jmlr.org/papers/v15/srivastava14a.html
894
2613
 
895
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
896
- # NOTE: it works if key, value have symbolic shape
2614
+ ```python exec="true" source="above" session="tensor" result="python"
2615
+ Tensor.manual_seed(42)
2616
+ t = Tensor.randn(2, 2)
2617
+ with Tensor.train():
2618
+ print(t.dropout().numpy())
2619
+ ```
2620
+ """
2621
+ if not Tensor.training or p == 0: return self
2622
+ return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
2623
+
2624
+ def one_hot(self, num_classes:int) -> Tensor:
2625
+ """
2626
+ Converts `self` to a one-hot tensor.
2627
+
2628
+ ```python exec="true" source="above" session="tensor" result="python"
2629
+ t = Tensor([0, 1, 3, 3, 4])
2630
+ print(t.one_hot(5).numpy())
2631
+ ```
2632
+ """
2633
+ return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
2634
+
2635
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
2636
+ dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
2637
+ """
2638
+ Computes scaled dot-product attention.
2639
+ `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
2640
+
2641
+ - Described: https://paperswithcode.com/method/scaled
2642
+ - Paper: https://arxiv.org/abs/1706.03762v7
2643
+
2644
+ ```python exec="true" source="above" session="tensor" result="python"
2645
+ q = Tensor.randn(2, 4, 8)
2646
+ k = Tensor.randn(2, 4, 8)
2647
+ v = Tensor.randn(2, 4, 8)
2648
+ print(q.scaled_dot_product_attention(k, v).numpy())
2649
+ ```
2650
+ """
2651
+ # NOTE: it also works when `key` and `value` have symbolic shape.
897
2652
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
898
2653
  if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
899
2654
  if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
@@ -901,30 +2656,119 @@ class Tensor:
901
2656
  return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
902
2657
 
903
2658
  def binary_crossentropy(self, y:Tensor) -> Tensor:
2659
+ """
2660
+ Computes the binary cross-entropy loss between `self` and `y`.
2661
+
2662
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
2663
+
2664
+ ```python exec="true" source="above" session="tensor" result="python"
2665
+ t = Tensor([0.1, 0.9, 0.2])
2666
+ y = Tensor([0, 1, 0])
2667
+ print(t.binary_crossentropy(y).item())
2668
+ ```
2669
+ """
904
2670
  return (-y*self.log() - (1-y)*(1-self).log()).mean()
905
2671
 
906
2672
  def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
2673
+ """
2674
+ Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
2675
+
2676
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
2677
+
2678
+ ```python exec="true" source="above" session="tensor" result="python"
2679
+ t = Tensor([-1, 2, -3])
2680
+ y = Tensor([0, 1, 0])
2681
+ print(t.binary_crossentropy_logits(y).item())
2682
+ ```
2683
+ """
907
2684
  return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
908
2685
 
909
- def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
910
- # NOTE: self is a logits input
911
- loss_mask = (Y != ignore_index)
2686
+ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
2687
+ """
2688
+ Computes the sparse categorical cross-entropy loss between `self` and `Y`.
2689
+
2690
+ NOTE: `self` is logits and `Y` is the target labels.
2691
+
2692
+ See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
2693
+
2694
+ ```python exec="true" source="above" session="tensor" result="python"
2695
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
2696
+ Y = Tensor([1, 2])
2697
+ print(t.sparse_categorical_crossentropy(Y).item())
2698
+ ```
2699
+ """
2700
+ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
2701
+ log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
912
2702
  y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
913
- y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
914
- return self.log_softmax().mul(y).sum() / loss_mask.sum()
2703
+ y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
2704
+ smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
2705
+ return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
915
2706
 
916
2707
  # ***** cast ops *****
917
2708
 
918
- def cast(self, dtype:DType) -> Tensor:
919
- if self.dtype == dtype: return self
2709
+ def llvm_bf16_cast(self, dtype:DType):
920
2710
  # hack for devices that don't support bfloat16
921
- if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
922
- return mlops.Cast.apply(self, dtype=dtype)
2711
+ assert self.dtype == dtypes.bfloat16
2712
+ return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
2713
+ def cast(self, dtype:DType) -> Tensor:
2714
+ """
2715
+ Casts `self` to the given `dtype`.
2716
+
2717
+ ```python exec="true" source="above" session="tensor" result="python"
2718
+ t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
2719
+ print(t.dtype, t.numpy())
2720
+ ```
2721
+ ```python exec="true" source="above" session="tensor" result="python"
2722
+ t = t.cast(dtypes.int32)
2723
+ print(t.dtype, t.numpy())
2724
+ ```
2725
+ """
2726
+ return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
923
2727
  def bitcast(self, dtype:DType) -> Tensor:
924
- assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
925
- return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
926
- def float(self) -> Tensor: return self.cast(dtypes.float32)
927
- def half(self) -> Tensor: return self.cast(dtypes.float16)
2728
+ """
2729
+ Bitcasts `self` to the given `dtype` of the same itemsize.
2730
+
2731
+ `self` must not require a gradient.
2732
+
2733
+ ```python exec="true" source="above" session="tensor" result="python"
2734
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2735
+ print(t.dtype, t.numpy())
2736
+ ```
2737
+ ```python exec="true" source="above" session="tensor" result="python"
2738
+ t = t.bitcast(dtypes.uint32)
2739
+ print(t.dtype, t.numpy())
2740
+ ```
2741
+ """
2742
+ if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
2743
+ return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
2744
+ def float(self) -> Tensor:
2745
+ """
2746
+ Convenience method to cast `self` to a `float32` Tensor.
2747
+
2748
+ ```python exec="true" source="above" session="tensor" result="python"
2749
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2750
+ print(t.dtype, t.numpy())
2751
+ ```
2752
+ ```python exec="true" source="above" session="tensor" result="python"
2753
+ t = t.float()
2754
+ print(t.dtype, t.numpy())
2755
+ ```
2756
+ """
2757
+ return self.cast(dtypes.float32)
2758
+ def half(self) -> Tensor:
2759
+ """
2760
+ Convenience method to cast `self` to a `float16` Tensor.
2761
+
2762
+ ```python exec="true" source="above" session="tensor" result="python"
2763
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2764
+ print(t.dtype, t.numpy())
2765
+ ```
2766
+ ```python exec="true" source="above" session="tensor" result="python"
2767
+ t = t.half()
2768
+ print(t.dtype, t.numpy())
2769
+ ```
2770
+ """
2771
+ return self.cast(dtypes.float16)
928
2772
 
929
2773
  # ***** convenience stuff *****
930
2774
 
@@ -934,20 +2778,101 @@ class Tensor:
934
2778
  def element_size(self) -> int: return self.dtype.itemsize
935
2779
  def nbytes(self) -> int: return self.numel() * self.element_size()
936
2780
  def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
2781
+ def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
2782
+
2783
+ # *** image Tensor function replacements ***
2784
+
2785
+ def image_dot(self, w:Tensor, acc_dtype=None):
2786
+ # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
2787
+ n1, n2 = len(self.shape), len(w.shape)
2788
+ assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
2789
+ assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
2790
+ bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
2791
+ out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
2792
+
2793
+ # NOTE: with NHWC we can remove the transposes
2794
+ # bs x groups*cin x H x W
2795
+ cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
2796
+ # groups*cout x cin x H, W
2797
+ cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
2798
+ return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
2799
+
2800
+ def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
2801
+ base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
2802
+
2803
+ (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
2804
+ x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
2805
+
2806
+ # hack for non multiples of 4 on cin
2807
+ if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
2808
+ x = x.reshape(bs, groups, cin, iy, ix) # do this always?
2809
+ added_input_channels = 4 - (cin % 4)
2810
+ w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
2811
+ x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
2812
+ cin = cin + added_input_channels
2813
+ x = x.reshape(bs, groups*cin, iy, ix)
2814
+
2815
+ # hack for non multiples of 4 on rcout
2816
+ added_output_channels = 0
2817
+ if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
2818
+ added_output_channels = 4 - (rcout % 4)
2819
+ rcout += added_output_channels
2820
+ cout = groups * rcout
2821
+ w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
2822
+
2823
+ # packed (note: flipping bs and iy would make the auto-padding work)
2824
+ x = x.permute(0,2,3,1)
2825
+ cin_last = iy == 1 and ix == 1
2826
+ if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
2827
+ elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
2828
+ else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
2829
+
2830
+ # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
2831
+ if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
2832
+ x, w = x.contiguous(), w.contiguous()
2833
+
2834
+ # expand out
2835
+ rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
2836
+ cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
2837
+ x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
2838
+ if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
2839
+ else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
2840
+
2841
+ # padding
2842
+ padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
2843
+ x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
2844
+
2845
+ # prepare input
2846
+ x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
2847
+ x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
2848
+
2849
+ # prepare weights
2850
+ w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
2851
+
2852
+ # the conv!
2853
+ ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
2854
+
2855
+ # undo hack for non multiples of 4 on C.rcout
2856
+ if added_output_channels != 0:
2857
+ ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
2858
+ cout = groups * (rcout - added_output_channels)
2859
+
2860
+ # NCHW output
2861
+ ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
2862
+ return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
937
2863
 
938
2864
  # register functions to move between devices
939
- for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
2865
+ for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
940
2866
 
941
2867
  if IMAGE:
942
2868
  # if IMAGE>0 we install these replacement functions in Tensor (hack!)
943
- from tinygrad.features.image import image_conv2d, image_dot
944
- setattr(Tensor, "conv2d", image_conv2d)
945
- setattr(Tensor, "dot", image_dot)
2869
+ setattr(Tensor, "conv2d", Tensor.image_conv2d)
2870
+ setattr(Tensor, "dot", Tensor.image_dot)
946
2871
 
947
- # TODO: remove the custom op and replace with threefry
2872
+ # TODO: eventually remove this
948
2873
  def custom_random(out:Buffer):
949
2874
  Tensor._seed += 1
950
- if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
951
2875
  rng = np.random.default_rng(Tensor._seed)
952
- rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
2876
+ if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2877
+ else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
953
2878
  out.copyin(rng_np_buffer.data)