tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py CHANGED
@@ -1,37 +1,39 @@
1
1
  # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
2
2
  from __future__ import annotations
3
- import time, math, itertools, functools, struct
3
+ import dataclasses
4
+ import time, math, itertools, functools, struct, sys, inspect
4
5
  from contextlib import ContextDecorator
5
6
  from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
6
7
  from collections import defaultdict
7
8
  import numpy as np
8
9
 
9
- from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
10
+ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
10
11
  from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
11
- from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
12
+ from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
12
13
  from tinygrad.lazy import LazyBuffer
13
14
  from tinygrad.multi import MultiLazyBuffer
14
- from tinygrad.ops import LoadOps, truncate
15
+ from tinygrad.ops import MetaOps, truncate
15
16
  from tinygrad.device import Device, Buffer, BufferOptions
16
17
  from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
17
- from tinygrad.engine.realize import run_schedule
18
- from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner
18
+ from tinygrad.engine.realize import run_schedule, memory_planner
19
+ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
19
20
 
20
21
  # **** start with two base classes, Tensor and Function ****
21
22
 
22
23
  class Function:
23
- def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
24
+ def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
24
25
  self.device = device
25
26
  self.needs_input_grad = [t.requires_grad for t in tensors]
26
27
  self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
27
28
  if self.requires_grad: self.parents = tensors
29
+ self.metadata = metadata
28
30
 
29
31
  def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
30
32
  def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
31
33
 
32
34
  @classmethod
33
35
  def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
34
- ctx = fxn(x[0].device, *x)
36
+ ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
35
37
  ret = Tensor.__new__(Tensor)
36
38
  ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
37
39
  ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
@@ -39,24 +41,24 @@ class Function:
39
41
 
40
42
  import tinygrad.function as F
41
43
 
42
- def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
43
- if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
44
- return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
44
+ def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
45
+ if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
46
+ return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
45
47
 
46
- def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
48
+ def _from_np_dtype(npdtype:np.dtype) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
47
49
  def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
48
50
 
49
51
  def _fromnp(x: np.ndarray) -> LazyBuffer:
50
- ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
52
+ ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
51
53
  # fake realize
52
54
  ret.buffer.allocate(x)
53
55
  del ret.srcs
54
56
  return ret
55
57
 
56
58
  def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
57
- if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
59
+ if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
58
60
  else:
59
- ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
61
+ ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
60
62
  assert dtype.fmt is not None, f"{dtype=} has None fmt"
61
63
  truncate_function = truncate[dtype]
62
64
  data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
@@ -85,7 +87,7 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
85
87
  max_dim = max(len(shape) for shape in shapes)
86
88
  return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
87
89
  def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
