lucid-dl 2.7.9__py3-none-any.whl → 2.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/nn/modules/__init__.py +1 -0
- lucid/nn/modules/rnn.py +97 -0
- lucid/optim/lr_scheduler/_schedulers.py +3 -1
- {lucid_dl-2.7.9.dist-info → lucid_dl-2.8.0.dist-info}/METADATA +1 -1
- {lucid_dl-2.7.9.dist-info → lucid_dl-2.8.0.dist-info}/RECORD +8 -7
- {lucid_dl-2.7.9.dist-info → lucid_dl-2.8.0.dist-info}/WHEEL +0 -0
- {lucid_dl-2.7.9.dist-info → lucid_dl-2.8.0.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.7.9.dist-info → lucid_dl-2.8.0.dist-info}/top_level.txt +0 -0
lucid/nn/modules/__init__.py
CHANGED
lucid/nn/modules/rnn.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import lucid
|
|
4
|
+
import lucid.nn as nn
|
|
5
|
+
import lucid.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from lucid._tensor import Tensor
|
|
8
|
+
|
|
9
|
+
from .activation import Tanh, ReLU
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = ["RNNCell"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_activation(nonlinearity: str) -> type[nn.Module]:
|
|
16
|
+
if nonlinearity == "tanh":
|
|
17
|
+
return Tanh
|
|
18
|
+
elif nonlinearity == "relu":
|
|
19
|
+
return ReLU
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Invalid nonlinearity '{nonlinearity}'. "
|
|
23
|
+
"Supported nonlinearities are 'tanh' and 'relu'."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RNNCell(nn.Module):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
input_size: int,
|
|
31
|
+
hidden_size: int,
|
|
32
|
+
bias: bool = True,
|
|
33
|
+
nonlinearity: Literal["tanh", "relu"] = "tanh",
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.input_size = input_size
|
|
37
|
+
self.hidden_size = hidden_size
|
|
38
|
+
self.bias = bias
|
|
39
|
+
self.nonlinearity = _get_activation(nonlinearity)()
|
|
40
|
+
|
|
41
|
+
sqrt_k = 1.0 / (hidden_size**0.5)
|
|
42
|
+
self.weight_ih = nn.Parameter(
|
|
43
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.input_size))
|
|
44
|
+
)
|
|
45
|
+
self.weight_hh = nn.Parameter(
|
|
46
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.hidden_size))
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if self.bias:
|
|
50
|
+
self.bias_ih = nn.Parameter(
|
|
51
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
|
|
52
|
+
)
|
|
53
|
+
self.bias_hh = nn.Parameter(
|
|
54
|
+
lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
self.bias_ih = None
|
|
58
|
+
self.bias_hh = None
|
|
59
|
+
|
|
60
|
+
def forward(self, input_: Tensor, hx: Tensor | None = None) -> Tensor:
|
|
61
|
+
if input_.ndim not in (1, 2):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"RNNCell expected input with 1 or 2 dimensions, "
|
|
64
|
+
f"got {input_.ndim} dimensions"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
is_batched = input_.ndim == 2
|
|
68
|
+
if not is_batched:
|
|
69
|
+
input_ = input_.unsqueeze(axis=0)
|
|
70
|
+
batch_size = input_.shape[0]
|
|
71
|
+
|
|
72
|
+
if hx is None:
|
|
73
|
+
hx = lucid.zeros(
|
|
74
|
+
batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
if hx.ndim not in (1, 2):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"RNNCell expected hidden state with 1 or 2 dimensions, "
|
|
80
|
+
f"got {hx.ndim} dimensions"
|
|
81
|
+
)
|
|
82
|
+
if hx.ndim == 1:
|
|
83
|
+
hx = hx.unsqueeze(axis=0)
|
|
84
|
+
|
|
85
|
+
if hx.shape[0] != batch_size or hx.shape[1] != self.hidden_size:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"RNNCell expected hidden state with shape "
|
|
88
|
+
f"({batch_size}, {self.hidden_size}), got {hx.shape}"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
hy = F.linear(input_, self.weight_ih, self.bias_ih)
|
|
92
|
+
hy += F.linear(hx, self.weight_hh, self.bias_hh)
|
|
93
|
+
ret = self.nonlinearity(hy)
|
|
94
|
+
|
|
95
|
+
if not is_batched:
|
|
96
|
+
ret = ret.squeeze(axis=0)
|
|
97
|
+
return ret
|
|
@@ -298,4 +298,6 @@ class NoamScheduler(LRScheduler):
|
|
|
298
298
|
decay_term = step_num**-0.5
|
|
299
299
|
lr_factor = scale * min(decay_term, warmup_term)
|
|
300
300
|
|
|
301
|
-
|
|
301
|
+
# Noam's schedule computes the absolute learning rate, so we ignore
|
|
302
|
+
# the optimizer's initial lr (base_lr) when returning the new values.
|
|
303
|
+
return [lr_factor for _ in self.base_lrs]
|
|
@@ -89,7 +89,7 @@ lucid/nn/functional/_spatial.py,sha256=lazoSvVMFcauBWRbMOqmkgixA5bDes6scGHVWCgVm
|
|
|
89
89
|
lucid/nn/functional/_util.py,sha256=oaMbR76XuFrFtEjLCUEQBPgfFObP98WGnkGQLtFz2uk,4949
|
|
90
90
|
lucid/nn/init/__init__.py,sha256=YFi-HD2TEglweJ-gyX3n4UVZYzd70gcUi1dBu6hnOAY,1533
|
|
91
91
|
lucid/nn/init/_dist.py,sha256=zk4IoECjCvs-U4DfN17-6cqBu0-fpunAWYAKT4YJPRE,2023
|
|
92
|
-
lucid/nn/modules/__init__.py,sha256=
|
|
92
|
+
lucid/nn/modules/__init__.py,sha256=mol5Gfy-3ab5hBYZRxX0vjiI0w5VyKtBxVwj_vrOAZs,285
|
|
93
93
|
lucid/nn/modules/activation.py,sha256=CpiKpzgZHoCp8UO5taCJ9BuwFz5mYUs0o1_TQcEwQbQ,2823
|
|
94
94
|
lucid/nn/modules/attention.py,sha256=pZi7IGsNFu2xCmeLMuyWgveMyi2QXtaKRKQ70yAeE0c,4407
|
|
95
95
|
lucid/nn/modules/conv.py,sha256=KbtInQgKSw3U_qXiqy7x53DZM9YAMUq7sFas1nV7NxY,13932
|
|
@@ -99,6 +99,7 @@ lucid/nn/modules/linear.py,sha256=87cuFWYct9JlmtVC3jGR-8eouxxzANaVA6cd7p9r2Ho,28
|
|
|
99
99
|
lucid/nn/modules/loss.py,sha256=pjEMIruhtpTHhHFsNThS9LFz-aI_DAXLqMV8KRXydEg,3431
|
|
100
100
|
lucid/nn/modules/norm.py,sha256=qaaVQ2vfOUkPRLTHT4hgsRNxxN1--kdEhlrKXJmE--w,6803
|
|
101
101
|
lucid/nn/modules/pool.py,sha256=ymVnS2NZjh08Tw0VeOfkB6AVrMeLmCKvgxkmEO3KUuw,5044
|
|
102
|
+
lucid/nn/modules/rnn.py,sha256=iMwWLtTC9i3k_h8pI6to-2rJfiqkXX9hzeDDJT3i7XU,2968
|
|
102
103
|
lucid/nn/modules/sparse.py,sha256=EpjiviED2nI55wUjh1twFwa4Lvlrzw0TR6lpCDGeSbo,1147
|
|
103
104
|
lucid/nn/modules/transformer.py,sha256=z56emF_eX18pxRELjfmmsY-7Bn9h2yjIdxCaxs6YDwA,11246
|
|
104
105
|
lucid/nn/modules/vision.py,sha256=8xYasT7TNj4NXwMwwJIw1nbV1paeWEFg_ZohXn9kZBg,1579
|
|
@@ -110,7 +111,7 @@ lucid/optim/prop.py,sha256=CbsWmoBb_g_8z16M3T6dMoSR9c72hm8M375IT1UHjpw,4740
|
|
|
110
111
|
lucid/optim/sgd.py,sha256=DBZ1ZXQ9TfKZCRECfNRMDH9mvqUWCOPdY5TobnVxpz8,4477
|
|
111
112
|
lucid/optim/lr_scheduler/__init__.py,sha256=kUoyN2g9nwTtEAqEVij832WSRvzEpKZywSJdfD7MQvY,58
|
|
112
113
|
lucid/optim/lr_scheduler/_base.py,sha256=NNJnjwmJpsRXathrbLtH4tjfBHtwOiJ5HwF1_S6Ym5c,3092
|
|
113
|
-
lucid/optim/lr_scheduler/_schedulers.py,sha256=
|
|
114
|
+
lucid/optim/lr_scheduler/_schedulers.py,sha256=x6naustFYJUD8SEwdwzP8Wv4pDXWg-yp1HD5scvf1ZY,9365
|
|
114
115
|
lucid/random/__init__.py,sha256=s8EAaKhEiTKT_vYjP4IFHx0xQVa1jqc_qIyvMauUu7M,2727
|
|
115
116
|
lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
|
|
116
117
|
lucid/transforms/__init__.py,sha256=DGznMbqhXdU9FLDMKnJawScO4HCqu40Sf_j4vJGJrjc,90
|
|
@@ -120,8 +121,8 @@ lucid/visual/__init__.py,sha256=6TuFDfmXTwpLyHl7_KqBfdzW6zqHjGzIFvymjFPlvjI,21
|
|
|
120
121
|
lucid/visual/graph.py,sha256=YjpIDM_lloZARw3sCBiXPl_hT5A2gTk2fEHvwvJWXTk,4599
|
|
121
122
|
lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
|
|
122
123
|
lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
|
|
123
|
-
lucid_dl-2.
|
|
124
|
-
lucid_dl-2.
|
|
125
|
-
lucid_dl-2.
|
|
126
|
-
lucid_dl-2.
|
|
127
|
-
lucid_dl-2.
|
|
124
|
+
lucid_dl-2.8.0.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
|
|
125
|
+
lucid_dl-2.8.0.dist-info/METADATA,sha256=zLT-u67ODSZMq4vod9l1GqZCQN-0XDl_lKXSTdVcSzY,11519
|
|
126
|
+
lucid_dl-2.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
127
|
+
lucid_dl-2.8.0.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
|
|
128
|
+
lucid_dl-2.8.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|