mrzerocore 0.4.3__cp37-abi3-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. MRzeroCore/__init__.py +22 -0
  2. MRzeroCore/_prepass.abi3.so +0 -0
  3. MRzeroCore/phantom/brainweb/.gitignore +1 -0
  4. MRzeroCore/phantom/brainweb/__init__.py +192 -0
  5. MRzeroCore/phantom/brainweb/brainweb_data.json +92 -0
  6. MRzeroCore/phantom/brainweb/brainweb_data_sources.txt +74 -0
  7. MRzeroCore/phantom/brainweb/output/.gitkeep +0 -0
  8. MRzeroCore/phantom/custom_voxel_phantom.py +240 -0
  9. MRzeroCore/phantom/nifti_phantom.py +210 -0
  10. MRzeroCore/phantom/sim_data.py +200 -0
  11. MRzeroCore/phantom/tissue_dict.py +269 -0
  12. MRzeroCore/phantom/voxel_grid_phantom.py +610 -0
  13. MRzeroCore/pulseq/exporter.py +374 -0
  14. MRzeroCore/pulseq/exporter_v2.py +650 -0
  15. MRzeroCore/pulseq/helpers.py +228 -0
  16. MRzeroCore/pulseq/pulseq_exporter.py +553 -0
  17. MRzeroCore/pulseq/pulseq_loader/__init__.py +66 -0
  18. MRzeroCore/pulseq/pulseq_loader/adc.py +48 -0
  19. MRzeroCore/pulseq/pulseq_loader/helpers.py +75 -0
  20. MRzeroCore/pulseq/pulseq_loader/pulse.py +80 -0
  21. MRzeroCore/pulseq/pulseq_loader/pulseq_file/__init__.py +235 -0
  22. MRzeroCore/pulseq/pulseq_loader/pulseq_file/adc.py +68 -0
  23. MRzeroCore/pulseq/pulseq_loader/pulseq_file/block.py +98 -0
  24. MRzeroCore/pulseq/pulseq_loader/pulseq_file/definitons.py +68 -0
  25. MRzeroCore/pulseq/pulseq_loader/pulseq_file/gradient.py +70 -0
  26. MRzeroCore/pulseq/pulseq_loader/pulseq_file/helpers.py +156 -0
  27. MRzeroCore/pulseq/pulseq_loader/pulseq_file/rf.py +91 -0
  28. MRzeroCore/pulseq/pulseq_loader/pulseq_file/trap.py +69 -0
  29. MRzeroCore/pulseq/pulseq_loader/spoiler.py +33 -0
  30. MRzeroCore/reconstruction.py +104 -0
  31. MRzeroCore/sequence.py +747 -0
  32. MRzeroCore/simulation/isochromat_sim.py +254 -0
  33. MRzeroCore/simulation/main_pass.py +286 -0
  34. MRzeroCore/simulation/pre_pass.py +192 -0
  35. MRzeroCore/simulation/sig_to_mrd.py +362 -0
  36. MRzeroCore/util.py +884 -0
  37. MRzeroCore.libs/libgcc_s-39080030.so.1 +0 -0
  38. mrzerocore-0.4.3.dist-info/METADATA +121 -0
  39. mrzerocore-0.4.3.dist-info/RECORD +41 -0
  40. mrzerocore-0.4.3.dist-info/WHEEL +4 -0
  41. mrzerocore-0.4.3.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,610 @@
