miss-alignment 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- miss_alignment/__init__.py +19 -0
- miss_alignment/_cli.py +13 -0
- miss_alignment/alignment/__init__.py +4 -0
- miss_alignment/alignment/correlation.py +112 -0
- miss_alignment/alignment/optimize_global.py +519 -0
- miss_alignment/alignment/optimize_iterative.py +248 -0
- miss_alignment/alignment/optimize_spline.py +404 -0
- miss_alignment/alignment/parallel.py +158 -0
- miss_alignment/alignment/statistics.py +245 -0
- miss_alignment/alignment/tilt_series.py +266 -0
- miss_alignment/alignment/utils.py +32 -0
- miss_alignment/config_template.yaml +51 -0
- miss_alignment/data/__init__.py +5 -0
- miss_alignment/data/_augmentation.py +110 -0
- miss_alignment/data/_reconstruction_worker.py +422 -0
- miss_alignment/data/io.py +210 -0
- miss_alignment/data/shift_generation.py +353 -0
- miss_alignment/data/training_datamodule.py +301 -0
- miss_alignment/data/training_dataset.py +150 -0
- miss_alignment/gradcam/__init__.py +0 -0
- miss_alignment/gradcam/gradcam.py +123 -0
- miss_alignment/models/__init__.py +31 -0
- miss_alignment/models/_compact.py +364 -0
- miss_alignment/models/_resnet.py +209 -0
- miss_alignment/models/models.py +523 -0
- miss_alignment/prepare_stacks.py +167 -0
- miss_alignment/preprocessing.py +151 -0
- miss_alignment/py.typed +5 -0
- miss_alignment/train.py +337 -0
- miss_alignment/utils.py +51 -0
- miss_alignment-0.1.4.dist-info/METADATA +90 -0
- miss_alignment-0.1.4.dist-info/RECORD +35 -0
- miss_alignment-0.1.4.dist-info/WHEEL +4 -0
- miss_alignment-0.1.4.dist-info/entry_points.txt +2 -0
- miss_alignment-0.1.4.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""She has a chaotic good alignment for tilt-series."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version("miss_alignment")
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
__version__ = "uninstalled"
|
|
9
|
+
|
|
10
|
+
__author__ = "Marten Chaillet"
|
|
11
|
+
__email__ = "martenchaillet@gmail.com"
|
|
12
|
+
__all__ = [
|
|
13
|
+
"__version__",
|
|
14
|
+
"cli",
|
|
15
|
+
"train_miss_align",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
from ._cli import cli
|
|
19
|
+
from .train import train_miss_align
|
miss_alignment/_cli.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from click import Context
|
|
2
|
+
import typer
|
|
3
|
+
from typer.core import TyperGroup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OrderCommands(TyperGroup):
|
|
7
|
+
def list_commands(self, ctx: Context):
|
|
8
|
+
"""Return list of commands in the order appear."""
|
|
9
|
+
return list(self.commands) # get commands using self.commands
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
cli = typer.Typer(cls=OrderCommands, add_completion=False, no_args_is_help=True)
|
|
13
|
+
OPTION_PROMPT_KWARGS = {"prompt": True, "prompt_required": True}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from warpylib.ops import rescale
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def calculate_cross_correlation(
|
|
6
|
+
a: torch.Tensor,
|
|
7
|
+
b: torch.Tensor,
|
|
8
|
+
) -> torch.Tensor:
|
|
9
|
+
"""
|
|
10
|
+
Calculate the 3D cross correlation between volumes of the same size.
|
|
11
|
+
|
|
12
|
+
The position of the maximum relative to the center of the volume gives a shift.
|
|
13
|
+
This is the shift that when applied to `b` best aligns it to `a`.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
a : torch.Tensor
|
|
18
|
+
First 3D volume with shape (..., D, H, W)
|
|
19
|
+
b : torch.Tensor
|
|
20
|
+
Second 3D volume with shape (..., D, H, W)
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
torch.Tensor
|
|
25
|
+
3D cross-correlation volume
|
|
26
|
+
"""
|
|
27
|
+
a = (a - a.mean()) / a.std()
|
|
28
|
+
b = (b - b.mean()) / b.std()
|
|
29
|
+
d, h, w = a.shape[-3:]
|
|
30
|
+
fta = torch.fft.rfftn(a, dim=(-3, -2, -1))
|
|
31
|
+
ftb = torch.fft.rfftn(b, dim=(-3, -2, -1))
|
|
32
|
+
result = fta * torch.conj(ftb)
|
|
33
|
+
result = torch.fft.irfftn(result, dim=(-3, -2, -1), s=(d, h, w))
|
|
34
|
+
result = torch.fft.ifftshift(result, dim=(-3, -2, -1))
|
|
35
|
+
result /= d * h * w # normalize the result
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_shift_from_correlation_image(
|
|
40
|
+
correlation_image: torch.Tensor,
|
|
41
|
+
patch_size: int = 16,
|
|
42
|
+
upsample_size: int = 1024,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
"""
|
|
45
|
+
Extract shift from 3D correlation volume.
|
|
46
|
+
|
|
47
|
+
The shift should be applied to img2 to align with img1.
|
|
48
|
+
Uses Fourier upsampling for sub-voxel accuracy: extracts a region around the
|
|
49
|
+
integer peak, upsamples it using bandwidth-limited Fourier rescaling, and finds
|
|
50
|
+
the peak position in the upsampled volume.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
correlation_image : torch.Tensor
|
|
55
|
+
3D correlation volume
|
|
56
|
+
patch_size : int
|
|
57
|
+
Size of the cubic region to extract around the integer peak (must be even).
|
|
58
|
+
Default is 16.
|
|
59
|
+
upsample_size : int
|
|
60
|
+
Size to upsample the extracted region to (must be even). Default is 512.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
torch.Tensor
|
|
65
|
+
3D shift vector [z, y, x]
|
|
66
|
+
"""
|
|
67
|
+
dtype, device = correlation_image.dtype, correlation_image.device
|
|
68
|
+
shape = torch.tensor(correlation_image.shape, device=device, dtype=dtype)
|
|
69
|
+
center = torch.div(shape, 2, rounding_mode="floor")
|
|
70
|
+
|
|
71
|
+
# Find integer peak location
|
|
72
|
+
flat_idx = torch.argmax(correlation_image)
|
|
73
|
+
peak_coords = torch.tensor(
|
|
74
|
+
torch.unravel_index(flat_idx, correlation_image.shape),
|
|
75
|
+
device=device,
|
|
76
|
+
dtype=dtype,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
half_patch = patch_size // 2
|
|
80
|
+
|
|
81
|
+
# Check if we can extract a full patch around the peak
|
|
82
|
+
if torch.any(peak_coords < half_patch) or torch.any(
|
|
83
|
+
peak_coords >= shape - half_patch
|
|
84
|
+
):
|
|
85
|
+
return peak_coords - center
|
|
86
|
+
|
|
87
|
+
# Extract patch around peak
|
|
88
|
+
pz, py, px = peak_coords.int().tolist()
|
|
89
|
+
patch = correlation_image[
|
|
90
|
+
pz - half_patch : pz + half_patch,
|
|
91
|
+
py - half_patch : py + half_patch,
|
|
92
|
+
px - half_patch : px + half_patch,
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Upsample using Fourier rescaling
|
|
96
|
+
upsampled = rescale(patch, size=(upsample_size, upsample_size, upsample_size))
|
|
97
|
+
|
|
98
|
+
# Find peak in upsampled volume
|
|
99
|
+
up_flat_idx = torch.argmax(upsampled)
|
|
100
|
+
up_peak_coords = torch.tensor(
|
|
101
|
+
torch.unravel_index(up_flat_idx, upsampled.shape),
|
|
102
|
+
device=device,
|
|
103
|
+
dtype=dtype,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Convert upsampled peak position back to original coordinates
|
|
107
|
+
upsample_factor = upsample_size / patch_size
|
|
108
|
+
up_center = upsample_size / 2
|
|
109
|
+
offset = (up_peak_coords - up_center) / upsample_factor
|
|
110
|
+
subpixel_peak = peak_coords + offset
|
|
111
|
+
|
|
112
|
+
return subpixel_peak - center
|
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""Global shift optimization for tilt series alignment.
|
|
2
|
+
|
|
3
|
+
This module provides the core optimization function for per-tilt shifts
|
|
4
|
+
and warping grids (2D and 3D).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import einops
|
|
10
|
+
import torch
|
|
11
|
+
from warpylib import TiltSeries
|
|
12
|
+
from warpylib.cubic_grid import CubicGrid
|
|
13
|
+
from torch_affine_utils.transforms_3d import Ry, Rz
|
|
14
|
+
from torch_affine_utils.transforms_2d import R
|
|
15
|
+
|
|
16
|
+
from miss_alignment.models import MissAlignment
|
|
17
|
+
from miss_alignment.alignment.utils import project_volume_shift_to_image_alignment
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AlignmentNanError(Exception):
|
|
21
|
+
"""Raised when NaN values are detected during alignment optimization."""
|
|
22
|
+
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def optimize_shifts(
|
|
27
|
+
model: MissAlignment,
|
|
28
|
+
tilt_series: TiltSeries,
|
|
29
|
+
images: torch.Tensor,
|
|
30
|
+
pixel_size: float,
|
|
31
|
+
positions: torch.Tensor,
|
|
32
|
+
setting: str | tuple[int, int] | tuple[int, int, int, int] = "global",
|
|
33
|
+
patch_size: int = 96,
|
|
34
|
+
batch_size: int = 16,
|
|
35
|
+
apply_ctf: bool = True,
|
|
36
|
+
device: str | torch.device = "cpu",
|
|
37
|
+
max_retries: int = 3,
|
|
38
|
+
):
|
|
39
|
+
"""Find shifts to optimize model score.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
model : MissAlignment
|
|
44
|
+
Trained model for scoring reconstructions.
|
|
45
|
+
tilt_series : TiltSeries
|
|
46
|
+
Tilt series to optimize.
|
|
47
|
+
images : torch.Tensor
|
|
48
|
+
Preprocessed tilt images.
|
|
49
|
+
pixel_size : float
|
|
50
|
+
Pixel size in Angstroms.
|
|
51
|
+
positions : torch.Tensor
|
|
52
|
+
3D positions to reconstruct and evaluate.
|
|
53
|
+
setting : str | tuple
|
|
54
|
+
Type of alignment to run:
|
|
55
|
+
- 'global': optimizes a single shift per image
|
|
56
|
+
- tuple(int, int) e.g. (3, 3): a single 2D field per tilt image
|
|
57
|
+
- tuple(int, int, int, int) e.g. (3, 3, 2, 10): a volume warp grid
|
|
58
|
+
patch_size : int
|
|
59
|
+
Size of reconstruction patches.
|
|
60
|
+
batch_size : int
|
|
61
|
+
Batch size for reconstruction.
|
|
62
|
+
apply_ctf : bool
|
|
63
|
+
Whether to apply CTF correction.
|
|
64
|
+
device : str | torch.device
|
|
65
|
+
Device to run optimization on.
|
|
66
|
+
max_retries : int
|
|
67
|
+
Maximum number of retry attempts if NaN is encountered.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
tuple[TiltSeries, list[float]]
|
|
72
|
+
Optimized tilt series and list of loss values.
|
|
73
|
+
"""
|
|
74
|
+
# Store original tilt series state in case all retries fail
|
|
75
|
+
original_tilt_axis_offset_y = tilt_series.tilt_axis_offset_y.clone()
|
|
76
|
+
original_tilt_axis_offset_x = tilt_series.tilt_axis_offset_x.clone()
|
|
77
|
+
|
|
78
|
+
# Store original grid states if applicable
|
|
79
|
+
# We need to store complete grid state (dimensions, values, margins)
|
|
80
|
+
# because resize() changes the grid structure
|
|
81
|
+
if setting != "global" and len(setting) == 2:
|
|
82
|
+
has_grid_x = hasattr(tilt_series.grid_movement_x, "values")
|
|
83
|
+
has_grid_y = hasattr(tilt_series.grid_movement_y, "values")
|
|
84
|
+
original_grid_x = (
|
|
85
|
+
{
|
|
86
|
+
"dimensions": tilt_series.grid_movement_x.dimensions,
|
|
87
|
+
"values": tilt_series.grid_movement_x.values.clone(),
|
|
88
|
+
"margins": tilt_series.grid_movement_x.margins,
|
|
89
|
+
}
|
|
90
|
+
if has_grid_x
|
|
91
|
+
else None
|
|
92
|
+
)
|
|
93
|
+
original_grid_y = (
|
|
94
|
+
{
|
|
95
|
+
"dimensions": tilt_series.grid_movement_y.dimensions,
|
|
96
|
+
"values": tilt_series.grid_movement_y.values.clone(),
|
|
97
|
+
"margins": tilt_series.grid_movement_y.margins,
|
|
98
|
+
}
|
|
99
|
+
if has_grid_y
|
|
100
|
+
else None
|
|
101
|
+
)
|
|
102
|
+
elif setting != "global" and len(setting) == 4:
|
|
103
|
+
has_grid_x = hasattr(tilt_series.grid_volume_warp_x, "values")
|
|
104
|
+
has_grid_y = hasattr(tilt_series.grid_volume_warp_y, "values")
|
|
105
|
+
has_grid_z = hasattr(tilt_series.grid_volume_warp_z, "values")
|
|
106
|
+
original_grid_x = (
|
|
107
|
+
{
|
|
108
|
+
"dimensions": tilt_series.grid_volume_warp_x.dimensions,
|
|
109
|
+
"values": tilt_series.grid_volume_warp_x.values.clone(),
|
|
110
|
+
"margins": tilt_series.grid_volume_warp_x.margins,
|
|
111
|
+
}
|
|
112
|
+
if has_grid_x
|
|
113
|
+
else None
|
|
114
|
+
)
|
|
115
|
+
original_grid_y = (
|
|
116
|
+
{
|
|
117
|
+
"dimensions": tilt_series.grid_volume_warp_y.dimensions,
|
|
118
|
+
"values": tilt_series.grid_volume_warp_y.values.clone(),
|
|
119
|
+
"margins": tilt_series.grid_volume_warp_y.margins,
|
|
120
|
+
}
|
|
121
|
+
if has_grid_y
|
|
122
|
+
else None
|
|
123
|
+
)
|
|
124
|
+
original_grid_z = (
|
|
125
|
+
{
|
|
126
|
+
"dimensions": tilt_series.grid_volume_warp_z.dimensions,
|
|
127
|
+
"values": tilt_series.grid_volume_warp_z.values.clone(),
|
|
128
|
+
"margins": tilt_series.grid_volume_warp_z.margins,
|
|
129
|
+
}
|
|
130
|
+
if has_grid_z
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Use highest precision for optimization to avoid NaN issues
|
|
135
|
+
# Training uses "medium" and 16-mixed, but optimization needs full precision
|
|
136
|
+
original_precision = torch.get_float32_matmul_precision()
|
|
137
|
+
torch.set_float32_matmul_precision("highest")
|
|
138
|
+
|
|
139
|
+
# Retry loop
|
|
140
|
+
retries_left = max_retries
|
|
141
|
+
while retries_left > 0:
|
|
142
|
+
try:
|
|
143
|
+
return _optimize_shifts_inner(
|
|
144
|
+
model=model,
|
|
145
|
+
tilt_series=tilt_series,
|
|
146
|
+
images=images,
|
|
147
|
+
pixel_size=pixel_size,
|
|
148
|
+
positions=positions,
|
|
149
|
+
setting=setting,
|
|
150
|
+
patch_size=patch_size,
|
|
151
|
+
batch_size=batch_size,
|
|
152
|
+
apply_ctf=apply_ctf,
|
|
153
|
+
device=device,
|
|
154
|
+
original_precision=original_precision,
|
|
155
|
+
)
|
|
156
|
+
except AlignmentNanError:
|
|
157
|
+
retries_left -= 1
|
|
158
|
+
if retries_left > 0:
|
|
159
|
+
# Reset tilt series to original state before retry
|
|
160
|
+
tilt_series.tilt_axis_offset_y = original_tilt_axis_offset_y.clone()
|
|
161
|
+
tilt_series.tilt_axis_offset_x = original_tilt_axis_offset_x.clone()
|
|
162
|
+
|
|
163
|
+
if setting != "global" and len(setting) == 2:
|
|
164
|
+
if original_grid_x is not None:
|
|
165
|
+
tilt_series.grid_movement_x = CubicGrid(
|
|
166
|
+
dimensions=original_grid_x["dimensions"],
|
|
167
|
+
values=original_grid_x["values"].clone(),
|
|
168
|
+
margins=original_grid_x["margins"],
|
|
169
|
+
)
|
|
170
|
+
if original_grid_y is not None:
|
|
171
|
+
tilt_series.grid_movement_y = CubicGrid(
|
|
172
|
+
dimensions=original_grid_y["dimensions"],
|
|
173
|
+
values=original_grid_y["values"].clone(),
|
|
174
|
+
margins=original_grid_y["margins"],
|
|
175
|
+
)
|
|
176
|
+
elif setting != "global" and len(setting) == 4:
|
|
177
|
+
if original_grid_x is not None:
|
|
178
|
+
tilt_series.grid_volume_warp_x = CubicGrid(
|
|
179
|
+
dimensions=original_grid_x["dimensions"],
|
|
180
|
+
values=original_grid_x["values"].clone(),
|
|
181
|
+
margins=original_grid_x["margins"],
|
|
182
|
+
)
|
|
183
|
+
if original_grid_y is not None:
|
|
184
|
+
tilt_series.grid_volume_warp_y = CubicGrid(
|
|
185
|
+
dimensions=original_grid_y["dimensions"],
|
|
186
|
+
values=original_grid_y["values"].clone(),
|
|
187
|
+
margins=original_grid_y["margins"],
|
|
188
|
+
)
|
|
189
|
+
if original_grid_z is not None:
|
|
190
|
+
tilt_series.grid_volume_warp_z = CubicGrid(
|
|
191
|
+
dimensions=original_grid_z["dimensions"],
|
|
192
|
+
values=original_grid_z["values"].clone(),
|
|
193
|
+
margins=original_grid_z["margins"],
|
|
194
|
+
)
|
|
195
|
+
print(f"Retrying optimization... (retries left: {retries_left})")
|
|
196
|
+
|
|
197
|
+
# All retries failed, restore original state and return failure
|
|
198
|
+
tilt_series.tilt_axis_offset_y = original_tilt_axis_offset_y
|
|
199
|
+
tilt_series.tilt_axis_offset_x = original_tilt_axis_offset_x
|
|
200
|
+
|
|
201
|
+
if setting != "global" and len(setting) == 2:
|
|
202
|
+
if original_grid_x is not None:
|
|
203
|
+
tilt_series.grid_movement_x = CubicGrid(
|
|
204
|
+
dimensions=original_grid_x["dimensions"],
|
|
205
|
+
values=original_grid_x["values"],
|
|
206
|
+
margins=original_grid_x["margins"],
|
|
207
|
+
)
|
|
208
|
+
if original_grid_y is not None:
|
|
209
|
+
tilt_series.grid_movement_y = CubicGrid(
|
|
210
|
+
dimensions=original_grid_y["dimensions"],
|
|
211
|
+
values=original_grid_y["values"],
|
|
212
|
+
margins=original_grid_y["margins"],
|
|
213
|
+
)
|
|
214
|
+
elif setting != "global" and len(setting) == 4:
|
|
215
|
+
if original_grid_x is not None:
|
|
216
|
+
tilt_series.grid_volume_warp_x = CubicGrid(
|
|
217
|
+
dimensions=original_grid_x["dimensions"],
|
|
218
|
+
values=original_grid_x["values"],
|
|
219
|
+
margins=original_grid_x["margins"],
|
|
220
|
+
)
|
|
221
|
+
if original_grid_y is not None:
|
|
222
|
+
tilt_series.grid_volume_warp_y = CubicGrid(
|
|
223
|
+
dimensions=original_grid_y["dimensions"],
|
|
224
|
+
values=original_grid_y["values"],
|
|
225
|
+
margins=original_grid_y["margins"],
|
|
226
|
+
)
|
|
227
|
+
if original_grid_z is not None:
|
|
228
|
+
tilt_series.grid_volume_warp_z = CubicGrid(
|
|
229
|
+
dimensions=original_grid_z["dimensions"],
|
|
230
|
+
values=original_grid_z["values"],
|
|
231
|
+
margins=original_grid_z["margins"],
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Restore original precision setting
|
|
235
|
+
torch.set_float32_matmul_precision(original_precision)
|
|
236
|
+
|
|
237
|
+
# Return original tilt series with failure loss
|
|
238
|
+
return tilt_series, [float("inf")]
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _optimize_shifts_inner(
|
|
242
|
+
model: MissAlignment,
|
|
243
|
+
tilt_series: TiltSeries,
|
|
244
|
+
images: torch.Tensor,
|
|
245
|
+
pixel_size: float,
|
|
246
|
+
positions: torch.Tensor,
|
|
247
|
+
setting: str | tuple[int, int] | tuple[int, int, int, int],
|
|
248
|
+
patch_size: int,
|
|
249
|
+
batch_size: int,
|
|
250
|
+
apply_ctf: bool,
|
|
251
|
+
device: str | torch.device,
|
|
252
|
+
original_precision: str,
|
|
253
|
+
):
|
|
254
|
+
"""Inner optimization function that can raise AlignmentNanError.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
tuple[TiltSeries, list[float]]
|
|
259
|
+
Optimized tilt series and list of loss values.
|
|
260
|
+
"""
|
|
261
|
+
# move all modules to device in place
|
|
262
|
+
tilt_series.to(device)
|
|
263
|
+
model.to(device)
|
|
264
|
+
model.freeze()
|
|
265
|
+
model.eval()
|
|
266
|
+
# move images to device
|
|
267
|
+
images = images.to(device)
|
|
268
|
+
|
|
269
|
+
parameters = None
|
|
270
|
+
if setting == "global":
|
|
271
|
+
# store the initial tilt_series alignment
|
|
272
|
+
initial_tilt_axis_offset_y = tilt_series.tilt_axis_offset_y.clone()
|
|
273
|
+
initial_tilt_axis_offset_x = tilt_series.tilt_axis_offset_x.clone()
|
|
274
|
+
|
|
275
|
+
# Find the index of the tilt closest to zero degrees for recentering
|
|
276
|
+
zero_tilt_idx = tilt_series.angles.abs().argmin()
|
|
277
|
+
initial_zero_tilt_shift_y = initial_tilt_axis_offset_y[zero_tilt_idx].clone()
|
|
278
|
+
initial_zero_tilt_shift_x = initial_tilt_axis_offset_x[zero_tilt_idx].clone()
|
|
279
|
+
|
|
280
|
+
# create the alignment parameters
|
|
281
|
+
shifts_y = torch.zeros_like(
|
|
282
|
+
initial_tilt_axis_offset_x,
|
|
283
|
+
requires_grad=True,
|
|
284
|
+
device=device,
|
|
285
|
+
)
|
|
286
|
+
shifts_x = torch.zeros_like(
|
|
287
|
+
initial_tilt_axis_offset_x,
|
|
288
|
+
requires_grad=True,
|
|
289
|
+
device=device,
|
|
290
|
+
)
|
|
291
|
+
parameters = [shifts_y, shifts_x]
|
|
292
|
+
elif len(setting) == 2: # TODO add case of starting from existent grid
|
|
293
|
+
# movement grids - these should receive gradients
|
|
294
|
+
grid_dims = [setting[0], setting[1], tilt_series.n_tilts]
|
|
295
|
+
|
|
296
|
+
tilt_series.grid_movement_x = tilt_series.grid_movement_x.resize(
|
|
297
|
+
new_size=grid_dims
|
|
298
|
+
).to(device)
|
|
299
|
+
leaf_variable_x = tilt_series.grid_movement_x.values.requires_grad_(True)
|
|
300
|
+
tilt_series.grid_movement_x = CubicGrid(grid_dims, leaf_variable_x)
|
|
301
|
+
|
|
302
|
+
tilt_series.grid_movement_y = tilt_series.grid_movement_y.resize(
|
|
303
|
+
new_size=grid_dims
|
|
304
|
+
).to(device)
|
|
305
|
+
leaf_variable_y = tilt_series.grid_movement_y.values.requires_grad_(True)
|
|
306
|
+
tilt_series.grid_movement_y = CubicGrid(grid_dims, leaf_variable_y)
|
|
307
|
+
|
|
308
|
+
parameters = [leaf_variable_x, leaf_variable_y]
|
|
309
|
+
elif len(setting) == 4: # TODO add case of starting from existent grid
|
|
310
|
+
tilt_series.grid_volume_warp_x = tilt_series.grid_volume_warp_x.resize(
|
|
311
|
+
new_size=setting
|
|
312
|
+
).to(device)
|
|
313
|
+
leaf_variable_x = tilt_series.grid_volume_warp_x.values.requires_grad_(True)
|
|
314
|
+
tilt_series.grid_volume_warp_x = CubicGrid(setting, leaf_variable_x)
|
|
315
|
+
|
|
316
|
+
tilt_series.grid_volume_warp_y = tilt_series.grid_volume_warp_y.resize(
|
|
317
|
+
new_size=setting
|
|
318
|
+
).to(device)
|
|
319
|
+
leaf_variable_y = tilt_series.grid_volume_warp_y.values.requires_grad_(True)
|
|
320
|
+
tilt_series.grid_volume_warp_y = CubicGrid(setting, leaf_variable_y)
|
|
321
|
+
|
|
322
|
+
tilt_series.grid_volume_warp_z = tilt_series.grid_volume_warp_z.resize(
|
|
323
|
+
new_size=setting
|
|
324
|
+
).to(device)
|
|
325
|
+
leaf_variable_z = tilt_series.grid_volume_warp_z.values.requires_grad_(True)
|
|
326
|
+
tilt_series.grid_volume_warp_z = CubicGrid(setting, leaf_variable_z)
|
|
327
|
+
|
|
328
|
+
parameters = [
|
|
329
|
+
leaf_variable_x,
|
|
330
|
+
leaf_variable_y,
|
|
331
|
+
leaf_variable_z,
|
|
332
|
+
]
|
|
333
|
+
else:
|
|
334
|
+
raise ValueError(f"Invalid setting for alignment optimization: {setting}")
|
|
335
|
+
|
|
336
|
+
alignment_optimizer = torch.optim.LBFGS(
|
|
337
|
+
parameters,
|
|
338
|
+
line_search_fn="strong_wolfe",
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Initialize list to store loss values
|
|
342
|
+
loss_values = []
|
|
343
|
+
|
|
344
|
+
# Determine device type for autocast
|
|
345
|
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
|
346
|
+
|
|
347
|
+
def closure():
|
|
348
|
+
alignment_optimizer.zero_grad()
|
|
349
|
+
|
|
350
|
+
# Check for NaN in parameters before computing loss
|
|
351
|
+
# If found, return large penalty to make line search reject this step
|
|
352
|
+
nan_in_params = False
|
|
353
|
+
if setting == "global":
|
|
354
|
+
if torch.isnan(shifts_x).any() or torch.isnan(shifts_y).any():
|
|
355
|
+
nan_in_params = True
|
|
356
|
+
elif len(setting) == 2:
|
|
357
|
+
if torch.isnan(leaf_variable_x).any() or torch.isnan(leaf_variable_y).any():
|
|
358
|
+
nan_in_params = True
|
|
359
|
+
elif len(setting) == 4:
|
|
360
|
+
if (
|
|
361
|
+
torch.isnan(leaf_variable_x).any()
|
|
362
|
+
or torch.isnan(leaf_variable_y).any()
|
|
363
|
+
or torch.isnan(leaf_variable_z).any()
|
|
364
|
+
):
|
|
365
|
+
nan_in_params = True
|
|
366
|
+
|
|
367
|
+
if nan_in_params:
|
|
368
|
+
raise AlignmentNanError
|
|
369
|
+
|
|
370
|
+
# update the alignments
|
|
371
|
+
if setting == "global":
|
|
372
|
+
tilt_series.tilt_axis_offset_y = initial_tilt_axis_offset_y + shifts_y
|
|
373
|
+
tilt_series.tilt_axis_offset_x = initial_tilt_axis_offset_x + shifts_x
|
|
374
|
+
|
|
375
|
+
batches = int(math.ceil(positions.shape[0] / batch_size))
|
|
376
|
+
total_samples = positions.shape[0]
|
|
377
|
+
total_weighted_score = 0.0
|
|
378
|
+
total_precision = 0.0
|
|
379
|
+
|
|
380
|
+
# Disable autocast to ensure full precision during optimization
|
|
381
|
+
with torch.amp.autocast(device_type=device_type, enabled=False):
|
|
382
|
+
# Use gradient accumulation: process each batch separately
|
|
383
|
+
for b in range(batches):
|
|
384
|
+
if b == batches - 1:
|
|
385
|
+
batch_positions = positions[b * batch_size :]
|
|
386
|
+
else:
|
|
387
|
+
batch_positions = positions[b * batch_size : (b + 1) * batch_size]
|
|
388
|
+
|
|
389
|
+
current_batch_size = batch_positions.shape[0]
|
|
390
|
+
|
|
391
|
+
# reconstruct subvolumes for this batch
|
|
392
|
+
subvolumes = tilt_series.reconstruct_subvolumes_single(
|
|
393
|
+
tilt_data=images,
|
|
394
|
+
coords=batch_positions.to(device),
|
|
395
|
+
pixel_size=pixel_size,
|
|
396
|
+
size=patch_size,
|
|
397
|
+
apply_ctf=apply_ctf,
|
|
398
|
+
oversampling=2.0,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# ensure normalization per subvolume
|
|
402
|
+
mean = einops.reduce(subvolumes, "n d h w -> n 1 1 1", reduction="mean")
|
|
403
|
+
std = torch.std(subvolumes, dim=(-3, -2, -1), keepdim=True)
|
|
404
|
+
# Add epsilon to prevent division by zero (which causes NaN precision)
|
|
405
|
+
eps = 1e-8
|
|
406
|
+
subvolumes = (subvolumes - mean) / (std + eps)
|
|
407
|
+
|
|
408
|
+
# change channel to batch dimension
|
|
409
|
+
subvolumes = einops.rearrange(subvolumes, "b d h w -> b 1 d h w")
|
|
410
|
+
|
|
411
|
+
# Get score and precision for this batch
|
|
412
|
+
batch_scores, batch_log_precisions = model(subvolumes)
|
|
413
|
+
|
|
414
|
+
batch_precisions = batch_log_precisions.exp().detach()
|
|
415
|
+
|
|
416
|
+
# Precision-weighted average score for this batch
|
|
417
|
+
batch_weighted_score = (batch_scores * batch_precisions).sum()
|
|
418
|
+
batch_precision_sum = batch_precisions.sum()
|
|
419
|
+
|
|
420
|
+
# Weight by batch size for proper gradient accumulation
|
|
421
|
+
weighted_loss = batch_weighted_score * (
|
|
422
|
+
current_batch_size / total_samples
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Backward pass for this batch (gradients accumulate)
|
|
426
|
+
weighted_loss.backward()
|
|
427
|
+
|
|
428
|
+
# Accumulate for precision-weighted average
|
|
429
|
+
total_weighted_score += batch_weighted_score.item()
|
|
430
|
+
total_precision += batch_precision_sum.item()
|
|
431
|
+
|
|
432
|
+
# Precision-weighted average score
|
|
433
|
+
if total_precision <= 0:
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"Total precision is {total_precision}, which is <= 0. "
|
|
436
|
+
"This indicates a problem with the model precision outputs."
|
|
437
|
+
)
|
|
438
|
+
avg_score = total_weighted_score / total_precision
|
|
439
|
+
|
|
440
|
+
# Check if loss is NaN and raise error
|
|
441
|
+
if math.isnan(avg_score):
|
|
442
|
+
raise AlignmentNanError("Loss value is NaN")
|
|
443
|
+
|
|
444
|
+
loss_values.append(avg_score)
|
|
445
|
+
|
|
446
|
+
return avg_score
|
|
447
|
+
|
|
448
|
+
n_iters = 1 # 5 iterations should give convergence
|
|
449
|
+
for x in range(n_iters):
|
|
450
|
+
alignment_optimizer.step(closure)
|
|
451
|
+
|
|
452
|
+
if setting == "global":
|
|
453
|
+
# remove gradients and finalize global shifts
|
|
454
|
+
tilt_series.tilt_axis_offset_y = initial_tilt_axis_offset_y + shifts_y.detach()
|
|
455
|
+
tilt_series.tilt_axis_offset_x = initial_tilt_axis_offset_x + shifts_x.detach()
|
|
456
|
+
|
|
457
|
+
# Recenter alignment: set the shift at zero tilt to match initial zero tilt
|
|
458
|
+
# Get the current shift at the zero tilt
|
|
459
|
+
current_zero_tilt_shift_y = tilt_series.tilt_axis_offset_y[zero_tilt_idx]
|
|
460
|
+
current_zero_tilt_shift_x = tilt_series.tilt_axis_offset_x[zero_tilt_idx]
|
|
461
|
+
|
|
462
|
+
# Calculate the difference from initial to current at zero tilt
|
|
463
|
+
delta_shift_y = current_zero_tilt_shift_y - initial_zero_tilt_shift_y
|
|
464
|
+
delta_shift_x = current_zero_tilt_shift_x - initial_zero_tilt_shift_x
|
|
465
|
+
|
|
466
|
+
delta_shift_2d = torch.tensor(
|
|
467
|
+
[delta_shift_y, delta_shift_x],
|
|
468
|
+
device=device,
|
|
469
|
+
dtype=tilt_series.angles.dtype,
|
|
470
|
+
)
|
|
471
|
+
m_2d = R(tilt_series.tilt_axis_angles, yx=True)
|
|
472
|
+
m_2d = torch.linalg.inv(m_2d[zero_tilt_idx, :2, :2])
|
|
473
|
+
delta_shift_2d = m_2d @ einops.rearrange(delta_shift_2d, "x -> x 1")
|
|
474
|
+
delta_shift_y, delta_shift_x = delta_shift_2d[0], delta_shift_2d[1]
|
|
475
|
+
|
|
476
|
+
# Create a 3D shift tensor with z=0 (in ZYX order)
|
|
477
|
+
shift_3d = torch.tensor(
|
|
478
|
+
[0.0, delta_shift_y, delta_shift_x],
|
|
479
|
+
device=device,
|
|
480
|
+
dtype=tilt_series.angles.dtype,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Compute projection matrices from tilt angles
|
|
484
|
+
r0 = Ry(-tilt_series.angles, zyx=True)
|
|
485
|
+
r1 = Rz(tilt_series.tilt_axis_angles, zyx=True)
|
|
486
|
+
rotation_matrices = r1 @ r0
|
|
487
|
+
projection_matrices = rotation_matrices[..., 1:3, :3]
|
|
488
|
+
|
|
489
|
+
# Project the 3D shift to 2D shifts for all tilts
|
|
490
|
+
shifts_2d = project_volume_shift_to_image_alignment(
|
|
491
|
+
shift_3d, projection_matrices
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Apply the correction: subtract the projected delta shift from all tilts
|
|
495
|
+
tilt_series.tilt_axis_offset_y -= shifts_2d[:, 0]
|
|
496
|
+
tilt_series.tilt_axis_offset_x -= shifts_2d[:, 1]
|
|
497
|
+
elif len(setting) == 2:
|
|
498
|
+
# remove gradients
|
|
499
|
+
tilt_series.grid_movement_x.values = tilt_series.grid_movement_x.values.detach()
|
|
500
|
+
tilt_series.grid_movement_y.values = tilt_series.grid_movement_y.values.detach()
|
|
501
|
+
elif len(setting) == 4:
|
|
502
|
+
# remove gradients
|
|
503
|
+
tilt_series.grid_volume_warp_x.values = (
|
|
504
|
+
tilt_series.grid_volume_warp_x.values.detach()
|
|
505
|
+
)
|
|
506
|
+
tilt_series.grid_volume_warp_y.values = (
|
|
507
|
+
tilt_series.grid_volume_warp_y.values.detach()
|
|
508
|
+
)
|
|
509
|
+
tilt_series.grid_volume_warp_z.values = (
|
|
510
|
+
tilt_series.grid_volume_warp_z.values.detach()
|
|
511
|
+
)
|
|
512
|
+
# move back because there were modified in-place
|
|
513
|
+
tilt_series.to("cpu")
|
|
514
|
+
model.to("cpu")
|
|
515
|
+
|
|
516
|
+
# Restore original precision setting
|
|
517
|
+
torch.set_float32_matmul_precision(original_precision)
|
|
518
|
+
|
|
519
|
+
return tilt_series, loss_values
|