tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,37 +1,37 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools, functools, struct
3
+ import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
4
4
  from contextlib import ContextDecorator
5
- from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
5
+ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
6
6
  from collections import defaultdict
7
- import numpy as np
8
7
 
9
- from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
- from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
11
- from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
12
- from tinygrad.lazy import LazyBuffer
8
+ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
9
+ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
10
+ from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
13
11
  from tinygrad.multi import MultiLazyBuffer
14
- from tinygrad.ops import LoadOps, truncate
12
+ from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait
15
13
  from tinygrad.device import Device, Buffer, BufferOptions
16
- from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
14
+ from tinygrad.engine.lazy import LazyBuffer
17
15
  from tinygrad.engine.realize import run_schedule
18
- from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
16
+ from tinygrad.engine.memory import memory_planner
17
+ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
19
18
 
20
19
  # **** start with two base classes, Tensor and Function ****
21
20
 
22
21
  class Function:
23
- def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
22
+ def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
24
23
  self.device = device
25
24
  self.needs_input_grad = [t.requires_grad for t in tensors]
26
25
  self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
27
26
  if self.requires_grad: self.parents = tensors
27
+ self.metadata = metadata
28
28
 
29
29
  def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
30
30
  def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
31
31
 
32
32
  @classmethod
33
33
  def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
34
- ctx = fxn(x[0].device, *x)
34
+ ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
35
35
  ret = Tensor.__new__(Tensor)
36
36
  ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
37
  ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
@@ -39,29 +39,39 @@ class Function:
39
39
 
40
40
  import tinygrad.function as F
41
41
 
42
- def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
- if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
44
- return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
42
+ def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
+ if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
44
+ return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
45
45
 
46
- def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
47
- def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
46
+ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
47
+ import numpy as np
48
+ return dtypes.fields()[np.dtype(npdtype).name]
49
+ def _to_np_dtype(dtype:DType) -> Optional[type]:
50
+ import numpy as np
51
+ return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
48
52
 
49
- def _fromnp(x: np.ndarray) -> LazyBuffer:
50
- ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
53
+ def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
54
+ ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
51
55
  # fake realize
52
56
  ret.buffer.allocate(x)
53
57
  del ret.srcs
54
58
  return ret
55
59
 
60
+ def get_shape(x) -> Tuple[int, ...]:
61
+ # NOTE: str is special because __getitem__ on a str is still a str
62
+ if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
63
+ if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
64
+ return (len(subs),) + (subs[0] if subs else ())
65
+
56
66
  def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
57
- if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
67
+ if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
58
68
  else:
59
- ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
69
+ ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
60
70
  assert dtype.fmt is not None, f"{dtype=} has None fmt"
61
71
  truncate_function = truncate[dtype]
62
72
  data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
63
73
  # fake realize
64
- ret.buffer.allocate(memoryview(data))
74
+ ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
65
75
  del ret.srcs
66
76
  return ret
67
77
 
@@ -85,9 +95,11 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
85
95
  max_dim = max(len(shape) for shape in shapes)
86
96
  return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
87
97
  def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
88
- return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
98
+ return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
99
+
100
+ ReductionStr = Literal["mean", "sum", "none"]
89
101
 
90
- class Tensor:
102
+ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
91
103
  """
92
104
  A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
93
105
 
@@ -103,9 +115,11 @@ class Tensor:
103
115
  training: ClassVar[bool] = False
104
116
  no_grad: ClassVar[bool] = False
105
117
 
106
- def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
107
- device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
118
+ def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, LazyBuffer, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
119
+ device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
120
+ if dtype is not None: dtype = to_dtype(dtype)
108
121
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
122
+ if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
109
123
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
110
124
 
111
125
  # tensors can have gradients if you have called .backward
@@ -119,42 +133,45 @@ class Tensor:
119
133
  self._ctx: Optional[Function] = None
120
134
 
121
135
  # create a LazyBuffer from the different types of inputs
122
- if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
123
- elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
124
- elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
125
- elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
136
+ if isinstance(data, (LazyBuffer, MultiLazyBuffer)): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
137
+ elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
138
+ elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
139
+ elif isinstance(data, UOp):
140
+ assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
141
+ data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
142
+ elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
126
143
  elif isinstance(data, (list, tuple)):
127
144
  if dtype is None:
128
145
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
129
146
  else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
130
- if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
131
- else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
132
- elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
133
- elif isinstance(data, np.ndarray):
134
- if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
135
- else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
147
+ if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
148
+ else: data = _frompy(data, dtype)
149
+ elif str(type(data)) == "<class 'numpy.ndarray'>":
150
+ import numpy as np
151
+ assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
152
+ if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
153
+ else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
154
+ elif isinstance(data, pathlib.Path):
155
+ dtype = dtype or dtypes.uint8
156
+ data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
136
157
 
137
158
  # by this point, it has to be a LazyBuffer
138
- if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
139
- raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
159
+ if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
140
160
 
141
- # data is a LazyBuffer, but it might be on the wrong device
142
- if isinstance(device, tuple):
143
- # if device is a tuple, we should have/construct a MultiLazyBuffer
144
- if isinstance(data, MultiLazyBuffer):
145
- assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
146
- self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
147
- else:
148
- self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
161
+ # data might be on a different device
162
+ if isinstance(device, str): self.lazydata:Union[LazyBuffer, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
163
+ # if device is a tuple, we should have/construct a MultiLazyBuffer
164
+ elif isinstance(data, LazyBuffer): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
149
165
  else:
150
- self.lazydata = data if data.device == device else data.copy_to_device(device)
166
+ assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
167
+ self.lazydata = data
151
168
 
152
169
  class train(ContextDecorator):
153
170
  def __init__(self, mode:bool = True): self.mode = mode
154
171
  def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
155
172
  def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
156
173
 
157
- class inference_mode(ContextDecorator):
174
+ class test(ContextDecorator):
158
175
  def __init__(self, mode:bool = True): self.mode = mode
159
176
  def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
160
177
  def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
@@ -182,17 +199,18 @@ class Tensor:
182
199
 
183
200
  # ***** data handlers ****
184
201
 
185
- def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
186
- """Creates the schedule needed to realize these Tensor(s), with Variables."""
187
- if getenv("FUZZ_SCHEDULE"):
188
- from test.external.fuzz_schedule import fuzz_schedule
189
- fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
190
- schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
202
+ def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
203
+ """
204
+ Creates the schedule needed to realize these Tensor(s), with Variables.
205
+
206
+ NOTE: A Tensor can only be scheduled once.
207
+ """
208
+ schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
191
209
  return memory_planner(schedule), var_vals
192
210
 
193
- def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
211
+ def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
194
212
  """Creates the schedule needed to realize these Tensor(s)."""
195
- schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
213
+ schedule, var_vals = self.schedule_with_vars(*lst)
196
214
  assert len(var_vals) == 0
197
215
  return schedule
198
216
 
@@ -226,7 +244,7 @@ class Tensor:
226
244
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
227
245
  assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
228
246
  assert not x.requires_grad # self requires_grad is okay?
229
- if not self.lazydata.is_realized(): return self.replace(x)
247
+ if not self.lazydata.is_realized: return self.replace(x)
230
248
  self.lazydata = self.lazydata.assign(x.lazydata)
231
249
  return self
232
250
 
@@ -239,7 +257,7 @@ class Tensor:
239
257
  def _data(self) -> memoryview:
240
258
  if 0 in self.shape: return memoryview(bytearray(0))
241
259
  # NOTE: this realizes on the object from as_buffer being a Python object
242
- cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
260
+ cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
243
261
  buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
244
262
  if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
245
263
  return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
@@ -283,7 +301,7 @@ class Tensor:
283
301
  """
284
302
  return self.data().tolist()
285
303
 
286
- def numpy(self) -> np.ndarray:
304
+ def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
287
305
  """
288
306
  Returns the value of this tensor as a `numpy.ndarray`.
289
307
 
@@ -292,11 +310,21 @@ class Tensor:
292
310
  print(repr(t.numpy()))
293
311
  ```
294
312
  """
313
+ import numpy as np
295
314
  if self.dtype == dtypes.bfloat16: return self.float().numpy()
296
315
  assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
297
316
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
298
317
  return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
299
318
 
319
+ def clone(self) -> Tensor:
320
+ """
321
+ Creates a clone of this tensor allocating a seperate buffer for the data.
322
+ """
323
+ ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
324
+ if self.grad is not None: ret.grad = self.grad.clone()
325
+ if hasattr(self, '_ctx'): ret._ctx = self._ctx
326
+ return ret
327
+
300
328
  def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
301
329
  """
302
330
  Moves the tensor to the given device.
@@ -318,38 +346,54 @@ class Tensor:
318
346
  if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
319
347
  self.lazydata = real.lazydata
320
348
 
321
- def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
349
+ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
322
350
  """
323
- Shards the tensor across the given devices.
351
+ Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
352
+
353
+ ```python exec="true" source="above" session="tensor" result="python"
354
+ t = Tensor.empty(2, 3)
355
+ print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
356
+ ```
357
+
324
358
  """
325
359
  assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
326
- canonical_devices = tuple(Device.canonicalize(x) for x in devices)
327
- if axis is not None and axis < 0: axis += len(self.shape)
328
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
329
-
330
- def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
360
+ devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
361
+ if axis is not None:
362
+ if axis < 0: axis += len(self.shape)
363
+ if splits is None:
364
+ if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
365
+ sz = ceildiv(total, len(devices))
366
+ splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
367
+ assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
368
+ boundaries = tuple(itertools.accumulate(splits))
369
+ bounds = tuple(zip((0,) + boundaries, boundaries))
370
+ return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
371
+
372
+ def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
331
373
  """
332
374
  Shards the tensor across the given devices in place.
333
375
  """
334
- self.lazydata = self.shard(devices, axis).lazydata
376
+ self.lazydata = self.shard(devices, axis, splits).lazydata
335
377
  return self
336
378
 
337
379
  @staticmethod
338
- def from_node(y:Node, **kwargs) -> Tensor:
339
- if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
340
- if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
341
- if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
342
- if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
343
- raise RuntimeError(f"unhandled Node {y}")
380
+ def from_uop(y:UOp, **kwargs) -> Tensor:
381
+ if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
382
+ if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
383
+ if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
384
+ if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
385
+ if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
386
+ raise RuntimeError(f"unhandled UOp {y}")
344
387
 
345
- # ***** creation llop entrypoint *****
388
+ # ***** creation entrypoint *****
346
389
 
347
390
  @staticmethod
348
- def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
391
+ def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
392
+ dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
349
393
  if isinstance(device, tuple):
350
- return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
351
- for d in device], None), device, dtype, **kwargs)
352
- return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
394
+ return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
395
+ device, dtype, **kwargs)
396
+ return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
353
397
 
354
398
  @staticmethod
355
399
  def empty(*shape, **kwargs):
@@ -364,10 +408,39 @@ class Tensor:
364
408
  print(t.shape)
365
409
  ```
366
410
  """
