tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,18 +1,26 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math
4
- from functools import partialmethod, reduce
5
- from itertools import accumulate
3
+ import time, math, itertools, functools
4
+ from contextlib import ContextDecorator
5
+ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
6
+ from collections import defaultdict
6
7
  import numpy as np
7
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
8
8
 
9
- from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
9
+ from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
+ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv
11
+ from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
10
12
  from tinygrad.lazy import LazyBuffer
11
- from tinygrad.ops import Device, LoadOps
13
+ from tinygrad.multi import MultiLazyBuffer
14
+ from tinygrad.ops import LoadOps
15
+ from tinygrad.device import Device, Buffer, BufferOptions
16
+ from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
17
+ from tinygrad.engine.realize import run_schedule
18
+ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
19
+
20
+ # **** start with two base classes, Tensor and Function ****
12
21
 
13
- # An instantiation of the Function is the Context
14
22
  class Function:
15
- def __init__(self, device:str, *tensors:Tensor):
23
+ def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
16
24
  self.device = device
17
25
  self.needs_input_grad = [t.requires_grad for t in tensors]
18
26
  self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
@@ -24,23 +32,71 @@ class Function:
24
32
  @classmethod
25
33
  def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
26
34
  ctx = fxn(x[0].device, *x)
27
- ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
28
- if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
35
+ ret = Tensor.__new__(Tensor)
36
+ ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
+ ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
29
38
  return ret
30
39
 
31
- import tinygrad.mlops as mlops
32
-
33
- # **** start with two base classes, Tensor and Function ****
40
+ import tinygrad.function as F
41
+
42
+ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
+ if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
44
+ return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
45
+
46
+ def _fromcpu(x: np.ndarray) -> LazyBuffer:
47
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
48
+ # fake realize
49
+ ret.buffer.allocate(x)
50
+ del ret.srcs
51
+ return ret
52
+
53
+ def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
54
+ return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
55
+ for k in range(len(mat[0]))] for dim in range(dims)]
56
+
57
+ # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
58
+ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
59
+ # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
60
+ # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
61
+ t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
62
+ # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
63
+ matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device)
64
+ # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
65
+ ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
66
+ assert isinstance(ret, Tensor), "sum didn't return a Tensor"
67
+ return ret
68
+
69
+ def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
70
+ def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
34
71
 
35
72
  class Tensor:
73
+ """
74
+ A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
75
+
76
+ ```python exec="true" session="tensor"
77
+ from tinygrad import Tensor, dtypes, nn
78
+ import numpy as np
79
+ import math
80
+ np.set_printoptions(precision=4)
81
+ ```
82
+ """
36
83
  __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
37
84
  __deletable__ = ('_ctx',)
38
85
  training: ClassVar[bool] = False
86
+ class train(ContextDecorator):
87
+ def __init__(self, mode:bool = True): self.mode = mode
88
+ def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
89
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
90
+
39
91
  no_grad: ClassVar[bool] = False
40
- default_type: ClassVar[DType] = dtypes.float32
41
- def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
92
+ class inference_mode(ContextDecorator):
93
+ def __init__(self, mode:bool = True): self.mode = mode
94
+ def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
95
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
96
+ def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
97
+ device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
42
98
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
43
- device = Device.canonicalize(device)
99
+ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
44
100
  # tensors have gradients, buffers do not
45
101
  self.grad: Optional[Tensor] = None
46
102
 
@@ -51,169 +107,597 @@ class Tensor:
51
107
  # internal variables used for autograd graph construction
52
108
  self._ctx: Optional[Function] = None
53
109
  if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
54
- elif isinstance(data, (int, float)):
55
- self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
56
- return
57
- elif data.__class__ is list:
58
- assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
59
- data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
110
+ elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
111
+ elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
112
+ elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
113
+ elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
114
+ elif isinstance(data, list):
115
+ if dtype is None:
116
+ if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
117
+ else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
118
+ if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
119
+ else: data = _fromcpu(np.array(data, dtype.np))
60
120
  elif isinstance(data, np.ndarray):
61
- data = LazyBuffer.fromCPU(data)
62
- else: raise RuntimeError(f"can't create Tensor from {data}")
121
+ if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
122
+ else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
63
123
 
64
- self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
124
+ # data is a LazyBuffer, but it might be on the wrong device
125
+ if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
126
+ if isinstance(device, tuple):
127
+ # TODO: what if it's a MultiLazyBuffer on other devices?
128
+ self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
129
+ else:
130
+ self.lazydata = data if data.device == device else data.copy_to_device(device)
65
131
 
66
- def __repr__(self):
67
- return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
132
+ def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
68
133
 
69
134
  # Python has a non moving GC, so this should be okay
70
135
  def __hash__(self): return id(self)
71
136
 
137
+ def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
138
+
139
+ def __len__(self): return self.shape[0] if len(self.shape) else 1
140
+
72
141
  @property
73
- def device(self) -> str: return self.lazydata.device
142
+ def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
74
143
 
75
144
  @property
76
- def shape(self) -> Tuple[int, ...]: return self.lazydata.shape
145
+ def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
77
146
 
78
147
  @property
79
148
  def dtype(self) -> DType: return self.lazydata.dtype
80
149
 
81
150
  # ***** data handlers ****
82
151
 
83
- def realize(self) -> Tensor:
84
- self.lazydata.realize()
152
+ def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
153
+ """Creates the schedule needed to realize these Tensor(s), with Variables."""
154
+ if getenv("FUZZ_SCHEDULE"):
155
+ from test.external.fuzz_schedule import fuzz_schedule
156
+ fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
157
+ schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
158
+ return memory_planner(schedule), var_vals
159
+
160
+ def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
161
+ """Creates the schedule needed to realize these Tensor(s)."""
162
+ schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
163
+ assert len(var_vals) == 0
164
+ return schedule
165
+
166
+ def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
167
+ """Triggers the computation needed to create these Tensor(s)."""
168
+ run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
169
+ return self
170
+
171
+ def replace(self, x:Tensor) -> Tensor:
172
+ """
173
+ Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
174
+ """
175
+ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
176
+ assert not x.requires_grad and getattr(self, '_ctx', None) is None
177
+ assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
178
+ self.lazydata = x.lazydata
85
179
  return self
86
180
 
87
181
  def assign(self, x) -> Tensor:
88
- # TODO: this is a hack for writing to DISK
89
- if self.device.startswith("DISK"):
90
- if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
91
- self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
182
+ # TODO: this is a hack for writing to DISK. remove with working assign
183
+ if isinstance(self.device, str) and self.device.startswith("DISK"):
184
+ if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
185
+ self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
92
186
  return self
93
187
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
94
- assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
95
- assert not x.requires_grad # self requires_grad is okay?
96
188
  if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
97
- if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
98
- self.lazydata = x.lazydata
189
+ if self.lazydata is x.lazydata: return self # a self assign is a NOOP
190
+ # NOTE: we allow cross device assign
191
+ assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
192
+ assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
193
+ assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
194
+ assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
195
+ assert not x.requires_grad # self requires_grad is okay?
196
+ if not self.lazydata.is_realized(): return self.replace(x)
197
+ self.lazydata = self.lazydata.assign(x.lazydata)
99
198
  return self
199
+ def detach(self) -> Tensor:
200
+ """
201
+ Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
202
+ """
203
+ return Tensor(self.lazydata, device=self.device, requires_grad=False)
204
+
205
+ def _data(self) -> memoryview:
206
+ if 0 in self.shape: return memoryview(bytearray(0))
207
+ # NOTE: this realizes on the object from as_buffer being a Python object
208
+ cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
209
+ buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
210
+ if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
211
+ return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
212
+
213
+ def data(self) -> memoryview:
214
+ """
215
+ Returns the data of this tensor as a memoryview.
216
+
217
+ ```python exec="true" source="above" session="tensor" result="python"
218
+ t = Tensor([1, 2, 3, 4])
219
+ print(np.frombuffer(t.data(), dtype=np.int32))
220
+ ```
221
+ """
222
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
223
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
224
+ return self._data().cast(self.dtype.fmt, self.shape)
225
+
226
+ def item(self) -> ConstType:
227
+ """
228
+ Returns the value of this tensor as a standard Python number.
229
+
230
+ ```python exec="true" source="above" session="tensor" result="python"
231
+ t = Tensor(42)
232
+ print(t.item())
233
+ ```
234
+ """
235
+ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
236
+ assert self.numel() == 1, "must have one element for item"
237
+ return self._data().cast(self.dtype.fmt)[0]
238
+
239
+ # TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
240
+ # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
241
+ def tolist(self) -> Union[Sequence[ConstType], ConstType]:
242
+ """
243
+ Returns the value of this tensor as a nested list.
244
+
245
+ ```python exec="true" source="above" session="tensor" result="python"
246
+ t = Tensor([1, 2, 3, 4])
247
+ print(t.tolist())
248
+ ```
249
+ """
250
+ return self.data().tolist()
251
+
252
+ def numpy(self) -> np.ndarray:
253
+ """
254
+ Returns the value of this tensor as a `numpy.ndarray`.
255
+
256
+ ```python exec="true" source="above" session="tensor" result="python"
257
+ t = Tensor([1, 2, 3, 4])
258
+ print(repr(t.numpy()))
259
+ ```
260
+ """
261
+ if self.dtype == dtypes.bfloat16: return self.float().numpy()
262
+ assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
263
+ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
264
+ return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
265
+
266
+ def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
267
+ """
268
+ Moves the tensor to the given device.
269
+ """
270
+ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
271
+ if device == self.device: return self
272
+ if not isinstance(device, str): return self.shard(device)
273
+ ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
274
+ if self.grad is not None: ret.grad = self.grad.to(device)
275
+ if hasattr(self, '_ctx'): ret._ctx = self._ctx
276
+ return ret
100
277
 
101
- def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
102
- def numpy(self) -> np.ndarray: return self.lazydata.toCPU()
103
-
104
- # TODO: if things are realized this won't work
105
- def to_(self, device:str):
106
- assert self.lazydata.realized is None
107
- self.lazydata.device = device
108
- if self.grad: self.grad.to_(device)
278
+ def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
279
+ """
280
+ Moves the tensor to the given device in place.
281
+ """
282
+ real = self.to(device)
283
+ # TODO: is this assign?
284
+ if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
285
+ self.lazydata = real.lazydata
286
+
287
+ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
288
+ """
289
+ Shards the tensor across the given devices.
290
+ """
291
+ assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
292
+ canonical_devices = tuple(Device.canonicalize(x) for x in devices)
293
+ if axis is not None and axis < 0: axis += len(self.shape)
294
+ return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
295
+
296
+ def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
297
+ """
298
+ Shards the tensor across the given devices in place.
299
+ """
300
+ self.lazydata = self.shard(devices, axis).lazydata
301
+ return self
109
302
 
110
- def to(self, device:str) -> Tensor:
111
- ret = Tensor(self.lazydata, device)
112
- if self.grad: ret.grad = self.grad.to(device)
113
- return ret
303
+ @staticmethod
304
+ def from_node(y:Node, **kwargs) -> Tensor:
305
+ if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
306
+ if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
307
+ raise RuntimeError(f"unhandled Node {y}")
114
308
 
115
309
  # ***** creation llop entrypoint *****
116
310
 
117
311
  @staticmethod
118
- def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
119
- return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
312
+ def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
313
+ if isinstance(device, tuple):
314
+ return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
315
+ for d in device], None), device, dtype, **kwargs)
316
+ return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
120
317
 
121
318
  @staticmethod
122
- def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
319
+ def empty(*shape, **kwargs):
320
+ """
321
+ Creates an empty tensor with the given shape.
322
+
323
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
324
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
325
+
326
+ ```python exec="true" source="above" session="tensor" result="python"
327
+ t = Tensor.empty(2, 3)
328
+ print(t.shape)
329
+ ```
330
+ """
331
+ return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
123
332
 
124
333
  _seed: int = int(time.time())
334
+ _rng_counter: Optional[Tensor] = None
125
335
  @staticmethod
126
- def manual_seed(seed=0): Tensor._seed = seed
336
+ def manual_seed(seed=0):
337
+ """
338
+ Sets the seed for random operations.
339
+
340
+ ```python exec="true" source="above" session="tensor" result="python"
341
+ Tensor.manual_seed(42)
342
+ print(Tensor._seed)
343
+ ```
344
+ """
345
+ Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
127
346
 
128
347
  @staticmethod
129
- def rand(*shape, **kwargs):
130
- Tensor._seed += 1
131
- return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
348
+ def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
349
+ """
350
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
351
+
352
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
353
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
354
+
355
+ ```python exec="true" source="above" session="tensor" result="python"
356
+ Tensor.manual_seed(42)
357
+ t = Tensor.rand(2, 3)
358
+ print(t.numpy())
359
+ ```
360
+ """
361
+ if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
362
+ if not THREEFRY.value:
363
+ # for bfloat16, numpy rand passes buffer in float
364
+ if (dtype or dtypes.default_float) == dtypes.bfloat16:
365
+ return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
366
+ return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
367
+
368
+ # threefry
369
+ if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
370
+ counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
371
+ counts2 = counts1 + math.ceil(num / 2)
372
+ Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
373
+
374
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
375
+ ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
376
+
377
+ x = [counts1 + ks[-1], counts2 + ks[0]]
378
+ for i in range(5):
379
+ for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
380
+ x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
381
+ out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
382
+ out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
383
+ out.requires_grad = kwargs.get("requires_grad")
384
+ return out.contiguous()
132
385
 
133
386
  # ***** creation helper functions *****
134
387
 
135
388
  @staticmethod
136
- def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
389
+ def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
390
+ """
391
+ Creates a tensor with the given shape, filled with the given value.
392
+
393
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
394
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
395
+
396
+ ```python exec="true" source="above" session="tensor" result="python"
397
+ print(Tensor.full((2, 3), 42).numpy())
398
+ ```
399
+ ```python exec="true" source="above" session="tensor" result="python"
400
+ print(Tensor.full((2, 3), False).numpy())
401
+ ```
402
+ """
403
+ return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
137
404
 