1
+ from __future__ import annotations
2
+ from typing import Literal, Optional, Dict
3
+ from warnings import warn
4
+ from scipy import io
5
+ import numpy as np
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from .sim_data import SimData
9
+ from ..util import imshow
10
+
11
+
12
+ def sigmoid(trajectory: torch.Tensor, nyquist: torch.Tensor) -> torch.Tensor:
13
+ """Differentiable approximation of the sinc voxel dephasing function.
14
+
15
+ The true dephasing function of a sinc-shaped voxel (in real space) is a
16
+ box - function, with the FFT conform size [-nyquist, nyquist[. This is not
17
+ differentiable, so we approximate the edges with a narrow sigmod at
18
+ ±(nyquist + 0.5). The difference is neglegible at usual nyquist freqs.
19
+ """
20
+ return torch.prod(torch.sigmoid(
21
+ (nyquist - trajectory.abs() + 0.5) * 100
22
+ ), dim=1)
23
+
24
+
25
+ def sinc(trajectory: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
26
+ """Box voxel (real space) dephasing function.
27
+
28
+ The size describes the total extends of the box shape.
29
+ """
30
+ return torch.prod(torch.sinc(trajectory * size), dim=1)
31
+
32
+
33
+ def identity(trajectory: torch.Tensor) -> torch.Tensor:
34
+ """Point voxel (real space) dephasing function.
35
+
36
+ There is no dephasing.
37
+ """
38
+ return torch.ones_like(trajectory[:, 0])
39
+
40
+
41
+ def generate_B0_B1(PD):
42
+ # Generate a somewhat plausible B0 and B1 map.
43
+ # Visually fitted to look similar to the numerical_brain_cropped
44
+ x_pos, y_pos, z_pos = torch.meshgrid(
45
+ torch.linspace(-1, 1, PD.shape[0]),
46
+ torch.linspace(-1, 1, PD.shape[1]),
47
+ torch.linspace(-1, 1, PD.shape[2]),
48
+ indexing="ij"
49
+ )
50
+ B1 = torch.exp(-(0.4*x_pos**2 + 0.2*y_pos**2 + 0.3*z_pos**2))
51
+ dist2 = (0.4*x_pos**2 + 0.2*(y_pos - 0.7)**2 + 0.3*z_pos**2)
52
+ B0 = 7 / (0.05 + dist2) - 45 / (0.3 + dist2)
53
+ # Normalize such that the weighted average is 0 or 1
54
+ weight = PD / PD.sum()
55
+ B0 -= (B0 * weight).sum()
56
+ B1 /= (B1 * weight).sum()
57
+ return B0, B1
58
+
59
+
60
+ class VoxelGridPhantom:
61
+ """Class for using typical phantoms like those provided by BrainWeb.
62
+
63
+ The data is assumed to be defined by a uniform cartesian grid of samples.
64
+ As it is bandwidth limited, we assume that there is no signal above the
65
+ Nyquist frequency. This leads to the usage of sinc-shaped voxels.
66
+
67
+ Attributes
68
+ ----------
69
+ PD : torch.Tensor
70
+ (sx, sy, sz) tensor containing the Proton Density
71
+ T1 : torch.Tensor
72
+ (sx, sy, sz) tensor containing the T1 relaxation
73
+ T2 : torch.Tensor
74
+ (sx, sy, sz) tensor containing the T2 relaxation
75
+ T2dash : torch.Tensor
76
+ (sx, sy, sz) tensor containing the T2' dephasing
77
+ D : torch.Tensor
78
+ (sx, sy, sz) tensor containing the Diffusion coefficient
79
+ B0 : torch.Tensor
80
+ (sx, sy, sz) tensor containing the B0 inhomogeneities
81
+ B1 : torch.Tensor
82
+ (coil_count, sx, sy, sz) tensor of RF coil profiles
83
+ coil_sens : torch.Tensor
84
+ (coil_count, sx, sy, sz) tensor of coil sensitivities
85
+ size : torch.Tensor
86
+ Size of the data, in meters.
87
+ tissue_masks : Dict[str, torch.Tensor] | None
88
+ Segmentation masks for different tissues. The keys are the tissue names
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ PD: torch.Tensor,
94
+ T1: torch.Tensor,
95
+ T2: torch.Tensor,
96
+ T2dash: torch.Tensor,
97
+ D: torch.Tensor,
98
+ B0: torch.Tensor,
99
+ B1: torch.Tensor,
100
+ coil_sens: torch.Tensor,
101
+ size: torch.Tensor,
102
+ phantom_motion=None,
103
+ voxel_motion=None,
104
+ tissue_masks: Optional[Dict[str,torch.Tensor]] = None,
105
+ ) -> None:
106
+ """Set the phantom attributes to the provided parameters.
107
+
108
+ This function does no cloning nor contain any other funcionality. You
109
+ probably want to use :meth:`brainweb` to load a phantom instead.
110
+ """
111
+ self.PD = torch.as_tensor(PD, dtype=torch.float32)
112
+ self.T1 = torch.as_tensor(T1, dtype=torch.float32)
113
+ self.T2 = torch.as_tensor(T2, dtype=torch.float32)
114
+ self.T2dash = torch.as_tensor(T2dash, dtype=torch.float32)
115
+ self.D = torch.as_tensor(D, dtype=torch.float32)
116
+ self.B0 = torch.as_tensor(B0, dtype=torch.float32)
117
+ self.B1 = torch.as_tensor(B1, dtype=torch.complex64)
118
+ self.tissue_masks = tissue_masks
119
+ if self.tissue_masks is None:
120
+ self.tissue_masks = {}
121
+ self.coil_sens = torch.as_tensor(coil_sens, dtype=torch.complex64)
122
+ self.size = torch.as_tensor(size, dtype=torch.float32)
123
+
124
+ self.phantom_motion = phantom_motion
125
+ self.voxel_motion = voxel_motion
126
+
127
+ def build(self, PD_threshold: float = 1e-6,
128
+ voxel_shape: Literal["sinc", "box", "point"] = "sinc"
129
+ ) -> SimData:
130
+ """Build a :class:`SimData` instance for simulation.
131
+
132
+ Arguments
133
+ ---------
134
+ PD_threshold : float
135
+ All voxels with a proton density below this value are ignored.
136
+ """
137
+ mask = self.PD > PD_threshold
138
+
139
+ shape = torch.tensor(mask.shape)
140
+ pos_x, pos_y, pos_z = torch.meshgrid(
141
+ self.size[0] *
142
+ torch.fft.fftshift(torch.fft.fftfreq(
143
+ int(shape[0]), device=self.PD.device)),
144
+ self.size[1] *
145
+ torch.fft.fftshift(torch.fft.fftfreq(
146
+ int(shape[1]), device=self.PD.device)),
147
+ self.size[2] *
148
+ torch.fft.fftshift(torch.fft.fftfreq(
149
+ int(shape[2]), device=self.PD.device)),
150
+ indexing="ij"
151
+ )
152
+
153
+ voxel_pos = torch.stack([
154
+ pos_x[mask].flatten(),
155
+ pos_y[mask].flatten(),
156
+ pos_z[mask].flatten()
157
+ ], dim=1)
158
+
159
+ if voxel_shape == "box":
160
+ def dephasing_func(t, n): return sinc(t, 0.5 / n)
161
+ elif voxel_shape == "sinc":
162
+ def dephasing_func(t, n): return sigmoid(t, n)
163
+ elif voxel_shape == "point":
164
+ def dephasing_func(t, _): return identity(t)
165
+ else:
166
+ raise ValueError(f"Unsupported voxel shape '{voxel_shape}'")
167
+
168
+ return SimData(
169
+ self.PD[mask],
170
+ self.T1[mask],
171
+ self.T2[mask],
172
+ self.T2dash[mask],
173
+ self.D[mask],
174
+ self.B0[mask],
175
+ self.B1[:, mask],
176
+ self.coil_sens[:, mask],
177
+ self.size,
178
+ voxel_pos,
179
+ torch.as_tensor(shape, device=self.PD.device) / 2 / self.size,
180
+ dephasing_func,
181
+ recover_func=lambda data: recover(mask, data),
182
+ phantom_motion=self.phantom_motion,
183
+ voxel_motion=self.voxel_motion,
184
+ tissue_masks=self.tissue_masks
185
+ )
186
+
187
+ @classmethod
188
+ def brainweb(cls, file_name: str) -> VoxelGridPhantom:
189
+ warn("brainweb() will be removed in a future version, use load() instead", DeprecationWarning)
190
+ return cls.load(file_name)
191
+
192
+ @classmethod
193
+ def load(cls, file_name: str) -> VoxelGridPhantom:
194
+ """Load a phantom from data produced by `generate_maps.py`."""
195
+ with np.load(file_name) as data:
196
+ T1 = torch.tensor(data['T1_map'])
197
+ T2 = torch.tensor(data['T2_map'])
198
+ T2dash = torch.tensor(data['T2dash_map'])
199
+ PD = torch.tensor(data['PD_map'])
200
+ D = torch.tensor(data['D_map'])
201
+ try:
202
+ B0 = torch.tensor(data['B0_map'])
203
+ B1 = torch.tensor(data['B1_map'])
204
+ except KeyError:
205
+ B0, B1 = generate_B0_B1(PD)
206
+ try:
207
+ size = torch.tensor(data['FOV'], dtype=torch.float)
208
+ except KeyError:
209
+ size = torch.tensor([0.192, 0.192, 0.192])
210
+
211
+ tissue_masks = {
212
+ key: torch.tensor(mask)
213
+ for key, mask in data.items()
214
+ if key.startswith("tissue_")
215
+ }
216
+ if B1.ndim == 3:
217
+ # Add coil-dimension
218
+ B1 = B1[None, ...]
219
+
220
+ return cls(
221
+ PD, T1, T2, T2dash, D, B0, B1,
222
+ torch.ones(1, *PD.shape), size,
223
+ tissue_masks=tissue_masks
224
+ )
225
+
226
+ @classmethod
227
+ def load_mat(
228
+ cls,
229
+ file_name: str,
230
+ T2dash: float | torch.Tensor = 0.03,
231
+ D: float | torch.Tensor = 1.0,
232
+ size = [0.2, 0.2, 8e-3]
233
+ ) -> VoxelGridPhantom:
234
+ """Load a :class:`VoxelGridPhantom` from a .mat file.
235
+
236
+ The file must contain exactly one array, of which the last dimension
237
+ must have size 5. This dimension is assumed to specify (in that order):
238
+
239
+ * Proton density
240
+ * T1
241
+ * T2
242
+ * B0
243
+ * B1
244
+
245
+ All data is per-voxel, multiple coils are not yet supported.
246
+ Data will be normalized (see constructor).
247
+
248
+ Parameters
249
+ ----------
250
+ file_name : str
251
+ Name of the matlab .mat file to be loaded
252
+ T2dash : float, optional
253
+ T2dash value set uniformly for all voxels, by default 0.03
254
+ T2dash : float, optional
255
+ Diffusion value set uniformly for all voxels, by default 1
256
+
257
+ Returns
258
+ -------
259
+ SimData
260
+ A new :class:`SimData` instance containing the loaded data.
261
+
262
+ Raises
263
+ ------
264
+ Exception
265
+ The loaded file does not contain the expected data.
266
+ """
267
+ data = _load_tensor_from_mat(file_name)
268
+
269
+ if data.ndim < 2 or data.shape[-1] != 5:
270
+ raise Exception(
271
+ f"Expected a tensor with shape [..., 5], "
272
+ f"but got {list(data.shape)}"
273
+ )
274
+
275
+ if data.ndim == 3:
276
+ # Expand to 3D: [x, y, i] -> [x, y, z, i]
277
+ data = data.unsqueeze(2)
278
+
279
+ if isinstance(T2dash, float):
280
+ T2dash = torch.full_like(data[..., 0], T2dash)
281
+ if isinstance(D, float):
282
+ D = torch.full_like(data[..., 0], D)
283
+
284
+ return cls(
285
+ data[..., 0], # PD
286
+ data[..., 1], # T1
287
+ data[..., 2], # T2
288
+ T2dash,
289
+ D,
290
+ data[..., 3], # B0
291
+ data[..., 4][None, ...], # B1
292
+ coil_sens=torch.ones(1, *data.shape[:-1]),
293
+ size=torch.as_tensor(size),
294
+ )
295
+
296
+ def slices(self, slices: list[int]) -> VoxelGridPhantom:
297
+ """Generate a copy that only contains the selected slice(s).
298
+
299
+ Parameters
300
+ ----------
301
+ slice: int or tuple
302
+ The selected slice(s)
303
+
304
+ Returns
305
+ -------
306
+ SimData
307
+ A new instance containing the selected slice(s).
308
+ """
309
+ assert 0 <= any([slices]) < self.PD.shape[2]
310
+
311
+ def select(tensor: torch.Tensor):
312
+ return tensor[..., slices].view(
313
+ *list(self.PD.shape[:2]), len(slices)
314
+ )
315
+ def select_multicoil(tensor: torch.Tensor):
316
+ coils = tensor.shape[0]
317
+ return tensor[..., slices].view(
318
+ coils, *list(self.PD.shape[:2]), len(slices)
319
+ )
320
+
321
+ return VoxelGridPhantom(
322
+ select(self.PD),
323
+ select(self.T1),
324
+ select(self.T2),
325
+ select(self.T2dash),
326
+ select(self.D),
327
+ select(self.B0),
328
+ select_multicoil(self.B1),
329
+ select_multicoil(self.coil_sens),
330
+ self.size.clone(),
331
+ tissue_masks={
332
+ key: mask[..., slices] for key, mask in self.tissue_masks.items()
333
+ },
334
+ )
335
+
336
+ def scale_fft(self, x: int, y: int, z: int) -> VoxelGridPhantom:
337
+ """This is experimental, shows strong ringing and is not recommended"""
338
+ # This function currently only supports downscaling
339
+ assert x <= self.PD.shape[0]
340
+ assert y <= self.PD.shape[1]
341
+ assert z <= self.PD.shape[2]
342
+
343
+ # Normalize signal, otherwise magnitude changes with scaling
344
+ norm = (
345
+ (x / self.PD.shape[0]) *
346
+ (y / self.PD.shape[1]) *
347
+ (z / self.PD.shape[2])
348
+ )
349
+ # Center for FT
350
+ cx = self.PD.shape[0] // 2
351
+ cy = self.PD.shape[1] // 2
352
+ cz = self.PD.shape[2] // 2
353
+
354
+ def scale(map: torch.Tensor) -> torch.Tensor:
355
+ FT = torch.fft.fftshift(torch.fft.fftn(map))
356
+ FT = FT[
357
+ cx - x // 2:cx + (x+1) // 2,
358
+ cy - y // 2:cy + (y+1) // 2,
359
+ cz - z // 2:cz + (z+1) // 2
360
+ ] * norm
361
+ return torch.fft.ifftn(torch.fft.ifftshift(FT)).abs()
362
+
363
+ return VoxelGridPhantom(
364
+ scale(self.PD),
365
+ scale(self.T1),
366
+ scale(self.T2),
367
+ scale(self.T2dash),
368
+ scale(self.D),
369
+ scale(self.B0),
370
+ scale(self.B1.squeeze()).unsqueeze(0),
371
+ scale(self.coil_sens.squeeze()).unsqueeze(0),
372
+ self.size.clone(),
373
+ tissue_masks={
374
+ key: scale(mask) for key, mask in self.tissue_masks.items()
375
+ }
376
+ )
377
+
378
+ def interpolate(self, x: int, y: int, z: int) -> VoxelGridPhantom:
379
+ """Return a resized copy of this :class:`SimData` instance.
380
+
381
+ This uses torch.nn.functional.interpolate in 'area' mode, which is not
382
+ very good: Assumes pixels are squares -> has strong aliasing.
383
+
384
+ Parameters
385
+ ----------
386
+ x : int
387
+ The new resolution along the 1st dimension
388
+ y : int
389
+ The new resolution along the 2nd dimension
390
+ z : int
391
+ The new resolution along the 3rd dimension
392
+ mode : str
393
+ Algorithm used for upsampling (via torch.nn.functional.interpolate)
394
+
395
+ Returns
396
+ -------
397
+ SimData
398
+ A new :class:`SimData` instance containing resized tensors.
399
+ """
400
+ def resample(tensor: torch.Tensor) -> torch.Tensor:
401
+ # Introduce additional dimensions: mini-batch and channels
402
+ return torch.nn.functional.interpolate(
403
+ tensor[None, None, ...], size=(x, y, z), mode='trilinear'
404
+ )[0, 0, ...]
405
+
406
+ def resample_multicoil(tensor: torch.Tensor) -> torch.Tensor:
407
+ coils = tensor.shape[0]
408
+ output = torch.zeros(coils, x, y, z, dtype=tensor.dtype)
409
+ for i in range(coils):
410
+ re = resample(torch.real(tensor[i, ...]))
411
+ im = resample(torch.imag(tensor[i, ...]))
412
+ output[i, ...] = re + 1j * im
413
+
414
+ return output
415
+
416
+ def resample_masks(tensors: Dict) -> Optional[Dict]:
417
+ output = {}
418
+ for key, mask in tensors.items():
419
+ # Interpolate the mask
420
+ interpolated_mask = torch.nn.functional.interpolate(
421
+ mask[None, None, ...].float(), size=(x, y, z), mode='area'
422
+ )[0, 0, ...]
423
+ # Store the result
424
+ output[key] = interpolated_mask
425
+
426
+ return output
427
+
428
+ return VoxelGridPhantom(
429
+ resample(self.PD),
430
+ resample(self.T1),
431
+ resample(self.T2),
432
+ resample(self.T2dash),
433
+ resample(self.D),
434
+ resample(self.B0),
435
+ resample_multicoil(self.B1),
436
+ resample_multicoil(self.coil_sens),
437
+ self.size.clone(),
438
+ tissue_masks=resample_masks(self.tissue_masks)
439
+ )
440
+
441
+ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
442
+ """
443
+ Print and plot all data stored in this phantom.
444
+
445
+ Parameters
446
+ ----------
447
+ plot_masks : bool
448
+ Plot tissue masks stored in this phantom (assumes they exist)
449
+ slice : str | int
450
+ If int, the specified slice is plotted. "center" plots the center
451
+ slice and "all" plots all slices as a grid.
452
+ time_unit : str
453
+ Time unit to use for T1, T2, and T2' maps (default: 's'). Supported 's' and 'ms'.
454
+ """
455
+ print("VoxelGridPhantom")
456
+ print(f"size = {self.size}")
457
+ # Center slice
458
+ if plot_slice == "center":
459
+ s = self.PD.shape[2] // 2
460
+ elif plot_slice == "all":
461
+ s = slice(None)
462
+ elif isinstance(plot_slice, int):
463
+ s = plot_slice
464
+ else:
465
+ raise ValueError("expected plot_slice to be 'all', 'center' or an integer")
466
+ # Warn if we only print a part of all data
467
+ if self.coil_sens.shape[0] > 1:
468
+ print(f"Plotting 1st of {self.coil_sens.shape[0]} coil sens maps")
469
+ if self.B1.shape[0] > 1:
470
+ print(f"Plotting 1st of {self.B1.shape[0]} B1 maps")
471
+ if self.PD.shape[2] > 1:
472
+ print(f"Plotting slice {s} / {self.PD.shape[2]}")
473
+
474
+
475
+ # Get time unit scaling factor
476
+ time_factor = 1000 if time_unit == 'ms' else 1
477
+
478
+ # Determine the number of subplots needed
479
+ num_plots = 9 # Base number of plots without masks
480
+ if plot_masks:
481
+ num_masks = len(self.tissue_masks)
482
+ num_plots += num_masks
483
+
484
+ # Calculate the grid size based on the number of plots
485
+ cols = 3
486
+ rows = int(np.ceil(num_plots / cols))
487
+
488
+ plt.figure(figsize=(12, rows * 3))
489
+
490
+ # Plot the basic maps
491
+ plt.subplot(rows, cols, 1)
492
+ plt.title("PD")
493
+ imshow(self.PD[:, :, s], vmin=0)
494
+ plt.colorbar()
495
+
496
+ plt.subplot(rows, cols, 2)
497
+ plt.title("T1 (%s)" % time_unit)
498
+ imshow(self.T1[:, :, s]*time_factor, vmin=0)
499
+ plt.colorbar()
500
+
501
+ plt.subplot(rows, cols, 3)
502
+ plt.title("T2 (%s)" % time_unit)
503
+ imshow(self.T2[:, :, s]*time_factor, vmin=0)
504
+ plt.colorbar()
505
+
506
+ plt.subplot(rows, cols, 4)
507
+ plt.title("T2' (%s)" % time_unit)
508
+ imshow(self.T2dash[:, :, s]*time_factor, vmin=0)
509
+ plt.colorbar()
510
+
511
+ plt.subplot(rows, cols, 5)
512
+ plt.title("D")
513
+ imshow(self.D[:, :, s], vmin=0)
514
+ plt.colorbar()
515
+
516
+ plt.subplot(rows, cols, 7)
517
+ plt.title("B0")
518
+ imshow(self.B0[:, :, s])
519
+ plt.colorbar()
520
+
521
+ plt.subplot(rows, cols, 8)
522
+ plt.title("B1")
523
+ imshow(torch.abs(self.B1[0, :, :, s]))
524
+ plt.colorbar()
525
+
526
+ plt.subplot(rows, cols, 9)
527
+ plt.title("coil sens")
528
+ imshow(torch.abs(self.coil_sens[0, :, :, s]), vmin=0)
529
+ plt.colorbar()
530
+
531
+ # Conditionally plot masks if plot_masks is True
532
+ if plot_masks:
533
+ for i, (key, mask) in enumerate(self.tissue_masks.items()):
534
+ plt.subplot(rows, cols, 10 + i)
535
+ plt.title(key)
536
+ imshow(mask)
537
+ plt.colorbar()
538
+
539
+ plt.tight_layout()
540
+ plt.show()
541
+
542
+ def plot3D(self, data2print: int = 0) -> None:
543
+ """Print and plot all slices of one selected data stored in this phantom."""
544
+ print("VoxelGridPhantom")
545
+ print(f"size = {self.size}")
546
+ print()
547
+
548
+ label = ['PD', 'T1', 'T2', "T2'", "D", "B0", "B1", "coil sens"]
549
+
550
+ tensors = [
551
+ self.PD, self.T1, self.T2, self.T2dash, self.D, self.B0,
552
+ self.B1.squeeze(0), self.coil_sens
553
+ ]
554
+
555
+ # Warn if we only print a part of all data
556
+ print(f"Plotting {label[data2print]}")
557
+
558
+ tensor = tensors[data2print].squeeze(0)
559
+
560
+ util.plot3D(tensor, figsize=(20, 5))
561
+ plt.title(label[data2print])
562
+ plt.show()
563
+
564
+
565
+ def recover(mask, sim_data: SimData) -> VoxelGridPhantom:
566
+ """Provided to :class:`SimData` to reverse the ``build()``"""
567
+
568
+ mask = mask.to(sim_data.device)
569
+
570
+ def to_full(sparse):
571
+ assert sparse.ndim < 3
572
+ if sparse.ndim == 2:
573
+ full = torch.zeros(
574
+ [sparse.shape[0], *mask.shape], dtype=sparse.dtype, device=mask.device)
575
+ full[:, mask] = sparse
576
+ else:
577
+ full = torch.zeros(mask.shape, device=mask.device)
578
+ full[mask] = sparse
579
+ return full
580
+
581
+ return VoxelGridPhantom(
582
+ to_full(sim_data.PD),
583
+ to_full(sim_data.T1),
584
+ to_full(sim_data.T2),
585
+ to_full(sim_data.T2dash),
586
+ to_full(sim_data.D),
587
+ to_full(sim_data.B0),
588
+ to_full(sim_data.B1),
589
+ to_full(sim_data.coil_sens),
590
+ sim_data.size
591
+ )
592
+
593
+
594
+ def _load_tensor_from_mat(file_name: str) -> torch.Tensor:
595
+ mat = io.loadmat(file_name)
596
+
597
+ keys = [
598
+ key for key in mat
599
+ if not (key.startswith('__') and key.endswith('__'))
600
+ ]
601
+
602
+ arrays = [mat[key] for key in keys if isinstance(mat[key], np.ndarray)]
603
+
604
+ if len(keys) == 0:
605
+ raise Exception("The loaded mat file does not contain any variables")
606
+
607
+ if len(arrays) != 1:
608
+ raise Exception("The loaded mat file must contain exactly one array")
609
+
610
+ return torch.from_numpy(arrays[0]).float()