comfi-fast-grnn-torch 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfi_fast_grnn_torch-0.0.1/PKG-INFO +46 -0
- comfi_fast_grnn_torch-0.0.1/README.md +32 -0
- comfi_fast_grnn_torch-0.0.1/pyproject.toml +24 -0
- comfi_fast_grnn_torch-0.0.1/setup.cfg +4 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch/ComfiFastGRNN.py +290 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch/__init__.py +0 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch.egg-info/PKG-INFO +46 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch.egg-info/SOURCES.txt +9 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch.egg-info/dependency_links.txt +1 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch.egg-info/requires.txt +4 -0
- comfi_fast_grnn_torch-0.0.1/src/comfi_fast_grnn_torch.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: comfi_fast_grnn_torch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A PyTorch implementation of Fast ULCNet
|
|
5
|
+
Author-email: Nicolas Arrieta Larraza <NIAL@bang-olufsen.dk>, Niels de Koeijer <NEMK@bang-olufsen.dk>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/narrietal/Fast-ULCNet
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch>=2.5.1
|
|
11
|
+
Requires-Dist: libsegmenter==1.0.4
|
|
12
|
+
Requires-Dist: torchinfo==1.8.0
|
|
13
|
+
Requires-Dist: CRM_pytorch==0.1.0
|
|
14
|
+
|
|
15
|
+
# fast-ulcnet-torch
|
|
16
|
+
Implements FastULCNet and Comfi-FastGRNN in torch.
|
|
17
|
+
|
|
18
|
+
## Usage
|
|
19
|
+
|
|
20
|
+
The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
|
|
21
|
+
|
|
22
|
+
### Basic Implementation
|
|
23
|
+
Here is how to use the layer with default settings in a standard training loop:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
import torch
|
|
27
|
+
from comfi_fast_grnn_torch import ComfiFastGRNN
|
|
28
|
+
|
|
29
|
+
# 1. Initialize the layer
|
|
30
|
+
# batch_first=True is the default for this implementation
|
|
31
|
+
model = ComfiFastGRNN(
|
|
32
|
+
input_size=32,
|
|
33
|
+
hidden_size=64,
|
|
34
|
+
num_layers=1
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
|
|
38
|
+
x = torch.randn(10, 50, 32)
|
|
39
|
+
|
|
40
|
+
# 3. Forward pass
|
|
41
|
+
# Returns output (all timesteps) and final hidden state
|
|
42
|
+
output, h_n = model(x)
|
|
43
|
+
|
|
44
|
+
print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
|
|
45
|
+
print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
|
|
46
|
+
```
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# fast-ulcnet-torch
|
|
2
|
+
Implements FastULCNet and Comfi-FastGRNN in torch.
|
|
3
|
+
|
|
4
|
+
## Usage
|
|
5
|
+
|
|
6
|
+
The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
|
|
7
|
+
|
|
8
|
+
### Basic Implementation
|
|
9
|
+
Here is how to use the layer with default settings in a standard training loop:
|
|
10
|
+
|
|
11
|
+
```python
|
|
12
|
+
import torch
|
|
13
|
+
from comfi_fast_grnn_torch import ComfiFastGRNN
|
|
14
|
+
|
|
15
|
+
# 1. Initialize the layer
|
|
16
|
+
# batch_first=True is the default for this implementation
|
|
17
|
+
model = ComfiFastGRNN(
|
|
18
|
+
input_size=32,
|
|
19
|
+
hidden_size=64,
|
|
20
|
+
num_layers=1
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
|
|
24
|
+
x = torch.randn(10, 50, 32)
|
|
25
|
+
|
|
26
|
+
# 3. Forward pass
|
|
27
|
+
# Returns output (all timesteps) and final hidden state
|
|
28
|
+
output, h_n = model(x)
|
|
29
|
+
|
|
30
|
+
print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
|
|
31
|
+
print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
|
|
32
|
+
```
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "comfi_fast_grnn_torch"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "A PyTorch implementation of Fast ULCNet"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Nicolas Arrieta Larraza", email = "NIAL@bang-olufsen.dk"},
|
|
14
|
+
{name = "Niels de Koeijer", email = "NEMK@bang-olufsen.dk"}
|
|
15
|
+
]
|
|
16
|
+
dependencies = [
|
|
17
|
+
"torch>=2.5.1",
|
|
18
|
+
"libsegmenter==1.0.4",
|
|
19
|
+
"torchinfo==1.8.0",
|
|
20
|
+
"CRM_pytorch==0.1.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[project.urls]
|
|
24
|
+
"Homepage" = "https://github.com/narrietal/Fast-ULCNet"
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
# -------------------------------
|
|
6
|
+
# Auxiliar non-linear generation function
|
|
7
|
+
# -------------------------------
|
|
8
|
+
def gen_non_linearity(A, non_linearity):
|
|
9
|
+
"""
|
|
10
|
+
Returns required activation for a tensor based on the inputs
|
|
11
|
+
"""
|
|
12
|
+
if non_linearity == "tanh":
|
|
13
|
+
return torch.tanh(A)
|
|
14
|
+
elif non_linearity == "sigmoid":
|
|
15
|
+
return torch.sigmoid(A)
|
|
16
|
+
elif non_linearity == "relu":
|
|
17
|
+
return torch.relu(A)
|
|
18
|
+
elif non_linearity == "quantTanh":
|
|
19
|
+
return torch.clamp(A, -1.0, 1.0)
|
|
20
|
+
elif non_linearity == "quantSigm":
|
|
21
|
+
A = (A + 1.0) / 2.0
|
|
22
|
+
return torch.clamp(A, 0.0, 1.0)
|
|
23
|
+
elif non_linearity == "quantSigm4":
|
|
24
|
+
A = (A + 2.0) / 4.0
|
|
25
|
+
return torch.clamp(A, 0.0, 1.0)
|
|
26
|
+
elif callable(non_linearity):
|
|
27
|
+
return non_linearity(A)
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"non_linearity must be one of ['tanh', 'sigmoid', 'relu', 'quantTanh', 'quantSigm', 'quantSigm4'] or callable"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# -------------------------------
|
|
34
|
+
# Comfi-FastGRNN cell
|
|
35
|
+
# -------------------------------
|
|
36
|
+
class ComfiFastGRNNCell(nn.Module):
|
|
37
|
+
'''
|
|
38
|
+
Comfi-FastGRNN Cell
|
|
39
|
+
|
|
40
|
+
This class is imported from the official FastGRNN cell code to Pytorch syntax, and
|
|
41
|
+
it is extended with the trainable complementary filter aproach suggested in our paper.
|
|
42
|
+
|
|
43
|
+
Original FastGRNN cell code available in: https://github.com/microsoft/EdgeML/tree/master
|
|
44
|
+
|
|
45
|
+
The cell has both Full Rank and Low Rank Formulations and
|
|
46
|
+
multiple activation functions for the gates.
|
|
47
|
+
|
|
48
|
+
hidden_size = # hidden units
|
|
49
|
+
|
|
50
|
+
gate_non_linearity = nonlinearity for the gate can be chosen from
|
|
51
|
+
[tanh, sigmoid, relu, quantTanh, quantSigm]
|
|
52
|
+
|
|
53
|
+
update_non_linearity = nonlinearity for final rnn update
|
|
54
|
+
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
|
55
|
+
|
|
56
|
+
w_rank = rank of W matrix (creates two matrices if not None)
|
|
57
|
+
u_rank = rank of U matrix (creates two matrices if not None)
|
|
58
|
+
zeta_init = init for zeta, the scale param
|
|
59
|
+
nu_init = init for nu, the translation param
|
|
60
|
+
lambda_init = init value for lambda, the CF drift modulation parameter
|
|
61
|
+
gamma_init = init value for gamma, the CF hidden state contribution parameter
|
|
62
|
+
|
|
63
|
+
Equations of the RNN state update:
|
|
64
|
+
|
|
65
|
+
z_t = gate_nl(W + Uh_{t-1} + B_g)
|
|
66
|
+
h_t^ = update_nl(W + Uh_{t-1} + B_h)
|
|
67
|
+
h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
|
|
68
|
+
h_t_comfi = gamma*h_t + (1-gamma)*lambda
|
|
69
|
+
|
|
70
|
+
W and U can further parameterised into low rank version by
|
|
71
|
+
W = matmul(W_1, W_2) and U = matmul(U_1, U_2)
|
|
72
|
+
'''
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
input_size,
|
|
77
|
+
hidden_size,
|
|
78
|
+
gate_non_linearity="sigmoid",
|
|
79
|
+
update_non_linearity="tanh",
|
|
80
|
+
w_rank=None,
|
|
81
|
+
u_rank=None,
|
|
82
|
+
zeta_init=1.0,
|
|
83
|
+
nu_init=-4.0,
|
|
84
|
+
lambda_init=0.0,
|
|
85
|
+
gamma_init=0.999,
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.hidden_size = hidden_size
|
|
89
|
+
self.gate_non_linearity = gate_non_linearity
|
|
90
|
+
self.update_non_linearity = update_non_linearity
|
|
91
|
+
self.w_rank = w_rank
|
|
92
|
+
self.u_rank = u_rank
|
|
93
|
+
|
|
94
|
+
# --- Weight definitions ---
|
|
95
|
+
if w_rank is None:
|
|
96
|
+
self.w_matrix = nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)
|
|
97
|
+
else:
|
|
98
|
+
self.w_matrix_1 = nn.Parameter(torch.randn(input_size, w_rank) * 0.1)
|
|
99
|
+
self.w_matrix_2 = nn.Parameter(torch.randn(w_rank, hidden_size) * 0.1)
|
|
100
|
+
|
|
101
|
+
if u_rank is None:
|
|
102
|
+
self.u_matrix = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
|
|
103
|
+
else:
|
|
104
|
+
self.u_matrix_1 = nn.Parameter(torch.randn(hidden_size, u_rank) * 0.1)
|
|
105
|
+
self.u_matrix_2 = nn.Parameter(torch.randn(u_rank, hidden_size) * 0.1)
|
|
106
|
+
|
|
107
|
+
# --- Biases ---
|
|
108
|
+
self.bias_gate = nn.Parameter(torch.ones(1, hidden_size))
|
|
109
|
+
self.bias_update = nn.Parameter(torch.ones(1, hidden_size))
|
|
110
|
+
|
|
111
|
+
# --- Scalars ---
|
|
112
|
+
self.zeta = nn.Parameter(torch.tensor([[zeta_init]], dtype=torch.float32))
|
|
113
|
+
self.nu = nn.Parameter(torch.tensor([[nu_init]], dtype=torch.float32))
|
|
114
|
+
self.lambd = nn.Parameter(torch.tensor([lambda_init], dtype=torch.float32))
|
|
115
|
+
self.gamma = nn.Parameter(torch.tensor([gamma_init], dtype=torch.float32))
|
|
116
|
+
|
|
117
|
+
def forward(self, x, h_prev):
|
|
118
|
+
# Compute W*x
|
|
119
|
+
if self.w_rank is None:
|
|
120
|
+
W = x @ self.w_matrix
|
|
121
|
+
else:
|
|
122
|
+
W = x @ self.w_matrix_1 @ self.w_matrix_2
|
|
123
|
+
|
|
124
|
+
# Compute U*h_prev
|
|
125
|
+
if self.u_rank is None:
|
|
126
|
+
U = h_prev @ self.u_matrix
|
|
127
|
+
else:
|
|
128
|
+
U = h_prev @ self.u_matrix_1 @ self.u_matrix_2
|
|
129
|
+
|
|
130
|
+
# Gates
|
|
131
|
+
z = gen_non_linearity(W + U + self.bias_gate, self.gate_non_linearity)
|
|
132
|
+
h_hat = gen_non_linearity(W + U + self.bias_update, self.update_non_linearity)
|
|
133
|
+
|
|
134
|
+
# FastGRNN update
|
|
135
|
+
h = z * h_prev + (torch.sigmoid(self.zeta) * (1 - z) + torch.sigmoid(self.nu)) * h_hat
|
|
136
|
+
|
|
137
|
+
# Comfi-FastGRNN update
|
|
138
|
+
gamma_clamped = torch.clamp(self.gamma, 0.0, 1.0)
|
|
139
|
+
h_t_comfi = gamma_clamped * h + (1 - gamma_clamped) * self.lambd
|
|
140
|
+
|
|
141
|
+
return h_t_comfi
|
|
142
|
+
|
|
143
|
+
# -------------------------------
|
|
144
|
+
# Comfi-FastGRNN layer
|
|
145
|
+
# -------------------------------
|
|
146
|
+
class ComfiFastGRNN(nn.Module):
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
input_size: int,
|
|
150
|
+
hidden_size: int,
|
|
151
|
+
num_layers: int = 1,
|
|
152
|
+
batch_first: bool = True,
|
|
153
|
+
dropout: float = 0.0,
|
|
154
|
+
bidirectional: bool = False,
|
|
155
|
+
# ---- Cell arguments (explicitly exposed) ----
|
|
156
|
+
gate_non_linearity: str = "sigmoid",
|
|
157
|
+
update_non_linearity: str = "tanh",
|
|
158
|
+
w_rank: int | None = None,
|
|
159
|
+
u_rank: int | None = None,
|
|
160
|
+
zeta_init: float = 1.0,
|
|
161
|
+
nu_init: float = -4.0,
|
|
162
|
+
lambda_init: float = 0.0,
|
|
163
|
+
gamma_init: float = 0.999,
|
|
164
|
+
):
|
|
165
|
+
super().__init__()
|
|
166
|
+
|
|
167
|
+
# ---- Layer config ----
|
|
168
|
+
self.input_size = input_size
|
|
169
|
+
self.hidden_size = hidden_size
|
|
170
|
+
self.num_layers = num_layers
|
|
171
|
+
self.batch_first = batch_first
|
|
172
|
+
self.dropout = dropout
|
|
173
|
+
self.bidirectional = bidirectional
|
|
174
|
+
self.num_directions = 2 if bidirectional else 1
|
|
175
|
+
|
|
176
|
+
# ---- Cell config (stored for export / repr / checkpoint clarity) ----
|
|
177
|
+
self.gate_non_linearity = gate_non_linearity
|
|
178
|
+
self.update_non_linearity = update_non_linearity
|
|
179
|
+
self.w_rank = w_rank
|
|
180
|
+
self.u_rank = u_rank
|
|
181
|
+
self.zeta_init = zeta_init
|
|
182
|
+
self.nu_init = nu_init
|
|
183
|
+
self.lambda_init = lambda_init
|
|
184
|
+
self.gamma_init = gamma_init
|
|
185
|
+
|
|
186
|
+
# ---- Cells ----
|
|
187
|
+
self.cells_fwd = nn.ModuleList()
|
|
188
|
+
self.cells_bwd = nn.ModuleList() if bidirectional else None
|
|
189
|
+
|
|
190
|
+
for layer in range(num_layers):
|
|
191
|
+
in_size = input_size if layer == 0 else hidden_size * self.num_directions
|
|
192
|
+
|
|
193
|
+
self.cells_fwd.append(
|
|
194
|
+
ComfiFastGRNNCell(
|
|
195
|
+
input_size=in_size,
|
|
196
|
+
hidden_size=hidden_size,
|
|
197
|
+
gate_non_linearity=gate_non_linearity,
|
|
198
|
+
update_non_linearity=update_non_linearity,
|
|
199
|
+
w_rank=w_rank,
|
|
200
|
+
u_rank=u_rank,
|
|
201
|
+
zeta_init=zeta_init,
|
|
202
|
+
nu_init=nu_init,
|
|
203
|
+
lambda_init=lambda_init,
|
|
204
|
+
gamma_init=gamma_init,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if bidirectional:
|
|
209
|
+
self.cells_bwd.append(
|
|
210
|
+
ComfiFastGRNNCell(
|
|
211
|
+
input_size=in_size,
|
|
212
|
+
hidden_size=hidden_size,
|
|
213
|
+
gate_non_linearity=gate_non_linearity,
|
|
214
|
+
update_non_linearity=update_non_linearity,
|
|
215
|
+
w_rank=w_rank,
|
|
216
|
+
u_rank=u_rank,
|
|
217
|
+
zeta_init=zeta_init,
|
|
218
|
+
nu_init=nu_init,
|
|
219
|
+
lambda_init=lambda_init,
|
|
220
|
+
gamma_init=gamma_init,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def forward(self, x, h0=None):
|
|
225
|
+
"""
|
|
226
|
+
x: (batch, seq_len, input_size) if batch_first=True
|
|
227
|
+
h0: (num_layers * num_directions, batch, hidden_size)
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
output: (batch, seq_len, hidden_size * num_directions)
|
|
231
|
+
h_n: (num_layers * num_directions, batch, hidden_size)
|
|
232
|
+
"""
|
|
233
|
+
if not self.batch_first:
|
|
234
|
+
x = x.transpose(0, 1)
|
|
235
|
+
|
|
236
|
+
batch_size, seq_len, _ = x.size()
|
|
237
|
+
|
|
238
|
+
if h0 is None:
|
|
239
|
+
h0 = x.new_zeros(
|
|
240
|
+
self.num_layers * self.num_directions,
|
|
241
|
+
batch_size,
|
|
242
|
+
self.hidden_size,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
layer_input = x
|
|
246
|
+
h_n = []
|
|
247
|
+
|
|
248
|
+
for layer in range(self.num_layers):
|
|
249
|
+
fw_cell = self.cells_fwd[layer]
|
|
250
|
+
h_fw = h0[layer * self.num_directions + 0]
|
|
251
|
+
fw_outs = []
|
|
252
|
+
|
|
253
|
+
# ---- forward ----
|
|
254
|
+
for t in range(seq_len):
|
|
255
|
+
h_fw = fw_cell(layer_input[:, t, :], h_fw)
|
|
256
|
+
fw_outs.append(h_fw.unsqueeze(1))
|
|
257
|
+
|
|
258
|
+
fw_out = torch.cat(fw_outs, dim=1)
|
|
259
|
+
|
|
260
|
+
if self.bidirectional:
|
|
261
|
+
bw_cell = self.cells_bwd[layer]
|
|
262
|
+
h_bw = h0[layer * self.num_directions + 1]
|
|
263
|
+
bw_outs = []
|
|
264
|
+
|
|
265
|
+
# ---- backward ----
|
|
266
|
+
for t in reversed(range(seq_len)):
|
|
267
|
+
h_bw = bw_cell(layer_input[:, t, :], h_bw)
|
|
268
|
+
bw_outs.append(h_bw.unsqueeze(1))
|
|
269
|
+
|
|
270
|
+
bw_outs.reverse()
|
|
271
|
+
bw_out = torch.cat(bw_outs, dim=1)
|
|
272
|
+
|
|
273
|
+
layer_out = torch.cat([fw_out, bw_out], dim=2)
|
|
274
|
+
h_n.extend([h_fw, h_bw])
|
|
275
|
+
else:
|
|
276
|
+
layer_out = fw_out
|
|
277
|
+
h_n.append(h_fw)
|
|
278
|
+
|
|
279
|
+
if self.dropout > 0.0 and layer < self.num_layers - 1:
|
|
280
|
+
layer_out = F.dropout(layer_out, p=self.dropout, training=self.training)
|
|
281
|
+
|
|
282
|
+
layer_input = layer_out
|
|
283
|
+
|
|
284
|
+
output = layer_input
|
|
285
|
+
h_n = torch.stack(h_n, dim=0)
|
|
286
|
+
|
|
287
|
+
if not self.batch_first:
|
|
288
|
+
output = output.transpose(0, 1)
|
|
289
|
+
|
|
290
|
+
return output, h_n
|
|
File without changes
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: comfi_fast_grnn_torch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A PyTorch implementation of Fast ULCNet
|
|
5
|
+
Author-email: Nicolas Arrieta Larraza <NIAL@bang-olufsen.dk>, Niels de Koeijer <NEMK@bang-olufsen.dk>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/narrietal/Fast-ULCNet
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch>=2.5.1
|
|
11
|
+
Requires-Dist: libsegmenter==1.0.4
|
|
12
|
+
Requires-Dist: torchinfo==1.8.0
|
|
13
|
+
Requires-Dist: CRM_pytorch==0.1.0
|
|
14
|
+
|
|
15
|
+
# fast-ulcnet-torch
|
|
16
|
+
Implements FastULCNet and Comfi-FastGRNN in torch.
|
|
17
|
+
|
|
18
|
+
## Usage
|
|
19
|
+
|
|
20
|
+
The `ComfiFastGRNN` module is designed to be a drop-in replacement for standard PyTorch RNN layers (like `nn.LSTM` or `nn.GRU`), but with added support for low-rank factorization and complementary filtering.
|
|
21
|
+
|
|
22
|
+
### Basic Implementation
|
|
23
|
+
Here is how to use the layer with default settings in a standard training loop:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
import torch
|
|
27
|
+
from comfi_fast_grnn_torch import ComfiFastGRNN
|
|
28
|
+
|
|
29
|
+
# 1. Initialize the layer
|
|
30
|
+
# batch_first=True is the default for this implementation
|
|
31
|
+
model = ComfiFastGRNN(
|
|
32
|
+
input_size=32,
|
|
33
|
+
hidden_size=64,
|
|
34
|
+
num_layers=1
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# 2. Create dummy input: (Batch Size, Sequence Length, Input Size)
|
|
38
|
+
x = torch.randn(10, 50, 32)
|
|
39
|
+
|
|
40
|
+
# 3. Forward pass
|
|
41
|
+
# Returns output (all timesteps) and final hidden state
|
|
42
|
+
output, h_n = model(x)
|
|
43
|
+
|
|
44
|
+
print(f"Output shape: {output.shape}") # torch.Size([10, 50, 64])
|
|
45
|
+
print(f"Hidden state shape: {h_n.shape}") # torch.Size([1, 10, 64])
|
|
46
|
+
```
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/comfi_fast_grnn_torch/ComfiFastGRNN.py
|
|
4
|
+
src/comfi_fast_grnn_torch/__init__.py
|
|
5
|
+
src/comfi_fast_grnn_torch.egg-info/PKG-INFO
|
|
6
|
+
src/comfi_fast_grnn_torch.egg-info/SOURCES.txt
|
|
7
|
+
src/comfi_fast_grnn_torch.egg-info/dependency_links.txt
|
|
8
|
+
src/comfi_fast_grnn_torch.egg-info/requires.txt
|
|
9
|
+
src/comfi_fast_grnn_torch.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
comfi_fast_grnn_torch
|