138
405
  @staticmethod
139
- def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
406
+ def zeros(*shape, **kwargs):
407
+ """
408
+ Creates a tensor with the given shape, filled with zeros.
409
+
410
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
411
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
412
+
413
+ ```python exec="true" source="above" session="tensor" result="python"
414
+ print(Tensor.zeros(2, 3).numpy())
415
+ ```
416
+ ```python exec="true" source="above" session="tensor" result="python"
417
+ print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
418
+ ```
419
+ """
420
+ return Tensor.full(argfix(*shape), 0.0, **kwargs)
140
421
 
141
422
  @staticmethod
142
- def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
423
+ def ones(*shape, **kwargs):
424
+ """
425
+ Creates a tensor with the given shape, filled with ones.
426
+
427
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
428
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
429
+
430
+ ```python exec="true" source="above" session="tensor" result="python"
431
+ print(Tensor.ones(2, 3).numpy())
432
+ ```
433
+ ```python exec="true" source="above" session="tensor" result="python"
434
+ print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
435
+ ```
436
+ """
437
+ return Tensor.full(argfix(*shape), 1.0, **kwargs)
143
438
 
144
439
  @staticmethod
145
440
  def arange(start, stop=None, step=1, **kwargs):
441
+ """
442
+ Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
443
+
444
+ If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
445
+
446
+ If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
447
+
448
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
449
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
450
+
451
+ ```python exec="true" source="above" session="tensor" result="python"
452
+ print(Tensor.arange(5).numpy())
453
+ ```
454
+ ```python exec="true" source="above" session="tensor" result="python"
455
+ print(Tensor.arange(5, 10).numpy())
456
+ ```
457
+ ```python exec="true" source="above" session="tensor" result="python"
458
+ print(Tensor.arange(5, 10, 2).numpy())
459
+ ```
460
+ ```python exec="true" source="above" session="tensor" result="python"
461
+ print(Tensor.arange(5.5, 10, 2).numpy())
462
+ ```
463
+ """
146
464
  if stop is None: stop, start = start, 0
147
- return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
148
-
149
- @staticmethod
150
- def full_like(tensor, fill_value, **kwargs):
151
- return Tensor.full(tensor.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", tensor.dtype), device=kwargs.pop("device", tensor.device), **kwargs)
465
+ assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
466
+ dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
467
+ return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
152
468
 
153
469
  @staticmethod
154
- def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
155
-
156
- @staticmethod
157
- def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
158
-
159
- @staticmethod
160
- def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
470
+ def eye(dim:int, **kwargs):
471
+ """
472
+ Creates an identity matrix of the given dimension.
473
+
474
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
475
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
476
+
477
+ ```python exec="true" source="above" session="tensor" result="python"
478
+ print(Tensor.eye(3).numpy())
479
+ ```
480
+ """
481
+ return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
482
+
483
+ def full_like(self, fill_value:ConstType, **kwargs):
484
+ """
485
+ Creates a tensor with the same shape as `self`, filled with the given value.
486
+ If `dtype` is not specified, the dtype of `self` is used.
487
+
488
+ You can pass in the `device` keyword argument to control device of the tensor.
489
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
490
+
491
+ ```python exec="true" source="above" session="tensor" result="python"
492
+ t = Tensor.ones(2, 3)
493
+ print(Tensor.full_like(t, 42).numpy())
494
+ ```
495
+ """
496
+ return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
497
+
498
+ def zeros_like(self, **kwargs):
499
+ """
500
+ Creates a tensor with the same shape as `self`, filled with zeros.
501
+
502
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
503
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
504
+
505
+ ```python exec="true" source="above" session="tensor" result="python"
506
+ t = Tensor.ones(2, 3)
507
+ print(Tensor.zeros_like(t).numpy())
508
+ ```
509
+ """
510
+ return self.full_like(0, **kwargs)
511
+
512
+ def ones_like(self, **kwargs):
513
+ """
514
+ Creates a tensor with the same shape as `self`, filled with ones.
515
+
516
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
517
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
518
+
519
+ ```python exec="true" source="above" session="tensor" result="python"
520
+ t = Tensor.zeros(2, 3)
521
+ print(Tensor.ones_like(t).numpy())
522
+ ```
523
+ """
524
+ return self.full_like(1, **kwargs)
161
525
 
162
526
  # ***** rng hlops *****
163
527
 
164
528
  @staticmethod
165
529
  def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
530
+ """
531
+ Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
532
+ If `dtype` is not specified, the default type is used.
533
+
534
+ You can pass in the `device` keyword argument to control device of the tensor.
535
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
536
+
537
+ ```python exec="true" source="above" session="tensor" result="python"
538
+ Tensor.manual_seed(42)
539
+ print(Tensor.randn(2, 3).numpy())
540
+ ```
541
+ """
166
542
  # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
167
- src = Tensor.rand(2, *shape, **kwargs)
168
- return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
543
+ src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
544
+ return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
545
+
546
+ @staticmethod
547
+ def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
548
+ """
549
+ Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
550
+ If `dtype` is not specified, the default type is used.
551
+
552
+ You can pass in the `device` keyword argument to control device of the tensor.
553
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
554
+
555
+ ```python exec="true" source="above" session="tensor" result="python"
556
+ Tensor.manual_seed(42)
557
+ print(Tensor.randint(2, 3, low=5, high=10).numpy())
558
+ ```
559
+ """
560
+ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
561
+ dtype = kwargs.pop("dtype", dtypes.int32)
562
+ if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
563
+ return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
169
564
 
170
565
  @staticmethod
171
- def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
566
+ def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
567
+ """
568
+ Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
569
+
570
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
571
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
572
+
573
+ ```python exec="true" source="above" session="tensor" result="python"
574
+ Tensor.manual_seed(42)
575
+ print(Tensor.normal(2, 3, mean=10, std=2).numpy())
576
+ ```
577
+ """
578
+ return (std * Tensor.randn(*shape, **kwargs)) + mean
172
579
 
173
580
  @staticmethod
174
- def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
175
- dtype = kwargs.pop("dtype", Tensor.default_type)
581
+ def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
582
+ """
583
+ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
584
+
585
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
586
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
587
+
588
+ ```python exec="true" source="above" session="tensor" result="python"
589
+ Tensor.manual_seed(42)
590
+ print(Tensor.uniform(2, 3, low=2, high=10).numpy())
591
+ ```
592
+ """
593
+ dtype = kwargs.pop("dtype", dtypes.default_float)
176
594
  return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
177
595
 
178
596
  @staticmethod
179
- def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
597
+ def scaled_uniform(*shape, **kwargs) -> Tensor:
598
+ """
599
+ Creates a tensor with the given shape, filled with random values from a uniform distribution
600
+ over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
601
+
602
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
603
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
604
+
605
+ ```python exec="true" source="above" session="tensor" result="python"
606
+ Tensor.manual_seed(42)
607
+ print(Tensor.scaled_uniform(2, 3).numpy())
608
+ ```
609
+ """
610
+ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
180
611
 
181
612
  # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
182
613
  @staticmethod
183
- def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
614
+ def glorot_uniform(*shape, **kwargs) -> Tensor:
615
+ """
616
+ <https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
617
+
618
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
619
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
620
+
621
+ ```python exec="true" source="above" session="tensor" result="python"
622
+ Tensor.manual_seed(42)
623
+ print(Tensor.glorot_uniform(2, 3).numpy())
624
+ ```
625
+ """
626
+ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
184
627
 
185
628
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
186
629
  @staticmethod
187
630
  def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
188
- bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
631
+ """
632
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
633
+
634
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
635
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
636
+
637
+ ```python exec="true" source="above" session="tensor" result="python"
638
+ Tensor.manual_seed(42)
639
+ print(Tensor.kaiming_uniform(2, 3).numpy())
640
+ ```
641
+ """
642
+ bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
189
643
  return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
190
644
 
191
645
  # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
192
646
  @staticmethod
193
647
  def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
194
- std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
648
+ """
649
+ <https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
650
+
651
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
652
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
653
+
654
+ ```python exec="true" source="above" session="tensor" result="python"
655
+ Tensor.manual_seed(42)
656
+ print(Tensor.kaiming_normal(2, 3).numpy())
657
+ ```
658
+ """
659
+ std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
195
660
  return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
196
661
 
662
+ def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
663
+ assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
664
+ assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
665
+ weight = self.unsqueeze(0) if self.ndim == 1 else self
666
+ cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
667
+ unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
668
+ indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
669
+ return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
670
+
197
671
  # ***** toposort and backward pass *****
198
- def deepwalk(self):
199
- def _deepwalk(node, visited, nodes):
672
+
673
+ def _deepwalk(self):
674
+ def _walk(node, visited):
200
675
  visited.add(node)
201
676
  if getattr(node, "_ctx", None):
202
677
  for i in node._ctx.parents:
203
- if i not in visited: _deepwalk(i, visited, nodes)
204
- nodes.append(node)
205
- return nodes
206
- return _deepwalk(self, set(), [])
207
-
208
- def backward(self):
678
+ if i not in visited: yield from _walk(i, visited)
679
+ yield node
680
+ return list(_walk(self, set()))
681
+
682
+ def backward(self) -> Tensor:
683
+ """
684
+ Propagates the gradient of a tensor backwards through the computation graph.
685
+ Must be used on a scalar tensor.
686
+
687
+ ```python exec="true" source="above" session="tensor" result="python"
688
+ t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
689
+ t.sum().backward()
690
+ print(t.grad.numpy())
691
+ ```
692
+ """
209
693
  assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
210
694
 
211
695
  # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
212
696
  # this is "implicit gradient creation"
213
- self.grad = Tensor(1, device=self.device, requires_grad=False)
697
+ self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
214
698
 
215
- for t0 in reversed(self.deepwalk()):
216
- assert (t0.grad is not None)
699
+ for t0 in reversed(self._deepwalk()):
700
+ if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
217
701
  grads = t0._ctx.backward(t0.grad.lazydata)
218
702
  grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
219
703
  for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
@@ -222,405 +706,1805 @@ class Tensor:
222
706
  assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
223
707
  t.grad = g if t.grad is None else (t.grad + g)
224
708
  del t0._ctx
709
+ return self
710
+
711
+ # ***** movement low level ops *****
712
+
713
+ def view(self, *shape) -> Tensor:
714
+ """`.view` is an alias for `.reshape`."""
715
+ return self.reshape(shape)
225
716
 
226
- # ***** movement mlops *****
227
717
  def reshape(self, shape, *args) -> Tensor:
718
+ """
719
+ Returns a tensor with the same data as the original tensor but with a different shape.
720
+ `shape` can be passed as a tuple or as separate arguments.
721
+
722
+ ```python exec="true" source="above" session="tensor" result="python"
723
+ t = Tensor.arange(6)
724
+ print(t.reshape(2, 3).numpy())
725
+ ```
726
+ """
228
727
  new_shape = argfix(shape, *args)
229
- assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
230
- return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
231
- def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
232
- def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
233
- def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
234
- def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
235
- def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
236
- ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
237
- return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
238
-
239
- # ***** movement hlops *****
240
-
241
- # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
242
- # - A slice i:j returns the elements with indices in [i, j)
243
- # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
244
- # - Negative values for i and j are taken relative to the end of the sequence
245
- # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
246
- # - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis
247
- # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
248
- # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
249
- # - Strides > 1 and < 0 are now allowed!:
250
- # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
251
- # - Idea of stride < 0 support:
252
- # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
253
- # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
254
- # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
255
- # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
256
- # is possible.
257
- # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
258
- # - Fancy indexing and combined indexing is supported
259
- # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
260
- # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
261
- # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
262
- # - The dims are reduced at sum_dim for each Tensor passed in
263
- # - There's a special case where a permute is needed at the end:
264
- # - if first Tensor passed in (expand dims) is not at dim 0
265
- # - and following Tensors does not follow consecutively to the end of fancy indexing's dims
266
- def __getitem__(self, val):
267
- def normalize_int(e, i, dim_sz):
268
- if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
269
- raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
270
- orig_slices = list(val) if isinstance(val, tuple) else [val]
271
- if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in orig_slices)) > len(self.shape):
272
- raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
273
- ellipses_found = [i for i, v in enumerate(orig_slices) if v is Ellipsis]
274
- if len(ellipses_found) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
275
- ellipsis_idx = ellipses_found[0] if ellipses_found else len(orig_slices)
276
- orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
277
-
278
- tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
279
- orig_slices = [slice(None) if isinstance(v, Tensor) else v for v in orig_slices]
280
- valid_slices = [s for s in orig_slices if s is not None]
281
- valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
282
- start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
283
- new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
284
- # Shrink
285
- sliced_tensor = self.shrink(new_slice)
286
- new_shape = sliced_tensor.shape
287
- # Flip
288
- if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
289
- sliced_tensor = sliced_tensor.flip(axis=flip_axes)
290
- if any(s > 1 or s < 0 for s in strides):
291
- # normalize if negative strides
728
+ new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
729
+ return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
730
+
731
+ def expand(self, shape, *args) -> Tensor:
732
+ """
733
+ Returns a tensor that is expanded to the shape that is specified.
734
+ Expand can also increase the number of dimensions that a tensor has.
735
+
736
+ Passing a `-1` or `None` to a dimension means that its size will not be changed.
737
+
738
+ ```python exec="true" source="above" session="tensor" result="python"
739
+ t = Tensor([1, 2, 3])
740
+ print(t.expand(4, -1).numpy())
741
+ ```
742
+ """
743
+ return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
744
+
745
+ def permute(self, order, *args) -> Tensor:
746
+ """
747
+ Returns a tensor that is a permutation of the original tensor.
748
+ The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
749
+ `order` can be passed as a tuple or as separate arguments.
750
+
751
+ ```python exec="true" source="above" session="tensor" result="python"
752
+ t = Tensor.arange(6).reshape(2, 3)
753
+ print(t.numpy())
754
+ ```
755
+ ```python exec="true" source="above" session="tensor" result="python"
756
+ print(t.permute(1, 0).numpy())
757
+ ```
758
+ """
759
+ return F.Permute.apply(self, order=argfix(order, *args))
760
+
761
+ def flip(self, axis, *args) -> Tensor:
762
+ """
763
+ Returns a tensor that reverses the order of the original tensor along given `axis`.
764
+ `axis` can be passed as a tuple or as separate arguments.
765
+
766
+ ```python exec="true" source="above" session="tensor" result="python"
767
+ t = Tensor.arange(6).reshape(2, 3)
768
+ print(t.numpy())
769
+ ```
770
+ ```python exec="true" source="above" session="tensor" result="python"
771
+ print(t.flip(0).numpy())
772
+ ```
773
+ ```python exec="true" source="above" session="tensor" result="python"
774
+ print(t.flip((0, 1)).numpy())
775
+ ```
776
+ """
777
+ return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
778
+
779
+ def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
780
+ """
781
+ Returns a tensor that shrinks the each axis based on input arg.
782
+ `arg` must have the same length as `self.ndim`.
783
+ For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
784
+
785
+ ```python exec="true" source="above" session="tensor" result="python"
786
+ t = Tensor.arange(9).reshape(3, 3)
787
+ print(t.numpy())
788
+ ```
789
+ ```python exec="true" source="above" session="tensor" result="python"
790
+ print(t.shrink(((None, (1, 3)))).numpy())
791
+ ```
792
+ ```python exec="true" source="above" session="tensor" result="python"
793
+ print(t.shrink((((0, 2), (0, 2)))).numpy())
794
+ ```
795
+ """
796
+ if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
797
+ return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
798
+
799
+ def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
800
+ """
801
+ Returns a tensor that pads the each axis based on input arg.
802
+ `arg` must have the same length as `self.ndim`.
803
+ For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
804
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
805
+
806
+ ```python exec="true" source="above" session="tensor" result="python"
807
+ t = Tensor.arange(6).reshape(2, 3)
808
+ print(t.numpy())
809
+ ```
810
+ ```python exec="true" source="above" session="tensor" result="python"
811
+ print(t.pad(((None, (1, 2)))).numpy())
812
+ ```
813
+ ```python exec="true" source="above" session="tensor" result="python"
814
+ print(t.pad(((None, (1, 2))), -2).numpy())
815
+ ```
816
+ """
817
+ if all(x is None or x == (0,0) for x in arg): return self
818
+ ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
819
+ return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
820
+
821
+ # ***** movement high level ops *****
822
+
823
+ # Supported Indexing Implementations:
824
+ # 1. Int indexing (no copy)
825
+ # - for all dims where there's int, shrink -> reshape
826
+ # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
827
+ # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
828
+ # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
829
+ # 2. Slice indexing (no copy)
830
+ # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
831
+ # - first shrink the Tensor to X.shrink(((start, end),))
832
+ # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
833
+ # - flip where dim value is negative
834
+ # - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
835
+ # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
836
+ # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
837
+ # 3. None indexing (no copy)
838
+ # - reshape (inject) a dim at the dim where there's None
839
+ # 4. Tensor indexing (copy)
840
+ # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
841
+ # - combine masks together with mul
842
+ # - apply mask to self by mask * self
843
+ # - sum reduce away the extra dims added from creating masks
844
+ # Tiny Things:
845
+ # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
846
+ # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
847
+ # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
848
+ # 2. Bool indexing is not supported
849
+ # 3. Out of bounds Tensor indexing results in 0
850
+ # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
851
+ def __getitem__(self, indices) -> Tensor:
852
+ # 1. indices normalization and validation
853
+ # treat internal tuples and lists as Tensors and standardize indices to list type
854
+ if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
855
+ elif isinstance(indices, (tuple, list)):
856
+ indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
857
+ else: indices = [indices]
858
+
859
+ # turn scalar Tensors into const val for int indexing if possible
860
+ indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
861
+ # move Tensor indices to the same device as self
862
+ indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
863
+
864
+ # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
865
+ ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
866
+ fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
867
+ num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
868
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
869
+
870
+ # use Dict[type, List[dimension]] to track elements in indices
871
+ type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
872
+
873
+ # record None for dimension injection later and filter None and record rest of indices
874
+ type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
875
+ indices_filtered = [v for v in indices if v is not None]
876
+ for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
877
+
878
+ for index_type in type_dim:
879
+ if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
880
+ if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
881
+ if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
882
+
883
+ # 2. basic indexing, uses only movement ops (no copy)
884
+ # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
885
+ # turn indices in indices_filtered to Tuple[shrink_arg, strides]
886
+ for dim in type_dim[int]:
887
+ if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
888
+ raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
889
+ indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
890
+ for dim in type_dim[slice]:
891
+ if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
892
+ s, e, st = index.indices(self.shape[dim])
893
+ indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
894
+ # record tensors and skip all Tensor dims for basic indexing
895
+ tensor_index: List[Tensor] = []
896
+ for dim in type_dim[Tensor]:
897
+ tensor_index.append(index := indices_filtered[dim])
898
+ if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
899
+ indices_filtered[dim] = ((0, self.shape[dim]), 1)
900
+
901
+ new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
902
+ ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
903
+ if any(abs(s) != 1 for s in strides):
292
904
  strides = tuple(abs(s) for s in strides)
