tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/tensor.py
CHANGED
@@ -1,37 +1,39 @@
|
|
1
1
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
2
2
|
from __future__ import annotations
|
3
|
-
import
|
3
|
+
import dataclasses
|
4
|
+
import time, math, itertools, functools, struct, sys, inspect
|
4
5
|
from contextlib import ContextDecorator
|
5
6
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
6
7
|
from collections import defaultdict
|
7
8
|
import numpy as np
|
8
9
|
|
9
|
-
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
|
10
|
+
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
10
11
|
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
11
|
-
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
12
|
+
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
|
12
13
|
from tinygrad.lazy import LazyBuffer
|
13
14
|
from tinygrad.multi import MultiLazyBuffer
|
14
|
-
from tinygrad.ops import
|
15
|
+
from tinygrad.ops import MetaOps, truncate
|
15
16
|
from tinygrad.device import Device, Buffer, BufferOptions
|
16
17
|
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
17
|
-
from tinygrad.engine.realize import run_schedule
|
18
|
-
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
18
|
+
from tinygrad.engine.realize import run_schedule, memory_planner
|
19
|
+
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
19
20
|
|
20
21
|
# **** start with two base classes, Tensor and Function ****
|
21
22
|
|
22
23
|
class Function:
|
23
|
-
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
|
24
|
+
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
24
25
|
self.device = device
|
25
26
|
self.needs_input_grad = [t.requires_grad for t in tensors]
|
26
27
|
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
27
28
|
if self.requires_grad: self.parents = tensors
|
29
|
+
self.metadata = metadata
|
28
30
|
|
29
31
|
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
30
32
|
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
31
33
|
|
32
34
|
@classmethod
|
33
35
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
34
|
-
ctx = fxn(x[0].device, *x)
|
36
|
+
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
|
35
37
|
ret = Tensor.__new__(Tensor)
|
36
38
|
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
37
39
|
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
@@ -39,24 +41,24 @@ class Function:
|
|
39
41
|
|
40
42
|
import tinygrad.function as F
|
41
43
|
|
42
|
-
def
|
43
|
-
if isinstance(device, str): return LazyBuffer.
|
44
|
-
return MultiLazyBuffer([LazyBuffer.
|
44
|
+
def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
45
|
+
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
|
46
|
+
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
45
47
|
|
46
|
-
def _from_np_dtype(npdtype:
|
48
|
+
def _from_np_dtype(npdtype:np.dtype) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
|
47
49
|
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
48
50
|
|
49
51
|
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
50
|
-
ret = LazyBuffer.
|
52
|
+
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
51
53
|
# fake realize
|
52
54
|
ret.buffer.allocate(x)
|
53
55
|
del ret.srcs
|
54
56
|
return ret
|
55
57
|
|
56
58
|
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
57
|
-
if isinstance(x, bytes): ret, data = LazyBuffer.
|
59
|
+
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
58
60
|
else:
|
59
|
-
ret = LazyBuffer.
|
61
|
+
ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
|
60
62
|
assert dtype.fmt is not None, f"{dtype=} has None fmt"
|
61
63
|
truncate_function = truncate[dtype]
|
62
64
|
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
@@ -85,7 +87,7 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
|
85
87
|
max_dim = max(len(shape) for shape in shapes)
|
86
88
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
87
89
|
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
88
|
-
return tuple(0 if
|
90
|
+
return tuple(0 if 0 in nth_dim_sizes else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
89
91
|
|
90
92
|
class Tensor:
|
91
93
|
"""
|
@@ -104,7 +106,8 @@ class Tensor:
|
|
104
106
|
no_grad: ClassVar[bool] = False
|
105
107
|
|
106
108
|
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
107
|
-
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[
|
109
|
+
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
110
|
+
if dtype is not None: dtype = to_dtype(dtype)
|
108
111
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
109
112
|
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
110
113
|
|
@@ -120,18 +123,18 @@ class Tensor:
|
|
120
123
|
|
121
124
|
# create a LazyBuffer from the different types of inputs
|
122
125
|
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
123
|
-
elif isinstance(data, get_args(ConstType)): data =
|
124
|
-
elif isinstance(data, Variable): data =
|
125
|
-
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
|
126
|
+
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
127
|
+
elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
128
|
+
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
126
129
|
elif isinstance(data, (list, tuple)):
|
127
130
|
if dtype is None:
|
128
131
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
129
132
|
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
130
133
|
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
131
134
|
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
|
132
|
-
elif data is None: data =
|
135
|
+
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
133
136
|
elif isinstance(data, np.ndarray):
|
134
|
-
if data.shape == (): data =
|
137
|
+
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
135
138
|
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
136
139
|
|
137
140
|
# by this point, it has to be a LazyBuffer
|
@@ -145,7 +148,7 @@ class Tensor:
|
|
145
148
|
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
146
149
|
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
147
150
|
else:
|
148
|
-
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
|
151
|
+
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
|
149
152
|
else:
|
150
153
|
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
151
154
|
|
@@ -318,20 +321,34 @@ class Tensor:
|
|
318
321
|
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
319
322
|
self.lazydata = real.lazydata
|
320
323
|
|
321
|
-
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
324
|
+
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
|
322
325
|
"""
|
323
|
-
Shards the tensor across the given devices.
|
326
|
+
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
|
327
|
+
|
328
|
+
```python exec="true" source="above" session="tensor" result="python"
|
329
|
+
t = Tensor.empty(2, 3)
|
330
|
+
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
|
331
|
+
```
|
332
|
+
|
324
333
|
"""
|
325
334
|
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
326
|
-
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
|
327
|
-
if axis is not None
|
328
|
-
|
329
|
-
|
330
|
-
|
335
|
+
canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
336
|
+
if axis is not None:
|
337
|
+
if axis < 0: axis += len(self.shape)
|
338
|
+
if splits is None:
|
339
|
+
sz = round_up(self.shape[axis], len(devices)) // len(devices)
|
340
|
+
splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))])
|
341
|
+
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
342
|
+
boundaries = tuple(itertools.accumulate(splits))
|
343
|
+
bounds = tuple(zip((0,) + boundaries, boundaries))
|
344
|
+
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds),
|
345
|
+
device=canonical_devices, requires_grad=self.requires_grad)
|
346
|
+
|
347
|
+
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
|
331
348
|
"""
|
332
349
|
Shards the tensor across the given devices in place.
|
333
350
|
"""
|
334
|
-
self.lazydata = self.shard(devices, axis).lazydata
|
351
|
+
self.lazydata = self.shard(devices, axis, splits).lazydata
|
335
352
|
return self
|
336
353
|
|
337
354
|
@staticmethod
|
@@ -342,14 +359,14 @@ class Tensor:
|
|
342
359
|
if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:])
|
343
360
|
raise RuntimeError(f"unhandled Node {y}")
|
344
361
|
|
345
|
-
# ***** creation
|
362
|
+
# ***** creation entrypoint *****
|
346
363
|
|
347
364
|
@staticmethod
|
348
|
-
def
|
365
|
+
def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
349
366
|
if isinstance(device, tuple):
|
350
|
-
return Tensor(MultiLazyBuffer([LazyBuffer.
|
367
|
+
return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
351
368
|
for d in device], None), device, dtype, **kwargs)
|
352
|
-
return Tensor(LazyBuffer.
|
369
|
+
return Tensor(LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
353
370
|
|
354
371
|
@staticmethod
|
355
372
|
def empty(*shape, **kwargs):
|
@@ -364,7 +381,7 @@ class Tensor:
|
|
364
381
|
print(t.shape)
|
365
382
|
```
|
366
383
|
"""
|
367
|
-
return Tensor.
|
384
|
+
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
|
368
385
|
|
369
386
|
_seed: int = int(time.time())
|
370
387
|
_rng_counter: Optional[Tensor] = None
|
@@ -384,10 +401,10 @@ class Tensor:
|
|
384
401
|
print(Tensor.rand(5).numpy())
|
385
402
|
```
|
386
403
|
"""
|
387
|
-
Tensor._seed, Tensor._rng_counter = seed,
|
404
|
+
Tensor._seed, Tensor._rng_counter = seed, None
|
388
405
|
|
389
406
|
@staticmethod
|
390
|
-
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[
|
407
|
+
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
|
391
408
|
"""
|
392
409
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
393
410
|
|
@@ -400,27 +417,25 @@ class Tensor:
|
|
400
417
|
print(t.numpy())
|
401
418
|
```
|
402
419
|
"""
|
403
|
-
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
420
|
+
if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
421
|
+
if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}")
|
404
422
|
if not THREEFRY.value:
|
405
423
|
# for bfloat16, numpy rand passes buffer in float
|
406
|
-
if (dtype or dtypes.default_float) == dtypes.bfloat16:
|
424
|
+
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
|
407
425
|
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
408
|
-
return Tensor.
|
426
|
+
return Tensor._metaop(MetaOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
409
427
|
|
410
428
|
# threefry
|
411
429
|
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
412
|
-
|
430
|
+
if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num)
|
431
|
+
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device))
|
413
432
|
counts2 = counts1 + math.ceil(num / 2)
|
414
|
-
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
415
433
|
|
416
|
-
|
417
|
-
|
434
|
+
x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64)
|
435
|
+
x = F.Threefry.apply(*x._broadcasted(Tensor._seed))
|
436
|
+
counts1, counts2 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
418
437
|
|
419
|
-
|
420
|
-
for i in range(5):
|
421
|
-
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
|
422
|
-
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
423
|
-
out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
|
438
|
+
out = counts1.cat(counts2).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
|
424
439
|
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
425
440
|
out.requires_grad = kwargs.get("requires_grad")
|
426
441
|
return out.contiguous()
|
@@ -506,12 +521,14 @@ class Tensor:
|
|
506
521
|
if stop is None: stop, start = start, 0
|
507
522
|
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
508
523
|
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
524
|
+
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
525
|
+
if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
|
509
526
|
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
510
527
|
|
511
528
|
@staticmethod
|
512
|
-
def eye(
|
529
|
+
def eye(n:int, m:Optional[int]=None, **kwargs):
|
513
530
|
"""
|
514
|
-
|
531
|
+
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
|
515
532
|
|
516
533
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
517
534
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
@@ -519,8 +536,13 @@ class Tensor:
|
|
519
536
|
```python exec="true" source="above" session="tensor" result="python"
|
520
537
|
print(Tensor.eye(3).numpy())
|
521
538
|
```
|
539
|
+
|
540
|
+
```python exec="true" source="above" session="tensor" result="python"
|
541
|
+
print(Tensor.eye(2, 4).numpy())
|
542
|
+
```
|
522
543
|
"""
|
523
|
-
|
544
|
+
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
545
|
+
return Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)._slice((None,(0,n if m is None else m)))
|
524
546
|
|
525
547
|
def full_like(self, fill_value:ConstType, **kwargs):
|
526
548
|
"""
|
@@ -568,7 +590,7 @@ class Tensor:
|
|
568
590
|
# ***** rng hlops *****
|
569
591
|
|
570
592
|
@staticmethod
|
571
|
-
def randn(*shape, dtype:Optional[
|
593
|
+
def randn(*shape, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor:
|
572
594
|
"""
|
573
595
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
574
596
|
If `dtype` is not specified, the default type is used.
|
@@ -721,10 +743,10 @@ class Tensor:
|
|
721
743
|
yield node
|
722
744
|
return list(_walk(self, set()))
|
723
745
|
|
724
|
-
def backward(self) -> Tensor:
|
746
|
+
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
725
747
|
"""
|
726
748
|
Propagates the gradient of a tensor backwards through the computation graph.
|
727
|
-
|
749
|
+
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
728
750
|
|
729
751
|
```python exec="true" source="above" session="tensor" result="python"
|
730
752
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
@@ -732,15 +754,20 @@ class Tensor:
|
|
732
754
|
print(t.grad.numpy())
|
733
755
|
```
|
734
756
|
"""
|
735
|
-
|
757
|
+
if gradient is None:
|
758
|
+
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
759
|
+
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
760
|
+
# this is "implicit gradient creation"
|
761
|
+
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
736
762
|
|
737
|
-
|
738
|
-
|
739
|
-
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
763
|
+
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
764
|
+
self.grad = gradient
|
740
765
|
|
741
766
|
for t0 in reversed(self._deepwalk()):
|
742
767
|
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
768
|
+
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
743
769
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
770
|
+
_METADATA.reset(token)
|
744
771
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
745
772
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
746
773
|
for t, g in zip(t0._ctx.parents, grads):
|
@@ -1036,8 +1063,8 @@ class Tensor:
|
|
1036
1063
|
```
|
1037
1064
|
"""
|
1038
1065
|
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
1039
|
-
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1040
1066
|
dim = self._resolve_dim(dim)
|
1067
|
+
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
1041
1068
|
index = index.to(self.device)
|
1042
1069
|
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
1043
1070
|
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
@@ -1079,6 +1106,19 @@ class Tensor:
|
|
1079
1106
|
# checks for shapes and number of dimensions delegated to cat
|
1080
1107
|
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
|
1081
1108
|
|
1109
|
+
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
|
1110
|
+
"""
|
1111
|
+
Repeat elements of a tensor.
|
1112
|
+
|
1113
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1114
|
+
t = Tensor([1, 2, 3])
|
1115
|
+
print(t.repeat_interleave(2).numpy())
|
1116
|
+
```
|
1117
|
+
"""
|
1118
|
+
x, dim = (self.flatten(), 0) if dim is None else (self, dim)
|
1119
|
+
shp = x.shape
|
1120
|
+
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
|
1121
|
+
|
1082
1122
|
def repeat(self, repeats, *args) -> Tensor:
|
1083
1123
|
"""
|
1084
1124
|
Repeats tensor number of times along each dimension specified by `repeats`.
|
@@ -1270,7 +1310,7 @@ class Tensor:
|
|
1270
1310
|
ret = fxn.apply(self, axis=axis_)
|
1271
1311
|
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_))
|
1272
1312
|
|
1273
|
-
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[
|
1313
|
+
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
1274
1314
|
"""
|
1275
1315
|
Sums the elements of the tensor along the specified axis or axes.
|
1276
1316
|
|
@@ -1343,6 +1383,50 @@ class Tensor:
|
|
1343
1383
|
"""
|
1344
1384
|
return -((-self).max(axis=axis, keepdim=keepdim))
|
1345
1385
|
|
1386
|
+
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1387
|
+
"""
|
1388
|
+
Tests if any element evaluates to `True` along the specified axis or axes.
|
1389
|
+
|
1390
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
1391
|
+
|
1392
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1393
|
+
t = Tensor([[True, True], [True, False], [False, False]])
|
1394
|
+
print(t.numpy())
|
1395
|
+
```
|
1396
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1397
|
+
print(t.any().numpy())
|
1398
|
+
```
|
1399
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1400
|
+
print(t.any(axis=0).numpy())
|
1401
|
+
```
|
1402
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1403
|
+
print(t.any(axis=1, keepdim=True).numpy())
|
1404
|
+
```
|
1405
|
+
"""
|
1406
|
+
return self.bool().max(axis, keepdim)
|
1407
|
+
|
1408
|
+
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1409
|
+
"""
|
1410
|
+
Tests if all element evaluates to `True` along the specified axis or axes.
|
1411
|
+
|
1412
|
+
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
|
1413
|
+
|
1414
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1415
|
+
t = Tensor([[True, True], [True, False], [False, False]])
|
1416
|
+
print(t.numpy())
|
1417
|
+
```
|
1418
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1419
|
+
print(t.all().numpy())
|
1420
|
+
```
|
1421
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1422
|
+
print(t.all(axis=0).numpy())
|
1423
|
+
```
|
1424
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1425
|
+
print(t.all(axis=1, keepdim=True).numpy())
|
1426
|
+
```
|
1427
|
+
"""
|
1428
|
+
return self.logical_not().any(axis, keepdim).logical_not()
|
1429
|
+
|
1346
1430
|
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
1347
1431
|
"""
|
1348
1432
|
Returns the mean value of the tensor along the specified axis or axes.
|
@@ -1548,7 +1632,7 @@ class Tensor:
|
|
1548
1632
|
return (-self).argmax(axis=axis, keepdim=keepdim)
|
1549
1633
|
|
1550
1634
|
@staticmethod
|
1551
|
-
def einsum(formula:str, *raw_xs, acc_dtype:Optional[
|
1635
|
+
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1552
1636
|
"""
|
1553
1637
|
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
1554
1638
|
|
@@ -1599,18 +1683,22 @@ class Tensor:
|
|
1599
1683
|
# handle dilation
|
1600
1684
|
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
1601
1685
|
# handle stride
|
1602
|
-
xup = xup.shrink(
|
1603
|
-
|
1686
|
+
xup = xup.shrink(
|
1687
|
+
tuple(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
1688
|
+
xup = xup.shrink(tuple(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
|
1604
1689
|
# permute to move reduce to the end
|
1605
1690
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
|
1606
1691
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
1607
1692
|
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
|
1608
1693
|
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
|
1609
|
-
xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
1694
|
+
xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))))
|
1610
1695
|
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
|
1611
1696
|
|
1697
|
+
def _padding2d(self, padding:Union[int, Tuple[int, ...]], dims:int) -> Sequence[int]:
|
1698
|
+
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
1699
|
+
|
1612
1700
|
# NOTE: these work for more than 2D
|
1613
|
-
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1701
|
+
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
1614
1702
|
"""
|
1615
1703
|
Applies average pooling over a tensor.
|
1616
1704
|
|
@@ -1622,11 +1710,15 @@ class Tensor:
|
|
1622
1710
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1623
1711
|
print(t.avg_pool2d().numpy())
|
1624
1712
|
```
|
1713
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1714
|
+
print(t.avg_pool2d(padding=1).numpy())
|
1715
|
+
```
|
1625
1716
|
"""
|
1626
|
-
|
1627
|
-
return
|
1717
|
+
padding_, axis = self._padding2d(padding, len(k_ := make_pair(kernel_size))), tuple(range(-len(k_), 0))
|
1718
|
+
def pool(x:Tensor) -> Tensor: return x.pad2d(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
1719
|
+
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
|
1628
1720
|
|
1629
|
-
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
1721
|
+
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
1630
1722
|
"""
|
1631
1723
|
Applies max pooling over a tensor.
|
1632
1724
|
|
@@ -1638,11 +1730,14 @@ class Tensor:
|
|
1638
1730
|
t = Tensor.arange(25).reshape(1, 1, 5, 5)
|
1639
1731
|
print(t.max_pool2d().numpy())
|
1640
1732
|
```
|
1733
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1734
|
+
print(t.max_pool2d(padding=1).numpy())
|
1735
|
+
```
|
1641
1736
|
"""
|
1642
|
-
|
1643
|
-
return self._pool(
|
1737
|
+
padding_ = self._padding2d(padding, len(k_ := make_pair(kernel_size)))
|
1738
|
+
return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0)))
|
1644
1739
|
|
1645
|
-
def conv2d(self, weight:Tensor, bias:
|
1740
|
+
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:DTypeLike|None=None) -> Tensor:
|
1646
1741
|
"""
|
1647
1742
|
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
1648
1743
|
|
@@ -1659,7 +1754,7 @@ class Tensor:
|
|
1659
1754
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
1660
1755
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
1661
1756
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
|
1662
|
-
padding_ =
|
1757
|
+
padding_ = self._padding2d(padding, len(HW))
|
1663
1758
|
|
1664
1759
|
# conv2d is a pooling op (with padding)
|
1665
1760
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
@@ -1729,7 +1824,7 @@ class Tensor:
|
|
1729
1824
|
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
1730
1825
|
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
1731
1826
|
|
1732
|
-
def dot(self, w:Tensor, acc_dtype:Optional[
|
1827
|
+
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1733
1828
|
"""
|
1734
1829
|
Performs dot product between two tensors.
|
1735
1830
|
|
@@ -1748,7 +1843,7 @@ class Tensor:
|
|
1748
1843
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
1749
1844
|
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
1750
1845
|
|
1751
|
-
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[
|
1846
|
+
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
1752
1847
|
"""
|
1753
1848
|
Performs matrix multiplication between two tensors.
|
1754
1849
|
|
@@ -1850,6 +1945,33 @@ class Tensor:
|
|
1850
1945
|
"""
|
1851
1946
|
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
1852
1947
|
|
1948
|
+
def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
1949
|
+
"""
|
1950
|
+
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
|
1951
|
+
|
1952
|
+
The interpolation algorithm is selected with `mode` which currently only supports `linear`.
|
1953
|
+
To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
|
1954
|
+
|
1955
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1956
|
+
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
|
1957
|
+
print(t.numpy())
|
1958
|
+
```
|
1959
|
+
```python exec="true" source="above" session="tensor" result="python"
|
1960
|
+
print(t.interpolate(size=(2,3), mode="linear").numpy())
|
1961
|
+
```
|
1962
|
+
"""
|
1963
|
+
assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
|
1964
|
+
assert mode == "linear", "only supports linear interpolate"
|
1965
|
+
x, expand = self, list(self.shape)
|
1966
|
+
for i in range(-len(size), 0):
|
1967
|
+
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
|
1968
|
+
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
|
1969
|
+
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
1970
|
+
reshape[i] = expand[i] = size[i]
|
1971
|
+
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
1972
|
+
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
1973
|
+
return x
|
1974
|
+
|
1853
1975
|
# ***** unary ops *****
|
1854
1976
|
|
1855
1977
|
def logical_not(self):
|
@@ -1999,7 +2121,7 @@ class Tensor:
|
|
1999
2121
|
Truncates the tensor element-wise.
|
2000
2122
|
|
2001
2123
|
```python exec="true" source="above" session="tensor" result="python"
|
2002
|
-
print(Tensor([-3.
|
2124
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
|
2003
2125
|
```
|
2004
2126
|
"""
|
2005
2127
|
return self.cast(dtypes.int32).cast(self.dtype)
|
@@ -2008,7 +2130,7 @@ class Tensor:
|
|
2008
2130
|
Rounds the tensor element-wise towards positive infinity.
|
2009
2131
|
|
2010
2132
|
```python exec="true" source="above" session="tensor" result="python"
|
2011
|
-
print(Tensor([-3.
|
2133
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
|
2012
2134
|
```
|
2013
2135
|
"""
|
2014
2136
|
return (self > (b := self.trunc())).where(b+1, b)
|
@@ -2017,19 +2139,20 @@ class Tensor:
|
|
2017
2139
|
Rounds the tensor element-wise towards negative infinity.
|
2018
2140
|
|
2019
2141
|
```python exec="true" source="above" session="tensor" result="python"
|
2020
|
-
print(Tensor([-3.
|
2142
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
|
2021
2143
|
```
|
2022
2144
|
"""
|
2023
2145
|
return (self < (b := self.trunc())).where(b-1, b)
|
2024
2146
|
def round(self: Tensor) -> Tensor:
|
2025
2147
|
"""
|
2026
|
-
Rounds the tensor element-wise.
|
2148
|
+
Rounds the tensor element-wise with rounding half to even.
|
2027
2149
|
|
2028
2150
|
```python exec="true" source="above" session="tensor" result="python"
|
2029
|
-
print(Tensor([-3.
|
2151
|
+
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
|
2030
2152
|
```
|
2031
2153
|
"""
|
2032
2154
|
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
|
2155
|
+
|
2033
2156
|
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
|
2034
2157
|
"""
|
2035
2158
|
Linearly interpolates between `self` and `end` by `weight`.
|
@@ -2039,6 +2162,7 @@ class Tensor:
|
|
2039
2162
|
```
|
2040
2163
|
"""
|
2041
2164
|
return self + (end - self) * weight
|
2165
|
+
|
2042
2166
|
def square(self):
|
2043
2167
|
"""
|
2044
2168
|
Squares the tensor element-wise.
|
@@ -2049,15 +2173,23 @@ class Tensor:
|
|
2049
2173
|
```
|
2050
2174
|
"""
|
2051
2175
|
return self*self
|
2052
|
-
def
|
2176
|
+
def clamp(self, min_=None, max_=None):
|
2053
2177
|
"""
|
2054
2178
|
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
2179
|
+
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
2055
2180
|
|
2056
2181
|
```python exec="true" source="above" session="tensor" result="python"
|
2057
2182
|
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
|
2058
2183
|
```
|
2059
2184
|
"""
|
2060
|
-
|
2185
|
+
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
2186
|
+
ret = self.maximum(min_) if min_ is not None else self
|
2187
|
+
return ret.minimum(max_) if max_ is not None else ret
|
2188
|
+
def clip(self, min_=None, max_=None):
|
2189
|
+
"""
|
2190
|
+
Alias for `Tensor.clamp`.
|
2191
|
+
"""
|
2192
|
+
return self.clamp(min_, max_)
|
2061
2193
|
def sign(self):
|
2062
2194
|
"""
|
2063
2195
|
Returns the sign of the tensor element-wise.
|
@@ -2340,7 +2472,7 @@ class Tensor:
|
|
2340
2472
|
x: Tensor = self
|
2341
2473
|
if not isinstance(y, Tensor):
|
2342
2474
|
# make y a Tensor
|
2343
|
-
assert isinstance(y, (
|
2475
|
+
assert isinstance(y, (*get_args(ConstType), Node)), f"{type(y)=}, {y=}"
|
2344
2476
|
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
2345
2477
|
elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
|
2346
2478
|
if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
|
@@ -2462,6 +2594,36 @@ class Tensor:
|
|
2462
2594
|
"""
|
2463
2595
|
return F.Xor.apply(*self._broadcasted(x, reverse))
|
2464
2596
|
|
2597
|
+
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2598
|
+
"""
|
2599
|
+
Compute the bit-wise AND of `self` and `x`.
|
2600
|
+
Equivalent to `self & x`.
|
2601
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
2602
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2603
|
+
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
2604
|
+
```
|
2605
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2606
|
+
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
2607
|
+
```
|
2608
|
+
"""
|
2609
|
+
assert dtypes.is_int(self.dtype)
|
2610
|
+
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
2611
|
+
|
2612
|
+
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
2613
|
+
"""
|
2614
|
+
Compute the bit-wise OR of `self` and `x`.
|
2615
|
+
Equivalent to `self | x`.
|
2616
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
2617
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2618
|
+
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
2619
|
+
```
|
2620
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2621
|
+
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
2622
|
+
```
|
2623
|
+
"""
|
2624
|
+
assert dtypes.is_int(self.dtype)
|
2625
|
+
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
2626
|
+
|
2465
2627
|
def lshift(self, x:int):
|
2466
2628
|
"""
|
2467
2629
|
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
@@ -2586,6 +2748,8 @@ class Tensor:
|
|
2586
2748
|
def __pow__(self, x) -> Tensor: return self.pow(x)
|
2587
2749
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
2588
2750
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
2751
|
+
def __and__(self, x) -> Tensor: return self.bitwise_and(x)
|
2752
|
+
def __or__(self, x) -> Tensor: return self.bitwise_or(x)
|
2589
2753
|
def __xor__(self, x) -> Tensor: return self.xor(x)
|
2590
2754
|
def __lshift__(self, x) -> Tensor: return self.lshift(x)
|
2591
2755
|
def __rshift__(self, x) -> Tensor: return self.rshift(x)
|
@@ -2596,6 +2760,8 @@ class Tensor:
|
|
2596
2760
|
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
2597
2761
|
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
2598
2762
|
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
2763
|
+
def __rand__(self, x) -> Tensor: return self.bitwise_and(x, True)
|
2764
|
+
def __ror__(self, x) -> Tensor: return self.bitwise_or(x, True)
|
2599
2765
|
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
|
2600
2766
|
|
2601
2767
|
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
@@ -2604,6 +2770,8 @@ class Tensor:
|
|
2604
2770
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
2605
2771
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
2606
2772
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
2773
|
+
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
2774
|
+
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
2607
2775
|
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
2608
2776
|
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
2609
2777
|
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
@@ -2788,13 +2956,87 @@ class Tensor:
|
|
2788
2956
|
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
2789
2957
|
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
2790
2958
|
|
2959
|
+
# ***** Tensor Properties *****
|
2960
|
+
|
2961
|
+
@property
|
2962
|
+
def ndim(self) -> int:
|
2963
|
+
"""
|
2964
|
+
Returns the number of dimensions in the tensor.
|
2965
|
+
|
2966
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2967
|
+
t = Tensor([[1, 2], [3, 4]])
|
2968
|
+
print(t.ndim)
|
2969
|
+
```
|
2970
|
+
"""
|
2971
|
+
return len(self.shape)
|
2972
|
+
|
2973
|
+
def numel(self) -> sint:
|
2974
|
+
"""
|
2975
|
+
Returns the total number of elements in the tensor.
|
2976
|
+
|
2977
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2978
|
+
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
2979
|
+
print(t.numel())
|
2980
|
+
```
|
2981
|
+
"""
|
2982
|
+
return prod(self.shape)
|
2983
|
+
|
2984
|
+
def element_size(self) -> int:
|
2985
|
+
"""
|
2986
|
+
Returns the size in bytes of an individual element in the tensor.
|
2987
|
+
|
2988
|
+
```python exec="true" source="above" session="tensor" result="python"
|
2989
|
+
t = Tensor([5], dtype=dtypes.int16)
|
2990
|
+
print(t.element_size())
|
2991
|
+
```
|
2992
|
+
"""
|
2993
|
+
return self.dtype.itemsize
|
2994
|
+
|
2995
|
+
def nbytes(self) -> int:
|
2996
|
+
"""
|
2997
|
+
Returns the total number of bytes of all elements in the tensor.
|
2998
|
+
|
2999
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3000
|
+
t = Tensor([8, 9], dtype=dtypes.float)
|
3001
|
+
print(t.nbytes())
|
3002
|
+
```
|
3003
|
+
"""
|
3004
|
+
return self.numel() * self.element_size()
|
3005
|
+
|
3006
|
+
def is_floating_point(self) -> bool:
|
3007
|
+
"""
|
3008
|
+
Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
|
3009
|
+
`dtype.float16`, `dtype.bfloat16`.
|
3010
|
+
|
3011
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3012
|
+
t = Tensor([8, 9], dtype=dtypes.float32)
|
3013
|
+
print(t.is_floating_point())
|
3014
|
+
```
|
3015
|
+
"""
|
3016
|
+
return dtypes.is_float(self.dtype)
|
3017
|
+
|
3018
|
+
def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
|
3019
|
+
"""
|
3020
|
+
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
3021
|
+
|
3022
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3023
|
+
t = Tensor([[4, 5, 6], [7, 8, 9]])
|
3024
|
+
print(t.size())
|
3025
|
+
```
|
3026
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3027
|
+
print(t.size(dim=1))
|
3028
|
+
```
|
3029
|
+
"""
|
3030
|
+
return self.shape if dim is None else self.shape[dim]
|
3031
|
+
|
2791
3032
|
# ***** cast ops *****
|
2792
3033
|
|
2793
|
-
def llvm_bf16_cast(self, dtype:
|
3034
|
+
def llvm_bf16_cast(self, dtype:DTypeLike):
|
2794
3035
|
# hack for devices that don't support bfloat16
|
2795
3036
|
assert self.dtype == dtypes.bfloat16
|
2796
3037
|
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
2797
|
-
|
3038
|
+
|
3039
|
+
def cast(self, dtype:DTypeLike) -> Tensor:
|
2798
3040
|
"""
|
2799
3041
|
Casts `self` to the given `dtype`.
|
2800
3042
|
|
@@ -2807,8 +3049,9 @@ class Tensor:
|
|
2807
3049
|
print(t.dtype, t.numpy())
|
2808
3050
|
```
|
2809
3051
|
"""
|
2810
|
-
return self if self.dtype == dtype else F.Cast.apply(self, dtype=
|
2811
|
-
|
3052
|
+
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
|
3053
|
+
|
3054
|
+
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
2812
3055
|
"""
|
2813
3056
|
Bitcasts `self` to the given `dtype` of the same itemsize.
|
2814
3057
|
|
@@ -2824,7 +3067,8 @@ class Tensor:
|
|
2824
3067
|
```
|
2825
3068
|
"""
|
2826
3069
|
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
2827
|
-
return F.Cast.apply(self, dtype=
|
3070
|
+
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != (dt:=to_dtype(dtype)) else self
|
3071
|
+
|
2828
3072
|
def float(self) -> Tensor:
|
2829
3073
|
"""
|
2830
3074
|
Convenience method to cast `self` to a `float32` Tensor.
|
@@ -2839,6 +3083,7 @@ class Tensor:
|
|
2839
3083
|
```
|
2840
3084
|
"""
|
2841
3085
|
return self.cast(dtypes.float32)
|
3086
|
+
|
2842
3087
|
def half(self) -> Tensor:
|
2843
3088
|
"""
|
2844
3089
|
Convenience method to cast `self` to a `float16` Tensor.
|
@@ -2854,15 +3099,35 @@ class Tensor:
|
|
2854
3099
|
"""
|
2855
3100
|
return self.cast(dtypes.float16)
|
2856
3101
|
|
2857
|
-
|
3102
|
+
def int(self) -> Tensor:
|
3103
|
+
"""
|
3104
|
+
Convenience method to cast `self` to a `int32` Tensor.
|
2858
3105
|
|
2859
|
-
|
2860
|
-
|
2861
|
-
|
2862
|
-
|
2863
|
-
|
2864
|
-
|
2865
|
-
|
3106
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3107
|
+
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
3108
|
+
print(t.dtype, t.numpy())
|
3109
|
+
```
|
3110
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3111
|
+
t = t.int()
|
3112
|
+
print(t.dtype, t.numpy())
|
3113
|
+
```
|
3114
|
+
"""
|
3115
|
+
return self.cast(dtypes.int32)
|
3116
|
+
|
3117
|
+
def bool(self) -> Tensor:
|
3118
|
+
"""
|
3119
|
+
Convenience method to cast `self` to a `bool` Tensor.
|
3120
|
+
|
3121
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3122
|
+
t = Tensor([-1, 0, 1])
|
3123
|
+
print(t.dtype, t.numpy())
|
3124
|
+
```
|
3125
|
+
```python exec="true" source="above" session="tensor" result="python"
|
3126
|
+
t = t.bool()
|
3127
|
+
print(t.dtype, t.numpy())
|
3128
|
+
```
|
3129
|
+
"""
|
3130
|
+
return self.cast(dtypes.bool)
|
2866
3131
|
|
2867
3132
|
# *** image Tensor function replacements ***
|
2868
3133
|
|
@@ -2960,3 +3225,40 @@ def custom_random(out:Buffer):
|
|
2960
3225
|
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
2961
3226
|
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
|
2962
3227
|
out.copyin(rng_np_buffer.data)
|
3228
|
+
|
3229
|
+
def _metadata_wrapper(fn):
|
3230
|
+
def _wrapper(*args, **kwargs):
|
3231
|
+
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
3232
|
+
|
3233
|
+
if TRACEMETA >= 2:
|
3234
|
+
caller_frame = sys._getframe(frame := 1)
|
3235
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3236
|
+
caller_func = caller_frame.f_code.co_name
|
3237
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3238
|
+
|
3239
|
+
# if its called from nn we want to step up frames until we are out of nn
|
3240
|
+
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
|
3241
|
+
caller_frame = sys._getframe(frame := frame + 1)
|
3242
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3243
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3244
|
+
|
3245
|
+
# if its called from a lambda in tinygrad we want to look two more frames up
|
3246
|
+
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
|
3247
|
+
caller_module = caller_frame.f_globals.get("__name__", None)
|
3248
|
+
if caller_module is None: return fn(*args, **kwargs)
|
3249
|
+
caller_func = caller_frame.f_code.co_name
|
3250
|
+
caller_lineno = caller_frame.f_lineno
|
3251
|
+
|
3252
|
+
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
|
3253
|
+
else: caller = ""
|
3254
|
+
|
3255
|
+
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
3256
|
+
ret = fn(*args, **kwargs)
|
3257
|
+
_METADATA.reset(token)
|
3258
|
+
return ret
|
3259
|
+
return _wrapper
|
3260
|
+
|
3261
|
+
if TRACEMETA >= 1:
|
3262
|
+
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
3263
|
+
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
3264
|
+
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|