pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """Modules used for building the models."""
@@ -0,0 +1,161 @@
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from pocket_tts.modules.stateful_module import StatefulModule
9
+
10
+
11
+ def get_extra_padding_for_conv1d(
12
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
13
+ ) -> int:
14
+ """See `pad_for_conv1d`."""
15
+ length = x.shape[-1]
16
+ n_frames = (length - kernel_size + padding_total) / stride + 1
17
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
18
+ return ideal_length - length
19
+
20
+
21
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
22
+ """Pad for a convolution to make sure that the last window is full.
23
+ Extra padding is added at the end. This is required to ensure that we can rebuild
24
+ an output of the same length, as otherwise, even with padding, some time steps
25
+ might get removed.
26
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
27
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
28
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
29
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
30
+ 1 2 3 4 # once you removed padding, we are missing one time step !
31
+ """
32
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
33
+ return F.pad(x, (0, extra_padding))
34
+
35
+
36
+ class StreamingConv1d(StatefulModule):
37
+ """Conv1d with some builtin handling of asymmetric or causal padding
38
+ and normalization.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ in_channels: int,
44
+ out_channels: int,
45
+ kernel_size: int,
46
+ stride: int = 1,
47
+ dilation: int = 1,
48
+ groups: int = 1,
49
+ bias: bool = True,
50
+ pad_mode: str = "constant",
51
+ ):
52
+ super().__init__()
53
+ assert pad_mode in ["constant", "replicate"], pad_mode
54
+ self.pad_mode = pad_mode
55
+ # warn user on unusual setup between dilation and stride
56
+ if stride > 1 and dilation > 1:
57
+ warnings.warn(
58
+ "StreamingConv1d has been initialized with stride > 1 and dilation > 1"
59
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
60
+ )
61
+ self.conv = nn.Conv1d(
62
+ in_channels,
63
+ out_channels,
64
+ kernel_size,
65
+ stride,
66
+ dilation=dilation,
67
+ groups=groups,
68
+ bias=bias,
69
+ )
70
+
71
+ @property
72
+ def _stride(self) -> int:
73
+ return self.conv.stride[0]
74
+
75
+ @property
76
+ def _kernel_size(self) -> int:
77
+ return self.conv.kernel_size[0]
78
+
79
+ @property
80
+ def _effective_kernel_size(self) -> int:
81
+ dilation = self.conv.dilation[0]
82
+ return (self._kernel_size - 1) * dilation + 1 # effective kernel size with dilations
83
+
84
+ def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
85
+ stride = self._stride
86
+ # Effective kernel size accounting for dilation.
87
+ kernel = self._effective_kernel_size
88
+ previous = torch.zeros(batch_size, self.conv.in_channels, kernel - stride)
89
+ first = torch.ones(batch_size, dtype=torch.bool)
90
+ return dict(previous=previous, first=first)
91
+
92
+ def forward(self, x, model_state: dict | None):
93
+ B, C, T = x.shape
94
+ S = self._stride
95
+ assert T > 0 and T % S == 0, "Steps must be multiple of stride"
96
+ if model_state is None:
97
+ state = self.init_state(B, 0)
98
+ else:
99
+ state = self.get_state(model_state)
100
+ TP = state["previous"].shape[-1]
101
+ if TP and self.pad_mode == "replicate":
102
+ assert T >= TP, "Not enough content to pad streaming."
103
+ init = x[..., :1]
104
+ state["previous"][:] = torch.where(
105
+ state["first"].view(-1, 1, 1), init, state["previous"]
106
+ )
107
+ if TP:
108
+ x = torch.cat([state["previous"], x], dim=-1)
109
+ y = self.conv(x)
110
+ if TP:
111
+ state["previous"][:] = x[..., -TP:]
112
+ if self.pad_mode == "replicate":
113
+ state["first"] = torch.zeros_like(state["first"])
114
+ return y
115
+
116
+
117
+ class StreamingConvTranspose1d(StatefulModule):
118
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
119
+ and normalization.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_channels: int,
125
+ out_channels: int,
126
+ kernel_size: int,
127
+ stride: int = 1,
128
+ groups: int = 1,
129
+ bias: bool = True,
130
+ ):
131
+ super().__init__()
132
+ self.convtr = nn.ConvTranspose1d(
133
+ in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias
134
+ )
135
+
136
+ @property
137
+ def _stride(self) -> int:
138
+ return self.convtr.stride[0]
139
+
140
+ @property
141
+ def _kernel_size(self) -> int:
142
+ return self.convtr.kernel_size[0]
143
+
144
+ def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
145
+ K = self._kernel_size
146
+ S = self._stride
147
+ return dict(partial=torch.zeros(batch_size, self.convtr.out_channels, K - S))
148
+
149
+ def forward(self, x, mimi_state: dict):
150
+ layer_state = self.get_state(mimi_state)["partial"]
151
+ y = self.convtr(x)
152
+ PT = layer_state.shape[-1]
153
+ if PT > 0:
154
+ y[..., :PT] += layer_state
155
+ bias = self.convtr.bias
156
+ for_partial = y[..., -PT:]
157
+ if bias is not None:
158
+ for_partial -= bias[:, None]
159
+ layer_state[:] = for_partial
160
+ y = y[..., :-PT]
161
+ return y
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class DummyQuantizer(nn.Module):
6
+ """Simplified quantizer that only provides output projection for TTS.
7
+
8
+ This removes all unnecessary quantization logic since we don't use actual quantization.
9
+ """
10
+
11
+ def __init__(self, dimension: int, output_dimension: int):
12
+ super().__init__()
13
+ self.dimension = dimension
14
+ self.output_dimension = output_dimension
15
+ self.output_proj = torch.nn.Conv1d(self.dimension, self.output_dimension, 1, bias=False)
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ return self.output_proj(x)
@@ -0,0 +1,11 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerScale(nn.Module):
6
+ def __init__(self, channels: int, init: float):
7
+ super().__init__()
8
+ self.scale = nn.Parameter(torch.full((channels,), init))
9
+
10
+ def forward(self, x: torch.Tensor):
11
+ return self.scale * x
@@ -0,0 +1,285 @@
1
+ from typing import NamedTuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from torch.nn import functional as F
7
+ from typing_extensions import Self
8
+
9
+ from pocket_tts.modules.layer_scale import LayerScale
10
+ from pocket_tts.modules.rope import RotaryEmbedding
11
+ from pocket_tts.modules.stateful_module import StatefulModule
12
+ from pocket_tts.modules.transformer import StreamingMultiheadAttention
13
+ from pocket_tts.utils.config import FlowLMTransformerConfig
14
+
15
+
16
+ class KVCacheResult(NamedTuple):
17
+ keys: torch.Tensor
18
+ values: torch.Tensor
19
+ positions: torch.Tensor
20
+
21
+ @staticmethod
22
+ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
23
+ B, H, T, D = keys.shape
24
+ assert tuple(values.shape[:-1]) == (B, H, T)
25
+ positions = torch.arange(T, device=keys.device, dtype=torch.long)
26
+ return KVCacheResult(keys, values, positions.expand(B, -1))
27
+
28
+
29
+ def complete(
30
+ cache: torch.Tensor, end_offset: torch.Tensor, k: torch.Tensor, v: torch.Tensor
31
+ ) -> KVCacheResult:
32
+ capacity = cache.shape[3]
33
+ assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
34
+ B, H, T, D = k.shape
35
+ assert T > 0
36
+ indexes = torch.arange(T, device=end_offset.device, dtype=end_offset.dtype)
37
+ indexes = indexes + end_offset.view(-1, 1)
38
+ indexes = indexes % capacity
39
+ # indexes is [B, T]
40
+ # k is [B, H, T, D]
41
+ # cache is [B, H, T', D]
42
+ this_indexes = indexes.view(B, 1, T, 1)
43
+ this_indexes = this_indexes.expand(-1, H, T, D)
44
+ cache[0].scatter_(2, this_indexes, k)
45
+ cache[1].scatter_(2, this_indexes, v)
46
+
47
+ keys = cache[0]
48
+ values = cache[1]
49
+
50
+ indexes = torch.arange(capacity, device=end_offset.device, dtype=torch.long)
51
+
52
+ # end_index correspond to the actual index where the last value was written.
53
+ last_offset = end_offset.view(-1, 1) + T - 1
54
+ end_index = last_offset % capacity
55
+ delta = indexes - end_index
56
+
57
+ positions = torch.where(delta <= 0, last_offset + delta, last_offset + delta - capacity)
58
+ end_offset[:] = end_offset + T
59
+ invalid = indexes >= end_offset.view(-1, 1)
60
+ positions = torch.where(invalid, torch.full_like(positions, -1), positions)
61
+
62
+ return KVCacheResult(keys, values, positions)
63
+
64
+
65
+ class MimiStreamingMultiheadAttention(StatefulModule):
66
+ def __init__(self, embed_dim: int, num_heads: int, context: int, rope: RotaryEmbedding):
67
+ super().__init__()
68
+
69
+ self.embed_dim = embed_dim
70
+ self.context = context
71
+ self.rope = rope
72
+ self.num_heads = num_heads
73
+ out_dim = 3 * embed_dim
74
+
75
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
76
+ self.in_proj = nn.Linear(embed_dim, out_dim, bias=False)
77
+
78
+ def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
79
+ dim_per_head = self.embed_dim // self.num_heads
80
+
81
+ state = {}
82
+ state["offset"] = torch.zeros(batch_size, dtype=torch.long)
83
+ state["cache"] = torch.zeros((2, batch_size, self.num_heads, sequence_length, dim_per_head))
84
+ state["end_offset"] = torch.zeros(batch_size, dtype=torch.long)
85
+ return state
86
+
87
+ def increment_step(self, state, increment: int = 1):
88
+ state["offset"] += increment
89
+
90
+ def _complete_kv(self, k, v, model_state: dict | None) -> KVCacheResult:
91
+ if model_state is None:
92
+ return KVCacheResult.from_kv(k, v)
93
+ else:
94
+ layer_state = self.get_state(model_state)
95
+ return complete(layer_state["cache"], layer_state["end_offset"], k, v)
96
+
97
+ def forward(self, query: torch.Tensor, model_state: dict | None) -> torch.Tensor:
98
+ B, T = query.shape[:2]
99
+
100
+ if model_state is None:
101
+ offset = torch.zeros(B, device=query.device, dtype=torch.long)
102
+ else:
103
+ offset = self.get_state(model_state)["offset"]
104
+
105
+ projected = self.in_proj(query)
106
+
107
+ q, k, v = rearrange(projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads)
108
+
109
+ # Permute from [b, h, t, d] to [b, t, h, d] for rope
110
+ q = q.permute(0, 2, 1, 3)
111
+ k = k.permute(0, 2, 1, 3)
112
+ q, k = self.rope(q, k, offset)
113
+ # Permute back from [b, t, h, d] to [b, h, t, d]
114
+ q = q.permute(0, 2, 1, 3)
115
+ k = k.permute(0, 2, 1, 3)
116
+
117
+ k, v, pos_k = self._complete_kv(k, v, model_state)
118
+ pos_k = pos_k[:, None]
119
+ pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(
120
+ -1, 1
121
+ )
122
+ delta = pos_q - pos_k
123
+ attn_bias = (pos_k >= 0) & (delta >= 0)
124
+ attn_bias = attn_bias & (delta < self.context)
125
+ attn_bias = attn_bias[:, None]
126
+
127
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
128
+
129
+ x = rearrange(x, "b h t d -> b t (h d)")
130
+ x = self.out_proj(x)
131
+ return x
132
+
133
+
134
+ class StreamingTransformerLayer(nn.Module):
135
+ def __init__(
136
+ self,
137
+ d_model: int,
138
+ num_heads: int,
139
+ dim_feedforward: int,
140
+ context: int | None,
141
+ rope: RotaryEmbedding,
142
+ layer_scale: float | None = None,
143
+ attention_kind: str = "mimi",
144
+ ):
145
+ super().__init__()
146
+ # Redefine self_attn to our streaming multi-head attention
147
+ if attention_kind == "mimi":
148
+ # TODO: we should actually use StreamingMultiheadAttention here and add context window
149
+ # support. And we should then delete MimiStreamingMultiheadAttention.
150
+ # The implementation is really close.
151
+ self.self_attn = MimiStreamingMultiheadAttention(
152
+ context=context, rope=rope, embed_dim=d_model, num_heads=num_heads
153
+ )
154
+ else:
155
+ self.self_attn = StreamingMultiheadAttention(
156
+ rope=rope, embed_dim=d_model, num_heads=num_heads
157
+ )
158
+ self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
159
+ self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
160
+
161
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
162
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
163
+
164
+ if layer_scale is None:
165
+ self.layer_scale_1 = nn.Identity()
166
+ self.layer_scale_2 = nn.Identity()
167
+ else:
168
+ self.layer_scale_1 = LayerScale(d_model, layer_scale)
169
+ self.layer_scale_2 = LayerScale(d_model, layer_scale)
170
+
171
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
172
+ x_orig = x
173
+ x = self.norm2(x)
174
+ update = self.linear2(F.gelu(self.linear1(x)))
175
+ return x_orig.to(update) + self.layer_scale_2(update)
176
+
177
+ def _sa_block(self, x: torch.Tensor, model_state: dict | None) -> torch.Tensor:
178
+ x_orig = x
179
+ x = self.norm1(x)
180
+ update = self.self_attn(x, model_state)
181
+ return x_orig.to(update) + self.layer_scale_1(update)
182
+
183
+ def forward(self, x: torch.Tensor, model_state: dict | None) -> torch.Tensor:
184
+ x = self._sa_block(x, model_state)
185
+ x = self._ff_block(x)
186
+ return x
187
+
188
+
189
+ class StreamingTransformer(nn.Module):
190
+ def __init__(
191
+ self,
192
+ d_model: int,
193
+ num_heads: int,
194
+ num_layers: int,
195
+ layer_scale: float | None = None,
196
+ dim_feedforward: int | list[int] = 2048,
197
+ context: int | None = None,
198
+ max_period: float = 10_000.0,
199
+ kind: str = "mimi",
200
+ ):
201
+ super().__init__()
202
+ assert d_model % num_heads == 0
203
+ self.max_period = max_period
204
+
205
+ self.rope = RotaryEmbedding(max_period=max_period)
206
+
207
+ self.layers = nn.ModuleList()
208
+ for _ in range(num_layers):
209
+ self.layers.append(
210
+ StreamingTransformerLayer(
211
+ d_model=d_model,
212
+ num_heads=num_heads,
213
+ dim_feedforward=dim_feedforward,
214
+ context=context,
215
+ rope=self.rope,
216
+ layer_scale=layer_scale,
217
+ attention_kind=kind,
218
+ )
219
+ )
220
+
221
+ @classmethod
222
+ def from_pydantic_config(cls, config: FlowLMTransformerConfig) -> Self:
223
+ dim_feedforward = int(config.d_model * config.hidden_scale)
224
+ return cls(
225
+ d_model=config.d_model,
226
+ num_heads=config.num_heads,
227
+ num_layers=config.num_layers,
228
+ dim_feedforward=dim_feedforward,
229
+ max_period=float(config.max_period),
230
+ kind="flow_lm",
231
+ )
232
+
233
+ def forward(self, x: torch.Tensor, model_state: dict | None):
234
+ for layer in self.layers:
235
+ x = layer(x, model_state)
236
+ return x
237
+
238
+
239
+ class ProjectedTransformer(nn.Module):
240
+ def __init__(
241
+ self,
242
+ input_dimension: int,
243
+ output_dimensions: tuple[int, ...],
244
+ d_model: int,
245
+ num_heads: int,
246
+ num_layers: int,
247
+ layer_scale: float,
248
+ context: int,
249
+ max_period: float,
250
+ dim_feedforward: int,
251
+ ):
252
+ super().__init__()
253
+ self.transformer = StreamingTransformer(
254
+ d_model=d_model,
255
+ num_heads=num_heads,
256
+ num_layers=num_layers,
257
+ layer_scale=layer_scale,
258
+ context=context,
259
+ max_period=max_period,
260
+ dim_feedforward=dim_feedforward,
261
+ )
262
+ self.input_dimension = input_dimension
263
+ self.output_dimensions = output_dimensions
264
+ self.input_proj = None
265
+ if d_model != input_dimension:
266
+ self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
267
+
268
+ self.output_projs = nn.ModuleList()
269
+ for output_dimension in output_dimensions:
270
+ if d_model == output_dimension:
271
+ self.output_projs.append(nn.Identity())
272
+ else:
273
+ self.output_projs.append(nn.Linear(d_model, output_dimension, bias=False))
274
+
275
+ def forward(self, x, model_state: dict | None):
276
+ x = x.transpose(1, 2)
277
+ if self.input_proj is not None:
278
+ x = self.input_proj(x)
279
+ z = self.transformer(x, model_state)
280
+ ys = []
281
+ for output_proj in self.output_projs:
282
+ y = output_proj(z)
283
+ y = y.transpose(1, 2)
284
+ ys.append(y)
285
+ return ys
@@ -0,0 +1,215 @@
1
+ """
2
+ Taken from
3
+ https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py
4
+
5
+ """
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing_extensions import Self
12
+
13
+ from pocket_tts.utils.config import FlowLMConfig
14
+
15
+
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale) + shift
18
+
19
+
20
+ def _rms_norm(x: torch.Tensor, alpha: torch.Tensor, eps: float):
21
+ assert x.dim() >= alpha.dim()
22
+ x_dtype = x.dtype
23
+ var = eps + x.var(dim=-1, keepdim=True)
24
+ y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
25
+ return y
26
+
27
+
28
+ class RMSNorm(nn.Module):
29
+ def __init__(self, dim: int, eps: float = 1e-5):
30
+ super().__init__()
31
+ self.eps = eps
32
+ alpha_shape = (dim,)
33
+ self.alpha = nn.Parameter(torch.full(alpha_shape, 1.0, requires_grad=True))
34
+
35
+ def forward(self, x: torch.Tensor):
36
+ return _rms_norm(x, self.alpha, self.eps)
37
+
38
+
39
+ class LayerNorm(nn.Module):
40
+ """Reimplementation of LayerNorm because the default one doesn't support jvp."""
41
+
42
+ def __init__(self, channels, eps=1e-6, elementwise_affine=True):
43
+ super().__init__()
44
+ self.eps = eps
45
+ if elementwise_affine:
46
+ self.weight = nn.Parameter(torch.ones(channels))
47
+ self.bias = nn.Parameter(torch.zeros(channels))
48
+
49
+ def forward(self, x):
50
+ mean = x.mean(dim=-1, keepdim=True)
51
+ var = x.var(dim=-1, unbiased=False, keepdim=True)
52
+ x = (x - mean) / torch.sqrt(var + self.eps)
53
+ if hasattr(self, "weight"):
54
+ x = x * self.weight + self.bias
55
+ return x
56
+
57
+
58
+ class TimestepEmbedder(nn.Module):
59
+ """Embeds scalar timesteps into vector representations."""
60
+
61
+ def __init__(
62
+ self, hidden_size: int, frequency_embedding_size: int = 256, max_period: int = 10000
63
+ ):
64
+ super().__init__()
65
+ blocks = [
66
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size, bias=True),
69
+ ]
70
+ blocks.append(RMSNorm(hidden_size))
71
+ self.mlp = nn.Sequential(*blocks)
72
+ self.frequency_embedding_size = frequency_embedding_size
73
+ half = frequency_embedding_size // 2
74
+ self.register_buffer(
75
+ "freqs", torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half)
76
+ )
77
+
78
+ def forward(self, t):
79
+ args = t * self.freqs.to(t.dtype)
80
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
81
+ assert not (self.frequency_embedding_size % 2)
82
+ t_emb = self.mlp(embedding)
83
+ return t_emb
84
+
85
+
86
+ class ResBlock(nn.Module):
87
+ """
88
+ A residual block that can optionally change the number of channels.
89
+ :param channels: the number of input channels.
90
+ """
91
+
92
+ def __init__(self, channels):
93
+ super().__init__()
94
+ self.channels = channels
95
+
96
+ self.in_ln = LayerNorm(channels, eps=1e-6)
97
+ self.mlp = nn.Sequential(
98
+ nn.Linear(channels, channels, bias=True),
99
+ nn.SiLU(),
100
+ nn.Linear(channels, channels, bias=True),
101
+ )
102
+
103
+ self.adaLN_modulation = nn.Sequential(
104
+ nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
105
+ )
106
+
107
+ def forward(self, x, y):
108
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
109
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
110
+ h = self.mlp(h)
111
+ return x + gate_mlp * h
112
+
113
+
114
+ class FinalLayer(nn.Module):
115
+ """
116
+ The final layer adopted from DiT.
117
+ """
118
+
119
+ def __init__(self, model_channels, out_channels):
120
+ super().__init__()
121
+ self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
122
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
123
+ self.adaLN_modulation = nn.Sequential(
124
+ nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)
125
+ )
126
+
127
+ def forward(self, x, c):
128
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
129
+ x = modulate(self.norm_final(x), shift, scale)
130
+ x = self.linear(x)
131
+ return x
132
+
133
+
134
+ class SimpleMLPAdaLN(nn.Module):
135
+ """Taken from https://arxiv.org/abs/2406.11838.
136
+
137
+ The MLP for Diffusion Loss.
138
+ :param in_channels: channels in the input Tensor.
139
+ :param model_channels: base channel count for the model.
140
+ :param out_channels: channels in the output Tensor.
141
+ :param cond_channels: channels in the condition.
142
+ :param num_res_blocks: number of residual blocks per downsample.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ model_channels,
149
+ out_channels,
150
+ cond_channels,
151
+ num_res_blocks,
152
+ num_time_conds=1,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.in_channels = in_channels
157
+ self.model_channels = model_channels
158
+ self.out_channels = out_channels
159
+ self.num_res_blocks = num_res_blocks
160
+ self.num_time_conds = num_time_conds
161
+
162
+ assert num_time_conds != 1
163
+ self.time_embed = nn.ModuleList(
164
+ [TimestepEmbedder(model_channels) for _ in range(num_time_conds)]
165
+ )
166
+ self.cond_embed = nn.Linear(cond_channels, model_channels)
167
+
168
+ self.input_proj = nn.Linear(in_channels, model_channels)
169
+
170
+ res_blocks = []
171
+ for i in range(num_res_blocks):
172
+ res_blocks.append(ResBlock(model_channels))
173
+
174
+ self.res_blocks = nn.ModuleList(res_blocks)
175
+ self.final_layer = FinalLayer(model_channels, out_channels)
176
+
177
+ @classmethod
178
+ def from_pydantic_config(cls, cfg: FlowLMConfig, latent_dim: int, cond_dim: int) -> Self:
179
+ config = cfg.flow
180
+
181
+ flow_dim = config.dim
182
+ flow_depth = config.depth
183
+ num_time_conds = 2
184
+ return SimpleMLPAdaLN(
185
+ latent_dim, flow_dim, latent_dim, cond_dim, flow_depth, num_time_conds=num_time_conds
186
+ )
187
+
188
+ def forward(
189
+ self, c: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x: torch.Tensor
190
+ ) -> torch.Tensor:
191
+ """
192
+ Apply the model to an input batch.
193
+ :param c: conditioning from AR transformer.
194
+ :param s: start time tensor.
195
+ :param t: target time tensor.
196
+ :param x: an [N x C] Tensor of inputs.
197
+ :return: an [N x C] Tensor of outputs.
198
+ """
199
+ # Combine time conditions
200
+ ts = [s, t]
201
+ x = self.input_proj(x)
202
+ assert len(ts) == self.num_time_conds, (
203
+ f"Expected {self.num_time_conds} time conditions, got {len(ts)}"
204
+ )
205
+ assert self.num_time_conds != 1
206
+ t_combined = (
207
+ sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds)) / self.num_time_conds
208
+ )
209
+ c = self.cond_embed(c)
210
+ y = t_combined + c
211
+
212
+ for block in self.res_blocks:
213
+ x = block(x, y)
214
+
215
+ return self.final_layer(x, y)