lucid-dl 2.12.0__tar.gz → 2.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. {lucid_dl-2.12.0/lucid_dl.egg-info → lucid_dl-2.12.1}/PKG-INFO +5 -1
  2. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/README.md +4 -0
  3. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/__init__.py +2 -2
  4. lucid_dl-2.12.1/lucid/_tensor/__init__.py +11 -0
  5. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_tensor/base.py +2 -0
  6. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_tensor/tensor.py +192 -3
  7. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_util/__init__.py +14 -5
  8. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_util/func.py +73 -0
  9. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/__init__.py +1 -0
  10. lucid_dl-2.12.1/lucid/models/seqclf/__init__.py +1 -0
  11. lucid_dl-2.12.1/lucid/models/seqclf/bert.py +31 -0
  12. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/embedding.py +19 -16
  13. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_util.py +40 -8
  14. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/attention.py +58 -6
  15. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/rnn.py +133 -21
  16. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/sparse.py +16 -1
  17. {lucid_dl-2.12.0 → lucid_dl-2.12.1/lucid_dl.egg-info}/PKG-INFO +5 -1
  18. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/SOURCES.txt +2 -0
  19. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/setup.py +1 -1
  20. lucid_dl-2.12.0/lucid/_tensor/__init__.py +0 -1
  21. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/LICENSE +0 -0
  22. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/__init__.py +0 -0
  23. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/core.py +0 -0
  24. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_backend/metal.py +0 -0
  25. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/__init__.py +0 -0
  26. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/bfunc.py +0 -0
  27. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/gfunc.py +0 -0
  28. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_func/ufunc.py +0 -0
  29. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/__init__.py +0 -0
  30. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/base.py +0 -0
  31. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/_fusion/func.py +0 -0
  32. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/autograd/__init__.py +0 -0
  33. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/__init__.py +0 -0
  34. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/_base.py +0 -0
  35. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/data/_util.py +0 -0
  36. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/__init__.py +0 -0
  37. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/_base.py +0 -0
  38. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/cifar.py +0 -0
  39. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/datasets/mnist.py +0 -0
  40. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/einops/__init__.py +0 -0
  41. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/einops/_func.py +0 -0
  42. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/error.py +0 -0
  43. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/linalg/__init__.py +0 -0
  44. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/linalg/_func.py +0 -0
  45. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/__init__.py +0 -0
  46. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/alex.py +0 -0
  47. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/coatnet.py +0 -0
  48. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/convnext.py +0 -0
  49. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/crossvit.py +0 -0
  50. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/cspnet.py +0 -0
  51. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/cvt.py +0 -0
  52. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/dense.py +0 -0
  53. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/efficient.py +0 -0
  54. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/efficientformer.py +0 -0
  55. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception.py +0 -0
  56. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_next.py +0 -0
  57. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/inception_res.py +0 -0
  58. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/lenet.py +0 -0
  59. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/maxvit.py +0 -0
  60. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/mobile.py +0 -0
  61. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/pvt.py +0 -0
  62. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnest.py +0 -0
  63. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnet.py +0 -0
  64. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/resnext.py +0 -0
  65. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/senet.py +0 -0
  66. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/sknet.py +0 -0
  67. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/swin.py +0 -0
  68. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/vgg.py +0 -0
  69. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/vit.py +0 -0
  70. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/xception.py +0 -0
  71. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imgclf/zfnet.py +0 -0
  72. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/__init__.py +0 -0
  73. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/ddpm.py +0 -0
  74. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/ncsn.py +0 -0
  75. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/imggen/vae.py +0 -0
  76. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/__init__.py +0 -0
  77. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/detr.py +0 -0
  78. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/efficientdet.py +0 -0
  79. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/fast_rcnn.py +0 -0
  80. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/faster_rcnn.py +0 -0
  81. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/rcnn.py +0 -0
  82. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/util.py +0 -0
  83. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/__init__.py +0 -0
  84. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
  85. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
  86. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
  87. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
  88. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/seq2seq/__init__.py +0 -0
  89. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/seq2seq/transformer.py +0 -0
  90. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/models/utils.py +0 -0
  91. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/__init__.py +0 -0
  92. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/__init__.py +0 -0
  93. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/activation.py +0 -0
  94. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/attention.py +0 -0
  95. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/conv.py +0 -0
  96. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/loss.py +0 -0
  97. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/norm.py +0 -0
  98. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/_kernel/pool.py +0 -0
  99. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/__init__.py +0 -0
  100. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_activation.py +0 -0
  101. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_attention.py +0 -0
  102. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_conv.py +0 -0
  103. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_drop.py +0 -0
  104. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_linear.py +0 -0
  105. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_loss.py +0 -0
  106. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_norm.py +0 -0
  107. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_pool.py +0 -0
  108. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/functional/_spatial.py +0 -0
  109. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/fused.py +0 -0
  110. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/init/__init__.py +0 -0
  111. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/init/_dist.py +0 -0
  112. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/module.py +0 -0
  113. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/__init__.py +0 -0
  114. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/activation.py +0 -0
  115. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/conv.py +0 -0
  116. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/drop.py +0 -0
  117. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/einops.py +0 -0
  118. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/linear.py +0 -0
  119. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/loss.py +0 -0
  120. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/norm.py +0 -0
  121. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/pool.py +0 -0
  122. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/transformer.py +0 -0
  123. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/modules/vision.py +0 -0
  124. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/parameter.py +0 -0
  125. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/__init__.py +0 -0
  126. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/_grad.py +0 -0
  127. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/nn/utils/rnn.py +0 -0
  128. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/__init__.py +0 -0
  129. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/_base.py +0 -0
  130. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/ada.py +0 -0
  131. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/adam.py +0 -0
  132. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/__init__.py +0 -0
  133. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_base.py +0 -0
  134. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
  135. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/prop.py +0 -0
  136. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/optim/sgd.py +0 -0
  137. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/port.py +0 -0
  138. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/random/__init__.py +0 -0
  139. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/random/_func.py +0 -0
  140. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/__init__.py +0 -0
  141. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/_base.py +0 -0
  142. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/transforms/image.py +0 -0
  143. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/types.py +0 -0
  144. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/visual/__init__.py +0 -0
  145. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/visual/mermaid.py +0 -0
  146. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/weights/__init__.py +0 -0
  147. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid/weights/__init__.pyi +0 -0
  148. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/dependency_links.txt +0 -0
  149. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/requires.txt +0 -0
  150. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/lucid_dl.egg-info/top_level.txt +0 -0
  151. {lucid_dl-2.12.0 → lucid_dl-2.12.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.12.0
3
+ Version: 2.12.1
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,6 +48,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
+ - New Tensor utility function added: `lucid.Tensor.expand`
52
+
53
+ - Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
54
+
51
55
  - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
52
56
 
53
57
  - Added additional `nn.Module` hooks for richer introspection during training:
@@ -20,6 +20,10 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
20
20
 
21
21
  ### 🔥 What's New
22
22
 
23
+ - New Tensor utility function added: `lucid.Tensor.expand`
24
+
25
+ - Added Type-Generic Tensors: `lucid.LongTensor`, `lucid.DoubleTensor`, etc.
26
+
23
27
  - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
24
28
 
25
29
  - Added additional `nn.Module` hooks for richer introspection during training:
@@ -25,7 +25,7 @@ import json
25
25
  import math
26
26
  import numpy as np
27
27
 
28
- from lucid._tensor import Tensor
28
+ from lucid._tensor import *
29
29
  from lucid._func import *
30
30
  from lucid._util import *
31
31
 
@@ -308,7 +308,7 @@ def register_model(func: _ModuleReturnFunc) -> _ModuleReturnFunc:
308
308
 
309
309
 
310
310
  def _conv_view_limit_mb() -> int:
311
- from lucid._kernel import conv as _conv_kernel
311
+ from lucid.nn._kernel import conv as _conv_kernel
312
312
 
313
313
  return _conv_kernel.get_conv_view_limit_mb()
314
314
 
@@ -0,0 +1,11 @@
1
+ from lucid._tensor.tensor import (
2
+ Tensor,
3
+ LongTensor,
4
+ IntTensor,
5
+ ShortTensor,
6
+ CharTensor,
7
+ HalfTensor,
8
+ FloatTensor,
9
+ DoubleTensor,
10
+ BoolTensor,
11
+ )
@@ -108,6 +108,8 @@ class _TensorBase:
108
108
 
109
109
  def broadcast_to(self, shape: _ShapeLike) -> Self: ...
110
110
 
111
+ def expand(self, *sizes: int | _ShapeLike) -> Self: ...
112
+
111
113
  def chunk(self, chunks: int, axis: int = 0) -> tuple[Self, ...]: ...
112
114
 
113
115
  def swapaxes(self, axis1: int, axis2: int) -> Self: ...
@@ -1,4 +1,15 @@
1
- from typing import Callable, Iterator, Optional, Self, SupportsIndex, Any, overload
1
+ from typing import (
2
+ Callable,
3
+ Iterator,
4
+ Optional,
5
+ Self,
6
+ SupportsIndex,
7
+ Any,
8
+ overload,
9
+ Generic,
10
+ TypeVar,
11
+ ClassVar,
12
+ )
2
13
  from types import NoneType
3
14
  from collections import deque
4
15
 
@@ -22,15 +33,32 @@ from lucid._backend.core import BackwardOperation, Operation, noop
22
33
  from lucid._backend.metal import mx, parse_mlx_indexing, check_metal_availability
23
34
 
24
35
 
36
+ __all__ = [
37
+ "Tensor",
38
+ "FloatTensor",
39
+ "DoubleTensor",
40
+ "HalfTensor",
41
+ "CharTensor",
42
+ "ShortTensor",
43
+ "IntTensor",
44
+ "LongTensor",
45
+ "BoolTensor",
46
+ ]
47
+
48
+
49
+ DType = TypeVar("DType", bound=Numeric | bool)
50
+
25
51
  _HookType = Callable[["Tensor", _NumPyArray | _MLXArray], None]
26
52
 
27
53
  _dtype_map = {int: types.Int64, float: types.Float64, complex: types.Complex64}
28
54
 
29
55
 
30
- class Tensor(_TensorBase, _TensorInplace):
56
+ class Tensor(Generic[DType], _TensorBase, _TensorInplace):
57
+ _fixed_dtype: ClassVar[Numeric | None] = None
58
+
31
59
  def __init__(
32
60
  self,
33
- data: _ArrayOrScalar | _MLXArray,
61
+ data: _ArrayOrScalar,
34
62
  requires_grad: bool = False,
35
63
  keep_grad: bool = False,
36
64
  dtype: _BuiltinNumeric | Numeric | None = None,
@@ -39,6 +67,9 @@ class Tensor(_TensorBase, _TensorInplace):
39
67
  self._is_free = False
40
68
  self._is_bool_tensor = False
41
69
 
70
+ if self._fixed_dtype is not None:
71
+ dtype = self._fixed_dtype
72
+
42
73
  if dtype is bool:
43
74
  self._is_bool_tensor = True
44
75
  dtype = None
@@ -285,6 +316,12 @@ class Tensor(_TensorBase, _TensorInplace):
285
316
  dtype = device_or_dtype
286
317
  return self.astype(dtype)
287
318
 
319
+ def cpu(self) -> Self:
320
+ return self.to(device="cpu")
321
+
322
+ def gpu(self) -> Self:
323
+ return self.to(device="gpu")
324
+
288
325
  def is_cpu(self) -> bool:
289
326
  return self.device == "cpu"
290
327
 
@@ -480,3 +517,155 @@ class Tensor(_TensorBase, _TensorInplace):
480
517
 
481
518
  def bool(self) -> Self:
482
519
  return self.astype(bool)
520
+
521
+
522
+ class LongTensor(Tensor[types.Int64]):
523
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int64
524
+
525
+ def __init__(
526
+ self,
527
+ data: _ArrayOrScalar,
528
+ requires_grad: bool = False,
529
+ keep_grad: bool = False,
530
+ device: _DeviceType = "cpu",
531
+ ) -> None:
532
+ super().__init__(
533
+ data=data,
534
+ requires_grad=requires_grad,
535
+ keep_grad=keep_grad,
536
+ dtype=types.Int64,
537
+ device=device,
538
+ )
539
+
540
+
541
+ class IntTensor(Tensor[types.Int32]):
542
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int32
543
+
544
+ def __init__(
545
+ self,
546
+ data: _ArrayOrScalar,
547
+ requires_grad: bool = False,
548
+ keep_grad: bool = False,
549
+ device: _DeviceType = "cpu",
550
+ ) -> None:
551
+ super().__init__(
552
+ data=data,
553
+ requires_grad=requires_grad,
554
+ keep_grad=keep_grad,
555
+ dtype=types.Int32,
556
+ device=device,
557
+ )
558
+
559
+
560
+ class ShortTensor(Tensor[types.Int16]):
561
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int16
562
+
563
+ def __init__(
564
+ self,
565
+ data: _ArrayOrScalar,
566
+ requires_grad: bool = False,
567
+ keep_grad: bool = False,
568
+ device: _DeviceType = "cpu",
569
+ ) -> None:
570
+ super().__init__(
571
+ data=data,
572
+ requires_grad=requires_grad,
573
+ keep_grad=keep_grad,
574
+ dtype=types.Int16,
575
+ device=device,
576
+ )
577
+
578
+
579
+ class CharTensor(Tensor[types.Int8]):
580
+ _fixed_dtype: ClassVar[Numeric | None] = types.Int8
581
+
582
+ def __init__(
583
+ self,
584
+ data: _ArrayOrScalar,
585
+ requires_grad: bool = False,
586
+ keep_grad: bool = False,
587
+ device: _DeviceType = "cpu",
588
+ ) -> None:
589
+ super().__init__(
590
+ data=data,
591
+ requires_grad=requires_grad,
592
+ keep_grad=keep_grad,
593
+ dtype=types.Int8,
594
+ device=device,
595
+ )
596
+
597
+
598
+ class HalfTensor(Tensor[types.Float16]):
599
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float16
600
+
601
+ def __init__(
602
+ self,
603
+ data: _ArrayOrScalar,
604
+ requires_grad: bool = False,
605
+ keep_grad: bool = False,
606
+ device: _DeviceType = "cpu",
607
+ ) -> None:
608
+ super().__init__(
609
+ data=data,
610
+ requires_grad=requires_grad,
611
+ keep_grad=keep_grad,
612
+ dtype=types.Float16,
613
+ device=device,
614
+ )
615
+
616
+
617
+ class FloatTensor(Tensor[types.Float32]):
618
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float32
619
+
620
+ def __init__(
621
+ self,
622
+ data: _ArrayOrScalar,
623
+ requires_grad: bool = False,
624
+ keep_grad: bool = False,
625
+ device: _DeviceType = "cpu",
626
+ ) -> None:
627
+ super().__init__(
628
+ data=data,
629
+ requires_grad=requires_grad,
630
+ keep_grad=keep_grad,
631
+ dtype=types.Float32,
632
+ device=device,
633
+ )
634
+
635
+
636
+ class DoubleTensor(Tensor[types.Float64]):
637
+ _fixed_dtype: ClassVar[Numeric | None] = types.Float64
638
+
639
+ def __init__(
640
+ self,
641
+ data: _ArrayOrScalar,
642
+ requires_grad: bool = False,
643
+ keep_grad: bool = False,
644
+ device: _DeviceType = "cpu",
645
+ ) -> None:
646
+ super().__init__(
647
+ data=data,
648
+ requires_grad=requires_grad,
649
+ keep_grad=keep_grad,
650
+ dtype=types.Float64,
651
+ device=device,
652
+ )
653
+
654
+
655
+ class BoolTensor(Tensor[bool]):
656
+ _fixed_dtype: ClassVar[Numeric | None] = None
657
+
658
+ def __init__(
659
+ self,
660
+ data: _ArrayOrScalar,
661
+ requires_grad: bool = False,
662
+ keep_grad: bool = False,
663
+ device: _DeviceType = "cpu",
664
+ ) -> None:
665
+ super().__init__(
666
+ data=data,
667
+ requires_grad=requires_grad,
668
+ keep_grad=keep_grad,
669
+ dtype=bool,
670
+ device=device,
671
+ )
@@ -9,11 +9,11 @@ from lucid._util import func
9
9
  # fmt: off
10
10
  __all__ = [
11
11
  "reshape", "squeeze", "unsqueeze", "expand_dims", "ravel", "stack", "hstack",
12
- "vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
13
- "split", "tril", "triu", "broadcast_to", "chunk", "masked_fill", "roll",
14
- "unbind", "sort", "nonzero", "unique", "topk", "argsort", "histogramdd",
15
- "histogram", "histogram2d", "where", "nonzero", "argmin", "argmax",
16
- "diagonal",
12
+ "vstack", "concatenate", "pad", "repeat", "tile", "flatten", "meshgrid",
13
+ "split", "tril", "triu", "broadcast_to", "expand", "chunk", "masked_fill",
14
+ "roll", "unbind", "sort", "nonzero", "unique", "topk", "argsort",
15
+ "histogramdd", "histogram", "histogram2d", "where", "nonzero", "argmin",
16
+ "argmax", "diagonal",
17
17
  ]
18
18
  # fmt: on
19
19
 
@@ -106,6 +106,14 @@ def broadcast_to(a: Tensor, /, shape: _ShapeLike) -> Tensor:
106
106
  return func.broadcast_to(shape)(a)
107
107
 
108
108
 
109
+ def expand(a: Tensor, /, *sizes: int | _ShapeLike) -> Tensor:
110
+ if len(sizes) == 1 and isinstance(sizes[0], (tuple, list)):
111
+ shape = sizes[0]
112
+ else:
113
+ shape = sizes
114
+ return func.expand(shape)(a)
115
+
116
+
109
117
  def chunk(a: Tensor, /, chunks: int, axis: int = 0) -> tuple[Tensor, ...]:
110
118
  return func.chunk(chunks, axis)(a)
111
119
 
@@ -257,6 +265,7 @@ Tensor.split = split
257
265
  Tensor.tril = tril
258
266
  Tensor.triu = triu
259
267
  Tensor.broadcast_to = broadcast_to
268
+ Tensor.expand = expand
260
269
  Tensor.chunk = chunk
261
270
  Tensor.masked_fill = masked_fill
262
271
  Tensor.roll = roll
@@ -605,6 +605,79 @@ class broadcast_to(Operation):
605
605
  return self.result.grad.reshape(self.original_shape)
606
606
 
607
607
 
608
+ class expand(Operation):
609
+ def __init__(self, shape: _ShapeLike) -> None:
610
+ super().__init__()
611
+ self.shape = shape
612
+
613
+ def _resolve_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
614
+ shape = tuple(int(dim) for dim in self.shape)
615
+ if len(shape) == 0:
616
+ raise ValueError("expand() expects at least one dimension.")
617
+
618
+ if len(shape) < len(input_shape):
619
+ raise ValueError(
620
+ "expand() cannot shrink the number of dimensions from "
621
+ f"{len(input_shape)} to {len(shape)}."
622
+ )
623
+
624
+ ndim_diff = len(shape) - len(input_shape)
625
+ padded_input = (1,) * ndim_diff + input_shape
626
+
627
+ resolved: list[int] = []
628
+ for axis, (target_dim, input_dim) in enumerate(zip(shape, padded_input)):
629
+ if target_dim == -1:
630
+ if axis < ndim_diff:
631
+ raise ValueError(
632
+ "expand() cannot use -1 in a leading, "
633
+ "non-existing dimension."
634
+ )
635
+ target_dim = input_dim
636
+
637
+ elif target_dim < -1:
638
+ raise ValueError("expand() size must be >= -1.")
639
+
640
+ if input_dim == target_dim:
641
+ resolved.append(target_dim)
642
+ elif input_dim == 1 and target_dim >= 0:
643
+ resolved.append(target_dim)
644
+ else:
645
+ raise ValueError(
646
+ "expand() cannot expand dimension "
647
+ f"{axis} from {input_dim} to {target_dim}."
648
+ )
649
+
650
+ return tuple(resolved)
651
+
652
+ @unary_func_op()
653
+ def cpu(self, a: Tensor) -> _FuncOpReturnType:
654
+ self.original_shape = a.shape
655
+ self.expanded_shape = self._resolve_shape(a.shape)
656
+
657
+ self.result = Tensor(np.broadcast_to(a.data, self.expanded_shape))
658
+ return self.result, self.__grad__
659
+
660
+ @unary_func_op(device="gpu")
661
+ def gpu(self, a: Tensor) -> _FuncOpReturnType:
662
+ self.original_shape = a.shape
663
+ self.expanded_shape = self._resolve_shape(a.shape)
664
+
665
+ self.result = Tensor(mx.broadcast_to(a.data, self.expanded_shape))
666
+ return self.result, self.__grad__
667
+
668
+ def __grad__(self) -> _GradType:
669
+ input_shape = self.original_shape
670
+ ndim_diff = len(self.expanded_shape) - len(input_shape)
671
+ if ndim_diff > 0:
672
+ input_shape = (1,) * ndim_diff + input_shape
673
+
674
+ for axis, (in_dim, out_dim) in enumerate(zip(input_shape, self.expanded_shape)):
675
+ if in_dim == 1 and out_dim > 1:
676
+ self.result.grad = self.result.grad.sum(axis=axis, keepdims=True)
677
+
678
+ return self.result.grad.reshape(self.original_shape)
679
+
680
+
608
681
  class chunk(Operation):
609
682
  def __init__(self, chunks: int, axis: int) -> None:
610
683
  super().__init__()
@@ -2,3 +2,4 @@ from .imgclf import *
2
2
  from .imggen import *
3
3
  from .objdet import *
4
4
  from .seq2seq import *
5
+ from .seqclf import *
@@ -0,0 +1 @@
1
+ from .bert import *
@@ -0,0 +1,31 @@
1
+ import lucid
2
+ import lucid.nn as nn
3
+ import lucid.nn.functional as F
4
+
5
+ from lucid._tensor import Tensor
6
+
7
+
8
+ class _BertEmbeddings(nn.Module):
9
+ def __init__(
10
+ self,
11
+ vocab_size: int,
12
+ hidden_size: int,
13
+ pad_token_id: int,
14
+ max_position_embeddings: int,
15
+ type_vocab_size: int,
16
+ layer_norm_eps: float,
17
+ hidden_dropout_prob: float,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size, pad_token_id)
21
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
22
+ self.token_type_embeddings = nn.Embedding(type_vocab_size)
23
+
24
+ self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
25
+ self.dropout = nn.Dropout(hidden_dropout_prob)
26
+
27
+ self.position_ids: nn.Buffer
28
+ self.register_buffer(
29
+ "position_ids", nn.Buffer(lucid.arange(max_position_embeddings))
30
+ )
31
+ # TODO: Implement `lucid.Tensor.expand`
@@ -1,4 +1,4 @@
1
- import functools
1
+ from functools import partial
2
2
  from types import ModuleType
3
3
 
4
4
  import numpy as np
@@ -7,49 +7,44 @@ from lucid._backend.core import Operation, func_op, _FuncOpReturnType, _GradType
7
7
  from lucid._backend.metal import mx
8
8
  from lucid._tensor import Tensor
9
9
 
10
- from lucid.types import _DeviceType, _TensorData
11
-
12
-
13
- def _as_int_array(arr, lib_: ModuleType) -> _TensorData:
14
- if lib_ is np:
15
- return arr.astype(np.int64)
16
- return arr.astype(mx.int32)
17
-
18
10
 
19
11
  class embedding_kernel(Operation):
20
- def __init__(self) -> None:
12
+ def __init__(self, padding_idx: int = -1) -> None:
21
13
  super().__init__()
14
+ self.padding_idx = int(padding_idx)
22
15
  self._indices = None
23
16
  self._num_embeddings = None
24
17
 
25
18
  def clear(self) -> None:
26
19
  super().clear()
20
+ self.padding_idx = -1
27
21
  self._indices = None
28
22
  self._num_embeddings = None
29
23
 
30
24
  @func_op(n_in=2, n_ret=1)
31
25
  def cpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
32
- return self._forward(indices, weight, lib_=np, device="cpu")
26
+ return self._forward(indices, weight, lib_=np)
33
27
 
34
28
  @func_op(n_in=2, n_ret=1, device="gpu")
35
29
  def gpu(self, indices: Tensor, weight: Tensor) -> _FuncOpReturnType:
36
- return self._forward(indices, weight, lib_=mx, device="gpu")
30
+ return self._forward(indices, weight, lib_=mx)
37
31
 
38
32
  def _forward(
39
- self, indices: Tensor, weight: Tensor, lib_: ModuleType, device: _DeviceType
33
+ self, indices: Tensor, weight: Tensor, lib_: ModuleType
40
34
  ) -> _FuncOpReturnType:
41
- idx = _as_int_array(indices.data, lib_)
35
+ idx = indices.data
42
36
  out = weight.data[idx]
43
37
 
44
38
  self._indices = idx
45
39
  self._num_embeddings = int(weight.shape[0])
46
40
 
47
- self.result = Tensor(out, device=device)
48
- return self.result, functools.partial(self.__grad__, lib_=lib_)
41
+ self.result = Tensor(out)
42
+ return self.result, partial(self.__grad__, lib_=lib_)
49
43
 
50
44
  def __grad__(self, lib_: ModuleType) -> _GradType:
51
45
  if self.result is None or self.result.grad is None:
52
46
  raise RuntimeError("embedding backward called before forward.")
47
+
53
48
  if self._indices is None or self._num_embeddings is None:
54
49
  raise RuntimeError("embedding cached data missing.")
55
50
 
@@ -58,15 +53,23 @@ class embedding_kernel(Operation):
58
53
  grad_flat = grad_out.reshape(idx.shape[0], -1)
59
54
 
60
55
  if lib_ is np:
56
+ if self.padding_idx >= 0:
57
+ keep = idx != self.padding_idx
58
+ idx = idx[keep]
59
+ grad_flat = grad_flat[keep]
60
+
61
61
  grad_w = np.zeros(
62
62
  (self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
63
63
  )
64
64
  np.add.at(grad_w, idx, grad_flat)
65
+
65
66
  else:
66
67
  grad_w = mx.zeros(
67
68
  (self._num_embeddings, grad_flat.shape[1]), dtype=grad_out.dtype
68
69
  )
69
70
  for i in range(idx.shape[0]):
71
+ if self.padding_idx >= 0 and int(idx[i]) == self.padding_idx:
72
+ continue
70
73
  grad_w = grad_w.at[idx[i]].add(grad_flat[i])
71
74
 
72
75
  return None, grad_w
@@ -1,3 +1,5 @@
1
+ import numpy as np
2
+
1
3
  import lucid
2
4
  import lucid.nn.functional
3
5
 
@@ -5,6 +7,7 @@ from lucid._tensor import Tensor
5
7
  from lucid.types import _Scalar, Numeric
6
8
 
7
9
  from lucid.nn._kernel.embedding import embedding_kernel
10
+ from lucid._backend.metal import mx
8
11
 
9
12
 
10
13
  def _interpolate_bilinear(
@@ -131,17 +134,46 @@ def embedding(
131
134
  max_norm: float | None = None,
132
135
  norm_type: float = 2.0,
133
136
  ) -> Tensor:
137
+ num_embeddings = int(weight.shape[0])
138
+ if padding_idx is None:
139
+ pad = -1
140
+ else:
141
+ pad = int(padding_idx)
142
+ if pad < 0:
143
+ pad += num_embeddings
144
+ if pad < 0 or pad >= num_embeddings:
145
+ raise IndexError("padding_idx out of range.")
146
+
134
147
  indices = input_.astype(lucid.Int)
135
- op = embedding_kernel()
136
- output = op(indices, weight)
137
- if padding_idx is not None:
138
- mask = input_.data == padding_idx
139
- output *= 1 - mask[..., None]
148
+ idx_data = indices.data
149
+
150
+ if (idx_data < 0).any() or (idx_data >= num_embeddings).any():
151
+ raise IndexError("embedding indices out of range.")
140
152
 
141
153
  if max_norm is not None:
142
- norm = (output**norm_type).sum(axis=-1, keepdims=True) ** (1 / norm_type)
143
- scaling = max_norm / (norm + (norm == 0))
144
- output *= scaling
154
+ lib_ = np if weight.is_cpu() else mx
155
+ flat = idx_data.reshape(-1)
156
+
157
+ w = weight.data[flat]
158
+ if norm_type <= 0:
159
+ raise ValueError("norm_type must be positive.")
160
+
161
+ norms = (lib_.abs(w) ** norm_type).sum(axis=1) ** (1.0 / norm_type)
162
+ scale = lib_.minimum(1.0, max_norm / (norms + (norms == 0)))
163
+
164
+ if pad >= 0:
165
+ mask = flat == pad
166
+ mask_f = mask.astype(scale.dtype)
167
+ scale = scale * (1 - mask_f) + mask_f
168
+
169
+ weight.data[flat] = w * scale[:, None]
170
+
171
+ op = embedding_kernel(padding_idx=pad)
172
+ output = op(indices, weight)
173
+
174
+ if pad >= 0:
175
+ mask = input_.data == pad
176
+ output *= 1 - mask[..., None]
145
177
 
146
178
  return output
147
179