lucid-dl 2.11.5__tar.gz → 2.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {lucid_dl-2.11.5/lucid_dl.egg-info → lucid_dl-2.12.1}/PKG-INFO +5 -1
  2. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/README.md +4 -0
  3. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/__init__.py +2 -2
  4. lucid_dl-2.12.1/lucid/_tensor/__init__.py +11 -0
  5. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_tensor/base.py +2 -0
  6. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_tensor/tensor.py +192 -3
  7. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_util/__init__.py +14 -5
  8. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_util/func.py +73 -0
  9. lucid_dl-2.12.1/lucid/datasets/__init__.py +2 -0
  10. lucid_dl-2.12.1/lucid/datasets/cifar.py +365 -0
  11. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/__init__.py +1 -0
  12. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/vit.py +6 -4
  13. lucid_dl-2.12.1/lucid/models/seqclf/__init__.py +1 -0
  14. lucid_dl-2.12.1/lucid/models/seqclf/bert.py +31 -0
  15. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/__init__.py +1 -1
  16. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/embedding.py +19 -16
  17. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_util.py +40 -8
  18. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/attention.py +58 -6
  19. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/rnn.py +263 -46
  20. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/sparse.py +16 -1
  21. lucid_dl-2.12.1/lucid/nn/utils/__init__.py +2 -0
  22. lucid_dl-2.11.5/lucid/nn/util.py → lucid_dl-2.12.1/lucid/nn/utils/_grad.py +21 -2
  23. lucid_dl-2.12.1/lucid/nn/utils/rnn.py +237 -0
  24. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/transforms/image.py +2 -2
  25. {lucid_dl-2.11.5 → lucid_dl-2.12.1/lucid_dl.egg-info}/PKG-INFO +5 -1
  26. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid_dl.egg-info/SOURCES.txt +6 -2
  27. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/setup.py +1 -1
  28. lucid_dl-2.11.5/lucid/_tensor/__init__.py +0 -1
  29. lucid_dl-2.11.5/lucid/datasets/__init__.py +0 -3
  30. lucid_dl-2.11.5/lucid/datasets/cifar.py +0 -112
  31. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/LICENSE +0 -0
  32. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_backend/__init__.py +0 -0
  33. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_backend/core.py +0 -0
  34. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_backend/metal.py +0 -0
  35. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_func/__init__.py +0 -0
  36. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_func/bfunc.py +0 -0
  37. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_func/gfunc.py +0 -0
  38. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_func/ufunc.py +0 -0
  39. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_fusion/__init__.py +0 -0
  40. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_fusion/base.py +0 -0
  41. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/_fusion/func.py +0 -0
  42. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/autograd/__init__.py +0 -0
  43. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/data/__init__.py +0 -0
  44. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/data/_base.py +0 -0
  45. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/data/_util.py +0 -0
  46. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/datasets/_base.py +0 -0
  47. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/datasets/mnist.py +0 -0
  48. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/einops/__init__.py +0 -0
  49. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/einops/_func.py +0 -0
  50. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/error.py +0 -0
  51. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/linalg/__init__.py +0 -0
  52. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/linalg/_func.py +0 -0
  53. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/__init__.py +0 -0
  54. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/alex.py +0 -0
  55. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/coatnet.py +0 -0
  56. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/convnext.py +0 -0
  57. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/crossvit.py +0 -0
  58. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/cspnet.py +0 -0
  59. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/cvt.py +0 -0
  60. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/dense.py +0 -0
  61. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/efficient.py +0 -0
  62. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/efficientformer.py +0 -0
  63. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/inception.py +0 -0
  64. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_next.py +0 -0
  65. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_res.py +0 -0
  66. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/lenet.py +0 -0
  67. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/maxvit.py +0 -0
  68. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/mobile.py +0 -0
  69. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/pvt.py +0 -0
  70. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/resnest.py +0 -0
  71. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/resnet.py +0 -0
  72. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/resnext.py +0 -0
  73. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/senet.py +0 -0
  74. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/sknet.py +0 -0
  75. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/swin.py +0 -0
  76. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/vgg.py +0 -0
  77. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/xception.py +0 -0
  78. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imgclf/zfnet.py +0 -0
  79. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imggen/__init__.py +0 -0
  80. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imggen/ddpm.py +0 -0
  81. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imggen/ncsn.py +0 -0
  82. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/imggen/vae.py +0 -0
  83. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/__init__.py +0 -0
  84. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/detr.py +0 -0
  85. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/efficientdet.py +0 -0
  86. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/fast_rcnn.py +0 -0
  87. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/faster_rcnn.py +0 -0
  88. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/rcnn.py +0 -0
  89. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/util.py +0 -0
  90. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/__init__.py +0 -0
  91. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  92. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  93. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  94. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  95. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/seq2seq/__init__.py +0 -0
  96. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/models/seq2seq/transformer.py +0 -0
  97. /lucid_dl-2.11.5/lucid/models/util.py → /lucid_dl-2.12.1/lucid/models/utils.py +0 -0
  98. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/__init__.py +0 -0
  99. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/activation.py +0 -0
  100. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/attention.py +0 -0
  101. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/conv.py +0 -0
  102. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/loss.py +0 -0
  103. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/norm.py +0 -0
  104. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/_kernel/pool.py +0 -0
  105. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/__init__.py +0 -0
  106. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_activation.py +0 -0
  107. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_attention.py +0 -0
  108. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_conv.py +0 -0
  109. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_drop.py +0 -0
  110. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_linear.py +0 -0
  111. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_loss.py +0 -0
  112. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_norm.py +0 -0
  113. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_pool.py +0 -0
  114. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/functional/_spatial.py +0 -0
  115. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/fused.py +0 -0
  116. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/init/__init__.py +0 -0
  117. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/init/_dist.py +0 -0
  118. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/module.py +0 -0
  119. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/__init__.py +0 -0
  120. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/activation.py +0 -0
  121. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/conv.py +0 -0
  122. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/drop.py +0 -0
  123. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/einops.py +0 -0
  124. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/linear.py +0 -0
  125. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/loss.py +0 -0
  126. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/norm.py +0 -0
  127. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/pool.py +0 -0
  128. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/transformer.py +0 -0
  129. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/modules/vision.py +0 -0
  130. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/nn/parameter.py +0 -0
  131. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/__init__.py +0 -0
  132. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/_base.py +0 -0
  133. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/ada.py +0 -0
  134. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/adam.py +0 -0
  135. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/__init__.py +0 -0
  136. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_base.py +0 -0
  137. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  138. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/prop.py +0 -0
  139. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/optim/sgd.py +0 -0
  140. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/port.py +0 -0
  141. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/random/__init__.py +0 -0
  142. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/random/_func.py +0 -0
  143. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/transforms/__init__.py +0 -0
  144. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/transforms/_base.py +0 -0
  145. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/types.py +0 -0
  146. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/visual/__init__.py +0 -0
  147. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/visual/mermaid.py +0 -0
  148. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/weights/__init__.py +0 -0
  149. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid/weights/__init__.pyi +0 -0
  150. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid_dl.egg-info/dependency_links.txt +0 -0
  151. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid_dl.egg-info/requires.txt +0 -0
  152. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/lucid_dl.egg-info/top_level.txt +0 -0
  153. {lucid_dl-2.11.5 → lucid_dl-2.12.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.5
3
+ Version: 2.12.1
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,6 +48,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
+ - New Tensor utility function added: `lucid.Tensor.expand`
52
+
53
+ - Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
54
+
51
55
  - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
52
56
 
53
57
  - Added additional `nn.Module` hooks for richer introspection during training:
@@ -20,6 +20,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
20
20
 
21
21
  ### 🔥 What's New
22
22
 
23
+ - New Tensor utility function added: `lucid.Tensor.expand`
24
+
25
+ - Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
26
+
23
27
  - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
24
28
 
25
29
  - Added additional `nn.Module` hooks for richer introspection during training:
@@ -25,7 +25,7 @@ import json
25
25
  import math
26
26
  import numpy as np
27
27
 
28
- from lucid._tensor import Tensor
28
+ from lucid._tensor import *
29
29
  from lucid._func import *
30
30
  from lucid._util import *
31
31
 
@@ -308,7 +308,7 @@ def register_model(func: _ModuleReturnFunc) -> _ModuleReturnFunc:
308
308
 
309
309
 
310
310
  def _conv_view_limit_mb() -> int:
311
- from lucid._kernel import conv as _conv_kernel
311
+ from lucid.nn._kernel import conv as _conv_kernel
312
312
 
313
313
  return _conv_kernel.get_conv_view_limit_mb()
314
314
 
@@ -0,0 +1,11 @@
1
+ from lucid._tensor.tensor import (
2
+ Tensor,
3
+ LongTensor,
4
+ IntTensor,
5
+ ShortTensor,
6
+ CharTensor,
7
+ HalfTensor,
8
+ FloatTensor,
9
+ DoubleTensor,
10
+ BoolTensor,
11
+ )
@@ -108,6 +108,8 @@ class _TensorBase:
108
108
 
109
109
  def broadcast_to(self, shape: _ShapeLike) -> Self: ...
110
110
 
111
+ def expand(self, *sizes: int | _ShapeLike) -> Self: ...
112
+
111
113
  def chunk(self, chunks: int, axis: int = 0) -> tuple[Self, ...]: ...
112
114
 
113
115
  def swapaxes(self, axis1: int, axis2: int) -> Self: ...
@@ -1,4 +1,15 @@
1
- from typing import Callable, Iterator, Optional, Self, SupportsIndex, Any, overload
1
+ from typing import (
2
+ Callable,
3
+ Iterator,
4
+ Optional,
5
+ Self,
6
+ SupportsIndex,
7
+ Any,
8
+ overload,
9
+ Generic,
10
+ TypeVar,
11
+ ClassVar,
12
+ )
2
13
  from types import NoneType
3
14
  from collections import deque
4
15
 
@@ -22,15 +33,32 @@ from lucid._backend.core import BackwardOperation, Operation, noop
22
33
  from lucid._backend.metal import mx, parse_mlx_indexing, check_metal_availability
23
34
 
24
35
 
36
+ __all__ = [
37
+ "Tensor",
38
+ "FloatTensor",
39
+ "DoubleTensor",
40
+ "HalfTensor",
41
+ "CharTensor",
42
+ "ShortTensor",
43
+ "IntTensor",
44
+ "LongTensor",
45
+ "BoolTensor",
46
+ ]
47
+
48
+
49
+ DType = TypeVar("DType", bound=Numeric | bool)
50
+
25
51
  _HookType = Callable[["Tensor", _NumPyArray | _MLXArray], None]
26
52
 
27
53
  _dtype_map = {int: types.Int64, float: types.Float64, complex: types.Complex64}
28
54
 
29
55
 
30
- class Tensor(_TensorBase, _TensorInplace):
56
+ class Tensor(Generic[DType], _TensorBase, _TensorInplace):
57
+ _fixed_dtype: ClassVar[Numeric | None] = None
58
+
31
59
  def __init__(
32
60
  self,
33
- data: _ArrayOrScalar | _MLXArray,
61
+ data: _ArrayOrScalar,
34
62
  requires_grad: bool = False,
35
63
  keep_grad: bool = False,
36
64
  dtype: _BuiltinNumeric | Numeric | None = None,
@@ -39,6 +67,9 @@ class Tensor(_TensorBase, _TensorInplace):
39
67
  self._is_free = False
40
68
  self._is_bool_tensor = False
41
69
 
70
+ if self._fixed_dtype is not None:
71
+ dtype = self._fixed_dtype
72
+
42
73
  if dtype is bool:
43
74
  self._is_bool_tensor = True
44
75
  dtype = None
@@ -285,6 +316,12 @@ class Tensor(_TensorBase, _TensorInplace):
285
316
  dtype = device_or_dtype
286
317
  return self.astype(dtype)
287
318
 
319
+ def cpu(self) -> Self:
320
+ return self.to(device="cpu")
321
+
322
+ def gpu(self) -> Self:
323
+ return self.to(device="gpu")
324
+
288
325
  def is_cpu(self) -> bool:
289
326
  return self.device == "cpu"
290
327
 
@@ -480,3 +517,155 @@ class Tensor(_TensorBase, _TensorInplace):
480
517
 
481
518
  def bool(self) -> Self:
482
519
  return self.astype(bool)
520
+
521
+
522
+ class LongTensor(Tensor[types.Int64]):
523
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int64
524
+
525
+ def __init__(
526
+ self,
527
+ data: _ArrayOrScalar,
528
+ requires_grad: bool = False,
529
+ keep_grad: bool = False,
530
+ device: _DeviceType = "cpu",
531
+ ) -> None:
532
+ super().__init__(
533
+ data=data,
534
+ requires_grad=requires_grad,
535
+ keep_grad=keep_grad,
536
+ dtype=types.Int64,
537
+ device=device,
538
+ )
539
+
540
+
541
+ class IntTensor(Tensor[types.Int32]):
542
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int32
543
+
544
+ def __init__(
545
+ self,
546
+ data: _ArrayOrScalar,
547
+ requires_grad: bool = False,
548
+ keep_grad: bool = False,
549
+ device: _DeviceType = "cpu",
550
+ ) -> None:
551
+ super().__init__(
552
+ data=data,
553
+ requires_grad=requires_grad,
554
+ keep_grad=keep_grad,
555
+ dtype=types.Int32,
556
+ device=device,
557
+ )
558
+
559
+
560
+ class ShortTensor(Tensor[types.Int16]):
561
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int16
562
+
563
+ def __init__(
564
+ self,
565
+ data: _ArrayOrScalar,
566
+ requires_grad: bool = False,
567
+ keep_grad: bool = False,
568
+ device: _DeviceType = "cpu",
569
+ ) -> None:
570
+ super().__init__(
571
+ data=data,
572
+ requires_grad=requires_grad,
573
+ keep_grad=keep_grad,
574
+ dtype=types.Int16,
575
+ device=device,
576
+ )
577
+
578
+
579
+ class CharTensor(Tensor[types.Int8]):
580
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int8
581
+
582
+ def __init__(
583
+ self,
584
+ data: _ArrayOrScalar,
585
+ requires_grad: bool = False,
586
+ keep_grad: bool = False,
587
+ device: _DeviceType = "cpu",
588
+ ) -> None:
589
+ super().__init__(
590
+ data=data,
591
+ requires_grad=requires_grad,
592
+ keep_grad=keep_grad,
593
+ dtype=types.Int8,
594
+ device=device,
595
+ )
596
+
597
+
598
+ class HalfTensor(Tensor[types.Float16]):
599
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float16
600
+
601
+ def __init__(
602
+ self,
603
+ data: _ArrayOrScalar,
604
+ requires_grad: bool = False,
605
+ keep_grad: bool = False,
606
+ device: _DeviceType = "cpu",
607
+ ) -> None:
608
+ super().__init__(
609
+ data=data,
610
+ requires_grad=requires_grad,
611
+ keep_grad=keep_grad,
612
+ dtype=types.Float16,
613
+ device=device,
614
+ )
615
+
616
+
617
+ class FloatTensor(Tensor[types.Float32]):
618
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float32
619
+
620
+ def __init__(
621
+ self,
622
+ data: _ArrayOrScalar,
623
+ requires_grad: bool = False,
624
+ keep_grad: bool = False,
625
+ device: _DeviceType = "cpu",
626
+ ) -> None:
627
+ super().__init__(
628
+ data=data,
629
+ requires_grad=requires_grad,
630
+ keep_grad=keep_grad,
631
+ dtype=types.Float32,
632
+ device=device,
633
+ )
634
+
635
+
636
+ class DoubleTensor(Tensor[types.Float64]):
637
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float64
638
+
639
+ def __init__(
640
+ self,
641
+ data: _ArrayOrScalar,
642
+ requires_grad: bool = False,
643
+ keep_grad: bool = False,
644
+ device: _DeviceType = "cpu",
645
+ ) -> None:
646
+ super().__init__(
647
+ data=data,
648
+ requires_grad=requires_grad,
649
+ keep_grad=keep_grad,
650
+ dtype=types.Float64,
651
+ device=device,
652
+ )
653
+
654
+
655
+ class BoolTensor(Tensor[bool]):
656
+ _fixed_dtype: ClassVar[Numeric | None] = None
657
+
658
+ def __init__(
659
+ self,
660
+ data: _ArrayOrScalar,
661
+ requires_grad: bool = False,
662
+ keep_grad: bool = False,
663
+ device: _DeviceType = "cpu",
664
+ ) -> None:
665
+ super().__init__(
666
+ data=data,
667
+ requires_grad=requires_grad,
668
+ keep_grad=keep_grad,
669
+ dtype=bool,
670
+ device=device,
671
+ )
@@ -9,11 +9,11 @@ from lucid._util import func
9
9
  # fmt: off
10
10
  __all__ = [
11
11
  "reshape", "squeeze", "unsqueeze", "expand_dims", "ravel", "stack", "hstack",
12
- "vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
13
- "split", "tril", "triu", "broadcast_to", "chunk", "masked_fill", "roll",
14
- "unbind", "sort", "nonzero", "unique", "topk", "argsort", "histogramdd",
15
- "histogram", "histogram2d", "where", "nonzero", "argmin", "argmax",
16
- "diagonal",
12
+ "vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
13
+ "split", "tril", "triu", "broadcast_to", "expand", "chunk", "masked_fill",
14
+ "roll", "unbind", "sort", "nonzero", "unique", "topk", "argsort",
15
+ "histogramdd", "histogram", "histogram2d", "where", "nonzero", "argmin",
16
+ "argmax", "diagonal",
17
17
  ]
18
18
  # fmt: on
19
19
 
@@ -106,6 +106,14 @@ def broadcast_to(a: Tensor, /, shape: _ShapeLike) -> Tensor:
106
106
  return func.broadcast_to(shape)(a)
107
107
 
108
108
 
109
+ def expand(a: Tensor, /, *sizes: int | _ShapeLike) -> Tensor:
110
+ if len(sizes) == 1 and isinstance(sizes[0], (tuple, list)):
111
+ shape = sizes[0]
112
+ else:
113
+ shape = sizes
114
+ return func.expand(shape)(a)
115
+
116
+
109
117
  def chunk(a: Tensor, /, chunks: int, axis: int = 0) -> tuple[Tensor, ...]:
110
118
  return func.chunk(chunks, axis)(a)
111
119
 
@@ -257,6 +265,7 @@ Tensor.split = split
257
265
  Tensor.tril = tril
258
266
  Tensor.triu = triu
259
267
  Tensor.broadcast_to = broadcast_to
268
+ Tensor.expand = expand
260
269
  Tensor.chunk = chunk
261
270
  Tensor.masked_fill = masked_fill
262
271
  Tensor.roll = roll
@@ -605,6 +605,79 @@ class broadcast_to(Operation):
605
605
  return self.result.grad.reshape(self.original_shape)
606
606
 
607
607
 
608
+ class expand(Operation):
609
+ def __init__(self, shape: _ShapeLike) -> None:
610
+ super().__init__()
611
+ self.shape = shape
612
+
613
+ def _resolve_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
614
+ shape = tuple(int(dim) for dim in self.shape)
615
+ if len(shape) == 0:
616
+ raise ValueError("expand() expects at least one dimension.")
617
+
618
+ if len(shape) < len(input_shape):
619
+ raise ValueError(
620
+ "expand() cannot shrink the number of dimensions from "
621
+ f"{len(input_shape)} to {len(shape)}."
622
+ )
623
+
624
+ ndim_diff = len(shape) - len(input_shape)
625
+ padded_input = (1,) * ndim_diff + input_shape
626
+
627
+ resolved: list[int] = []
628
+ for axis, (target_dim, input_dim) in enumerate(zip(shape, padded_input)):
629
+ if target_dim == -1:
630
+ if axis < ndim_diff:
631
+ raise ValueError(
632
+ "expand() cannot use -1 in a leading, "
633
+ "non-existing dimension."
634
+ )
635
+ target_dim = input_dim
636
+
637
+ elif target_dim < -1:
638
+ raise ValueError("expand() size must be >= -1.")
639
+
640
+ if input_dim == target_dim:
641
+ resolved.append(target_dim)
642
+ elif input_dim == 1 and target_dim >= 0:
643
+ resolved.append(target_dim)
644
+ else:
645
+ raise ValueError(
646
+ "expand() cannot expand dimension "
647
+ f"{axis} from {input_dim} to {target_dim}."
648
+ )
649
+
650
+ return tuple(resolved)
651
+
652
+ @unary_func_op()
653
+ def cpu(self, a: Tensor) -> _FuncOpReturnType:
654
+ self.original_shape = a.shape
655
+ self.expanded_shape = self._resolve_shape(a.shape)
656
+
657
+ self.result = Tensor(np.broadcast_to(a.data, self.expanded_shape))
658
+ return self.result, self.__grad__
659
+
660
+ @unary_func_op(device="gpu")
661
+ def gpu(self, a: Tensor) -> _FuncOpReturnType:
662
+ self.original_shape = a.shape
663
+ self.expanded_shape = self._resolve_shape(a.shape)
664
+
665
+ self.result = Tensor(mx.broadcast_to(a.data, self.expanded_shape))
666
+ return self.result, self.__grad__
667
+
668
+ def __grad__(self) -> _GradType:
669
+ input_shape = self.original_shape
670
+ ndim_diff = len(self.expanded_shape) - len(input_shape)
671
+ if ndim_diff > 0:
672
+ input_shape = (1,) * ndim_diff + input_shape
673
+
674
+ for axis, (in_dim, out_dim) in enumerate(zip(input_shape, self.expanded_shape)):
675
+ if in_dim == 1 and out_dim > 1:
676
+ self.result.grad = self.result.grad.sum(axis=axis, keepdims=True)
677
+
678
+ return self.result.grad.reshape(self.original_shape)
679
+
680
+
608
681
  class chunk(Operation):
609
682
  def __init__(self, chunks: int, axis: int) -> None:
610
683
  super().__init__()
@@ -0,0 +1,2 @@
1
+ from .mnist import *
2
+ from .cifar import *