293
- def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0 else (step - y)
294
- # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
295
- paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
296
- padded_tensor = sliced_tensor.pad(paddings)
297
- # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
298
- new_shape = flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))
299
- reshaped_tensor = padded_tensor.reshape(new_shape)
300
- # Shrink: do [:, 0]
301
- new_shape = new_shape[::2]
302
- final_slice = tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))
303
- sliced_tensor = reshaped_tensor.shrink(final_slice)
304
- final_shape, it_shape = [], iter(new_shape)
305
- sub = [0] * len(tensor_found)
306
- for i,s in enumerate(orig_slices):
307
- if isinstance(s, (int, slice)):
308
- dim_shape = next(it_shape)
309
- if isinstance(s, slice): final_shape.append(dim_shape)
310
- elif tensor_found:
311
- for i_ in range(len(tensor_found)):
312
- if tensor_found[i_][0] > i: sub[i_] -= 1
313
- else: # s is None
314
- final_shape.append(1)
315
- ret = sliced_tensor.reshape(tuple(final_shape)) # Reshape
316
- if tensor_found: # Fancy/tensor indexing
317
- for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
318
- dim = [i[0] for i in tensor_found]
319
- idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
320
- max_dim = max(i.ndim for i in idx)
321
- idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
322
- sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
323
- new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
324
- arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
325
- ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
326
- for idx_,d in zip(idx[1:],sum_dim[1:]):
327
- new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
328
- arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
329
- ret = ((new_idx == arange) * ret).sum(d)
330
- if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
331
- order = list(range(ret.ndim))
332
- order = order[dim[0]:dim[0]+idx[0].ndim] + order[:dim[0]] + order[dim[0]+idx[0].ndim:]
333
- ret = ret.permute(order=order)
905
+ ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
906
+ ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
907
+ ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
908
+
909
+ # inject 1 for dim where it's None and collapse dim for int
910
+ new_shape = list(ret.shape)
911
+ for dim in type_dim[None]: new_shape.insert(dim, 1)
912
+ for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
913
+
914
+ ret = ret.reshape(new_shape)
915
+ assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
916
+
917
+ # 3. advanced indexing (copy)
918
+ if type_dim[Tensor]:
919
+ # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
920
+ def calc_dim(tensor_dim:int) -> int:
921
+ return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
922
+
923
+ # track tensor_dim and tensor_index using a dict
924
+ # calc_dim to get dim and use that to normalize the negative tensor indices
925
+ idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
926
+
927
+ masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
928
+ pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
929
+
930
+ # create masks
931
+ for dim, i in idx.items():
932
+ try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
933
+ except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
934
+ a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
935
+ masks.append(i == a)
936
+
937
+ # reduce masks to 1 mask
938
+ mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
939
+
940
+ # inject 1's for the extra dims added in create masks
941
+ sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
942
+ # sum reduce the extra dims introduced in create masks
943
+ ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
944
+
945
+ # special permute case
946
+ if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
947
+ ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
334
948
  return ret
335
949
 
336
- # NOTE: using slice is discouraged and things should migrate to pad and shrink
337
- def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
338
- arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
339
- padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
340
- return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
341
-
342
- def gather(self: Tensor, idx: Tensor, dim: int):
343
- assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
344
- assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
345
- if dim < 0: dim += self.ndim
346
- idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
950
+ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
951
+ if isinstance(self.device, str) and self.device.startswith("DISK"):
952
+ self.__getitem__(indices).assign(v)
953
+ return
954
+ # NOTE: check that setitem target is valid first
955
+ assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
956
+ if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
957
+ if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
958
+ if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
959
+ if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
960
+ raise NotImplementedError("Advanced indexing setitem is not currently supported")
961
+
962
+ assign_to = self.realize().__getitem__(indices)
963
+ # NOTE: contiguous to prevent const folding.
964
+ v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
965
+ assign_to.assign(v).realize()
966
+
967
+ # NOTE: using _slice is discouraged and things should migrate to pad and shrink
968
+ def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
969
+ arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
970
+ padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
971
+ return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
972
+
973
+ def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
974
+ """
975
+ Gathers values along an axis specified by `dim`.
976
+
977
+ ```python exec="true" source="above" session="tensor" result="python"
978
+ t = Tensor([[1, 2], [3, 4]])
979
+ print(t.numpy())
980
+ ```
981
+ ```python exec="true" source="above" session="tensor" result="python"
982
+ print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
983
+ ```
984
+ """
985
+ assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
986
+ assert all(s >= i for s,i in zip(self.shape, index.shape)), "all dim of index.shape must be smaller than self.shape"
987
+ dim = self._resolve_dim(dim)
988
+ index = index.to(self.device).transpose(0, dim).unsqueeze(-1)
347
989
  permarg = list(range(self.ndim))
348
990
  permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
349
- return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
350
-
351
- def cat(self, *args, dim=0):
352
- dim = (dim + len(self.shape)) if dim < 0 else dim
991
+ return ((index == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
992
+ tuple([*[(0,sh) for sh in index.shape[1:-1]], None])).unsqueeze(0)).sum(-1, acc_dtype=self.dtype).transpose(0, dim)
993
+
994
+ def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
995
+ """
996
+ Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
997
+ All tensors must have the same shape except in the concatenating dimension.
998
+
999
+ ```python exec="true" source="above" session="tensor" result="python"
1000
+ t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
1001
+ print(t0.cat(t1, t2, dim=0).numpy())
1002
+ ```
1003
+ ```python exec="true" source="above" session="tensor" result="python"
1004
+ print(t0.cat(t1, t2, dim=1).numpy())
1005
+ ```
1006
+ """
1007
+ dim = self._resolve_dim(dim)
353
1008
  assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
354
1009
  catargs = [self, *args]
355
- assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
356
- shapes = [s.shape[dim] for s in catargs]
357
- shape_cumsum = [0, *accumulate(shapes)]
358
- slc = [[(0, 0) for _ in self.shape] for _ in catargs]
359
- for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
360
- s[dim] = (k, shape_cumsum[-1] - k - shp)
361
- return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
362
-
363
- @staticmethod
364
- def stack(tensors, dim=0):
365
- first = tensors[0].unsqueeze(dim)
366
- unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
1010
+ cat_dims = [s.shape[dim] for s in catargs]
1011
+ cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
1012
+ slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
1013
+ for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
1014
+ return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
1015
+
1016
+ def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1017
+ """
1018
+ Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
1019
+
1020
+ ```python exec="true" source="above" session="tensor" result="python"
1021
+ t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
1022
+ print(t0.stack(t1, t2, dim=0).numpy())
1023
+ ```
1024
+ ```python exec="true" source="above" session="tensor" result="python"
1025
+ print(t0.stack(t1, t2, dim=1).numpy())
1026
+ ```
1027
+ """
367
1028
  # checks for shapes and number of dimensions delegated to cat
368
- return first.cat(*unsqueezed_tensors, dim=dim)
369
-
370
- def repeat(self, repeats):
1029
+ return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
1030
+
1031
+ def repeat(self, repeats, *args) -> Tensor:
1032
+ """
1033
+ Repeats tensor number of times along each dimension specified by `repeats`.
1034
+ `repeats` can be passed as a tuple or as separate arguments.
1035
+
1036
+ ```python exec="true" source="above" session="tensor" result="python"
1037
+ t = Tensor([1, 2, 3])
1038
+ print(t.repeat(4, 2).numpy())
1039
+ ```
1040
+ ```python exec="true" source="above" session="tensor" result="python"
1041
+ print(t.repeat(4, 2, 1).shape)
1042
+ ```
1043
+ """
1044
+ repeats = argfix(repeats, *args)
371
1045
  base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
372
1046
  new_shape = [x for b in base_shape for x in [1, b]]
373
1047
  expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
374
1048
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
375
1049
  return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
376
1050
 
