miss-alignment 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. miss_alignment/__init__.py +19 -0
  2. miss_alignment/_cli.py +13 -0
  3. miss_alignment/alignment/__init__.py +4 -0
  4. miss_alignment/alignment/correlation.py +112 -0
  5. miss_alignment/alignment/optimize_global.py +519 -0
  6. miss_alignment/alignment/optimize_iterative.py +248 -0
  7. miss_alignment/alignment/optimize_spline.py +404 -0
  8. miss_alignment/alignment/parallel.py +158 -0
  9. miss_alignment/alignment/statistics.py +245 -0
  10. miss_alignment/alignment/tilt_series.py +266 -0
  11. miss_alignment/alignment/utils.py +32 -0
  12. miss_alignment/config_template.yaml +51 -0
  13. miss_alignment/data/__init__.py +5 -0
  14. miss_alignment/data/_augmentation.py +110 -0
  15. miss_alignment/data/_reconstruction_worker.py +422 -0
  16. miss_alignment/data/io.py +210 -0
  17. miss_alignment/data/shift_generation.py +353 -0
  18. miss_alignment/data/training_datamodule.py +301 -0
  19. miss_alignment/data/training_dataset.py +150 -0
  20. miss_alignment/gradcam/__init__.py +0 -0
  21. miss_alignment/gradcam/gradcam.py +123 -0
  22. miss_alignment/models/__init__.py +31 -0
  23. miss_alignment/models/_compact.py +364 -0
  24. miss_alignment/models/_resnet.py +209 -0
  25. miss_alignment/models/models.py +523 -0
  26. miss_alignment/prepare_stacks.py +167 -0
  27. miss_alignment/preprocessing.py +151 -0
  28. miss_alignment/py.typed +5 -0
  29. miss_alignment/train.py +337 -0
  30. miss_alignment/utils.py +51 -0
  31. miss_alignment-0.1.4.dist-info/METADATA +90 -0
  32. miss_alignment-0.1.4.dist-info/RECORD +35 -0
  33. miss_alignment-0.1.4.dist-info/WHEEL +4 -0
  34. miss_alignment-0.1.4.dist-info/entry_points.txt +2 -0
  35. miss_alignment-0.1.4.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,19 @@