367
- return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
411
+ return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
412
+
413
+ @staticmethod
414
+ def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
415
+ """
416
+ Exposes the pointer as a Tensor without taking ownership of the original data.
417
+ The pointer must remain valid for the entire lifetime of the created Tensor.
418
+
419
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
420
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
421
+ """
422
+
423
+ r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
424
+ r.lazydata.buffer.allocate(external_ptr=ptr)
425
+ del r.lazydata.srcs # fake realize
426
+ return r
427
+
428
+ @staticmethod
429
+ def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
430
+ """
431
+ Create a Tensor from a URL.
432
+
433
+ This is the preferred way to access Internet resources.
434
+ It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
435
+ This also will soon become lazy (when possible) and not print progress without DEBUG.
436
+
437
+ THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
438
+ """
439
+ return Tensor(fetch(url, gunzip=gunzip), **kwargs)
368
440
 
369
441
  _seed: int = int(time.time())
370
- _rng_counter: Optional[Tensor] = None
442
+ _device_seeds: Dict[str, Tensor] = {}
443
+ _device_rng_counters: Dict[str, Tensor] = {}
371
444
  @staticmethod
372
445
  def manual_seed(seed=0):
373
446
  """
@@ -384,10 +457,17 @@ class Tensor:
384
457
  print(Tensor.rand(5).numpy())
385
458
  ```
386
459
  """
387
- Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
460
+ Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
388
461
 
389
462
  @staticmethod
390
- def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
463
+ def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
464
+ x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
465
+ x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
466
+ counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
467
+ return counts0.cat(counts1)
468
+
469
+ @staticmethod
470
+ def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
391
471
  """
392
472
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
393
473
 
@@ -400,35 +480,55 @@ class Tensor:
400
480
  print(t.numpy())
401
481
  ```
402
482
  """
