cortexflowx 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/CHANGELOG.md +16 -0
  2. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/PKG-INFO +3 -3
  3. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/README.md +2 -2
  4. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/pyproject.toml +1 -1
  5. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2audio.py +32 -5
  6. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2img.py +24 -6
  7. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain2text.py +52 -10
  8. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/dit.py +14 -1
  9. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2audio.py +20 -0
  10. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2img.py +21 -0
  11. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain2text.py +48 -0
  12. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/.github/workflows/ci.yml +0 -0
  13. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/.gitignore +0 -0
  14. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/CITATION.cff +0 -0
  15. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/LICENSE +0 -0
  16. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2audio_demo.py +0 -0
  17. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2img_demo.py +0 -0
  18. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/examples/brain2text_demo.py +0 -0
  19. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/__init__.py +0 -0
  20. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/_types.py +0 -0
  21. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/brain_encoder.py +0 -0
  22. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/flow_matching.py +0 -0
  23. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/training.py +0 -0
  24. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/src/cortexflow/vae.py +0 -0
  25. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/conftest.py +0 -0
  26. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_brain_encoder.py +0 -0
  27. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_dit.py +0 -0
  28. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_flow_matching.py +0 -0
  29. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_init.py +0 -0
  30. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_integration.py +0 -0
  31. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_training.py +0 -0
  32. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_types.py +0 -0
  33. {cortexflowx-0.1.0 → cortexflowx-0.2.0}/tests/test_vae.py +0 -0
@@ -2,6 +2,22 @@
2
2
 
3
3
  All notable changes to CortexFlow will be documented in this file.
4
4
 
5
+ ## [0.2.0] - 2026-04-12
6
+
7
+ ### Added
8
+
9
+ - **Semantic diversity**: All pipelines now support `num_samples` parameter to generate multiple diverse reconstructions per brain input. Each sample uses independent noise (image/audio) or independent random draws (text), producing semantically varied outputs.
10
+ - **Nucleus (top-p) sampling** for Brain2Text: `top_p` parameter enables nucleus filtering alongside top-k, giving finer control over text generation diversity.
11
+ - When `num_samples > 1`, output shapes become `(B, num_samples, ...)` for image/audio; text metadata returns grouped lists per brain input.
12
+
13
+ ## [0.1.1] - 2025-06-26
14
+
15
+ ### Fixed
16
+
17
+ - **DiT zero-init**: `_initialize_weights()` no longer overwrites critical zero-initialized gating parameters (AdaLN modulation, final layer projection). Output is now exactly zero at initialization for stable training.
18
+ - **Brain2Text BOS mismatch**: Training now prepends BOS token so the model learns to predict from BOS context, matching inference behavior.
19
+ - **Brain2Text empty output**: `reconstruct()` now correctly skips the BOS token when decoding generated sequences to text.
20
+
5
21
  ## [0.1.0] - 2025-06-25
6
22
 
7
23
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cortexflowx
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.
5
5
  Project-URL: Homepage, https://github.com/stef41/cortexflow
6
6
  Project-URL: Repository, https://github.com/stef41/cortexflow
@@ -65,12 +65,12 @@ fMRI voxels
65
65
  ## Installation
66
66
 
67
67
  ```bash
68
- pip install cortexflow
68
+ pip install cortexflowx
69
69
  ```
70
70
 
71
71
  With audio support:
72
72
  ```bash
73
- pip install cortexflow[audio]
73
+ pip install cortexflowx[audio]
74
74
  ```
75
75
 
76
76
  ## Quick Start
@@ -30,12 +30,12 @@ fMRI voxels
30
30
  ## Installation
31
31
 
32
32
  ```bash
33
- pip install cortexflow
33
+ pip install cortexflowx
34
34
  ```
35
35
 
36
36
  With audio support:
37
37
  ```bash
38
- pip install cortexflow[audio]
38
+ pip install cortexflowx[audio]
39
39
  ```
40
40
 
41
41
  ## Quick Start
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "cortexflowx"
7
- version = "0.1.0"
7
+ version = "0.2.0"
8
8
  description = "Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI."
9
9
  readme = "README.md"
10
10
  license = {text = "Apache-2.0"}
@@ -211,25 +211,52 @@ class Brain2Audio(nn.Module):
211
211
  brain_data: BrainData,
212
212
  num_steps: int = 50,
213
213
  cfg_scale: float = 3.0,
