lucid-dl 2.7.9__py3-none-any.whl → 2.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,3 +10,4 @@ from .attention import *
10
10
  from .transformer import *
11
11
  from .sparse import *
12
12
  from .einops import *
13
+ from .rnn import *
@@ -0,0 +1,97 @@
1
+ from typing import Literal
2
+
3
+ import lucid
4
+ import lucid.nn as nn
5
+ import lucid.nn.functional as F
6
+
7
+ from lucid._tensor import Tensor
8
+
9
+ from .activation import Tanh, ReLU
10
+
11
+
12
+ __all__ = ["RNNCell"]
13
+
14
+
15
+ def _get_activation(nonlinearity: str) -> type[nn.Module]:
16
+ if nonlinearity == "tanh":
17
+ return Tanh
18
+ elif nonlinearity == "relu":
19
+ return ReLU
20
+ else:
21
+ raise ValueError(
22
+ f"Invalid nonlinearity '{nonlinearity}'. "
23
+ "Supported nonlinearities are 'tanh' and 'relu'."
24
+ )
25
+
26
+
27
+ class RNNCell(nn.Module):
28
+ def __init__(
29
+ self,
30
+ input_size: int,
31
+ hidden_size: int,
32
+ bias: bool = True,
33
+ nonlinearity: Literal["tanh", "relu"] = "tanh",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.input_size = input_size
37
+ self.hidden_size = hidden_size
38
+ self.bias = bias
39
+ self.nonlinearity = _get_activation(nonlinearity)()
40
+
41
+ sqrt_k = 1.0 / (hidden_size**0.5)
42
+ self.weight_ih = nn.Parameter(
43
+ lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.input_size))
44
+ )
45
+ self.weight_hh = nn.Parameter(
46
+ lucid.random.uniform(-sqrt_k, sqrt_k, (self.hidden_size, self.hidden_size))
47
+ )
48
+
49
+ if self.bias:
50
+ self.bias_ih = nn.Parameter(
51
+ lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
52
+ )
53
+ self.bias_hh = nn.Parameter(
54
+ lucid.random.uniform(-sqrt_k, sqrt_k, self.hidden_size)
55
+ )
56
+ else:
57
+ self.bias_ih = None
58
+ self.bias_hh = None
59
+
60
+ def forward(self, input_: Tensor, hx: Tensor | None = None) -> Tensor:
61
+ if input_.ndim not in (1, 2):
62
+ raise ValueError(
63
+ "RNNCell expected input with 1 or 2 dimensions, "
64
+ f"got {input_.ndim} dimensions"
65
+ )
66
+
67
+ is_batched = input_.ndim == 2
68
+ if not is_batched:
69
+ input_ = input_.unsqueeze(axis=0)
70
+ batch_size = input_.shape[0]
71
+
72
+ if hx is None:
73
+ hx = lucid.zeros(
74
+ batch_size, self.hidden_size, dtype=input_.dtype, device=input_.device
75
+ )
76
+ else:
77
+ if hx.ndim not in (1, 2):
78
+ raise ValueError(
79
+ "RNNCell expected hidden state with 1 or 2 dimensions, "
80
+ f"got {hx.ndim} dimensions"
81
+ )
82
+ if hx.ndim == 1:
83
+ hx = hx.unsqueeze(axis=0)
84
+
85
+ if hx.shape[0] != batch_size or hx.shape[1] != self.hidden_size:
86
+ raise ValueError(
87
+ "RNNCell expected hidden state with shape "
88
+ f"({batch_size}, {self.hidden_size}), got {hx.shape}"
89
+ )
90
+
91
+ hy = F.linear(input_, self.weight_ih, self.bias_ih)
92
+ hy += F.linear(hx, self.weight_hh, self.bias_hh)
93
+ ret = self.nonlinearity(hy)
94
+
95
+ if not is_batched:
96
+ ret = ret.squeeze(axis=0)
97
+ return ret
@@ -298,4 +298,6 @@ class NoamScheduler(LRScheduler):
298
298
  decay_term = step_num**-0.5
299
299
  lr_factor = scale * min(decay_term, warmup_term)
300
300
 
301
- return [base_lr * lr_factor for base_lr in self.base_lrs]
301
+ # Noam's schedule computes the absolute learning rate, so we ignore
302
+ # the optimizer's initial lr (base_lr) when returning the new values.
303
+ return [lr_factor for _ in self.base_lrs]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.7.9
3
+ Version: 2.8.0
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -89,7 +89,7 @@ lucid/nn/functional/_spatial.py,sha256=lazoSvVMFcauBWRbMOqmkgixA5bDes6scGHVWCgVm
89
89
  lucid/nn/functional/_util.py,sha256=oaMbR76XuFrFtEjLCUEQBPgfFObP98WGnkGQLtFz2uk,4949
90
90
  lucid/nn/init/__init__.py,sha256=YFi-HD2TEglweJ-gyX3n4UVZYzd70gcUi1dBu6hnOAY,1533