403
- if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
404
- if not THREEFRY.value:
405
- # for bfloat16, numpy rand passes buffer in float
406
- if (dtype or dtypes.default_float) == dtypes.bfloat16:
407
- return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
408
- return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
409
-
410
- # threefry
411
- if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
412
- counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
413
- counts2 = counts1 + math.ceil(num / 2)
414
- Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
415
-
416
- rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
417
- ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
418
-
419
- x = [counts1 + ks[-1], counts2 + ks[0]]
420
- for i in range(5):
421
- for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
422
- x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
423
- out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
424
- out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
483
+ if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
484
+ if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
485
+ if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
486
+ _device = device = Device.canonicalize(device)
487
+
488
+ # when using MOCKGPU and NV generate rand on CLANG
489
+ if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
490
+
491
+ # generate per device seeds and rng counter if we haven't seen this device yet
492
+ if device not in Tensor._device_seeds:
493
+ Tensor._device_seeds[device] = Tensor(
494
+ [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
495
+ device=device, dtype=dtypes.uint32, requires_grad=False)
496
+ Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
497
+ had_counter = False
498
+ else: had_counter = True
499
+
500
+ # if shape has 0, return zero tensor
501
+ if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
502
+ num = ceildiv(numel * dtype.itemsize, 4)
503
+
504
+ # increment rng counter for devices
505
+ if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
506
+
507
+ # threefry random bits
508
+ counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
509
+ counts1 = counts0 + ceildiv(num, 2)
510
+ bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
511
+
512
+ # bitcast to uint with same number of bits
513
+ _, nmant = dtypes.finfo(dtype)
514
+ uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
515
+ bits = bits.bitcast(uint_dtype)
516
+ # only randomize the mantissa bits and set the exponent to 1
517
+ one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
518
+ bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
519
+ # bitcast back to the original dtype and reshape
520
+ out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
521
+
522
+ # move back to the original device if we were using MOCKGPU
523
+ if getenv("MOCKGPU") and _device: out = out.to(_device)
524
+
425
525
  out.requires_grad = kwargs.get("requires_grad")
426
- return out.contiguous()
526
+ return out.contiguous() if contiguous else out
427
527
 
428
528
  # ***** creation helper functions *****
429
529
 
430
530
  @staticmethod
431
- def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
531
+ def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
432
532
  """
433
533
  Creates a tensor with the given shape, filled with the given value.
434
534
 
@@ -445,7 +545,7 @@ class Tensor:
445
545
  return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
446
546
 
447
547
  @staticmethod
448
- def zeros(*shape, **kwargs):
548
+ def zeros(*shape, **kwargs) -> Tensor:
449
549
  """
450
550
  Creates a tensor with the given shape, filled with zeros.
451
551
 
@@ -462,7 +562,7 @@ class Tensor:
462
562
  return Tensor.full(argfix(*shape), 0.0, **kwargs)
463
563
 
464
564
  @staticmethod
465
- def ones(*shape, **kwargs):
565
+ def ones(*shape, **kwargs) -> Tensor:
466
566
  """
467
567
  Creates a tensor with the given shape, filled with ones.
468
568
 
@@ -479,7 +579,7 @@ class Tensor:
479
579
  return Tensor.full(argfix(*shape), 1.0, **kwargs)
480
580
 
481
581
  @staticmethod
482
- def arange(start, stop=None, step=1, **kwargs):
582
+ def arange(start, stop=None, step=1, **kwargs) -> Tensor:
483
583
  """
484
584
  Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
485
585
 
@@ -504,14 +604,35 @@ class Tensor:
504
604
  ```
505
605
  """
506
606
  if stop is None: stop, start = start, 0
507
- assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
508
607
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
509
- return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
608
+ # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
609
+ if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
610
+ return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
611
+
612
+ @staticmethod
613
+ def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
614
+ """
615
+ Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
616
+
617
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
618
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
619
+
620
+ ```python exec="true" source="above" session="tensor" result="python"
621
+ print(Tensor.linspace(0, 10, 5).numpy())
622
+ ```
623
+ ```python exec="true" source="above" session="tensor" result="python"
624
+ print(Tensor.linspace(-1, 1, 5).numpy())
625
+ ```
626
+ """
627
+ if steps < 0: raise ValueError("number of steps must be non-negative")
628
+ if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
629
+ if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
630
+ return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
510
631
 
511
632
  @staticmethod
512
- def eye(dim:int, **kwargs):
633
+ def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
513
634
  """
514
- Creates an identity matrix of the given dimension.
635
+ Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
515
636
 
516
637
  You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
517
638
  Additionally, all other keyword arguments are passed to the constructor of the tensor.
@@ -519,10 +640,16 @@ class Tensor:
519
640
  ```python exec="true" source="above" session="tensor" result="python"
520
641
  print(Tensor.eye(3).numpy())
521
642
  ```
643
+
644
+ ```python exec="true" source="above" session="tensor" result="python"
645
+ print(Tensor.eye(2, 4).numpy())
646
+ ```
522
647
  """
523
- return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
648
+ if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
649
+ x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
650
+ return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
524
651
 
525
- def full_like(self, fill_value:ConstType, **kwargs):
652
+ def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
526
653
  """
527
654
  Creates a tensor with the same shape as `self`, filled with the given value.
528
655
  If `dtype` is not specified, the dtype of `self` is used.
@@ -537,7 +664,7 @@ class Tensor:
537
664
  """
538
665
  return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
539
666
 
540
- def zeros_like(self, **kwargs):
667
+ def zeros_like(self, **kwargs) -> Tensor:
541
668
  """
542
669
  Creates a tensor with the same shape as `self`, filled with zeros.
543
670
 
@@ -551,7 +678,7 @@ class Tensor:
551
678
  """
552
679
  return self.full_like(0, **kwargs)
553
680
 
554
- def ones_like(self, **kwargs):
681
+ def ones_like(self, **kwargs) -> Tensor:
555
682
  """
556
683
  Creates a tensor with the same shape as `self`, filled with ones.
557
684
 
@@ -565,10 +692,31 @@ class Tensor:
565
692
  """
566
693
  return self.full_like(1, **kwargs)
567
694
 
695
+ def rand_like(self, **kwargs) -> Tensor:
696
+ """
697
+ Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
698
+
699
+ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
700
+ Additionally, all other keyword arguments are passed to the constructor of the tensor.
701
+
702
+ ```python exec="true" source="above" session="tensor" result="python"
703
+ t = Tensor.ones(2, 3)
704
+ print(Tensor.rand_like(t).numpy())
705
+ ```
706
+ """
707
+ dtype = kwargs.pop("dtype", self.dtype)
708
+ if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
709
+ if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
710
+ if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
711
+ contiguous = kwargs.pop("contiguous", True)
712
+ rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
713
+ return Tensor(MultiLazyBuffer(cast(List[LazyBuffer], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
714
+ return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
715
+
568
716
  # ***** rng hlops *****
569
717
 
570
718
  @staticmethod
571
- def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
719
+ def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
572
720
  """
573
721
  Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
574
722
  If `dtype` is not specified, the default type is used.
@@ -600,7 +748,7 @@ class Tensor:
600
748
  ```
601
749
  """
602
750
  if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
603
- dtype = kwargs.pop("dtype", dtypes.int32)
751
+ dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
604
752
  if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
605
753
  return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
606
754
 
@@ -706,7 +854,7 @@ class Tensor:
706
854
  assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
707
855
  weight = self.unsqueeze(0) if self.ndim == 1 else self
708
856
  cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
709
- unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
857
+ unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
710
858
  indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
711
859
  return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
712
860
 
@@ -715,39 +863,46 @@ class Tensor:
715
863
  def _deepwalk(self):
716
864
  def _walk(node, visited):
717
865
  visited.add(node)
718
- if getattr(node, "_ctx", None):
866
+ # if tensor is not leaf, reset grad
867
+ if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
868
+ if ctx:
719
869
  for i in node._ctx.parents:
720
870
  if i not in visited: yield from _walk(i, visited)
721
871
  yield node
722
872
  return list(_walk(self, set()))
723
873
 
724
- def backward(self) -> Tensor:
874
+ def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
725
875
  """
726
876
  Propagates the gradient of a tensor backwards through the computation graph.
727
- Must be used on a scalar tensor.
728
-
877
+ If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
878
+ If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
729
879
  ```python exec="true" source="above" session="tensor" result="python"
730
880
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
731
881
  t.sum().backward()
732
882
  print(t.grad.numpy())
733
883
  ```
734
884
  """
735
- assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
736
-
737
- # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
738
- # this is "implicit gradient creation"
739
- self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
885
+ toposorted = self._deepwalk()
886
+ if gradient is None:
887
+ assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
888
+ # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
889
+ # this is "implicit gradient creation"
890
+ gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
740
891
 
741
- for t0 in reversed(self._deepwalk()):
892
+ assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
893
+ self.grad = gradient
894
+ for t0 in reversed(toposorted):
742
895
  if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
896
+ token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
743
897
  grads = t0._ctx.backward(t0.grad.lazydata)
898
+ _METADATA.reset(token)
744
899
  grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
745
900
  for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
746
901
  for t, g in zip(t0._ctx.parents, grads):
747
902
  if g is not None and t.requires_grad:
748
903
  assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
749
904
  t.grad = g if t.grad is None else (t.grad + g)
750
- del t0._ctx
905
+ if not retain_graph: del t0._ctx
751
906
  return self
752
907
 
753
908
  # ***** movement low level ops *****
@@ -822,7 +977,7 @@ class Tensor:
822
977
  ```
823
978
  """
824
979
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
825
- if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
980
+ if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
826
981
  return F.Flip.apply(self, axis=axis_arg)
827
982
 
828
983
  def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
@@ -842,30 +997,58 @@ class Tensor:
842
997
  print(t.shrink((((0, 2), (0, 2)))).numpy())
843
998
  ```
844
999
  """
845
- if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
846
- return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
1000
+ if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
1001
+ return F.Shrink.apply(self, arg=tuple(shrink_arg))
847
1002
 
848
- def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
1003
+ def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
849
1004
  """
850
- Returns a tensor that pads the each axis based on input arg.
851
- `arg` must have the same length as `self.ndim`.
852
- For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
853
- If `value` is specified, the tensor is padded with `value` instead of `0.0`.
1005
+ Returns a tensor with padding applied based on the input `padding`.
1006
+ `padding` supports two padding structures:
1007
+
1008
+ 1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
1009
+ - This structure matches PyTorch's pad.
1010
+ - `padding` length must be even.
1011
+
1012
+ 2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
1013
+ - This structure matches pad for jax, numpy, tensorflow and others.
1014
+ - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
1015
+ - `padding` must have the same length as `self.ndim`.
1016
+
1017
+ Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
1018
+ Padding modes is selected with `mode` which supports `constant` and `reflect`.
854
1019
 
855
1020
  ```python exec="true" source="above" session="tensor" result="python"
856
- t = Tensor.arange(6).reshape(2, 3)
1021
+ t = Tensor.arange(9).reshape(1, 1, 3, 3)
857
1022
  print(t.numpy())
858
1023
  ```
859
1024
  ```python exec="true" source="above" session="tensor" result="python"
860
- print(t.pad(((None, (1, 2)))).numpy())
1025
+ print(t.pad((1, 2, 0, -1)).numpy())
1026
+ ```
1027
+ ```python exec="true" source="above" session="tensor" result="python"
1028
+ print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
861
1029
  ```
862
1030
  ```python exec="true" source="above" session="tensor" result="python"
863
- print(t.pad(((None, (1, 2))), -2).numpy())
1031
+ print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
864
1032
  ```
865
1033
  """
866
- if all(x is None or x == (0,0) for x in arg): return self
867
- ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
868
- return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
1034
+ if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
1035
+ if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
1036
+ # turn flat padding into group padding
1037
+ pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
1038
+ if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
1039
+ X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
1040
+ def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
1041
+ # early return for symbolic with positive pads (no need to max)
1042
+ if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
1043
+ pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
1044
+ if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
1045
+ assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1046
+ for d,(pB,pA) in enumerate(pads):
1047
+ if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
1048
+ slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
1049
+ xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
1050
+ X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
1051
+ return X.shrink(shrinks(X.shape))
869
1052
 
870
1053
  # ***** movement high level ops *****
871
1054
 
@@ -897,7 +1080,7 @@ class Tensor:
897
1080
  # 2. Bool indexing is not supported
898
1081
  # 3. Out of bounds Tensor indexing results in 0
899
1082
  # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
900
- def __getitem__(self, indices) -> Tensor:
1083
+ def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
901
1084
  # 1. indices normalization and validation
902
1085
  # treat internal tuples and lists as Tensors and standardize indices to list type
903
1086
  if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
@@ -921,7 +1104,6 @@ class Tensor:
921
1104
 
922
1105
  # record None for dimension injection later and filter None and record rest of indices
923
1106
  type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
924
- tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
925
1107
  indices_filtered = [i for i in indices if i is not None]
926
1108
  for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
927
1109
 
@@ -939,13 +1121,15 @@ class Tensor:
939
1121
  indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
940
1122
  for dim in type_dim[slice]:
941
1123
  if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
1124
+ if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)):
1125
+ raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice("
1126
+ f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).")
942
1127
  s, e, st = index.indices(self.shape[dim])
943
1128
  indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
944
- # record tensors and skip all Tensor dims for basic indexing
945
- tensor_index: List[Tensor] = []
1129
+ # skip all Tensor dims for basic indexing
946
1130
  for dim in type_dim[Tensor]:
947
- tensor_index.append(index := indices_filtered[dim])
948
- if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
1131
+ dtype = indices_filtered[dim].dtype
1132
+ if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported")
949
1133
  indices_filtered[dim] = ((0, self.shape[dim]), 1)
950
1134
 
951
1135
  new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
@@ -955,6 +1139,7 @@ class Tensor:
955
1139
  if any(abs(st) != 1 for st in strides):
956
1140
  strides = tuple(abs(s) for s in strides)
957
1141
  # pad shape to multiple of stride
1142
+ if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted")
958
1143
  ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
959
1144
  ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
960
1145
  ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
@@ -968,6 +1153,7 @@ class Tensor:
968
1153
 
969
1154
  # 3. advanced indexing (copy)
970
1155
  if type_dim[Tensor]:
1156
+ dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)]
971
1157
  # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
972
1158
  def calc_dim(tensor_dim:int) -> int:
973
1159
  return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
@@ -975,7 +1161,7 @@ class Tensor:
975
1161
  assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
976
1162
  # track tensor_dim and tensor_index using a dict
977
1163
  # calc_dim to get dim and use that to normalize the negative tensor indices
978
- idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
1164
+ idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors}
979
1165
 
980
1166
  masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
981
1167
  pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
@@ -993,35 +1179,47 @@ class Tensor:
993
1179
  # inject 1's for the extra dims added in create masks
994
1180
  reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
995
1181
  # sum reduce the extra dims introduced in create masks
996
- ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
1182
+ ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
997
1183
 
998
1184
  # special permute case
999
1185
  if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
1000
1186
  ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
1187
+
1188
+ # for advanced setitem, returns whole tensor with indices replaced
1189
+ if v is not None:
1190
+ vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
1191
+ # add back reduced dims from sum
1192
+ for dim in sum_axis: vb = vb.unsqueeze(dim)
1193
+ # axis to be reduced to match self.shape
1194
+ axis = tuple(range(first_dim, first_dim + len(big_shape)))
1195
+ # apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
1196
+ vb = vb * mask
1197
+ for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
1198
+ # reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
1199
+ ret = mask.any(axis).where(vb.squeeze(), self)
1200
+
1001
1201
  return ret
1002
1202
 
1203
+ def __getitem__(self, indices) -> Tensor:
1204
+ return self._getitem(indices)
1205
+
1003
1206
  def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
1004
1207
  if isinstance(self.device, str) and self.device.startswith("DISK"):
1005
- self.__getitem__(indices).assign(v)
1208
+ self._getitem(indices).assign(v)
1006
1209
  return
1007
1210
  # NOTE: check that setitem target is valid first
1008
- assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
1211
+ if not all(lb.st.contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
1009
1212
  if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
1010
1213
  if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
1011
1214
  if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
1012
- if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
1013
- raise NotImplementedError("Advanced indexing setitem is not currently supported")
1014
-
1015
- assign_to = self.realize().__getitem__(indices)
1016
- # NOTE: contiguous to prevent const folding.
1017
- v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
1018
- assign_to.assign(v).realize()
1019
1215
 
1020
- # NOTE: using _slice is discouraged and things should migrate to pad and shrink
1021
- def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
1022
- arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
1023
- padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
1024
- return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
1216
+ res = self.realize()._getitem(indices, v)
1217
+ # if shapes match and data is not shared it's a copy and we assign to self
1218
+ if res.shape == self.shape and res.lazydata is not self.lazydata:
1219
+ self.assign(res).realize()
1220
+ else: # no copy, basic setitem
1221
+ v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
1222
+ res.assign(v).realize()
1025
1223
 
1026
1224
  def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
1027
1225
  """
@@ -1036,8 +1234,8 @@ class Tensor:
1036
1234
  ```
1037
1235
  """
1038
1236
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
1039
- assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1040
1237
  dim = self._resolve_dim(dim)
1238
+ assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1041
1239
  index = index.to(self.device)
1042
1240
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1043
1241
  return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
@@ -1056,13 +1254,11 @@ class Tensor:
1056
1254
  ```
1057
1255
  """
1058
1256
  dim = self._resolve_dim(dim)
1059
- assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
1060
- catargs = [self, *args]
1061
- cat_dims = [s.shape[dim] for s in catargs]
1062
- cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
1063
- slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
1064
- for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
1065
- return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
1257
+ for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
1258
+ tensors = [self, *args]
1259
+ dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
1260
+ for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
1261
+ return functools.reduce(Tensor.add, tensors)
1066
1262
 
1067
1263
  def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
1068
1264
  """
@@ -1077,7 +1273,20 @@ class Tensor:
1077
1273
  ```
1078
1274
  """
1079
1275
  # checks for shapes and number of dimensions delegated to cat
1080
- return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
1276
+ return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
1277
+
1278
+ def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
1279
+ """
1280
+ Repeat elements of a tensor.
1281
+
1282
+ ```python exec="true" source="above" session="tensor" result="python"
1283
+ t = Tensor([1, 2, 3])
1284
+ print(t.repeat_interleave(2).numpy())
1285
+ ```
1286
+ """
1287
+ x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
1288
+ shp = x.shape
1289
+ return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
1081
1290
 
1082
1291
  def repeat(self, repeats, *args) -> Tensor:
1083
1292
  """
@@ -1093,16 +1302,16 @@ class Tensor:
1093
1302
  ```
1094
1303
  """
1095
1304
  repeats = argfix(repeats, *args)
1096
- base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
1097
- new_shape = [x for b in base_shape for x in [1, b]]
1098
- expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
1305
+ base_shape = _pad_left(self.shape, repeats)[0]
1306
+ unsqueezed_shape = flatten([[1, s] for s in base_shape])
1307
+ expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
1099
1308
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
1100
- return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
1309
+ return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
1101
1310
 
1102
- def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
1103
- if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
1104
- raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
1105
- return dim + self.ndim+outer if dim < 0 else dim
1311
+ def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
1312
+ total = self.ndim + int(extra)
1313
+ if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
1314
+ return dim + total if dim < 0 else dim
1106
1315
 
1107
1316
  def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
1108
1317
  """
@@ -1151,7 +1360,34 @@ class Tensor:
1151
1360
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1152
1361
  assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
1153
1362
  dim = self._resolve_dim(dim)
1154
- return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1363
+ return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
1364
+
1365
+ def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]:
1366
+ """
1367
+ Generates coordinate matrices from coordinate vectors.
1368
+ Input tensors can be scalars or 1D tensors.
1369
+
1370
+ `indexing` determines how the output grids are aligned.
1371
+ `ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing.
1372
+
1373
+ ```python exec="true" source="above" session="tensor" result="python"
1374
+ x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
1375
+ grid_x, grid_y = x.meshgrid(y)
1376
+ print(grid_x.numpy())
1377
+ print(grid_y.numpy())
1378
+ ```
1379
+ ```python exec="true" source="above" session="tensor" result="python"
1380
+ grid_x, grid_y = x.meshgrid(y, indexing="xy")
1381
+ print(grid_x.numpy())
1382
+ print(grid_y.numpy())
1383
+ ```
1384
+ """
1385
+ if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
1386
+ if len(tensors:=(self, *args)) == 1: return tensors
1387
+ basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors)))
1388
+ tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors))
1389
+ output_shape = _broadcast_shape(*(t.shape for t in tensors))
1390
+ return tuple(t._broadcast_to(output_shape) for t in tensors)
1155
1391
 
1156
1392
  def squeeze(self, dim:Optional[int]=None) -> Tensor:
1157
1393
  """
@@ -1185,25 +1421,9 @@ class Tensor:
1185
1421
  print(t.unsqueeze(1).numpy())
1186
1422
  ```
1187
1423
  """
1188
- dim = self._resolve_dim(dim, outer=True)
1424
+ dim = self._resolve_dim(dim, extra=True)
1189
1425
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
1190
1426
 
1191
- def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
1192
- """
1193
- Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
1194
- If `value` is specified, the tensor is padded with `value` instead of `0.0`.
1195
-
1196
- ```python exec="true" source="above" session="tensor" result="python"
1197
- t = Tensor.arange(9).reshape(1, 1, 3, 3)
1198
- print(t.numpy())
1199
- ```
1200
- ```python exec="true" source="above" session="tensor" result="python"
1201
- print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
1202
- ```
1203
- """
1204
- slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
1205
- return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
1206
-
1207
1427
  @property
1208
1428
  def T(self) -> Tensor:
1209
1429
  """`.T` is an alias for `.transpose()`."""
@@ -1259,20 +1479,37 @@ class Tensor:
1259
1479
  dim = self._resolve_dim(dim)
1260
1480
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
1261
1481
 
1482
+ def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
1483
+ """
1484
+ Rolls the tensor along specified dimension(s).
1485
+ The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
1486
+
1487
+ ```python exec="true" source="above" session="tensor" result="python"
1488
+ t = Tensor.arange(4)
1489
+ print(t.roll(shifts=1, dims=0).numpy())
1490
+ ```
1491
+ ```python exec="true" source="above" session="tensor" result="python"
1492
+ print(t.roll(shifts=-1, dims=0).numpy())
1493
+ ```
1494
+ """
1495
+ dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
1496
+ for dim, shift in zip(dims, make_tuple(shifts, 1)):
1497
+ shift = shift % self.shape[dim]
1498
+ rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
1499
+ rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
1500
+ return rolled
1501
+
1262
1502
  # ***** reduce ops *****
1263
1503
 
1264
1504
  def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1265
- if self.ndim == 0:
1266
- if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
1267
- axis = ()
1268
- axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
1269
- axis_ = tuple(self._resolve_dim(x) for x in axis_)
1270
- ret = fxn.apply(self, axis=axis_)
1271
- return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1505
+ axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
1506
+ if self.ndim == 0: axis = ()
1507
+ ret = fxn.apply(self, axis=axis)
1508
+ return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
1272
1509
 
1273
- def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
1510
+ def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
1274
1511
  """
1275
- Sums the elements of the tensor along the specified axis or axes.
1512
+ Returns the sum of the elements of the tensor along the specified axis or axes.
1276
1513
 
1277
1514
  You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1278
1515
  which the maximum is computed and whether the reduced dimensions are retained.
@@ -1294,9 +1531,35 @@ class Tensor:
1294
1531
  print(t.sum(axis=1).numpy())
1295
1532
  ```
1296
1533
  """
1297
- ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
1534
+ ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
1298
1535
  return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1299
1536
 
1537
+ def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
1538
+ """
1539
+ Returns the product of the elements of the tensor along the specified axis or axes.
1540
+
1541
+ You can pass in `axis` and `keepdim` keyword arguments to control the axis along
1542
+ which the maximum is computed and whether the reduced dimensions are retained.
1543
+
1544
+ You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
1545
+ If not specified, the accumulation data type is chosen based on the input tensor's data type.
1546
+
1547
+ ```python exec="true" source="above" session="tensor" result="python"
1548
+ t = Tensor([-1, -2, -3, 1, 2, 3]).reshape(2, 3)
1549
+ print(t.numpy())
1550
+ ```
1551
+ ```python exec="true" source="above" session="tensor" result="python"
1552
+ print(t.prod().numpy())
1553
+ ```
1554
+ ```python exec="true" source="above" session="tensor" result="python"
1555
+ print(t.prod(axis=0).numpy())
1556
+ ```
1557
+ ```python exec="true" source="above" session="tensor" result="python"
1558
+ print(t.prod(axis=1).numpy())
1559
+ ```
1560
+ """
1561
+ return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
1562
+
1300
1563
  def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1301
1564
  """
1302
1565
  Returns the maximum value of the tensor along the specified axis or axes.
@@ -1341,8 +1604,53 @@ class Tensor:
1341
1604
  print(t.min(axis=1, keepdim=True).numpy())
1342
1605
  ```
1343
1606
  """
1607
+ if dtypes.is_int(self.dtype) or self.dtype == dtypes.bool: return ~((~self).max(axis=axis, keepdim=keepdim))
1344
1608
  return -((-self).max(axis=axis, keepdim=keepdim))
1345
1609
 
1610
+ def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1611
+ """
1612
+ Tests if any element evaluates to `True` along the specified axis or axes.
1613
+
1614
+ You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
1615
+
1616
+ ```python exec="true" source="above" session="tensor" result="python"
1617
+ t = Tensor([[True, True], [True, False], [False, False]])
1618
+ print(t.numpy())
1619
+ ```
1620
+ ```python exec="true" source="above" session="tensor" result="python"
1621
+ print(t.any().numpy())
1622
+ ```
1623
+ ```python exec="true" source="above" session="tensor" result="python"
1624
+ print(t.any(axis=0).numpy())
1625
+ ```
1626
+ ```python exec="true" source="above" session="tensor" result="python"
1627
+ print(t.any(axis=1, keepdim=True).numpy())
1628
+ ```
1629
+ """
1630
+ return self.bool().max(axis, keepdim)
1631
+
1632
+ def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1633
+ """
1634
+ Tests if all element evaluates to `True` along the specified axis or axes.
1635
+
1636
+ You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
1637
+
1638
+ ```python exec="true" source="above" session="tensor" result="python"
1639
+ t = Tensor([[True, True], [True, False], [False, False]])
1640
+ print(t.numpy())
1641
+ ```
1642
+ ```python exec="true" source="above" session="tensor" result="python"
1643
+ print(t.all().numpy())
1644
+ ```
1645
+ ```python exec="true" source="above" session="tensor" result="python"
1646
+ print(t.all(axis=0).numpy())
1647
+ ```
1648
+ ```python exec="true" source="above" session="tensor" result="python"
1649
+ print(t.all(axis=1, keepdim=True).numpy())
1650
+ ```
1651
+ """
1652
+ return self.logical_not().any(axis, keepdim).logical_not()
1653
+
1346
1654
  def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1347
1655
  """
1348
1656
  Returns the mean value of the tensor along the specified axis or axes.
@@ -1367,7 +1675,7 @@ class Tensor:
1367
1675
  """
1368
1676
  output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
1369
1677
  numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1370
- return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
1678
+ return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
1371
1679
 
1372
1680
  def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1373
1681
  """
@@ -1392,8 +1700,8 @@ class Tensor:
1392
1700
  ```
1393
1701
  """
1394
1702
  squares = (self - self.mean(axis=axis, keepdim=True)).square()
1395
- n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
1396
- return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
1703
+ n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
1704
+ return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
1397
1705
 
1398
1706
  def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1399
1707
  """
@@ -1419,12 +1727,30 @@ class Tensor:
1419
1727
  """
1420
1728
  return self.var(axis, keepdim, correction).sqrt()
1421
1729
 
1422
- def _softmax(self, axis):
1423
- m = self - self.max(axis=axis, keepdim=True)
1730
+ def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1731
+ """
1732
+ Calculates the standard deviation and mean over the dimensions specified by dim.
1733
+ Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
1734
+
1735
+ ```python exec="true" source="above" session="tensor" result="python"
1736
+ Tensor.manual_seed(42)
1737
+ t = Tensor.normal(2, 3, mean=2.5, std=0.5)
1738
+ print(t.numpy())
1739
+ ```
1740
+ ```python exec="true" source="above" session="tensor" result="python"
1741
+ std, mean = t.std_mean()
1742
+ print(std.numpy(), mean.numpy())
1743
+ ```
1744
+ """
1745
+ return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
1746
+
1747
+ def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
1748
+ x = self.cast(dtype) if dtype is not None else self
1749
+ m = x - x.max(axis=axis, keepdim=True).detach()
1424
1750
  e = m.exp()
1425
1751
  return m, e, e.sum(axis=axis, keepdim=True)
1426
1752
 
1427
- def softmax(self, axis=-1):
1753
+ def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
1428
1754
  """
1429
1755
  Applies the softmax function to the tensor along the specified axis.
1430
1756
 
@@ -1444,10 +1770,10 @@ class Tensor:
1444
1770
  print(t.softmax(axis=0).numpy())
1445
1771
  ```
1446
1772
  """
1447
- _, e, ss = self._softmax(axis)
1773
+ _, e, ss = self._softmax(axis, dtype)
1448
1774
  return e.div(ss)
1449
1775
 
1450
- def log_softmax(self, axis=-1):
1776
+ def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
1451
1777
  """
1452
1778
  Applies the log-softmax function to the tensor along the specified axis.
1453
1779
 
@@ -1467,7 +1793,7 @@ class Tensor:
1467
1793
  print(t.log_softmax(axis=0).numpy())
1468
1794
  ```
1469
1795
  """
1470
- m, _, ss = self._softmax(axis)
1796
+ m, _, ss = self._softmax(axis, dtype)
1471
1797
  return m - ss.log()
1472
1798
 
1473
1799
  def logsumexp(self, axis=None, keepdim=False):
@@ -1497,6 +1823,33 @@ class Tensor:
1497
1823
  m = self.max(axis=axis, keepdim=True)
1498
1824
  return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
1499
1825
 
1826
+ def logcumsumexp(self, axis=0):
1827
+ """
1828
+ Computes the log-cumsum-exp of the tensor along the specified axis or axes.
1829
+
1830
+ The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
1831
+
1832
+ You can pass in the `axis` keyword argument to control the axis along which
1833
+ the log-cum-sum-exp is computed.
1834
+
1835
+ ```python exec="true" source="above" session="tensor" result="python"
1836
+ Tensor.manual_seed(42)
1837
+ t = Tensor.randn(2, 3)
1838
+ print(t.numpy())
1839
+ ```
1840
+ ```python exec="true" source="above" session="tensor" result="python"
1841
+ print(t.logcumsumexp().numpy())
1842
+ ```
1843
+ ```python exec="true" source="above" session="tensor" result="python"
1844
+ print(t.logcumsumexp(axis=0).numpy())
1845
+ ```
1846
+ ```python exec="true" source="above" session="tensor" result="python"
1847
+ print(t.logcumsumexp(axis=1).numpy())
1848
+ ```
1849
+ """
1850
+ m = self.max(axis=axis, keepdim=True)
1851
+ return (self - m).exp().cumsum(axis=axis).log() + m
1852
+
1500
1853
  def argmax(self, axis=None, keepdim=False):
1501
1854
  """
1502
1855
  Returns the indices of the maximum value of the tensor along the specified axis.
@@ -1521,8 +1874,8 @@ class Tensor:
1521
1874
  if axis is None: return self.flatten().argmax(0)
1522
1875
  axis = self._resolve_dim(axis)
1523
1876
  m = self == self.max(axis=axis, keepdim=True)
1524
- idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
1525
- return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
1877
+ idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
1878
+ return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
1526
1879
 
1527
1880
  def argmin(self, axis=None, keepdim=False):
1528
1881
  """
@@ -1547,8 +1900,48 @@ class Tensor:
1547
1900
  """
1548
1901
  return (-self).argmax(axis=axis, keepdim=keepdim)
1549
1902
 
1903
+ def rearrange(self, formula: str, **sizes) -> Tensor:
1904
+ """
1905
+ Rearranges input according to formula
1906
+
1907
+ See: https://einops.rocks/api/rearrange/
1908
+
1909
+ ```python exec="true" source="above" session="tensor" result="python"
1910
+ x = Tensor([[1, 2], [3, 4]])
1911
+ print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
1912
+ ```
1913
+ """
1914
+ def parse_formula(formula: str):
1915
+ tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
1916
+ lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
1917
+ pairs = list(zip(lparens, rparens))
1918
+ assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
1919
+ return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
1920
+
1921
+ assert formula.count("->") == 1, 'need exactly one "->" in formula'
1922
+
1923
+ (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
1924
+
1925
+ for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
1926
+ assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
1927
+ for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
1928
+ assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
1929
+ assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
1930
+
1931
+ # resolve ellipsis
1932
+ if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
1933
+ lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
1934
+ unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
1935
+ flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
1936
+
1937
+ # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
1938
+ t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
1939
+ for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
1940
+ t = t.permute([lhs.index(name) for name in rhs])
1941
+ return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
1942
+
1550
1943
  @staticmethod
1551
- def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
1944
+ def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1552
1945
  """
1553
1946
  Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
1554
1947
 
@@ -1560,11 +1953,20 @@ class Tensor:
1560
1953
  print(Tensor.einsum("ij,ij->", x, y).numpy())
1561
1954
  ```
1562
1955
  """
1563
- xs:Tuple[Tensor] = argfix(*raw_xs)
1564
- formula = formula.replace(" ", "")
1565
- inputs_str, output = formula.split("->") if "->" in formula else (formula, \
1566
- ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1567
- inputs = inputs_str.split(',')
1956
+ def parse_formula(formula:str, *operands:Tensor):
1957
+ if "..." in (formula := formula.replace(" ", "")):
1958
+ ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
1959
+ for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
1960
+ if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
1961
+ inputs[i] = inp.replace("...", ell_chars[-ell_count:])
1962
+ inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
1963
+ return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
1964
+ (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
1965
+ return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1966
+
1967
+ xs:Tuple[Tensor, ...] = argfix(*operands)
1968
+ inputs_str, output = parse_formula(formula, *xs)
1969
+ inputs = inputs_str.split(",")
1568
1970
  assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
1569
1971
 
1570
1972
  # map the value of each letter in the formula
@@ -1576,41 +1978,43 @@ class Tensor:
1576
1978
  # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
1577
1979
  xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
1578
1980
 
1579
- # determine the inverse permutation to revert back to original order
1580
- rhs_letter_order = argsort(list(output))
1581
- rhs_order = argsort(rhs_letter_order)
1981
+ # ordinal encode the output alphabet
1982
+ rhs_order = argsort(argsort(list(output)))
1582
1983
 
1583
1984
  # sum over all axes that's not in the output, then permute to the output order
1584
1985
  return functools.reduce(lambda a,b:a*b, xs_) \
1585
- .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
1986
+ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
1586
1987
 
1587
1988
  # ***** processing ops *****
1588
1989
 
1589
1990
  def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
1590
1991
  assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
1591
- assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}"
1592
- s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
1992
+ s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
1593
1993
  assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
1594
- noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
1595
- o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
1596
- if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
1994
+ noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
1995
+ assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
1996
+ o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
1997
+ if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
1597
1998
  # repeats such that we don't need padding
1598
- xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
1999
+ x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)])
1599
2000
  # handle dilation
1600
- xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
2001
+ x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_)))
1601
2002
  # handle stride
1602
- xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1603
- xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
2003
+ x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
2004
+ x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
1604
2005
  # permute to move reduce to the end
1605
- return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
2006
+ return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
1606
2007
  # TODO: once the shapetracker can optimize well, remove this alternative implementation
1607
- xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1608
- xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1609
- xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
1610
- return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
2008
+ x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
2009
+ x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
2010
+ x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
2011
+ return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
2012
+
2013
+ def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
2014
+ return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
1611
2015
 
1612
2016
  # NOTE: these work for more than 2D
1613
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
2017
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
1614
2018
  """
1615
2019
  Applies average pooling over a tensor.
1616
2020
 
@@ -1622,11 +2026,15 @@ class Tensor:
1622
2026
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
1623
2027
  print(t.avg_pool2d().numpy())
1624
2028
  ```
2029
+ ```python exec="true" source="above" session="tensor" result="python"
2030
+ print(t.avg_pool2d(padding=1).numpy())
2031
+ ```
1625
2032
  """
1626
- kernel_size = make_pair(kernel_size)
1627
- return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
2033
+ padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
2034
+ def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
2035
+ return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
1628
2036
 
1629
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
2037
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
1630
2038
  """
1631
2039
  Applies max pooling over a tensor.
1632
2040
 
@@ -1638,11 +2046,15 @@ class Tensor:
1638
2046
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
1639
2047
  print(t.max_pool2d().numpy())
1640
2048
  ```
2049
+ ```python exec="true" source="above" session="tensor" result="python"
2050
+ print(t.max_pool2d(padding=1).numpy())
2051
+ ```
1641
2052
  """
1642
- kernel_size = make_pair(kernel_size)
1643
- return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
2053
+ padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
2054
+ return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
1644
2055
 
1645
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
2056
+ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
2057
+ acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1646
2058
  """
1647
2059
  Applies a convolution over a tensor with a given `weight` and optional `bias`.
1648
2060
 
@@ -1656,13 +2068,14 @@ class Tensor:
1656
2068
  print(t.conv2d(w).numpy())
1657
2069
  ```
1658
2070
  """
2071
+ if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
1659
2072
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
1660
2073
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
1661
2074
  if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
1662
- padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
2075
+ padding_ = self._padding2d(padding, len(HW))
1663
2076
 
1664
2077
  # conv2d is a pooling op (with padding)
1665
- x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
2078
+ x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
1666
2079
  rcout, oyx = cout//groups, x.shape[2:-len(HW)]
1667
2080
  if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
1668
2081
  # normal conv
@@ -1680,7 +2093,7 @@ class Tensor:
1680
2093
  # todo: stride == dilation
1681
2094
  # use padding to round up to 4x4 output tiles
1682
2095
  # (bs, cin_, tyx, HWI)
1683
- d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
2096
+ d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
1684
2097
  # move HW to the front: # (HWI, bs, cin_, tyx)
1685
2098
  d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
1686
2099
  tyx = d.shape[-len(HWI):] # dim of tiling
@@ -1719,7 +2132,7 @@ class Tensor:
1719
2132
  """
1720
2133
  x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
1721
2134
  HW = weight.shape[2:]
1722
- stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
2135
+ stride, dilation, padding, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
1723
2136
  if any(s>1 for s in stride):
1724
2137
  # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
1725
2138
  x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
@@ -1729,26 +2142,35 @@ class Tensor:
1729
2142
  padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
1730
2143
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1731
2144
 
1732
- def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
2145
+ def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2146
+
1733
2147
  """
1734
2148
  Performs dot product between two tensors.
2149
+ If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
2150
+ If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
1735
2151
 
1736
2152
  You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
1737
2153
 
2154
+ ```python exec="true" source="above" session="tensor" result="python"
2155
+ a = Tensor([1, 2, 3])
2156
+ b = Tensor([1, 1, 0])
2157
+ print(a.dot(b).numpy())
2158
+ ```
1738
2159
  ```python exec="true" source="above" session="tensor" result="python"
1739
2160
  a = Tensor([[1, 2], [3, 4]])
1740
2161
  b = Tensor([[5, 6], [7, 8]])
1741
2162
  print(a.dot(b).numpy())
1742
2163
  ```
1743
2164
  """
1744
- n1, n2 = len(self.shape), len(w.shape)
1745
- assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
1746
- assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
1747
- x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
1748
- w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
2165
+ if IMAGE: return self.image_dot(w, acc_dtype)
2166
+ x, dx, dw = self, self.ndim, w.ndim
2167
+ if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
2168
+ if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
2169
+ x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
2170
+ w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
1749
2171
  return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
1750
2172
 
1751
- def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
2173
+ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1752
2174
  """
1753
2175
  Performs matrix multiplication between two tensors.
1754
2176
 
@@ -1766,7 +2188,7 @@ class Tensor:
1766
2188
  def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
1767
2189
  assert self.shape[axis] != 0
1768
2190
  pl_sz = self.shape[axis] - int(not _first_zero)
1769
- return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
2191
+ return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
1770
2192
  def cumsum(self, axis:int=0) -> Tensor:
1771
2193
  """
1772
2194
  Computes the cumulative sum of the tensor along the specified axis.
@@ -1786,12 +2208,11 @@ class Tensor:
1786
2208
  # TODO: someday the optimizer will find this on it's own
1787
2209
  # for now this is a two stage cumsum
1788
2210
  SPLIT = 256
1789
- if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
1790
- ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
1791
- ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
2211
+ if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
2212
+ ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
1792
2213
  base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
1793
2214
  base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
1794
- def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
2215
+ def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
1795
2216
  return fix(ret) + fix(base_add)
1796
2217
 
1797
2218
  @staticmethod
@@ -1850,6 +2271,38 @@ class Tensor:
1850
2271
  """
1851
2272
  return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
1852
2273
 
2274
+ def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
2275
+ """
2276
+ Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
2277
+
2278
+ The interpolation algorithm is selected with `mode` which currently only supports `linear`, `nearest` and `nearest-exact`.
2279
+ To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
2280
+
2281
+ ```python exec="true" source="above" session="tensor" result="python"
2282
+ t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
2283
+ print(t.numpy())
2284
+ ```
2285
+ ```python exec="true" source="above" session="tensor" result="python"
2286
+ print(t.interpolate(size=(2,3), mode="linear").numpy())
2287
+ ```
2288
+ """
2289
+ assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
2290
+ assert mode in ("linear", "nearest", "nearest-exact"), "only supports linear, nearest or nearest-exact interpolate"
2291
+ assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear"
2292
+ x, expand = self, list(self.shape)
2293
+ for i in range(-1,-len(size)-1,-1):
2294
+ scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
2295
+ arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
2296
+ reshape[i] = expand[i] = size[i]
2297
+ if mode == "linear":
2298
+ index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
2299
+ low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
2300
+ x = x.gather(i, low).lerp(x.gather(i, high), perc)
2301
+ else:
2302
+ index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
2303
+ x = x.gather(i, index)
2304
+ return x.cast(self.dtype)
2305
+
1853
2306
  # ***** unary ops *****
1854
2307
 
1855
2308
  def logical_not(self):
@@ -1869,7 +2322,7 @@ class Tensor:
1869
2322
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
1870
2323
  ```
1871
2324
  """
1872
- return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
2325
+ return self*-1 if self.dtype != dtypes.bool else self.logical_not()
1873
2326
  def contiguous(self):
1874
2327
  """
1875
2328
  Returns a contiguous tensor.
@@ -1946,6 +2399,20 @@ class Tensor:
1946
2399
  ```
1947
2400
  """
1948
2401
  return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
2402
+ def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
2403
+ """
2404
+ Applies the Hardsigmoid function element-wise.
2405
+ NOTE: default `alpha` and `beta` values is taken from torch
2406
+
2407
+ - Described: https://paperswithcode.com/method/hard-sigmoid
2408
+ - See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
2409
+
2410
+ ```python exec="true" source="above" session="tensor" result="python"
2411
+ print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy())
2412
+ ```
2413
+ """
2414
+ return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
2415
+
1949
2416
  def sqrt(self):
1950
2417
  """
1951
2418
  Computes the square root of the tensor element-wise.
@@ -1999,7 +2466,7 @@ class Tensor:
1999
2466
  Truncates the tensor element-wise.
2000
2467
 
2001
2468
  ```python exec="true" source="above" session="tensor" result="python"
