cortexflowx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexflow/__init__.py +78 -0
- cortexflow/_types.py +115 -0
- cortexflow/brain2audio.py +298 -0
- cortexflow/brain2img.py +234 -0
- cortexflow/brain2text.py +278 -0
- cortexflow/brain_encoder.py +228 -0
- cortexflow/dit.py +397 -0
- cortexflow/flow_matching.py +236 -0
- cortexflow/training.py +283 -0
- cortexflow/vae.py +232 -0
- cortexflowx-0.1.0.dist-info/METADATA +218 -0
- cortexflowx-0.1.0.dist-info/RECORD +14 -0
- cortexflowx-0.1.0.dist-info/WHEEL +4 -0
- cortexflowx-0.1.0.dist-info/licenses/LICENSE +190 -0
cortexflow/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""CortexFlow — Brain decoding with Diffusion Transformers & Flow Matching.
|
|
2
|
+
|
|
3
|
+
Reconstruct what someone saw, heard, or thought from fMRI brain activity.
|
|
4
|
+
|
|
5
|
+
Architecture: fMRI → BrainEncoder → DiT (flow matching) → stimulus
|
|
6
|
+
|
|
7
|
+
Modules:
|
|
8
|
+
dit Diffusion Transformer backbone (AdaLN-Zero, QK-Norm, SwiGLU)
|
|
9
|
+
flow_matching Rectified flow training & ODE sampling
|
|
10
|
+
vae Latent space compression (2D images, 1D audio)
|
|
11
|
+
brain_encoder fMRI → conditioning embeddings (global + tokens)
|
|
12
|
+
brain2img Full brain → image pipeline
|
|
13
|
+
brain2audio Full brain → audio pipeline
|
|
14
|
+
brain2text Full brain → text pipeline (autoregressive)
|
|
15
|
+
training Training loops, schedulers, synthetic data
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from cortexflow._types import (
|
|
19
|
+
BrainData,
|
|
20
|
+
CortexFlowError,
|
|
21
|
+
DiTConfig,
|
|
22
|
+
FlowConfig,
|
|
23
|
+
Modality,
|
|
24
|
+
ReconstructionResult,
|
|
25
|
+
TrainingConfig,
|
|
26
|
+
VAEConfig,
|
|
27
|
+
)
|
|
28
|
+
from cortexflow.brain2audio import AudioDiT, Brain2Audio, build_brain2audio
|
|
29
|
+
from cortexflow.brain2img import Brain2Image, build_brain2img
|
|
30
|
+
from cortexflow.brain2text import Brain2Text, BrainTextDecoder, build_brain2text
|
|
31
|
+
from cortexflow.brain_encoder import (
|
|
32
|
+
BrainEncoder,
|
|
33
|
+
ROIBrainEncoder,
|
|
34
|
+
SubjectAdapter,
|
|
35
|
+
make_synthetic_brain_data,
|
|
36
|
+
)
|
|
37
|
+
from cortexflow.dit import DiffusionTransformer
|
|
38
|
+
from cortexflow.flow_matching import EMAModel, RectifiedFlowMatcher
|
|
39
|
+
from cortexflow.training import SyntheticBrainDataset, Trainer, WarmupCosineScheduler
|
|
40
|
+
from cortexflow.vae import AudioVAE, LatentVAE
|
|
41
|
+
|
|
42
|
+
__version__ = "0.1.0"
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# Types
|
|
46
|
+
"BrainData",
|
|
47
|
+
"ReconstructionResult",
|
|
48
|
+
"DiTConfig",
|
|
49
|
+
"FlowConfig",
|
|
50
|
+
"VAEConfig",
|
|
51
|
+
"TrainingConfig",
|
|
52
|
+
"Modality",
|
|
53
|
+
"CortexFlowError",
|
|
54
|
+
# Core
|
|
55
|
+
"DiffusionTransformer",
|
|
56
|
+
"RectifiedFlowMatcher",
|
|
57
|
+
"EMAModel",
|
|
58
|
+
"LatentVAE",
|
|
59
|
+
"AudioVAE",
|
|
60
|
+
"BrainEncoder",
|
|
61
|
+
"ROIBrainEncoder",
|
|
62
|
+
"SubjectAdapter",
|
|
63
|
+
# Pipelines
|
|
64
|
+
"Brain2Image",
|
|
65
|
+
"Brain2Audio",
|
|
66
|
+
"Brain2Text",
|
|
67
|
+
"AudioDiT",
|
|
68
|
+
"BrainTextDecoder",
|
|
69
|
+
# Factories
|
|
70
|
+
"build_brain2img",
|
|
71
|
+
"build_brain2audio",
|
|
72
|
+
"build_brain2text",
|
|
73
|
+
"make_synthetic_brain_data",
|
|
74
|
+
# Training
|
|
75
|
+
"Trainer",
|
|
76
|
+
"WarmupCosineScheduler",
|
|
77
|
+
"SyntheticBrainDataset",
|
|
78
|
+
]
|
cortexflow/_types.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Core data types for cortexflow."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Modality(str, Enum):
|
|
13
|
+
"""Output modality for brain decoding."""
|
|
14
|
+
|
|
15
|
+
IMAGE = "image"
|
|
16
|
+
AUDIO = "audio"
|
|
17
|
+
TEXT = "text"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class BrainData:
|
|
22
|
+
"""Container for fMRI brain activity data.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
voxels: Tensor of shape ``(batch, n_voxels)`` or ``(n_voxels,)``.
|
|
26
|
+
subject_id: Optional subject identifier for subject-specific adapters.
|
|
27
|
+
roi_mask: Optional boolean mask indicating which voxels belong to ROIs.
|
|
28
|
+
tr: Repetition time in seconds (fMRI temporal resolution).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
voxels: torch.Tensor
|
|
32
|
+
subject_id: str | None = None
|
|
33
|
+
roi_mask: torch.Tensor | None = None
|
|
34
|
+
tr: float = 2.0
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def n_voxels(self) -> int:
|
|
38
|
+
return self.voxels.shape[-1]
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def batch_size(self) -> int:
|
|
42
|
+
if self.voxels.ndim == 1:
|
|
43
|
+
return 1
|
|
44
|
+
return self.voxels.shape[0]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ReconstructionResult:
|
|
49
|
+
"""Result of a brain decoding reconstruction."""
|
|
50
|
+
|
|
51
|
+
modality: Modality
|
|
52
|
+
output: torch.Tensor # decoded output (image / waveform / token ids)
|
|
53
|
+
brain_condition: torch.Tensor # the conditioning embeddings used
|
|
54
|
+
n_steps: int = 50 # diffusion steps used
|
|
55
|
+
cfg_scale: float = 1.0 # classifier-free guidance scale
|
|
56
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class TrainingConfig:
|
|
61
|
+
"""Configuration for training a brain decoder."""
|
|
62
|
+
|
|
63
|
+
learning_rate: float = 1e-4
|
|
64
|
+
batch_size: int = 16
|
|
65
|
+
n_epochs: int = 100
|
|
66
|
+
warmup_steps: int = 500
|
|
67
|
+
weight_decay: float = 0.01
|
|
68
|
+
ema_decay: float = 0.9999
|
|
69
|
+
grad_clip: float = 1.0
|
|
70
|
+
mixed_precision: bool = False
|
|
71
|
+
log_every: int = 100
|
|
72
|
+
save_every: int = 1000
|
|
73
|
+
eval_every: int = 500
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class DiTConfig:
|
|
78
|
+
"""Configuration for the Diffusion Transformer."""
|
|
79
|
+
|
|
80
|
+
in_channels: int = 4 # latent channels
|
|
81
|
+
hidden_dim: int = 768
|
|
82
|
+
depth: int = 12
|
|
83
|
+
num_heads: int = 12
|
|
84
|
+
patch_size: int = 2
|
|
85
|
+
cond_dim: int = 768 # brain conditioning dimension
|
|
86
|
+
mlp_ratio: float = 4.0
|
|
87
|
+
qk_norm: bool = True
|
|
88
|
+
use_cross_attn: bool = True
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class VAEConfig:
|
|
93
|
+
"""Configuration for the latent VAE."""
|
|
94
|
+
|
|
95
|
+
in_channels: int = 3 # RGB
|
|
96
|
+
latent_channels: int = 4
|
|
97
|
+
hidden_dims: list[int] = field(default_factory=lambda: [64, 128, 256, 512])
|
|
98
|
+
kl_weight: float = 1e-6
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class FlowConfig:
|
|
103
|
+
"""Configuration for flow matching."""
|
|
104
|
+
|
|
105
|
+
num_steps: int = 50
|
|
106
|
+
sigma_min: float = 1e-5
|
|
107
|
+
logit_normal: bool = True # SD3-style timestep sampling
|
|
108
|
+
logit_normal_mean: float = 0.0
|
|
109
|
+
logit_normal_std: float = 1.0
|
|
110
|
+
solver: str = "euler" # "euler" or "midpoint"
|
|
111
|
+
cfg_scale: float = 4.0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class CortexFlowError(Exception):
|
|
115
|
+
"""Base exception for cortexflow."""
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
"""Brain → Audio reconstruction pipeline.
|
|
2
|
+
|
|
3
|
+
Reconstruct what someone heard from their fMRI activity.
|
|
4
|
+
|
|
5
|
+
Architecture: fMRI → BrainEncoder → DiT (flow matching on mel spectrograms)
|
|
6
|
+
→ Griffin-Lim or learned vocoder → waveform.
|
|
7
|
+
|
|
8
|
+
The DiT operates on 1D latent sequences (compressed mel spectrograms)
|
|
9
|
+
rather than 2D spatial maps.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
from cortexflow._types import (
|
|
21
|
+
BrainData,
|
|
22
|
+
DiTConfig,
|
|
23
|
+
FlowConfig,
|
|
24
|
+
Modality,
|
|
25
|
+
ReconstructionResult,
|
|
26
|
+
)
|
|
27
|
+
from cortexflow.brain_encoder import BrainEncoder
|
|
28
|
+
from cortexflow.flow_matching import RectifiedFlowMatcher
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ── 1D DiT for audio ────────────────────────────────────────────────────
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DiTBlock1D(nn.Module):
|
|
35
|
+
"""DiT block adapted for 1D sequences (audio spectrograms)."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, hidden_dim: int, num_heads: int, cond_dim: int, mlp_ratio: float = 4.0):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, 6 * hidden_dim))
|
|
40
|
+
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
41
|
+
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
|
|
42
|
+
self.norm_cross = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
43
|
+
self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
|
|
44
|
+
self.cond_proj = nn.Linear(cond_dim, hidden_dim) if cond_dim != hidden_dim else nn.Identity()
|
|
45
|
+
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
|
|
46
|
+
mlp_h = int(hidden_dim * mlp_ratio)
|
|
47
|
+
self.mlp = nn.Sequential(
|
|
48
|
+
nn.Linear(hidden_dim, mlp_h), nn.SiLU(),
|
|
49
|
+
nn.Linear(mlp_h, hidden_dim),
|
|
50
|
+
)
|
|
51
|
+
nn.init.zeros_(self.adaLN[-1].weight)
|
|
52
|
+
nn.init.zeros_(self.adaLN[-1].bias)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor, c: torch.Tensor, brain_tokens: torch.Tensor | None = None):
|
|
55
|
+
s1, sc1, g1, s2, sc2, g2 = self.adaLN(c).chunk(6, dim=-1)
|
|
56
|
+
|
|
57
|
+
h = self.norm1(x) * (1 + sc1.unsqueeze(1)) + s1.unsqueeze(1)
|
|
58
|
+
h, _ = self.attn(h, h, h, need_weights=False)
|
|
59
|
+
x = x + g1.unsqueeze(1) * h
|
|
60
|
+
|
|
61
|
+
if brain_tokens is not None:
|
|
62
|
+
h = self.norm_cross(x)
|
|
63
|
+
kv = self.cond_proj(brain_tokens)
|
|
64
|
+
h, _ = self.cross_attn(h, kv, kv, need_weights=False)
|
|
65
|
+
x = x + h
|
|
66
|
+
|
|
67
|
+
h = self.norm2(x) * (1 + sc2.unsqueeze(1)) + s2.unsqueeze(1)
|
|
68
|
+
h = self.mlp(h)
|
|
69
|
+
x = x + g2.unsqueeze(1) * h
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AudioDiT(nn.Module):
|
|
74
|
+
"""1D Diffusion Transformer for mel spectrogram generation.
|
|
75
|
+
|
|
76
|
+
Operates on a sequence of mel frames (or compressed latents).
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
n_mels: int = 80,
|
|
82
|
+
seq_len: int = 128,
|
|
83
|
+
hidden_dim: int = 256,
|
|
84
|
+
depth: int = 6,
|
|
85
|
+
num_heads: int = 8,
|
|
86
|
+
cond_dim: int = 256,
|
|
87
|
+
):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.n_mels = n_mels
|
|
90
|
+
self.seq_len = seq_len
|
|
91
|
+
self.input_proj = nn.Linear(n_mels, hidden_dim)
|
|
92
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, hidden_dim))
|
|
93
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
94
|
+
|
|
95
|
+
# Timestep embedding
|
|
96
|
+
self.time_embed = nn.Sequential(
|
|
97
|
+
nn.Linear(hidden_dim, hidden_dim * 4), nn.SiLU(),
|
|
98
|
+
nn.Linear(hidden_dim * 4, hidden_dim),
|
|
99
|
+
)
|
|
100
|
+
self.time_dim = hidden_dim
|
|
101
|
+
self.cond_proj = nn.Sequential(nn.Linear(cond_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))
|
|
102
|
+
|
|
103
|
+
self.blocks = nn.ModuleList([
|
|
104
|
+
DiTBlock1D(hidden_dim, num_heads, cond_dim) for _ in range(depth)
|
|
105
|
+
])
|
|
106
|
+
self.final_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
107
|
+
self.output_proj = nn.Linear(hidden_dim, n_mels)
|
|
108
|
+
nn.init.zeros_(self.output_proj.weight)
|
|
109
|
+
nn.init.zeros_(self.output_proj.bias)
|
|
110
|
+
|
|
111
|
+
def _sinusoidal_embed(self, t: torch.Tensor) -> torch.Tensor:
|
|
112
|
+
half = self.time_dim // 2
|
|
113
|
+
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device).float() / half)
|
|
114
|
+
args = t.view(-1, 1) * freqs.unsqueeze(0)
|
|
115
|
+
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
116
|
+
|
|
117
|
+
def forward(
|
|
118
|
+
self,
|
|
119
|
+
x: torch.Tensor,
|
|
120
|
+
t: torch.Tensor,
|
|
121
|
+
brain_global: torch.Tensor,
|
|
122
|
+
brain_tokens: torch.Tensor | None = None,
|
|
123
|
+
) -> torch.Tensor:
|
|
124
|
+
"""Predict velocity for mel spectrogram flow matching.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
x: ``(B, n_mels, T)`` noisy mel spectrogram.
|
|
128
|
+
t: ``(B,)`` timestep.
|
|
129
|
+
brain_global: ``(B, cond_dim)``
|
|
130
|
+
brain_tokens: ``(B, K, cond_dim)``
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
``(B, n_mels, T)`` predicted velocity.
|
|
134
|
+
"""
|
|
135
|
+
B = x.shape[0]
|
|
136
|
+
# (B, n_mels, T) → (B, T, n_mels) → (B, T, hidden)
|
|
137
|
+
x = x.transpose(1, 2)
|
|
138
|
+
x = self.input_proj(x) + self.pos_embed[:, :x.shape[1]]
|
|
139
|
+
|
|
140
|
+
t_emb = self.time_embed(self._sinusoidal_embed(t))
|
|
141
|
+
c = t_emb + self.cond_proj(brain_global)
|
|
142
|
+
|
|
143
|
+
for block in self.blocks:
|
|
144
|
+
x = block(x, c, brain_tokens)
|
|
145
|
+
|
|
146
|
+
x = self.final_norm(x)
|
|
147
|
+
x = self.output_proj(x) # (B, T, n_mels)
|
|
148
|
+
return x.transpose(1, 2) # (B, n_mels, T)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ── Brain2Audio Pipeline ────────────────────────────────────────────────
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class Brain2Audio(nn.Module):
|
|
155
|
+
"""Reconstruct audio from brain activity using 1D DiT + flow matching.
|
|
156
|
+
|
|
157
|
+
Pipeline::
|
|
158
|
+
|
|
159
|
+
fMRI → BrainEncoder → AudioDiT (flow matching) → mel spectrogram
|
|
160
|
+
→ Griffin-Lim → waveform
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
n_voxels: Number of fMRI voxels.
|
|
164
|
+
n_mels: Number of mel frequency bins.
|
|
165
|
+
audio_len: Number of mel spectrogram frames.
|
|
166
|
+
sample_rate: Audio sample rate (for mel spectrogram computation).
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
n_voxels: int = 1024,
|
|
172
|
+
n_mels: int = 80,
|
|
173
|
+
audio_len: int = 128,
|
|
174
|
+
hidden_dim: int = 256,
|
|
175
|
+
depth: int = 6,
|
|
176
|
+
num_heads: int = 8,
|
|
177
|
+
sample_rate: int = 16000,
|
|
178
|
+
):
|
|
179
|
+
super().__init__()
|
|
180
|
+
self.n_mels = n_mels
|
|
181
|
+
self.audio_len = audio_len
|
|
182
|
+
self.sample_rate = sample_rate
|
|
183
|
+
|
|
184
|
+
cond_dim = hidden_dim
|
|
185
|
+
self.brain_encoder = BrainEncoder(
|
|
186
|
+
n_voxels=n_voxels, cond_dim=cond_dim, n_tokens=16,
|
|
187
|
+
)
|
|
188
|
+
self.dit = AudioDiT(
|
|
189
|
+
n_mels=n_mels, seq_len=audio_len, hidden_dim=hidden_dim,
|
|
190
|
+
depth=depth, num_heads=num_heads, cond_dim=cond_dim,
|
|
191
|
+
)
|
|
192
|
+
self.flow_matcher = RectifiedFlowMatcher(FlowConfig())
|
|
193
|
+
|
|
194
|
+
# Unconditional embeddings for CFG
|
|
195
|
+
self.uncond_global = nn.Parameter(torch.zeros(1, cond_dim))
|
|
196
|
+
self.uncond_tokens = nn.Parameter(torch.zeros(1, 16, cond_dim))
|
|
197
|
+
|
|
198
|
+
def training_loss(self, mel: torch.Tensor, brain_data: BrainData) -> torch.Tensor:
|
|
199
|
+
"""Flow matching loss on mel spectrograms.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
mel: Target mel spectrogram ``(B, n_mels, T)``.
|
|
203
|
+
brain_data: Corresponding fMRI.
|
|
204
|
+
"""
|
|
205
|
+
brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
206
|
+
return self.flow_matcher.compute_loss(self.dit, mel, brain_global, brain_tokens)
|
|
207
|
+
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def reconstruct(
|
|
210
|
+
self,
|
|
211
|
+
brain_data: BrainData,
|
|
212
|
+
num_steps: int = 50,
|
|
213
|
+
cfg_scale: float = 3.0,
|
|
214
|
+
) -> ReconstructionResult:
|
|
215
|
+
"""Reconstruct audio mel spectrogram from brain activity."""
|
|
216
|
+
B = brain_data.batch_size
|
|
217
|
+
device = brain_data.voxels.device
|
|
218
|
+
brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
219
|
+
|
|
220
|
+
mel_shape = (B, self.n_mels, self.audio_len)
|
|
221
|
+
mel = self.flow_matcher.sample(
|
|
222
|
+
self.dit, mel_shape, brain_global, brain_tokens,
|
|
223
|
+
num_steps=num_steps, cfg_scale=cfg_scale,
|
|
224
|
+
brain_global_uncond=self.uncond_global.expand(B, -1),
|
|
225
|
+
brain_tokens_uncond=self.uncond_tokens.expand(B, -1, -1),
|
|
226
|
+
)
|
|
227
|
+
return ReconstructionResult(
|
|
228
|
+
modality=Modality.AUDIO,
|
|
229
|
+
output=mel,
|
|
230
|
+
brain_condition=brain_global,
|
|
231
|
+
n_steps=num_steps,
|
|
232
|
+
cfg_scale=cfg_scale,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def mel_to_waveform(mel: torch.Tensor, n_fft: int = 1024, hop_length: int = 256) -> torch.Tensor:
|
|
237
|
+
"""Convert mel spectrogram to waveform via Griffin-Lim.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
mel: ``(B, n_mels, T)`` mel spectrogram (linear scale).
|
|
241
|
+
n_fft: FFT size.
|
|
242
|
+
hop_length: Hop length.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
``(B, samples)`` waveform.
|
|
246
|
+
"""
|
|
247
|
+
B, n_mels, T = mel.shape
|
|
248
|
+
# Create simple mel filterbank inverse (pseudo-inverse approach)
|
|
249
|
+
n_freqs = n_fft // 2 + 1
|
|
250
|
+
# Approximate: project mel back to linear spectrogram
|
|
251
|
+
mel_basis = torch.linspace(0, 1, n_mels, device=mel.device).unsqueeze(1)
|
|
252
|
+
freq_basis = torch.linspace(0, 1, n_freqs, device=mel.device).unsqueeze(0)
|
|
253
|
+
filterbank = torch.exp(-0.5 * ((mel_basis - freq_basis) / 0.05) ** 2)
|
|
254
|
+
filterbank = filterbank / (filterbank.sum(dim=0, keepdim=True) + 1e-8)
|
|
255
|
+
|
|
256
|
+
# Mel → linear spectrogram (pseudo-inverse)
|
|
257
|
+
fb_pinv = filterbank.T # (n_freqs, n_mels)
|
|
258
|
+
magnitude = torch.matmul(fb_pinv.unsqueeze(0), mel.clamp(min=0)) # (B, n_freqs, T)
|
|
259
|
+
|
|
260
|
+
# Griffin-Lim: iterative phase estimation
|
|
261
|
+
window = torch.hann_window(n_fft, device=mel.device)
|
|
262
|
+
out_length = T * hop_length
|
|
263
|
+
phase = torch.randn(B, n_freqs, T, device=mel.device) * 2 * math.pi
|
|
264
|
+
for _ in range(32):
|
|
265
|
+
stft = magnitude * torch.exp(1j * phase)
|
|
266
|
+
waveform = torch.istft(
|
|
267
|
+
stft, n_fft=n_fft, hop_length=hop_length,
|
|
268
|
+
window=window, length=out_length,
|
|
269
|
+
)
|
|
270
|
+
new_stft = torch.stft(
|
|
271
|
+
waveform, n_fft=n_fft, hop_length=hop_length,
|
|
272
|
+
window=window, return_complex=True,
|
|
273
|
+
)
|
|
274
|
+
# Trim or pad to match magnitude shape
|
|
275
|
+
nt = min(new_stft.shape[-1], T)
|
|
276
|
+
phase = torch.zeros_like(magnitude)
|
|
277
|
+
phase[:, :, :nt] = new_stft[:, :, :nt].angle()
|
|
278
|
+
|
|
279
|
+
stft = magnitude * torch.exp(1j * phase)
|
|
280
|
+
waveform = torch.istft(
|
|
281
|
+
stft, n_fft=n_fft, hop_length=hop_length,
|
|
282
|
+
window=window, length=out_length,
|
|
283
|
+
)
|
|
284
|
+
return waveform
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def build_brain2audio(
|
|
288
|
+
n_voxels: int = 1024,
|
|
289
|
+
n_mels: int = 80,
|
|
290
|
+
audio_len: int = 128,
|
|
291
|
+
hidden_dim: int = 256,
|
|
292
|
+
depth: int = 6,
|
|
293
|
+
) -> Brain2Audio:
|
|
294
|
+
"""Build a Brain2Audio model with sensible defaults."""
|
|
295
|
+
return Brain2Audio(
|
|
296
|
+
n_voxels=n_voxels, n_mels=n_mels, audio_len=audio_len,
|
|
297
|
+
hidden_dim=hidden_dim, depth=depth,
|
|
298
|
+
)
|