tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,19 +1,19 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools, functools
3
+ import time, math, itertools, functools, struct
4
4
  from contextlib import ContextDecorator
5
5
  from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
6
6
  from collections import defaultdict
7
7
  import numpy as np
8
8
 
9
9
  from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
- from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, getenv
10
+ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
11
11
  from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
12
12
  from tinygrad.lazy import LazyBuffer
13
13
  from tinygrad.multi import MultiLazyBuffer
14
- from tinygrad.ops import LoadOps
14
+ from tinygrad.ops import LoadOps, truncate
15
15
  from tinygrad.device import Device, Buffer, BufferOptions
16
- from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
16
+ from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
17
17
  from tinygrad.engine.realize import run_schedule
18
18
  from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
19
19
 
@@ -43,13 +43,28 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
43
43
  if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
44
44
  return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
45
45
 
46
- def _fromcpu(x: np.ndarray) -> LazyBuffer:
47
- ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
46
+ def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
47
+ def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
48
+
49
+ def _fromnp(x: np.ndarray) -> LazyBuffer:
50
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
48
51
  # fake realize
49
52
  ret.buffer.allocate(x)
50
53
  del ret.srcs
51
54
  return ret
52
55
 
56
+ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
57
+ if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
58
+ else:
59
+ ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
60
+ assert dtype.fmt is not None, f"{dtype=} has None fmt"
61
+ truncate_function = truncate[dtype]
62
+ data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
63
+ # fake realize
64
+ ret.buffer.allocate(memoryview(data))
65
+ del ret.srcs
66
+ return ret
67
+
53
68
  def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]:
54
69
  return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim)
55
70
  for k in range(len(mat[0]))] for dim in range(dims)]
@@ -66,8 +81,11 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
66
81
  assert isinstance(ret, Tensor), "sum didn't return a Tensor"
67
82
  return ret
68
83
 
69
- def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
70
- def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
84
+ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
85
+ max_dim = max(len(shape) for shape in shapes)
86
+ return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
87
+ def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
88
+ return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
71
89
 
72
90
  class Tensor:
73
91
  """
@@ -83,60 +101,75 @@ class Tensor:
83
101
  __slots__ = "lazydata", "requires_grad", "grad", "_ctx"
84
102
  __deletable__ = ('_ctx',)
85
103
  training: ClassVar[bool] = False
86
- class train(ContextDecorator):
87
- def __init__(self, mode:bool = True): self.mode = mode
88
- def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
89
- def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
90
-
91
104
  no_grad: ClassVar[bool] = False
92
- class inference_mode(ContextDecorator):
93
- def __init__(self, mode:bool = True): self.mode = mode
94
- def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
95
- def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
105
+
96
106
  def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
97
107
  device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
98
108
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
99
109
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
100
- # tensors have gradients, buffers do not
110
+
111
+ # tensors can have gradients if you have called .backward
101
112
  self.grad: Optional[Tensor] = None
102
113
 
103
114
  # NOTE: this can be in three states. False and None: no gradient, True: gradient
104
115
  # None (the default) will be updated to True if it's put in an optimizer
105
116
  self.requires_grad: Optional[bool] = requires_grad
106
117
 
107
- # internal variables used for autograd graph construction
118
+ # internal variable used for autograd graph construction
108
119
  self._ctx: Optional[Function] = None
120
+
121
+ # create a LazyBuffer from the different types of inputs
109
122
  if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
110
123
  elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
111
124
  elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
112
- elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
113
- elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
114
- elif isinstance(data, list):
125
+ elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
126
+ elif isinstance(data, (list, tuple)):
115
127
  if dtype is None:
116
128
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
117
129
  else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
118
- if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
119
- else: data = _fromcpu(np.array(data, dtype.np))
130
+ if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
131
+ else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
132
+ elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
120
133
  elif isinstance(data, np.ndarray):
121
- if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
122
- else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
134
+ if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
135
+ else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
136
+
137
+ # by this point, it has to be a LazyBuffer
138
+ if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
139
+ raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
123
140
 
124
141
  # data is a LazyBuffer, but it might be on the wrong device
125
- if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
126
142
  if isinstance(device, tuple):
127
- # TODO: what if it's a MultiLazyBuffer on other devices?
128
- self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
143
+ # if device is a tuple, we should have/construct a MultiLazyBuffer
144
+ if isinstance(data, MultiLazyBuffer):
145
+ assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
146
+ self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
147
+ else:
148
+ self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
129
149
  else:
130
150
  self.lazydata = data if data.device == device else data.copy_to_device(device)
131
151
 
132
- def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
152
+ class train(ContextDecorator):
153
+ def __init__(self, mode:bool = True): self.mode = mode
154
+ def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
155
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
156
+
157
+ class inference_mode(ContextDecorator):
158
+ def __init__(self, mode:bool = True): self.mode = mode
159
+ def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
160
+ def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
161
+
162
+ def __repr__(self):
163
+ return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
133
164
 
134
165
  # Python has a non moving GC, so this should be okay
135
166
  def __hash__(self): return id(self)
136
167
 
137
168
  def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
138
169
 
139
- def __len__(self): return self.shape[0] if len(self.shape) else 1
170
+ def __len__(self):
171
+ if not self.shape: raise TypeError("len() of a 0-d tensor")
172
+ return self.shape[0]
140
173
 
141
174
  @property
142
175
  def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device
@@ -196,6 +229,7 @@ class Tensor:
196
229
  if not self.lazydata.is_realized(): return self.replace(x)
197
230
  self.lazydata = self.lazydata.assign(x.lazydata)
198
231
  return self
232
+
199
233
  def detach(self) -> Tensor:
200
234
  """
