comfi-fast-grnn-torch 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: comfi_fast_grnn_torch
3
+ Version: 0.0.1
4
+ Summary: A PyTorch implementation of Fast ULCNet
5
+ Author-email: Nicolas Arrieta Larraza <NIAL@bang-olufsen.dk>, Niels de Koeijer <NEMK@bang-olufsen.dk>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/narrietal/Fast-ULCNet
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch>=2.5.1
11
+ Requires-Dist: libsegmenter==1.0.4
12
+ Requires-Dist: torchinfo==1.8.0
13
+ Requires-Dist: CRM_pytorch==0.1.0
14
+
15
+ # fast-ulcnet-torch
16
+ Implements FastULCNet and Comfi-FastGRNN in torch.
17
+
18
+ ## Usage
19
+
20
+ The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
21
+
22
+ ### Basic Implementation
23
+ Here is how to use the layer with default settings in a standard training loop:
24
+
25
+ ```python
26
+ import torch
27
+ from comfi_fast_grnn_torch import ComfiFastGRNN
28
+
29
+ # 1. Initialize the layer
30
+ # batch_first=True is the default for this implementation
31
+ model = ComfiFastGRNN(
32
+ input_size=32,
33
+ hidden_size=64,
34
+ num_layers=1
35
+ )
36
+
37
+ # 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
38
+ x = torch.randn(10, 50, 32)
39
+
40
+ # 3. Forward pass
41
+ # Returns output (all timesteps) and final hidden state
42
+ output, h_n = model(x)
43
+
44
+ print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
45
+ print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
46
+ ```
@@ -0,0 +1,32 @@
1
+ # fast-ulcnet-torch
2
+ Implements FastULCNet and Comfi-FastGRNN in torch.
3
+
4
+ ## Usage
5
+
6
+ The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
7
+
8
+ ### Basic Implementation
9
+ Here is how to use the layer with default settings in a standard training loop:
10
+
11
+ ```python
12
+ import torch
13
+ from comfi_fast_grnn_torch import ComfiFastGRNN
14
+
15
+ # 1. Initialize the layer
16
+ # batch_first=True is the default for this implementation
17
+ model = ComfiFastGRNN(
18
+ input_size=32,
19
+ hidden_size=64,
20
+ num_layers=1
21
+ )
22
+
23
+ # 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
24
+ x = torch.randn(10, 50, 32)
25
+
26
+ # 3. Forward pass
27
+ # Returns output (all timesteps) and final hidden state
28
+ output, h_n = model(x)
29
+
30
+ print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
31
+ print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
32
+ ```
@@ -0,0 +1,24 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "comfi_fast_grnn_torch"
7
+ version = "0.0.1"
8
+ description = "A PyTorch implementation of Fast ULCNet"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Nicolas Arrieta Larraza", email = "NIAL@bang-olufsen.dk"},
14
+ {name = "Niels de Koeijer", email = "NEMK@bang-olufsen.dk"}
15
+ ]
16
+ dependencies = [
17
+ "torch>=2.5.1",
18
+ "libsegmenter==1.0.4",
19
+ "torchinfo==1.8.0",
20
+ "CRM_pytorch==0.1.0",
21
+ ]
22
+
23
+ [project.urls]
24
+ "Homepage" = "https://github.com/narrietal/Fast-ULCNet"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,290 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # -------------------------------
6
+ # Auxiliar non-linear generation function
7
+ # -------------------------------
8
+ def gen_non_linearity(A, non_linearity):
9
+ """
10
+ Returns required activation for a tensor based on the inputs
11
+ """
12
+ if non_linearity == "tanh":
13
+ return torch.tanh(A)
14
+ elif non_linearity == "sigmoid":
15
+ return torch.sigmoid(A)
16
+ elif non_linearity == "relu":
17
+ return torch.relu(A)
18
+ elif non_linearity == "quantTanh":
19
+ return torch.clamp(A, -1.0, 1.0)
20
+ elif non_linearity == "quantSigm":
21
+ A = (A + 1.0) / 2.0
22
+ return torch.clamp(A, 0.0, 1.0)
23
+ elif non_linearity == "quantSigm4":
24
+ A = (A + 2.0) / 4.0
25
+ return torch.clamp(A, 0.0, 1.0)
26
+ elif callable(non_linearity):
27
+ return non_linearity(A)
28
+ else:
29
+ raise ValueError(
30
+ "non_linearity must be one of ['tanh', 'sigmoid', 'relu', 'quantTanh', 'quantSigm', 'quantSigm4'] or callable"
31
+ )
32
+
33
+ # -------------------------------
34
+ # Comfi-FastGRNN cell
35
+ # -------------------------------
36
+ class ComfiFastGRNNCell(nn.Module):
37
+ '''
38
+ Comfi-FastGRNN Cell
39
+
40
+ This class is imported from the official FastGRNN cell code to Pytorch syntax, and
41
+ it is extended with the trainable complementary filter aproach suggested in our paper.
42
+
43
+ Original FastGRNN cell code available in: https://github.com/microsoft/EdgeML/tree/master
44
+
45
+ The cell has both Full Rank and Low Rank Formulations and
46
+ multiple activation functions for the gates.
47
+
48
+ hidden_size = # hidden units
49
+
50
+ gate_non_linearity = nonlinearity for the gate can be chosen from
51
+ [tanh, sigmoid, relu, quantTanh, quantSigm]
52
+
53
+ update_non_linearity = nonlinearity for final rnn update
54
+ can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
55
+
56
+ w_rank = rank of W matrix (creates two matrices if not None)
57
+ u_rank = rank of U matrix (creates two matrices if not None)
58
+ zeta_init = init for zeta, the scale param
59
+ nu_init = init for nu, the translation param
60
+ lambda_init = init value for lambda, the CF drift modulation parameter
61
+ gamma_init = init value for gamma, the CF hidden state contribution parameter
62
+
63
+ Equations of the RNN state update:
64
+
65
+ z_t = gate_nl(W + Uh_{t-1} + B_g)
66
+ h_t^ = update_nl(W + Uh_{t-1} + B_h)
67
+ h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
68
+ h_t_comfi = gamma*h_t + (1-gamma)*lambda
69
+
70
+ W and U can further parameterised into low rank version by
71
+ W = matmul(W_1, W_2) and U = matmul(U_1, U_2)
72
+ '''
73
+
74
+ def __init__(
75
+ self,
76
+ input_size,
77
+ hidden_size,
78
+ gate_non_linearity="sigmoid",
79
+ update_non_linearity="tanh",
80
+ w_rank=None,
81
+ u_rank=None,
82
+ zeta_init=1.0,
83
+ nu_init=-4.0,
84
+ lambda_init=0.0,
85
+ gamma_init=0.999,
86
+ ):
87
+ super().__init__()
88
+ self.hidden_size = hidden_size
89
+ self.gate_non_linearity = gate_non_linearity
90
+ self.update_non_linearity = update_non_linearity
91
+ self.w_rank = w_rank
92
+ self.u_rank = u_rank
93
+
94
+ # --- Weight definitions ---
95
+ if w_rank is None:
96
+ self.w_matrix = nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)
97
+ else:
98
+ self.w_matrix_1 = nn.Parameter(torch.randn(input_size, w_rank) * 0.1)
99
+ self.w_matrix_2 = nn.Parameter(torch.randn(w_rank, hidden_size) * 0.1)
100
+
101
+ if u_rank is None:
102
+ self.u_matrix = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
103
+ else:
104
+ self.u_matrix_1 = nn.Parameter(torch.randn(hidden_size, u_rank) * 0.1)
105
+ self.u_matrix_2 = nn.Parameter(torch.randn(u_rank, hidden_size) * 0.1)
106
+
107
+ # --- Biases ---
108
+ self.bias_gate = nn.Parameter(torch.ones(1, hidden_size))
109
+ self.bias_update = nn.Parameter(torch.ones(1, hidden_size))
110
+
111
+ # --- Scalars ---
112
+ self.zeta = nn.Parameter(torch.tensor([[zeta_init]], dtype=torch.float32))
113
+ self.nu = nn.Parameter(torch.tensor([[nu_init]], dtype=torch.float32))
114
+ self.lambd = nn.Parameter(torch.tensor([lambda_init], dtype=torch.float32))
115
+ self.gamma = nn.Parameter(torch.tensor([gamma_init], dtype=torch.float32))
116
+
117
+ def forward(self, x, h_prev):
118
+ # Compute W*x
119
+ if self.w_rank is None:
120
+ W = x @ self.w_matrix
121
+ else:
122
+ W = x @ self.w_matrix_1 @ self.w_matrix_2
123
+
124
+ # Compute U*h_prev
125
+ if self.u_rank is None:
126
+ U = h_prev @ self.u_matrix
127
+ else:
128
+ U = h_prev @ self.u_matrix_1 @ self.u_matrix_2
129
+
130
+ # Gates
131
+ z = gen_non_linearity(W + U + self.bias_gate, self.gate_non_linearity)
132
+ h_hat = gen_non_linearity(W + U + self.bias_update, self.update_non_linearity)
133
+
134
+ # FastGRNN update
135
+ h = z * h_prev + (torch.sigmoid(self.zeta) * (1 - z) + torch.sigmoid(self.nu)) * h_hat
136
+
137
+ # Comfi-FastGRNN update
138
+ gamma_clamped = torch.clamp(self.gamma, 0.0, 1.0)
139
+ h_t_comfi = gamma_clamped * h + (1 - gamma_clamped) * self.lambd
140
+
141
+ return h_t_comfi
142
+
143
+ # -------------------------------
144
+ # Comfi-FastGRNN layer
145
+ # -------------------------------
146
+ class ComfiFastGRNN(nn.Module):
147
+ def __init__(
148
+ self,
149
+ input_size: int,
150
+ hidden_size: int,
151
+ num_layers: int = 1,
152
+ batch_first: bool = True,
153
+ dropout: float = 0.0,
154
+ bidirectional: bool = False,
155
+ # ---- Cell arguments (explicitly exposed) ----
156
+ gate_non_linearity: str = "sigmoid",
157
+ update_non_linearity: str = "tanh",
158
+ w_rank: int | None = None,
159
+ u_rank: int | None = None,
160
+ zeta_init: float = 1.0,
161
+ nu_init: float = -4.0,
162
+ lambda_init: float = 0.0,
163
+ gamma_init: float = 0.999,
164
+ ):
165
+ super().__init__()
166
+
167
+ # ---- Layer config ----
168
+ self.input_size = input_size
169
+ self.hidden_size = hidden_size
170
+ self.num_layers = num_layers
171
+ self.batch_first = batch_first
172
+ self.dropout = dropout
173
+ self.bidirectional = bidirectional
174
+ self.num_directions = 2 if bidirectional else 1
175
+
176
+ # ---- Cell config (stored for export / repr / checkpoint clarity) ----
177
+ self.gate_non_linearity = gate_non_linearity
178
+ self.update_non_linearity = update_non_linearity
179
+ self.w_rank = w_rank
180
+ self.u_rank = u_rank
181
+ self.zeta_init = zeta_init
182
+ self.nu_init = nu_init
183
+ self.lambda_init = lambda_init
184
+ self.gamma_init = gamma_init
185
+
186
+ # ---- Cells ----
187
+ self.cells_fwd = nn.ModuleList()
188
+ self.cells_bwd = nn.ModuleList() if bidirectional else None
189
+
190
+ for layer in range(num_layers):
191
+ in_size = input_size if layer == 0 else hidden_size * self.num_directions
192
+
193
+ self.cells_fwd.append(
194
+ ComfiFastGRNNCell(
195
+ input_size=in_size,
196
+ hidden_size=hidden_size,
197
+ gate_non_linearity=gate_non_linearity,
198
+ update_non_linearity=update_non_linearity,
199
+ w_rank=w_rank,
200
+ u_rank=u_rank,
201
+ zeta_init=zeta_init,
202
+ nu_init=nu_init,
203
+ lambda_init=lambda_init,
204
+ gamma_init=gamma_init,
205
+ )
206
+ )
207
+
208
+ if bidirectional:
209
+ self.cells_bwd.append(
210
+ ComfiFastGRNNCell(
211
+ input_size=in_size,
212
+ hidden_size=hidden_size,
213
+ gate_non_linearity=gate_non_linearity,
214
+ update_non_linearity=update_non_linearity,
215
+ w_rank=w_rank,
216
+ u_rank=u_rank,
217
+ zeta_init=zeta_init,
218
+ nu_init=nu_init,
219
+ lambda_init=lambda_init,
220
+ gamma_init=gamma_init,
221
+ )
222
+ )
223
+
224
+ def forward(self, x, h0=None):
225
+ """
226
+ x: (batch, seq_len, input_size) if batch_first=True
227
+ h0: (num_layers * num_directions, batch, hidden_size)
228
+
229
+ Returns:
230
+ output: (batch, seq_len, hidden_size * num_directions)
231
+ h_n: (num_layers * num_directions, batch, hidden_size)
232
+ """
233
+ if not self.batch_first:
234
+ x = x.transpose(0, 1)
235
+
236
+ batch_size, seq_len, _ = x.size()
237
+
238
+ if h0 is None:
239
+ h0 = x.new_zeros(
240
+ self.num_layers * self.num_directions,
241
+ batch_size,
242
+ self.hidden_size,
243
+ )
244
+
245
+ layer_input = x
246
+ h_n = []
247
+
248
+ for layer in range(self.num_layers):
249
+ fw_cell = self.cells_fwd[layer]
250
+ h_fw = h0[layer * self.num_directions + 0]
251
+ fw_outs = []
252
+
253
+ # ---- forward ----
254
+ for t in range(seq_len):
255
+ h_fw = fw_cell(layer_input[:, t, :], h_fw)
256
+ fw_outs.append(h_fw.unsqueeze(1))
257
+
258
+ fw_out = torch.cat(fw_outs, dim=1)
259
+
260
+ if self.bidirectional:
261
+ bw_cell = self.cells_bwd[layer]
262
+ h_bw = h0[layer * self.num_directions + 1]
263
+ bw_outs = []
264
+
265
+ # ---- backward ----
266
+ for t in reversed(range(seq_len)):
267
+ h_bw = bw_cell(layer_input[:, t, :], h_bw)
268
+ bw_outs.append(h_bw.unsqueeze(1))
269
+
270
+ bw_outs.reverse()
271
+ bw_out = torch.cat(bw_outs, dim=1)
272
+
273
+ layer_out = torch.cat([fw_out, bw_out], dim=2)
274
+ h_n.extend([h_fw, h_bw])
275
+ else:
276
+ layer_out = fw_out
277
+ h_n.append(h_fw)
278
+
279
+ if self.dropout > 0.0 and layer < self.num_layers - 1:
280
+ layer_out = F.dropout(layer_out, p=self.dropout, training=self.training)
281
+
282
+ layer_input = layer_out
283
+
284
+ output = layer_input
285
+ h_n = torch.stack(h_n, dim=0)
286
+
287
+ if not self.batch_first:
288
+ output = output.transpose(0, 1)
289
+
290
+ return output, h_n
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: comfi_fast_grnn_torch
3
+ Version: 0.0.1
4
+ Summary: A PyTorch implementation of Fast ULCNet
5
+ Author-email: Nicolas Arrieta Larraza <NIAL@bang-olufsen.dk>, Niels de Koeijer <NEMK@bang-olufsen.dk>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/narrietal/Fast-ULCNet
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch>=2.5.1
11
+ Requires-Dist: libsegmenter==1.0.4
12
+ Requires-Dist: torchinfo==1.8.0
13
+ Requires-Dist: CRM_pytorch==0.1.0
14
+
15
+ # fast-ulcnet-torch
16
+ Implements FastULCNet and Comfi-FastGRNN in torch.
17
+
18
+ ## Usage
19
+
20
+ The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
21
+
22
+ ### Basic Implementation
23
+ Here is how to use the layer with default settings in a standard training loop:
24
+
25
+ ```python
26
+ import torch
27
+ from comfi_fast_grnn_torch import ComfiFastGRNN
28
+
29
+ # 1. Initialize the layer
30
+ # batch_first=True is the default for this implementation
31
+ model = ComfiFastGRNN(
32
+ input_size=32,
33
+ hidden_size=64,
34
+ num_layers=1
35
+ )
36
+
37
+ # 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
38
+ x = torch.randn(10, 50, 32)
39
+
40
+ # 3. Forward pass
41
+ # Returns output (all timesteps) and final hidden state
42
+ output, h_n = model(x)
43
+
44
+ print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
45
+ print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
46
+ ```
@@ -0,0 +1,9 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/comfi_fast_grnn_torch/ComfiFastGRNN.py
4
+ src/comfi_fast_grnn_torch/__init__.py
5
+ src/comfi_fast_grnn_torch.egg-info/PKG-INFO
6
+ src/comfi_fast_grnn_torch.egg-info/SOURCES.txt
7
+ src/comfi_fast_grnn_torch.egg-info/dependency_links.txt
8
+ src/comfi_fast_grnn_torch.egg-info/requires.txt
9
+ src/comfi_fast_grnn_torch.egg-info/top_level.txt
@@ -0,0 +1,4 @@
1
+ torch>=2.5.1
2
+ libsegmenter==1.0.4
3
+ torchinfo==1.8.0
4
+ CRM_pytorch==0.1.0