88
- return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
90
+ return tuple(0 if 0 in nth_dim_sizes else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
89
91
 
90
92
  class Tensor:
91
93
  """
@@ -104,7 +106,8 @@ class Tensor:
104
106
  no_grad: ClassVar[bool] = False
105
107
 
106
108
  def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
107
- device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
109
+ device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
110
+ if dtype is not None: dtype = to_dtype(dtype)
108
111
  assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
109
112
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
110
113
 
@@ -120,18 +123,18 @@ class Tensor:
120
123
 
121
124
  # create a LazyBuffer from the different types of inputs
122
125
  if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
123
- elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
124
- elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
125
- elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
126
+ elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
127
+ elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
128
+ elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
126
129
  elif isinstance(data, (list, tuple)):
127
130
  if dtype is None:
128
131
  if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
129
132
  else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
130
133
  if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
131
134
  else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
132
- elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
135
+ elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
133
136
  elif isinstance(data, np.ndarray):
134
- if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
137
+ if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
135
138
  else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
136
139
 
137
140
  # by this point, it has to be a LazyBuffer
@@ -145,7 +148,7 @@ class Tensor:
145
148
  assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
146
149
  self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
147
150
  else:
148
- self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
151
+ self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
149
152
  else:
150
153
  self.lazydata = data if data.device == device else data.copy_to_device(device)
151
154
 
@@ -318,20 +321,34 @@ class Tensor:
318
321
  if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
319
322
  self.lazydata = real.lazydata
320
323
 
321
- def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
324
+ def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
322
325
  """
323
- Shards the tensor across the given devices.
326
+ Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
327
+
328
+ ```python exec="true" source="above" session="tensor" result="python"
329
+ t = Tensor.empty(2, 3)
330
+ print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
331
+ ```
332
+
324
333
  """
325
334
  assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
326
- canonical_devices = tuple(Device.canonicalize(x) for x in devices)
327
- if axis is not None and axis < 0: axis += len(self.shape)
328
- return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
329
-
330
- def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
335
+ canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
336
+ if axis is not None:
337
+ if axis < 0: axis += len(self.shape)
338
+ if splits is None:
339
+ sz = round_up(self.shape[axis], len(devices)) // len(devices)
340
+ splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))])
341
+ assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
342
+ boundaries = tuple(itertools.accumulate(splits))
343
+ bounds = tuple(zip((0,) + boundaries, boundaries))
344
+ return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds),
345
+ device=canonical_devices, requires_grad=self.requires_grad)
346
+
347
+ def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
331
348
  """
332
349
  Shards the tensor across the given devices in place.
333
350
  """
334
- self.lazydata = self.shard(devices, axis).lazydata
351
+ self.lazydata = self.shard(devices, axis, splits).lazydata
335
352
  return self
336
353
 
337
354
  @staticmethod
@@ -342,14 +359,14 @@ class Tensor:
342
359
  if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
343
360
  raise RuntimeError(f"unhandled Node {y}")
344
361
 
345
- # ***** creation llop entrypoint *****
362
+ # ***** creation entrypoint *****
346
363
 
347
364
  @staticmethod
348
- def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
365
+ def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
349
366
  if isinstance(device, tuple):
350
- return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
367
+ return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
351
368
  for d in device], None), device, dtype, **kwargs)
352
- return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
369
+ return Tensor(LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
353
370
 
354
371
  @staticmethod
355
372
  def empty(*shape, **kwargs):
@@ -364,7 +381,7 @@ class Tensor:
364
381
  print(t.shape)
365
382
  ```
366
383
  """
367
- return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
384
+ return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
368
385
 
369
386
  _seed: int = int(time.time())
370
387
  _rng_counter: Optional[Tensor] = None
@@ -384,10 +401,10 @@ class Tensor:
384
401
  print(Tensor.rand(5).numpy())
385
402
  ```
386
403
  """
387
- Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
404
+ Tensor._seed, Tensor._rng_counter = seed, None
388
405
 
389
406
  @staticmethod
390
- def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
407
+ def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
391
408
  """
392
409
  Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
393
410
 
@@ -400,27 +417,25 @@ class Tensor:
400
417
  print(t.numpy())
401
418
  ```
402
419
  """
403
- if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
420
+ if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
421
+ if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}")
404
422
  if not THREEFRY.value:
405
423
  # for bfloat16, numpy rand passes buffer in float
406
- if (dtype or dtypes.default_float) == dtypes.bfloat16:
424
+ if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
407
425
  return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
408
- return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
426
+ return Tensor._metaop(MetaOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
409
427
 
410
428
  # threefry
411
429
  if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
412
- counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
430
+ if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num)
431
+ counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device))
413
432
  counts2 = counts1 + math.ceil(num / 2)
414
- Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
415
433
 
416
- rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
417
- ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
434
+ x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64)
435
+ x = F.Threefry.apply(*x._broadcasted(Tensor._seed))
436
+ counts1, counts2 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
418
437
 
419
- x = [counts1 + ks[-1], counts2 + ks[0]]
420
- for i in range(5):
421
- for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
422
- x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
423
- out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
438
+ out = counts1.cat(counts2).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
424
439
  out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
425
440
  out.requires_grad = kwargs.get("requires_grad")
426
441
  return out.contiguous()
@@ -506,12 +521,14 @@ class Tensor:
506
521
  if stop is None: stop, start = start, 0
507
522
  assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
508
523
  dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
524
+ # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
525
+ if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
509
526
  return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
510
527
 
511
528
  @staticmethod
512
- def eye(dim:int, **kwargs):
529
+ def eye(n:int, m:Optional[int]=None, **kwargs):
513
530
  """
514
- Creates an identity matrix of the given dimension.
531
+ Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
515
532
 
516
533
  You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
517
534
  Additionally, all other keyword arguments are passed to the constructor of the tensor.
@@ -519,8 +536,13 @@ class Tensor:
519
536
  ```python exec="true" source="above" session="tensor" result="python"
520
537
  print(Tensor.eye(3).numpy())
521
538
  ```
539
+
540
+ ```python exec="true" source="above" session="tensor" result="python"
541
+ print(Tensor.eye(2, 4).numpy())
542
+ ```
522
543
  """
523
- return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
544
+ if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
545
+ return Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)._slice((None,(0,n if m is None else m)))
524
546
 
525
547
  def full_like(self, fill_value:ConstType, **kwargs):
526
548
  """
@@ -568,7 +590,7 @@ class Tensor:
568
590
  # ***** rng hlops *****
569
591
 
570
592
  @staticmethod
571
- def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
593
+ def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
572
594
  """
573
595
  Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
574
596
  If `dtype` is not specified, the default type is used.
@@ -721,10 +743,10 @@ class Tensor:
721
743
  yield node
722
744
  return list(_walk(self, set()))
723
745
 
724
- def backward(self) -> Tensor:
746
+ def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
725
747
  """
726
748
  Propagates the gradient of a tensor backwards through the computation graph.
727
- Must be used on a scalar tensor.
749
+ If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
728
750
 
729
751
  ```python exec="true" source="above" session="tensor" result="python"
730
752
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
@@ -732,15 +754,20 @@ class Tensor:
732
754
  print(t.grad.numpy())
733
755
  ```
734
756
  """
735
- assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
757
+ if gradient is None:
758
+ assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
759
+ # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
760
+ # this is "implicit gradient creation"
761
+ gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
736
762
 
737
- # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
738
- # this is "implicit gradient creation"
739
- self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
763
+ assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
764
+ self.grad = gradient
740
765
 
741
766
  for t0 in reversed(self._deepwalk()):
742
767
  if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
768
+ token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
743
769
  grads = t0._ctx.backward(t0.grad.lazydata)
770
+ _METADATA.reset(token)
744
771
  grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
745
772
  for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
746
773
  for t, g in zip(t0._ctx.parents, grads):
@@ -1036,8 +1063,8 @@ class Tensor:
1036
1063
  ```
1037
1064
  """
1038
1065
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
1039
- assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1040
1066
  dim = self._resolve_dim(dim)
1067
+ assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
1041
1068
  index = index.to(self.device)
1042
1069
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
1043
1070
  return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
@@ -1079,6 +1106,19 @@ class Tensor:
1079
1106
  # checks for shapes and number of dimensions delegated to cat
1080
1107
  return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
1081
1108
 
1109
+ def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
1110
+ """
1111
+ Repeat elements of a tensor.
1112
+
1113
+ ```python exec="true" source="above" session="tensor" result="python"
1114
+ t = Tensor([1, 2, 3])
1115
+ print(t.repeat_interleave(2).numpy())
1116
+ ```
1117
+ """
1118
+ x, dim = (self.flatten(), 0) if dim is None else (self, dim)
1119
+ shp = x.shape
1120
+ return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
1121
+
1082
1122
  def repeat(self, repeats, *args) -> Tensor:
1083
1123
  """
1084
1124
  Repeats tensor number of times along each dimension specified by `repeats`.
@@ -1270,7 +1310,7 @@ class Tensor:
1270
1310
  ret = fxn.apply(self, axis=axis_)
1271
1311
  return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
1272
1312
 
1273
- def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DType]=None):
1313
+ def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
1274
1314
  """
1275
1315
  Sums the elements of the tensor along the specified axis or axes.
1276
1316
 
@@ -1343,6 +1383,50 @@ class Tensor:
1343
1383
  """
1344
1384
  return -((-self).max(axis=axis, keepdim=keepdim))
1345
1385
 
1386
+ def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1387
+ """
1388
+ Tests if any element evaluates to `True` along the specified axis or axes.
1389
+
1390
+ You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
1391
+
1392
+ ```python exec="true" source="above" session="tensor" result="python"
1393
+ t = Tensor([[True, True], [True, False], [False, False]])
1394
+ print(t.numpy())
1395
+ ```
1396
+ ```python exec="true" source="above" session="tensor" result="python"
1397
+ print(t.any().numpy())
1398
+ ```
1399
+ ```python exec="true" source="above" session="tensor" result="python"
1400
+ print(t.any(axis=0).numpy())
1401
+ ```
1402
+ ```python exec="true" source="above" session="tensor" result="python"
1403
+ print(t.any(axis=1, keepdim=True).numpy())
1404
+ ```
1405
+ """
1406
+ return self.bool().max(axis, keepdim)
1407
+
1408
+ def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1409
+ """
1410
+ Tests if all element evaluates to `True` along the specified axis or axes.
1411
+
1412
+ You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
1413
+
1414
+ ```python exec="true" source="above" session="tensor" result="python"
1415
+ t = Tensor([[True, True], [True, False], [False, False]])
1416
+ print(t.numpy())
1417
+ ```
1418
+ ```python exec="true" source="above" session="tensor" result="python"
1419
+ print(t.all().numpy())
1420
+ ```
1421
+ ```python exec="true" source="above" session="tensor" result="python"
1422
+ print(t.all(axis=0).numpy())
1423
+ ```
1424
+ ```python exec="true" source="above" session="tensor" result="python"
1425
+ print(t.all(axis=1, keepdim=True).numpy())
1426
+ ```
1427
+ """
1428
+ return self.logical_not().any(axis, keepdim).logical_not()
1429
+
1346
1430
  def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
1347
1431
  """
1348
1432
  Returns the mean value of the tensor along the specified axis or axes.
@@ -1548,7 +1632,7 @@ class Tensor:
1548
1632
  return (-self).argmax(axis=axis, keepdim=keepdim)
1549
1633
 
1550
1634
  @staticmethod
1551
- def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
1635
+ def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1552
1636
  """
1553
1637
  Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
1554
1638
 
@@ -1599,18 +1683,22 @@ class Tensor:
1599
1683
  # handle dilation
1600
1684
  xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
1601
1685
  # handle stride
1602
- xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1603
- xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1686
+ xup = xup.shrink(
1687
+ tuple(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
1688
+ xup = xup.shrink(tuple(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
1604
1689
  # permute to move reduce to the end
1605
1690
  return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
1606
1691
  # TODO: once the shapetracker can optimize well, remove this alternative implementation
1607
1692
  xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
1608
1693
  xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
1609
- xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
1694
+ xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))))
1610
1695
  return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
1611
1696
 
1697
+ def _padding2d(self, padding:Union[int, Tuple[int, ...]], dims:int) -> Sequence[int]:
1698
+ return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
1699
+
1612
1700
  # NOTE: these work for more than 2D
1613
- def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1701
+ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
1614
1702
  """
1615
1703
  Applies average pooling over a tensor.
1616
1704
 
@@ -1622,11 +1710,15 @@ class Tensor:
1622
1710
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
1623
1711
  print(t.avg_pool2d().numpy())
1624
1712
  ```
1713
+ ```python exec="true" source="above" session="tensor" result="python"
1714
+ print(t.avg_pool2d(padding=1).numpy())
1715
+ ```
1625
1716
  """
1626
- kernel_size = make_pair(kernel_size)
1627
- return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
1717
+ padding_, axis = self._padding2d(padding, len(k_ := make_pair(kernel_size))), tuple(range(-len(k_), 0))
1718
+ def pool(x:Tensor) -> Tensor: return x.pad2d(padding_)._pool(k_, stride if stride is not None else k_, dilation)
1719
+ return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
1628
1720
 
1629
- def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
1721
+ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
1630
1722
  """
1631
1723
  Applies max pooling over a tensor.
1632
1724
 
@@ -1638,11 +1730,14 @@ class Tensor:
1638
1730
  t = Tensor.arange(25).reshape(1, 1, 5, 5)
1639
1731
  print(t.max_pool2d().numpy())
1640
1732
  ```
1733
+ ```python exec="true" source="above" session="tensor" result="python"
1734
+ print(t.max_pool2d(padding=1).numpy())
1735
+ ```
1641
1736
  """
1642
- kernel_size = make_pair(kernel_size)
1643
- return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
1737
+ padding_ = self._padding2d(padding, len(k_ := make_pair(kernel_size)))
1738
+ return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0)))
1644
1739
 
1645
- def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
1740
+ def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:DTypeLike|None=None) -> Tensor:
1646
1741
  """
1647
1742
  Applies a convolution over a tensor with a given `weight` and optional `bias`.
1648
1743
 
@@ -1659,7 +1754,7 @@ class Tensor:
1659
1754
  (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
1660
1755
  assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
1661
1756
  if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
1662
- padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501
1757
+ padding_ = self._padding2d(padding, len(HW))
1663
1758
 
1664
1759
  # conv2d is a pooling op (with padding)
1665
1760
  x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
@@ -1729,7 +1824,7 @@ class Tensor:
1729
1824
  padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
1730
1825
  return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
1731
1826
 
1732
- def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
1827
+ def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1733
1828
  """
1734
1829
  Performs dot product between two tensors.
1735
1830
 
@@ -1748,7 +1843,7 @@ class Tensor:
1748
1843
  w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
1749
1844
  return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
1750
1845
 
1751
- def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
1846
+ def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
1752
1847
  """
1753
1848
  Performs matrix multiplication between two tensors.
1754
1849
 
@@ -1850,6 +1945,33 @@ class Tensor:
1850
1945
  """
1851
1946
  return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
1852
1947
 
1948
+ def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
1949
+ """
1950
+ Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
1951
+
1952
+ The interpolation algorithm is selected with `mode` which currently only supports `linear`.
1953
+ To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
1954
+
1955
+ ```python exec="true" source="above" session="tensor" result="python"
1956
+ t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
1957
+ print(t.numpy())
1958
+ ```
1959
+ ```python exec="true" source="above" session="tensor" result="python"
1960
+ print(t.interpolate(size=(2,3), mode="linear").numpy())
1961
+ ```
1962
+ """
1963
+ assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
1964
+ assert mode == "linear", "only supports linear interpolate"
1965
+ x, expand = self, list(self.shape)
1966
+ for i in range(-len(size), 0):
1967
+ scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
1968
+ arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
1969
+ index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
1970
+ reshape[i] = expand[i] = size[i]
1971
+ low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
1972
+ x = x.gather(i, low).lerp(x.gather(i, high), perc)
1973
+ return x
1974
+
1853
1975
  # ***** unary ops *****
1854
1976
 
1855
1977
  def logical_not(self):
@@ -1999,7 +2121,7 @@ class Tensor:
1999
2121
  Truncates the tensor element-wise.
2000
2122
 
2001
2123
  ```python exec="true" source="above" session="tensor" result="python"
2002
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
2124
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
2003
2125
  ```
2004
2126
  """
2005
2127
  return self.cast(dtypes.int32).cast(self.dtype)
@@ -2008,7 +2130,7 @@ class Tensor:
2008
2130
  Rounds the tensor element-wise towards positive infinity.
2009
2131
 
2010
2132
  ```python exec="true" source="above" session="tensor" result="python"
2011
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
2133
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
2012
2134
  ```
2013
2135
  """
2014
2136
  return (self > (b := self.trunc())).where(b+1, b)
@@ -2017,19 +2139,20 @@ class Tensor:
2017
2139
  Rounds the tensor element-wise towards negative infinity.
2018
2140
 
2019
2141
  ```python exec="true" source="above" session="tensor" result="python"
2020
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
2142
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
2021
2143
  ```
2022
2144
  """
2023
2145
  return (self < (b := self.trunc())).where(b-1, b)
2024
2146
  def round(self: Tensor) -> Tensor:
2025
2147
  """
2026
- Rounds the tensor element-wise.
2148
+ Rounds the tensor element-wise with rounding half to even.
2027
2149
 
2028
2150
  ```python exec="true" source="above" session="tensor" result="python"
2029
- print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
2151
+ print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
2030
2152
  ```
2031
2153
  """
2032
2154
  return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
2155
+
2033
2156
  def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
2034
2157
  """
2035
2158
  Linearly interpolates between `self` and `end` by `weight`.
@@ -2039,6 +2162,7 @@ class Tensor:
2039
2162
  ```
2040
2163
  """
2041
2164
  return self + (end - self) * weight
2165
+
2042
2166
  def square(self):
2043
2167
  """
2044
2168
  Squares the tensor element-wise.
@@ -2049,15 +2173,23 @@ class Tensor:
2049
2173
  ```
2050
2174
  """
2051
2175
  return self*self
2052
- def clip(self, min_, max_):
2176
+ def clamp(self, min_=None, max_=None):
2053
2177
  """
2054
2178
  Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
2179
+ If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
2055
2180
 
2056
2181
  ```python exec="true" source="above" session="tensor" result="python"
2057
2182
  print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
2058
2183
  ```
2059
2184
  """
2060
- return self.maximum(min_).minimum(max_)
2185
+ if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
2186
+ ret = self.maximum(min_) if min_ is not None else self
2187
+ return ret.minimum(max_) if max_ is not None else ret
2188
+ def clip(self, min_=None, max_=None):
2189
+ """
2190
+ Alias for `Tensor.clamp`.
2191
+ """
2192
+ return self.clamp(min_, max_)
2061
2193
  def sign(self):
2062
2194
  """
2063
2195
  Returns the sign of the tensor element-wise.
@@ -2340,7 +2472,7 @@ class Tensor:
2340
2472
  x: Tensor = self
2341
2473
  if not isinstance(y, Tensor):
2342
2474
  # make y a Tensor
2343
- assert isinstance(y, (float, int, bool, Node)), f"{type(y)=}, {y=}"
2475
+ assert isinstance(y, (*get_args(ConstType), Node)), f"{type(y)=}, {y=}"
2344
2476
  if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
2345
2477
  elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
2346
2478
  if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
@@ -2462,6 +2594,36 @@ class Tensor:
2462
2594
  """
2463
2595
  return F.Xor.apply(*self._broadcasted(x, reverse))
2464
2596
 
2597
+ def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2598
+ """
2599
+ Compute the bit-wise AND of `self` and `x`.
2600
+ Equivalent to `self & x`.
2601
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2602
+ ```python exec="true" source="above" session="tensor" result="python"
2603
+ print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
2604
+ ```
2605
+ ```python exec="true" source="above" session="tensor" result="python"
2606
+ print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
2607
+ ```
2608
+ """
2609
+ assert dtypes.is_int(self.dtype)
2610
+ return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
2611
+
2612
+ def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
2613
+ """
2614
+ Compute the bit-wise OR of `self` and `x`.
2615
+ Equivalent to `self | x`.
2616
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
2617
+ ```python exec="true" source="above" session="tensor" result="python"
2618
+ print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
2619
+ ```
2620
+ ```python exec="true" source="above" session="tensor" result="python"
2621
+ print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
2622
+ ```
2623
+ """
2624
+ assert dtypes.is_int(self.dtype)
2625
+ return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
2626
+
2465
2627
  def lshift(self, x:int):
2466
2628
  """
2467
2629
  Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
@@ -2586,6 +2748,8 @@ class Tensor:
2586
2748
  def __pow__(self, x) -> Tensor: return self.pow(x)
2587
2749
  def __truediv__(self, x) -> Tensor: return self.div(x)
2588
2750
  def __matmul__(self, x) -> Tensor: return self.matmul(x)
2751
+ def __and__(self, x) -> Tensor: return self.bitwise_and(x)
2752
+ def __or__(self, x) -> Tensor: return self.bitwise_or(x)
2589
2753
  def __xor__(self, x) -> Tensor: return self.xor(x)
2590
2754
  def __lshift__(self, x) -> Tensor: return self.lshift(x)
2591
2755
  def __rshift__(self, x) -> Tensor: return self.rshift(x)
@@ -2596,6 +2760,8 @@ class Tensor:
2596
2760
  def __rpow__(self, x) -> Tensor: return self.pow(x, True)
2597
2761
  def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
2598
2762
  def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
2763
+ def __rand__(self, x) -> Tensor: return self.bitwise_and(x, True)
2764
+ def __ror__(self, x) -> Tensor: return self.bitwise_or(x, True)
2599
2765
  def __rxor__(self, x) -> Tensor: return self.xor(x, True)
2600
2766
 
2601
2767
  def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
@@ -2604,6 +2770,8 @@ class Tensor:
2604
2770
  def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
2605
2771
  def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
2606
2772
  def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
2773
+ def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
2774
+ def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
2607
2775
  def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
2608
2776
  def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
2609
2777
  def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
@@ -2788,13 +2956,87 @@ class Tensor:
2788
2956
  smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
2789
2957
  return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
2790
2958
 
2959
+ # ***** Tensor Properties *****
2960
+
2961
+ @property
2962
+ def ndim(self) -> int:
2963
+ """
2964
+ Returns the number of dimensions in the tensor.
2965
+
2966
+ ```python exec="true" source="above" session="tensor" result="python"
2967
+ t = Tensor([[1, 2], [3, 4]])
2968
+ print(t.ndim)
2969
+ ```
2970
+ """
2971
+ return len(self.shape)
2972
+
2973
+ def numel(self) -> sint:
2974
+ """
2975
+ Returns the total number of elements in the tensor.
2976
+
2977
+ ```python exec="true" source="above" session="tensor" result="python"
2978
+ t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
2979
+ print(t.numel())
2980
+ ```
2981
+ """
2982
+ return prod(self.shape)
2983
+
2984
+ def element_size(self) -> int:
2985
+ """
2986
+ Returns the size in bytes of an individual element in the tensor.
2987
+
2988
+ ```python exec="true" source="above" session="tensor" result="python"
2989
+ t = Tensor([5], dtype=dtypes.int16)
2990
+ print(t.element_size())
2991
+ ```
2992
+ """
2993
+ return self.dtype.itemsize
2994
+
2995
+ def nbytes(self) -> int:
2996
+ """
2997
+ Returns the total number of bytes of all elements in the tensor.
2998
+
2999
+ ```python exec="true" source="above" session="tensor" result="python"
3000
+ t = Tensor([8, 9], dtype=dtypes.float)
3001
+ print(t.nbytes())
3002
+ ```
3003
+ """
3004
+ return self.numel() * self.element_size()
3005
+
3006
+ def is_floating_point(self) -> bool:
3007
+ """
3008
+ Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
3009
+ `dtype.float16`, `dtype.bfloat16`.
3010
+
3011
+ ```python exec="true" source="above" session="tensor" result="python"
3012
+ t = Tensor([8, 9], dtype=dtypes.float32)
3013
+ print(t.is_floating_point())
3014
+ ```
3015
+ """
3016
+ return dtypes.is_float(self.dtype)
3017
+
3018
+ def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
3019
+ """
3020
+ Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
3021
+
3022
+ ```python exec="true" source="above" session="tensor" result="python"
3023
+ t = Tensor([[4, 5, 6], [7, 8, 9]])
3024
+ print(t.size())
3025
+ ```
3026
+ ```python exec="true" source="above" session="tensor" result="python"
3027
+ print(t.size(dim=1))
3028
+ ```
3029
+ """
3030
+ return self.shape if dim is None else self.shape[dim]
3031
+
2791
3032
  # ***** cast ops *****
2792
3033
 
2793
- def llvm_bf16_cast(self, dtype:DType):
3034
+ def llvm_bf16_cast(self, dtype:DTypeLike):
2794
3035
  # hack for devices that don't support bfloat16
2795
3036
  assert self.dtype == dtypes.bfloat16
2796
3037
  return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
2797
- def cast(self, dtype:DType) -> Tensor:
3038
+
3039
+ def cast(self, dtype:DTypeLike) -> Tensor:
2798
3040
  """
2799
3041
  Casts `self` to the given `dtype`.
2800
3042
 
@@ -2807,8 +3049,9 @@ class Tensor:
2807
3049
  print(t.dtype, t.numpy())
2808
3050
  ```
2809
3051
  """
2810
- return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
2811
- def bitcast(self, dtype:DType) -> Tensor:
3052
+ return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
3053
+
3054
+ def bitcast(self, dtype:DTypeLike) -> Tensor:
2812
3055
  """
2813
3056
  Bitcasts `self` to the given `dtype` of the same itemsize.
2814
3057
 
@@ -2824,7 +3067,8 @@ class Tensor:
2824
3067
  ```
2825
3068
  """
2826
3069
  if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
2827
- return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
3070
+ return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != (dt:=to_dtype(dtype)) else self
3071
+
2828
3072
  def float(self) -> Tensor:
2829
3073
  """
2830
3074
  Convenience method to cast `self` to a `float32` Tensor.
@@ -2839,6 +3083,7 @@ class Tensor:
2839
3083
  ```
2840
3084
  """
2841
3085
  return self.cast(dtypes.float32)
3086
+
2842
3087
  def half(self) -> Tensor:
2843
3088
  """
2844
3089
  Convenience method to cast `self` to a `float16` Tensor.
@@ -2854,15 +3099,35 @@ class Tensor:
2854
3099
  """
2855
3100
  return self.cast(dtypes.float16)
2856
3101
 
2857
- # ***** convenience stuff *****
3102
+ def int(self) -> Tensor:
3103
+ """
3104
+ Convenience method to cast `self` to a `int32` Tensor.
2858
3105
 
2859
- @property
2860
- def ndim(self) -> int: return len(self.shape)
2861
- def numel(self) -> sint: return prod(self.shape)
2862
- def element_size(self) -> int: return self.dtype.itemsize
2863
- def nbytes(self) -> int: return self.numel() * self.element_size()
2864
- def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
2865
- def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
3106
+ ```python exec="true" source="above" session="tensor" result="python"
3107
+ t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
3108
+ print(t.dtype, t.numpy())
3109
+ ```
3110
+ ```python exec="true" source="above" session="tensor" result="python"
3111
+ t = t.int()
3112
+ print(t.dtype, t.numpy())
3113
+ ```
3114
+ """
3115
+ return self.cast(dtypes.int32)
3116
+
3117
+ def bool(self) -> Tensor:
3118
+ """
3119
+ Convenience method to cast `self` to a `bool` Tensor.
3120
+
3121
+ ```python exec="true" source="above" session="tensor" result="python"
3122
+ t = Tensor([-1, 0, 1])
3123
+ print(t.dtype, t.numpy())
3124
+ ```
3125
+ ```python exec="true" source="above" session="tensor" result="python"
3126
+ t = t.bool()
3127
+ print(t.dtype, t.numpy())
3128
+ ```
3129
+ """
3130
+ return self.cast(dtypes.bool)
2866
3131
 
2867
3132
  # *** image Tensor function replacements ***
2868
3133
 
@@ -2960,3 +3225,40 @@ def custom_random(out:Buffer):
2960
3225
  if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
2961
3226
  else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
2962
3227
  out.copyin(rng_np_buffer.data)
3228
+
3229
+ def _metadata_wrapper(fn):
3230
+ def _wrapper(*args, **kwargs):
3231
+ if _METADATA.get() is not None: return fn(*args, **kwargs)
3232
+
3233
+ if TRACEMETA >= 2:
3234
+ caller_frame = sys._getframe(frame := 1)
3235
+ caller_module = caller_frame.f_globals.get("__name__", None)
3236
+ caller_func = caller_frame.f_code.co_name
3237
+ if caller_module is None: return fn(*args, **kwargs)
3238
+
3239
+ # if its called from nn we want to step up frames until we are out of nn
3240
+ while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
3241
+ caller_frame = sys._getframe(frame := frame + 1)
3242
+ caller_module = caller_frame.f_globals.get("__name__", None)
3243
+ if caller_module is None: return fn(*args, **kwargs)
3244
+
3245
+ # if its called from a lambda in tinygrad we want to look two more frames up
3246
+ if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
3247
+ caller_module = caller_frame.f_globals.get("__name__", None)
3248
+ if caller_module is None: return fn(*args, **kwargs)
3249
+ caller_func = caller_frame.f_code.co_name
3250
+ caller_lineno = caller_frame.f_lineno
3251
+
3252
+ caller = f"{caller_module}:{caller_lineno}::{caller_func}"
3253
+ else: caller = ""
3254
+
3255
+ token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
3256
+ ret = fn(*args, **kwargs)
3257
+ _METADATA.reset(token)
3258
+ return ret
3259
+ return _wrapper
3260
+
3261
+ if TRACEMETA >= 1:
3262
+ for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
3263
+ if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
3264
+ setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))