201
235
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
@@ -259,9 +293,9 @@ class Tensor:
259
293
  ```
260
294
  """
261
295
  if self.dtype == dtypes.bfloat16: return self.float().numpy()
262
- assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
296
+ assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
263
297
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
264
- return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
298
+ return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
265
299
 
266
300
  def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
267
301
  """
@@ -302,8 +336,10 @@ class Tensor:
302
336
 
303
337
  @staticmethod
304
338
  def from_node(y:Node, **kwargs) -> Tensor:
305
- if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
339
+ if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False)
306
340
  if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False)
341
+ if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b
342
+ if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
307
343
  raise RuntimeError(f"unhandled Node {y}")
308
344
 
309
345
  # ***** creation llop entrypoint *****
@@ -339,7 +375,13 @@ class Tensor:
339
375
 
340
376
  ```python exec="true" source="above" session="tensor" result="python"
341
377
  Tensor.manual_seed(42)
342
- print(Tensor._seed)
378
+ print(Tensor.rand(5).numpy())
379
+ print(Tensor.rand(5).numpy())
380
+ ```
381
+ ```python exec="true" source="above" session="tensor" result="python"
382
+ Tensor.manual_seed(42) # reset to the same seed
383
+ print(Tensor.rand(5).numpy())
384
+ print(Tensor.rand(5).numpy())
343
385
  ```
344
386
  """
345
387
  Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
@@ -724,8 +766,11 @@ class Tensor:
724
766
  print(t.reshape(2, 3).numpy())
725
767
  ```
726
768
  """
727
- new_shape = argfix(shape, *args)
728
- new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
769
+ # resolve None and args
770
+ new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
771
+ # resolve -1
772
+ if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
773
+ if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
729
774
  return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
730
775
 
731
776
  def expand(self, shape, *args) -> Tensor:
@@ -740,7 +785,7 @@ class Tensor:
740
785
  print(t.expand(4, -1).numpy())
741
786
  ```
742
787
  """
743
- return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
788
+ return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
744
789
 
745
790
  def permute(self, order, *args) -> Tensor:
746
791
  """
@@ -756,7 +801,9 @@ class Tensor:
756
801
  print(t.permute(1, 0).numpy())
757
802
  ```
758
803
  """
759
- return F.Permute.apply(self, order=argfix(order, *args))
804
+ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
805
+ if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
806
+ return F.Permute.apply(self, order=order_arg)
760
807
 
