cortexflowx 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/CHANGELOG.md +16 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/PKG-INFO +3 -3
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/README.md +2 -2
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/pyproject.toml +1 -1
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2audio.py +32 -5
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2img.py +24 -6
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2text.py +52 -10
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/dit.py +14 -1
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2audio.py +20 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2img.py +21 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2text.py +48 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/.github/workflows/ci.yml +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/.gitignore +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/CITATION.cff +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/LICENSE +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2audio_demo.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2img_demo.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2text_demo.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/__init__.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/_types.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain_encoder.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/flow_matching.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/training.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/vae.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/conftest.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain_encoder.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_dit.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_flow_matching.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_init.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_integration.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_training.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_types.py +0 -0
- {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_vae.py +0 -0
|
@@ -2,6 +2,22 @@
|
|
|
2
2
|
|
|
3
3
|
All notable changes to CortexFlow will be documented in this file.
|
|
4
4
|
|
|
5
|
+
## [0.2.0] - 2026-04-12
|
|
6
|
+
|
|
7
|
+
### Added
|
|
8
|
+
|
|
9
|
+
- **Semantic diversity**: All pipelines now support `num_samples` parameter to generate multiple diverse reconstructions per brain input. Each sample uses independent noise (image/audio) or independent random draws (text), producing semantically varied outputs.
|
|
10
|
+
- **Nucleus (top-p) sampling** for Brain2Text: `top_p` parameter enables nucleus filtering alongside top-k, giving finer control over text generation diversity.
|
|
11
|
+
- When `num_samples > 1`, output shapes become `(B, num_samples, ...)` for image/audio; text metadata returns grouped lists per brain input.
|
|
12
|
+
|
|
13
|
+
## [0.1.1] - 2025-06-26
|
|
14
|
+
|
|
15
|
+
### Fixed
|
|
16
|
+
|
|
17
|
+
- **DiT zero-init**: `_initialize_weights()` no longer overwrites critical zero-initialized gating parameters (AdaLN modulation, final layer projection). Output is now exactly zero at initialization for stable training.
|
|
18
|
+
- **Brain2Text BOS mismatch**: Training now prepends BOS token so the model learns to predict from BOS context, matching inference behavior.
|
|
19
|
+
- **Brain2Text empty output**: `reconstruct()` now correctly skips the BOS token when decoding generated sequences to text.
|
|
20
|
+
|
|
5
21
|
## [0.1.0] - 2025-06-25
|
|
6
22
|
|
|
7
23
|
### Added
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cortexflowx
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.
|
|
5
5
|
Project-URL: Homepage, https://github.com/stef41/cortexflow
|
|
6
6
|
Project-URL: Repository, https://github.com/stef41/cortexflow
|
|
@@ -65,12 +65,12 @@ fMRI voxels
|
|
|
65
65
|
## Installation
|
|
66
66
|
|
|
67
67
|
```bash
|
|
68
|
-
pip install
|
|
68
|
+
pip install cortexflowx
|
|
69
69
|
```
|
|
70
70
|
|
|
71
71
|
With audio support:
|
|
72
72
|
```bash
|
|
73
|
-
pip install
|
|
73
|
+
pip install cortexflowx[audio]
|
|
74
74
|
```
|
|
75
75
|
|
|
76
76
|
## Quick Start
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "cortexflowx"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0"
|
|
8
8
|
description = "Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = {text = "Apache-2.0"}
|
|
@@ -211,25 +211,52 @@ class Brain2Audio(nn.Module):
|
|
|
211
211
|
brain_data: BrainData,
|
|
212
212
|
num_steps: int = 50,
|
|
213
213
|
cfg_scale: float = 3.0,
|
|
214
|
+
num_samples: int = 1,
|
|
214
215
|
) -> ReconstructionResult:
|
|
215
|
-
"""Reconstruct audio mel spectrogram from brain activity.
|
|
216
|
+
"""Reconstruct audio mel spectrogram from brain activity.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
brain_data: fMRI data to decode.
|
|
220
|
+
num_steps: Number of ODE solver steps.
|
|
221
|
+
cfg_scale: Classifier-free guidance scale.
|
|
222
|
+
num_samples: Number of diverse samples per brain input.
|
|
223
|
+
Each sample uses independent noise, producing semantically
|
|
224
|
+
varied reconstructions. Output shape becomes
|
|
225
|
+
``(B, num_samples, n_mels, T)`` when ``num_samples > 1``.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
ReconstructionResult with the decoded mel spectrogram(s).
|
|
229
|
+
"""
|
|
216
230
|
B = brain_data.batch_size
|
|
217
231
|
device = brain_data.voxels.device
|
|
218
232
|
brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
219
233
|
|
|
220
|
-
|
|
234
|
+
# Repeat conditioning for multiple samples per input
|
|
235
|
+
if num_samples > 1:
|
|
236
|
+
brain_global = brain_global.repeat_interleave(num_samples, dim=0)
|
|
237
|
+
brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
|
|
238
|
+
|
|
239
|
+
BN = B * num_samples
|
|
240
|
+
|
|
241
|
+
mel_shape = (BN, self.n_mels, self.audio_len)
|
|
221
242
|
mel = self.flow_matcher.sample(
|
|
222
243
|
self.dit, mel_shape, brain_global, brain_tokens,
|
|
223
244
|
num_steps=num_steps, cfg_scale=cfg_scale,
|
|
224
|
-
brain_global_uncond=self.uncond_global.expand(
|
|
225
|
-
brain_tokens_uncond=self.uncond_tokens.expand(
|
|
245
|
+
brain_global_uncond=self.uncond_global.expand(BN, -1),
|
|
246
|
+
brain_tokens_uncond=self.uncond_tokens.expand(BN, -1, -1),
|
|
226
247
|
)
|
|
248
|
+
|
|
249
|
+
# Reshape to (B, num_samples, n_mels, T) when generating multiple
|
|
250
|
+
if num_samples > 1:
|
|
251
|
+
mel = mel.view(B, num_samples, self.n_mels, self.audio_len)
|
|
252
|
+
|
|
227
253
|
return ReconstructionResult(
|
|
228
254
|
modality=Modality.AUDIO,
|
|
229
255
|
output=mel,
|
|
230
|
-
brain_condition=brain_global,
|
|
256
|
+
brain_condition=brain_global[:B],
|
|
231
257
|
n_steps=num_steps,
|
|
232
258
|
cfg_scale=cfg_scale,
|
|
259
|
+
metadata={"num_samples": num_samples},
|
|
233
260
|
)
|
|
234
261
|
|
|
235
262
|
@staticmethod
|
|
@@ -154,6 +154,7 @@ class Brain2Image(nn.Module):
|
|
|
154
154
|
brain_data: BrainData,
|
|
155
155
|
num_steps: int = 50,
|
|
156
156
|
cfg_scale: float = 4.0,
|
|
157
|
+
num_samples: int = 1,
|
|
157
158
|
) -> ReconstructionResult:
|
|
158
159
|
"""Reconstruct an image from brain activity.
|
|
159
160
|
|
|
@@ -161,9 +162,13 @@ class Brain2Image(nn.Module):
|
|
|
161
162
|
brain_data: fMRI data to decode.
|
|
162
163
|
num_steps: Number of ODE solver steps.
|
|
163
164
|
cfg_scale: Classifier-free guidance scale.
|
|
165
|
+
num_samples: Number of diverse samples per brain input.
|
|
166
|
+
Each sample uses independent noise, producing semantically
|
|
167
|
+
varied reconstructions. Output shape becomes
|
|
168
|
+
``(B, num_samples, C, H, W)`` when ``num_samples > 1``.
|
|
164
169
|
|
|
165
170
|
Returns:
|
|
166
|
-
ReconstructionResult with the decoded image.
|
|
171
|
+
ReconstructionResult with the decoded image(s).
|
|
167
172
|
"""
|
|
168
173
|
B = brain_data.batch_size
|
|
169
174
|
device = brain_data.voxels.device
|
|
@@ -171,12 +176,19 @@ class Brain2Image(nn.Module):
|
|
|
171
176
|
# Encode brain
|
|
172
177
|
brain_global, brain_tokens = self.encode_brain(brain_data)
|
|
173
178
|
|
|
179
|
+
# Repeat conditioning for multiple samples per input
|
|
180
|
+
if num_samples > 1:
|
|
181
|
+
brain_global = brain_global.repeat_interleave(num_samples, dim=0)
|
|
182
|
+
brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
|
|
183
|
+
|
|
184
|
+
BN = B * num_samples
|
|
185
|
+
|
|
174
186
|
# Unconditional embeddings for CFG
|
|
175
|
-
uncond_global = self.uncond_global.expand(
|
|
176
|
-
uncond_tokens = self.uncond_tokens.expand(
|
|
187
|
+
uncond_global = self.uncond_global.expand(BN, -1)
|
|
188
|
+
uncond_tokens = self.uncond_tokens.expand(BN, -1, -1)
|
|
177
189
|
|
|
178
|
-
# Sample latents via flow matching
|
|
179
|
-
latent_shape = (
|
|
190
|
+
# Sample latents via flow matching (each gets independent noise)
|
|
191
|
+
latent_shape = (BN, self._latent_channels, self._latent_size, self._latent_size)
|
|
180
192
|
z = self.flow_matcher.sample(
|
|
181
193
|
self.dit,
|
|
182
194
|
shape=latent_shape,
|
|
@@ -192,12 +204,18 @@ class Brain2Image(nn.Module):
|
|
|
192
204
|
images = self.vae.decode(z)
|
|
193
205
|
images = images.clamp(0, 1)
|
|
194
206
|
|
|
207
|
+
# Reshape to (B, num_samples, C, H, W) when generating multiple
|
|
208
|
+
if num_samples > 1:
|
|
209
|
+
C, H, W = images.shape[1:]
|
|
210
|
+
images = images.view(B, num_samples, C, H, W)
|
|
211
|
+
|
|
195
212
|
return ReconstructionResult(
|
|
196
213
|
modality=Modality.IMAGE,
|
|
197
214
|
output=images,
|
|
198
|
-
brain_condition=brain_global,
|
|
215
|
+
brain_condition=brain_global[:B],
|
|
199
216
|
n_steps=num_steps,
|
|
200
217
|
cfg_scale=cfg_scale,
|
|
218
|
+
metadata={"num_samples": num_samples},
|
|
201
219
|
)
|
|
202
220
|
|
|
203
221
|
|
|
@@ -190,16 +190,22 @@ class Brain2Text(nn.Module):
|
|
|
190
190
|
|
|
191
191
|
Args:
|
|
192
192
|
text_tokens: ``(B, T)`` target token IDs (byte-level).
|
|
193
|
+
Raw text bytes — BOS is prepended automatically.
|
|
193
194
|
brain_data: Corresponding fMRI data.
|
|
194
195
|
|
|
195
196
|
Returns:
|
|
196
197
|
Scalar cross-entropy loss.
|
|
197
198
|
"""
|
|
199
|
+
B, T = text_tokens.shape
|
|
198
200
|
_, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
199
201
|
|
|
200
|
-
#
|
|
201
|
-
|
|
202
|
-
|
|
202
|
+
# Prepend BOS so the model learns to predict from BOS context
|
|
203
|
+
# Input: [BOS, t1, t2, ..., t_{T-1}], Target: [t1, t2, ..., t_T]
|
|
204
|
+
bos = torch.full(
|
|
205
|
+
(B, 1), self.bos_token, dtype=torch.long, device=text_tokens.device
|
|
206
|
+
)
|
|
207
|
+
input_tokens = torch.cat([bos, text_tokens[:, :-1]], dim=1)
|
|
208
|
+
target_tokens = text_tokens
|
|
203
209
|
|
|
204
210
|
logits = self.decoder(input_tokens, brain_tokens)
|
|
205
211
|
return F.cross_entropy(
|
|
@@ -215,6 +221,8 @@ class Brain2Text(nn.Module):
|
|
|
215
221
|
max_len: int | None = None,
|
|
216
222
|
temperature: float = 0.8,
|
|
217
223
|
top_k: int = 50,
|
|
224
|
+
top_p: float = 0.0,
|
|
225
|
+
num_samples: int = 1,
|
|
218
226
|
) -> ReconstructionResult:
|
|
219
227
|
"""Reconstruct text from brain activity via autoregressive decoding.
|
|
220
228
|
|
|
@@ -222,7 +230,16 @@ class Brain2Text(nn.Module):
|
|
|
222
230
|
brain_data: fMRI data to decode.
|
|
223
231
|
max_len: Maximum generation length.
|
|
224
232
|
temperature: Sampling temperature.
|
|
225
|
-
top_k: Top-k filtering.
|
|
233
|
+
top_k: Top-k filtering (0 to disable).
|
|
234
|
+
top_p: Nucleus sampling threshold (0.0 to disable). When set,
|
|
235
|
+
only the smallest set of tokens with cumulative probability
|
|
236
|
+
>= ``top_p`` are kept. Promotes semantic diversity.
|
|
237
|
+
num_samples: Number of diverse samples per brain input.
|
|
238
|
+
Each sample decodes independently with different random
|
|
239
|
+
draws, producing semantically varied texts. When
|
|
240
|
+
``num_samples > 1``, ``metadata["texts"]`` is a list of
|
|
241
|
+
lists: ``texts[i]`` contains ``num_samples`` strings for
|
|
242
|
+
brain input *i*.
|
|
226
243
|
|
|
227
244
|
Returns:
|
|
228
245
|
ReconstructionResult with generated text as metadata.
|
|
@@ -233,8 +250,14 @@ class Brain2Text(nn.Module):
|
|
|
233
250
|
|
|
234
251
|
_, brain_tokens = self.brain_encoder(brain_data.voxels)
|
|
235
252
|
|
|
253
|
+
# Repeat conditioning for multiple samples per input
|
|
254
|
+
if num_samples > 1:
|
|
255
|
+
brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
|
|
256
|
+
|
|
257
|
+
BN = B * num_samples
|
|
258
|
+
|
|
236
259
|
# Start with BOS token
|
|
237
|
-
generated = torch.full((
|
|
260
|
+
generated = torch.full((BN, 1), self.bos_token, dtype=torch.long, device=device)
|
|
238
261
|
|
|
239
262
|
for _ in range(gen_len - 1):
|
|
240
263
|
logits = self.decoder(generated, brain_tokens)
|
|
@@ -242,10 +265,20 @@ class Brain2Text(nn.Module):
|
|
|
242
265
|
|
|
243
266
|
# Top-k filtering
|
|
244
267
|
if top_k > 0:
|
|
245
|
-
topk_vals, _ = next_logits.topk(top_k, dim=-1)
|
|
268
|
+
topk_vals, _ = next_logits.topk(min(top_k, next_logits.size(-1)), dim=-1)
|
|
246
269
|
threshold = topk_vals[:, -1].unsqueeze(-1)
|
|
247
270
|
next_logits = next_logits.masked_fill(next_logits < threshold, float("-inf"))
|
|
248
271
|
|
|
272
|
+
# Nucleus (top-p) filtering
|
|
273
|
+
if top_p > 0.0:
|
|
274
|
+
sorted_logits, sorted_idx = next_logits.sort(dim=-1, descending=True)
|
|
275
|
+
cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
276
|
+
# Remove tokens with cumulative probability above top_p
|
|
277
|
+
mask = cum_probs - sorted_logits.softmax(dim=-1) >= top_p
|
|
278
|
+
sorted_logits[mask] = float("-inf")
|
|
279
|
+
# Scatter back
|
|
280
|
+
next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
|
281
|
+
|
|
249
282
|
probs = F.softmax(next_logits, dim=-1)
|
|
250
283
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
251
284
|
generated = torch.cat([generated, next_token], dim=1)
|
|
@@ -254,14 +287,23 @@ class Brain2Text(nn.Module):
|
|
|
254
287
|
if (next_token == 0).all():
|
|
255
288
|
break
|
|
256
289
|
|
|
257
|
-
# Decode to text
|
|
258
|
-
|
|
290
|
+
# Decode to text (skip BOS token at position 0)
|
|
291
|
+
if num_samples > 1:
|
|
292
|
+
# Group samples: texts[i] = list of num_samples strings
|
|
293
|
+
texts = []
|
|
294
|
+
for i in range(B):
|
|
295
|
+
group = []
|
|
296
|
+
for s in range(num_samples):
|
|
297
|
+
group.append(self.tokens_to_text(generated[i * num_samples + s, 1:]))
|
|
298
|
+
texts.append(group)
|
|
299
|
+
else:
|
|
300
|
+
texts = [self.tokens_to_text(generated[i, 1:]) for i in range(B)]
|
|
259
301
|
|
|
260
302
|
return ReconstructionResult(
|
|
261
303
|
modality=Modality.TEXT,
|
|
262
304
|
output=generated,
|
|
263
|
-
brain_condition=brain_tokens.mean(dim=1),
|
|
264
|
-
metadata={"texts": texts},
|
|
305
|
+
brain_condition=brain_tokens[:B].mean(dim=1),
|
|
306
|
+
metadata={"texts": texts, "num_samples": num_samples},
|
|
265
307
|
)
|
|
266
308
|
|
|
267
309
|
|
|
@@ -330,7 +330,11 @@ class DiffusionTransformer(nn.Module):
|
|
|
330
330
|
self._initialize_weights()
|
|
331
331
|
|
|
332
332
|
def _initialize_weights(self) -> None:
|
|
333
|
-
"""Initialize weights following DiT conventions.
|
|
333
|
+
"""Initialize weights following DiT conventions.
|
|
334
|
+
|
|
335
|
+
Global Xavier init first, then re-apply zero-init on gating
|
|
336
|
+
parameters (AdaLN modulation, final layer) for stable training.
|
|
337
|
+
"""
|
|
334
338
|
|
|
335
339
|
def _init(m: nn.Module) -> None:
|
|
336
340
|
if isinstance(m, nn.Linear):
|
|
@@ -344,6 +348,15 @@ class DiffusionTransformer(nn.Module):
|
|
|
344
348
|
|
|
345
349
|
self.apply(_init)
|
|
346
350
|
|
|
351
|
+
# Re-apply zero-init on gating parameters (overwritten by global init)
|
|
352
|
+
for block in self.blocks:
|
|
353
|
+
nn.init.zeros_(block.adaLN_modulation[-1].weight)
|
|
354
|
+
nn.init.zeros_(block.adaLN_modulation[-1].bias)
|
|
355
|
+
nn.init.zeros_(self.final_layer.adaLN[-1].weight)
|
|
356
|
+
nn.init.zeros_(self.final_layer.adaLN[-1].bias)
|
|
357
|
+
nn.init.zeros_(self.final_layer.proj.weight)
|
|
358
|
+
nn.init.zeros_(self.final_layer.proj.bias)
|
|
359
|
+
|
|
347
360
|
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
|
348
361
|
"""Reshape patch tokens back to spatial latent maps.
|
|
349
362
|
|
|
@@ -91,6 +91,26 @@ class TestBrain2Audio:
|
|
|
91
91
|
assert result.output.shape[0] == 1
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
def test_reconstruct_num_samples(self, model, brain_data):
|
|
95
|
+
model.eval()
|
|
96
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
|
|
97
|
+
assert result.output.shape == (BATCH, 3, N_MELS, AUDIO_LEN)
|
|
98
|
+
assert result.metadata["num_samples"] == 3
|
|
99
|
+
|
|
100
|
+
def test_diverse_audio_samples_differ(self, model, brain_data):
|
|
101
|
+
"""Multiple audio samples from same brain input should differ."""
|
|
102
|
+
model.eval()
|
|
103
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
|
|
104
|
+
sample_0 = result.output[:, 0]
|
|
105
|
+
sample_1 = result.output[:, 1]
|
|
106
|
+
assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
|
|
107
|
+
|
|
108
|
+
def test_reconstruct_num_samples_1(self, model, brain_data):
|
|
109
|
+
model.eval()
|
|
110
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
|
|
111
|
+
assert result.output.shape == (BATCH, N_MELS, AUDIO_LEN)
|
|
112
|
+
|
|
113
|
+
|
|
94
114
|
class TestMelToWaveform:
|
|
95
115
|
def test_output_shape(self):
|
|
96
116
|
mel = torch.rand(1, N_MELS, 16).abs() + 0.01
|
|
@@ -79,6 +79,27 @@ class TestBrain2Image:
|
|
|
79
79
|
assert result.output.shape[0] == 1
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
def test_reconstruct_num_samples(self, model, brain_data):
|
|
83
|
+
model.eval()
|
|
84
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
|
|
85
|
+
assert result.output.shape == (BATCH, 3, 3, IMG_SIZE, IMG_SIZE)
|
|
86
|
+
assert result.metadata["num_samples"] == 3
|
|
87
|
+
|
|
88
|
+
def test_diverse_samples_differ(self, model, brain_data):
|
|
89
|
+
"""Multiple samples from same brain input should differ."""
|
|
90
|
+
model.eval()
|
|
91
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
|
|
92
|
+
sample_0 = result.output[:, 0]
|
|
93
|
+
sample_1 = result.output[:, 1]
|
|
94
|
+
assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
|
|
95
|
+
|
|
96
|
+
def test_reconstruct_num_samples_1(self, model, brain_data):
|
|
97
|
+
"""num_samples=1 should behave like the default."""
|
|
98
|
+
model.eval()
|
|
99
|
+
result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
|
|
100
|
+
assert result.output.shape == (BATCH, 3, IMG_SIZE, IMG_SIZE)
|
|
101
|
+
|
|
102
|
+
|
|
82
103
|
class TestBuildBrain2Img:
|
|
83
104
|
def test_default_build(self):
|
|
84
105
|
model = build_brain2img(n_voxels=64, img_size=8, hidden_dim=16, depth=1, num_heads=4)
|
|
@@ -142,6 +142,54 @@ class TestBrain2Text:
|
|
|
142
142
|
assert len(result.metadata["texts"]) == 1
|
|
143
143
|
|
|
144
144
|
|
|
145
|
+
def test_reconstruct_num_samples(self, model, brain_data):
|
|
146
|
+
model.eval()
|
|
147
|
+
result = model.reconstruct(brain_data, max_len=8, num_samples=3)
|
|
148
|
+
texts = result.metadata["texts"]
|
|
149
|
+
assert len(texts) == BATCH
|
|
150
|
+
for group in texts:
|
|
151
|
+
assert isinstance(group, list)
|
|
152
|
+
assert len(group) == 3
|
|
153
|
+
for t in group:
|
|
154
|
+
assert isinstance(t, str)
|
|
155
|
+
|
|
156
|
+
def test_diverse_text_samples_differ(self, model, brain_data):
|
|
157
|
+
"""Multiple text samples from same brain input should differ."""
|
|
158
|
+
model.eval()
|
|
159
|
+
# High temperature for max diversity
|
|
160
|
+
result = model.reconstruct(
|
|
161
|
+
brain_data, max_len=8, temperature=1.5, num_samples=4,
|
|
162
|
+
)
|
|
163
|
+
texts = result.metadata["texts"]
|
|
164
|
+
# At least one brain input should produce non-identical samples
|
|
165
|
+
any_differ = False
|
|
166
|
+
for group in texts:
|
|
167
|
+
if len(set(group)) > 1:
|
|
168
|
+
any_differ = True
|
|
169
|
+
break
|
|
170
|
+
assert any_differ, "At high temperature, diverse samples should differ"
|
|
171
|
+
|
|
172
|
+
def test_reconstruct_top_p(self, model, brain_data):
|
|
173
|
+
model.eval()
|
|
174
|
+
result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=0)
|
|
175
|
+
assert result.output.shape[0] == BATCH
|
|
176
|
+
assert len(result.metadata["texts"]) == BATCH
|
|
177
|
+
|
|
178
|
+
def test_reconstruct_top_p_and_top_k(self, model, brain_data):
|
|
179
|
+
model.eval()
|
|
180
|
+
result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=20)
|
|
181
|
+
assert result.output.shape[0] == BATCH
|
|
182
|
+
|
|
183
|
+
def test_reconstruct_num_samples_1(self, model, brain_data):
|
|
184
|
+
"""num_samples=1 should return flat list of strings."""
|
|
185
|
+
model.eval()
|
|
186
|
+
result = model.reconstruct(brain_data, max_len=8, num_samples=1)
|
|
187
|
+
texts = result.metadata["texts"]
|
|
188
|
+
assert len(texts) == BATCH
|
|
189
|
+
for t in texts:
|
|
190
|
+
assert isinstance(t, str)
|
|
191
|
+
|
|
192
|
+
|
|
145
193
|
class TestBuildBrain2Text:
|
|
146
194
|
def test_default_build(self):
|
|
147
195
|
model = build_brain2text(n_voxels=64, max_len=16, hidden_dim=16, depth=1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|