377
- def chunk(self, num:int, dim:int) -> List[Tensor]:
378
- dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
379
- slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
380
- return [self[tuple(sl)] for sl in slice_params]
381
-
382
- def squeeze(self, dim=None):
383
- if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
384
- if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
385
- if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
386
- if dim < 0: dim += self.ndim
387
- return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
388
-
389
- def unsqueeze(self, dim):
390
- if dim < 0: dim = len(self.shape) + dim + 1
1051
+ def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
1052
+ if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
1053
+ raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
1054
+ return dim + self.ndim+outer if dim < 0 else dim
1055
+
1056
+ def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1057
+ """
1058
+ Splits the tensor into chunks along the dimension specified by `dim`.
1059
+ If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
1060
+ If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
1061
+
1062
+ ```python exec="true" source="above" session="tensor" result="python"
1063
+ t = Tensor.arange(10).reshape(5, 2)
1064
+ print(t.numpy())
1065
+ ```
1066
+ ```python exec="true" source="above" session="tensor" result="python"
1067
+ split = t.split(2)
1068
+ print("\\n".join([repr(x.numpy()) for x in split]))
1069
+ ```
1070
+ ```python exec="true" source="above" session="tensor" result="python"
1071
+ split = t.split([1, 4])
1072
+ print("\\n".join([repr(x.numpy()) for x in split]))
1073
+ ```
1074
+ """
1075
+ assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1076
+ dim = self._resolve_dim(dim)
1077
+ if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
1078
+ assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
1079
+ return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
1080
+
1081
+ def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
1082
+ """
1083
+ Splits the tensor into `chunks` number of chunks along the dimension `dim`.
1084
+ If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
1085
+ The function may return fewer than the specified number of chunks.
1086
+
1087
+ ```python exec="true" source="above" session="tensor" result="python"
1088
+ chunked = Tensor.arange(11).chunk(6)
1089
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1090
+ ```
1091
+ ```python exec="true" source="above" session="tensor" result="python"
1092
+ chunked = Tensor.arange(12).chunk(6)
1093
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1094
+ ```
1095
+ ```python exec="true" source="above" session="tensor" result="python"
1096
+ chunked = Tensor.arange(13).chunk(6)
1097
+ print("\\n".join([repr(x.numpy()) for x in chunked]))
1098
+ ```
1099
+ """
1100
+ assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1101
+ assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
1102
+ dim = self._resolve_dim(dim)
1103
+ return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1104
+
1105
+ def squeeze(self, dim:Optional[int]=None) -> Tensor:
1106
+ """
1107
+ Returns a tensor with specified dimensions of input of size 1 removed.
1108
+ If `dim` is not specified, all dimensions with size 1 are removed.
1109
+
1110
+ ```python exec="true" source="above" session="tensor" result="python"
1111
+ t = Tensor.zeros(2, 1, 2, 1, 2)
1112
+ print(t.squeeze().shape)
1113
+ ```
1114
+ ```python exec="true" source="above" session="tensor" result="python"
1115
+ print(t.squeeze(0).shape)
1116
+ ```
1117
+ ```python exec="true" source="above" session="tensor" result="python"
1118
+ print(t.squeeze(1).shape)
1119
+ ```
1120
+ """
1121
+ if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
1122
+ dim = self._resolve_dim(dim)
1123
+ return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
1124
+
1125
+ def unsqueeze(self, dim:int) -> Tensor:
1126
+ """
1127
+ Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
1128
+
1129
+ ```python exec="true" source="above" session="tensor" result="python"
1130
+ t = Tensor([1, 2, 3, 4])
1131
+ print(t.unsqueeze(0).numpy())
1132
+ ```
1133
+ ```python exec="true" source="above" session="tensor" result="python"
1134
+ print(t.unsqueeze(1).numpy())
1135
+ ```
1136
+ """
1137
+ dim = self._resolve_dim(dim, outer=True)
391
1138
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
392
1139
 
393
- # (padding_left, padding_right, padding_top, padding_bottom)
394
- def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
1140
+ def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
1141
+ """
1142
+ Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
1143
+ If `value` is specified, the tensor is padded with `value` instead of `0.0`.
1144
+
1145
+ ```python exec="true" source="above" session="tensor" result="python"
1146
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1147
+ print(t.numpy())
1148
+ ```
1149
+ ```python exec="true" source="above" session="tensor" result="python"
1150
+ print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
1151
+ ```
1152
+ """
395
1153
  slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
396
- return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
1154
+ return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
397
1155
 
398
1156
  @property
399
- def T(self) -> Tensor: return self.transpose()
400
- def transpose(self, ax1=1, ax2=0) -> Tensor:
401
- order = list(range(len(self.shape)))
402
- order[ax1], order[ax2] = order[ax2], order[ax1]
1157
+ def T(self) -> Tensor:
1158
+ """`.T` is an alias for `.transpose()`."""
1159
+ return self.transpose()
1160
+
1161
+ def transpose(self, dim0=1, dim1=0) -> Tensor:
1162
+ """
1163
+ Returns a tensor that is a transposed version of the original tensor.
1164
+ The given dimensions `dim0` and `dim1` are swapped.
1165
+
1166
+ ```python exec="true" source="above" session="tensor" result="python"
1167
+ t = Tensor.arange(6).reshape(2, 3)
1168
+ print(t.numpy())
1169
+ ```
1170
+ ```python exec="true" source="above" session="tensor" result="python"
1171
+ print(t.transpose(0, 1).numpy())
1172
+ ```
1173
+ """
1174
+ order = list(range(self.ndim))
1175
+ order[dim0], order[dim1] = order[dim1], order[dim0]
403
1176
  return self.permute(order)
404
- def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,))
405
1177
 
406
- # ***** reduce ops *****
1178
+ def flatten(self, start_dim=0, end_dim=-1):
1179
+ """
1180
+ Flattens the tensor by reshaping it into a one-dimensional tensor.
1181
+ If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
1182
+
1183
+ ```python exec="true" source="above" session="tensor" result="python"
1184
+ t = Tensor.arange(8).reshape(2, 2, 2)
1185
+ print(t.flatten().numpy())
1186
+ ```
1187
+ ```python exec="true" source="above" session="tensor" result="python"
1188
+ print(t.flatten(start_dim=1).numpy())
1189
+ ```
1190
+ """
1191
+ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
1192
+ return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
1193
+
1194
+ def unflatten(self, dim:int, sizes:Tuple[int,...]):
1195
+ """
1196
+ Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
1197
+
1198
+ ```python exec="true" source="above" session="tensor" result="python"
1199
+ print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
1200
+ ```
1201
+ ```python exec="true" source="above" session="tensor" result="python"
1202
+ print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
1203
+ ```
1204
+ ```python exec="true" source="above" session="tensor" result="python"
1205
+ print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
1206
+ ```
1207
+ """
1208
+ dim = self._resolve_dim(dim)
1209
+ return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
407
1210
 
408
- def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
409
- axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
410
- axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
411
- shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
412
- ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
413
- return ret if keepdim else ret.reshape(shape=shape)
1211
+ # ***** reduce ops *****
414
1212
 
415
- def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
416
- def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
417
- def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
1213
+ def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
1214
+ if self.ndim == 0:
1215
+ if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
1216
+ axis = None
1217
+ axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
1218
+ axis_ = tuple(self._resolve_dim(x) for x in axis_)
1219
+ ret = fxn.apply(self, axis=axis_)
1220
+ return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1221
+
1222
+ def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
1223
+ """
1224
+ Sums the elements of the tensor along the specified axis or axes.
1225
+
1226
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1227
+ which the maximum is computed and whether the reduced dimensions are retained.
1228
+
1229
+ You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1230
+ If not specified, the accumulation data type is chosen based on the input tensor's data type.
1231
+
1232
+ ```python exec="true" source="above" session="tensor" result="python"
1233
+ t = Tensor.arange(6).reshape(2, 3)
1234
+ print(t.numpy())
1235
+ ```
1236
+ ```python exec="true" source="above" session="tensor" result="python"
1237
+ print(t.sum().numpy())
1238
+ ```
1239
+ ```python exec="true" source="above" session="tensor" result="python"
1240
+ print(t.sum(axis=0).numpy())
1241
+ ```
1242
+ ```python exec="true" source="above" session="tensor" result="python"
1243
+ print(t.sum(axis=1).numpy())
1244
+ ```
1245
+ """
1246
+ ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
1247
+ return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
1248
+ def max(self, axis=None, keepdim=False):
1249
+ """
1250
+ Returns the maximum value of the tensor along the specified axis or axes.
1251
+
1252
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1253
+ which the maximum is computed and whether the reduced dimensions are retained.
1254
+
1255
+ ```python exec="true" source="above" session="tensor" result="python"
1256
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1257
+ print(t.numpy())
1258
+ ```
1259
+ ```python exec="true" source="above" session="tensor" result="python"
1260
+ print(t.max().numpy())
1261
+ ```
1262
+ ```python exec="true" source="above" session="tensor" result="python"
1263
+ print(t.max(axis=0).numpy())
1264
+ ```
1265
+ ```python exec="true" source="above" session="tensor" result="python"
1266
+ print(t.max(axis=1, keepdim=True).numpy())
1267
+ ```
1268
+ """
1269
+ return self._reduce(F.Max, axis, keepdim)
1270
+ def min(self, axis=None, keepdim=False):
1271
+ """
1272
+ Returns the minimum value of the tensor along the specified axis or axes.
1273
+
1274
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1275
+ which the minimum is computed and whether the reduced dimensions are retained.
1276
+
1277
+ ```python exec="true" source="above" session="tensor" result="python"
1278
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1279
+ print(t.numpy())
1280
+ ```
1281
+ ```python exec="true" source="above" session="tensor" result="python"
1282
+ print(t.min().numpy())
1283
+ ```
1284
+ ```python exec="true" source="above" session="tensor" result="python"
1285
+ print(t.min(axis=0).numpy())
1286
+ ```
1287
+ ```python exec="true" source="above" session="tensor" result="python"
1288
+ print(t.min(axis=1, keepdim=True).numpy())
1289
+ ```
1290
+ """
1291
+ return -((-self).max(axis=axis, keepdim=keepdim))
418
1292
 
419
1293
  def mean(self, axis=None, keepdim=False):
420
- out = self.sum(axis=axis, keepdim=keepdim)
421
- return out * (math.prod(out.shape)/math.prod(self.shape))
422
- def std(self, axis=None, keepdim=False, correction=1):
1294
+ """
1295
+ Returns the mean value of the tensor along the specified axis or axes.
1296
+
1297
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1298
+ which the mean is computed and whether the reduced dimensions are retained.
1299
+
1300
+ ```python exec="true" source="above" session="tensor" result="python"
1301
+ Tensor.manual_seed(42)
1302
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1303
+ print(t.numpy())
1304
+ ```
1305
+ ```python exec="true" source="above" session="tensor" result="python"
1306
+ print(t.mean().numpy())
1307
+ ```
1308
+ ```python exec="true" source="above" session="tensor" result="python"
1309
+ print(t.mean(axis=0).numpy())
1310
+ ```
1311
+ ```python exec="true" source="above" session="tensor" result="python"
1312
+ print(t.mean(axis=1).numpy())
1313
+ ```
1314
+ """
1315
+ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
1316
+ numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1317
+ return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
1318
+
1319
+ def var(self, axis=None, keepdim=False, correction=1):
1320
+ """
1321
+ Returns the variance of the tensor along the specified axis or axes.
1322
+
1323
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1324
+ which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1325
+
1326
+ ```python exec="true" source="above" session="tensor" result="python"
1327
+ Tensor.manual_seed(42)
1328
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1329
+ print(t.numpy())
1330
+ ```
1331
+ ```python exec="true" source="above" session="tensor" result="python"
1332
+ print(t.var().numpy())
1333
+ ```
1334
+ ```python exec="true" source="above" session="tensor" result="python"
1335
+ print(t.var(axis=0).numpy())
1336
+ ```
1337
+ ```python exec="true" source="above" session="tensor" result="python"
1338
+ print(t.var(axis=1).numpy())
1339
+ ```
1340
+ """
1341
+ assert all_int(self.shape), "does not support symbolic shape"
423
1342
  square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
424
- return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
1343
+ return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
1344
+
1345
+ def std(self, axis=None, keepdim=False, correction=1):
1346
+ """
1347
+ Returns the standard deviation of the tensor along the specified axis or axes.
1348
+
1349
+ You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
1350
+ which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
1351
+
1352
+ ```python exec="true" source="above" session="tensor" result="python"
1353
+ Tensor.manual_seed(42)
1354
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1355
+ print(t.numpy())
1356
+ ```
1357
+ ```python exec="true" source="above" session="tensor" result="python"
1358
+ print(t.std().numpy())
1359
+ ```
1360
+ ```python exec="true" source="above" session="tensor" result="python"
1361
+ print(t.std(axis=0).numpy())
1362
+ ```
1363
+ ```python exec="true" source="above" session="tensor" result="python"
1364
+ print(t.std(axis=1).numpy())
1365
+ ```
1366
+ """
1367
+ return self.var(axis, keepdim, correction).sqrt()
1368
+
425
1369
  def _softmax(self, axis):
426
1370
  m = self - self.max(axis=axis, keepdim=True)
427
1371
  e = m.exp()
428
1372
  return m, e, e.sum(axis=axis, keepdim=True)
429
1373
 
430
1374
  def softmax(self, axis=-1):
1375
+ """
1376
+ Applies the softmax function to the tensor along the specified axis.
1377
+
1378
+ Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
1379
+
1380
+ You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
1381
+
1382
+ ```python exec="true" source="above" session="tensor" result="python"
1383
+ Tensor.manual_seed(42)
1384
+ t = Tensor.randn(2, 3)
1385
+ print(t.numpy())
1386
+ ```
1387
+ ```python exec="true" source="above" session="tensor" result="python"
1388
+ print(t.softmax().numpy())
1389
+ ```
1390
+ ```python exec="true" source="above" session="tensor" result="python"
1391
+ print(t.softmax(axis=0).numpy())
1392
+ ```
1393
+ """
431
1394
  _, e, ss = self._softmax(axis)
432
1395
  return e.div(ss)
433
1396
 
434
1397
  def log_softmax(self, axis=-1):
1398
+ """
1399
+ Applies the log-softmax function to the tensor along the specified axis.
1400
+
1401
+ The log-softmax function is a numerically stable alternative to the softmax function in log space.
1402
+
1403
+ You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
1404
+
1405
+ ```python exec="true" source="above" session="tensor" result="python"
1406
+ Tensor.manual_seed(42)
1407
+ t = Tensor.randn(2, 3)
1408
+ print(t.numpy())
1409
+ ```
1410
+ ```python exec="true" source="above" session="tensor" result="python"
1411
+ print(t.log_softmax().numpy())
1412
+ ```
1413
+ ```python exec="true" source="above" session="tensor" result="python"
1414
+ print(t.log_softmax(axis=0).numpy())
1415
+ ```
1416
+ """
435
1417
  m, _, ss = self._softmax(axis)
