ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,156 @@
1
+ from collections.abc import Callable, Sequence
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from ml4gw.nn.autoencoder.base import Autoencoder
7
+ from ml4gw.nn.autoencoder.skip_connection import SkipConnection
8
+ from ml4gw.nn.autoencoder.utils import match_size
9
+
10
+ Module = Callable[[...], torch.nn.Module]
11
+
12
+
13
+ class ConvBlock(Autoencoder):
14
+ def __init__(
15
+ self,
16
+ in_channels: int,
17
+ encode_channels: int,
18
+ kernel_size: int,
19
+ stride: int = 1,
20
+ groups: int = 1,
21
+ activation: torch.nn.Module = torch.nn.ReLU(),
22
+ norm: Module = torch.nn.BatchNorm1d,
23
+ decode_channels: Optional[int] = None,
24
+ output_activation: Optional[torch.nn.Module] = None,
25
+ skip_connection: Optional[SkipConnection] = None,
26
+ ) -> None:
27
+ super().__init__(skip_connection=None)
28
+
29
+ self.kernel_size = kernel_size
30
+ self.padding = int((kernel_size - 1) // 2)
31
+ self.stride = stride
32
+
33
+ out_channels = encode_channels * groups
34
+ self.encode_layer = torch.nn.Conv1d(
35
+ in_channels,
36
+ out_channels,
37
+ kernel_size=kernel_size,
38
+ stride=stride,
39
+ padding=self.padding,
40
+ bias=False,
41
+ groups=groups,
42
+ )
43
+
44
+ decode_channels = decode_channels or in_channels
45
+ in_channels = encode_channels * groups
46
+ if skip_connection is not None:
47
+ in_channels = skip_connection.get_out_channels(in_channels)
48
+ self.decode_layer = torch.nn.ConvTranspose1d(
49
+ in_channels,
50
+ decode_channels,
51
+ kernel_size=kernel_size,
52
+ stride=stride,
53
+ padding=self.padding,
54
+ bias=False,
55
+ groups=groups,
56
+ )
57
+
58
+ self.activation = activation
59
+ if output_activation is not None:
60
+ self.output_activation = output_activation
61
+ else:
62
+ self.output_activation = activation
63
+
64
+ self.encode_norm = norm(out_channels)
65
+ self.decode_norm = norm(decode_channels)
66
+
67
+ def encode(self, X):
68
+ X = self.encode_layer(X)
69
+ X = self.encode_norm(X)
70
+ return self.activation(X)
71
+
72
+ def decode(self, X):
73
+ X = self.decode_layer(X)
74
+ X = self.decode_norm(X)
75
+ return self.output_activation(X)
76
+
77
+
78
+ class ConvolutionalAutoencoder(Autoencoder):
79
+ """
80
+ Build a stack of convolutional autoencoder layer
81
+ blocks. The output of each decoder layer will
82
+ match the shape of the input to its corresponding
83
+ encoder layer, except for the last decoder which
84
+ can have an arbitrary number of channels specified
85
+ by `decode_channels`.
86
+
87
+ All layers also share the same `activation` except
88
+ for the last decoder layer, which can have an
89
+ arbitrary `output_activation`.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ in_channels: int,
95
+ encode_channels: Sequence[int],
96
+ kernel_size: int,
97
+ stride: int = 1,
98
+ groups: int = 1,
99
+ activation: torch.nn.Module = torch.nn.ReLU(),
100
+ output_activation: Optional[torch.nn.Module] = None,
101
+ norm: Module = torch.nn.BatchNorm1d,
102
+ decode_channels: Optional[int] = None,
103
+ skip_connection: Optional[SkipConnection] = None,
104
+ ) -> None:
105
+ # TODO: how to do this dynamically? Maybe the base
106
+ # architecture looks for overlapping arguments between
107
+ # this and the skip connection class and then provides them?
108
+ # if skip_connection is not None:
109
+ # skip_connection = skip_connection(groups)
110
+ super().__init__(skip_connection=skip_connection)
111
+
112
+ output_activation = output_activation or activation
113
+ for i, channels in enumerate(encode_channels):
114
+ # All intermediate layers should decode to
115
+ # the same number of channels. The last decoder
116
+ # should decode to whatever number of channels
117
+ # was specified, even if it's `None` (in which
118
+ # case it will just be in_channels anyway)
119
+ decode = in_channels if i else decode_channels
120
+
121
+ # don't have the middle layer skip to itself
122
+ # TODO: wait I don't think this makes sense.
123
+ # j = len(encode_channels) - 1 - i
124
+ # connect = skip_connection if j else None
125
+ connect = skip_connection
126
+
127
+ # all intermediate layers should use the same
128
+ # activation. Only the last decoder should have
129
+ # a potentially different activation
130
+ out_act = None if i else output_activation
131
+
132
+ block = ConvBlock(
133
+ in_channels,
134
+ channels,
135
+ kernel_size,
136
+ stride,
137
+ groups,
138
+ activation=activation,
139
+ norm=norm,
140
+ decode_channels=decode,
141
+ skip_connection=connect,
142
+ output_activation=out_act,
143
+ )
144
+ self.blocks.append(block)
145
+ in_channels = channels * groups
146
+
147
+ def decode(self, *X, states=None, input_size: Optional[int] = None):
148
+ X = super().decode(*X, states=states)
149
+ if input_size is not None:
150
+ return match_size(X, input_size)
151
+ return X
152
+
153
+ def forward(self, X):
154
+ input_size = X.size(-1)
155
+ X = super().forward(X)
156
+ return match_size(X, input_size)
@@ -0,0 +1,46 @@
1
+ import torch
2
+
3
+ from ml4gw.nn.autoencoder.utils import match_size
4
+
5
+
6
+ class SkipConnection(torch.nn.Module):
7
+ def forward(self, X: torch.Tensor, state: torch.Tensor):
8
+ return match_size(X, state.size(-1))
9
+
10
+ def get_out_channels(self, in_channels):
11
+ return in_channels
12
+
13
+
14
+ class AddSkipConnect(SkipConnection):
15
+ def forward(self, X, state):
16
+ X = super().forward(X, state)
17
+ return X + state
18
+
19
+
20
+ class ConcatSkipConnect(SkipConnection):
21
+ def __init__(self, groups: int = 1):
22
+ super().__init__()
23
+ self.groups = groups
24
+
25
+ def get_out_channels(self, in_channels):
26
+ return 2 * in_channels
27
+
28
+ def forward(self, X, state):
29
+ X = super().forward(X, state)
30
+ if self.groups == 1:
31
+ return torch.cat([X, state], dim=1)
32
+
33
+ num_channels = X.size(1)
34
+ rem = num_channels % self.groups
35
+ if rem:
36
+ raise ValueError(
37
+ "Number of channels in input tensor {} cannot "
38
+ "be divided evenly into {} groups".format(
39
+ num_channels, self.groups
40
+ )
41
+ )
42
+
43
+ X = torch.split(X, self.groups, dim=1)
44
+ state = torch.split(state, self.groups, dim=1)
45
+ frags = [i for j in zip(X, state) for i in j]
46
+ return torch.cat(frags, dim=1)
@@ -0,0 +1,14 @@
1
+ import torch
2
+
3
+
4
+ def match_size(X: torch.Tensor, target_size: int):
5
+ diff = target_size - X.size(-1)
6
+ left = int(diff // 2)
7
+ right = diff - left
8
+
9
+ if diff > 0:
10
+ return torch.nn.functional.pad(X, (left, right))
11
+ elif diff < 0:
12
+ right = -right or None
13
+ return X[:, :, -left:right]
14
+ return X
ml4gw/nn/norm.py ADDED
@@ -0,0 +1,97 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+
5
+ NormLayer = Callable[[int], torch.nn.Module]
6
+
7
+
8
+ class GroupNorm1D(torch.nn.Module):
9
+ """
10
+ Custom implementation of GroupNorm which is faster than the
11
+ out-of-the-box PyTorch version at inference time.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ num_channels: int,
17
+ num_groups: Optional[int] = None,
18
+ eps: float = 1e-5,
19
+ ):
20
+ super().__init__()
21
+ num_groups = num_groups or num_channels
22
+ if num_channels % num_groups:
23
+ raise ValueError("num_groups must be a factor of num_channels")
24
+
25
+ self.num_channels = num_channels
26
+ self.num_groups = num_groups
27
+ self.channels_per_group = self.num_channels // self.num_groups
28
+ self.eps = eps
29
+
30
+ shape = (self.num_channels, 1)
31
+ self.weight = torch.nn.Parameter(torch.ones(shape))
32
+ self.bias = torch.nn.Parameter(torch.zeros(shape))
33
+
34
+ def forward(self, x):
35
+ keepdims = self.num_groups == self.num_channels
36
+
37
+ # compute group variance via the E[x**2] - E**2[x] trick
38
+ mean = x.mean(-1, keepdims=keepdims)
39
+ sq_mean = (x**2).mean(-1, keepdims=keepdims)
40
+
41
+ # if we have groups, do some reshape magic
42
+ # to calculate group level stats then
43
+ # reshape back to full channel dimension
44
+ if self.num_groups != self.num_channels:
45
+ mean = torch.stack([mean, sq_mean], dim=1)
46
+ mean = mean.reshape(
47
+ -1, 2, self.num_groups, self.channels_per_group
48
+ )
49
+ mean = mean.mean(-1, keepdims=True)
50
+ mean = mean.expand(-1, -1, -1, self.channels_per_group)
51
+ mean = mean.reshape(-1, 2, self.num_channels, 1)
52
+ mean, sq_mean = mean[:, 0], mean[:, 1]
53
+
54
+ # roll the mean and variance into the
55
+ # weight and bias so that we have to do
56
+ # fewer computations along the full time axis
57
+ std = (sq_mean - mean**2 + self.eps) ** 0.5
58
+ scale = self.weight / std
59
+ shift = self.bias - scale * mean
60
+ return shift + x * scale
61
+
62
+
63
+ class GroupNorm1DGetter:
64
+ """
65
+ Utility for making a NormLayer Callable that maps from
66
+ an integer number of channels to a torch Module. Useful
67
+ for command-line parameterization with jsonargparse.
68
+ """
69
+
70
+ def __init__(self, groups: Optional[int] = None) -> None:
71
+ self.groups = groups
72
+
73
+ def __call__(self, num_channels: int) -> torch.nn.Module:
74
+ if self.groups is None:
75
+ num_groups = None
76
+ else:
77
+ num_groups = min(num_channels, self.groups)
78
+ return GroupNorm1D(num_channels, num_groups)
79
+
80
+
81
+ # TODO generalize faster 1dDGroupNorm to 2D
82
+ class GroupNorm2DGetter:
83
+ """
84
+ Utility for making a NormLayer Callable that maps from
85
+ an integer number of channels to a torch Module. Useful
86
+ for command-line parameterization with jsonargparse.
87
+ """
88
+
89
+ def __init__(self, groups: Optional[int] = None) -> None:
90
+ self.groups = groups
91
+
92
+ def __call__(self, num_channels: int) -> torch.nn.Module:
93
+ if self.groups is None:
94
+ num_groups = num_channels
95
+ else:
96
+ num_groups = min(num_channels, self.groups)
97
+ return torch.nn.GroupNorm(num_groups, num_channels)
@@ -0,0 +1,2 @@
1
+ from .resnet_1d import ResNet1D
2
+ from .resnet_2d import ResNet2D