761
808
  def flip(self, axis, *args) -> Tensor:
762
809
  """
@@ -774,7 +821,9 @@ class Tensor:
774
821
  print(t.flip((0, 1)).numpy())
775
822
  ```
776
823
  """
777
- return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
824
+ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
825
+ if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
826
+ return F.Flip.apply(self, axis=axis_arg)
778
827
 
779
828
  def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
780
829
  """
@@ -831,7 +880,7 @@ class Tensor:
831
880
  # - first shrink the Tensor to X.shrink(((start, end),))
832
881
  # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
833
882
  # - flip where dim value is negative
834
- # - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
883
+ # - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
835
884
  # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
836
885
  # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
837
886
  # 3. None indexing (no copy)
@@ -847,13 +896,13 @@ class Tensor:
847
896
  # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape
848
897
  # 2. Bool indexing is not supported
849
898
  # 3. Out of bounds Tensor indexing results in 0
850
- # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB
899
+ # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
851
900
  def __getitem__(self, indices) -> Tensor:
852
901
  # 1. indices normalization and validation
853
902
  # treat internal tuples and lists as Tensors and standardize indices to list type
854
903
  if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)]
855
904
  elif isinstance(indices, (tuple, list)):
856
- indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
905
+ indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
857
906
  else: indices = [indices]
858
907
 
859
908
  # turn scalar Tensors into const val for int indexing if possible
@@ -865,24 +914,25 @@ class Tensor:
865
914
  ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
866
915
  fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
867
916
  num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
868
- indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
917
+ indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
869
918
 
870
919
  # use Dict[type, List[dimension]] to track elements in indices
871
920
  type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
872
921
 
873
922
  # record None for dimension injection later and filter None and record rest of indices
874
923
  type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
875
- indices_filtered = [v for v in indices if v is not None]
924
+ tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
925
+ indices_filtered = [i for i in indices if i is not None]
876
926
  for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
877
927
 
928
+ if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
878
929
  for index_type in type_dim:
879
930
  if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
880
- if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
881
931
  if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
882
932
 
883
933
  # 2. basic indexing, uses only movement ops (no copy)
884
- # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
885
- # turn indices in indices_filtered to Tuple[shrink_arg, strides]
934
+ # currently indices_filtered: Tuple[Union[int, slice, Tensor], ...]
935
+ # turn indices in indices_filtered to Tuple[new_slice, strides]
886
936
  for dim in type_dim[int]:
887
937
  if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size:
888
938
  raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}")
@@ -898,13 +948,16 @@ class Tensor:
898
948
  if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported")
899
949
  indices_filtered[dim] = ((0, self.shape[dim]), 1)
900
950
 
901
- new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
902
- ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0))
903
- if any(abs(s) != 1 for s in strides):
951
+ new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered)
952
+ # flip negative strides
953
+ ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0))
954
+ # handle stride != 1 or -1
955
+ if any(abs(st) != 1 for st in strides):
904
956
  strides = tuple(abs(s) for s in strides)
905
- ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
906
- ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape))))
907
- ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
957
+ # pad shape to multiple of stride
958
+ ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
959
+ ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
960
+ ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
908
961
 
909
962
  # inject 1 for dim where it's None and collapse dim for int
910
963
  new_shape = list(ret.shape)
@@ -912,17 +965,17 @@ class Tensor:
912
965
  for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim)
913
966
 
914
967
  ret = ret.reshape(new_shape)
915
- assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
916
968
 
917
969
  # 3. advanced indexing (copy)
918
970
  if type_dim[Tensor]:
919
971
  # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
920
972
  def calc_dim(tensor_dim:int) -> int:
921
- return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
973
+ return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
922
974
 
975
+ assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
923
976
  # track tensor_dim and tensor_index using a dict
924
977
  # calc_dim to get dim and use that to normalize the negative tensor indices
925
- idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
978
+ idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
926
979
 
927
980
  masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
928
981
  pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
@@ -938,9 +991,9 @@ class Tensor:
938
991
  mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
939
992
 
940
993
  # inject 1's for the extra dims added in create masks
941
- sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
994
+ reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
942
995
  # sum reduce the extra dims introduced in create masks
943
- ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
996
+ ret = (ret.reshape(reshape_arg) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype)
944
997
 
945
998
  # special permute case
946
999
  if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
@@ -983,13 +1036,11 @@ class Tensor:
983
1036
  ```
984
1037
  """