436
1418
  return m - ss.log()
437
1419
 
1420
+ def logsumexp(self, axis=None, keepdim=False):
1421
+ """
1422
+ Computes the log-sum-exp of the tensor along the specified axis or axes.
1423
+
1424
+ The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
1425
+
1426
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1427
+ which the log-sum-exp is computed and whether the reduced dimensions are retained.
1428
+
1429
+ ```python exec="true" source="above" session="tensor" result="python"
1430
+ Tensor.manual_seed(42)
1431
+ t = Tensor.randn(2, 3)
1432
+ print(t.numpy())
1433
+ ```
1434
+ ```python exec="true" source="above" session="tensor" result="python"
1435
+ print(t.logsumexp().numpy())
1436
+ ```
1437
+ ```python exec="true" source="above" session="tensor" result="python"
1438
+ print(t.logsumexp(axis=0).numpy())
1439
+ ```
1440
+ ```python exec="true" source="above" session="tensor" result="python"
1441
+ print(t.logsumexp(axis=1).numpy())
1442
+ ```
1443
+ """
1444
+ m = self.max(axis=axis, keepdim=True)
1445
+ return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
1446
+
438
1447
  def argmax(self, axis=None, keepdim=False):
1448
+ """
1449
+ Returns the indices of the maximum value of the tensor along the specified axis.
1450
+
1451
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1452
+ which the maximum is computed and whether the reduced dimensions are retained.
1453
+
1454
+ ```python exec="true" source="above" session="tensor" result="python"
1455
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1456
+ print(t.numpy())
1457
+ ```
1458
+ ```python exec="true" source="above" session="tensor" result="python"
1459
+ print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
1460
+ ```
1461
+ ```python exec="true" source="above" session="tensor" result="python"
1462
+ print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
1463
+ ```
1464
+ ```python exec="true" source="above" session="tensor" result="python"
1465
+ print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
1466
+ ```
1467
+ """
439
1468
  if axis is None:
440
- idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
441
- return math.prod(self.shape) - idx.max() - 1
442
- axis = axis + len(self.shape) if axis < 0 else axis
1469
+ idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
1470
+ return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
1471
+ axis = self._resolve_dim(axis)
443
1472
  m = self == self.max(axis=axis, keepdim=True)
444
- idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
445
- return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
446
- def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
1473
+ idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
1474
+ return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
1475
+
1476
+ def argmin(self, axis=None, keepdim=False):
1477
+ """
1478
+ Returns the indices of the minimum value of the tensor along the specified axis.
1479
+
1480
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1481
+ which the minimum is computed and whether the reduced dimensions are retained.
1482
+
1483
+ ```python exec="true" source="above" session="tensor" result="python"
1484
+ t = Tensor([[1, 0, 2], [5, 4, 3]])
1485
+ print(t.numpy())
1486
+ ```
1487
+ ```python exec="true" source="above" session="tensor" result="python"
1488
+ print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
1489
+ ```
1490
+ ```python exec="true" source="above" session="tensor" result="python"
1491
+ print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
1492
+ ```
1493
+ ```python exec="true" source="above" session="tensor" result="python"
1494
+ print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
1495
+ ```
1496
+ """
1497
+ return (-self).argmax(axis=axis, keepdim=keepdim)
1498
+
1499
+ @staticmethod
1500
+ def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
1501
+ """
1502
+ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
1503
+
1504
+ See: https://pytorch.org/docs/stable/generated/torch.einsum.html
1505
+
1506
+ ```python exec="true" source="above" session="tensor" result="python"
1507
+ x = Tensor([[1, 2], [3, 4]])
1508
+ y = Tensor([[5, 6], [7, 8]])
1509
+ print(Tensor.einsum("ij,ij->", x, y).numpy())
1510
+ ```
1511
+ """
1512
+ xs:Tuple[Tensor] = argfix(*raw_xs)
1513
+ formula = formula.replace(" ", "")
1514
+ inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
1515
+ inputs = [x for x in cast(str,inputs_str).split(',')]
1516
+ assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
1517
+
1518
+ # map the value of each letter in the formula
1519
+ letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
1520
+
1521
+ xs_:List[Tensor] = []
1522
+ lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
1523
+ for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
1524
+ # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
1525
+ xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
1526
+
1527
+ # determine the inverse permutation to revert back to original order
1528
+ rhs_letter_order = argsort(list(output))
1529
+ rhs_order = argsort(rhs_letter_order)
1530
+
1531
+ # sum over all axes that's not in the output, then permute to the output order
1532
+ return functools.reduce(lambda a,b:a*b, xs_) \
1533
+ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
447
1534
 
448
1535
  # ***** processing ops *****
449
1536
 
450
- def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
1537
+ def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
451
1538
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
1539
+ assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
452
1540
  s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
453
- assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
454
- slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
1541
+ assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
1542
+ noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
455
1543
  if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
456
1544
  o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
457
- e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
458
- xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
459
- # slide by dilation
460
- xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
461
- xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
462
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
463
- # handle stride, and permute to move reduce to the end
464
- xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
465
- xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
466
- xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
467
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
1545
+ # repeats such that we don't need padding
1546
+ xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
1547
+ # slice by dilation
1548
+ xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
1549
+ # handle stride
1550
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1551
+ xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1552
+ # permute to move reduce to the end
1553
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
468
1554
  # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
469
1555
  o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
470
- xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
471
- xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
472
- xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
473
- return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
1556
+ xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1557
+ xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1558
+ xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
1559
+ return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
474
1560
 
475
1561
  # NOTE: these work for more than 2D
476
- def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
477
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
478
-
479
- def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
480
- HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
481
- x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
482
- stride = make_pair(stride, len(HW))
483
- if any(s>1 for s in stride):
484
- x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
485
- x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
486
- x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
487
- x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
488
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
489
- return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
490
-
491
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
1562
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1563
+ """
1564
+ Applies average pooling over a tensor.
1565
+
1566
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1567
+
1568
+ See: https://paperswithcode.com/method/average-pooling
1569
+
1570
+ ```python exec="true" source="above" session="tensor" result="python"
1571
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1572
+ print(t.avg_pool2d().numpy())
1573
+ ```
1574
+ """
1575
+ return self._pool(
1576
+ make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1577
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1578
+ """
1579
+ Applies max pooling over a tensor.
1580
+
1581
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
1582
+
1583
+ See: https://paperswithcode.com/method/max-pooling
1584
+
1585
+ ```python exec="true" source="above" session="tensor" result="python"
1586
+ t = Tensor.arange(25).reshape(1, 1, 5, 5)
1587
+ print(t.max_pool2d().numpy())
1588
+ ```
1589
+ """
1590
+ return self._pool(
1591
+ make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1592
+
1593
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
1594
+ """
1595
+ Applies a convolution over a tensor with a given `weight` and optional `bias`.
1596
+
1597
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
1598
+
1599
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
1600
+
1601
+ ```python exec="true" source="above" session="tensor" result="python"
1602
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1603
+ w = Tensor.ones(1, 1, 2, 2)
1604
+ print(t.conv2d(w).numpy())
1605
+ ```
1606
+ """
492
1607
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
493
- assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
494
- if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
495
- padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1])
1608
+ assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
1609
+ if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
1610
+ padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
496
1611
 
497
1612
  # conv2d is a pooling op (with padding)
498
1613
  x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
499
1614
  rcout, oyx = cout//groups, x.shape[2:-len(HW)]
500
- x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
1615
+ if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
1616
+ # normal conv
1617
+ x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
1618
+
1619
+ # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
1620
+ ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
1621
+ return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
1622
+
1623
+ HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
1624
+ winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
1625
+ winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
1626
+ winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
1627
+
1628
+ # todo: stride == dilation
1629
+ # use padding to round up to 4x4 output tiles
1630
+ # (bs, cin_, tyx, HWI)
1631
+ d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
1632
+ # move HW to the front: # (HWI, bs, cin_, tyx)
1633
+ d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
1634
+ tyx = d.shape[-len(HWI):] # dim of tiling
1635
+
1636
+ g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
1637
+
1638
+ # compute 6x6 winograd tiles: GgGt, BtdB
1639
+ # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
1640
+ gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
1641
+ # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
1642
+ dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
1643
+
1644
+ # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
1645
+ ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
1646
+
1647
+ # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
1648
+ ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
1649
+ # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
1650
+ ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
1651
+
1652
+ return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
1653
+
1654
+ def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
1655
+ """
1656
+ Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
1657
+
1658
+ NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
501
1659
 
502
- # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
503
- ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
504
- return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
1660
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
505
1661
 
506
- def dot(self, w:Tensor) -> Tensor:
1662
+ ```python exec="true" source="above" session="tensor" result="python"
1663
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
1664
+ w = Tensor.ones(1, 1, 2, 2)
1665
+ print(t.conv_transpose2d(w).numpy())
1666
+ ```
1667
+ """
1668
+ HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
1669
+ x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
1670
+ stride = make_pair(stride, len(HW))
1671
+ if any(s>1 for s in stride):
1672
+ x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
1673
+ x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
1674
+ x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
1675
+ x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
1676
+ padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
1677
+ zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
1678
+ return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1679
+
1680
+ def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
1681
+ """
1682
+ Performs dot product between two tensors.
1683
+
1684
+ ```python exec="true" source="above" session="tensor" result="python"
1685
+ a = Tensor([[1, 2], [3, 4]])
1686
+ b = Tensor([[5, 6], [7, 8]])
1687
+ print(a.dot(b).numpy())
1688
+ ```
1689
+ """
507
1690
  n1, n2 = len(self.shape), len(w.shape)
508
1691
  assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
509
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
1692
+ assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
510
1693
  x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
511
1694
  w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
512
- return (x*w).sum(-1)
513
-
514
- def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
515
-
516
- # ***** mlops (unary) *****
517
-
518
- def contiguous(self): return mlops.Contiguous.apply(self)
519
- def log(self): return mlops.Log.apply(self)
520
- def log2(self): return mlops.Log.apply(self)/math.log(2)
521
- def exp(self): return mlops.Exp.apply(self)
522
- def relu(self): return mlops.Relu.apply(self)
523
- def sigmoid(self): return mlops.Sigmoid.apply(self)
524
- def sin(self): return mlops.Sin.apply(self)
525
- def sqrt(self): return mlops.Sqrt.apply(self)
526
- def rsqrt(self): return (1/self).sqrt()
527
- def cos(self): return ((math.pi/2)-self).sin()
528
- def tan(self): return self.sin() / self.cos()
1695
+ return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
1696
+
1697
+ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
1698
+ """
1699
+ Performs matrix multiplication between two tensors.
1700
+
1701
+ You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
1702
+ You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
1703
+
1704
+ ```python exec="true" source="above" session="tensor" result="python"
1705
+ a = Tensor([[1, 2], [3, 4]])
1706
+ b = Tensor([[5, 6], [7, 8]])
1707
+ print(a.matmul(b).numpy())
1708
+ ```
1709
+ """
1710
+ return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
1711
+
1712
+ def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
1713
+ pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
1714
+ return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
1715
+ def cumsum(self, axis:int=0) -> Tensor:
1716
+ """
1717
+ Computes the cumulative sum of the tensor along the specified axis.
1718
+
1719
+ You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
1720
+
1721
+ ```python exec="true" source="above" session="tensor" result="python"
1722
+ t = Tensor.ones(2, 3)
1723
+ print(t.numpy())
1724
+ ```
1725
+ ```python exec="true" source="above" session="tensor" result="python"
1726
+ print(t.cumsum(1).numpy())
1727
+ ```
1728
+ """
1729
+ # TODO: someday the optimizer will find this on it's own
1730
+ # for now this is a two stage cumsum
1731
+ SPLIT = 256
1732
+ if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
1733
+ ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
1734
+ ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
1735
+ base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
1736
+ base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
1737
+ def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
1738
+ return fix(ret) + fix(base_add)
529
1739
 
530
1740
  @staticmethod
