lucid-dl 2.8.0__tar.gz → 2.8.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lucid_dl-2.8.0/lucid_dl.egg-info → lucid_dl-2.8.5}/PKG-INFO +10 -10
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/README.md +9 -9
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/linalg/_func.py +3 -14
- lucid_dl-2.8.5/lucid/nn/modules/rnn.py +529 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5/lucid_dl.egg-info}/PKG-INFO +10 -10
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/setup.py +1 -1
- lucid_dl-2.8.0/lucid/nn/modules/rnn.py +0 -97
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/LICENSE +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_backend/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_backend/core.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_backend/metal.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_func/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_func/bfunc.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_func/gfunc.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_func/ufunc.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_tensor/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_tensor/tensor.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_tensor/tensor_ops.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_util/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/_util/func.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/data/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/data/_base.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/data/_util.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/datasets/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/datasets/_base.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/datasets/cifar.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/datasets/mnist.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/einops/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/einops/_func.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/error.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/linalg/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/alex.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/coatnet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/convnext.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/crossvit.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/cspnet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/cvt.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/dense.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/efficient.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/efficientformer.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/inception.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/inception_next.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/inception_res.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/lenet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/maxvit.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/mobile.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/pvt.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/resnest.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/resnet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/resnext.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/senet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/sknet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/swin.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/vgg.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/vit.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/xception.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imgclf/zfnet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imggen/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imggen/ddpm.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/imggen/vae.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/detr.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/efficientdet.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/fast_rcnn.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/faster_rcnn.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/rcnn.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/util.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/yolo/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/seq2seq/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/seq2seq/transformer.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/models/util.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_activation.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_attention.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_conv.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_drop.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_linear.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_loss.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_norm.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_pool.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_spatial.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/functional/_util.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/fused.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/init/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/init/_dist.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/module.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/activation.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/attention.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/conv.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/drop.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/einops.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/linear.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/loss.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/norm.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/pool.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/sparse.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/transformer.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/modules/vision.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/parameter.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/nn/util.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/_base.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/ada.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/adam.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/lr_scheduler/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/lr_scheduler/_base.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/prop.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/optim/sgd.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/port.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/random/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/random/_func.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/transforms/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/transforms/_base.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/transforms/image.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/types.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/visual/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/visual/graph.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/weights/__init__.py +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid/weights/__init__.pyi +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid_dl.egg-info/SOURCES.txt +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid_dl.egg-info/dependency_links.txt +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid_dl.egg-info/requires.txt +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/lucid_dl.egg-info/top_level.txt +0 -0
- {lucid_dl-2.8.0 → lucid_dl-2.8.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.8.
|
|
3
|
+
Version: 2.8.5
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -30,29 +30,29 @@ Dynamic: summary
|
|
|
30
30
|
|
|
31
31
|

|
|
32
32
|

|
|
33
|
-

|
|
34
34
|

|
|
35
35
|

|
|
36
|
-

|
|
37
37
|
|
|
38
38
|
**Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
|
|
39
39
|
|
|
40
40
|
Whether you're a student, educator, or an advanced researcher seeking to demystify deep learning internals, Lucid provides a transparent and highly introspectable API that faithfully replicates key behaviors of major frameworks like PyTorch, yet in a form simple enough to study line by line.
|
|
41
41
|
|
|
42
|
-
[📑 Lucid Documentation](https://chanlumerico.github.io/lucid/build/html/index.html) |
|
|
42
|
+
[📑 Lucid Documentation](https://chanlumerico.github.io/lucid/build/html/index.html) | [✏️ Lucid DevLog](https://velog.io/@lumerico284/series/Lucid-Development) |
|
|
43
43
|
[🤗 Lucid Huggingface](https://huggingface.co/ChanLumerico/lucid)
|
|
44
44
|
|
|
45
|
+
#### Other Languages
|
|
46
|
+
|
|
47
|
+
[🇰🇷 README.md in Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
|
|
48
|
+
|
|
45
49
|
### 🔥 What's New
|
|
46
50
|
|
|
47
51
|
- Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
|
|
48
52
|
|
|
49
|
-
- Added
|
|
50
|
-
|
|
51
|
-
- `nn.util.grad_norm` - Returns the global norm of the gradients
|
|
52
|
-
- `nn.util.clip_grad_norm` - Rescales the gradients based on the global norm
|
|
53
|
-
- `nn.util.clip_grad_value` - Rescales the gradients based on their values.
|
|
53
|
+
- Added new neural module category `nn.rnn`, including:
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
`nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
|
|
56
56
|
|
|
57
57
|
## 🔧 How to Install
|
|
58
58
|
|
|
@@ -2,29 +2,29 @@
|
|
|
2
2
|
|
|
3
3
|

|
|
4
4
|

|
|
5
|
-

|
|
6
6
|

|
|
7
7
|

|
|
8
|
-

|
|
9
9
|
|
|
10
10
|
**Lucid** is a minimalist deep learning framework built entirely from scratch in Python. It offers a pedagogically rich environment to explore the foundations of modern deep learning systems, including autodiff, neural network modules, and GPU acceleration — all while staying lightweight, readable, and free of complex dependencies.
|
|
11
11
|
|
|
12
12
|
Whether you're a student, educator, or an advanced researcher seeking to demystify deep learning internals, Lucid provides a transparent and highly introspectable API that faithfully replicates key behaviors of major frameworks like PyTorch, yet in a form simple enough to study line by line.
|
|
13
13
|
|
|
14
|
-
[📑 Lucid Documentation](https://chanlumerico.github.io/lucid/build/html/index.html) |
|
|
14
|
+
[📑 Lucid Documentation](https://chanlumerico.github.io/lucid/build/html/index.html) | [✏️ Lucid DevLog](https://velog.io/@lumerico284/series/Lucid-Development) |
|
|
15
15
|
[🤗 Lucid Huggingface](https://huggingface.co/ChanLumerico/lucid)
|
|
16
16
|
|
|
17
|
+
#### Other Languages
|
|
18
|
+
|
|
19
|
+
[🇰🇷 README.md in Korean](https://github.com/ChanLumerico/lucid/blob/main/README.kr.md)
|
|
20
|
+
|
|
17
21
|
### 🔥 What's New
|
|
18
22
|
|
|
19
23
|
- Now supports [**`Safetensors`**](https://github.com/huggingface/safetensors) for Lucid neural module porting along with the legacy `.lcd` format
|
|
20
24
|
|
|
21
|
-
- Added
|
|
22
|
-
|
|
23
|
-
- `nn.util.grad_norm` - Returns the global norm of the gradients
|
|
24
|
-
- `nn.util.clip_grad_norm` - Rescales the gradients based on the global norm
|
|
25
|
-
- `nn.util.clip_grad_value` - Rescales the gradients based on their values.
|
|
25
|
+
- Added new neural module category `nn.rnn`, including:
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
`nn.RNNBase`, `nn.RNN`, `nn.LSTM`, `nn.GRU`, `nn.RNNCell`, `nn.LSTMCell`, `nn.GRUCell`
|
|
28
28
|
|
|
29
29
|
## 🔧 How to Install
|
|
30
30
|
|
|
@@ -142,7 +142,6 @@ class cholesky(operation):
|
|
|
142
142
|
return int((1 / 3) * a.shape[-1] ** 3)
|
|
143
143
|
|
|
144
144
|
|
|
145
|
-
@fallback
|
|
146
145
|
class norm(operation):
|
|
147
146
|
def __init__(
|
|
148
147
|
self,
|
|
@@ -168,20 +167,10 @@ class norm(operation):
|
|
|
168
167
|
|
|
169
168
|
@unary_func_op(device="gpu")
|
|
170
169
|
def gpu(self, a: Tensor) -> _FuncOpReturnType:
|
|
171
|
-
|
|
172
|
-
|
|
170
|
+
mx_ord = self.ord if not (self.ord == 2 and a.ndim > 2) else None
|
|
171
|
+
self.result = Tensor(
|
|
172
|
+
mx.linalg.norm(a.data, ord=mx_ord, axis=self.axis, keepdims=self.keepdims)
|
|
173
173
|
)
|
|
174
|
-
if fallback_:
|
|
175
|
-
result_data = np.linalg.norm(
|
|
176
|
-
a.data, ord=self.ord, axis=self.axis, keepdims=self.keepdims
|
|
177
|
-
)
|
|
178
|
-
self.result = Tensor(result_data, device="gpu")
|
|
179
|
-
else:
|
|
180
|
-
result_data = mx.linalg.norm(
|
|
181
|
-
a.data, ord=self.ord, axis=self.axis, keepdims=self.keepdims
|
|
182
|
-
)
|
|
183
|
-
self.result = Tensor(result_data)
|
|
184
|
-
|
|
185
174
|
return self.result, partial(self.__grad__, a=a, lib_=mx)
|
|
186
175
|
|
|
187
176
|
def __grad__(self, a: Tensor, lib_: ModuleType) -> _GradFuncType:
|
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import lucid
|
|
4
|
+
import lucid.nn as nn
|
|
5
|
+
import lucid.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from lucid._tensor import Tensor
|
|
8
|
+
from lucid.types import Numeric, _DeviceType
|
|
9
|
+
|
|
10
|
+
from .activation import Tanh, ReLU
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = ["RNNCell", "LSTMCell", "GRUCell", "RNNBase", "RNN", "LSTM", "GRU"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_activation(nonlinearity: str) -> type[nn.Module]:
|
|
17
|
+
if nonlinearity == "tanh":
|
|
18
|
+
return Tanh
|
|
19
|
+
elif nonlinearity == "relu":
|
|
20
|
+
return ReLU
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Invalid nonlinearity '{nonlinearity}'. "
|
|
24
|
+
"Supported nonlinearities are 'tanh' and 'relu'."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RNNCell(nn.Module):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
input_size: int,
|
|
32
|
+
hidden_size: int,
|
|
33
|
+
bias: bool = True,
|
|
34
|
+
nonlinearity: Literal["tanh", "relu"] = "tanh",
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.input_size = input_size
|
|
38
|
+
self.hidden_size = hidden_size
|
|
39
|
+
self.bias = bias
|
|
40
|
+
self.nonlinearity = _get_activation(nonlinearity)()
|
|
41
|
+
|
|
42
|
+
sqrt_k = 1.0 / (hidden_size**0.5)
|
|
43
|
+
self.weight_ih = nn.Parameter(
|
|
44
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.input_size))
|
|
45
|
+
)
|
|
46
|
+
self.weight_hh = nn.Parameter(
|
|
47
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.hidden_size))
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if self.bias:
|
|
51
|
+
self.bias_ih = nn.Parameter(
|
|
52
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
|
|
53
|
+
)
|
|
54
|
+
self.bias_hh = nn.Parameter(
|
|
55
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
self.bias_ih = None
|
|
59
|
+
self.bias_hh = None
|
|
60
|
+
|
|
61
|
+
def forward(self, input_: Tensor, hx: Tensor | None = None) -> Tensor:
|
|
62
|
+
if input_.ndim not in (1, 2):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"RNNCell expected input with 1 or 2 dimensions, "
|
|
65
|
+
f"got {input_.ndim} dimensions"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
is_batched = input_.ndim == 2
|
|
69
|
+
if not is_batched:
|
|
70
|
+
input_ = input_.unsqueeze(axis=0)
|
|
71
|
+
batch_size = input_.shape[0]
|
|
72
|
+
|
|
73
|
+
if hx is None:
|
|
74
|
+
hx = lucid.zeros(
|
|
75
|
+
batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
if hx.ndim not in (1, 2):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"RNNCell expected hidden state with 1 or 2 dimensions, "
|
|
81
|
+
f"got {hx.ndim} dimensions"
|
|
82
|
+
)
|
|
83
|
+
if hx.ndim == 1:
|
|
84
|
+
hx = hx.unsqueeze(axis=0)
|
|
85
|
+
|
|
86
|
+
if hx.shape[0] != batch_size or hx.shape[1] != self.hidden_size:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"RNNCell expected hidden state with shape "
|
|
89
|
+
f"({batch_size}, {self.hidden_size}), got {hx.shape}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
hy = F.linear(input_, self.weight_ih, self.bias_ih)
|
|
93
|
+
hy += F.linear(hx, self.weight_hh, self.bias_hh)
|
|
94
|
+
ret = self.nonlinearity(hy)
|
|
95
|
+
|
|
96
|
+
if not is_batched:
|
|
97
|
+
ret = ret.squeeze(axis=0)
|
|
98
|
+
return ret
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class LSTMCell(nn.Module):
|
|
102
|
+
def __init__(
|
|
103
|
+
self, input_size: int, hidden_size: int, bias: bool = True, **kwargs
|
|
104
|
+
) -> None:
|
|
105
|
+
super().__init__()
|
|
106
|
+
self.input_size = input_size
|
|
107
|
+
self.hidden_size = hidden_size
|
|
108
|
+
self.bias = bias
|
|
109
|
+
|
|
110
|
+
sqrt_k = 1.0 / (hidden_size**0.5)
|
|
111
|
+
self.weight_ih = nn.Parameter(
|
|
112
|
+
lucid.random.uniform(
|
|
113
|
+
-sqrt_k, sqrt_k, (4 * self.hidden_size, self.input_size)
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
self.weight_hh = nn.Parameter(
|
|
117
|
+
lucid.random.uniform(
|
|
118
|
+
-sqrt_k, sqrt_k, (4 * self.hidden_size, self.hidden_size)
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if self.bias:
|
|
123
|
+
self.bias_ih = nn.Parameter(
|
|
124
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, 4 * self.hidden_size)
|
|
125
|
+
)
|
|
126
|
+
self.bias_hh = nn.Parameter(
|
|
127
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, 4 * self.hidden_size)
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
self.bias_ih = None
|
|
131
|
+
self.bias_hh = None
|
|
132
|
+
|
|
133
|
+
def forward(
|
|
134
|
+
self, input_: Tensor, hx: tuple[Tensor, Tensor] | None = None
|
|
135
|
+
) -> tuple[Tensor, Tensor]:
|
|
136
|
+
if input_.ndim not in (1, 2):
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"LSTMCell expected input with 1 or 2 dimensions, "
|
|
139
|
+
f"got {input_.ndim} dimensions"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
is_batched = input_.ndim == 2
|
|
143
|
+
if not is_batched:
|
|
144
|
+
input_ = input_.unsqueeze(axis=0)
|
|
145
|
+
batch_size = input_.shape[0]
|
|
146
|
+
|
|
147
|
+
if hx is None:
|
|
148
|
+
h_t = lucid.zeros(
|
|
149
|
+
batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
|
|
150
|
+
)
|
|
151
|
+
c_t = lucid.zeros(
|
|
152
|
+
batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
h_t, c_t = hx
|
|
156
|
+
if h_t.ndim not in (1, 2) or c_t.ndim not in (1, 2):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"LSTMCell expected hidden state and cell state with 1 or 2 dimensions"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if h_t.ndim == 1:
|
|
162
|
+
h_t = h_t.unsqueeze(axis=0)
|
|
163
|
+
if c_t.ndim == 1:
|
|
164
|
+
c_t = c_t.unsqueeze(axis=0)
|
|
165
|
+
|
|
166
|
+
if h_t.shape[0] != batch_size or h_t.shape[1] != self.hidden_size:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
"LSTMCell expected hidden state with shape "
|
|
169
|
+
f"({batch_size}, {self.hidden_size}), got {h_t.shape}"
|
|
170
|
+
)
|
|
171
|
+
if c_t.shape[0] != batch_size or c_t.shape[1] != self.hidden_size:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"LSTMCell expected cell state with shape "
|
|
174
|
+
f"({batch_size}, {self.hidden_size}), got {c_t.shape}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
gates = F.linear(input_, self.weight_ih, self.bias_ih)
|
|
178
|
+
gates += F.linear(h_t, self.weight_hh, self.bias_hh)
|
|
179
|
+
|
|
180
|
+
i_t, f_t, g_t, o_t = lucid.split(gates, 4, axis=1)
|
|
181
|
+
i_t = F.sigmoid(i_t)
|
|
182
|
+
f_t = F.sigmoid(f_t)
|
|
183
|
+
g_t = F.tanh(g_t)
|
|
184
|
+
o_t = F.sigmoid(o_t)
|
|
185
|
+
|
|
186
|
+
c_t = f_t * c_t + i_t * g_t
|
|
187
|
+
h_t = o_t * F.tanh(c_t)
|
|
188
|
+
|
|
189
|
+
if not is_batched:
|
|
190
|
+
h_t = h_t.squeeze(axis=0)
|
|
191
|
+
c_t = c_t.squeeze(axis=0)
|
|
192
|
+
return h_t, c_t
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class GRUCell(nn.Module):
|
|
196
|
+
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|
197
|
+
super().__init__()
|
|
198
|
+
self.input_size = input_size
|
|
199
|
+
self.hidden_size = hidden_size
|
|
200
|
+
self.bias = bias
|
|
201
|
+
|
|
202
|
+
sqrt_k = 1.0 / (hidden_size**0.5)
|
|
203
|
+
self.weight_ih = nn.Parameter(
|
|
204
|
+
lucid.random.uniform(
|
|
205
|
+
-sqrt_k, sqrt_k, (3 * self.hidden_size, self.input_size)
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
self.weight_hh = nn.Parameter(
|
|
209
|
+
lucid.random.uniform(
|
|
210
|
+
-sqrt_k, sqrt_k, (3 * self.hidden_size, self.hidden_size)
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if self.bias:
|
|
215
|
+
self.bias_ih = nn.Parameter(
|
|
216
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, 3 * self.hidden_size)
|
|
217
|
+
)
|
|
218
|
+
self.bias_hh = nn.Parameter(
|
|
219
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, 3 * self.hidden_size)
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
self.bias_ih = None
|
|
223
|
+
self.bias_hh = None
|
|
224
|
+
|
|
225
|
+
def forward(self, input_: Tensor, hx: Tensor | None = None) -> Tensor:
|
|
226
|
+
if input_.ndim not in (1, 2):
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"GRUCell expected input with 1 or 2 dimensions, "
|
|
229
|
+
f"got {input_.ndim} dimensions"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
is_batched = input_.ndim == 2
|
|
233
|
+
if not is_batched:
|
|
234
|
+
input_ = input_.unsqueeze(axis=0)
|
|
235
|
+
batch_size = input_.shape[0]
|
|
236
|
+
|
|
237
|
+
if hx is None:
|
|
238
|
+
hx = lucid.zeros(
|
|
239
|
+
batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
if hx.ndim not in (1, 2):
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"GRUCell expected hidden state with 1 or 2 dimensions, "
|
|
245
|
+
f"got {hx.ndim} dimensions"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if hx.ndim == 1:
|
|
249
|
+
hx = hx.unsqueeze(axis=0)
|
|
250
|
+
if hx.shape[0] != batch_size or hx.shape[1] != self.hidden_size:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"GRUCell expected hidden state with shape "
|
|
253
|
+
f"({batch_size}, {self.hidden_size}), got {hx.shape}"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
input_gates = F.linear(input_, self.weight_ih, self.bias_ih)
|
|
257
|
+
hidden_gates = F.linear(hx, self.weight_hh, self.bias_hh)
|
|
258
|
+
|
|
259
|
+
i_r, i_z, i_n = lucid.split(input_gates, 3, axis=1)
|
|
260
|
+
h_r, h_z, h_n = lucid.split(hidden_gates, 3, axis=1)
|
|
261
|
+
|
|
262
|
+
r_t = F.sigmoid(i_r + h_r)
|
|
263
|
+
z_t = F.sigmoid(i_z + h_z)
|
|
264
|
+
n_t = F.tanh(i_n + r_t * h_n)
|
|
265
|
+
|
|
266
|
+
h_t = (1 - z_t) * n_t + z_t * hx
|
|
267
|
+
|
|
268
|
+
if not is_batched:
|
|
269
|
+
h_t = h_t.squeeze(axis=0)
|
|
270
|
+
return h_t
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class RNNBase(nn.Module):
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
mode: Literal["RNN_TANH", "RNN_RELU", "LSTM", "GRU"],
|
|
277
|
+
input_size: int,
|
|
278
|
+
hidden_size: int,
|
|
279
|
+
num_layers: int = 1,
|
|
280
|
+
bias: bool = True,
|
|
281
|
+
batch_first: bool = False,
|
|
282
|
+
dropout: float = 0.0,
|
|
283
|
+
) -> None:
|
|
284
|
+
super().__init__()
|
|
285
|
+
self.is_lstm = False
|
|
286
|
+
cell_kwargs = {}
|
|
287
|
+
nonlinearity = "tanh"
|
|
288
|
+
|
|
289
|
+
if mode == "RNN_TANH":
|
|
290
|
+
cell_cls = RNNCell
|
|
291
|
+
cell_kwargs: dict[str, object] = {"nonlinearity": nonlinearity}
|
|
292
|
+
elif mode == "RNN_RELU":
|
|
293
|
+
nonlinearity = "relu"
|
|
294
|
+
cell_cls = RNNCell
|
|
295
|
+
cell_kwargs = {"nonlinearity": nonlinearity}
|
|
296
|
+
elif mode == "LSTM":
|
|
297
|
+
cell_cls = LSTMCell
|
|
298
|
+
self.is_lstm = True
|
|
299
|
+
elif mode == "GRU":
|
|
300
|
+
cell_cls = GRUCell
|
|
301
|
+
else:
|
|
302
|
+
raise ValueError(
|
|
303
|
+
f"Invalid mode '{mode}'. Supported modes are 'RNN_TANH', "
|
|
304
|
+
"'RNN_RELU', 'LSTM', or 'GRU'."
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
self.mode = mode
|
|
308
|
+
self.nonlinearity = nonlinearity
|
|
309
|
+
|
|
310
|
+
self.input_size = input_size
|
|
311
|
+
self.hidden_size = hidden_size
|
|
312
|
+
self.num_layers = num_layers
|
|
313
|
+
self.bias = bias
|
|
314
|
+
self.batch_first = batch_first
|
|
315
|
+
self.dropout = float(dropout)
|
|
316
|
+
|
|
317
|
+
layers: list[nn.Module] = []
|
|
318
|
+
for layer in range(num_layers):
|
|
319
|
+
layer_input_size = input_size if layer == 0 else hidden_size
|
|
320
|
+
layers.append(
|
|
321
|
+
cell_cls(
|
|
322
|
+
input_size=layer_input_size,
|
|
323
|
+
hidden_size=hidden_size,
|
|
324
|
+
bias=bias,
|
|
325
|
+
**cell_kwargs,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
self.layers = nn.ModuleList(layers)
|
|
329
|
+
|
|
330
|
+
def _init_hidden(
|
|
331
|
+
self, batch_size: int, dtype: Numeric, device: _DeviceType
|
|
332
|
+
) -> Tensor | tuple[Tensor, Tensor]:
|
|
333
|
+
if self.is_lstm:
|
|
334
|
+
h0 = lucid.zeros(
|
|
335
|
+
self.num_layers,
|
|
336
|
+
batch_size,
|
|
337
|
+
self.hidden_size,
|
|
338
|
+
dtype=dtype,
|
|
339
|
+
device=device,
|
|
340
|
+
)
|
|
341
|
+
c0 = lucid.zeros(
|
|
342
|
+
self.num_layers,
|
|
343
|
+
batch_size,
|
|
344
|
+
self.hidden_size,
|
|
345
|
+
dtype=dtype,
|
|
346
|
+
device=device,
|
|
347
|
+
)
|
|
348
|
+
return h0, c0
|
|
349
|
+
return lucid.zeros(
|
|
350
|
+
self.num_layers, batch_size, self.hidden_size, dtype=dtype, device=device
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def forward(
|
|
354
|
+
self, input_: Tensor, hx: Tensor | tuple[Tensor, Tensor] | None = None
|
|
355
|
+
) -> tuple[Tensor, Tensor] | tuple[Tensor, tuple[Tensor, Tensor]]:
|
|
356
|
+
if input_.ndim != 3:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"RNNBase expected input with 3 dimensions, got {input_.ndim} dimensions"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if self.batch_first:
|
|
362
|
+
input_ = input_.swapaxes(0, 1)
|
|
363
|
+
|
|
364
|
+
seq_len, batch_size, feat = input_.shape
|
|
365
|
+
if feat != self.input_size:
|
|
366
|
+
raise ValueError(
|
|
367
|
+
f"RNNBase expected input with feature size {self.input_size}, got {feat}"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if self.is_lstm:
|
|
371
|
+
if hx is None:
|
|
372
|
+
hx = self._init_hidden(batch_size, input_.dtype, input_.device)
|
|
373
|
+
if not (
|
|
374
|
+
isinstance(hx, (tuple, list))
|
|
375
|
+
and len(hx) == 2
|
|
376
|
+
and isinstance(hx[0], Tensor)
|
|
377
|
+
and isinstance(hx[1], Tensor)
|
|
378
|
+
):
|
|
379
|
+
raise ValueError("LSTM expects hx as a tuple of (h_0, c_0)")
|
|
380
|
+
|
|
381
|
+
h0, c0 = hx
|
|
382
|
+
if h0.ndim == 2:
|
|
383
|
+
h0 = h0.unsqueeze(axis=0)
|
|
384
|
+
if c0.ndim == 2:
|
|
385
|
+
c0 = c0.unsqueeze(axis=0)
|
|
386
|
+
|
|
387
|
+
if h0.ndim != 3 or c0.ndim != 3:
|
|
388
|
+
raise ValueError("LSTM expects h_0 and c_0 with 3 dimensions")
|
|
389
|
+
if h0.shape[0] != self.num_layers or c0.shape[0] != self.num_layers:
|
|
390
|
+
raise ValueError("Incorrect number of layers in h_0 or c_0")
|
|
391
|
+
if h0.shape[1] != batch_size or c0.shape[1] != batch_size:
|
|
392
|
+
raise ValueError("Incorrect batch size in h_0 or c_0")
|
|
393
|
+
if h0.shape[2] != self.hidden_size or c0.shape[2] != self.hidden_size:
|
|
394
|
+
raise ValueError("Incorrect hidden size in h_0 or c_0")
|
|
395
|
+
|
|
396
|
+
hx_h, hx_c = h0, c0
|
|
397
|
+
|
|
398
|
+
else:
|
|
399
|
+
if hx is None:
|
|
400
|
+
hx = self._init_hidden(batch_size, input_.dtype, input_.device)
|
|
401
|
+
if hx.ndim == 2:
|
|
402
|
+
hx = hx.unsqueeze(axis=0)
|
|
403
|
+
if hx.ndim != 3:
|
|
404
|
+
raise ValueError(
|
|
405
|
+
f"RNNBase expected hidden state with 3 dimensions, got {hx.ndim} dimensions"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
if hx.shape[0] != self.num_layers or hx.shape[1] != batch_size:
|
|
409
|
+
raise ValueError("hx has incorrect shape")
|
|
410
|
+
if hx.shape[2] != self.hidden_size:
|
|
411
|
+
raise ValueError("Incorrect hidden size in hx")
|
|
412
|
+
|
|
413
|
+
layer_input = input_
|
|
414
|
+
h_n_list: list[Tensor] = []
|
|
415
|
+
c_n_list: list[Tensor] | None = [] if self.is_lstm else None
|
|
416
|
+
|
|
417
|
+
for layer_idx, cell in enumerate(self.layers):
|
|
418
|
+
if self.is_lstm:
|
|
419
|
+
h_t = hx_h[layer_idx]
|
|
420
|
+
c_t = hx_c[layer_idx]
|
|
421
|
+
else:
|
|
422
|
+
h_t = hx[layer_idx]
|
|
423
|
+
outputs = []
|
|
424
|
+
|
|
425
|
+
for t in range(seq_len):
|
|
426
|
+
if self.is_lstm:
|
|
427
|
+
h_t, c_t = cell(layer_input[t], (h_t, c_t))
|
|
428
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
429
|
+
else:
|
|
430
|
+
h_t = cell(layer_input[t], h_t)
|
|
431
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
432
|
+
|
|
433
|
+
layer_output = lucid.concatenate(tuple(outputs), axis=0)
|
|
434
|
+
|
|
435
|
+
if self.training and self.dropout > 0.0 and layer_idx < self.num_layers - 1:
|
|
436
|
+
layer_output = F.dropout(layer_output, p=self.dropout)
|
|
437
|
+
|
|
438
|
+
h_n_list.append(h_t.unsqueeze(axis=0))
|
|
439
|
+
if self.is_lstm and c_n_list is not None:
|
|
440
|
+
c_n_list.append(c_t.unsqueeze(axis=0))
|
|
441
|
+
layer_input = layer_output
|
|
442
|
+
|
|
443
|
+
output = layer_input
|
|
444
|
+
h_n = lucid.concatenate(tuple(h_n_list), axis=0)
|
|
445
|
+
if self.is_lstm and c_n_list is not None:
|
|
446
|
+
c_n = lucid.concatenate(tuple(c_n_list), axis=0)
|
|
447
|
+
|
|
448
|
+
if self.batch_first:
|
|
449
|
+
output = output.swapaxes(0, 1)
|
|
450
|
+
|
|
451
|
+
if self.is_lstm and c_n_list is not None:
|
|
452
|
+
return output, (h_n, c_n)
|
|
453
|
+
return output, h_n
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class RNN(RNNBase):
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
input_size: int,
|
|
460
|
+
hidden_size: int,
|
|
461
|
+
num_layers: int = 1,
|
|
462
|
+
nonlinearity: Literal["tanh", "relu"] = "tanh",
|
|
463
|
+
bias: bool = True,
|
|
464
|
+
batch_first: bool = False,
|
|
465
|
+
dropout: float = 0.0,
|
|
466
|
+
) -> None:
|
|
467
|
+
if nonlinearity == "tanh":
|
|
468
|
+
mode = "RNN_TANH"
|
|
469
|
+
elif nonlinearity == "relu":
|
|
470
|
+
mode = "RNN_RELU"
|
|
471
|
+
else:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
f"Invalid nonlinearity '{nonlinearity}'. "
|
|
474
|
+
"Supported nonlinearities are 'tanh' and 'relu'."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
super().__init__(
|
|
478
|
+
mode=mode,
|
|
479
|
+
input_size=input_size,
|
|
480
|
+
hidden_size=hidden_size,
|
|
481
|
+
num_layers=num_layers,
|
|
482
|
+
bias=bias,
|
|
483
|
+
batch_first=batch_first,
|
|
484
|
+
dropout=dropout,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
class LSTM(RNNBase):
|
|
489
|
+
def __init__(
|
|
490
|
+
self,
|
|
491
|
+
input_size: int,
|
|
492
|
+
hidden_size: int,
|
|
493
|
+
num_layers: int = 1,
|
|
494
|
+
bias: bool = True,
|
|
495
|
+
batch_first: bool = False,
|
|
496
|
+
dropout: float = 0.0,
|
|
497
|
+
) -> None:
|
|
498
|
+
mode = "LSTM"
|
|
499
|
+
super().__init__(
|
|
500
|
+
mode=mode,
|
|
501
|
+
input_size=input_size,
|
|
502
|
+
hidden_size=hidden_size,
|
|
503
|
+
num_layers=num_layers,
|
|
504
|
+
bias=bias,
|
|
505
|
+
batch_first=batch_first,
|
|
506
|
+
dropout=dropout,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class GRU(RNNBase):
|
|
511
|
+
def __init__(
|
|
512
|
+
self,
|
|
513
|
+
input_size: int,
|
|
514
|
+
hidden_size: int,
|
|
515
|
+
num_layers: int = 1,
|
|
516
|
+
bias: bool = True,
|
|
517
|
+
batch_first: bool = False,
|
|
518
|
+
dropout: float = 0.0,
|
|
519
|
+
) -> None:
|
|
520
|
+
mode = "GRU"
|
|
521
|
+
super().__init__(
|
|
522
|
+
mode=mode,
|
|
523
|
+
input_size=input_size,
|
|
524
|
+
hidden_size=hidden_size,
|
|
525
|
+
num_layers=num_layers,
|
|
526
|
+
bias=bias,
|
|
527
|
+
batch_first=batch_first,
|
|
528
|
+
dropout=dropout,
|
|
529
|
+
)
|