985
1038
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
986
- assert all(s >= i for s,i in zip(self.shape, index.shape)), "all dim of index.shape must be smaller than self.shape"
1039
+ assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
987
1040
  dim = self._resolve_dim(dim)
988
- index = index.to(self.device).transpose(0, dim).unsqueeze(-1)
989
- permarg = list(range(self.ndim))
990
- permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
991
- return ((index == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
992
- tuple([*[(0,sh) for sh in index.shape[1:-1]], None])).unsqueeze(0)).sum(-1, acc_dtype=self.dtype).transpose(0, dim)
1041
+ index = index.to(self.device)
1042
+ x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1043
+ return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
993
1044
 
994
1045
  def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
995
1046
  """
@@ -1193,7 +1244,7 @@ class Tensor:
1193
1244
 
1194
1245
  def unflatten(self, dim:int, sizes:Tuple[int,...]):
1195
1246
  """
1196
- Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
1247
+ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
1197
1248
 
1198
1249
  ```python exec="true" source="above" session="tensor" result="python"
1199
1250
  print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
@@ -1210,16 +1261,16 @@ class Tensor:
1210
1261
 
1211
1262
  # ***** reduce ops *****
1212
1263
 
1213
- def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
1264
+ def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
1214
1265
  if self.ndim == 0:
1215
- if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
1216
- axis = None
1266
+ if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
1267
+ axis = ()
1217
1268
  axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
1218
1269
  axis_ = tuple(self._resolve_dim(x) for x in axis_)
1219
1270
  ret = fxn.apply(self, axis=axis_)
1220
1271
  return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1221
1272
 
1222
- def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
1273
+ def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
1223
1274
  """
1224
1275
  Sums the elements of the tensor along the specified axis or axes.
1225
1276
 
@@ -1244,8 +1295,9 @@ class Tensor:
1244
1295
  ```
1245
1296
  """
1246
1297
  ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
1247
- return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
1248
- def max(self, axis=None, keepdim=False):
1298
+ return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
1299
+
1300
+ def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1249
1301
  """
1250
1302
  Returns the maximum value of the tensor along the specified axis or axes.
1251
1303
 
@@ -1267,7 +1319,8 @@ class Tensor:
1267
1319
  ```
1268
1320
  """
1269
1321
  return self._reduce(F.Max, axis, keepdim)
1270
- def min(self, axis=None, keepdim=False):
1322
+
1323
+ def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1271
1324
  """
1272
1325
  Returns the minimum value of the tensor along the specified axis or axes.
1273
1326
 
@@ -1290,7 +1343,7 @@ class Tensor:
1290
1343
  """
1291
1344
  return -((-self).max(axis=axis, keepdim=keepdim))
1292
1345
 
1293
- def mean(self, axis=None, keepdim=False):
1346
+ def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1294
1347
  """
1295
1348
  Returns the mean value of the tensor along the specified axis or axes.
1296
1349
 
@@ -1316,7 +1369,7 @@ class Tensor:
1316
1369
  numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
1317
1370
  return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
1318
1371
 
1319
- def var(self, axis=None, keepdim=False, correction=1):
1372
+ def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1320
1373
  """
1321
1374
  Returns the variance of the tensor along the specified axis or axes.
1322
1375
 
@@ -1338,11 +1391,11 @@ class Tensor:
1338
1391
  print(t.var(axis=1).numpy())
1339
1392
  ```
1340
1393
  """
1341
- assert all_int(self.shape), "does not support symbolic shape"
1342
- square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
1343
- return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
1394
+ squares = (self - self.mean(axis=axis, keepdim=True)).square()
1395
+ n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
1396
+ return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
1344
1397
 
1345
- def std(self, axis=None, keepdim=False, correction=1):
1398
+ def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
1346
1399
  """