531
- def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
532
- def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
533
- def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
534
-
535
- # ***** math functions (unary) *****
536
- def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
537
- def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
538
- def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
539
-
540
- def __neg__(self): return 0.0-self
541
- def square(self): return self*self
542
- def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
543
- def abs(self): return self.relu() + (-self).relu()
544
- def sign(self): return self / (self.abs() + 1e-10)
545
- def reciprocal(self): return 1.0/self
546
-
547
- # ***** activation functions (unary) *****
548
- def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
549
- def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
550
- def swish(self): return self * self.sigmoid()
551
- def silu(self): return self.swish() # The SiLU function is also known as the swish function.
552
- def relu6(self): return self.relu() - (self-6).relu()
553
- def hardswish(self): return self * (self+3).relu6() * (1/6)
554
- def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
555
- def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
556
- def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
557
- def quick_gelu(self): return self * (self * 1.702).sigmoid()
558
- def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
559
- def mish(self): return self * self.softplus().tanh()
560
- def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
561
- def softsign(self): return self / (1 + self.abs())
562
-
563
- # ***** broadcasted binary mlops *****
564
-
565
- def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
1741
+ def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
1742
+ assert all_int((r,c)), "does not support symbolic"
1743
+ if r == 0: return Tensor.zeros((r, c), **kwargs)
1744
+ return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
1745
+ def triu(self, k:int=0) -> Tensor:
1746
+ """
1747
+ Returns the upper triangular part of the tensor, the other elements are set to 0.
1748
+
1749
+ ```python exec="true" source="above" session="tensor" result="python"
1750
+ t = Tensor([[1, 2, 3], [4, 5, 6]])
1751
+ print(t.numpy())
1752
+ ```
1753
+ ```python exec="true" source="above" session="tensor" result="python"
1754
+ print(t.triu(k=1).numpy())
1755
+ ```
1756
+ """
1757
+ return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
1758
+ def tril(self, k:int=0) -> Tensor:
1759
+ """
1760
+ Returns the lower triangular part of the tensor, the other elements are set to 0.
1761
+
1762
+ ```python exec="true" source="above" session="tensor" result="python"
1763
+ t = Tensor([[1, 2, 3], [4, 5, 6]])
1764
+ print(t.numpy())
1765
+ ```
1766
+ ```python exec="true" source="above" session="tensor" result="python"
1767
+ print(t.tril().numpy())
1768
+ ```
1769
+ """
1770
+ return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
1771
+
1772
+ # ***** unary ops *****
1773
+
1774
+ def logical_not(self):
1775
+ """
1776
+ Computes the logical NOT of the tensor element-wise.
1777
+
1778
+ ```python exec="true" source="above" session="tensor" result="python"
1779
+ print(Tensor([False, True]).logical_not().numpy())
1780
+ ```
1781
+ """
1782
+ return F.Eq.apply(*self._broadcasted(False))
1783
+ def neg(self):
1784
+ """
1785
+ Negates the tensor element-wise.
1786
+
1787
+ ```python exec="true" source="above" session="tensor" result="python"
1788
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
1789
+ ```
1790
+ """
1791
+ return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
1792
+ def contiguous(self):
1793
+ """
1794
+ Returns a contiguous tensor.
1795
+ """
1796
+ return F.Contiguous.apply(self)
1797
+ def contiguous_backward(self):
1798
+ """
1799
+ Inserts a contiguous operation in the backward pass.
1800
+ """
1801
+ return F.ContiguousBackward.apply(self)
1802
+ def log(self):
1803
+ """
1804
+ Computes the natural logarithm element-wise.
1805
+
1806
+ See: https://en.wikipedia.org/wiki/Logarithm
1807
+
1808
+ ```python exec="true" source="above" session="tensor" result="python"
1809
+ print(Tensor([1., 2., 4., 8.]).log().numpy())
1810
+ ```
1811
+ """
1812
+ return F.Log.apply(self.cast(least_upper_float(self.dtype)))
1813
+ def log2(self):
1814
+ """
1815
+ Computes the base-2 logarithm element-wise.
1816
+
1817
+ See: https://en.wikipedia.org/wiki/Logarithm
1818
+
1819
+ ```python exec="true" source="above" session="tensor" result="python"
1820
+ print(Tensor([1., 2., 4., 8.]).log2().numpy())
1821
+ ```
1822
+ """
1823
+ return self.log()/math.log(2)
1824
+ def exp(self):
1825
+ """
1826
+ Computes the exponential function element-wise.
1827
+
1828
+ See: https://en.wikipedia.org/wiki/Exponential_function
1829
+
1830
+ ```python exec="true" source="above" session="tensor" result="python"
1831
+ print(Tensor([0., 1., 2., 3.]).exp().numpy())
1832
+ ```
1833
+ """
1834
+ return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
1835
+ def exp2(self):
1836
+ """
1837
+ Computes the base-2 exponential function element-wise.
1838
+
1839
+ See: https://en.wikipedia.org/wiki/Exponential_function
1840
+
1841
+ ```python exec="true" source="above" session="tensor" result="python"
1842
+ print(Tensor([0., 1., 2., 3.]).exp2().numpy())
1843
+ ```
1844
+ """
1845
+ return F.Exp.apply(self*math.log(2))
1846
+ def relu(self):
1847
+ """
1848
+ Applies the Rectified Linear Unit (ReLU) function element-wise.
1849
+
1850
+ - Described: https://paperswithcode.com/method/relu
1851
+
1852
+ ```python exec="true" source="above" session="tensor" result="python"
1853
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
1854
+ ```
1855
+ """
1856
+ return F.Relu.apply(self)
1857
+ def sigmoid(self):
1858
+ """
1859
+ Applies the Sigmoid function element-wise.
1860
+
1861
+ - Described: https://en.wikipedia.org/wiki/Sigmoid_function
1862
+
1863
+ ```python exec="true" source="above" session="tensor" result="python"
1864
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
1865
+ ```
1866
+ """
1867
+ return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
1868
+ def sqrt(self):
1869
+ """
1870
+ Computes the square root of the tensor element-wise.
1871
+
1872
+ ```python exec="true" source="above" session="tensor" result="python"
1873
+ print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
1874
+ ```
1875
+ """
1876
+ return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
1877
+ def rsqrt(self):
1878
+ """
1879
+ Computes the reciprocal of the square root of the tensor element-wise.
1880
+
1881
+ ```python exec="true" source="above" session="tensor" result="python"
1882
+ print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
1883
+ ```
1884
+ """
1885
+ return self.reciprocal().sqrt()
1886
+ def sin(self):
1887
+ """
1888
+ Computes the sine of the tensor element-wise.
1889
+
1890
+ ```python exec="true" source="above" session="tensor" result="python"
1891
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
1892
+ ```
1893
+ """
1894
+ return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
1895
+ def cos(self):
1896
+ """
1897
+ Computes the cosine of the tensor element-wise.
1898
+
1899
+ ```python exec="true" source="above" session="tensor" result="python"
1900
+ print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
1901
+ ```
1902
+ """
1903
+ return ((math.pi/2)-self).sin()
1904
+ def tan(self):
1905
+ """
1906
+ Computes the tangent of the tensor element-wise.
1907
+
1908
+ ```python exec="true" source="above" session="tensor" result="python"
1909
+ print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
1910
+ ```
1911
+ """
1912
+ return self.sin() / self.cos()
1913
+
1914
+ # ***** math functions *****
1915
+
1916
+ def trunc(self: Tensor) -> Tensor:
1917
+ """
1918
+ Truncates the tensor element-wise.
1919
+
1920
+ ```python exec="true" source="above" session="tensor" result="python"
1921
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
1922
+ ```
1923
+ """
1924
+ return self.cast(dtypes.int32).cast(self.dtype)
1925
+ def ceil(self: Tensor) -> Tensor:
1926
+ """
1927
+ Rounds the tensor element-wise towards positive infinity.
1928
+
1929
+ ```python exec="true" source="above" session="tensor" result="python"
1930
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
1931
+ ```
1932
+ """
1933
+ return (self > (b := self.trunc())).where(b+1, b)
1934
+ def floor(self: Tensor) -> Tensor:
1935
+ """
1936
+ Rounds the tensor element-wise towards negative infinity.
1937
+
1938
+ ```python exec="true" source="above" session="tensor" result="python"
1939
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
1940
+ ```
1941
+ """
1942
+ return (self < (b := self.trunc())).where(b-1, b)
1943
+ def round(self: Tensor) -> Tensor:
1944
+ """
1945
+ Rounds the tensor element-wise.
1946
+
1947
+ ```python exec="true" source="above" session="tensor" result="python"
1948
+ print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
1949
+ ```
1950
+ """
1951
+ return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
1952
+ def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
1953
+ """
1954
+ Linearly interpolates between `self` and `end` by `weight`.
1955
+
1956
+ ```python exec="true" source="above" session="tensor" result="python"
1957
+ print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
1958
+ ```
1959
+ """
1960
+ return self + (end - self) * weight
1961
+ def square(self):
1962
+ """
1963
+ Squares the tensor element-wise.
1964
+ Equivalent to `self*self`.
1965
+
1966
+ ```python exec="true" source="above" session="tensor" result="python"
1967
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
1968
+ ```
1969
+ """
1970
+ return self*self
1971
+ def clip(self, min_, max_):
1972
+ """
1973
+ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
1974
+
1975
+ ```python exec="true" source="above" session="tensor" result="python"
1976
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
1977
+ ```
1978
+ """
1979
+ return self.maximum(min_).minimum(max_)
1980
+ def sign(self):
1981
+ """
1982
+ Returns the sign of the tensor element-wise.
1983
+
1984
+ ```python exec="true" source="above" session="tensor" result="python"
1985
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
1986
+ ```
1987
+ """
1988
+ return F.Sign.apply(self)
1989
+ def abs(self):
1990
+ """
1991
+ Computes the absolute value of the tensor element-wise.
1992
+
1993
+ ```python exec="true" source="above" session="tensor" result="python"
1994
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
1995
+ ```
1996
+ """
1997
+ return self * self.sign()
1998
+ def reciprocal(self):
1999
+ """
2000
+ Compute `1/x` element-wise.
2001
+
2002
+ ```python exec="true" source="above" session="tensor" result="python"
2003
+ print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
2004
+ ```
2005
+ """
2006
+ return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
2007
+
2008
+ # ***** activation functions *****
2009
+
2010
+ def elu(self, alpha=1.0):
2011
+ """
2012
+ Applies the Exponential Linear Unit (ELU) function element-wise.
2013
+
2014
+ - Described: https://paperswithcode.com/method/elu
2015
+ - Paper: https://arxiv.org/abs/1511.07289v5
2016
+
2017
+ ```python exec="true" source="above" session="tensor" result="python"
2018
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
2019
+ ```
2020
+ """
2021
+ return self.relu() - alpha*(1-self.exp()).relu()
2022
+
2023
+ def celu(self, alpha=1.0):
2024
+ """
2025
+ Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
2026
+
2027
+ - Described: https://paperswithcode.com/method/celu
2028
+ - Paper: https://arxiv.org/abs/1704.07483
2029
+
2030
+ ```python exec="true" source="above" session="tensor" result="python"
2031
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
2032
+ ```
2033
+ """
2034
+ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
2035
+
2036
+ def swish(self):
2037
+ """
2038
+ See `.silu()`
2039
+
2040
+ - Paper: https://arxiv.org/abs/1710.05941v1
2041
+
2042
+ ```python exec="true" source="above" session="tensor" result="python"
2043
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
2044
+ ```
2045
+ """
2046
+ return self * self.sigmoid()
2047
+
2048
+ def silu(self):
2049
+ """
2050
+ Applies the Sigmoid Linear Unit (SiLU) function element-wise.
2051
+
2052
+ - Described: https://paperswithcode.com/method/silu
2053
+ - Paper: https://arxiv.org/abs/1606.08415
2054
+
2055
+ ```python exec="true" source="above" session="tensor" result="python"
2056
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
2057
+ ```
2058
+ """
2059
+ return self.swish() # The SiLU function is also known as the swish function.
2060
+
2061
+ def relu6(self):
2062
+ """
2063
+ Applies the ReLU6 function element-wise.
2064
+
2065
+ - Described: https://paperswithcode.com/method/relu6
2066
+ - Paper: https://arxiv.org/abs/1704.04861v1
2067
+
2068
+ ```python exec="true" source="above" session="tensor" result="python"
2069
+ print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
2070
+ ```
2071
+ """
2072
+ return self.relu() - (self-6).relu()
2073
+
2074
+ def hardswish(self):
2075
+ """
2076
+ Applies the Hardswish function element-wise.
2077
+
2078
+ - Described: https://paperswithcode.com/method/hard-swish
2079
+ - Paper: https://arxiv.org/abs/1905.02244v5
2080
+
2081
+ ```python exec="true" source="above" session="tensor" result="python"
2082
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
2083
+ ```
2084
+ """
2085
+ return self * (self+3).relu6() * (1/6)
2086
+
2087
+ def tanh(self):
2088
+ """
2089
+ Applies the Hyperbolic Tangent (tanh) function element-wise.
2090
+
2091
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
2092
+
2093
+ ```python exec="true" source="above" session="tensor" result="python"
2094
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
2095
+ ```
2096
+ """
2097
+ return 2.0 * ((2.0 * self).sigmoid()) - 1.0
2098
+
2099
+ def sinh(self):
2100
+ """
2101
+ Applies the Hyperbolic Sine (sinh) function element-wise.
2102
+
2103
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
2104
+
2105
+ ```python exec="true" source="above" session="tensor" result="python"
2106
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
2107
+ ```
2108
+ """
2109
+ return (self.exp() - self.neg().exp()) / 2
2110
+
2111
+ def cosh(self):
2112
+ """
2113
+ Applies the Hyperbolic Cosine (cosh) function element-wise.
2114
+
2115
+ - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
2116
+
2117
+ ```python exec="true" source="above" session="tensor" result="python"
2118
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
2119
+ ```
2120
+ """
2121
+ return (self.exp() + self.neg().exp()) / 2
2122
+
2123
+ def atanh(self):
2124
+ """
2125
+ Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
2126
+
2127
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
2128
+
2129
+ ```python exec="true" source="above" session="tensor" result="python"
2130
+ print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
2131
+ ```
2132
+ """
2133
+ return ((1 + self)/(1 - self)).log() / 2
2134
+
2135
+ def asinh(self):
2136
+ """
2137
+ Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
2138
+
2139
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
2140
+
2141
+ ```python exec="true" source="above" session="tensor" result="python"
2142
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
2143
+ ```
2144
+ """
2145
+ return (self + (self.square() + 1).sqrt()).log()
2146
+
2147
+ def acosh(self):
2148
+ """
2149
+ Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
2150
+
2151
+ - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
2152
+
2153
+ ```python exec="true" source="above" session="tensor" result="python"
2154
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
2155
+ ```
2156
+ """
2157
+ return (self + (self.square() - 1).sqrt()).log()
2158
+
2159
+ def hardtanh(self, min_val=-1, max_val=1):
2160
+ """
2161
+ Applies the Hardtanh function element-wise.
2162
+
2163
+ - Described: https://paperswithcode.com/method/hardtanh-activation
2164
+
2165
+ ```python exec="true" source="above" session="tensor" result="python"
2166
+ print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
2167
+ ```
2168
+ """
2169
+ return self.clip(min_val, max_val)
2170
+
2171
+ def gelu(self):
2172
+ """
2173
+ Applies the Gaussian Error Linear Unit (GELU) function element-wise.
2174
+
2175
+ - Described: https://paperswithcode.com/method/gelu
2176
+ - Paper: https://arxiv.org/abs/1606.08415v5
2177
+
2178
+ ```python exec="true" source="above" session="tensor" result="python"
2179
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
2180
+ ```
2181
+ """
2182
+ return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
2183
+
2184
+ def quick_gelu(self):
2185
+ """
2186
+ Applies the Sigmoid GELU approximation element-wise.
2187
+
2188
+ - Described: https://paperswithcode.com/method/gelu
2189
+
2190
+ ```python exec="true" source="above" session="tensor" result="python"
2191
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
2192
+ ```
2193
+ """
2194
+ return self * (self * 1.702).sigmoid()
2195
+
2196
+ def leakyrelu(self, neg_slope=0.01):
2197
+ """
2198
+ Applies the Leaky ReLU function element-wise.
2199
+
2200
+ - Described: https://paperswithcode.com/method/leaky-relu
2201
+
2202
+ ```python exec="true" source="above" session="tensor" result="python"
2203
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
2204
+ ```
2205
+ ```python exec="true" source="above" session="tensor" result="python"
2206
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
2207
+ ```
2208
+ """
2209
+ return self.relu() - (-neg_slope*self).relu()
2210
+
2211
+ def mish(self):
2212
+ """
2213
+ Applies the Mish function element-wise.
2214
+
2215
+ - Described: https://paperswithcode.com/method/mish
2216
+ - Paper: https://arxiv.org/abs/1908.08681v3
2217
+
2218
+ ```python exec="true" source="above" session="tensor" result="python"
2219
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
2220
+ ```
2221
+ """
2222
+ return self * self.softplus().tanh()
2223
+
2224
+ def softplus(self, beta=1):
2225
+ """
2226
+ Applies the Softplus function element-wise.
2227
+
2228
+ - Described: https://paperswithcode.com/method/softplus
2229
+
2230
+ ```python exec="true" source="above" session="tensor" result="python"
2231
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
2232
+ ```
2233
+ """
2234
+ return (1/beta) * (1 + (self*beta).exp()).log()
2235
+
2236
+ def softsign(self):
2237
+ """
2238
+ Applies the Softsign function element-wise.
2239
+
2240
+ - Described: https://paperswithcode.com/method/softsign
2241
+
2242
+ ```python exec="true" source="above" session="tensor" result="python"
2243
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
2244
+ ```
2245
+ """
2246
+ return self / (1 + self.abs())
2247
+
2248
+ # ***** broadcasted elementwise ops *****
2249
+ def _broadcast_to(self, shape:Tuple[sint, ...]):
2250
+ reshape_arg, _ = _pad_left(self.shape, shape)
2251
+ if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
2252
+ raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
2253
+ return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
2254
+
2255
+ def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
566
2256
  x: Tensor = self