1
+ """She has a chaotic good alignment for tilt-series."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ try:
6
+ __version__ = version("miss_alignment")
7
+ except PackageNotFoundError:
8
+ __version__ = "uninstalled"
9
+
10
+ __author__ = "Marten Chaillet"
11
+ __email__ = "martenchaillet@gmail.com"
12
+ __all__ = [
13
+ "__version__",
14
+ "cli",
15
+ "train_miss_align",
16
+ ]
17
+
18
+ from ._cli import cli
19
+ from .train import train_miss_align
miss_alignment/_cli.py ADDED
@@ -0,0 +1,13 @@
1
+ from click import Context
2
+ import typer
3
+ from typer.core import TyperGroup
4
+
5
+
6
+ class OrderCommands(TyperGroup):
7
+ def list_commands(self, ctx: Context):
8
+ """Return list of commands in the order appear."""
9
+ return list(self.commands) # get commands using self.commands
10
+
11
+
12
+ cli = typer.Typer(cls=OrderCommands, add_completion=False, no_args_is_help=True)
13
+ OPTION_PROMPT_KWARGS = {"prompt": True, "prompt_required": True}
@@ -0,0 +1,4 @@
1
+ from .tilt_series import evaluate_tilt_series
2
+ from .parallel import run_alignment_parallel
3
+
4
+ __all__ = ["evaluate_tilt_series", "run_alignment_parallel"]
@@ -0,0 +1,112 @@
1
+ import torch
2
+ from warpylib.ops import rescale
3
+
4
+
5
+ def calculate_cross_correlation(
6
+ a: torch.Tensor,
7
+ b: torch.Tensor,
8
+ ) -> torch.Tensor:
9
+ """
10
+ Calculate the 3D cross correlation between volumes of the same size.
11
+
12
+ The position of the maximum relative to the center of the volume gives a shift.
13
+ This is the shift that when applied to `b` best aligns it to `a`.
14
+
15
+ Parameters
16
+ ----------
17
+ a : torch.Tensor
18
+ First 3D volume with shape (..., D, H, W)
19
+ b : torch.Tensor
20
+ Second 3D volume with shape (..., D, H, W)
21
+
22
+ Returns
23
+ -------
24
+ torch.Tensor
25
+ 3D cross-correlation volume
26
+ """
27
+ a = (a - a.mean()) / a.std()
28
+ b = (b - b.mean()) / b.std()
29
+ d, h, w = a.shape[-3:]
30
+ fta = torch.fft.rfftn(a, dim=(-3, -2, -1))
31
+ ftb = torch.fft.rfftn(b, dim=(-3, -2, -1))
32
+ result = fta * torch.conj(ftb)
33
+ result = torch.fft.irfftn(result, dim=(-3, -2, -1), s=(d, h, w))
34
+ result = torch.fft.ifftshift(result, dim=(-3, -2, -1))
35
+ result /= d * h * w # normalize the result
36
+ return result
37
+
38
+
39
+ def get_shift_from_correlation_image(
40
+ correlation_image: torch.Tensor,
41
+ patch_size: int = 16,
42
+ upsample_size: int = 1024,
43
+ ) -> torch.Tensor:
44
+ """
45
+ Extract shift from 3D correlation volume.
46
+
47
+ The shift should be applied to img2 to align with img1.
48
+ Uses Fourier upsampling for sub-voxel accuracy: extracts a region around the
49
+ integer peak, upsamples it using bandwidth-limited Fourier rescaling, and finds
50
+ the peak position in the upsampled volume.
51
+
52
+ Parameters
53
+ ----------
54
+ correlation_image : torch.Tensor
55
+ 3D correlation volume
56
+ patch_size : int
57
+ Size of the cubic region to extract around the integer peak (must be even).
58
+ Default is 16.
59
+ upsample_size : int
60
+ Size to upsample the extracted region to (must be even). Default is 512.
61
+
62
+ Returns
63
+ -------
64
+ torch.Tensor
65
+ 3D shift vector [z, y, x]
66
+ """
67
+ dtype, device = correlation_image.dtype, correlation_image.device
68
+ shape = torch.tensor(correlation_image.shape, device=device, dtype=dtype)
69
+ center = torch.div(shape, 2, rounding_mode="floor")
70
+
71
+ # Find integer peak location
72
+ flat_idx = torch.argmax(correlation_image)
73
+ peak_coords = torch.tensor(
74
+ torch.unravel_index(flat_idx, correlation_image.shape),
75
+ device=device,
76
+ dtype=dtype,
77
+ )
78
+
79
+ half_patch = patch_size // 2
80
+
81
+ # Check if we can extract a full patch around the peak
82
+ if torch.any(peak_coords < half_patch) or torch.any(
83
+ peak_coords >= shape - half_patch
84
+ ):
85
+ return peak_coords - center
86
+
87
+ # Extract patch around peak
88
+ pz, py, px = peak_coords.int().tolist()
89
+ patch = correlation_image[
90
+ pz - half_patch : pz + half_patch,
91
+ py - half_patch : py + half_patch,
92
+ px - half_patch : px + half_patch,
93
+ ]
94
+
95
+ # Upsample using Fourier rescaling
96
+ upsampled = rescale(patch, size=(upsample_size, upsample_size, upsample_size))
97
+
98
+ # Find peak in upsampled volume
99
+ up_flat_idx = torch.argmax(upsampled)
100
+ up_peak_coords = torch.tensor(
101
+ torch.unravel_index(up_flat_idx, upsampled.shape),
102
+ device=device,
103
+ dtype=dtype,
104
+ )
105
+
106
+ # Convert upsampled peak position back to original coordinates
107
+ upsample_factor = upsample_size / patch_size
108
+ up_center = upsample_size / 2
109
+ offset = (up_peak_coords - up_center) / upsample_factor
110
+ subpixel_peak = peak_coords + offset
111
+
112
+ return subpixel_peak - center
@@ -0,0 +1,519 @@
1
+ """Global shift optimization for tilt series alignment.
2
+
3
+ This module provides the core optimization function for per-tilt shifts
4
+ and warping grids (2D and 3D).
5
+ """
6
+
7
+ import math
8
+
9
+ import einops
10
+ import torch
11
+ from warpylib import TiltSeries
12
+ from warpylib.cubic_grid import CubicGrid
13
+ from torch_affine_utils.transforms_3d import Ry, Rz
14
+ from torch_affine_utils.transforms_2d import R
15
+
16
+ from miss_alignment.models import MissAlignment
17
+ from miss_alignment.alignment.utils import project_volume_shift_to_image_alignment
18
+
19
+
20
+ class AlignmentNanError(Exception):
21
+ """Raised when NaN values are detected during alignment optimization."""
22
+
23
+ pass
24
+
25
+
26
+ def optimize_shifts(
27
+ model: MissAlignment,
28
+ tilt_series: TiltSeries,
29
+ images: torch.Tensor,
30
+ pixel_size: float,
31
+ positions: torch.Tensor,
32
+ setting: str | tuple[int, int] | tuple[int, int, int, int] = "global",
33
+ patch_size: int = 96,
34
+ batch_size: int = 16,
35
+ apply_ctf: bool = True,
36
+ device: str | torch.device = "cpu",
37
+ max_retries: int = 3,
38
+ ):
39
+ """Find shifts to optimize model score.
40
+
41
+ Parameters
42
+ ----------
43
+ model : MissAlignment
44
+ Trained model for scoring reconstructions.
45
+ tilt_series : TiltSeries
46
+ Tilt series to optimize.
47
+ images : torch.Tensor
48
+ Preprocessed tilt images.
49
+ pixel_size : float
50
+ Pixel size in Angstroms.
51
+ positions : torch.Tensor
52
+ 3D positions to reconstruct and evaluate.
53
+ setting : str | tuple
54
+ Type of alignment to run:
55
+ - 'global': optimizes a single shift per image
56
+ - tuple(int, int) e.g. (3, 3): a single 2D field per tilt image
57
+ - tuple(int, int, int, int) e.g. (3, 3, 2, 10): a volume warp grid
58
+ patch_size : int
59
+ Size of reconstruction patches.
60
+ batch_size : int
61
+ Batch size for reconstruction.
62
+ apply_ctf : bool
63
+ Whether to apply CTF correction.
64
+ device : str | torch.device
65
+ Device to run optimization on.
66
+ max_retries : int
67
+ Maximum number of retry attempts if NaN is encountered.
68
+
69
+ Returns
70
+ -------
71
+ tuple[TiltSeries, list[float]]
72
+ Optimized tilt series and list of loss values.
73
+ """
74
+ # Store original tilt series state in case all retries fail
75
+ original_tilt_axis_offset_y = tilt_series.tilt_axis_offset_y.clone()
76
+ original_tilt_axis_offset_x = tilt_series.tilt_axis_offset_x.clone()
77
+
78
+ # Store original grid states if applicable
79
+ # We need to store complete grid state (dimensions, values, margins)
80
+ # because resize() changes the grid structure
81
+ if setting != "global" and len(setting) == 2:
82
+ has_grid_x = hasattr(tilt_series.grid_movement_x, "values")
83
+ has_grid_y = hasattr(tilt_series.grid_movement_y, "values")
84
+ original_grid_x = (
85
+ {
86
+ "dimensions": tilt_series.grid_movement_x.dimensions,
87
+ "values": tilt_series.grid_movement_x.values.clone(),
88
+ "margins": tilt_series.grid_movement_x.margins,
89
+ }
90
+ if has_grid_x
91
+ else None
92
+ )
93
+ original_grid_y = (
94
+ {
95
+ "dimensions": tilt_series.grid_movement_y.dimensions,
96
+ "values": tilt_series.grid_movement_y.values.clone(),
97
+ "margins": tilt_series.grid_movement_y.margins,
98
+ }
99
+ if has_grid_y
100
+ else None
101
+ )
102
+ elif setting != "global" and len(setting) == 4:
103
+ has_grid_x = hasattr(tilt_series.grid_volume_warp_x, "values")
104
+ has_grid_y = hasattr(tilt_series.grid_volume_warp_y, "values")
105
+ has_grid_z = hasattr(tilt_series.grid_volume_warp_z, "values")
106
+ original_grid_x = (
107
+ {
108
+ "dimensions": tilt_series.grid_volume_warp_x.dimensions,
109
+ "values": tilt_series.grid_volume_warp_x.values.clone(),
110
+ "margins": tilt_series.grid_volume_warp_x.margins,
111
+ }
112
+ if has_grid_x
113
+ else None
114
+ )
115
+ original_grid_y = (
116
+ {
117
+ "dimensions": tilt_series.grid_volume_warp_y.dimensions,
118
+ "values": tilt_series.grid_volume_warp_y.values.clone(),
119
+ "margins": tilt_series.grid_volume_warp_y.margins,
120
+ }
121
+ if has_grid_y
122
+ else None
123
+ )
124
+ original_grid_z = (
125
+ {
126
+ "dimensions": tilt_series.grid_volume_warp_z.dimensions,
127
+ "values": tilt_series.grid_volume_warp_z.values.clone(),
128
+ "margins": tilt_series.grid_volume_warp_z.margins,
129
+ }
130
+ if has_grid_z
131
+ else None
132
+ )
133
+
134
+ # Use highest precision for optimization to avoid NaN issues
135
+ # Training uses "medium" and 16-mixed, but optimization needs full precision
136
+ original_precision = torch.get_float32_matmul_precision()
137
+ torch.set_float32_matmul_precision("highest")
138
+
139
+ # Retry loop
140
+ retries_left = max_retries
141
+ while retries_left > 0:
142
+ try:
143
+ return _optimize_shifts_inner(
144
+ model=model,
145
+ tilt_series=tilt_series,
146
+ images=images,
147
+ pixel_size=pixel_size,
148
+ positions=positions,
149
+ setting=setting,
150
+ patch_size=patch_size,
151
+ batch_size=batch_size,
152
+ apply_ctf=apply_ctf,
153
+ device=device,
154
+ original_precision=original_precision,
155
+ )
156
+ except AlignmentNanError:
157
+ retries_left -= 1
158
+ if retries_left > 0:
159
+ # Reset tilt series to original state before retry
160
+ tilt_series.tilt_axis_offset_y = original_tilt_axis_offset_y.clone()
161
+ tilt_series.tilt_axis_offset_x = original_tilt_axis_offset_x.clone()
162
+
163
+ if setting != "global" and len(setting) == 2:
164
+ if original_grid_x is not None:
165
+ tilt_series.grid_movement_x = CubicGrid(
166
+ dimensions=original_grid_x["dimensions"],
167
+ values=original_grid_x["values"].clone(),
168
+ margins=original_grid_x["margins"],
169
+ )
170
+ if original_grid_y is not None:
171
+ tilt_series.grid_movement_y = CubicGrid(
172
+ dimensions=original_grid_y["dimensions"],
173
+ values=original_grid_y["values"].clone(),
174
+ margins=original_grid_y["margins"],
175
+ )
176
+ elif setting != "global" and len(setting) == 4:
177
+ if original_grid_x is not None:
178
+ tilt_series.grid_volume_warp_x = CubicGrid(
179
+ dimensions=original_grid_x["dimensions"],
180
+ values=original_grid_x["values"].clone(),
181
+ margins=original_grid_x["margins"],
182
+ )
183
+ if original_grid_y is not None:
184
+ tilt_series.grid_volume_warp_y = CubicGrid(
185
+ dimensions=original_grid_y["dimensions"],
186
+ values=original_grid_y["values"].clone(),
187
+ margins=original_grid_y["margins"],
188
+ )
189
+ if original_grid_z is not None:
190
+ tilt_series.grid_volume_warp_z = CubicGrid(
191
+ dimensions=original_grid_z["dimensions"],
192
+ values=original_grid_z["values"].clone(),
193
+ margins=original_grid_z["margins"],
194
+ )
195
+ print(f"Retrying optimization... (retries left: {retries_left})")
196
+
197
+ # All retries failed, restore original state and return failure
198
+ tilt_series.tilt_axis_offset_y = original_tilt_axis_offset_y
199
+ tilt_series.tilt_axis_offset_x = original_tilt_axis_offset_x
200
+
201
+ if setting != "global" and len(setting) == 2:
202
+ if original_grid_x is not None:
203
+ tilt_series.grid_movement_x = CubicGrid(
204
+ dimensions=original_grid_x["dimensions"],
205
+ values=original_grid_x["values"],
206
+ margins=original_grid_x["margins"],
207
+ )
208
+ if original_grid_y is not None:
209
+ tilt_series.grid_movement_y = CubicGrid(
210
+ dimensions=original_grid_y["dimensions"],
211
+ values=original_grid_y["values"],
212
+ margins=original_grid_y["margins"],
213
+ )
214
+ elif setting != "global" and len(setting) == 4:
215
+ if original_grid_x is not None:
216
+ tilt_series.grid_volume_warp_x = CubicGrid(
217
+ dimensions=original_grid_x["dimensions"],
218
+ values=original_grid_x["values"],
219
+ margins=original_grid_x["margins"],
220
+ )
221
+ if original_grid_y is not None:
222
+ tilt_series.grid_volume_warp_y = CubicGrid(
223
+ dimensions=original_grid_y["dimensions"],
224
+ values=original_grid_y["values"],
225
+ margins=original_grid_y["margins"],
226
+ )
227
+ if original_grid_z is not None:
228
+ tilt_series.grid_volume_warp_z = CubicGrid(
229
+ dimensions=original_grid_z["dimensions"],
230
+ values=original_grid_z["values"],
231
+ margins=original_grid_z["margins"],
232
+ )
233
+
234
+ # Restore original precision setting
235
+ torch.set_float32_matmul_precision(original_precision)
236
+
237
+ # Return original tilt series with failure loss
238
+ return tilt_series, [float("inf")]
239
+
240
+
241
+ def _optimize_shifts_inner(
242
+ model: MissAlignment,
243
+ tilt_series: TiltSeries,
244
+ images: torch.Tensor,
245
+ pixel_size: float,
246
+ positions: torch.Tensor,
247
+ setting: str | tuple[int, int] | tuple[int, int, int, int],
248
+ patch_size: int,
249
+ batch_size: int,
250
+ apply_ctf: bool,
251
+ device: str | torch.device,
252
+ original_precision: str,
253
+ ):
254
+ """Inner optimization function that can raise AlignmentNanError.
255
+
256
+ Returns
257
+ -------
258
+ tuple[TiltSeries, list[float]]
259
+ Optimized tilt series and list of loss values.
260
+ """
261
+ # move all modules to device in place
262
+ tilt_series.to(device)
263
+ model.to(device)
264
+ model.freeze()
265
+ model.eval()
266
+ # move images to device
267
+ images = images.to(device)
268
+
269
+ parameters = None
270
+ if setting == "global":
271
+ # store the initial tilt_series alignment
272
+ initial_tilt_axis_offset_y = tilt_series.tilt_axis_offset_y.clone()
273
+ initial_tilt_axis_offset_x = tilt_series.tilt_axis_offset_x.clone()
274
+
275
+ # Find the index of the tilt closest to zero degrees for recentering
276
+ zero_tilt_idx = tilt_series.angles.abs().argmin()
277
+ initial_zero_tilt_shift_y = initial_tilt_axis_offset_y[zero_tilt_idx].clone()
278
+ initial_zero_tilt_shift_x = initial_tilt_axis_offset_x[zero_tilt_idx].clone()
279
+
280
+ # create the alignment parameters
281
+ shifts_y = torch.zeros_like(
282
+ initial_tilt_axis_offset_x,
283
+ requires_grad=True,
284
+ device=device,
285
+ )
286
+ shifts_x = torch.zeros_like(
287
+ initial_tilt_axis_offset_x,
288
+ requires_grad=True,
289
+ device=device,
290
+ )
291
+ parameters = [shifts_y, shifts_x]
292
+ elif len(setting) == 2: # TODO add case of starting from existent grid
293
+ # movement grids - these should receive gradients
294
+ grid_dims = [setting[0], setting[1], tilt_series.n_tilts]
295
+
296
+ tilt_series.grid_movement_x = tilt_series.grid_movement_x.resize(
297
+ new_size=grid_dims
298
+ ).to(device)
299
+ leaf_variable_x = tilt_series.grid_movement_x.values.requires_grad_(True)
300
+ tilt_series.grid_movement_x = CubicGrid(grid_dims, leaf_variable_x)
301
+
302
+ tilt_series.grid_movement_y = tilt_series.grid_movement_y.resize(
303
+ new_size=grid_dims
304
+ ).to(device)
305
+ leaf_variable_y = tilt_series.grid_movement_y.values.requires_grad_(True)
306
+ tilt_series.grid_movement_y = CubicGrid(grid_dims, leaf_variable_y)
307
+
308
+ parameters = [leaf_variable_x, leaf_variable_y]
309
+ elif len(setting) == 4: # TODO add case of starting from existent grid
310
+ tilt_series.grid_volume_warp_x = tilt_series.grid_volume_warp_x.resize(
311
+ new_size=setting
312
+ ).to(device)
313
+ leaf_variable_x = tilt_series.grid_volume_warp_x.values.requires_grad_(True)
314
+ tilt_series.grid_volume_warp_x = CubicGrid(setting, leaf_variable_x)
315
+
316
+ tilt_series.grid_volume_warp_y = tilt_series.grid_volume_warp_y.resize(
317
+ new_size=setting
318
+ ).to(device)
319
+ leaf_variable_y = tilt_series.grid_volume_warp_y.values.requires_grad_(True)
320
+ tilt_series.grid_volume_warp_y = CubicGrid(setting, leaf_variable_y)
321
+
322
+ tilt_series.grid_volume_warp_z = tilt_series.grid_volume_warp_z.resize(
323
+ new_size=setting
324
+ ).to(device)
325
+ leaf_variable_z = tilt_series.grid_volume_warp_z.values.requires_grad_(True)
326
+ tilt_series.grid_volume_warp_z = CubicGrid(setting, leaf_variable_z)
327
+
328
+ parameters = [
329
+ leaf_variable_x,
330
+ leaf_variable_y,
331
+ leaf_variable_z,
332
+ ]
333
+ else:
334
+ raise ValueError(f"Invalid setting for alignment optimization: {setting}")
335
+
336
+ alignment_optimizer = torch.optim.LBFGS(
337
+ parameters,
338
+ line_search_fn="strong_wolfe",
339
+ )
340
+
341
+ # Initialize list to store loss values
342
+ loss_values = []
343
+
344
+ # Determine device type for autocast
345
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
346
+
347
+ def closure():
348
+ alignment_optimizer.zero_grad()
349
+
350
+ # Check for NaN in parameters before computing loss
351
+ # If found, return large penalty to make line search reject this step
352
+ nan_in_params = False
353
+ if setting == "global":
354
+ if torch.isnan(shifts_x).any() or torch.isnan(shifts_y).any():
355
+ nan_in_params = True
356
+ elif len(setting) == 2:
357
+ if torch.isnan(leaf_variable_x).any() or torch.isnan(leaf_variable_y).any():
358
+ nan_in_params = True
359
+ elif len(setting) == 4:
360
+ if (
361
+ torch.isnan(leaf_variable_x).any()
362
+ or torch.isnan(leaf_variable_y).any()
363
+ or torch.isnan(leaf_variable_z).any()
364
+ ):
365
+ nan_in_params = True
366
+
367
+ if nan_in_params:
368
+ raise AlignmentNanError
369
+
370
+ # update the alignments
371
+ if setting == "global":
372
+ tilt_series.tilt_axis_offset_y = initial_tilt_axis_offset_y + shifts_y
373
+ tilt_series.tilt_axis_offset_x = initial_tilt_axis_offset_x + shifts_x
374
+
375
+ batches = int(math.ceil(positions.shape[0] / batch_size))
376
+ total_samples = positions.shape[0]
377
+ total_weighted_score = 0.0
378
+ total_precision = 0.0
379
+
380
+ # Disable autocast to ensure full precision during optimization
381
+ with torch.amp.autocast(device_type=device_type, enabled=False):
382
+ # Use gradient accumulation: process each batch separately
383
+ for b in range(batches):
384
+ if b == batches - 1:
385
+ batch_positions = positions[b * batch_size :]
386
+ else:
387
+ batch_positions = positions[b * batch_size : (b + 1) * batch_size]
388
+
389
+ current_batch_size = batch_positions.shape[0]
390
+
391
+ # reconstruct subvolumes for this batch
392
+ subvolumes = tilt_series.reconstruct_subvolumes_single(
393
+ tilt_data=images,
394
+ coords=batch_positions.to(device),
395
+ pixel_size=pixel_size,
396
+ size=patch_size,
397
+ apply_ctf=apply_ctf,
398
+ oversampling=2.0,
399
+ )
400
+
401
+ # ensure normalization per subvolume
402
+ mean = einops.reduce(subvolumes, "n d h w -> n 1 1 1", reduction="mean")
403
+ std = torch.std(subvolumes, dim=(-3, -2, -1), keepdim=True)
404
+ # Add epsilon to prevent division by zero (which causes NaN precision)
405
+ eps = 1e-8
406
+ subvolumes = (subvolumes - mean) / (std + eps)
407
+
408
+ # change channel to batch dimension
409
+ subvolumes = einops.rearrange(subvolumes, "b d h w -> b 1 d h w")
410
+
411
+ # Get score and precision for this batch
412
+ batch_scores, batch_log_precisions = model(subvolumes)
413
+
414
+ batch_precisions = batch_log_precisions.exp().detach()
415
+
416
+ # Precision-weighted average score for this batch
417
+ batch_weighted_score = (batch_scores * batch_precisions).sum()
418
+ batch_precision_sum = batch_precisions.sum()
419
+
420
+ # Weight by batch size for proper gradient accumulation
421
+ weighted_loss = batch_weighted_score * (
422
+ current_batch_size / total_samples
423
+ )
424
+
425
+ # Backward pass for this batch (gradients accumulate)
426
+ weighted_loss.backward()
427
+
428
+ # Accumulate for precision-weighted average
429
+ total_weighted_score += batch_weighted_score.item()
430
+ total_precision += batch_precision_sum.item()
431
+
432
+ # Precision-weighted average score
433
+ if total_precision <= 0:
434
+ raise ValueError(
435
+ f"Total precision is {total_precision}, which is <= 0. "
436
+ "This indicates a problem with the model precision outputs."
437
+ )
438
+ avg_score = total_weighted_score / total_precision
439
+
440
+ # Check if loss is NaN and raise error
441
+ if math.isnan(avg_score):
442
+ raise AlignmentNanError("Loss value is NaN")
443
+
444
+ loss_values.append(avg_score)
445
+
446
+ return avg_score
447
+
448
+ n_iters = 1 # 5 iterations should give convergence
449
+ for x in range(n_iters):
450
+ alignment_optimizer.step(closure)
451
+
452
+ if setting == "global":
453
+ # remove gradients and finalize global shifts
454
+ tilt_series.tilt_axis_offset_y = initial_tilt_axis_offset_y + shifts_y.detach()
455
+ tilt_series.tilt_axis_offset_x = initial_tilt_axis_offset_x + shifts_x.detach()
456
+
457
+ # Recenter alignment: set the shift at zero tilt to match initial zero tilt
458
+ # Get the current shift at the zero tilt
459
+ current_zero_tilt_shift_y = tilt_series.tilt_axis_offset_y[zero_tilt_idx]
460
+ current_zero_tilt_shift_x = tilt_series.tilt_axis_offset_x[zero_tilt_idx]
461
+
462
+ # Calculate the difference from initial to current at zero tilt
463
+ delta_shift_y = current_zero_tilt_shift_y - initial_zero_tilt_shift_y
464
+ delta_shift_x = current_zero_tilt_shift_x - initial_zero_tilt_shift_x
465
+
466
+ delta_shift_2d = torch.tensor(
467
+ [delta_shift_y, delta_shift_x],
468
+ device=device,
469
+ dtype=tilt_series.angles.dtype,
470
+ )
471
+ m_2d = R(tilt_series.tilt_axis_angles, yx=True)
472
+ m_2d = torch.linalg.inv(m_2d[zero_tilt_idx, :2, :2])
473
+ delta_shift_2d = m_2d @ einops.rearrange(delta_shift_2d, "x -> x 1")
474
+ delta_shift_y, delta_shift_x = delta_shift_2d[0], delta_shift_2d[1]
475
+
476
+ # Create a 3D shift tensor with z=0 (in ZYX order)
477
+ shift_3d = torch.tensor(
478
+ [0.0, delta_shift_y, delta_shift_x],
479
+ device=device,
480
+ dtype=tilt_series.angles.dtype,
481
+ )
482
+
483
+ # Compute projection matrices from tilt angles
484
+ r0 = Ry(-tilt_series.angles, zyx=True)
485
+ r1 = Rz(tilt_series.tilt_axis_angles, zyx=True)
486
+ rotation_matrices = r1 @ r0
487
+ projection_matrices = rotation_matrices[..., 1:3, :3]
488
+
489
+ # Project the 3D shift to 2D shifts for all tilts
490
+ shifts_2d = project_volume_shift_to_image_alignment(
491
+ shift_3d, projection_matrices
492
+ )
493
+
494
+ # Apply the correction: subtract the projected delta shift from all tilts
495
+ tilt_series.tilt_axis_offset_y -= shifts_2d[:, 0]
496
+ tilt_series.tilt_axis_offset_x -= shifts_2d[:, 1]
497
+ elif len(setting) == 2:
498
+ # remove gradients
499
+ tilt_series.grid_movement_x.values = tilt_series.grid_movement_x.values.detach()
500
+ tilt_series.grid_movement_y.values = tilt_series.grid_movement_y.values.detach()
501
+ elif len(setting) == 4:
502
+ # remove gradients
503
+ tilt_series.grid_volume_warp_x.values = (
504
+ tilt_series.grid_volume_warp_x.values.detach()
505
+ )
506
+ tilt_series.grid_volume_warp_y.values = (
507
+ tilt_series.grid_volume_warp_y.values.detach()
508
+ )
509
+ tilt_series.grid_volume_warp_z.values = (
510
+ tilt_series.grid_volume_warp_z.values.detach()
511
+ )
512
+ # move back because there were modified in-place
513
+ tilt_series.to("cpu")
514
+ model.to("cpu")
515
+
516
+ # Restore original precision setting
517
+ torch.set_float32_matmul_precision(original_precision)
518
+
519
+ return tilt_series, loss_values