2002
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
2469
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
2003
2470
  ```
2004
2471
  """
2005
2472
  return self.cast(dtypes.int32).cast(self.dtype)
@@ -2008,7 +2475,7 @@ class Tensor:
2008
2475
  Rounds the tensor element-wise towards positive infinity.
2009
2476
 
2010
2477
  ```python exec="true" source="above" session="tensor" result="python"
2011
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
2478
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
2012
2479
  ```
2013
2480
  """
2014
2481
  return (self > (b := self.trunc())).where(b+1, b)
@@ -2017,19 +2484,39 @@ class Tensor:
2017
2484
  Rounds the tensor element-wise towards negative infinity.
2018
2485
 
2019
2486
  ```python exec="true" source="above" session="tensor" result="python"
2020
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
2487
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
2021
2488
  ```
2022
2489
  """
2023
2490
  return (self < (b := self.trunc())).where(b-1, b)
2024
2491
  def round(self: Tensor) -> Tensor:
2025
2492
  """
2026
- Rounds the tensor element-wise.
2493
+ Rounds the tensor element-wise with rounding half to even.
2027
2494
 
2028
2495
  ```python exec="true" source="above" session="tensor" result="python"
2029
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
2496
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
2030
2497
  ```
2031
2498
  """
2032
2499
  return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
2500
+
2501
+ def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
2502
+ """
2503
+ Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
2504
+
2505
+ ```python exec="true" source="above" session="tensor" result="python"
2506
+ print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy())
2507
+ ```
2508
+ """
2509
+ return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
2510
+ def isnan(self:Tensor):
2511
+ """
2512
+ Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
2513
+
2514
+ ```python exec="true" source="above" session="tensor" result="python"
2515
+ print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy())
2516
+ ```
2517
+ """
2518
+ return self != self
2519
+
2033
2520
  def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
2034
2521
  """
2035
2522
  Linearly interpolates between `self` and `end` by `weight`.
@@ -2038,7 +2525,11 @@ class Tensor:
2038
2525
  print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
2039
2526
  ```