1347
1400
  Returns the standard deviation of the tensor along the specified axis or axes.
1348
1401
 
@@ -1465,9 +1518,7 @@ class Tensor:
1465
1518
  print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
1466
1519
  ```
1467
1520
  """
1468
- if axis is None:
1469
- idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
1470
- return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
1521
+ if axis is None: return self.flatten().argmax(0)
1471
1522
  axis = self._resolve_dim(axis)
1472
1523
  m = self == self.max(axis=axis, keepdim=True)
1473
1524
  idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
@@ -1511,12 +1562,13 @@ class Tensor:
1511
1562
  """
1512
1563
  xs:Tuple[Tensor] = argfix(*raw_xs)
1513
1564
  formula = formula.replace(" ", "")
1514
- inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
1515
- inputs = [x for x in cast(str,inputs_str).split(',')]
1565
+ inputs_str, output = formula.split("->") if "->" in formula else (formula, \
1566
+ ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
1567
+ inputs = inputs_str.split(',')
1516
1568
  assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
1517
1569
 
1518
1570
  # map the value of each letter in the formula
1519
- letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items())
1571
+ letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
1520
1572
 
1521
1573
  xs_:List[Tensor] = []
1522
1574
  lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
@@ -1540,19 +1592,18 @@ class Tensor:
1540
1592
  s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
1541
1593
  assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
1542
1594
  noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):]
1595
+ o_ = [math.ceil((i - d * (k-1))/s) for i,d,k,s in zip(i_, d_, k_, s_)]
1543
1596
  if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
1544
- o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
1545
1597
  # repeats such that we don't need padding
1546
1598
  xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
1547
- # slice by dilation
1599
+ # handle dilation
1548
1600
  xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
1549
1601
  # handle stride
1550
1602
  xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1551
1603
  xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1552
1604
  # permute to move reduce to the end
1553
1605
  return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
1554
- # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
1555
- o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
1606
+ # TODO: once the shapetracker can optimize well, remove this alternative implementation
1556
1607
  xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1557
1608
  xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1558
1609
  xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
@@ -1572,8 +1623,9 @@ class Tensor:
1572
1623
  print(t.avg_pool2d().numpy())
1573
1624
  ```
1574
1625
  """
1575
- return self._pool(
1576
- make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1626
+ kernel_size = make_pair(kernel_size)
1627
+ return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
1628
+
1577
1629
  def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1578
1630
  """
1579
1631
  Applies max pooling over a tensor.
@@ -1587,8 +1639,8 @@ class Tensor:
1587
1639
  print(t.max_pool2d().numpy())
1588
1640
  ```
1589
1641
  """
1590
- return self._pool(
1591
- make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
1642
+ kernel_size = make_pair(kernel_size)
1643
+ return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
1592
1644
 
1593
1645
  def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
1594
1646
  """
@@ -1665,22 +1717,24 @@ class Tensor:
1665
1717
  print(t.conv_transpose2d(w).numpy())
1666
1718
  ```
1667
1719
  """
1668
- HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
1669
- x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
1670
- stride = make_pair(stride, len(HW))
1720
+ x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
1721
+ HW = weight.shape[2:]
1722
+ stride, dilation, padding, output_padding = [make_pair(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
1671
1723
  if any(s>1 for s in stride):
1724
+ # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
1672
1725
  x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
1673
1726
  x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
1674
1727
  x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
1675
1728
  x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
1676
- padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
1677
- zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
1729
+ padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
1678
1730
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1679
1731
 
1680
1732
  def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
1681
1733
  """
1682
1734
  Performs dot product between two tensors.
1683
1735
 