214
+ num_samples: int = 1,
214
215
  ) -> ReconstructionResult:
215
- """Reconstruct audio mel spectrogram from brain activity."""
216
+ """Reconstruct audio mel spectrogram from brain activity.
217
+
218
+ Args:
219
+ brain_data: fMRI data to decode.
220
+ num_steps: Number of ODE solver steps.
221
+ cfg_scale: Classifier-free guidance scale.
222
+ num_samples: Number of diverse samples per brain input.
223
+ Each sample uses independent noise, producing semantically
224
+ varied reconstructions. Output shape becomes
225
+ ``(B, num_samples, n_mels, T)`` when ``num_samples > 1``.
226
+
227
+ Returns:
228
+ ReconstructionResult with the decoded mel spectrogram(s).
229
+ """
216
230
  B = brain_data.batch_size
217
231
  device = brain_data.voxels.device
218
232
  brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
219
233
 
220
- mel_shape = (B, self.n_mels, self.audio_len)
234
+ # Repeat conditioning for multiple samples per input
235
+ if num_samples > 1:
236
+ brain_global = brain_global.repeat_interleave(num_samples, dim=0)
237
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
238
+
239
+ BN = B * num_samples
240
+
241
+ mel_shape = (BN, self.n_mels, self.audio_len)
221
242
  mel = self.flow_matcher.sample(
222
243
  self.dit, mel_shape, brain_global, brain_tokens,
223
244
  num_steps=num_steps, cfg_scale=cfg_scale,
224
- brain_global_uncond=self.uncond_global.expand(B, -1),
225
- brain_tokens_uncond=self.uncond_tokens.expand(B, -1, -1),
245
+ brain_global_uncond=self.uncond_global.expand(BN, -1),
246
+ brain_tokens_uncond=self.uncond_tokens.expand(BN, -1, -1),
226
247
  )
248
+
249
+ # Reshape to (B, num_samples, n_mels, T) when generating multiple
250
+ if num_samples > 1:
251
+ mel = mel.view(B, num_samples, self.n_mels, self.audio_len)
252
+
227
253
  return ReconstructionResult(
228
254
  modality=Modality.AUDIO,
229
255
  output=mel,
230
- brain_condition=brain_global,
256
+ brain_condition=brain_global[:B],
231
257
  n_steps=num_steps,
232
258
  cfg_scale=cfg_scale,
259
+ metadata={"num_samples": num_samples},
233
260
  )
234
261
 
235
262
  @staticmethod
@@ -154,6 +154,7 @@ class Brain2Image(nn.Module):
154
154
  brain_data: BrainData,
155
155
  num_steps: int = 50,
156
156
  cfg_scale: float = 4.0,
157
+ num_samples: int = 1,
157
158
  ) -> ReconstructionResult:
158
159
  """Reconstruct an image from brain activity.
159
160
 
@@ -161,9 +162,13 @@ class Brain2Image(nn.Module):
161
162
  brain_data: fMRI data to decode.
162
163
  num_steps: Number of ODE solver steps.
163
164
  cfg_scale: Classifier-free guidance scale.
165
+ num_samples: Number of diverse samples per brain input.
166
+ Each sample uses independent noise, producing semantically
167
+ varied reconstructions. Output shape becomes
168
+ ``(B, num_samples, C, H, W)`` when ``num_samples > 1``.
164
169
 
165
170
  Returns:
166
- ReconstructionResult with the decoded image.
171
+ ReconstructionResult with the decoded image(s).
167
172
  """
168
173
  B = brain_data.batch_size
169
174
  device = brain_data.voxels.device
@@ -171,12 +176,19 @@ class Brain2Image(nn.Module):
171
176
  # Encode brain
172
177
  brain_global, brain_tokens = self.encode_brain(brain_data)
173
178
 
179
+ # Repeat conditioning for multiple samples per input
180
+ if num_samples > 1:
181
+ brain_global = brain_global.repeat_interleave(num_samples, dim=0)
182
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
183
+
184
+ BN = B * num_samples
185
+
174
186
  # Unconditional embeddings for CFG
175
- uncond_global = self.uncond_global.expand(B, -1)
176
- uncond_tokens = self.uncond_tokens.expand(B, -1, -1)
187
+ uncond_global = self.uncond_global.expand(BN, -1)
188
+ uncond_tokens = self.uncond_tokens.expand(BN, -1, -1)
177
189
 