2040
2527
  """
2528
+ if self.dtype == dtypes.uint8 and isinstance(weight, Tensor):
2529
+ w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
2530
+ return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
2041
2531
  return self + (end - self) * weight
2532
+
2042
2533
  def square(self):
2043
2534
  """
2044
2535
  Squares the tensor element-wise.
@@ -2049,15 +2540,23 @@ class Tensor:
2049
2540
  ```
2050
2541
  """
2051
2542
  return self*self
2052
- def clip(self, min_, max_):
2543
+ def clamp(self, min_=None, max_=None):
2053
2544
  """
2054
2545
  Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
2546
+ If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
2055
2547
 
2056
2548
  ```python exec="true" source="above" session="tensor" result="python"
2057
2549
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
2058
2550
  ```
2059
2551
  """
2060
- return self.maximum(min_).minimum(max_)
2552
+ if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
2553
+ ret = self.maximum(min_) if min_ is not None else self
2554
+ return ret.minimum(max_) if max_ is not None else ret
2555
+ def clip(self, min_=None, max_=None):
2556
+ """
2557
+ Alias for `Tensor.clamp`.
2558
+ """
2559
+ return self.clamp(min_, max_)
2061
2560
  def sign(self):
2062
2561
  """
2063
2562
  Returns the sign of the tensor element-wise.
@@ -2249,6 +2748,20 @@ class Tensor:
2249
2748
  """
2250
2749
  return self.clip(min_val, max_val)
2251
2750
 
2751
+ def erf(self):
2752
+ """
2753
+ Applies error function element-wise.
2754
+
2755
+ - Described: https://en.wikipedia.org/wiki/Error_function
2756
+
2757
+ ```python exec="true" source="above" session="tensor" result="python"
2758
+ print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
2759
+ ```
2760
+ """
2761
+ # https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
2762
+ t = 1.0 / (1.0 + 0.3275911 * self.abs())
2763
+ return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
2764
+
2252
2765
  def gelu(self):
2253
2766
  """
2254
2767
  Applies the Gaussian Error Linear Unit (GELU) function element-wise.
@@ -2333,17 +2846,18 @@ class Tensor:
2333
2846
  # first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
2334
2847
  padded, _ = _pad_left(self.shape, shape)
2335
2848
  # for each dimension, check either from_ is 1, or it does not change
2336
- if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2849
+ if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
2850
+ raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2337
2851
  return F.Expand.apply(self.reshape(padded), shape=shape)
2338
2852
 
2339
- def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
2853
+ def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
2340
2854
  x: Tensor = self
2341
2855
  if not isinstance(y, Tensor):
2342
2856
  # make y a Tensor
2343
- assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
2857
+ assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
2344
2858
  if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2345
- elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
2346
- if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
2859
+ elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
2860
+ if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
2347
2861
  else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
2348
2862
 
2349
2863
  if match_dtype and x.dtype != y.dtype:
@@ -2421,12 +2935,25 @@ class Tensor:
2421
2935
  """
2422
2936
  return F.Mul.apply(*self._broadcasted(x, reverse))
2423
2937
 
2424
- def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
2938
+ def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2939
+ """
2940
+ Divides `self` by `x`.
2941
+ Equivalent to `self // x`.
2942
+ Supports broadcasting to a common shape, type promotion, and integer inputs.
2943
+ `idiv` performs integer division.
2944
+
2945
+ ```python exec="true" source="above" session="tensor" result="python"
2946
+ print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy())
2947
+ ```
2948
+ """
2949
+ return F.IDiv.apply(*self._broadcasted(x, reverse))
2950
+
2951
+ def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2425
2952
  """
