mrzerocore 0.4.3__cp37-abi3-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MRzeroCore/__init__.py +22 -0
- MRzeroCore/_prepass.abi3.so +0 -0
- MRzeroCore/phantom/brainweb/.gitignore +1 -0
- MRzeroCore/phantom/brainweb/__init__.py +192 -0
- MRzeroCore/phantom/brainweb/brainweb_data.json +92 -0
- MRzeroCore/phantom/brainweb/brainweb_data_sources.txt +74 -0
- MRzeroCore/phantom/brainweb/output/.gitkeep +0 -0
- MRzeroCore/phantom/custom_voxel_phantom.py +240 -0
- MRzeroCore/phantom/nifti_phantom.py +210 -0
- MRzeroCore/phantom/sim_data.py +200 -0
- MRzeroCore/phantom/tissue_dict.py +269 -0
- MRzeroCore/phantom/voxel_grid_phantom.py +610 -0
- MRzeroCore/pulseq/exporter.py +374 -0
- MRzeroCore/pulseq/exporter_v2.py +650 -0
- MRzeroCore/pulseq/helpers.py +228 -0
- MRzeroCore/pulseq/pulseq_exporter.py +553 -0
- MRzeroCore/pulseq/pulseq_loader/__init__.py +66 -0
- MRzeroCore/pulseq/pulseq_loader/adc.py +48 -0
- MRzeroCore/pulseq/pulseq_loader/helpers.py +75 -0
- MRzeroCore/pulseq/pulseq_loader/pulse.py +80 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/__init__.py +235 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/adc.py +68 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/block.py +98 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/definitons.py +68 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/gradient.py +70 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/helpers.py +156 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/rf.py +91 -0
- MRzeroCore/pulseq/pulseq_loader/pulseq_file/trap.py +69 -0
- MRzeroCore/pulseq/pulseq_loader/spoiler.py +33 -0
- MRzeroCore/reconstruction.py +104 -0
- MRzeroCore/sequence.py +747 -0
- MRzeroCore/simulation/isochromat_sim.py +254 -0
- MRzeroCore/simulation/main_pass.py +286 -0
- MRzeroCore/simulation/pre_pass.py +192 -0
- MRzeroCore/simulation/sig_to_mrd.py +362 -0
- MRzeroCore/util.py +884 -0
- MRzeroCore.libs/libgcc_s-39080030.so.1 +0 -0
- mrzerocore-0.4.3.dist-info/METADATA +121 -0
- mrzerocore-0.4.3.dist-info/RECORD +41 -0
- mrzerocore-0.4.3.dist-info/WHEEL +4 -0
- mrzerocore-0.4.3.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Literal, Optional, Dict
|
|
3
|
+
from warnings import warn
|
|
4
|
+
from scipy import io
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
from .sim_data import SimData
|
|
9
|
+
from ..util import imshow
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def sigmoid(trajectory: torch.Tensor, nyquist: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
"""Differentiable approximation of the sinc voxel dephasing function.
|
|
14
|
+
|
|
15
|
+
The true dephasing function of a sinc-shaped voxel (in real space) is a
|
|
16
|
+
box - function, with the FFT conform size [-nyquist, nyquist[. This is not
|
|
17
|
+
differentiable, so we approximate the edges with a narrow sigmod at
|
|
18
|
+
±(nyquist + 0.5). The difference is neglegible at usual nyquist freqs.
|
|
19
|
+
"""
|
|
20
|
+
return torch.prod(torch.sigmoid(
|
|
21
|
+
(nyquist - trajectory.abs() + 0.5) * 100
|
|
22
|
+
), dim=1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def sinc(trajectory: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
"""Box voxel (real space) dephasing function.
|
|
27
|
+
|
|
28
|
+
The size describes the total extends of the box shape.
|
|
29
|
+
"""
|
|
30
|
+
return torch.prod(torch.sinc(trajectory * size), dim=1)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def identity(trajectory: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
"""Point voxel (real space) dephasing function.
|
|
35
|
+
|
|
36
|
+
There is no dephasing.
|
|
37
|
+
"""
|
|
38
|
+
return torch.ones_like(trajectory[:, 0])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def generate_B0_B1(PD):
|
|
42
|
+
# Generate a somewhat plausible B0 and B1 map.
|
|
43
|
+
# Visually fitted to look similar to the numerical_brain_cropped
|
|
44
|
+
x_pos, y_pos, z_pos = torch.meshgrid(
|
|
45
|
+
torch.linspace(-1, 1, PD.shape[0]),
|
|
46
|
+
torch.linspace(-1, 1, PD.shape[1]),
|
|
47
|
+
torch.linspace(-1, 1, PD.shape[2]),
|
|
48
|
+
indexing="ij"
|
|
49
|
+
)
|
|
50
|
+
B1 = torch.exp(-(0.4*x_pos**2 + 0.2*y_pos**2 + 0.3*z_pos**2))
|
|
51
|
+
dist2 = (0.4*x_pos**2 + 0.2*(y_pos - 0.7)**2 + 0.3*z_pos**2)
|
|
52
|
+
B0 = 7 / (0.05 + dist2) - 45 / (0.3 + dist2)
|
|
53
|
+
# Normalize such that the weighted average is 0 or 1
|
|
54
|
+
weight = PD / PD.sum()
|
|
55
|
+
B0 -= (B0 * weight).sum()
|
|
56
|
+
B1 /= (B1 * weight).sum()
|
|
57
|
+
return B0, B1
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class VoxelGridPhantom:
|
|
61
|
+
"""Class for using typical phantoms like those provided by BrainWeb.
|
|
62
|
+
|
|
63
|
+
The data is assumed to be defined by a uniform cartesian grid of samples.
|
|
64
|
+
As it is bandwidth limited, we assume that there is no signal above the
|
|
65
|
+
Nyquist frequency. This leads to the usage of sinc-shaped voxels.
|
|
66
|
+
|
|
67
|
+
Attributes
|
|
68
|
+
----------
|
|
69
|
+
PD : torch.Tensor
|
|
70
|
+
(sx, sy, sz) tensor containing the Proton Density
|
|
71
|
+
T1 : torch.Tensor
|
|
72
|
+
(sx, sy, sz) tensor containing the T1 relaxation
|
|
73
|
+
T2 : torch.Tensor
|
|
74
|
+
(sx, sy, sz) tensor containing the T2 relaxation
|
|
75
|
+
T2dash : torch.Tensor
|
|
76
|
+
(sx, sy, sz) tensor containing the T2' dephasing
|
|
77
|
+
D : torch.Tensor
|
|
78
|
+
(sx, sy, sz) tensor containing the Diffusion coefficient
|
|
79
|
+
B0 : torch.Tensor
|
|
80
|
+
(sx, sy, sz) tensor containing the B0 inhomogeneities
|
|
81
|
+
B1 : torch.Tensor
|
|
82
|
+
(coil_count, sx, sy, sz) tensor of RF coil profiles
|
|
83
|
+
coil_sens : torch.Tensor
|
|
84
|
+
(coil_count, sx, sy, sz) tensor of coil sensitivities
|
|
85
|
+
size : torch.Tensor
|
|
86
|
+
Size of the data, in meters.
|
|
87
|
+
tissue_masks : Dict[str, torch.Tensor] | None
|
|
88
|
+
Segmentation masks for different tissues. The keys are the tissue names
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
PD: torch.Tensor,
|
|
94
|
+
T1: torch.Tensor,
|
|
95
|
+
T2: torch.Tensor,
|
|
96
|
+
T2dash: torch.Tensor,
|
|
97
|
+
D: torch.Tensor,
|
|
98
|
+
B0: torch.Tensor,
|
|
99
|
+
B1: torch.Tensor,
|
|
100
|
+
coil_sens: torch.Tensor,
|
|
101
|
+
size: torch.Tensor,
|
|
102
|
+
phantom_motion=None,
|
|
103
|
+
voxel_motion=None,
|
|
104
|
+
tissue_masks: Optional[Dict[str,torch.Tensor]] = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Set the phantom attributes to the provided parameters.
|
|
107
|
+
|
|
108
|
+
This function does no cloning nor contain any other funcionality. You
|
|
109
|
+
probably want to use :meth:`brainweb` to load a phantom instead.
|
|
110
|
+
"""
|
|
111
|
+
self.PD = torch.as_tensor(PD, dtype=torch.float32)
|
|
112
|
+
self.T1 = torch.as_tensor(T1, dtype=torch.float32)
|
|
113
|
+
self.T2 = torch.as_tensor(T2, dtype=torch.float32)
|
|
114
|
+
self.T2dash = torch.as_tensor(T2dash, dtype=torch.float32)
|
|
115
|
+
self.D = torch.as_tensor(D, dtype=torch.float32)
|
|
116
|
+
self.B0 = torch.as_tensor(B0, dtype=torch.float32)
|
|
117
|
+
self.B1 = torch.as_tensor(B1, dtype=torch.complex64)
|
|
118
|
+
self.tissue_masks = tissue_masks
|
|
119
|
+
if self.tissue_masks is None:
|
|
120
|
+
self.tissue_masks = {}
|
|
121
|
+
self.coil_sens = torch.as_tensor(coil_sens, dtype=torch.complex64)
|
|
122
|
+
self.size = torch.as_tensor(size, dtype=torch.float32)
|
|
123
|
+
|
|
124
|
+
self.phantom_motion = phantom_motion
|
|
125
|
+
self.voxel_motion = voxel_motion
|
|
126
|
+
|
|
127
|
+
def build(self, PD_threshold: float = 1e-6,
|
|
128
|
+
voxel_shape: Literal["sinc", "box", "point"] = "sinc"
|
|
129
|
+
) -> SimData:
|
|
130
|
+
"""Build a :class:`SimData` instance for simulation.
|
|
131
|
+
|
|
132
|
+
Arguments
|
|
133
|
+
---------
|
|
134
|
+
PD_threshold : float
|
|
135
|
+
All voxels with a proton density below this value are ignored.
|
|
136
|
+
"""
|
|
137
|
+
mask = self.PD > PD_threshold
|
|
138
|
+
|
|
139
|
+
shape = torch.tensor(mask.shape)
|
|
140
|
+
pos_x, pos_y, pos_z = torch.meshgrid(
|
|
141
|
+
self.size[0] *
|
|
142
|
+
torch.fft.fftshift(torch.fft.fftfreq(
|
|
143
|
+
int(shape[0]), device=self.PD.device)),
|
|
144
|
+
self.size[1] *
|
|
145
|
+
torch.fft.fftshift(torch.fft.fftfreq(
|
|
146
|
+
int(shape[1]), device=self.PD.device)),
|
|
147
|
+
self.size[2] *
|
|
148
|
+
torch.fft.fftshift(torch.fft.fftfreq(
|
|
149
|
+
int(shape[2]), device=self.PD.device)),
|
|
150
|
+
indexing="ij"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
voxel_pos = torch.stack([
|
|
154
|
+
pos_x[mask].flatten(),
|
|
155
|
+
pos_y[mask].flatten(),
|
|
156
|
+
pos_z[mask].flatten()
|
|
157
|
+
], dim=1)
|
|
158
|
+
|
|
159
|
+
if voxel_shape == "box":
|
|
160
|
+
def dephasing_func(t, n): return sinc(t, 0.5 / n)
|
|
161
|
+
elif voxel_shape == "sinc":
|
|
162
|
+
def dephasing_func(t, n): return sigmoid(t, n)
|
|
163
|
+
elif voxel_shape == "point":
|
|
164
|
+
def dephasing_func(t, _): return identity(t)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Unsupported voxel shape '{voxel_shape}'")
|
|
167
|
+
|
|
168
|
+
return SimData(
|
|
169
|
+
self.PD[mask],
|
|
170
|
+
self.T1[mask],
|
|
171
|
+
self.T2[mask],
|
|
172
|
+
self.T2dash[mask],
|
|
173
|
+
self.D[mask],
|
|
174
|
+
self.B0[mask],
|
|
175
|
+
self.B1[:, mask],
|
|
176
|
+
self.coil_sens[:, mask],
|
|
177
|
+
self.size,
|
|
178
|
+
voxel_pos,
|
|
179
|
+
torch.as_tensor(shape, device=self.PD.device) / 2 / self.size,
|
|
180
|
+
dephasing_func,
|
|
181
|
+
recover_func=lambda data: recover(mask, data),
|
|
182
|
+
phantom_motion=self.phantom_motion,
|
|
183
|
+
voxel_motion=self.voxel_motion,
|
|
184
|
+
tissue_masks=self.tissue_masks
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def brainweb(cls, file_name: str) -> VoxelGridPhantom:
|
|
189
|
+
warn("brainweb() will be removed in a future version, use load() instead", DeprecationWarning)
|
|
190
|
+
return cls.load(file_name)
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def load(cls, file_name: str) -> VoxelGridPhantom:
|
|
194
|
+
"""Load a phantom from data produced by `generate_maps.py`."""
|
|
195
|
+
with np.load(file_name) as data:
|
|
196
|
+
T1 = torch.tensor(data['T1_map'])
|
|
197
|
+
T2 = torch.tensor(data['T2_map'])
|
|
198
|
+
T2dash = torch.tensor(data['T2dash_map'])
|
|
199
|
+
PD = torch.tensor(data['PD_map'])
|
|
200
|
+
D = torch.tensor(data['D_map'])
|
|
201
|
+
try:
|
|
202
|
+
B0 = torch.tensor(data['B0_map'])
|
|
203
|
+
B1 = torch.tensor(data['B1_map'])
|
|
204
|
+
except KeyError:
|
|
205
|
+
B0, B1 = generate_B0_B1(PD)
|
|
206
|
+
try:
|
|
207
|
+
size = torch.tensor(data['FOV'], dtype=torch.float)
|
|
208
|
+
except KeyError:
|
|
209
|
+
size = torch.tensor([0.192, 0.192, 0.192])
|
|
210
|
+
|
|
211
|
+
tissue_masks = {
|
|
212
|
+
key: torch.tensor(mask)
|
|
213
|
+
for key, mask in data.items()
|
|
214
|
+
if key.startswith("tissue_")
|
|
215
|
+
}
|
|
216
|
+
if B1.ndim == 3:
|
|
217
|
+
# Add coil-dimension
|
|
218
|
+
B1 = B1[None, ...]
|
|
219
|
+
|
|
220
|
+
return cls(
|
|
221
|
+
PD, T1, T2, T2dash, D, B0, B1,
|
|
222
|
+
torch.ones(1, *PD.shape), size,
|
|
223
|
+
tissue_masks=tissue_masks
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
def load_mat(
|
|
228
|
+
cls,
|
|
229
|
+
file_name: str,
|
|
230
|
+
T2dash: float | torch.Tensor = 0.03,
|
|
231
|
+
D: float | torch.Tensor = 1.0,
|
|
232
|
+
size = [0.2, 0.2, 8e-3]
|
|
233
|
+
) -> VoxelGridPhantom:
|
|
234
|
+
"""Load a :class:`VoxelGridPhantom` from a .mat file.
|
|
235
|
+
|
|
236
|
+
The file must contain exactly one array, of which the last dimension
|
|
237
|
+
must have size 5. This dimension is assumed to specify (in that order):
|
|
238
|
+
|
|
239
|
+
* Proton density
|
|
240
|
+
* T1
|
|
241
|
+
* T2
|
|
242
|
+
* B0
|
|
243
|
+
* B1
|
|
244
|
+
|
|
245
|
+
All data is per-voxel, multiple coils are not yet supported.
|
|
246
|
+
Data will be normalized (see constructor).
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
file_name : str
|
|
251
|
+
Name of the matlab .mat file to be loaded
|
|
252
|
+
T2dash : float, optional
|
|
253
|
+
T2dash value set uniformly for all voxels, by default 0.03
|
|
254
|
+
T2dash : float, optional
|
|
255
|
+
Diffusion value set uniformly for all voxels, by default 1
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
SimData
|
|
260
|
+
A new :class:`SimData` instance containing the loaded data.
|
|
261
|
+
|
|
262
|
+
Raises
|
|
263
|
+
------
|
|
264
|
+
Exception
|
|
265
|
+
The loaded file does not contain the expected data.
|
|
266
|
+
"""
|
|
267
|
+
data = _load_tensor_from_mat(file_name)
|
|
268
|
+
|
|
269
|
+
if data.ndim < 2 or data.shape[-1] != 5:
|
|
270
|
+
raise Exception(
|
|
271
|
+
f"Expected a tensor with shape [..., 5], "
|
|
272
|
+
f"but got {list(data.shape)}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if data.ndim == 3:
|
|
276
|
+
# Expand to 3D: [x, y, i] -> [x, y, z, i]
|
|
277
|
+
data = data.unsqueeze(2)
|
|
278
|
+
|
|
279
|
+
if isinstance(T2dash, float):
|
|
280
|
+
T2dash = torch.full_like(data[..., 0], T2dash)
|
|
281
|
+
if isinstance(D, float):
|
|
282
|
+
D = torch.full_like(data[..., 0], D)
|
|
283
|
+
|
|
284
|
+
return cls(
|
|
285
|
+
data[..., 0], # PD
|
|
286
|
+
data[..., 1], # T1
|
|
287
|
+
data[..., 2], # T2
|
|
288
|
+
T2dash,
|
|
289
|
+
D,
|
|
290
|
+
data[..., 3], # B0
|
|
291
|
+
data[..., 4][None, ...], # B1
|
|
292
|
+
coil_sens=torch.ones(1, *data.shape[:-1]),
|
|
293
|
+
size=torch.as_tensor(size),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
def slices(self, slices: list[int]) -> VoxelGridPhantom:
|
|
297
|
+
"""Generate a copy that only contains the selected slice(s).
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
slice: int or tuple
|
|
302
|
+
The selected slice(s)
|
|
303
|
+
|
|
304
|
+
Returns
|
|
305
|
+
-------
|
|
306
|
+
SimData
|
|
307
|
+
A new instance containing the selected slice(s).
|
|
308
|
+
"""
|
|
309
|
+
assert 0 <= any([slices]) < self.PD.shape[2]
|
|
310
|
+
|
|
311
|
+
def select(tensor: torch.Tensor):
|
|
312
|
+
return tensor[..., slices].view(
|
|
313
|
+
*list(self.PD.shape[:2]), len(slices)
|
|
314
|
+
)
|
|
315
|
+
def select_multicoil(tensor: torch.Tensor):
|
|
316
|
+
coils = tensor.shape[0]
|
|
317
|
+
return tensor[..., slices].view(
|
|
318
|
+
coils, *list(self.PD.shape[:2]), len(slices)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return VoxelGridPhantom(
|
|
322
|
+
select(self.PD),
|
|
323
|
+
select(self.T1),
|
|
324
|
+
select(self.T2),
|
|
325
|
+
select(self.T2dash),
|
|
326
|
+
select(self.D),
|
|
327
|
+
select(self.B0),
|
|
328
|
+
select_multicoil(self.B1),
|
|
329
|
+
select_multicoil(self.coil_sens),
|
|
330
|
+
self.size.clone(),
|
|
331
|
+
tissue_masks={
|
|
332
|
+
key: mask[..., slices] for key, mask in self.tissue_masks.items()
|
|
333
|
+
},
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def scale_fft(self, x: int, y: int, z: int) -> VoxelGridPhantom:
|
|
337
|
+
"""This is experimental, shows strong ringing and is not recommended"""
|
|
338
|
+
# This function currently only supports downscaling
|
|
339
|
+
assert x <= self.PD.shape[0]
|
|
340
|
+
assert y <= self.PD.shape[1]
|
|
341
|
+
assert z <= self.PD.shape[2]
|
|
342
|
+
|
|
343
|
+
# Normalize signal, otherwise magnitude changes with scaling
|
|
344
|
+
norm = (
|
|
345
|
+
(x / self.PD.shape[0]) *
|
|
346
|
+
(y / self.PD.shape[1]) *
|
|
347
|
+
(z / self.PD.shape[2])
|
|
348
|
+
)
|
|
349
|
+
# Center for FT
|
|
350
|
+
cx = self.PD.shape[0] // 2
|
|
351
|
+
cy = self.PD.shape[1] // 2
|
|
352
|
+
cz = self.PD.shape[2] // 2
|
|
353
|
+
|
|
354
|
+
def scale(map: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
FT = torch.fft.fftshift(torch.fft.fftn(map))
|
|
356
|
+
FT = FT[
|
|
357
|
+
cx - x // 2:cx + (x+1) // 2,
|
|
358
|
+
cy - y // 2:cy + (y+1) // 2,
|
|
359
|
+
cz - z // 2:cz + (z+1) // 2
|
|
360
|
+
] * norm
|
|
361
|
+
return torch.fft.ifftn(torch.fft.ifftshift(FT)).abs()
|
|
362
|
+
|
|
363
|
+
return VoxelGridPhantom(
|
|
364
|
+
scale(self.PD),
|
|
365
|
+
scale(self.T1),
|
|
366
|
+
scale(self.T2),
|
|
367
|
+
scale(self.T2dash),
|
|
368
|
+
scale(self.D),
|
|
369
|
+
scale(self.B0),
|
|
370
|
+
scale(self.B1.squeeze()).unsqueeze(0),
|
|
371
|
+
scale(self.coil_sens.squeeze()).unsqueeze(0),
|
|
372
|
+
self.size.clone(),
|
|
373
|
+
tissue_masks={
|
|
374
|
+
key: scale(mask) for key, mask in self.tissue_masks.items()
|
|
375
|
+
}
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def interpolate(self, x: int, y: int, z: int) -> VoxelGridPhantom:
|
|
379
|
+
"""Return a resized copy of this :class:`SimData` instance.
|
|
380
|
+
|
|
381
|
+
This uses torch.nn.functional.interpolate in 'area' mode, which is not
|
|
382
|
+
very good: Assumes pixels are squares -> has strong aliasing.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : int
|
|
387
|
+
The new resolution along the 1st dimension
|
|
388
|
+
y : int
|
|
389
|
+
The new resolution along the 2nd dimension
|
|
390
|
+
z : int
|
|
391
|
+
The new resolution along the 3rd dimension
|
|
392
|
+
mode : str
|
|
393
|
+
Algorithm used for upsampling (via torch.nn.functional.interpolate)
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
SimData
|
|
398
|
+
A new :class:`SimData` instance containing resized tensors.
|
|
399
|
+
"""
|
|
400
|
+
def resample(tensor: torch.Tensor) -> torch.Tensor:
|
|
401
|
+
# Introduce additional dimensions: mini-batch and channels
|
|
402
|
+
return torch.nn.functional.interpolate(
|
|
403
|
+
tensor[None, None, ...], size=(x, y, z), mode='trilinear'
|
|
404
|
+
)[0, 0, ...]
|
|
405
|
+
|
|
406
|
+
def resample_multicoil(tensor: torch.Tensor) -> torch.Tensor:
|
|
407
|
+
coils = tensor.shape[0]
|
|
408
|
+
output = torch.zeros(coils, x, y, z, dtype=tensor.dtype)
|
|
409
|
+
for i in range(coils):
|
|
410
|
+
re = resample(torch.real(tensor[i, ...]))
|
|
411
|
+
im = resample(torch.imag(tensor[i, ...]))
|
|
412
|
+
output[i, ...] = re + 1j * im
|
|
413
|
+
|
|
414
|
+
return output
|
|
415
|
+
|
|
416
|
+
def resample_masks(tensors: Dict) -> Optional[Dict]:
|
|
417
|
+
output = {}
|
|
418
|
+
for key, mask in tensors.items():
|
|
419
|
+
# Interpolate the mask
|
|
420
|
+
interpolated_mask = torch.nn.functional.interpolate(
|
|
421
|
+
mask[None, None, ...].float(), size=(x, y, z), mode='area'
|
|
422
|
+
)[0, 0, ...]
|
|
423
|
+
# Store the result
|
|
424
|
+
output[key] = interpolated_mask
|
|
425
|
+
|
|
426
|
+
return output
|
|
427
|
+
|
|
428
|
+
return VoxelGridPhantom(
|
|
429
|
+
resample(self.PD),
|
|
430
|
+
resample(self.T1),
|
|
431
|
+
resample(self.T2),
|
|
432
|
+
resample(self.T2dash),
|
|
433
|
+
resample(self.D),
|
|
434
|
+
resample(self.B0),
|
|
435
|
+
resample_multicoil(self.B1),
|
|
436
|
+
resample_multicoil(self.coil_sens),
|
|
437
|
+
self.size.clone(),
|
|
438
|
+
tissue_masks=resample_masks(self.tissue_masks)
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
|
|
442
|
+
"""
|
|
443
|
+
Print and plot all data stored in this phantom.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
plot_masks : bool
|
|
448
|
+
Plot tissue masks stored in this phantom (assumes they exist)
|
|
449
|
+
slice : str | int
|
|
450
|
+
If int, the specified slice is plotted. "center" plots the center
|
|
451
|
+
slice and "all" plots all slices as a grid.
|
|
452
|
+
time_unit : str
|
|
453
|
+
Time unit to use for T1, T2, and T2' maps (default: 's'). Supported 's' and 'ms'.
|
|
454
|
+
"""
|
|
455
|
+
print("VoxelGridPhantom")
|
|
456
|
+
print(f"size = {self.size}")
|
|
457
|
+
# Center slice
|
|
458
|
+
if plot_slice == "center":
|
|
459
|
+
s = self.PD.shape[2] // 2
|
|
460
|
+
elif plot_slice == "all":
|
|
461
|
+
s = slice(None)
|
|
462
|
+
elif isinstance(plot_slice, int):
|
|
463
|
+
s = plot_slice
|
|
464
|
+
else:
|
|
465
|
+
raise ValueError("expected plot_slice to be 'all', 'center' or an integer")
|
|
466
|
+
# Warn if we only print a part of all data
|
|
467
|
+
if self.coil_sens.shape[0] > 1:
|
|
468
|
+
print(f"Plotting 1st of {self.coil_sens.shape[0]} coil sens maps")
|
|
469
|
+
if self.B1.shape[0] > 1:
|
|
470
|
+
print(f"Plotting 1st of {self.B1.shape[0]} B1 maps")
|
|
471
|
+
if self.PD.shape[2] > 1:
|
|
472
|
+
print(f"Plotting slice {s} / {self.PD.shape[2]}")
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# Get time unit scaling factor
|
|
476
|
+
time_factor = 1000 if time_unit == 'ms' else 1
|
|
477
|
+
|
|
478
|
+
# Determine the number of subplots needed
|
|
479
|
+
num_plots = 9 # Base number of plots without masks
|
|
480
|
+
if plot_masks:
|
|
481
|
+
num_masks = len(self.tissue_masks)
|
|
482
|
+
num_plots += num_masks
|
|
483
|
+
|
|
484
|
+
# Calculate the grid size based on the number of plots
|
|
485
|
+
cols = 3
|
|
486
|
+
rows = int(np.ceil(num_plots / cols))
|
|
487
|
+
|
|
488
|
+
plt.figure(figsize=(12, rows * 3))
|
|
489
|
+
|
|
490
|
+
# Plot the basic maps
|
|
491
|
+
plt.subplot(rows, cols, 1)
|
|
492
|
+
plt.title("PD")
|
|
493
|
+
imshow(self.PD[:, :, s], vmin=0)
|
|
494
|
+
plt.colorbar()
|
|
495
|
+
|
|
496
|
+
plt.subplot(rows, cols, 2)
|
|
497
|
+
plt.title("T1 (%s)" % time_unit)
|
|
498
|
+
imshow(self.T1[:, :, s]*time_factor, vmin=0)
|
|
499
|
+
plt.colorbar()
|
|
500
|
+
|
|
501
|
+
plt.subplot(rows, cols, 3)
|
|
502
|
+
plt.title("T2 (%s)" % time_unit)
|
|
503
|
+
imshow(self.T2[:, :, s]*time_factor, vmin=0)
|
|
504
|
+
plt.colorbar()
|
|
505
|
+
|
|
506
|
+
plt.subplot(rows, cols, 4)
|
|
507
|
+
plt.title("T2' (%s)" % time_unit)
|
|
508
|
+
imshow(self.T2dash[:, :, s]*time_factor, vmin=0)
|
|
509
|
+
plt.colorbar()
|
|
510
|
+
|
|
511
|
+
plt.subplot(rows, cols, 5)
|
|
512
|
+
plt.title("D")
|
|
513
|
+
imshow(self.D[:, :, s], vmin=0)
|
|
514
|
+
plt.colorbar()
|
|
515
|
+
|
|
516
|
+
plt.subplot(rows, cols, 7)
|
|
517
|
+
plt.title("B0")
|
|
518
|
+
imshow(self.B0[:, :, s])
|
|
519
|
+
plt.colorbar()
|
|
520
|
+
|
|
521
|
+
plt.subplot(rows, cols, 8)
|
|
522
|
+
plt.title("B1")
|
|
523
|
+
imshow(torch.abs(self.B1[0, :, :, s]))
|
|
524
|
+
plt.colorbar()
|
|
525
|
+
|
|
526
|
+
plt.subplot(rows, cols, 9)
|
|
527
|
+
plt.title("coil sens")
|
|
528
|
+
imshow(torch.abs(self.coil_sens[0, :, :, s]), vmin=0)
|
|
529
|
+
plt.colorbar()
|
|
530
|
+
|
|
531
|
+
# Conditionally plot masks if plot_masks is True
|
|
532
|
+
if plot_masks:
|
|
533
|
+
for i, (key, mask) in enumerate(self.tissue_masks.items()):
|
|
534
|
+
plt.subplot(rows, cols, 10 + i)
|
|
535
|
+
plt.title(key)
|
|
536
|
+
imshow(mask)
|
|
537
|
+
plt.colorbar()
|
|
538
|
+
|
|
539
|
+
plt.tight_layout()
|
|
540
|
+
plt.show()
|
|
541
|
+
|
|
542
|
+
def plot3D(self, data2print: int = 0) -> None:
|
|
543
|
+
"""Print and plot all slices of one selected data stored in this phantom."""
|
|
544
|
+
print("VoxelGridPhantom")
|
|
545
|
+
print(f"size = {self.size}")
|
|
546
|
+
print()
|
|
547
|
+
|
|
548
|
+
label = ['PD', 'T1', 'T2', "T2'", "D", "B0", "B1", "coil sens"]
|
|
549
|
+
|
|
550
|
+
tensors = [
|
|
551
|
+
self.PD, self.T1, self.T2, self.T2dash, self.D, self.B0,
|
|
552
|
+
self.B1.squeeze(0), self.coil_sens
|
|
553
|
+
]
|
|
554
|
+
|
|
555
|
+
# Warn if we only print a part of all data
|
|
556
|
+
print(f"Plotting {label[data2print]}")
|
|
557
|
+
|
|
558
|
+
tensor = tensors[data2print].squeeze(0)
|
|
559
|
+
|
|
560
|
+
util.plot3D(tensor, figsize=(20, 5))
|
|
561
|
+
plt.title(label[data2print])
|
|
562
|
+
plt.show()
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def recover(mask, sim_data: SimData) -> VoxelGridPhantom:
|
|
566
|
+
"""Provided to :class:`SimData` to reverse the ``build()``"""
|
|
567
|
+
|
|
568
|
+
mask = mask.to(sim_data.device)
|
|
569
|
+
|
|
570
|
+
def to_full(sparse):
|
|
571
|
+
assert sparse.ndim < 3
|
|
572
|
+
if sparse.ndim == 2:
|
|
573
|
+
full = torch.zeros(
|
|
574
|
+
[sparse.shape[0], *mask.shape], dtype=sparse.dtype, device=mask.device)
|
|
575
|
+
full[:, mask] = sparse
|
|
576
|
+
else:
|
|
577
|
+
full = torch.zeros(mask.shape, device=mask.device)
|
|
578
|
+
full[mask] = sparse
|
|
579
|
+
return full
|
|
580
|
+
|
|
581
|
+
return VoxelGridPhantom(
|
|
582
|
+
to_full(sim_data.PD),
|
|
583
|
+
to_full(sim_data.T1),
|
|
584
|
+
to_full(sim_data.T2),
|
|
585
|
+
to_full(sim_data.T2dash),
|
|
586
|
+
to_full(sim_data.D),
|
|
587
|
+
to_full(sim_data.B0),
|
|
588
|
+
to_full(sim_data.B1),
|
|
589
|
+
to_full(sim_data.coil_sens),
|
|
590
|
+
sim_data.size
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def _load_tensor_from_mat(file_name: str) -> torch.Tensor:
|
|
595
|
+
mat = io.loadmat(file_name)
|
|
596
|
+
|
|
597
|
+
keys = [
|
|
598
|
+
key for key in mat
|
|
599
|
+
if not (key.startswith('__') and key.endswith('__'))
|
|
600
|
+
]
|
|
601
|
+
|
|
602
|
+
arrays = [mat[key] for key in keys if isinstance(mat[key], np.ndarray)]
|
|
603
|
+
|
|
604
|
+
if len(keys) == 0:
|
|
605
|
+
raise Exception("The loaded mat file does not contain any variables")
|
|
606
|
+
|
|
607
|
+
if len(arrays) != 1:
|
|
608
|
+
raise Exception("The loaded mat file must contain exactly one array")
|
|
609
|
+
|
|
610
|
+
return torch.from_numpy(arrays[0]).float()
|