1736
+ You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
1737
+
1684
1738
  ```python exec="true" source="above" session="tensor" result="python"
1685
1739
  a = Tensor([[1, 2], [3, 4]])
1686
1740
  b = Tensor([[5, 6], [7, 8]])
@@ -1692,7 +1746,7 @@ class Tensor:
1692
1746
  assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
1693
1747
  x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
1694
1748
  w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
1695
- return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
1749
+ return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
1696
1750
 
1697
1751
  def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
1698
1752
  """
@@ -1710,8 +1764,9 @@ class Tensor:
1710
1764
  return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
1711
1765
 
1712
1766
  def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
1713
- pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
1714
- return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
1767
+ assert self.shape[axis] != 0
1768
+ pl_sz = self.shape[axis] - int(not _first_zero)
1769
+ return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
1715
1770
  def cumsum(self, axis:int=0) -> Tensor:
1716
1771
  """
1717
1772
  Computes the cumulative sum of the tensor along the specified axis.
@@ -1726,48 +1781,74 @@ class Tensor:
1726
1781
  print(t.cumsum(1).numpy())
1727
1782
  ```
1728
1783
  """
1784
+ axis = self._resolve_dim(axis)
1785
+ if self.ndim == 0 or 0 in self.shape: return self
1729
1786
  # TODO: someday the optimizer will find this on it's own
1730
1787
  # for now this is a two stage cumsum
1731
1788
  SPLIT = 256
1732
1789
  if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
1733
1790
  ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
1734
1791
  ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
1735
- base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
1792
+ base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
1736
1793
  base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
1737
1794
  def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
1738
1795
  return fix(ret) + fix(base_add)
1739
1796
 
1740
1797
  @staticmethod
1741
- def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
1742
- assert all_int((r,c)), "does not support symbolic"
1743
- if r == 0: return Tensor.zeros((r, c), **kwargs)
1744
- return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
1745
- def triu(self, k:int=0) -> Tensor:
1798
+ def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
1799
+ assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
1800
+ if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
1801
+ if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
1802
+ s = r+c-1
1803
+ # build a (s, s) upper triangle
1804
+ t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
1805
+ return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
1806
+
1807
+ def triu(self, diagonal:int=0) -> Tensor:
1746
1808
  """
1747
1809
  Returns the upper triangular part of the tensor, the other elements are set to 0.
1748
1810
 
1811
+ The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
1812
+ Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
1813
+
1749
1814
  ```python exec="true" source="above" session="tensor" result="python"
1750
- t = Tensor([[1, 2, 3], [4, 5, 6]])
1815
+ t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
1751
1816
  print(t.numpy())
1752
1817
  ```
1753
1818
  ```python exec="true" source="above" session="tensor" result="python"
1754
- print(t.triu(k=1).numpy())
1819
+ print(t.triu(diagonal=0).numpy())
1820
+ ```
1821
+ ```python exec="true" source="above" session="tensor" result="python"
1822
+ print(t.triu(diagonal=1).numpy())
1823
+ ```
1824
+ ```python exec="true" source="above" session="tensor" result="python"
1825
+ print(t.triu(diagonal=-1).numpy())
1755
1826
  ```
1756
1827
  """
1757
- return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
1758
- def tril(self, k:int=0) -> Tensor:
1828
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
1829
+
1830
+ def tril(self, diagonal:int=0) -> Tensor:
1759
1831
  """
1760
1832
  Returns the lower triangular part of the tensor, the other elements are set to 0.
1761
1833
 
1834
+ The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
1835
+ Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
1836
+
1762
1837
  ```python exec="true" source="above" session="tensor" result="python"
1763
- t = Tensor([[1, 2, 3], [4, 5, 6]])
1838
+ t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
1764
1839
  print(t.numpy())
1765
1840
  ```
1766
1841
  ```python exec="true" source="above" session="tensor" result="python"
1767
- print(t.tril().numpy())
1842
+ print(t.tril(diagonal=0).numpy())
1843
+ ```
1844
+ ```python exec="true" source="above" session="tensor" result="python"
1845
+ print(t.tril(diagonal=1).numpy())
1846
+ ```
1847
+ ```python exec="true" source="above" session="tensor" result="python"
1848
+ print(t.tril(diagonal=-1).numpy())
1768
1849
  ```
