cortexflowx 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/CHANGELOG.md +8 -0
  2. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/PKG-INFO +1 -1
  3. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/pyproject.toml +1 -1
  4. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/brain2audio.py +32 -5
  5. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/brain2img.py +24 -6
  6. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/brain2text.py +42 -6
  7. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_brain2audio.py +20 -0
  8. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_brain2img.py +21 -0
  9. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_brain2text.py +48 -0
  10. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/.github/workflows/ci.yml +0 -0
  11. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/.gitignore +0 -0
  12. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/CITATION.cff +0 -0
  13. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/LICENSE +0 -0
  14. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/README.md +0 -0
  15. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/examples/brain2audio_demo.py +0 -0
  16. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/examples/brain2img_demo.py +0 -0
  17. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/examples/brain2text_demo.py +0 -0
  18. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/__init__.py +0 -0
  19. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/_types.py +0 -0
  20. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/brain_encoder.py +0 -0
  21. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/dit.py +0 -0
  22. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/flow_matching.py +0 -0
  23. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/training.py +0 -0
  24. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/src/cortexflow/vae.py +0 -0
  25. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/conftest.py +0 -0
  26. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_brain_encoder.py +0 -0
  27. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_dit.py +0 -0
  28. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_flow_matching.py +0 -0
  29. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_init.py +0 -0
  30. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_integration.py +0 -0
  31. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_training.py +0 -0
  32. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_types.py +0 -0
  33. {cortexflowx-0.1.1 → cortexflowx-0.2.0}/tests/test_vae.py +0 -0
@@ -2,6 +2,14 @@
2
2
 
3
3
  All notable changes to CortexFlow will be documented in this file.
4
4
 
5
+ ## [0.2.0] - 2026-04-12
6
+
7
+ ### Added
8
+
9
+ - **Semantic diversity**: All pipelines now support `num_samples` parameter to generate multiple diverse reconstructions per brain input. Each sample uses independent noise (image/audio) or independent random draws (text), producing semantically varied outputs.
10
+ - **Nucleus (top-p) sampling** for Brain2Text: `top_p` parameter enables nucleus filtering alongside top-k, giving finer control over text generation diversity.
11
+ - When `num_samples > 1`, output shapes become `(B, num_samples, ...)` for image/audio; text metadata returns grouped lists per brain input.
12
+
5
13
  ## [0.1.1] - 2025-06-26
6
14
 
7
15
  ### Fixed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cortexflowx
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI.
5
5
  Project-URL: Homepage, https://github.com/stef41/cortexflow
6
6
  Project-URL: Repository, https://github.com/stef41/cortexflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "cortexflowx"
7
- version = "0.1.1"
7
+ version = "0.2.0"
8
8
  description = "Brain-to-image/audio/text reconstruction using Diffusion Transformers and Flow Matching. Decode what someone saw, heard, or thought from fMRI."
9
9
  readme = "README.md"
10
10
  license = {text = "Apache-2.0"}
@@ -211,25 +211,52 @@ class Brain2Audio(nn.Module):
211
211
  brain_data: BrainData,
212
212
  num_steps: int = 50,
213
213
  cfg_scale: float = 3.0,
214
+ num_samples: int = 1,
214
215
  ) -> ReconstructionResult:
215
- """Reconstruct audio mel spectrogram from brain activity."""
216
+ """Reconstruct audio mel spectrogram from brain activity.
217
+
218
+ Args:
219
+ brain_data: fMRI data to decode.
220
+ num_steps: Number of ODE solver steps.
221
+ cfg_scale: Classifier-free guidance scale.
222
+ num_samples: Number of diverse samples per brain input.
223
+ Each sample uses independent noise, producing semantically
224
+ varied reconstructions. Output shape becomes
225
+ ``(B, num_samples, n_mels, T)`` when ``num_samples > 1``.
226
+
227
+ Returns:
228
+ ReconstructionResult with the decoded mel spectrogram(s).
229
+ """
216
230
  B = brain_data.batch_size
217
231
  device = brain_data.voxels.device
218
232
  brain_global, brain_tokens = self.brain_encoder(brain_data.voxels)
219
233
 