178
- # Sample latents via flow matching
179
- latent_shape = (B, self._latent_channels, self._latent_size, self._latent_size)
190
+ # Sample latents via flow matching (each gets independent noise)
191
+ latent_shape = (BN, self._latent_channels, self._latent_size, self._latent_size)
180
192
  z = self.flow_matcher.sample(
181
193
  self.dit,
182
194
  shape=latent_shape,
@@ -192,12 +204,18 @@ class Brain2Image(nn.Module):
192
204
  images = self.vae.decode(z)
193
205
  images = images.clamp(0, 1)
194
206
 
207
+ # Reshape to (B, num_samples, C, H, W) when generating multiple
208
+ if num_samples > 1:
209
+ C, H, W = images.shape[1:]
210
+ images = images.view(B, num_samples, C, H, W)
211
+
195
212
  return ReconstructionResult(
196
213
  modality=Modality.IMAGE,
197
214
  output=images,
198
- brain_condition=brain_global,
215
+ brain_condition=brain_global[:B],
199
216
  n_steps=num_steps,
200
217
  cfg_scale=cfg_scale,
218
+ metadata={"num_samples": num_samples},
201
219
  )
202
220
 
203
221
 
@@ -190,16 +190,22 @@ class Brain2Text(nn.Module):
190
190
 
191
191
  Args:
192
192
  text_tokens: ``(B, T)`` target token IDs (byte-level).
193
+ Raw text bytes — BOS is prepended automatically.
193
194
  brain_data: Corresponding fMRI data.
194
195
 
195
196
  Returns:
196
197
  Scalar cross-entropy loss.
197
198
  """
199
+ B, T = text_tokens.shape
198
200
  _, brain_tokens = self.brain_encoder(brain_data.voxels)
199
201
 
200
- # Input: [BOS, t1, t2, ..., t_{n-1}], Target: [t1, t2, ..., t_n]
201
- input_tokens = text_tokens[:, :-1]
202
- target_tokens = text_tokens[:, 1:]
202
+ # Prepend BOS so the model learns to predict from BOS context
203
+ # Input: [BOS, t1, t2, ..., t_{T-1}], Target: [t1, t2, ..., t_T]
204
+ bos = torch.full(
205
+ (B, 1), self.bos_token, dtype=torch.long, device=text_tokens.device
206
+ )
207
+ input_tokens = torch.cat([bos, text_tokens[:, :-1]], dim=1)
208
+ target_tokens = text_tokens
203
209
 
204
210
  logits = self.decoder(input_tokens, brain_tokens)
205
211
  return F.cross_entropy(
@@ -215,6 +221,8 @@ class Brain2Text(nn.Module):
215
221
  max_len: int | None = None,
216
222
  temperature: float = 0.8,
217
223
  top_k: int = 50,
224
+ top_p: float = 0.0,
225
+ num_samples: int = 1,
218
226
  ) -> ReconstructionResult:
219
227
  """Reconstruct text from brain activity via autoregressive decoding.
220
228
 
@@ -222,7 +230,16 @@ class Brain2Text(nn.Module):
222
230
  brain_data: fMRI data to decode.
223
231
  max_len: Maximum generation length.
224
232
  temperature: Sampling temperature.
225
- top_k: Top-k filtering.
233
+ top_k: Top-k filtering (0 to disable).
234
+ top_p: Nucleus sampling threshold (0.0 to disable). When set,
235
+ only the smallest set of tokens with cumulative probability
236
+ >= ``top_p`` are kept. Promotes semantic diversity.
237
+ num_samples: Number of diverse samples per brain input.
238
+ Each sample decodes independently with different random
239
+ draws, producing semantically varied texts. When
240
+ ``num_samples > 1``, ``metadata["texts"]`` is a list of
241
+ lists: ``texts[i]`` contains ``num_samples`` strings for
242
+ brain input *i*.
226
243
 
227
244
  Returns:
228
245
  ReconstructionResult with generated text as metadata.
@@ -233,8 +250,14 @@ class Brain2Text(nn.Module):
233
250
 
234
251
  _, brain_tokens = self.brain_encoder(brain_data.voxels)
235
252
 
253
+ # Repeat conditioning for multiple samples per input
254
+ if num_samples > 1:
255
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
256
+
257
+ BN = B * num_samples
258
+
236
259
  # Start with BOS token
237
- generated = torch.full((B, 1), self.bos_token, dtype=torch.long, device=device)
260
+ generated = torch.full((BN, 1), self.bos_token, dtype=torch.long, device=device)
238
261
 
239
262
  for _ in range(gen_len - 1):
240
263
  logits = self.decoder(generated, brain_tokens)
@@ -242,10 +265,20 @@ class Brain2Text(nn.Module):
242
265
 
243
266
  # Top-k filtering
244
267
  if top_k > 0:
245
- topk_vals, _ = next_logits.topk(top_k, dim=-1)
268
+ topk_vals, _ = next_logits.topk(min(top_k, next_logits.size(-1)), dim=-1)
246
269
  threshold = topk_vals[:, -1].unsqueeze(-1)
247
270
  next_logits = next_logits.masked_fill(next_logits < threshold, float("-inf"))
248
271
 
272
+ # Nucleus (top-p) filtering
273
+ if top_p > 0.0:
274
+ sorted_logits, sorted_idx = next_logits.sort(dim=-1, descending=True)
275
+ cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
276
+ # Remove tokens with cumulative probability above top_p
277
+ mask = cum_probs - sorted_logits.softmax(dim=-1) >= top_p
278
+ sorted_logits[mask] = float("-inf")
279
+ # Scatter back
280
+ next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
281
+
249
282
  probs = F.softmax(next_logits, dim=-1)
250
283
  next_token = torch.multinomial(probs, num_samples=1)
251
284
  generated = torch.cat([generated, next_token], dim=1)
@@ -254,14 +287,23 @@ class Brain2Text(nn.Module):
254
287
  if (next_token == 0).all():
255
288
  break
256
289
 
257
- # Decode to text
258
- texts = [self.tokens_to_text(generated[i]) for i in range(B)]
290
+ # Decode to text (skip BOS token at position 0)
291
+ if num_samples > 1:
292
+ # Group samples: texts[i] = list of num_samples strings
293
+ texts = []
294
+ for i in range(B):
295
+ group = []
296
+ for s in range(num_samples):
297
+ group.append(self.tokens_to_text(generated[i * num_samples + s, 1:]))
298
+ texts.append(group)
299
+ else:
300
+ texts = [self.tokens_to_text(generated[i, 1:]) for i in range(B)]
259
301
 
260
302
  return ReconstructionResult(
261
303
  modality=Modality.TEXT,
262
304
  output=generated,
263
- brain_condition=brain_tokens.mean(dim=1),
264
- metadata={"texts": texts},
305
+ brain_condition=brain_tokens[:B].mean(dim=1),
306
+ metadata={"texts": texts, "num_samples": num_samples},
265
307
  )
266
308
 
267
309
 
@@ -330,7 +330,11 @@ class DiffusionTransformer(nn.Module):
330
330
  self._initialize_weights()
331
331
 
332
332
  def _initialize_weights(self) -> None:
333
- """Initialize weights following DiT conventions."""
333
+ """Initialize weights following DiT conventions.
334
+
335
+ Global Xavier init first, then re-apply zero-init on gating
336
+ parameters (AdaLN modulation, final layer) for stable training.
337
+ """
334
338
 
335
339
  def _init(m: nn.Module) -> None:
336
340
  if isinstance(m, nn.Linear):
@@ -344,6 +348,15 @@ class DiffusionTransformer(nn.Module):
344
348
 
345
349
  self.apply(_init)
346
350
 
351
+ # Re-apply zero-init on gating parameters (overwritten by global init)
352
+ for block in self.blocks:
353
+ nn.init.zeros_(block.adaLN_modulation[-1].weight)
354
+ nn.init.zeros_(block.adaLN_modulation[-1].bias)
355
+ nn.init.zeros_(self.final_layer.adaLN[-1].weight)
356
+ nn.init.zeros_(self.final_layer.adaLN[-1].bias)
357
+ nn.init.zeros_(self.final_layer.proj.weight)
358
+ nn.init.zeros_(self.final_layer.proj.bias)
359
+
347
360
  def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
348
361
  """Reshape patch tokens back to spatial latent maps.
349
362
 
@@ -91,6 +91,26 @@ class TestBrain2Audio:
91
91
  assert result.output.shape[0] == 1
92
92
 
93
93
 
94
+ def test_reconstruct_num_samples(self, model, brain_data):
95
+ model.eval()
96
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
97
+ assert result.output.shape == (BATCH, 3, N_MELS, AUDIO_LEN)
98
+ assert result.metadata["num_samples"] == 3
99
+
100
+ def test_diverse_audio_samples_differ(self, model, brain_data):
101
+ """Multiple audio samples from same brain input should differ."""
102
+ model.eval()
103
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
104
+ sample_0 = result.output[:, 0]
105
+ sample_1 = result.output[:, 1]
106
+ assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
107
+
108
+ def test_reconstruct_num_samples_1(self, model, brain_data):
109
+ model.eval()
110
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
111
+ assert result.output.shape == (BATCH, N_MELS, AUDIO_LEN)
112
+
113
+
94
114
  class TestMelToWaveform:
95
115
  def test_output_shape(self):
96
116
  mel = torch.rand(1, N_MELS, 16).abs() + 0.01
@@ -79,6 +79,27 @@ class TestBrain2Image:
79
79
  assert result.output.shape[0] == 1
80
80
 
81
81
 
82
+ def test_reconstruct_num_samples(self, model, brain_data):
83
+ model.eval()
84
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
85
+ assert result.output.shape == (BATCH, 3, 3, IMG_SIZE, IMG_SIZE)
86
+ assert result.metadata["num_samples"] == 3
87
+
88
+ def test_diverse_samples_differ(self, model, brain_data):
89
+ """Multiple samples from same brain input should differ."""
90
+ model.eval()
91
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
92
+ sample_0 = result.output[:, 0]
93
+ sample_1 = result.output[:, 1]
94
+ assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
95
+
96
+ def test_reconstruct_num_samples_1(self, model, brain_data):
97
+ """num_samples=1 should behave like the default."""
98
+ model.eval()
99
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
100
+ assert result.output.shape == (BATCH, 3, IMG_SIZE, IMG_SIZE)
101
+
102
+
82
103
  class TestBuildBrain2Img:
83
104
  def test_default_build(self):
84
105
  model = build_brain2img(n_voxels=64, img_size=8, hidden_dim=16, depth=1, num_heads=4)
@@ -142,6 +142,54 @@ class TestBrain2Text:
142
142
  assert len(result.metadata["texts"]) == 1
143
143
 
144
144
 
145
+ def test_reconstruct_num_samples(self, model, brain_data):
146
+ model.eval()
147
+ result = model.reconstruct(brain_data, max_len=8, num_samples=3)
148
+ texts = result.metadata["texts"]
149
+ assert len(texts) == BATCH
150
+ for group in texts:
151
+ assert isinstance(group, list)
152
+ assert len(group) == 3
153
+ for t in group:
154
+ assert isinstance(t, str)
155
+
156
+ def test_diverse_text_samples_differ(self, model, brain_data):
157
+ """Multiple text samples from same brain input should differ."""
158
+ model.eval()
159
+ # High temperature for max diversity
160
+ result = model.reconstruct(
161
+ brain_data, max_len=8, temperature=1.5, num_samples=4,
162
+ )
163
+ texts = result.metadata["texts"]
164
+ # At least one brain input should produce non-identical samples
165
+ any_differ = False
166
+ for group in texts:
167
+ if len(set(group)) > 1:
168
+ any_differ = True
169
+ break
170
+ assert any_differ, "At high temperature, diverse samples should differ"
171
+
172
+ def test_reconstruct_top_p(self, model, brain_data):
173
+ model.eval()
174
+ result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=0)
175
+ assert result.output.shape[0] == BATCH
176
+ assert len(result.metadata["texts"]) == BATCH
177
+
178
+ def test_reconstruct_top_p_and_top_k(self, model, brain_data):
179
+ model.eval()
180
+ result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=20)
181
+ assert result.output.shape[0] == BATCH
182
+
183
+ def test_reconstruct_num_samples_1(self, model, brain_data):
184
+ """num_samples=1 should return flat list of strings."""
185
+ model.eval()
186
+ result = model.reconstruct(brain_data, max_len=8, num_samples=1)
187
+ texts = result.metadata["texts"]
188
+ assert len(texts) == BATCH
189
+ for t in texts:
190
+ assert isinstance(t, str)
191
+
192
+
145
193
  class TestBuildBrain2Text:
146
194
  def test_default_build(self):
147
195
  model = build_brain2text(n_voxels=64, max_len=16, hidden_dim=16, depth=1)
File without changes
File without changes
File without changes