567
2257
  if not isinstance(y, Tensor):
568
- y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
2258
+ # make y a Tensor
2259
+ assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
2260
+ if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2261
+ else: y_dtype = dtypes.from_py(y)
2262
+ if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
2263
+ else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
2264
+
2265
+ if match_dtype:
2266
+ output_dtype = least_upper_dtype(x.dtype, y.dtype)
2267
+ x, y = x.cast(output_dtype), y.cast(output_dtype)
2268
+
569
2269
  if reverse: x, y = y, x
570
- if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
571
-
572
- shape_delta = len(xshape) - len(yshape)
573
- if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
574
- elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
575
- if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
576
-
577
- shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
578
- if xshape != shape_ret: x = x.expand(shape_ret)
579
- if yshape != shape_ret: y = y.expand(shape_ret)
580
- return (x, y)
581
-
582
- def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
583
- def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
584
- def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
585
- def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
586
- def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
587
- if x.__class__ is not Tensor and not reverse:
2270
+
2271
+ # broadcast
2272
+ out_shape = _broadcast_shape(x.shape, y.shape)
2273
+ return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
2274
+
2275
+ def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
2276
+ # TODO: update with multi
2277
+ return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
2278
+ and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
2279
+
2280
+ def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2281
+ """
2282
+ Adds `self` and `x`.
2283
+ Equivalent to `self + x`.
2284
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2285
+
2286
+ ```python exec="true" source="above" session="tensor" result="python"
2287
+ Tensor.manual_seed(42)
2288
+ t = Tensor.randn(4)
2289
+ print(t.numpy())
2290
+ ```
2291
+ ```python exec="true" source="above" session="tensor" result="python"
2292
+ print(t.add(20).numpy())
2293
+ ```
2294
+ ```python exec="true" source="above" session="tensor" result="python"
2295
+ print(t.add(Tensor([[2.0], [3.5]])).numpy())
2296
+ ```
2297
+ """
2298
+ return F.Add.apply(*self._broadcasted(x, reverse))
2299
+
2300
+ def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2301
+ """
2302
+ Subtracts `x` from `self`.
2303
+ Equivalent to `self - x`.
2304
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2305
+
2306
+ ```python exec="true" source="above" session="tensor" result="python"
2307
+ Tensor.manual_seed(42)
2308
+ t = Tensor.randn(4)
2309
+ print(t.numpy())
2310
+ ```
2311
+ ```python exec="true" source="above" session="tensor" result="python"
2312
+ print(t.sub(20).numpy())
2313
+ ```
2314
+ ```python exec="true" source="above" session="tensor" result="python"
2315
+ print(t.sub(Tensor([[2.0], [3.5]])).numpy())
2316
+ ```
2317
+ """
2318
+ return F.Sub.apply(*self._broadcasted(x, reverse))
2319
+
2320
+ def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2321
+ """
2322
+ Multiplies `self` and `x`.
2323
+ Equivalent to `self * x`.
2324
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2325
+
2326
+ ```python exec="true" source="above" session="tensor" result="python"
2327
+ Tensor.manual_seed(42)
2328
+ t = Tensor.randn(4)
2329
+ print(t.numpy())
2330
+ ```
2331
+ ```python exec="true" source="above" session="tensor" result="python"
2332
+ print(t.mul(3).numpy())
2333
+ ```
2334
+ ```python exec="true" source="above" session="tensor" result="python"
2335
+ print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
2336
+ ```
2337
+ """
2338
+ return F.Mul.apply(*self._broadcasted(x, reverse))
2339
+
2340
+ def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
2341
+ """
2342
+ Divides `self` by `x`.
2343
+ Equivalent to `self / x`.
2344
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2345
+ By default, `div` performs true division. Set `upcast` to `False` for integer division.
2346
+
2347
+ ```python exec="true" source="above" session="tensor" result="python"
2348
+ Tensor.manual_seed(42)
2349
+ t = Tensor.randn(4)
2350
+ print(t.numpy())
2351
+ ```
2352
+ ```python exec="true" source="above" session="tensor" result="python"
2353
+ print(t.div(3).numpy())
2354
+ ```
2355
+ ```python exec="true" source="above" session="tensor" result="python"
2356
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
2357
+ ```
2358
+ ```python exec="true" source="above" session="tensor" result="python"
2359
+ print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
2360
+ ```
2361
+ """
2362
+ numerator, denominator = self._broadcasted(x, reverse)
2363
+ if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
2364
+ return F.Div.apply(numerator, denominator)
2365
+
2366
+ def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2367
+ """
2368
+ Computes bitwise xor of `self` and `x`.
2369
+ Equivalent to `self ^ x`.
2370
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2371
+
2372
+ ```python exec="true" source="above" session="tensor" result="python"
2373
+ print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
2374
+ ```
2375
+ ```python exec="true" source="above" session="tensor" result="python"
2376
+ print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
2377
+ ```
2378
+ """
2379
+ return F.Xor.apply(*self._broadcasted(x, reverse))
2380
+
2381
+ def lshift(self, x:int):
2382
+ """
2383
+ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2384
+ Equivalent to `self << x`.
2385
+
2386
+ ```python exec="true" source="above" session="tensor" result="python"
2387
+ print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
2388
+ ```
2389
+ """
2390
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2391
+ return self.mul(2 ** x)
2392
+
2393
+ def rshift(self, x:int):
2394
+ """
2395
+ Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
2396
+ Equivalent to `self >> x`.
2397
+
2398
+ ```python exec="true" source="above" session="tensor" result="python"
2399
+ print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
2400
+ ```
2401
+ """
2402
+ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2403
+ return self.div(2 ** x, upcast=False)
2404
+
2405
+ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2406
+ """
2407
+ Computes power of `self` with `x`.
2408
+ Equivalent to `self ** x`.
2409
+
2410
+ ```python exec="true" source="above" session="tensor" result="python"
2411
+ print(Tensor([-1, 2, 3]).pow(2).numpy())
2412
+ ```
2413
+ ```python exec="true" source="above" session="tensor" result="python"
2414
+ print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
2415
+ ```
2416
+ ```python exec="true" source="above" session="tensor" result="python"
2417
+ print((2 ** Tensor([-1, 2, 3])).numpy())
2418
+ ```
2419
+ """
2420
+ x = self._to_const_val(x)
2421
+ if not isinstance(x, Tensor) and not reverse:
588
2422
  # simple pow identities
589
- if x < 0: return (1.0/self).pow(-x)
590
- if x == 2.0: return self*self
591
- if x == 1.0: return self
592
- if x == 0.5: return self.sqrt()
2423
+ if x < 0: return self.reciprocal().pow(-x)
2424
+ if x == 0: return 1 + self * 0
2425
+ if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
2426
+ if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
2427
+
2428
+ # positive const ** self
593
2429
  if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
594
- ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
595
- # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
596
- sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
597
- # we only need to correct the sign if the base is negative
598
- base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
599
- # we need 0 to be positive so we need to correct base_sign when the base is 0
600
- base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
601
- # inject nan if the base is negative and the power is not an integer
602
- to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
603
- inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
604
- return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
605
- def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
606
-
607
- def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
608
- def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
609
-
610
- def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
611
- x_,y = self._broadcasted(input_)
612
- x,z = x_._broadcasted(other)
613
- return mlops.Where.apply(x, *y._broadcasted(z))
614
-
615
- # ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
616
-
617
- # NOTE: __pow__ and friends are broken in mypyc with the ** operator
2430
+
2431
+ base, exponent = self._broadcasted(x, reverse=reverse)
2432
+ # start with b ** e = exp(e * log(b))
2433
+ ret = base.abs().log().mul(exponent).exp()
2434
+ # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
2435
+ negative_base = (base < 0).detach().where(1, 0)
2436
+ # 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
2437
+ correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
2438
+ # inject nan for negative base and non-integer exponent
2439
+ inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
2440
+ # apply correct_sign inject_nan, and fix 0 ** 0 = 1
2441
+ return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
2442
+
2443
+ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
2444
+ """
2445
+ Computes element-wise maximum of `self` and `x`.
2446
+
2447
+ ```python exec="true" source="above" session="tensor" result="python"
2448
+ print(Tensor([-1, 2, 3]).maximum(1).numpy())
2449
+ ```
2450
+ ```python exec="true" source="above" session="tensor" result="python"
2451
+ print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
2452
+ ```
2453
+ """
2454
+ return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
2455
+
2456
+ def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
2457
+ """
2458
+ Computes element-wise minimum of `self` and `x`.
2459
+
2460
+ ```python exec="true" source="above" session="tensor" result="python"
2461
+ print(Tensor([-1, 2, 3]).minimum(1).numpy())
2462
+ ```
2463
+ ```python exec="true" source="above" session="tensor" result="python"
2464
+ print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
2465
+ ```
2466
+ """
2467
+ return -((-self).maximum(-x))
2468
+
2469
+ def where(self:Tensor, x:Union[Tensor, ConstType], y:Union[Tensor, ConstType]):
2470
+ """
2471
+ Return a tensor of elements selected from either `x` or `y`, depending on `self`.
2472
+ `output_i = x_i if self_i else y_i`.
2473
+
2474
+ ```python exec="true" source="above" session="tensor" result="python"
2475
+ cond = Tensor([[True, True, False], [True, False, False]])
2476
+ print(cond.where(1, 3).numpy())
2477
+ ```
2478
+ ```python exec="true" source="above" session="tensor" result="python"
2479
+ Tensor.manual_seed(42)
2480
+ cond = Tensor.randn(2, 3)
2481
+ print(cond.numpy())
2482
+ ```
2483
+ ```python exec="true" source="above" session="tensor" result="python"
2484
+ print((cond > 0).where(cond, -float("inf")).numpy())
2485
+ ```
2486
+ """
2487
+ if isinstance(x, Tensor): x, y = x._broadcasted(y)
2488
+ elif isinstance(y, Tensor): y, x = y._broadcasted(x)
2489
+ cond, x = self._broadcasted(x, match_dtype=False)
2490
+ cond, y = cond._broadcasted(y, match_dtype=False)
2491
+ return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
2492
+
2493
+ def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
2494
+
2495
+ # ***** op wrappers *****
2496
+
2497
+ def __neg__(self) -> Tensor: return self.neg()
2498
+
618
2499
  def __add__(self, x) -> Tensor: return self.add(x)
619
2500
  def __sub__(self, x) -> Tensor: return self.sub(x)
620
2501
  def __mul__(self, x) -> Tensor: return self.mul(x)
621
2502
  def __pow__(self, x) -> Tensor: return self.pow(x)
622
2503
  def __truediv__(self, x) -> Tensor: return self.div(x)
623
2504
  def __matmul__(self, x) -> Tensor: return self.matmul(x)
2505
+ def __xor__(self, x) -> Tensor: return self.xor(x)
2506
+ def __lshift__(self, x) -> Tensor: return self.lshift(x)
2507
+ def __rshift__(self, x) -> Tensor: return self.rshift(x)
624
2508
 
625
2509
  def __radd__(self, x) -> Tensor: return self.add(x, True)
626
2510
  def __rsub__(self, x) -> Tensor: return self.sub(x, True)
@@ -628,6 +2512,7 @@ class Tensor:
628
2512
  def __rpow__(self, x) -> Tensor: return self.pow(x, True)
629
2513
  def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
630
2514
  def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
2515
+ def __rxor__(self, x) -> Tensor: return self.xor(x, True)
631
2516
 
632
2517
  def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
633
2518
  def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
@@ -635,73 +2520,359 @@ class Tensor:
635
2520
  def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