1769
1850
  """
1770
- return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
1851
+ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
1771
1852
 
1772
1853
  # ***** unary ops *****
1773
1854
 
@@ -1779,7 +1860,7 @@ class Tensor:
1779
1860
  print(Tensor([False, True]).logical_not().numpy())
1780
1861
  ```
1781
1862
  """
1782
- return F.Eq.apply(*self._broadcasted(False))
1863
+ return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
1783
1864
  def neg(self):
1784
1865
  """
1785
1866
  Negates the tensor element-wise.
@@ -2179,7 +2260,7 @@ class Tensor:
2179
2260
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
2180
2261
  ```
2181
2262
  """
2182
- return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
2263
+ return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
2183
2264
 
2184
2265
  def quick_gelu(self):
2185
2266
  """
@@ -2246,23 +2327,26 @@ class Tensor:
2246
2327
  return self / (1 + self.abs())
2247
2328
 
2248
2329
  # ***** broadcasted elementwise ops *****
2249
- def _broadcast_to(self, shape:Tuple[sint, ...]):
2250
- reshape_arg, _ = _pad_left(self.shape, shape)
2251
- if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
2252
- raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
2253
- return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
2254
-
2255
- def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
2330
+ def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
2331
+ if self.shape == shape: return self
2332
+ if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
2333
+ # first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
2334
+ padded, _ = _pad_left(self.shape, shape)
2335
+ # for each dimension, check either from_ is 1, or it does not change
2336
+ if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
2337
+ return F.Expand.apply(self.reshape(padded), shape=shape)
2338
+
2339
+ def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
2256
2340
  x: Tensor = self
2257
2341
  if not isinstance(y, Tensor):
2258
2342
  # make y a Tensor
2259
2343
  assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
2260
- if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2261
- else: y_dtype = dtypes.from_py(y)
2262
- if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
2263
- else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
2344
+ if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2345
+ elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
2346
+ if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
2347
+ else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
2264
2348
 
2265
- if match_dtype:
2349
+ if match_dtype and x.dtype != y.dtype:
2266
2350
  output_dtype = least_upper_dtype(x.dtype, y.dtype)
2267
2351
  x, y = x.cast(output_dtype), y.cast(output_dtype)
2268
2352
 
@@ -2273,7 +2357,6 @@ class Tensor:
2273
2357
  return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
2274
2358
 
2275
2359
  def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
2276
- # TODO: update with multi
2277
2360
  return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
2278
2361
  and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
2279
2362
 
@@ -2315,7 +2398,8 @@ class Tensor:
2315
2398
  print(t.sub(Tensor([[2.0], [3.5]])).numpy())
2316
2399
  ```
2317
2400
  """
2318
- return F.Sub.apply(*self._broadcasted(x, reverse))
2401
+ a, b = self._broadcasted(x, reverse)
2402
+ return a + (-b)
2319
2403
 
2320
2404
  def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2321
2405
  """
@@ -2528,8 +2612,8 @@ class Tensor:
2528
2612
  def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
2529
2613
  def __ge__(self, x) -> Tensor: return (self<x).logical_not()
2530
2614
  def __le__(self, x) -> Tensor: return (self>x).logical_not()
2531
- def __eq__(self, x) -> Tensor: return F.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
2532
- def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
2615
+ def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
2616
+ def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
2533
2617
 
2534
2618
  # ***** functional nn ops *****
2535
2619
 
@@ -2652,8 +2736,8 @@ class Tensor:
2652
2736
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
2653
2737
  if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
2654
2738
  if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
2655
- qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
2656
- return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
2739
+ qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
2740
+ return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
2657
2741
 
2658
2742
  def binary_crossentropy(self, y:Tensor) -> Tensor:
2659
2743
  """
@@ -2874,5 +2958,5 @@ def custom_random(out:Buffer):
2874
2958
  Tensor._seed += 1
2875
2959
  rng = np.random.default_rng(Tensor._seed)
2876
2960
  if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2877
- else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
2961
+ else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
2878
2962
  out.copyin(rng_np_buffer.data)