2426
2953
  Divides `self` by `x`.
2427
2954
  Equivalent to `self / x`.
2428
2955
  Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
2429
- By default, `div` performs true division. Set `upcast` to `False` for integer division.
2956
+ `div` performs true division.
2430
2957
 
2431
2958
  ```python exec="true" source="above" session="tensor" result="python"
2432
2959
  Tensor.manual_seed(42)
@@ -2439,13 +2966,9 @@ class Tensor:
2439
2966
  ```python exec="true" source="above" session="tensor" result="python"
2440
2967
  print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
2441
2968
  ```
2442
- ```python exec="true" source="above" session="tensor" result="python"
2443
- print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4]), upcast=False).numpy())
2444
- ```
2445
2969
  """
2446
2970
  numerator, denominator = self._broadcasted(x, reverse)
2447
- if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
2448
- return F.Div.apply(numerator, denominator)
2971
+ return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
2449
2972
 
2450
2973
  def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2451
2974
  """
@@ -2460,8 +2983,53 @@ class Tensor:
2460
2983
  print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
2461
2984
  ```
2462
2985
  """
2986
+ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
2463
2987
  return F.Xor.apply(*self._broadcasted(x, reverse))
2464
2988
 
2989
+ def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2990
+ """
2991
+ Compute the bit-wise AND of `self` and `x`.
2992
+ Equivalent to `self & x`.
2993
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2994
+ ```python exec="true" source="above" session="tensor" result="python"
2995
+ print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
2996
+ ```
2997
+ ```python exec="true" source="above" session="tensor" result="python"
2998
+ print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
2999
+ ```
3000
+ """
3001
+ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3002
+ return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
3003
+
3004
+ def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
3005
+ """
3006
+ Compute the bit-wise OR of `self` and `x`.
3007
+ Equivalent to `self | x`.
3008
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
3009
+ ```python exec="true" source="above" session="tensor" result="python"
3010
+ print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
3011
+ ```
3012
+ ```python exec="true" source="above" session="tensor" result="python"
3013
+ print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
3014
+ ```
3015
+ """
3016
+ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3017
+ return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
3018
+
3019
+ def bitwise_not(self) -> Tensor:
3020
+ """
3021
+ Compute the bit-wise NOT of `self`.
3022
+ Equivalent to `~self`.
3023
+ ```python exec="true" source="above" session="tensor" result="python"
3024
+ print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
3025
+ ```
3026
+ ```python exec="true" source="above" session="tensor" result="python"
3027
+ print(Tensor([True, False]).bitwise_not().numpy())
3028
+ ```
3029
+ """
3030
+ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
3031
+ return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
3032
+
2465
3033
  def lshift(self, x:int):
2466
3034
  """