636
2521
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
637
2522
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
2523
+ def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
2524
+ def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
2525
+ def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
638
2526
 
639
- def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
640
- def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
641
- def __ge__(self, x) -> Tensor: return 1.0-(self<x)
642
- def __le__(self, x) -> Tensor: return 1.0-(self>x)
643
- def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
644
- def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore
2527
+ def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
2528
+ def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
2529
+ def __ge__(self, x) -> Tensor: return (self<x).logical_not()
2530
+ def __le__(self, x) -> Tensor: return (self>x).logical_not()
2531
+ def __eq__(self, x) -> Tensor: return F.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
2532
+ def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
645
2533
 
646
2534
  # ***** functional nn ops *****
647
2535
 
648
2536
  def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
2537
+ """
2538
+ Applies a linear transformation to `self` using `weight` and `bias`.
2539
+
2540
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
2541
+
2542
+ ```python exec="true" source="above" session="tensor" result="python"
2543
+ t = Tensor([[1, 2], [3, 4]])
2544
+ weight = Tensor([[1, 2], [3, 4]])
2545
+ bias = Tensor([1, 2])
2546
+ print(t.linear(weight, bias).numpy())
2547
+ ```
2548
+ """
649
2549
  x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
650
2550
  return x.add(bias) if bias is not None else x
651
2551
 
652
- def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
2552
+ def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
2553
+ """
2554
+ Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
2555
+
2556
+ ```python exec="true" source="above" session="tensor" result="python"
2557
+ t = Tensor([1, 2, 3])
2558
+ print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
2559
+ ```
2560
+ """
2561
+ return functools.reduce(lambda x,f: f(x), ll, self)
653
2562
 
654
2563
  def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
2564
+ """
2565
+ Applies Layer Normalization over a mini-batch of inputs.
2566
+
2567
+ - Described: https://paperswithcode.com/method/layer-normalization
2568
+ - Paper: https://arxiv.org/abs/1607.06450v1
2569
+
2570
+ ```python exec="true" source="above" session="tensor" result="python"
2571
+ t = Tensor.randn(8, 10, 16) * 2 + 8
2572
+ print(t.mean().item(), t.std().item())
2573
+ ```
2574
+ ```python exec="true" source="above" session="tensor" result="python"
2575
+ t = t.layernorm()
2576
+ print(t.mean().item(), t.std().item())
2577
+ ```
2578
+ """
655
2579
  y = (self - self.mean(axis, keepdim=True))
656
2580
  return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
657
2581
 
658
- def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
659
- x = (self - mean.reshape(shape=[1, -1, 1, 1]))
660
- if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
661
- ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
662
- return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
2582
+ def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
2583
+ """
2584
+ Applies Batch Normalization over a mini-batch of inputs.
2585
+
2586
+ - Described: https://paperswithcode.com/method/batch-normalization
2587
+ - Paper: https://arxiv.org/abs/1502.03167
2588
+
2589
+ ```python exec="true" source="above" session="tensor" result="python"
2590
+ t = Tensor.randn(8, 4, 16, 16) * 2 + 8
2591
+ print(t.mean().item(), t.std().item())
2592
+ ```
2593
+ ```python exec="true" source="above" session="tensor" result="python"
2594
+ t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
2595
+ print(t.mean().item(), t.std().item())
2596
+ ```
2597
+ """
2598
+ axis_ = argfix(axis)
2599
+ shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
2600
+ x = self - mean.reshape(shape)
2601
+ if weight is not None: x = x * weight.reshape(shape)
2602
+ ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
2603
+ return (ret + bias.reshape(shape)) if bias is not None else ret
663
2604
 
664
2605
  def dropout(self, p=0.5) -> Tensor:
665
- if not Tensor.training or p == 0: return self
666
- mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
667
- return self * mask * (1/(1.0 - p))
2606
+ """
2607
+ Applies dropout to `self`.
668
2608
 
669
- def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
670
- if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
671
- if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
672
- return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
2609
+ NOTE: dropout is only applied when `Tensor.training` is `True`.
2610
+
2611
+ - Described: https://paperswithcode.com/method/dropout
2612
+ - Paper: https://jmlr.org/papers/v15/srivastava14a.html
673
2613
 
674
- def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
675
- loss_mask = Y != ignore_index
676
- y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
677
- y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
678
- return self.log_softmax().mul(y).sum() / loss_mask.sum()
2614
+ ```python exec="true" source="above" session="tensor" result="python"
2615
+ Tensor.manual_seed(42)
2616
+ t = Tensor.randn(2, 2)
2617
+ with Tensor.train():
2618
+ print(t.dropout().numpy())
2619
+ ```
2620
+ """
2621
+ if not Tensor.training or p == 0: return self
2622
+ return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
2623
+
2624
+ def one_hot(self, num_classes:int) -> Tensor:
2625
+ """
2626
+ Converts `self` to a one-hot tensor.
2627
+
2628
+ ```python exec="true" source="above" session="tensor" result="python"
2629
+ t = Tensor([0, 1, 3, 3, 4])
2630
+ print(t.one_hot(5).numpy())
2631
+ ```
2632
+ """
2633
+ return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
2634
+
2635
+ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
2636
+ dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
2637
+ """
2638
+ Computes scaled dot-product attention.
2639
+ `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
2640
+
2641
+ - Described: https://paperswithcode.com/method/scaled
2642
+ - Paper: https://arxiv.org/abs/1706.03762v7
2643
+
2644
+ ```python exec="true" source="above" session="tensor" result="python"
2645
+ q = Tensor.randn(2, 4, 8)
2646
+ k = Tensor.randn(2, 4, 8)
2647
+ v = Tensor.randn(2, 4, 8)
2648
+ print(q.scaled_dot_product_attention(k, v).numpy())
2649
+ ```
2650
+ """
2651
+ # NOTE: it also works when `key` and `value` have symbolic shape.
2652
+ assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
2653
+ if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
2654
+ if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
2655
+ qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
2656
+ return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
2657
+
2658
+ def binary_crossentropy(self, y:Tensor) -> Tensor:
2659
+ """
2660
+ Computes the binary cross-entropy loss between `self` and `y`.
2661
+
2662
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
2663
+
2664
+ ```python exec="true" source="above" session="tensor" result="python"
2665
+ t = Tensor([0.1, 0.9, 0.2])
2666
+ y = Tensor([0, 1, 0])
2667
+ print(t.binary_crossentropy(y).item())
2668
+ ```
2669
+ """
2670
+ return (-y*self.log() - (1-y)*(1-self).log()).mean()
2671
+
2672
+ def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
2673
+ """
2674
+ Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
2675
+
2676
+ See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
2677
+
2678
+ ```python exec="true" source="above" session="tensor" result="python"
2679
+ t = Tensor([-1, 2, -3])
2680
+ y = Tensor([0, 1, 0])
2681
+ print(t.binary_crossentropy_logits(y).item())
2682
+ ```
2683
+ """
2684
+ return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
2685
+
2686
+ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
2687
+ """
2688
+ Computes the sparse categorical cross-entropy loss between `self` and `Y`.
2689
+
2690
+ NOTE: `self` is logits and `Y` is the target labels.
2691
+
2692
+ See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
2693
+
2694
+ ```python exec="true" source="above" session="tensor" result="python"
2695
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
2696
+ Y = Tensor([1, 2])
2697
+ print(t.sparse_categorical_crossentropy(Y).item())
2698
+ ```
2699
+ """
2700
+ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
2701
+ log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
2702
+ y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
2703
+ y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
2704
+ smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
2705
+ return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
679
2706
 
680
2707
  # ***** cast ops *****
681
2708
 
682
- def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
2709
+ def llvm_bf16_cast(self, dtype:DType):
2710
+ # hack for devices that don't support bfloat16
2711
+ assert self.dtype == dtypes.bfloat16
2712
+ return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
2713
+ def cast(self, dtype:DType) -> Tensor:
2714
+ """
2715
+ Casts `self` to the given `dtype`.
2716
+
2717
+ ```python exec="true" source="above" session="tensor" result="python"
2718
+ t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
2719
+ print(t.dtype, t.numpy())
2720
+ ```
2721
+ ```python exec="true" source="above" session="tensor" result="python"
2722
+ t = t.cast(dtypes.int32)
2723
+ print(t.dtype, t.numpy())
2724
+ ```
2725
+ """
2726
+ return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
683
2727
  def bitcast(self, dtype:DType) -> Tensor:
684
- assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
685
- return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
686
- def float(self) -> Tensor: return self.cast(dtypes.float32)
687
- def half(self) -> Tensor: return self.cast(dtypes.float16)
2728
+ """
2729
+ Bitcasts `self` to the given `dtype` of the same itemsize.
2730
+
2731
+ `self` must not require a gradient.
2732
+
2733
+ ```python exec="true" source="above" session="tensor" result="python"
2734
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2735
+ print(t.dtype, t.numpy())
2736
+ ```
2737
+ ```python exec="true" source="above" session="tensor" result="python"
2738
+ t = t.bitcast(dtypes.uint32)
2739
+ print(t.dtype, t.numpy())
2740
+ ```
2741
+ """
2742
+ if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
2743
+ return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
2744
+ def float(self) -> Tensor:
2745
+ """
2746
+ Convenience method to cast `self` to a `float32` Tensor.
2747
+
2748
+ ```python exec="true" source="above" session="tensor" result="python"
2749
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2750
+ print(t.dtype, t.numpy())
2751
+ ```
2752
+ ```python exec="true" source="above" session="tensor" result="python"
2753
+ t = t.float()
2754
+ print(t.dtype, t.numpy())
2755
+ ```
2756
+ """
2757
+ return self.cast(dtypes.float32)
2758
+ def half(self) -> Tensor:
2759
+ """
2760
+ Convenience method to cast `self` to a `float16` Tensor.
2761
+
2762
+ ```python exec="true" source="above" session="tensor" result="python"
2763
+ t = Tensor([-1, 2, 3], dtype=dtypes.int32)
2764
+ print(t.dtype, t.numpy())
2765
+ ```
2766
+ ```python exec="true" source="above" session="tensor" result="python"
2767
+ t = t.half()
2768
+ print(t.dtype, t.numpy())
2769
+ ```
2770
+ """
2771
+ return self.cast(dtypes.float16)
688
2772
 
689
2773
  # ***** convenience stuff *****
690
2774
 
691
2775
  @property
692
2776
  def ndim(self) -> int: return len(self.shape)
693
- def numel(self) -> int: return math.prod(self.shape)
2777
+ def numel(self) -> sint: return prod(self.shape)
694
2778
  def element_size(self) -> int: return self.dtype.itemsize
695
2779
  def nbytes(self) -> int: return self.numel() * self.element_size()
696
2780
  def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
2781
+ def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
2782
+
2783
+ # *** image Tensor function replacements ***
2784
+
2785
+ def image_dot(self, w:Tensor, acc_dtype=None):
2786
+ # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
2787
+ n1, n2 = len(self.shape), len(w.shape)
2788
+ assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
2789
+ assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
2790
+ bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
2791
+ out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
2792
+
2793
+ # NOTE: with NHWC we can remove the transposes
2794
+ # bs x groups*cin x H x W
2795
+ cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
2796
+ # groups*cout x cin x H, W
2797
+ cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
2798
+ return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
2799
+
2800
+ def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
2801
+ base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
2802
+
2803
+ (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
2804
+ x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
2805
+
2806
+ # hack for non multiples of 4 on cin
2807
+ if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
2808
+ x = x.reshape(bs, groups, cin, iy, ix) # do this always?
2809
+ added_input_channels = 4 - (cin % 4)
2810
+ w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
2811
+ x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
2812
+ cin = cin + added_input_channels
2813
+ x = x.reshape(bs, groups*cin, iy, ix)
2814
+
2815
+ # hack for non multiples of 4 on rcout
2816
+ added_output_channels = 0
2817
+ if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
2818
+ added_output_channels = 4 - (rcout % 4)
2819
+ rcout += added_output_channels
2820
+ cout = groups * rcout
2821
+ w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
2822
+
2823
+ # packed (note: flipping bs and iy would make the auto-padding work)
2824
+ x = x.permute(0,2,3,1)
2825
+ cin_last = iy == 1 and ix == 1
2826
+ if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
2827
+ elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
2828
+ else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
2829
+
2830
+ # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
2831
+ if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
2832
+ x, w = x.contiguous(), w.contiguous()
2833
+
2834
+ # expand out
2835
+ rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
2836
+ cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
2837
+ x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
2838
+ if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
2839
+ else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
2840
+
2841
+ # padding
2842
+ padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
2843
+ x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
2844
+
2845
+ # prepare input
2846
+ x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
2847
+ x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
2848
+
2849
+ # prepare weights
2850
+ w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
2851
+
2852
+ # the conv!
2853
+ ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
2854
+
2855
+ # undo hack for non multiples of 4 on C.rcout
2856
+ if added_output_channels != 0:
2857
+ ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
2858
+ cout = groups * (rcout - added_output_channels)
2859
+
2860
+ # NCHW output
2861
+ ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
2862
+ return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
697
2863
 
698
2864
  # register functions to move between devices
699
- for device in Device._buffers:
700
- setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
701
- setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
2865
+ for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
702
2866
 
703
2867
  if IMAGE:
704
2868
  # if IMAGE>0 we install these replacement functions in Tensor (hack!)
705
- from tinygrad.nn.image import image_conv2d, image_dot
706
- setattr(Tensor, "conv2d", image_conv2d)
707
- setattr(Tensor, "dot", image_dot)
2869
+ setattr(Tensor, "conv2d", Tensor.image_conv2d)
2870
+ setattr(Tensor, "dot", Tensor.image_dot)
2871
+
2872
+ # TODO: eventually remove this
2873
+ def custom_random(out:Buffer):
2874
+ Tensor._seed += 1
2875
+ rng = np.random.default_rng(Tensor._seed)
2876
+ if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2877
+ else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
2878
+ out.copyin(rng_np_buffer.data)