91
91
  lucid/nn/init/_dist.py,sha256=zk4IoECjCvs-U4DfN17-6cqBu0-fpunAWYAKT4YJPRE,2023
92
- lucid/nn/modules/__init__.py,sha256=RSH073CXMrZDhkaLKp5KqxwNDBrUc9vP4cwq5-b95kk,266
92
+ lucid/nn/modules/__init__.py,sha256=mol5Gfy-3ab5hBYZRxX0vjiI0w5VyKtBxVwj_vrOAZs,285
93
93
  lucid/nn/modules/activation.py,sha256=CpiKpzgZHoCp8UO5taCJ9BuwFz5mYUs0o1_TQcEwQbQ,2823
94
94
  lucid/nn/modules/attention.py,sha256=pZi7IGsNFu2xCmeLMuyWgveMyi2QXtaKRKQ70yAeE0c,4407
95
95
  lucid/nn/modules/conv.py,sha256=KbtInQgKSw3U_qXiqy7x53DZM9YAMUq7sFas1nV7NxY,13932
@@ -99,6 +99,7 @@ lucid/nn/modules/linear.py,sha256=87cuFWYct9JlmtVC3jGR-8eouxxzANaVA6cd7p9r2Ho,28
99
99
  lucid/nn/modules/loss.py,sha256=pjEMIruhtpTHhHFsNThS9LFz-aI_DAXLqMV8KRXydEg,3431
100
100
  lucid/nn/modules/norm.py,sha256=qaaVQ2vfOUkPRLTHT4hgsRNxxN1--kdEhlrKXJmE--w,6803
101
101
  lucid/nn/modules/pool.py,sha256=ymVnS2NZjh08Tw0VeOfkB6AVrMeLmCKvgxkmEO3KUuw,5044
102
+ lucid/nn/modules/rnn.py,sha256=iMwWLtTC9i3k_h8pI6to-2rJfiqkXX9hzeDDJT3i7XU,2968
102
103
  lucid/nn/modules/sparse.py,sha256=EpjiviED2nI55wUjh1twFwa4Lvlrzw0TR6lpCDGeSbo,1147
103
104
  lucid/nn/modules/transformer.py,sha256=z56emF_eX18pxRELjfmmsY-7Bn9h2yjIdxCaxs6YDwA,11246
104
105
  lucid/nn/modules/vision.py,sha256=8xYasT7TNj4NXwMwwJIw1nbV1paeWEFg_ZohXn9kZBg,1579
@@ -110,7 +111,7 @@ lucid/optim/prop.py,sha256=CbsWmoBb_g_8z16M3T6dMoSR9c72hm8M375IT1UHjpw,4740
110
111
  lucid/optim/sgd.py,sha256=DBZ1ZXQ9TfKZCRECfNRMDH9mvqUWCOPdY5TobnVxpz8,4477
111
112
  lucid/optim/lr_scheduler/__init__.py,sha256=kUoyN2g9nwTtEAqEVij832WSRvzEpKZywSJdfD7MQvY,58
112
113
  lucid/optim/lr_scheduler/_base.py,sha256=NNJnjwmJpsRXathrbLtH4tjfBHtwOiJ5HwF1_S6Ym5c,3092
113
- lucid/optim/lr_scheduler/_schedulers.py,sha256=wxG6XvlTozz2TP57yXQL-krtSiO0hy2bySZq_sRDjh0,9227
114
+ lucid/optim/lr_scheduler/_schedulers.py,sha256=x6naustFYJUD8SEwdwzP8Wv4pDXWg-yp1HD5scvf1ZY,9365
114
115
  lucid/random/__init__.py,sha256=s8EAaKhEiTKT_vYjP4IFHx0xQVa1jqc_qIyvMauUu7M,2727
115
116
  lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
116
117
  lucid/transforms/__init__.py,sha256=DGznMbqhXdU9FLDMKnJawScO4HCqu40Sf_j4vJGJrjc,90
@@ -120,8 +121,8 @@ lucid/visual/__init__.py,sha256=6TuFDfmXTwpLyHl7_KqBfdzW6zqHjGzIFvymjFPlvjI,21
120
121
  lucid/visual/graph.py,sha256=YjpIDM_lloZARw3sCBiXPl_hT5A2gTk2fEHvwvJWXTk,4599
121
122
  lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
122
123
  lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
123
- lucid_dl-2.7.9.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
124
- lucid_dl-2.7.9.dist-info/METADATA,sha256=-h9IwL5SzshStMIt4eYUB0rWJ7nTELdGWSD9FIIWM00,11519
125
- lucid_dl-2.7.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
126
- lucid_dl-2.7.9.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
127
- lucid_dl-2.7.9.dist-info/RECORD,,
124
+ lucid_dl-2.8.0.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
125
+ lucid_dl-2.8.0.dist-info/METADATA,sha256=zLT-u67ODSZMq4vod9l1GqZCQN-0XDl_lKXSTdVcSzY,11519
126
+ lucid_dl-2.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
127
+ lucid_dl-2.8.0.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
128
+ lucid_dl-2.8.0.dist-info/RECORD,,