2467
3035
  Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
@@ -2484,7 +3052,7 @@ class Tensor:
2484
3052
  ```
2485
3053
  """
2486
3054
  assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
2487
- return self.div(2 ** x, upcast=False)
3055
+ return self.idiv(2 ** x)
2488
3056
 
2489
3057
  def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2490
3058
  """
@@ -2578,42 +3146,35 @@ class Tensor:
2578
3146
 
2579
3147
  # ***** op wrappers *****
2580
3148
 
2581
- def __neg__(self) -> Tensor: return self.neg()
3149
+ def __invert__(self) -> Tensor: return self.bitwise_not()
2582
3150
 
2583
- def __add__(self, x) -> Tensor: return self.add(x)
2584
- def __sub__(self, x) -> Tensor: return self.sub(x)
2585
- def __mul__(self, x) -> Tensor: return self.mul(x)
2586
- def __pow__(self, x) -> Tensor: return self.pow(x)
2587
- def __truediv__(self, x) -> Tensor: return self.div(x)
2588
- def __matmul__(self, x) -> Tensor: return self.matmul(x)
2589
- def __xor__(self, x) -> Tensor: return self.xor(x)
2590
3151
  def __lshift__(self, x) -> Tensor: return self.lshift(x)
2591
3152
  def __rshift__(self, x) -> Tensor: return self.rshift(x)
2592
3153
 
2593
- def __radd__(self, x) -> Tensor: return self.add(x, True)
2594
- def __rsub__(self, x) -> Tensor: return self.sub(x, True)
2595
- def __rmul__(self, x) -> Tensor: return self.mul(x, True)
3154
+ def __pow__(self, x) -> Tensor: return self.pow(x)
3155
+ def __matmul__(self, x) -> Tensor: return self.matmul(x)
3156
+
2596
3157
  def __rpow__(self, x) -> Tensor: return self.pow(x, True)
2597
- def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
2598
3158
  def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
2599
- def __rxor__(self, x) -> Tensor: return self.xor(x, True)
2600
3159
 
2601
3160
  def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
2602
3161
  def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
2603
3162
  def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
2604
3163
  def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
2605
3164
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
3165
+ def __ifloordiv__(self, x) -> Tensor: return self.assign(self.idiv(x))
2606
3166
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
3167
+ def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
3168
+ def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
2607
3169
  def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
2608
3170
  def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
2609
3171
  def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
2610
3172
 
2611
- def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
2612
- def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
2613
- def __ge__(self, x) -> Tensor: return (self<x).logical_not()
2614
- def __le__(self, x) -> Tensor: return (self>x).logical_not()
2615
- def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
2616
- def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
3173
+ def lt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
3174
+ def gt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
3175
+ def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
3176
+
3177
+ def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
2617
3178
 
2618
3179
  # ***** functional nn ops *****
2619
3180
 
@@ -2644,7 +3205,7 @@ class Tensor:
2644
3205
  """
2645
3206
  return functools.reduce(lambda x,f: f(x), ll, self)
2646
3207
 
2647
- def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
3208
+ def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
2648
3209
  """
2649
3210
  Applies Layer Normalization over a mini-batch of inputs.
2650
3211
 
@@ -2703,17 +3264,20 @@ class Tensor:
2703
3264
  ```
2704
3265
  """
2705
3266
  if not Tensor.training or p == 0: return self
2706
- return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
3267
+ return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
2707
3268
 
2708
- def one_hot(self, num_classes:int) -> Tensor:
3269
+ def one_hot(self, num_classes:int=-1) -> Tensor:
2709
3270
  """
2710
3271
  Converts `self` to a one-hot tensor.
2711
3272
 
3273
+ `num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1.
3274
+
2712
3275
  ```python exec="true" source="above" session="tensor" result="python"
2713
3276
  t = Tensor([0, 1, 3, 3, 4])
2714
3277
  print(t.one_hot(5).numpy())
2715
3278
  ```
2716
3279
  """
3280
+ if num_classes == -1: num_classes = (self.max()+1).item()
2717
3281
  return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
2718
3282
 
2719
3283
  def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
@@ -2739,39 +3303,45 @@ class Tensor:
2739
3303
  qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
2740
3304
  return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
2741
3305
 
2742
- def binary_crossentropy(self, y:Tensor) -> Tensor:
3306
+ def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
3307
+ if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
3308
+ reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
3309
+ return reductions[reduction](self)
3310
+
3311
+ def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
2743
3312
  """
2744
- Computes the binary cross-entropy loss between `self` and `y`.
3313
+ Computes the binary cross-entropy loss between `self` and `Y`.
2745
3314
 
2746
3315
  See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
2747
3316
 
2748
3317
  ```python exec="true" source="above" session="tensor" result="python"
2749
3318
  t = Tensor([0.1, 0.9, 0.2])
2750
- y = Tensor([0, 1, 0])
2751
- print(t.binary_crossentropy(y).item())
3319
+ Y = Tensor([0, 1, 0])
3320
+ print(t.binary_crossentropy(Y).item())
2752
3321
  ```
2753
3322
  """
2754
- return (-y*self.log() - (1-y)*(1-self).log()).mean()
3323
+ return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
2755
3324
 
2756
- def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
3325
+ def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
2757
3326
  """
2758
- Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
3327
+ Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
2759
3328
 
2760
3329
  See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
2761
3330
 
2762
3331
  ```python exec="true" source="above" session="tensor" result="python"
2763
3332
  t = Tensor([-1, 2, -3])
2764
- y = Tensor([0, 1, 0])
2765
- print(t.binary_crossentropy_logits(y).item())
3333
+ Y = Tensor([0, 1, 0])
3334
+ print(t.binary_crossentropy_logits(Y).item())
2766
3335
  ```
2767
3336
  """
2768
- return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
3337
+ return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
2769
3338
 
2770
- def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
3339
+ def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
2771
3340
  """
2772
3341
  Computes the sparse categorical cross-entropy loss between `self` and `Y`.
2773
3342
 
2774
3343
  NOTE: `self` is logits and `Y` is the target labels.
3344
+ NOTE: unlike PyTorch, this function expects the class axis to be -1
2775
3345
 
2776
3346
  See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
2777
3347
 
@@ -2782,19 +3352,145 @@ class Tensor:
2782
3352
  ```