220
- mel_shape = (B, self.n_mels, self.audio_len)
234
+ # Repeat conditioning for multiple samples per input
235
+ if num_samples > 1:
236
+ brain_global = brain_global.repeat_interleave(num_samples, dim=0)
237
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
238
+
239
+ BN = B * num_samples
240
+
241
+ mel_shape = (BN, self.n_mels, self.audio_len)
221
242
  mel = self.flow_matcher.sample(
222
243
  self.dit, mel_shape, brain_global, brain_tokens,
223
244
  num_steps=num_steps, cfg_scale=cfg_scale,
224
- brain_global_uncond=self.uncond_global.expand(B, -1),
225
- brain_tokens_uncond=self.uncond_tokens.expand(B, -1, -1),
245
+ brain_global_uncond=self.uncond_global.expand(BN, -1),
246
+ brain_tokens_uncond=self.uncond_tokens.expand(BN, -1, -1),
226
247
  )
248
+
249
+ # Reshape to (B, num_samples, n_mels, T) when generating multiple
250
+ if num_samples > 1:
251
+ mel = mel.view(B, num_samples, self.n_mels, self.audio_len)
252
+
227
253
  return ReconstructionResult(
228
254
  modality=Modality.AUDIO,
229
255
  output=mel,
230
- brain_condition=brain_global,
256
+ brain_condition=brain_global[:B],
231
257
  n_steps=num_steps,
232
258
  cfg_scale=cfg_scale,
259
+ metadata={"num_samples": num_samples},
233
260
  )
234
261
 
235
262
  @staticmethod
@@ -154,6 +154,7 @@ class Brain2Image(nn.Module):
154
154
  brain_data: BrainData,
155
155
  num_steps: int = 50,
156
156
  cfg_scale: float = 4.0,
157
+ num_samples: int = 1,
157
158
  ) -> ReconstructionResult:
158
159
  """Reconstruct an image from brain activity.
159
160
 
@@ -161,9 +162,13 @@ class Brain2Image(nn.Module):
161
162
  brain_data: fMRI data to decode.
162
163
  num_steps: Number of ODE solver steps.
163
164
  cfg_scale: Classifier-free guidance scale.
165
+ num_samples: Number of diverse samples per brain input.
166
+ Each sample uses independent noise, producing semantically
167
+ varied reconstructions. Output shape becomes
168
+ ``(B, num_samples, C, H, W)`` when ``num_samples > 1``.
164
169
 
165
170
  Returns:
166
- ReconstructionResult with the decoded image.
171
+ ReconstructionResult with the decoded image(s).
167
172
  """
168
173
  B = brain_data.batch_size
169
174
  device = brain_data.voxels.device
@@ -171,12 +176,19 @@ class Brain2Image(nn.Module):
171
176
  # Encode brain
172
177
  brain_global, brain_tokens = self.encode_brain(brain_data)
173
178
 
179
+ # Repeat conditioning for multiple samples per input
180
+ if num_samples > 1:
181
+ brain_global = brain_global.repeat_interleave(num_samples, dim=0)
182
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
183
+
184
+ BN = B * num_samples
185
+
174
186
  # Unconditional embeddings for CFG
175
- uncond_global = self.uncond_global.expand(B, -1)
176
- uncond_tokens = self.uncond_tokens.expand(B, -1, -1)
187
+ uncond_global = self.uncond_global.expand(BN, -1)
188
+ uncond_tokens = self.uncond_tokens.expand(BN, -1, -1)
177
189
 
178
- # Sample latents via flow matching
179
- latent_shape = (B, self._latent_channels, self._latent_size, self._latent_size)
190
+ # Sample latents via flow matching (each gets independent noise)
191
+ latent_shape = (BN, self._latent_channels, self._latent_size, self._latent_size)
180
192
  z = self.flow_matcher.sample(
181
193
  self.dit,
182
194
  shape=latent_shape,
@@ -192,12 +204,18 @@ class Brain2Image(nn.Module):
192
204
  images = self.vae.decode(z)
193
205
  images = images.clamp(0, 1)
194
206
 
207
+ # Reshape to (B, num_samples, C, H, W) when generating multiple
208
+ if num_samples > 1:
209
+ C, H, W = images.shape[1:]
210
+ images = images.view(B, num_samples, C, H, W)
211
+
195
212
  return ReconstructionResult(
196
213
  modality=Modality.IMAGE,
197
214
  output=images,
198
- brain_condition=brain_global,
215
+ brain_condition=brain_global[:B],
199
216
  n_steps=num_steps,
200
217
  cfg_scale=cfg_scale,
218
+ metadata={"num_samples": num_samples},
201
219
  )
202
220
 
203
221
 
@@ -221,6 +221,8 @@ class Brain2Text(nn.Module):
221
221
  max_len: int | None = None,
222
222
  temperature: float = 0.8,
223
223
  top_k: int = 50,
224
+ top_p: float = 0.0,
225
+ num_samples: int = 1,
224
226
  ) -> ReconstructionResult:
225
227
  """Reconstruct text from brain activity via autoregressive decoding.
226
228
 
@@ -228,7 +230,16 @@ class Brain2Text(nn.Module):
228
230
  brain_data: fMRI data to decode.
229
231
  max_len: Maximum generation length.
230
232
  temperature: Sampling temperature.
231
- top_k: Top-k filtering.
233
+ top_k: Top-k filtering (0 to disable).
234
+ top_p: Nucleus sampling threshold (0.0 to disable). When set,
235
+ only the smallest set of tokens with cumulative probability
236
+ >= ``top_p`` are kept. Promotes semantic diversity.
237
+ num_samples: Number of diverse samples per brain input.
238
+ Each sample decodes independently with different random
239
+ draws, producing semantically varied texts. When
240
+ ``num_samples > 1``, ``metadata["texts"]`` is a list of
241
+ lists: ``texts[i]`` contains ``num_samples`` strings for
242
+ brain input *i*.
232
243
 
233
244
  Returns:
234
245
  ReconstructionResult with generated text as metadata.
@@ -239,8 +250,14 @@ class Brain2Text(nn.Module):
239
250
 
240
251
  _, brain_tokens = self.brain_encoder(brain_data.voxels)
241
252
 
253
+ # Repeat conditioning for multiple samples per input
254
+ if num_samples > 1:
255
+ brain_tokens = brain_tokens.repeat_interleave(num_samples, dim=0)
256
+
257
+ BN = B * num_samples
258
+
242
259
  # Start with BOS token
243
- generated = torch.full((B, 1), self.bos_token, dtype=torch.long, device=device)
260
+ generated = torch.full((BN, 1), self.bos_token, dtype=torch.long, device=device)
244
261
 
245
262
  for _ in range(gen_len - 1):
246
263
  logits = self.decoder(generated, brain_tokens)
@@ -248,10 +265,20 @@ class Brain2Text(nn.Module):
248
265
 
249
266
  # Top-k filtering
250
267
  if top_k > 0:
251
- topk_vals, _ = next_logits.topk(top_k, dim=-1)
268
+ topk_vals, _ = next_logits.topk(min(top_k, next_logits.size(-1)), dim=-1)
252
269
  threshold = topk_vals[:, -1].unsqueeze(-1)
253
270
  next_logits = next_logits.masked_fill(next_logits < threshold, float("-inf"))
254
271
 
272
+ # Nucleus (top-p) filtering
273
+ if top_p > 0.0:
274
+ sorted_logits, sorted_idx = next_logits.sort(dim=-1, descending=True)
275
+ cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
276
+ # Remove tokens with cumulative probability above top_p
277
+ mask = cum_probs - sorted_logits.softmax(dim=-1) >= top_p
278
+ sorted_logits[mask] = float("-inf")
279
+ # Scatter back
280
+ next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
281
+
255
282
  probs = F.softmax(next_logits, dim=-1)
256
283
  next_token = torch.multinomial(probs, num_samples=1)
257
284
  generated = torch.cat([generated, next_token], dim=1)
@@ -261,13 +288,22 @@ class Brain2Text(nn.Module):
261
288
  break
262
289
 
263
290
  # Decode to text (skip BOS token at position 0)
264
- texts = [self.tokens_to_text(generated[i, 1:]) for i in range(B)]
291
+ if num_samples > 1:
292
+ # Group samples: texts[i] = list of num_samples strings
293
+ texts = []
294
+ for i in range(B):
295
+ group = []
296
+ for s in range(num_samples):
297
+ group.append(self.tokens_to_text(generated[i * num_samples + s, 1:]))
298
+ texts.append(group)
299
+ else:
300
+ texts = [self.tokens_to_text(generated[i, 1:]) for i in range(B)]
265
301
 
266
302
  return ReconstructionResult(
267
303
  modality=Modality.TEXT,
268
304
  output=generated,
269
- brain_condition=brain_tokens.mean(dim=1),
270
- metadata={"texts": texts},
305
+ brain_condition=brain_tokens[:B].mean(dim=1),
306
+ metadata={"texts": texts, "num_samples": num_samples},
271
307
  )
272
308
 
273
309
 
@@ -91,6 +91,26 @@ class TestBrain2Audio:
91
91
  assert result.output.shape[0] == 1
92
92
 
93
93
 
94
+ def test_reconstruct_num_samples(self, model, brain_data):
95
+ model.eval()
96
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
97
+ assert result.output.shape == (BATCH, 3, N_MELS, AUDIO_LEN)
98
+ assert result.metadata["num_samples"] == 3
99
+
100
+ def test_diverse_audio_samples_differ(self, model, brain_data):
101
+ """Multiple audio samples from same brain input should differ."""
102
+ model.eval()
103
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
104
+ sample_0 = result.output[:, 0]
105
+ sample_1 = result.output[:, 1]
106
+ assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
107
+
108
+ def test_reconstruct_num_samples_1(self, model, brain_data):
109
+ model.eval()
110
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
111
+ assert result.output.shape == (BATCH, N_MELS, AUDIO_LEN)
112
+
113
+
94
114
  class TestMelToWaveform:
95
115
  def test_output_shape(self):
96
116
  mel = torch.rand(1, N_MELS, 16).abs() + 0.01
@@ -79,6 +79,27 @@ class TestBrain2Image:
79
79
  assert result.output.shape[0] == 1
80
80
 
81
81
 
82
+ def test_reconstruct_num_samples(self, model, brain_data):
83
+ model.eval()
84
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=3)
85
+ assert result.output.shape == (BATCH, 3, 3, IMG_SIZE, IMG_SIZE)
86
+ assert result.metadata["num_samples"] == 3
87
+
88
+ def test_diverse_samples_differ(self, model, brain_data):
89
+ """Multiple samples from same brain input should differ."""
90
+ model.eval()
91
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=2)
92
+ sample_0 = result.output[:, 0]
93
+ sample_1 = result.output[:, 1]
94
+ assert not torch.allclose(sample_0, sample_1), "Diverse samples should differ"
95
+
96
+ def test_reconstruct_num_samples_1(self, model, brain_data):
97
+ """num_samples=1 should behave like the default."""
98
+ model.eval()
99
+ result = model.reconstruct(brain_data, num_steps=2, num_samples=1)
100
+ assert result.output.shape == (BATCH, 3, IMG_SIZE, IMG_SIZE)
101
+
102
+
82
103
  class TestBuildBrain2Img:
83
104
  def test_default_build(self):
84
105
  model = build_brain2img(n_voxels=64, img_size=8, hidden_dim=16, depth=1, num_heads=4)
@@ -142,6 +142,54 @@ class TestBrain2Text:
142
142
  assert len(result.metadata["texts"]) == 1
143
143
 
144
144
 
145
+ def test_reconstruct_num_samples(self, model, brain_data):
146
+ model.eval()
147
+ result = model.reconstruct(brain_data, max_len=8, num_samples=3)
148
+ texts = result.metadata["texts"]
149
+ assert len(texts) == BATCH
150
+ for group in texts:
151
+ assert isinstance(group, list)
152
+ assert len(group) == 3
153
+ for t in group:
154
+ assert isinstance(t, str)
155
+
156
+ def test_diverse_text_samples_differ(self, model, brain_data):
157
+ """Multiple text samples from same brain input should differ."""
158
+ model.eval()
159
+ # High temperature for max diversity
160
+ result = model.reconstruct(
161
+ brain_data, max_len=8, temperature=1.5, num_samples=4,
162
+ )
163
+ texts = result.metadata["texts"]
164
+ # At least one brain input should produce non-identical samples
165
+ any_differ = False
166
+ for group in texts:
167
+ if len(set(group)) > 1:
168
+ any_differ = True
169
+ break
170
+ assert any_differ, "At high temperature, diverse samples should differ"
171
+
172
+ def test_reconstruct_top_p(self, model, brain_data):
173
+ model.eval()
174
+ result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=0)
175
+ assert result.output.shape[0] == BATCH
176
+ assert len(result.metadata["texts"]) == BATCH
177
+
178
+ def test_reconstruct_top_p_and_top_k(self, model, brain_data):
179
+ model.eval()
180
+ result = model.reconstruct(brain_data, max_len=8, top_p=0.9, top_k=20)
181
+ assert result.output.shape[0] == BATCH
182
+
183
+ def test_reconstruct_num_samples_1(self, model, brain_data):
184
+ """num_samples=1 should return flat list of strings."""
185
+ model.eval()
186
+ result = model.reconstruct(brain_data, max_len=8, num_samples=1)
187
+ texts = result.metadata["texts"]
188
+ assert len(texts) == BATCH
189
+ for t in texts:
190
+ assert isinstance(t, str)
191
+
192
+
145
193
  class TestBuildBrain2Text:
146
194
  def test_default_build(self):
147
195
  model = build_brain2text(n_voxels=64, max_len=16, hidden_dim=16, depth=1)
File without changes
File without changes
File without changes
File without changes