2783
3353
  """
2784
3354
  assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
2785
- log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
3355
+ assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
3356
+ log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
2786
3357
  y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
2787
3358
  y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
2788
- smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
2789
- return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
3359
+ smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
3360
+ unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
3361
+ # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
3362
+ return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
3363
+
3364
+ def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
3365
+ """
3366
+ Compute the cross entropy loss between input logits and target.
3367
+
3368
+ NOTE: `self` are logits and `Y` are the target labels or class probabilities.
3369
+
3370
+ See: https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
3371
+
3372
+ ```python exec="true" source="above" session="tensor" result="python"
3373
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
3374
+ Y = Tensor([1, 2])
3375
+ print(t.cross_entropy(Y).item())
3376
+ ```
3377
+ ```python exec="true" source="above" session="tensor" result="python"
3378
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
3379
+ Y = Tensor([1, 2])
3380
+ print(t.cross_entropy(Y, reduction='none').numpy())
3381
+ ```
3382
+ """
3383
+ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
3384
+ Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
3385
+ Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
3386
+ ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
3387
+ return ret._do_reduction(reduction)
3388
+
3389
+ def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
3390
+ """
3391
+ Compute the negative log likelihood loss between log-probabilities and target labels.
3392
+
3393
+ NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
3394
+
3395
+ See: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html
3396
+
3397
+ ```python exec="true" source="above" session="tensor" result="python"
3398
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
3399
+ Y = Tensor([1, 2])
3400
+ print(t.log_softmax().nll_loss(Y).item())
3401
+ ```
3402
+ ```python exec="true" source="above" session="tensor" result="python"
3403
+ t = Tensor([[-1, 2, -3], [1, -2, 3]])
3404
+ Y = Tensor([1, 2])
3405
+ print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
3406
+ ```
3407
+ """
3408
+ weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
3409
+ masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
3410
+ nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
3411
+ return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
3412
+
3413
+ # ***** Tensor Properties *****
3414
+
3415
+ @property
3416
+ def ndim(self) -> int:
3417
+ """
3418
+ Returns the number of dimensions in the tensor.
3419
+
3420
+ ```python exec="true" source="above" session="tensor" result="python"
3421
+ t = Tensor([[1, 2], [3, 4]])
3422
+ print(t.ndim)
3423
+ ```
3424
+ """
3425
+ return len(self.shape)
3426
+
3427
+ def numel(self) -> sint:
3428
+ """
3429
+ Returns the total number of elements in the tensor.
3430
+
3431
+ ```python exec="true" source="above" session="tensor" result="python"
3432
+ t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
3433
+ print(t.numel())
3434
+ ```
3435
+ """
3436
+ return prod(self.shape)
3437
+
3438
+ def element_size(self) -> int:
3439
+ """
3440
+ Returns the size in bytes of an individual element in the tensor.
3441
+
3442
+ ```python exec="true" source="above" session="tensor" result="python"
3443
+ t = Tensor([5], dtype=dtypes.int16)
3444
+ print(t.element_size())
3445
+ ```
3446
+ """
3447
+ return self.dtype.itemsize
3448
+
3449
+ def nbytes(self) -> int:
3450
+ """
3451
+ Returns the total number of bytes of all elements in the tensor.
3452
+
3453
+ ```python exec="true" source="above" session="tensor" result="python"
3454
+ t = Tensor([8, 9], dtype=dtypes.float)
3455
+ print(t.nbytes())
3456
+ ```
3457
+ """
3458
+ return self.numel() * self.element_size()
3459
+
3460
+ def is_floating_point(self) -> bool:
3461
+ """
3462
+ Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
3463
+ `dtype.float16`, `dtype.bfloat16`.
3464
+
3465
+ ```python exec="true" source="above" session="tensor" result="python"
3466
+ t = Tensor([8, 9], dtype=dtypes.float32)
3467
+ print(t.is_floating_point())
3468
+ ```
3469
+ """
3470
+ return dtypes.is_float(self.dtype)
3471
+
3472
+ def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
3473
+ """
3474
+ Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
3475
+
3476
+ ```python exec="true" source="above" session="tensor" result="python"
3477
+ t = Tensor([[4, 5, 6], [7, 8, 9]])
3478
+ print(t.size())
3479
+ ```
3480
+ ```python exec="true" source="above" session="tensor" result="python"
3481
+ print(t.size(dim=1))
3482
+ ```
3483
+ """
3484
+ return self.shape if dim is None else self.shape[dim]
2790
3485
 
2791
3486
  # ***** cast ops *****
2792
3487
 
2793
- def llvm_bf16_cast(self, dtype:DType):
3488
+ def llvm_bf16_cast(self, dtype:DTypeLike):
2794
3489
  # hack for devices that don't support bfloat16
2795
3490
  assert self.dtype == dtypes.bfloat16
2796
3491
  return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
2797
- def cast(self, dtype:DType) -> Tensor:
3492
+
3493
+ def cast(self, dtype:DTypeLike) -> Tensor:
2798
3494
  """
2799
3495
  Casts `self` to the given `dtype`.
2800
3496
 
@@ -2807,8 +3503,9 @@ class Tensor:
2807
3503
  print(t.dtype, t.numpy())
2808
3504
  ```
2809
3505
  """
2810
- return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
2811
- def bitcast(self, dtype:DType) -> Tensor:
3506
+ return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
3507
+
3508
+ def bitcast(self, dtype:DTypeLike) -> Tensor:
2812
3509
  """
2813
3510
  Bitcasts `self` to the given `dtype` of the same itemsize.
2814
3511
 
@@ -2824,7 +3521,15 @@ class Tensor:
2824
3521
  ```
2825
3522
  """
2826
3523
  if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
2827
- return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
3524
+ dt = to_dtype(dtype)
3525
+ if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
3526
+ if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
3527
+ new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
3528
+ tmp = self.bitcast(old_uint)
3529
+ if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
3530
+ return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
3531
+ return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
3532
+
2828
3533
  def float(self) -> Tensor:
2829
3534
  """
2830
3535
  Convenience method to cast `self` to a `float32` Tensor.
@@ -2839,6 +3544,7 @@ class Tensor:
2839
3544
  ```
2840
3545
  """
2841
3546
  return self.cast(dtypes.float32)
3547
+
2842
3548
  def half(self) -> Tensor:
2843
3549
  """
2844
3550
  Convenience method to cast `self` to a `float16` Tensor.
@@ -2854,23 +3560,44 @@ class Tensor:
2854
3560
  """
2855
3561
  return self.cast(dtypes.float16)
2856
3562
 
2857
- # ***** convenience stuff *****
3563
+ def int(self) -> Tensor:
3564
+ """
3565
+ Convenience method to cast `self` to a `int32` Tensor.
2858
3566
 
2859
- @property
2860
- def ndim(self) -> int: return len(self.shape)
2861
- def numel(self) -> sint: return prod(self.shape)
2862
- def element_size(self) -> int: return self.dtype.itemsize
2863
- def nbytes(self) -> int: return self.numel() * self.element_size()
2864
- def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
2865
- def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
3567
+ ```python exec="true" source="above" session="tensor" result="python"
3568
+ t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
3569
+ print(t.dtype, t.numpy())
3570
+ ```
3571
+ ```python exec="true" source="above" session="tensor" result="python"
3572
+ t = t.int()
3573
+ print(t.dtype, t.numpy())
3574
+ ```
3575
+ """
3576
+ return self.cast(dtypes.int32)
3577
+
3578
+ def bool(self) -> Tensor:
3579
+ """
3580
+ Convenience method to cast `self` to a `bool` Tensor.
3581
+
3582
+ ```python exec="true" source="above" session="tensor" result="python"
3583
+ t = Tensor([-1, 0, 1])
3584
+ print(t.dtype, t.numpy())
3585
+ ```
3586
+ ```python exec="true" source="above" session="tensor" result="python"
3587
+ t = t.bool()
3588
+ print(t.dtype, t.numpy())
3589
+ ```
3590
+ """
3591
+ return self.cast(dtypes.bool)
2866
3592
 
2867
3593
  # *** image Tensor function replacements ***
2868
3594
 
2869
- def image_dot(self, w:Tensor, acc_dtype=None):
3595
+ def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
2870
3596
  # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
2871
- n1, n2 = len(self.shape), len(w.shape)
2872
- assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
2873
- assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
3597
+ x, dx, dw = self, self.ndim, w.ndim
3598
+ if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
3599
+ if x.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {x.shape} and {w.shape}")
3600
+
2874
3601
  bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
2875
3602
  out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
2876
3603
 
@@ -2881,7 +3608,7 @@ class Tensor:
2881
3608
  cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
2882
3609
  return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
2883
3610
 
2884
- def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
3611
+ def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
2885
3612
  base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
2886
3613
 
2887
3614
  (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
@@ -2922,12 +3649,8 @@ class Tensor:
2922
3649
  if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
2923
3650
  else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
2924
3651
 
2925
- # padding
2926
- padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
2927
- x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
2928
-
2929
3652
  # prepare input
2930
- x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
3653
+ x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
2931
3654
  x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
2932
3655
 
2933
3656
  # prepare weights
@@ -2945,18 +3668,39 @@ class Tensor:
2945
3668
  ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
2946
3669
  return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
2947
3670
 
2948
- # register functions to move between devices
2949
- for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
2950
-
2951
- if IMAGE:
2952
- # if IMAGE>0 we install these replacement functions in Tensor (hack!)
2953
- setattr(Tensor, "conv2d", Tensor.image_conv2d)
2954
- setattr(Tensor, "dot", Tensor.image_dot)
2955
-
2956
- # TODO: eventually remove this
2957
- def custom_random(out:Buffer):
2958
- Tensor._seed += 1
2959
- rng = np.random.default_rng(Tensor._seed)
2960
- if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2961
- else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
2962
- out.copyin(rng_np_buffer.data)
3671
+ def _metadata_wrapper(fn):
3672
+ def _wrapper(*args, **kwargs):
3673
+ if _METADATA.get() is not None: return fn(*args, **kwargs)
3674
+
3675
+ if TRACEMETA >= 2:
3676
+ caller_frame = sys._getframe(frame := 1)
3677
+ caller_module = caller_frame.f_globals.get("__name__", None)
3678
+ caller_func = caller_frame.f_code.co_name
3679
+ if caller_module is None: return fn(*args, **kwargs)
3680
+
3681
+ # if its called from nn we want to step up frames until we are out of nn
3682
+ while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
3683
+ caller_frame = sys._getframe(frame := frame + 1)
3684
+ caller_module = caller_frame.f_globals.get("__name__", None)
3685
+ if caller_module is None: return fn(*args, **kwargs)
3686
+
3687
+ # if its called from a lambda in tinygrad we want to look two more frames up
3688
+ if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
3689
+ caller_module = caller_frame.f_globals.get("__name__", None)
3690
+ if caller_module is None: return fn(*args, **kwargs)
3691
+ caller_func = caller_frame.f_code.co_name
3692
+ caller_lineno = caller_frame.f_lineno
3693
+
3694
+ caller = f"{caller_module}:{caller_lineno}::{caller_func}"
3695
+ else: caller = ""
3696
+
3697
+ token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
3698
+ ret = fn(*args, **kwargs)
3699
+ _METADATA.reset(token)
3700
+ return ret
3701
+ return _wrapper
3702
+
3703
+ if TRACEMETA >= 1:
3704
+ for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
3705
+ if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
3706
+ setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))