midas-diffract 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1559 @@
1
+ """Generic differentiable forward model for all HEDM modalities (NF, FF, pf).
2
+
3
+ The core Bragg geometry, omega solver, and detector projection are identical
4
+ across modalities. Modality differences (detector distance, output mode, scan
5
+ strategy) are handled via configuration, not subclassing.
6
+
7
+ Physics pipeline:
8
+ euler_angles, positions [, lattice_params]
9
+ -> orientation matrices (euler2mat)
10
+ -> G-vectors in crystal frame (calc_bragg_geometry)
11
+ -> [optional: strained G-vectors] (correct_hkls_latc)
12
+ -> omega solver (quadratic) (calc_bragg_geometry)
13
+ -> eta computation (calc_bragg_geometry)
14
+ -> position-dependent projection (project_to_detector)
15
+ -> validity filtering (project_to_detector)
16
+ -> SpotDescriptors
17
+
18
+ Output modes:
19
+ SpotDescriptors -> predict_images() [NF: Gaussian splatting]
20
+ SpotDescriptors -> predict_spot_coords() [FF/pf: angular coordinates]
21
+
22
+ Reference C code:
23
+ CorrectHKLsLatC: FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252
24
+ CalcDiffrSpots_Furnace: NF_HEDM/src/CalcDiffractionSpots.c:87-183
25
+ DisplacementSpots: NF_HEDM/src/SharedFuncsFit.c:269-292
26
+ Beam proximity filter: FF_HEDM/src/FitOrStrainsScanningOMP.c:1048-1058
27
+ """
28
+
29
+ import math
30
+ from dataclasses import dataclass, field
31
+ from typing import Optional, Tuple
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Configuration data classes
40
+ # ---------------------------------------------------------------------------
41
+
42
+ @dataclass
43
+ class HEDMGeometry:
44
+ """Detector and scan geometry configuration.
45
+
46
+ Units: distances in micrometers, angles in degrees (converted internally).
47
+
48
+ NF-HEDM uses multiple detector distances (nDistances=2-4), each with its
49
+ own Lsd, y_BC, z_BC. FF-HEDM and pf-HEDM use a single distance.
50
+ When lists are provided for Lsd/y_BC/z_BC, each entry is one "layer"
51
+ (detector distance). A spot is valid only if it falls on the detector
52
+ at **every** distance (the AllDistsFound logic in the C code).
53
+ """
54
+ Lsd: "float | list[float]" # Sample-detector distance(s) (um)
55
+ y_BC: "float | list[float]" # Beam center y (pixels), per distance
56
+ z_BC: "float | list[float]" # Beam center z (pixels), per distance
57
+ px: float # Pixel size (um) -- shared across distances
58
+ omega_start: float # Omega start (degrees)
59
+ omega_step: float # Omega step (degrees, may be negative)
60
+ n_frames: int # Frames per distance (NrFilesPerDistance)
61
+ n_pixels_y: int # Detector pixels in y
62
+ n_pixels_z: int # Detector pixels in z
63
+ min_eta: float # Minimum eta angle (degrees)
64
+ wavelength: float = 0.0 # X-ray wavelength (Angstroms)
65
+ # Detector tilts (degrees). Applied only in NF mode (flip_y=False), which
66
+ # compares predictions to raw detector images at pixel level. FF and
67
+ # pf-HEDM workflows apply a DetCor correction at peak-finding time, so
68
+ # their centroids are already tilt- and distortion-corrected; the
69
+ # forward model therefore ignores these fields when flip_y=True to avoid
70
+ # double-correcting.
71
+ tx: float = 0.0
72
+ ty: float = 0.0
73
+ tz: float = 0.0
74
+ flip_y: bool = True # FF/PF: True (DetHor = yBC - ydet/px).
75
+ # NF: False (pixel = yBC + ydet/px).
76
+ # Validated against C code conventions.
77
+
78
+ @property
79
+ def n_distances(self) -> int:
80
+ return len(self.Lsd) if isinstance(self.Lsd, list) else 1
81
+
82
+ def _as_list(self, attr):
83
+ v = getattr(self, attr)
84
+ return v if isinstance(v, list) else [v]
85
+
86
+
87
+ @dataclass
88
+ class ScanConfig:
89
+ """Multi-scan configuration for pf-HEDM (beam translation positions).
90
+
91
+ In pf-HEDM the pencil beam is translated to different Y-positions.
92
+ Each scan is a full omega sweep at one beam Y-position.
93
+ NF-HEDM and FF-HEDM do NOT use this (they have a single beam position).
94
+ """
95
+ beam_positions: torch.Tensor # (S,) beam y-positions per scan (um)
96
+ beam_size: float # Beam height (um)
97
+
98
+
99
+ @dataclass
100
+ class TriVoxelConfig:
101
+ """Triangular voxel grid configuration for NF-HEDM.
102
+
103
+ NF-HEDM uses equilateral triangle voxels. Each voxel is defined by
104
+ a center ``(x, y)``, an ``edge_length``, and an up/down flag ``ud``.
105
+ The three vertices are computed as:
106
+
107
+ .. code-block:: text
108
+
109
+ gs = edge_length / 2
110
+ dy1 = edge_length / sqrt(3) (flipped if ud < 0)
111
+ dy2 = -edge_length / (2*sqrt(3))
112
+
113
+ V0 = (x, y + dy1)
114
+ V1 = (x - gs, y + dy2)
115
+ V2 = (x + gs, y + dy2)
116
+
117
+ Matches ``simulateNF.c`` lines 556-572.
118
+ """
119
+ edge_lengths: torch.Tensor # (N,) per-voxel edge length in um
120
+ ud: torch.Tensor # (N,) up/down flag (+1 or -1)
121
+
122
+
123
+ @dataclass
124
+ class SpotDescriptors:
125
+ """Output of the forward model: all information about predicted spots.
126
+
127
+ All angular quantities are in radians. Pixel coordinates are fractional.
128
+ Shape convention: ``(..., K, M)`` where ``K = 2*N`` (two omega solutions
129
+ per position) and ``M`` = number of HKL reflections.
130
+
131
+ For multi-distance NF-HEDM, ``y_pixel``, ``z_pixel``, and
132
+ ``layer_valid`` have an extra leading ``D`` (n_distances) dimension.
133
+ The ``valid`` mask combines the angular validity with ALL-distances-found.
134
+ """
135
+ omega: torch.Tensor # (..., K, M) radians
136
+ eta: torch.Tensor # (..., K, M) radians
137
+ two_theta: torch.Tensor # (..., K, M) radians
138
+ y_pixel: torch.Tensor # (D, ..., K, M) or (..., K, M) fractional pixel
139
+ z_pixel: torch.Tensor # (D, ..., K, M) or (..., K, M) fractional pixel
140
+ frame_nr: torch.Tensor # (..., K, M) fractional frame (same at all distances)
141
+ valid: torch.Tensor # (..., K, M) float mask (1=valid at ALL distances)
142
+ layer_valid: Optional[torch.Tensor] = None # (D, ..., K, M) per-distance validity
143
+ scan_mask: Optional[torch.Tensor] = None # (..., S, K, M) per-beam-position validity (pf)
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Forward model
148
+ # ---------------------------------------------------------------------------
149
+
150
+ class HEDMForwardModel(nn.Module):
151
+ """Generic differentiable forward model for NF / FF / pf-HEDM.
152
+
153
+ Parameters
154
+ ----------
155
+ hkls : Tensor (M, 3)
156
+ Reciprocal-space G-vectors in Cartesian coordinates (1/Angstroms).
157
+ These are the *nominal* (unstrained) G-vectors, already transformed
158
+ through the B matrix for the reference lattice.
159
+ thetas : Tensor (M,)
160
+ Nominal Bragg angles in radians corresponding to ``hkls``.
161
+ geometry : HEDMGeometry
162
+ Detector / scan geometry.
163
+ hkls_int : Tensor (M, 3), optional
164
+ Integer Miller indices. Required for ``correct_hkls_latc`` (strain).
165
+ If None, strain correction is unavailable.
166
+ scan_config : ScanConfig, optional
167
+ Multi-scan configuration. None for single-scan (standard NF / FF).
168
+ device : torch.device
169
+ Target device.
170
+ """
171
+
172
+ # Match the C code: both now use M_PI/180.0 (previously the C code
173
+ # used hardcoded 13-digit constants that caused precision loss).
174
+ DEG2RAD = math.pi / 180.0
175
+ RAD2DEG = 180.0 / math.pi
176
+
177
+ def __init__(
178
+ self,
179
+ hkls: torch.Tensor,
180
+ thetas: torch.Tensor,
181
+ geometry: HEDMGeometry,
182
+ hkls_int: Optional[torch.Tensor] = None,
183
+ scan_config: Optional[ScanConfig] = None,
184
+ device: torch.device = torch.device("cpu"),
185
+ ):
186
+ super().__init__()
187
+
188
+ self.register_buffer("hkls", hkls.to(device).float())
189
+ self.register_buffer("thetas", thetas.to(device).float())
190
+
191
+ if hkls_int is not None:
192
+ self.register_buffer("hkls_int", hkls_int.to(device).float())
193
+ else:
194
+ self.hkls_int = None
195
+
196
+ # Geometry -- per-distance arrays stored as tensors for vectorised projection
197
+ Lsd_list = geometry._as_list("Lsd")
198
+ yBC_list = geometry._as_list("y_BC")
199
+ zBC_list = geometry._as_list("z_BC")
200
+ self.n_distances = len(Lsd_list)
201
+ self.register_buffer("_Lsd", torch.tensor(Lsd_list, dtype=torch.float32, device=device))
202
+ self.register_buffer("_y_BC", torch.tensor(yBC_list, dtype=torch.float32, device=device))
203
+ self.register_buffer("_z_BC", torch.tensor(zBC_list, dtype=torch.float32, device=device))
204
+ # Convenience aliases for single-distance (backward compat / simple access)
205
+ self.Lsd = Lsd_list[0]
206
+ self.y_BC = yBC_list[0]
207
+ self.z_BC = zBC_list[0]
208
+ self.px = geometry.px
209
+ self.omega_start = geometry.omega_start
210
+ self.omega_step = geometry.omega_step
211
+ self.n_frames = geometry.n_frames
212
+ self.n_pixels_y = geometry.n_pixels_y
213
+ self.n_pixels_z = geometry.n_pixels_z
214
+ self.min_eta = geometry.min_eta * self.DEG2RAD # store in radians
215
+ self.wavelength = geometry.wavelength
216
+ self.flip_y = geometry.flip_y
217
+
218
+ # Detector tilts (degrees). Stored as an nn.Parameter so they can be
219
+ # optimised via gradient descent (auto-calibration). NF mode
220
+ # (flip_y=False) applies them via _apply_nf_tilt; FF/pf mode
221
+ # (flip_y=True) ignores them because the experimental pipeline pre-
222
+ # corrects for detector tilts at peak-finding time.
223
+ #
224
+ # Composition: RotMatTilts = Rz(tz) @ Ry(ty) @ Rx(tx)
225
+ # Matches RotationTilts() in NF_HEDM/src/SharedFuncsFit.c:230-266.
226
+ self.tilts = nn.Parameter(
227
+ torch.tensor([geometry.tx, geometry.ty, geometry.tz],
228
+ dtype=torch.float64, device=device),
229
+ requires_grad=False,
230
+ )
231
+ self.tx = float(geometry.tx)
232
+ self.ty = float(geometry.ty)
233
+ self.tz = float(geometry.tz)
234
+ self._has_tilts = abs(self.tx) + abs(self.ty) + abs(self.tz) > 0.0
235
+
236
+ # Scan config
237
+ self.scan_config = scan_config
238
+ if scan_config is not None:
239
+ self.register_buffer(
240
+ "_beam_positions", scan_config.beam_positions.to(device).float()
241
+ )
242
+ self._beam_size = scan_config.beam_size
243
+
244
+ self.epsilon = 1e-7
245
+
246
+ # ------------------------------------------------------------------
247
+ # euler2mat (ZXZ convention)
248
+ # ------------------------------------------------------------------
249
+
250
+ @staticmethod
251
+ def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
252
+ """Convert ZXZ Euler angles to rotation matrices.
253
+
254
+ Parameters
255
+ ----------
256
+ euler_angles : Tensor (..., 3)
257
+ Euler angles (phi1, Phi, phi2) in radians.
258
+
259
+ Returns
260
+ -------
261
+ Tensor (..., 3, 3)
262
+ Rotation matrices.
263
+ """
264
+ c = torch.cos(euler_angles)
265
+ s = torch.sin(euler_angles)
266
+
267
+ c0, c1, c2 = c[..., 0], c[..., 1], c[..., 2]
268
+ s0, s1, s2 = s[..., 0], s[..., 1], s[..., 2]
269
+
270
+ # ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
271
+ # Verified element-by-element against nfhedm.py lines 114-120
272
+ R = torch.zeros(*euler_angles.shape[:-1], 3, 3,
273
+ dtype=euler_angles.dtype, device=euler_angles.device)
274
+ R[..., 0, 0] = c0 * c2 - s0 * c1 * s2
275
+ R[..., 0, 1] = -s0 * c1 * c2 - c0 * s2
276
+ R[..., 0, 2] = s0 * s1
277
+ R[..., 1, 0] = s0 * c2 + c0 * c1 * s2
278
+ R[..., 1, 1] = c0 * c1 * c2 - s0 * s2
279
+ R[..., 1, 2] = -c0 * s1
280
+ R[..., 2, 0] = s1 * s2
281
+ R[..., 2, 1] = s1 * c2
282
+ R[..., 2, 2] = c1
283
+
284
+ return HEDMForwardModel.orthogonalize(R)
285
+
286
+ # ------------------------------------------------------------------
287
+ # orthogonalize (SO(3) projection via SVD)
288
+ # ------------------------------------------------------------------
289
+
290
+ @staticmethod
291
+ def orthogonalize(R: torch.Tensor) -> torch.Tensor:
292
+ """Project a (..., 3, 3) matrix onto SO(3).
293
+
294
+ Guarantees ``R^T R = I`` and ``det(R) = +1`` (proper rotation).
295
+ Uses one Newton-Schulz iteration which is differentiable and
296
+ numerically stable (no SVD singularity at repeated singular values).
297
+
298
+ For matrices already near SO(3) (like those from ``euler2mat``),
299
+ a single iteration suffices (quadratic convergence).
300
+
301
+ Parameters
302
+ ----------
303
+ R : Tensor (..., 3, 3)
304
+
305
+ Returns
306
+ -------
307
+ Tensor (..., 3, 3) -- orthogonal with det = +1.
308
+ """
309
+ # Newton-Schulz iteration for polar decomposition:
310
+ # Q_{k+1} = 0.5 * Q_k * (3*I - Q_k^T * Q_k)
311
+ # Converges quadratically for matrices near SO(3).
312
+ # One iteration is sufficient for matrices from euler2mat
313
+ # (error < 1e-14 after one step for input error < 1e-7).
314
+ I = torch.eye(3, dtype=R.dtype, device=R.device)
315
+ Q = R
316
+ Q = 0.5 * Q @ (3.0 * I - Q.transpose(-1, -2) @ Q)
317
+ # Ensure det = +1 (the iteration preserves det sign for near-SO(3) input,
318
+ # but we guard against pathological cases)
319
+ det = torch.det(Q)
320
+ # Where det < 0, negate (flips to proper rotation)
321
+ sign = torch.where(det < 0, torch.tensor(-1.0, dtype=R.dtype, device=R.device),
322
+ torch.tensor(1.0, dtype=R.dtype, device=R.device))
323
+ return Q * sign.unsqueeze(-1).unsqueeze(-1)
324
+
325
+ # ------------------------------------------------------------------
326
+ # safe_arccos
327
+ # ------------------------------------------------------------------
328
+
329
+ def safe_arccos(self, x: torch.Tensor) -> torch.Tensor:
330
+ """Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
331
+ return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
332
+
333
+ # ------------------------------------------------------------------
334
+ # rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
335
+ # ------------------------------------------------------------------
336
+
337
+ @staticmethod
338
+ def rotate_strain_sample_to_crystal(
339
+ orientation_matrices: torch.Tensor,
340
+ strain_sample: torch.Tensor,
341
+ ) -> torch.Tensor:
342
+ """Rotate a symmetric infinitesimal strain from sample to crystal frame.
343
+
344
+ Port of ``RotateStrainSampleToCrystal`` from
345
+ ``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
346
+ eps_crystal = OM^T . eps_sample . OM, in Voigt notation
347
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
348
+
349
+ Parameters
350
+ ----------
351
+ orientation_matrices : Tensor (..., 3, 3)
352
+ strain_sample : Tensor (..., 6)
353
+
354
+ Returns
355
+ -------
356
+ strain_crystal : Tensor (..., 6)
357
+ """
358
+ e = strain_sample
359
+ S = torch.stack([
360
+ torch.stack([e[..., 0], e[..., 1], e[..., 2]], dim=-1),
361
+ torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
362
+ torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
363
+ ], dim=-2)
364
+ OM = orientation_matrices
365
+ C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
366
+ return torch.stack([
367
+ C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
368
+ C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
369
+ ], dim=-1)
370
+
371
+ # ------------------------------------------------------------------
372
+ # correct_hkls_latc (port of C CorrectHKLsLatC)
373
+ # ------------------------------------------------------------------
374
+
375
+ def correct_hkls_latc(
376
+ self,
377
+ lattice_params: torch.Tensor,
378
+ strain: Optional[torch.Tensor] = None,
379
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
380
+ """Compute strained reciprocal-space G-vectors and Bragg angles.
381
+
382
+ Builds the reciprocal lattice B matrix from lattice parameters
383
+ and transforms integer Miller indices to Cartesian G-vectors.
384
+
385
+ Faithfully ports ``CorrectHKLsLatC`` from
386
+ ``FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252``, with the
387
+ optional crystal-frame strain path from ``CorrectHKLsLatCEpsilon``
388
+ in ``FF_HEDM/src/ForwardSimulationCompressed.c:423-475``.
389
+
390
+ Parameters
391
+ ----------
392
+ lattice_params : Tensor (..., 6)
393
+ [a, b, c, alpha, beta, gamma] in Angstroms and degrees.
394
+ The ``...`` dimensions allow per-voxel or per-grain parameters.
395
+ strain : Tensor (..., 6), optional
396
+ Crystal-frame symmetric infinitesimal strain in Voigt form
397
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. When supplied,
398
+ the reciprocal lattice is post-multiplied by (I + eps)^{-1}:
399
+ B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
400
+ to convert a sample-frame strain into the crystal frame.
401
+
402
+ Returns
403
+ -------
404
+ hkls_cart : Tensor (..., M, 3)
405
+ G-vectors in Cartesian reciprocal space (1/Angstroms).
406
+ thetas : Tensor (..., M)
407
+ Bragg angles in radians.
408
+
409
+ Raises
410
+ ------
411
+ RuntimeError
412
+ If ``hkls_int`` was not provided at construction.
413
+ """
414
+ if self.hkls_int is None:
415
+ raise RuntimeError(
416
+ "correct_hkls_latc requires integer Miller indices "
417
+ "(pass hkls_int to the constructor)."
418
+ )
419
+
420
+ a = lattice_params[..., 0]
421
+ b = lattice_params[..., 1]
422
+ c = lattice_params[..., 2]
423
+ # Angles in degrees -> radians
424
+ alpha = lattice_params[..., 3] * self.DEG2RAD
425
+ beta = lattice_params[..., 4] * self.DEG2RAD
426
+ gamma = lattice_params[..., 5] * self.DEG2RAD
427
+
428
+ sin_a = torch.sin(alpha)
429
+ cos_a = torch.cos(alpha)
430
+ sin_b = torch.sin(beta)
431
+ cos_b = torch.cos(beta)
432
+ sin_g = torch.sin(gamma)
433
+ cos_g = torch.cos(gamma)
434
+
435
+ # Reciprocal lattice angles
436
+ # C: GammaPr = acosd((CosA*CosB - CosG) / (SinA*SinB))
437
+ gamma_pr = torch.acos(
438
+ torch.clamp((cos_a * cos_b - cos_g) / (sin_a * sin_b + self.epsilon),
439
+ -1.0 + self.epsilon, 1.0 - self.epsilon)
440
+ )
441
+ beta_pr = torch.acos(
442
+ torch.clamp((cos_g * cos_a - cos_b) / (sin_g * sin_a + self.epsilon),
443
+ -1.0 + self.epsilon, 1.0 - self.epsilon)
444
+ )
445
+ sin_beta_pr = torch.sin(beta_pr)
446
+
447
+ # Volume and reciprocal lengths
448
+ vol = a * b * c * sin_a * sin_beta_pr * sin_g
449
+ a_pr = b * c * sin_a / (vol + self.epsilon)
450
+ b_pr = c * a * sin_b / (vol + self.epsilon)
451
+ c_pr = a * b * sin_g / (vol + self.epsilon)
452
+
453
+ # Build B matrix (..., 3, 3)
454
+ zeros = torch.zeros_like(a)
455
+ # Row 0
456
+ B00 = a_pr
457
+ B01 = b_pr * torch.cos(gamma_pr)
458
+ B02 = c_pr * torch.cos(beta_pr)
459
+ # Row 1
460
+ B10 = zeros
461
+ B11 = b_pr * torch.sin(gamma_pr)
462
+ B12 = -c_pr * sin_beta_pr * cos_a
463
+ # Row 2
464
+ B20 = zeros
465
+ B21 = zeros
466
+ B22 = c_pr * sin_beta_pr * sin_a
467
+
468
+ # Stack into (..., 3, 3)
469
+ B = torch.stack([
470
+ torch.stack([B00, B01, B02], dim=-1),
471
+ torch.stack([B10, B11, B12], dim=-1),
472
+ torch.stack([B20, B21, B22], dim=-1),
473
+ ], dim=-2)
474
+
475
+ # Optional crystal-frame strain: B = (I + eps)^{-1} @ B0
476
+ # Voigt layout matches C CorrectHKLsLatCEpsilon:
477
+ # eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
478
+ if strain is not None:
479
+ e11 = strain[..., 0]
480
+ e12 = strain[..., 1]
481
+ e13 = strain[..., 2]
482
+ e22 = strain[..., 3]
483
+ e23 = strain[..., 4]
484
+ e33 = strain[..., 5]
485
+ one = torch.ones_like(e11)
486
+ F_mat = torch.stack([
487
+ torch.stack([one + e11, e12, e13 ], dim=-1),
488
+ torch.stack([e12, one + e22, e23 ], dim=-1),
489
+ torch.stack([e13, e23, one + e33], dim=-1),
490
+ ], dim=-2)
491
+ F_inv = torch.linalg.inv(F_mat)
492
+ B = torch.matmul(F_inv, B)
493
+
494
+ # G_cart = B @ hkls_int^T => (..., M, 3)
495
+ # hkls_int is (M, 3), B is (..., 3, 3) -- match dtype
496
+ hkls_cart = torch.einsum("...ij,mj->...mi", B, self.hkls_int.to(B.dtype))
497
+
498
+ # d-spacing = 1 / |G_cart|
499
+ g_norm = torch.norm(hkls_cart, dim=-1).clamp(min=self.epsilon)
500
+ d_spacing = 1.0 / g_norm
501
+
502
+ # Bragg angle: theta = arcsin(wavelength / (2*d))
503
+ sin_theta = (self.wavelength / (2.0 * d_spacing)).clamp(
504
+ -1.0 + self.epsilon, 1.0 - self.epsilon
505
+ )
506
+ thetas = torch.asin(sin_theta)
507
+
508
+ return hkls_cart, thetas
509
+
510
+ # ------------------------------------------------------------------
511
+ # calc_bragg_geometry (omega quadratic solver + eta)
512
+ # ------------------------------------------------------------------
513
+
514
+ def calc_bragg_geometry(
515
+ self,
516
+ orientation_matrices: torch.Tensor,
517
+ hkls_cart: Optional[torch.Tensor] = None,
518
+ thetas: Optional[torch.Tensor] = None,
519
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
520
+ """Core Bragg geometry: orientations + G-vectors -> angles.
521
+
522
+ Solves the omega quadratic from the diffraction condition and
523
+ computes the eta azimuthal angle.
524
+
525
+ Ports the quadratic solver from
526
+ ``NF_HEDM/src/CalcDiffractionSpots.c:87-183``.
527
+
528
+ Parameters
529
+ ----------
530
+ orientation_matrices : Tensor (..., N, 3, 3)
531
+ Rotation matrices for each voxel/grain.
532
+ hkls_cart : Tensor (..., M, 3) or None
533
+ G-vectors in Cartesian reciprocal space. If None, uses
534
+ the nominal ``self.hkls``.
535
+ thetas : Tensor (..., M) or None
536
+ Bragg angles in radians. If None, uses ``self.thetas``.
537
+
538
+ Returns
539
+ -------
540
+ omega : Tensor (..., 2N, M) -- two solutions (+/-) per position
541
+ eta : Tensor (..., 2N, M)
542
+ two_theta : Tensor (..., 2N, M)
543
+ valid : Tensor (..., 2N, M) float mask (1=valid, 0=invalid)
544
+ """
545
+ dtype = orientation_matrices.dtype
546
+ if hkls_cart is None:
547
+ hkls_cart = self.hkls.to(dtype) # (M, 3)
548
+ else:
549
+ hkls_cart = hkls_cart.to(dtype)
550
+ if thetas is None:
551
+ thetas = self.thetas.to(dtype) # (M,)
552
+ else:
553
+ thetas = thetas.to(dtype)
554
+
555
+ # G_C = R @ hkls^T => (..., N, M, 3)
556
+ # Two cases supported: (a) hkls_cart shape (M, 3) shared across the
557
+ # batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
558
+ # rendering. Both flow through the same einsum via leading-dim
559
+ # broadcasting on the second arg.
560
+ G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
561
+
562
+ # v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
563
+ # (rotation preserves norm in exact arithmetic but not in float64).
564
+ # Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
565
+ len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
566
+ v = torch.sin(thetas) * len_hkl # (M,) or (..., M)
567
+ v = v.unsqueeze(-2).expand_as(G_C[..., 0]) # (..., N, M)
568
+
569
+ # Extract components
570
+ Gx = G_C[..., 0] # (..., N, M)
571
+ Gy = G_C[..., 1]
572
+ Gz = G_C[..., 2]
573
+
574
+ # Quadratic solver for omega
575
+ # -Gx*cos(w) + Gy*sin(w) = v
576
+ # Rearranged: a*cos^2(w) + b*cos(w) + c = 0
577
+ # C uses almostzero=1e-12 for the Gy≈0 branch (see
578
+ # NF_HEDM/src/CalcDiffractionSpots.c:96 and
579
+ # FF_HEDM/src/ForwardSimulationCompressed.c:168). Match exactly.
580
+ almostzero = 1e-12
581
+ x2 = Gx * Gx
582
+ y2 = Gy * Gy
583
+ a = 1.0 + x2 / (y2 + self.epsilon)
584
+ b_coeff = 2.0 * v * Gx / (y2 + self.epsilon)
585
+ c_coeff = v * v / (y2 + self.epsilon) - 1.0
586
+ discriminant = b_coeff * b_coeff - 4.0 * a * c_coeff
587
+
588
+ sqrt_disc = torch.sqrt(torch.abs(discriminant))
589
+
590
+ coswp = (-b_coeff + sqrt_disc) / (2.0 * a)
591
+ coswn = (-b_coeff - sqrt_disc) / (2.0 * a)
592
+
593
+ wap = self.safe_arccos(coswp)
594
+ wan = self.safe_arccos(coswn)
595
+ wbp = -wap
596
+ wbn = -wan
597
+
598
+ # Select correct branch: the one satisfying -Gx*cos(w)+Gy*sin(w)=v
599
+ eqap = -Gx * torch.cos(wap) + Gy * torch.sin(wap)
600
+ eqbp = -Gx * torch.cos(wbp) + Gy * torch.sin(wbp)
601
+ eqan = -Gx * torch.cos(wan) + Gy * torch.sin(wan)
602
+ eqbn = -Gx * torch.cos(wbn) + Gy * torch.sin(wbn)
603
+
604
+ Dap = torch.abs(eqap - v)
605
+ Dbp = torch.abs(eqbp - v)
606
+ Dan = torch.abs(eqan - v)
607
+ Dbn = torch.abs(eqbn - v)
608
+
609
+ all_wp = torch.where(Dap < Dbp, wap, wbp)
610
+ all_wn = torch.where(Dan < Dbn, wan, wbn)
611
+
612
+ # Special case: Gy ~ 0 (C uses almostzero=1e-12)
613
+ # C code (CalcDiffractionSpots.c:97-106):
614
+ # cosome1 = -v / x;
615
+ # if (|cosome1| <= 1) { ome = acos(cosome1); solutions: +ome, -ome }
616
+ gy_zero = torch.abs(Gy) < almostzero
617
+ cosome_special = -v / (Gx + self.epsilon)
618
+ cosome_special_valid = (torch.abs(cosome_special) <= 1.0) & gy_zero & (torch.abs(Gx) > self.epsilon)
619
+ special_w = self.safe_arccos(cosome_special) # positive omega solution
620
+ # Two solutions: +ome and -ome
621
+ special_wp = special_w # positive
622
+ special_wn = -special_w # negative
623
+
624
+ # When |Gy| < almostzero, use the special case; otherwise use the quadratic
625
+ disc_valid = (discriminant >= 0) & (~gy_zero)
626
+ coswp_valid = (coswp >= -1.0) & (coswp <= 1.0)
627
+ coswn_valid = (coswn >= -1.0) & (coswn <= 1.0)
628
+
629
+ omega_p = torch.where(cosome_special_valid, special_wp,
630
+ torch.where(disc_valid & coswp_valid, all_wp,
631
+ torch.zeros_like(all_wp)))
632
+ omega_n = torch.where(cosome_special_valid, special_wn,
633
+ torch.where(disc_valid & coswn_valid, all_wn,
634
+ torch.zeros_like(all_wn)))
635
+
636
+ # Concatenate two solutions: (..., 2N, M)
637
+ all_omega = torch.cat([omega_p, omega_n], dim=-2)
638
+
639
+ # Build omega rotation matrix for each spot
640
+ cos_w = torch.cos(all_omega)
641
+ sin_w = torch.sin(all_omega)
642
+
643
+ # Construct Rz(omega): rotation around z-axis
644
+ # [[cos, -sin, 0], [sin, cos, 0], [0, 0, 1]]
645
+ Omega_mat = torch.zeros(*all_omega.shape, 3, 3,
646
+ dtype=all_omega.dtype, device=all_omega.device)
647
+ Omega_mat[..., 0, 0] = cos_w
648
+ Omega_mat[..., 0, 1] = -sin_w
649
+ Omega_mat[..., 1, 0] = sin_w
650
+ Omega_mat[..., 1, 1] = cos_w
651
+ Omega_mat[..., 2, 2] = 1.0
652
+
653
+ # Rotate G_C by omega: nrot = Omega @ G_C
654
+ # G_C is (..., N, M, 3); double along N dim (dim=-3)
655
+ G_C_doubled = torch.cat([G_C, G_C], dim=-3) # (..., 2N, M, 3)
656
+ nrot = torch.einsum("...kmij,...kmj->...kmi", Omega_mat, G_C_doubled)
657
+ nrot_y = nrot[..., 1] # (..., 2N, M)
658
+ nrot_z = nrot[..., 2]
659
+
660
+ # Eta angle
661
+ r_yz = torch.sqrt(nrot_y * nrot_y + nrot_z * nrot_z).clamp(min=self.epsilon)
662
+ eta = self.safe_arccos(nrot_z / r_yz)
663
+ eta = -torch.sign(nrot_y) * eta
664
+
665
+ # 2*theta (broadcast thetas to match 2N dimension)
666
+ two_theta_single = 2.0 * thetas.unsqueeze(-2) # (..., 1, M) or (1, M)
667
+ two_theta = two_theta_single.expand_as(all_omega)
668
+
669
+ # Validity mask
670
+ valid_p = disc_valid & coswp_valid
671
+ valid_n = disc_valid & coswn_valid
672
+ # For gy_zero special case, valid only if cosome is in [-1, 1]
673
+ valid_p = valid_p | cosome_special_valid
674
+ valid_n = valid_n | cosome_special_valid
675
+ valid = torch.cat([valid_p, valid_n], dim=-2).float()
676
+
677
+ # Eta bounds
678
+ eta_ok = (torch.abs(eta) >= self.min_eta) & \
679
+ ((math.pi - torch.abs(eta)) >= self.min_eta)
680
+ valid = valid * eta_ok.float()
681
+
682
+ return all_omega, eta, two_theta, valid
683
+
684
+ # ------------------------------------------------------------------
685
+ # Tilt rotation matrix (RotationTilts in SharedFuncsFit.c:230-266)
686
+ # ------------------------------------------------------------------
687
+
688
+ @staticmethod
689
+ def _build_rot_tilts(tx_deg: float, ty_deg: float, tz_deg: float,
690
+ device: torch.device) -> torch.Tensor:
691
+ """Build the 3x3 NF-style tilt rotation matrix Rz(tz) @ Ry(ty) @ Rx(tx).
692
+
693
+ Matches RotationTilts() in NF_HEDM/src/SharedFuncsFit.c:230-266.
694
+ """
695
+ d2r = math.pi / 180.0
696
+ tx, ty, tz = tx_deg * d2r, ty_deg * d2r, tz_deg * d2r
697
+ cx, sx = math.cos(tx), math.sin(tx)
698
+ cy, sy = math.cos(ty), math.sin(ty)
699
+ cz, sz = math.cos(tz), math.sin(tz)
700
+ Rx = torch.tensor([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=torch.float64)
701
+ Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=torch.float64)
702
+ Rz = torch.tensor([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=torch.float64)
703
+ # Composition matches NF C: Rz @ Ry @ Rx
704
+ return (Rz @ Ry @ Rx).to(device)
705
+
706
+ def _build_rot_tilts_from_param(self, dtype) -> torch.Tensor:
707
+ """Build Rz(tz) @ Ry(ty) @ Rx(tx) from self.tilts (differentiable)."""
708
+ d2r = math.pi / 180.0
709
+ t = self.tilts.to(dtype) * d2r
710
+ tx_, ty_, tz_ = t[0], t[1], t[2]
711
+ cx, sx = torch.cos(tx_), torch.sin(tx_)
712
+ cy, sy = torch.cos(ty_), torch.sin(ty_)
713
+ cz, sz = torch.cos(tz_), torch.sin(tz_)
714
+ zero = torch.zeros((), dtype=dtype, device=t.device)
715
+ one = torch.ones((), dtype=dtype, device=t.device)
716
+ Rx = torch.stack([
717
+ torch.stack([one, zero, zero]),
718
+ torch.stack([zero, cx, -sx ]),
719
+ torch.stack([zero, sx, cx ]),
720
+ ])
721
+ Ry = torch.stack([
722
+ torch.stack([cy, zero, sy ]),
723
+ torch.stack([zero, one, zero]),
724
+ torch.stack([-sy, zero, cy ]),
725
+ ])
726
+ Rz = torch.stack([
727
+ torch.stack([cz, -sz, zero]),
728
+ torch.stack([sz, cz, zero]),
729
+ torch.stack([zero, zero, one]),
730
+ ])
731
+ return Rz @ Ry @ Rx
732
+
733
+ def _apply_nf_tilt(self, ydet: torch.Tensor, zdet: torch.Tensor,
734
+ Lsd_val) -> "tuple[torch.Tensor, torch.Tensor]":
735
+ """Apply the NF detector-tilt correction to lab-frame (ydet, zdet).
736
+
737
+ Ports the ray-plane intersection in
738
+ ``NF_HEDM/src/SharedFuncsFit.c:947-958``: builds
739
+ P0 = RotMatTilts @ [-Lsd, 0, 0], P1 = RotMatTilts @ [0, ydet, zdet],
740
+ and returns the (y, z) coordinates where the line from P0 through P1
741
+ crosses the plane x = 0. Reduces to the identity when tilts are zero.
742
+
743
+ The rotation matrix is rebuilt from ``self.tilts`` on every call, so
744
+ if ``self.tilts.requires_grad`` is True the tilt parameters enter the
745
+ autograd graph and can be optimised via gradient descent
746
+ (auto-calibration). ``Lsd_val`` can be a Python scalar or a scalar tensor.
747
+ """
748
+ if not self._has_tilts and not self.tilts.requires_grad:
749
+ return ydet, zdet
750
+ dtype = ydet.dtype
751
+ R = self._build_rot_tilts_from_param(dtype)
752
+ # P0 = -Lsd * R[:, 0] (3 scalars)
753
+ p0x = -Lsd_val * R[0, 0]
754
+ p0y = -Lsd_val * R[1, 0]
755
+ p0z = -Lsd_val * R[2, 0]
756
+ # P1 = ydet * R[:, 1] + zdet * R[:, 2] (pointwise on tensors)
757
+ P1x = ydet * R[0, 1] + zdet * R[0, 2]
758
+ P1y = ydet * R[1, 1] + zdet * R[1, 2]
759
+ P1z = ydet * R[2, 1] + zdet * R[2, 2]
760
+ ABCx = P1x - p0x
761
+ ABCy = P1y - p0y
762
+ ABCz = P1z - p0z
763
+ safe_denom = torch.where(
764
+ torch.abs(ABCx) < self.epsilon,
765
+ torch.full_like(ABCx, self.epsilon),
766
+ ABCx,
767
+ )
768
+ out_y = p0y - ABCy * p0x / safe_denom
769
+ out_z = p0z - ABCz * p0x / safe_denom
770
+ return out_y, out_z
771
+
772
+ @staticmethod
773
+ def _build_ff_tilt_rot(tx_deg: float, ty_deg: float, tz_deg: float,
774
+ device: torch.device) -> torch.Tensor:
775
+ """Build the 3x3 FF-style tilt rotation matrix Rx(tx) @ Ry(ty) @ Rz(tz).
776
+
777
+ Matches CorrectTiltSpatialDistortion() in
778
+ FF_HEDM/src/ForwardSimulationCompressed.c:593-612.
779
+ Note the composition differs from the NF convention.
780
+ """
781
+ d2r = math.pi / 180.0
782
+ tx, ty, tz = tx_deg * d2r, ty_deg * d2r, tz_deg * d2r
783
+ cx, sx = math.cos(tx), math.sin(tx)
784
+ cy, sy = math.cos(ty), math.sin(ty)
785
+ cz, sz = math.cos(tz), math.sin(tz)
786
+ Rx = torch.tensor([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=torch.float64)
787
+ Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=torch.float64)
788
+ Rz = torch.tensor([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=torch.float64)
789
+ # Composition matches FF C: Rx @ Ry @ Rz
790
+ return (Rx @ Ry @ Rz).to(device)
791
+
792
+ # ------------------------------------------------------------------
793
+ # project_to_detector
794
+ # ------------------------------------------------------------------
795
+
796
+ def project_to_detector(
797
+ self,
798
+ omega: torch.Tensor,
799
+ eta: torch.Tensor,
800
+ two_theta: torch.Tensor,
801
+ positions: torch.Tensor,
802
+ valid: torch.Tensor,
803
+ ) -> SpotDescriptors:
804
+ """Position-dependent detector projection for one or more distances.
805
+
806
+ Implements the geometry from ``SharedFuncsFit.c:DisplacementSpots``
807
+ (lines 269-292). For multi-distance NF-HEDM, projects to each
808
+ distance and requires spots to be on-detector at ALL distances
809
+ (the ``AllDistsFound`` logic from ``CalcFracOverlap``, line 638).
810
+
811
+ Parameters
812
+ ----------
813
+ omega : Tensor (..., 2N, M)
814
+ eta : Tensor (..., 2N, M)
815
+ two_theta : Tensor (..., 2N, M)
816
+ positions : Tensor (N, 3) or (..., N, 3)
817
+ Real-space positions [x, y, z] in micrometers.
818
+ valid : Tensor (..., 2N, M)
819
+
820
+ Returns
821
+ -------
822
+ SpotDescriptors
823
+ """
824
+ N = positions.shape[-2]
825
+
826
+ # Omega-rotated position: rotate (x,y,z) by omega around z-axis
827
+ pos_doubled = torch.cat([positions, positions], dim=-2) # (..., 2N, 3)
828
+
829
+ cos_w = torch.cos(omega) # (..., 2N, M)
830
+ sin_w = torch.sin(omega)
831
+
832
+ px = pos_doubled[..., 0].unsqueeze(-1) # (..., 2N, 1)
833
+ py = pos_doubled[..., 1].unsqueeze(-1)
834
+ pz = pos_doubled[..., 2].unsqueeze(-1)
835
+
836
+ x_grain = px * cos_w - py * sin_w # (..., 2N, M)
837
+ y_grain = px * sin_w + py * cos_w
838
+ z_grain = pz.expand_as(x_grain)
839
+
840
+ tan_2th = torch.tan(two_theta)
841
+ sin_eta = torch.sin(eta)
842
+ cos_eta = torch.cos(eta)
843
+
844
+ # Frame number (same at all distances -- omega doesn't change)
845
+ frame_nr = (omega / self.DEG2RAD - self.omega_start) / self.omega_step
846
+ frame_ok = (frame_nr >= 0) & (frame_nr < self.n_frames)
847
+
848
+ # Project to each detector distance
849
+ D = self.n_distances
850
+ dtype = omega.dtype
851
+ # _Lsd, _y_BC, _z_BC are (D,) tensors
852
+ # Reshape to (D, 1..., 1, 1) for broadcasting against (..., 2N, M)
853
+ extra_dims = omega.ndim # number of dims in (..., 2N, M)
854
+ Lsd_d = self._Lsd.to(dtype).reshape(D, *([1] * extra_dims))
855
+ yBC_d = self._y_BC.to(dtype).reshape(D, *([1] * extra_dims))
856
+ zBC_d = self._z_BC.to(dtype).reshape(D, *([1] * extra_dims))
857
+
858
+ # ydet, zdet, y_pixel, z_pixel all get shape (D, ..., 2N, M)
859
+ dist_d = Lsd_d - x_grain.unsqueeze(0) # (D, ..., 2N, M)
860
+ ydet_d = y_grain.unsqueeze(0) - dist_d * tan_2th.unsqueeze(0) * sin_eta.unsqueeze(0)
861
+ zdet_d = z_grain.unsqueeze(0) + dist_d * tan_2th.unsqueeze(0) * cos_eta.unsqueeze(0)
862
+
863
+ # Apply detector tilt -- NF mode only.
864
+ #
865
+ # Design note: FF and pf-HEDM experimental workflows apply a DetCor
866
+ # correction at peak-finding time, so the per-spot centroids in
867
+ # SpotMatrix.csv are already tilt- and distortion-corrected. A
868
+ # differentiable forward model targeting FF/pf experimental data
869
+ # therefore must NOT apply tilts -- doing so would double-correct.
870
+ # NF-HEDM works at pixel level against raw detector images, with no
871
+ # DetCor step, so the forward model MUST include tilts to produce
872
+ # pixel predictions that match real NF measurements.
873
+ #
874
+ # The NF branch below ports the ray-plane intersection from
875
+ # NF_HEDM/src/SharedFuncsFit.c:947-958 (composition Rz @ Ry @ Rx,
876
+ # P0 = R @ [-Lsd, 0, 0]). The FF/pf path ignores tilts entirely.
877
+ if (not self.flip_y) and self._has_tilts:
878
+ # Per-distance tilt application. Use a loop to keep per-Lsd
879
+ # handling explicit (typical NF has 1-4 distances).
880
+ Lsd_list = self._Lsd.to(dtype)
881
+ out_y = []
882
+ out_z = []
883
+ for d in range(self.n_distances):
884
+ yd, zd = self._apply_nf_tilt(
885
+ ydet_d[d], zdet_d[d], Lsd_list[d]
886
+ )
887
+ out_y.append(yd)
888
+ out_z.append(zd)
889
+ ydet_d = torch.stack(out_y, dim=0)
890
+ zdet_d = torch.stack(out_z, dim=0)
891
+
892
+ # FF/PF: y-axis on detector flipped (yBC - ydet/px), validated against C
893
+ # NF: not flipped (yBC + ydet/px), validated against C
894
+ y_sign = -1.0 if self.flip_y else 1.0
895
+ y_pixel_d = yBC_d + y_sign * ydet_d / self.px # (D, ..., 2N, M)
896
+ z_pixel_d = zBC_d + zdet_d / self.px
897
+
898
+ # Per-distance detector bounds
899
+ layer_bounds_ok = (
900
+ (y_pixel_d >= 0) & (y_pixel_d < self.n_pixels_y) &
901
+ (z_pixel_d >= 0) & (z_pixel_d < self.n_pixels_z)
902
+ ) # (D, ..., 2N, M)
903
+
904
+ # Per-distance validity = angular valid & frame ok & detector bounds
905
+ layer_valid = valid.unsqueeze(0) * frame_ok.unsqueeze(0).float() * layer_bounds_ok.float()
906
+
907
+ # Overall valid = valid at ALL distances (AllDistsFound)
908
+ all_dists_valid = layer_valid.prod(dim=0) # (..., 2N, M)
909
+
910
+ # For single-distance, squeeze out the D dimension for convenience
911
+ if D == 1:
912
+ y_pixel_out = y_pixel_d.squeeze(0)
913
+ z_pixel_out = z_pixel_d.squeeze(0)
914
+ layer_valid_out = None
915
+ else:
916
+ y_pixel_out = y_pixel_d
917
+ z_pixel_out = z_pixel_d
918
+ layer_valid_out = layer_valid
919
+
920
+ return SpotDescriptors(
921
+ omega=omega,
922
+ eta=eta,
923
+ two_theta=two_theta,
924
+ y_pixel=y_pixel_out,
925
+ z_pixel=z_pixel_out,
926
+ frame_nr=frame_nr,
927
+ valid=all_dists_valid,
928
+ layer_valid=layer_valid_out,
929
+ )
930
+
931
+ # ------------------------------------------------------------------
932
+ # forward (orchestrator)
933
+ # ------------------------------------------------------------------
934
+
935
+ def forward(
936
+ self,
937
+ euler_angles: torch.Tensor,
938
+ positions: torch.Tensor,
939
+ lattice_params: Optional[torch.Tensor] = None,
940
+ strain: Optional[torch.Tensor] = None,
941
+ ) -> SpotDescriptors:
942
+ """Full forward simulation pipeline.
943
+
944
+ Parameters
945
+ ----------
946
+ euler_angles : Tensor (..., N, 3)
947
+ Euler angles (phi1, Phi, phi2) in radians at each position.
948
+ positions : Tensor (N, 3) or (N, 2) or (..., N, 3)
949
+ Real-space positions in micrometers. If (N,2), z is padded to 0.
950
+ lattice_params : Tensor (..., 6) or (..., N, 6), optional
951
+ Strained lattice parameters [a,b,c,alpha,beta,gamma] in
952
+ Angstroms/degrees. None = use nominal hkls/thetas (no strain).
953
+ strain : Tensor (..., 6) or (..., N, 6), optional
954
+ Crystal-frame symmetric infinitesimal strain in Voigt form
955
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. Applied as
956
+ B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
957
+ strain expressed through ``lattice_params``. Requires
958
+ ``lattice_params`` to be supplied.
959
+
960
+ Returns
961
+ -------
962
+ SpotDescriptors
963
+ """
964
+ # Backward compat: pad (N,2) -> (N,3)
965
+ if positions.shape[-1] == 2:
966
+ positions = F.pad(positions, (0, 1), value=0.0)
967
+
968
+ # 1. Orientation matrices
969
+ orientation_matrices = self.euler2mat(euler_angles)
970
+
971
+ # 2. Optionally compute strained G-vectors / thetas
972
+ hkls_cart = None
973
+ thetas = None
974
+ if lattice_params is not None:
975
+ hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
976
+ elif strain is not None:
977
+ raise ValueError(
978
+ "strain was supplied but lattice_params is None; strain "
979
+ "requires a reference lattice to apply (I + eps)^{-1} @ B0."
980
+ )
981
+
982
+ # 3. Bragg geometry
983
+ omega, eta, two_theta, valid = self.calc_bragg_geometry(
984
+ orientation_matrices, hkls_cart, thetas
985
+ )
986
+
987
+ # 4. Detector projection
988
+ spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
989
+
990
+ # 5. Scan filter (multi-scan only)
991
+ if self.scan_config is not None:
992
+ spots = self.filter_by_scan(spots, positions)
993
+
994
+ return spots
995
+
996
+ # ------------------------------------------------------------------
997
+ # filter_by_scan (beam proximity for pf-HEDM)
998
+ # ------------------------------------------------------------------
999
+
1000
+ def filter_by_scan(
1001
+ self, spots: SpotDescriptors, positions: torch.Tensor
1002
+ ) -> SpotDescriptors:
1003
+ """Apply beam illumination filter for multi-scan geometry.
1004
+
1005
+ For each spot, checks whether the omega-rotated y-position of the
1006
+ source voxel falls within the beam at each scan position.
1007
+
1008
+ Ports ``FitOrStrainsScanningOMP.c:1050-1058``:
1009
+ yRot = posX * sin(omega) + posY * cos(omega)
1010
+ |yRot - beam_y[scan]| < beam_size / 2
1011
+
1012
+ Parameters
1013
+ ----------
1014
+ spots : SpotDescriptors
1015
+ positions : Tensor (..., N, 3)
1016
+
1017
+ Returns
1018
+ -------
1019
+ SpotDescriptors with ``scan_mask`` populated.
1020
+ """
1021
+ if self.scan_config is None:
1022
+ return spots
1023
+
1024
+ N = positions.shape[-2]
1025
+ pos_doubled = torch.cat([positions, positions], dim=-2) # (..., 2N, 3)
1026
+
1027
+ cos_w = torch.cos(spots.omega) # (..., 2N, M)
1028
+ sin_w = torch.sin(spots.omega)
1029
+
1030
+ px = pos_doubled[..., 0].unsqueeze(-1) # (..., 2N, 1)
1031
+ py = pos_doubled[..., 1].unsqueeze(-1)
1032
+
1033
+ # Omega-rotated y position
1034
+ y_rot = px * sin_w + py * cos_w # (..., 2N, M)
1035
+
1036
+ # beam_positions: (S,)
1037
+ beam_y = self._beam_positions
1038
+ half_beam = self._beam_size / 2.0
1039
+
1040
+ # |yRot - beam_y[s]| < half_beam
1041
+ # y_rot: (..., 2N, M), beam_y: (S,)
1042
+ # Expand for broadcasting: (..., 1, 2N, M) vs (S, 1, 1)
1043
+ y_rot_exp = y_rot.unsqueeze(-3) # (..., 1, 2N, M)
1044
+ beam_y_exp = beam_y.reshape(-1, 1, 1) # (S, 1, 1)
1045
+
1046
+ scan_mask = (torch.abs(y_rot_exp - beam_y_exp) < half_beam).float()
1047
+ # Combine with overall validity
1048
+ scan_mask = scan_mask * spots.valid.unsqueeze(-3)
1049
+
1050
+ spots = SpotDescriptors(
1051
+ omega=spots.omega,
1052
+ eta=spots.eta,
1053
+ two_theta=spots.two_theta,
1054
+ y_pixel=spots.y_pixel,
1055
+ z_pixel=spots.z_pixel,
1056
+ frame_nr=spots.frame_nr,
1057
+ valid=spots.valid,
1058
+ scan_mask=scan_mask,
1059
+ )
1060
+ return spots
1061
+
1062
+ # ------------------------------------------------------------------
1063
+ # predict_images (NF output mode: Gaussian splatting)
1064
+ # ------------------------------------------------------------------
1065
+
1066
+ @staticmethod
1067
+ def predict_images(
1068
+ spots: SpotDescriptors,
1069
+ n_frames: int,
1070
+ n_pixels_y: int,
1071
+ n_pixels_z: int,
1072
+ sigma: float = 1.0,
1073
+ radius: int = 3,
1074
+ ) -> torch.Tensor:
1075
+ """Gaussian-splat predicted spots onto detector grid.
1076
+
1077
+ This is the NF-HEDM output mode. Each valid spot is represented
1078
+ as a Gaussian blob on the (frame, y, z) detector volume.
1079
+
1080
+ Parameters
1081
+ ----------
1082
+ spots : SpotDescriptors
1083
+ Output from ``forward()``.
1084
+ n_frames, n_pixels_y, n_pixels_z : int
1085
+ Grid dimensions.
1086
+ sigma : float
1087
+ Gaussian kernel sigma in pixels.
1088
+ radius : int
1089
+ Kernel radius in pixels.
1090
+
1091
+ Returns
1092
+ -------
1093
+ Tensor (..., n_frames, n_pixels_y, n_pixels_z)
1094
+ """
1095
+ # Extract coordinates and mask
1096
+ frame_nr = spots.frame_nr # (..., K, M)
1097
+ y_pix = spots.y_pixel
1098
+ z_pix = spots.z_pixel
1099
+ valid = spots.valid
1100
+
1101
+ # Flatten batch dims for processing
1102
+ orig_shape = frame_nr.shape # (..., K, M)
1103
+ batch_shape = orig_shape[:-2]
1104
+ n_batch = 1
1105
+ for s in batch_shape:
1106
+ n_batch *= s
1107
+ KM = orig_shape[-2] * orig_shape[-1]
1108
+
1109
+ # Reshape to (B, KM, 3) where 3 = (frame, y, z)
1110
+ coords = torch.stack([frame_nr, y_pix, z_pix], dim=-1) # (..., K, M, 3)
1111
+ coords = coords.reshape(n_batch, KM, 3)
1112
+ mask = valid.reshape(n_batch, KM)
1113
+
1114
+ # Zero out invalid spots
1115
+ coords = coords * mask.unsqueeze(-1)
1116
+
1117
+ device = coords.device
1118
+ dtype = coords.dtype
1119
+
1120
+ grids = torch.zeros(n_batch, n_frames, n_pixels_y, n_pixels_z,
1121
+ dtype=dtype, device=device)
1122
+ gaussian_factor = -0.5 / (sigma ** 2)
1123
+
1124
+ # Filter non-zero coordinates
1125
+ batch_ids = torch.arange(n_batch, device=device).unsqueeze(1).expand(-1, KM)
1126
+ non_zero_mask = mask > 0.5
1127
+ non_zero_coords = coords[non_zero_mask]
1128
+ non_zero_batch = batch_ids[non_zero_mask]
1129
+
1130
+ if non_zero_coords.shape[0] == 0:
1131
+ return grids.reshape(*batch_shape, n_frames, n_pixels_y, n_pixels_z)
1132
+
1133
+ rounded = non_zero_coords.round().long()
1134
+
1135
+ # Neighborhood offsets
1136
+ offsets = torch.arange(-radius, radius + 1, device=device)
1137
+ oz, ox, oy = torch.meshgrid(offsets, offsets, offsets, indexing="ij")
1138
+ local_offsets = torch.stack([oz, ox, oy], dim=-1).reshape(-1, 3)
1139
+ n_local = local_offsets.shape[0]
1140
+
1141
+ expanded_centers = rounded.unsqueeze(1) # (P, 1, 3)
1142
+ neighbors = expanded_centers + local_offsets.unsqueeze(0) # (P, L, 3)
1143
+
1144
+ f_neigh = neighbors[..., 0].clamp(0, n_frames - 1)
1145
+ y_neigh = neighbors[..., 1].clamp(0, n_pixels_y - 1)
1146
+ z_neigh = neighbors[..., 2].clamp(0, n_pixels_z - 1)
1147
+
1148
+ distances = torch.sum(
1149
+ (neighbors.float() - non_zero_coords.unsqueeze(1)) ** 2, dim=-1
1150
+ )
1151
+ weights = torch.exp(distances * gaussian_factor)
1152
+
1153
+ # Flatten for scatter_add
1154
+ f_flat = f_neigh.flatten()
1155
+ y_flat = y_neigh.flatten()
1156
+ z_flat = z_neigh.flatten()
1157
+ w_flat = weights.flatten()
1158
+
1159
+ batch_flat = non_zero_batch.unsqueeze(1).expand(-1, n_local).flatten()
1160
+
1161
+ flat_idx = (batch_flat * (n_frames * n_pixels_y * n_pixels_z)
1162
+ + f_flat * (n_pixels_y * n_pixels_z)
1163
+ + y_flat * n_pixels_z
1164
+ + z_flat)
1165
+
1166
+ grids.view(-1).scatter_add_(0, flat_idx, w_flat)
1167
+
1168
+ return grids.reshape(*batch_shape, n_frames, n_pixels_y, n_pixels_z)
1169
+
1170
+ # ------------------------------------------------------------------
1171
+ # predict_spot_coords (FF/pf output mode)
1172
+ # ------------------------------------------------------------------
1173
+
1174
+ @staticmethod
1175
+ def predict_spot_coords(
1176
+ spots: SpotDescriptors,
1177
+ space: str = "angular",
1178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1179
+ """Extract spot coordinates for COM matching (FF/pf mode).
1180
+
1181
+ Parameters
1182
+ ----------
1183
+ spots : SpotDescriptors
1184
+ space : str
1185
+ ``"angular"``: return (2theta, eta, omega) in radians.
1186
+ ``"detector"``: return (y_pixel, z_pixel, frame_nr).
1187
+
1188
+ Returns
1189
+ -------
1190
+ coords : Tensor (..., K, M, 3)
1191
+ valid : Tensor (..., K, M)
1192
+ """
1193
+ if space == "angular":
1194
+ coords = torch.stack(
1195
+ [spots.two_theta, spots.eta, spots.omega], dim=-1
1196
+ )
1197
+ elif space == "detector":
1198
+ coords = torch.stack(
1199
+ [spots.y_pixel, spots.z_pixel, spots.frame_nr], dim=-1
1200
+ )
1201
+ else:
1202
+ raise ValueError(f"Unknown space: {space!r}. Use 'angular' or 'detector'.")
1203
+ return coords, spots.valid
1204
+
1205
+ # ------------------------------------------------------------------
1206
+ # Triangular voxel support (NF-HEDM)
1207
+ # ------------------------------------------------------------------
1208
+
1209
+ @staticmethod
1210
+ def tri_vertices(
1211
+ centers: torch.Tensor,
1212
+ edge_lengths: torch.Tensor,
1213
+ ud: torch.Tensor,
1214
+ ) -> torch.Tensor:
1215
+ """Compute 3 triangle vertices from voxel centres.
1216
+
1217
+ Matches ``simulateNF.c`` lines 556-572.
1218
+
1219
+ Parameters
1220
+ ----------
1221
+ centers : Tensor (N, 2) or (N, 3)
1222
+ Voxel centres [x, y] or [x, y, z] in micrometers.
1223
+ edge_lengths : Tensor (N,)
1224
+ Edge length per voxel in micrometers.
1225
+ ud : Tensor (N,)
1226
+ Up/down flag (+1 or -1).
1227
+
1228
+ Returns
1229
+ -------
1230
+ Tensor (N, 3, 3)
1231
+ Vertices ``[V0, V1, V2]`` each of shape ``(3,)`` = ``[x, y, z]``.
1232
+ If input centres are 2D, z is set to 0.
1233
+ """
1234
+ if centers.shape[-1] == 2:
1235
+ centers = F.pad(centers, (0, 1), value=0.0)
1236
+
1237
+ xs = centers[:, 0]
1238
+ ys = centers[:, 1]
1239
+ zs = centers[:, 2]
1240
+
1241
+ gs = edge_lengths / 2.0
1242
+ sqrt3 = math.sqrt(3.0)
1243
+ dy1 = edge_lengths / sqrt3
1244
+ dy2 = -edge_lengths / (2.0 * sqrt3)
1245
+ # flip if ud < 0
1246
+ sign = torch.sign(ud)
1247
+ dy1 = dy1 * sign
1248
+ dy2 = dy2 * sign
1249
+
1250
+ # V0 = (xs, ys+dy1, zs), V1 = (xs-gs, ys+dy2, zs), V2 = (xs+gs, ys+dy2, zs)
1251
+ V0 = torch.stack([xs, ys + dy1, zs], dim=-1)
1252
+ V1 = torch.stack([xs - gs, ys + dy2, zs], dim=-1)
1253
+ V2 = torch.stack([xs + gs, ys + dy2, zs], dim=-1)
1254
+
1255
+ return torch.stack([V0, V1, V2], dim=1) # (N, 3, 3)
1256
+
1257
+ @staticmethod
1258
+ def _c_round(x: float) -> int:
1259
+ """C-compatible round(): half away from zero.
1260
+
1261
+ Python's ``round()`` uses banker's rounding (half to even),
1262
+ which differs from C's ``round()`` at ``.5`` boundaries.
1263
+ C's ``(int)round(x)`` = ``floor(x + 0.5)`` for x >= 0,
1264
+ ``ceil(x - 0.5)`` for x < 0.
1265
+ """
1266
+ return int(math.floor(x + 0.5)) if x >= 0 else int(math.ceil(x - 0.5))
1267
+
1268
+ @staticmethod
1269
+ def rasterize_triangle(v0y, v0z, v1y, v1z, v2y, v2z):
1270
+ """Rasterize a single triangle on an integer pixel grid.
1271
+
1272
+ Matches ``CalcPixels2`` in ``SharedFuncsFit.c`` lines 308-370.
1273
+ Uses the edge-function rasterizer with a distance-based border
1274
+ (``distSq < 0.9801``) for edges, exactly as the C code does.
1275
+
1276
+ Parameters
1277
+ ----------
1278
+ v0y, v0z, v1y, v1z, v2y, v2z : int
1279
+ Rounded integer pixel coordinates of the 3 vertices.
1280
+
1281
+ Returns
1282
+ -------
1283
+ list of (int, int)
1284
+ List of (y, z) integer pixel coordinates inside the triangle.
1285
+ """
1286
+ min_y = min(v0y, v1y, v2y)
1287
+ max_y = max(v0y, v1y, v2y)
1288
+ min_z = min(v0z, v1z, v2z)
1289
+ max_z = max(v0z, v1z, v2z)
1290
+
1291
+ # Edge function coefficients (matching C variable names)
1292
+ A01 = v0z - v1z; B01 = v1y - v0y
1293
+ A12 = v1z - v2z; B12 = v2y - v1y
1294
+ A20 = v2z - v0z; B20 = v0y - v2y
1295
+
1296
+ def orient2d(ax, ay, bx, by, cx, cy):
1297
+ return (bx - ax) * (cy - ay) - (by - ay) * (cx - ax)
1298
+
1299
+ def dist_sq_to_edge(ax, ay, bx, by, px, py):
1300
+ num = (bx - ax) * (py - ay) - (by - ay) * (px - ax)
1301
+ den_sq = (ay - by) ** 2 + (bx - ax) ** 2
1302
+ if den_sq == 0:
1303
+ return 1e30
1304
+ return (num * num) / den_sq
1305
+
1306
+ pixels = []
1307
+ w0_row = orient2d(v1y, v1z, v2y, v2z, min_y, min_z)
1308
+ w1_row = orient2d(v2y, v2z, v0y, v0z, min_y, min_z)
1309
+ w2_row = orient2d(v0y, v0z, v1y, v1z, min_y, min_z)
1310
+
1311
+ for pz in range(min_z, max_z + 1):
1312
+ w0 = w0_row
1313
+ w1 = w1_row
1314
+ w2 = w2_row
1315
+ for py in range(min_y, max_y + 1):
1316
+ inside = (w0 >= 0 and w1 >= 0 and w2 >= 0)
1317
+ if not inside:
1318
+ # Check distance to each edge (0.9801 = 0.99^2)
1319
+ inside = (
1320
+ dist_sq_to_edge(v1y, v1z, v2y, v2z, py, pz) < 0.9801 or
1321
+ dist_sq_to_edge(v2y, v2z, v0y, v0z, py, pz) < 0.9801 or
1322
+ dist_sq_to_edge(v0y, v0z, v1y, v1z, py, pz) < 0.9801
1323
+ )
1324
+ if inside:
1325
+ pixels.append((py, pz))
1326
+ w0 += A12
1327
+ w1 += A20
1328
+ w2 += A01
1329
+ w0_row += B12
1330
+ w1_row += B20
1331
+ w2_row += B01
1332
+
1333
+ return pixels
1334
+
1335
+ def forward_nf_triangles(
1336
+ self,
1337
+ euler_angles: torch.Tensor,
1338
+ centers: torch.Tensor,
1339
+ tri_config: TriVoxelConfig,
1340
+ lattice_params: Optional[torch.Tensor] = None,
1341
+ strain: Optional[torch.Tensor] = None,
1342
+ ) -> SpotDescriptors:
1343
+ """NF-HEDM forward simulation with triangular voxel rasterization.
1344
+
1345
+ Matches the full ``simulateNF`` / ``CalcFracOverlap`` pipeline:
1346
+ compute Bragg geometry once per voxel, project 3 triangle vertices,
1347
+ rasterize the detector-space triangle, check all distances.
1348
+
1349
+ Parameters
1350
+ ----------
1351
+ euler_angles : Tensor (N, 3) in radians
1352
+ centers : Tensor (N, 2) or (N, 3) in micrometers
1353
+ tri_config : TriVoxelConfig
1354
+ lattice_params : optional strain
1355
+
1356
+ Returns
1357
+ -------
1358
+ SpotDescriptors with per-pixel y_pixel/z_pixel (not per-vertex).
1359
+ y_pixel, z_pixel: lists of per-voxel, per-spot rasterized pixel coords.
1360
+ For bit-level comparison, use ``predict_spotsinfo_bits`` instead.
1361
+ """
1362
+ if centers.shape[-1] == 2:
1363
+ centers = F.pad(centers, (0, 1), value=0.0)
1364
+
1365
+ N = centers.shape[0]
1366
+ vertices = self.tri_vertices(
1367
+ centers, tri_config.edge_lengths, tri_config.ud
1368
+ ) # (N, 3, 3)
1369
+
1370
+ # 1. Compute orientation matrices.
1371
+ # C: reads radians from .mic, multiplies by rad2deg, then
1372
+ # Euler2OrientMat uses cosd()=cos(deg2rad*x). With both C and Python
1373
+ # now using M_PI/180, the roundtrip is lossless and we can use
1374
+ # radians directly.
1375
+ euler_deg_c = euler_angles * self.RAD2DEG
1376
+ euler_rad_c = euler_deg_c * self.DEG2RAD
1377
+ orientation_matrices = self.euler2mat(euler_rad_c)
1378
+
1379
+ # 2. Optionally strained HKLs
1380
+ hkls_cart = thetas = None
1381
+ if lattice_params is not None:
1382
+ hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
1383
+ elif strain is not None:
1384
+ raise ValueError(
1385
+ "strain was supplied but lattice_params is None; strain "
1386
+ "requires a reference lattice to apply (I + eps)^{-1} @ B0."
1387
+ )
1388
+
1389
+ # 3. Bragg geometry (once per voxel, from center orientation)
1390
+ omega, eta, two_theta, valid = self.calc_bragg_geometry(
1391
+ orientation_matrices, hkls_cart, thetas
1392
+ )
1393
+ # omega, eta, two_theta, valid: (2N, M)
1394
+
1395
+ # Recompute G_C for eta recomputation below
1396
+ dtype = omega.dtype
1397
+ use_hkls = hkls_cart if hkls_cart is not None else self.hkls.to(dtype)
1398
+ G_C = torch.einsum("...nij,mj->...nmi", orientation_matrices, use_hkls)
1399
+
1400
+ # 4. Frame number
1401
+ frame_nr = (omega / self.DEG2RAD - self.omega_start) / self.omega_step
1402
+ frame_ok = (frame_nr >= 0) & (frame_nr < self.n_frames)
1403
+ valid = valid * frame_ok.float()
1404
+
1405
+ # 5. Project each vertex through DisplacementSpots for each distance
1406
+ D = self.n_distances
1407
+ dtype = omega.dtype
1408
+ Lsd_0 = self._Lsd[0].to(dtype)
1409
+
1410
+ # ---------------------------------------------------------------
1411
+ # Replicate C's exact CalcOmega→CalcSpotPosition chain to avoid
1412
+ # float-precision differences at pixel boundaries.
1413
+ #
1414
+ # C CalcOmega stores omega in DEGREES: omega_deg = acos(...)*rad2deg
1415
+ # C then recomputes eta by RotateAroundZ(G, omega_DEG):
1416
+ # internally: omega_rad2 = omega_deg * deg2rad (roundtrip!)
1417
+ # gw = Rz(omega_rad2) @ G
1418
+ # eta_deg = CalcEtaAngle(gw[1], gw[2])
1419
+ # C CalcSpotPosition: eta_rad = eta_deg * deg2rad (another roundtrip!)
1420
+ # yl = -sin(eta_rad) * RingRadius
1421
+ # zl = cos(eta_rad) * RingRadius
1422
+ # C RingRadius = Lsd * tan(2 * deg2rad * Theta_deg)
1423
+ #
1424
+ # We replicate every degree conversion.
1425
+ # ---------------------------------------------------------------
1426
+ omega_deg = omega * self.RAD2DEG
1427
+ omega_rad_c = omega_deg * self.DEG2RAD # C's deg2rad * omega_deg
1428
+
1429
+ cos_w = torch.cos(omega_rad_c)
1430
+ sin_w = torch.sin(omega_rad_c)
1431
+
1432
+ # Recompute eta via C's chain: rotate G by omega_deg, then CalcEtaAngle
1433
+ G_C_doubled = torch.cat([G_C, G_C], dim=-3) # (2N, M, 3)
1434
+ # gw = Rz(omega_rad_c) @ G
1435
+ gw_y = G_C_doubled[..., 0] * sin_w + G_C_doubled[..., 1] * cos_w
1436
+ gw_z = G_C_doubled[..., 2]
1437
+ r_yz = torch.sqrt(gw_y * gw_y + gw_z * gw_z).clamp(min=self.epsilon)
1438
+ eta_c_rad = torch.acos(torch.clamp(gw_z / r_yz,
1439
+ -1.0 + self.epsilon, 1.0 - self.epsilon))
1440
+ eta_c_rad = -torch.sign(gw_y) * eta_c_rad # C: if (y > 0) alpha = -alpha
1441
+ # Convert to degrees then back (C's CalcSpotPosition chain)
1442
+ eta_c_deg = eta_c_rad * self.RAD2DEG
1443
+ eta_c_rad2 = eta_c_deg * self.DEG2RAD
1444
+
1445
+ theta_deg = (two_theta / 2.0) * self.RAD2DEG
1446
+ ring_radius = Lsd_0 * torch.tan(2.0 * self.DEG2RAD * theta_deg)
1447
+
1448
+ sin_eta_c = torch.sin(eta_c_rad2)
1449
+ cos_eta_c = torch.cos(eta_c_rad2)
1450
+ ythis = -sin_eta_c * ring_radius
1451
+ zthis = cos_eta_c * ring_radius
1452
+
1453
+ # Project center reference. For non-zero tilts, apply the NF
1454
+ # ray-plane intersection (SharedFuncsFit.c:947-958).
1455
+ # YZSpotsTemp = outxyz/px + bc
1456
+ ybc_0 = self._y_BC[0].to(dtype)
1457
+ zbc_0 = self._z_BC[0].to(dtype)
1458
+ Lsd_0_scalar = self._Lsd[0].to(dtype)
1459
+ y_center_lab, z_center_lab = self._apply_nf_tilt(ythis, zthis, Lsd_0_scalar)
1460
+ y_center = y_center_lab / self.px + ybc_0 # (2N, M)
1461
+ z_center = z_center_lab / self.px + zbc_0
1462
+
1463
+ # Project each vertex: DisplacementSpots for each of 3 vertices
1464
+ # vertices: (N, 3, 3) -> double to (2N, 3, 3)
1465
+ verts_doubled = torch.cat([vertices, vertices], dim=0) # (2N, 3, 3)
1466
+
1467
+ # For each vertex v, compute Displ_Y, Displ_Z:
1468
+ # xa = vx*cos(w) - vy*sin(w), ya = vx*sin(w) + vy*cos(w)
1469
+ # Displ_Y = ya + ythis*(1-xa/Lsd), Displ_Z = (1-xa/Lsd)*zthis
1470
+ vert_y_pixel = []
1471
+ vert_z_pixel = []
1472
+ for vi in range(3):
1473
+ vx = verts_doubled[:, vi, 0].unsqueeze(-1) # (2N, 1)
1474
+ vy = verts_doubled[:, vi, 1].unsqueeze(-1)
1475
+
1476
+ xa = vx * cos_w - vy * sin_w # (2N, M)
1477
+ ya = vx * sin_w + vy * cos_w
1478
+
1479
+ t = 1.0 - xa / Lsd_0
1480
+ displ_y = ya + ythis * t
1481
+ displ_z = t * zthis
1482
+
1483
+ # Apply NF tilt (no-op when tilts are zero)
1484
+ displ_y_tilt, displ_z_tilt = self._apply_nf_tilt(
1485
+ displ_y, displ_z, Lsd_0_scalar
1486
+ )
1487
+ yp = displ_y_tilt / self.px + ybc_0
1488
+ zp = displ_z_tilt / self.px + zbc_0
1489
+ vert_y_pixel.append(yp)
1490
+ vert_z_pixel.append(zp)
1491
+
1492
+ # 6. Compute relative offsets from center (matching C lines 584-586)
1493
+ # YZSpots[k] = YZSpotsT[k] - YZSpotsTemp
1494
+ rel_y = [vyp - y_center for vyp in vert_y_pixel] # 3 x (2N, M)
1495
+ rel_z = [vzp - z_center for vzp in vert_z_pixel]
1496
+
1497
+ # 7. Rasterize and collect hits
1498
+ # This is the per-spot loop (not vectorizable due to variable triangle sizes)
1499
+ K = 2 * N
1500
+ M = omega.shape[-1]
1501
+
1502
+ all_hits = [] # list of (vox_nr, dist_nr, frame, y_px, z_px, omega_deg)
1503
+
1504
+ for k in range(K):
1505
+ for m in range(M):
1506
+ if valid[k, m] < 0.5:
1507
+ continue
1508
+
1509
+ # Relative triangle vertices in pixel coords
1510
+ ry = [rel_y[vi][k, m].item() for vi in range(3)]
1511
+ rz = [rel_z[vi][k, m].item() for vi in range(3)]
1512
+
1513
+ gs_um = tri_config.edge_lengths[k % N].item()
1514
+ cr = self._c_round
1515
+ if gs_um > self.px:
1516
+ # Rasterize triangle (CalcPixels2 rounds vertices, line 312)
1517
+ v0y, v0z = cr(ry[0]), cr(rz[0])
1518
+ v1y, v1z = cr(ry[1]), cr(rz[1])
1519
+ v2y, v2z = cr(ry[2]), cr(rz[2])
1520
+ pixels = self.rasterize_triangle(v0y, v0z, v1y, v1z, v2y, v2z)
1521
+ else:
1522
+ # Single center pixel (C line 596-599: (int)round(...))
1523
+ cy = cr((ry[0] + ry[1] + ry[2]) / 3.0)
1524
+ cz = cr((rz[0] + rz[1] + rz[2]) / 3.0)
1525
+ pixels = [(cy, cz)]
1526
+
1527
+ # Absolute pixel = center + offset, check all distances
1528
+ yc = y_center[k, m].item()
1529
+ zc = z_center[k, m].item()
1530
+ frame = int(frame_nr[k, m].item()) # C uses (int) truncation
1531
+ ome_deg = omega[k, m].item() * self.RAD2DEG
1532
+
1533
+ for (py_off, pz_off) in pixels:
1534
+ all_dists_ok = True
1535
+ layer_pixels = []
1536
+ for d in range(D):
1537
+ Lsd_d = self._Lsd[d].item()
1538
+ ybc_d = self._y_BC[d].item()
1539
+ zbc_d = self._z_BC[d].item()
1540
+ # Scale to this distance (matching C lines 605-613)
1541
+ my = int(math.floor(
1542
+ ((yc - ybc_0) * self.px * (Lsd_d / Lsd_0)) / self.px
1543
+ + ybc_d
1544
+ )) + py_off
1545
+ mz = int(math.floor(
1546
+ ((zc - zbc_0) * self.px * (Lsd_d / Lsd_0)) / self.px
1547
+ + zbc_d
1548
+ )) + pz_off
1549
+ if my < 0 or my >= self.n_pixels_y or mz < 0 or mz >= self.n_pixels_z:
1550
+ all_dists_ok = False
1551
+ break
1552
+ layer_pixels.append((d, frame, my, mz))
1553
+
1554
+ if all_dists_ok:
1555
+ vox_nr = k % N
1556
+ for (d, fr, my, mz) in layer_pixels:
1557
+ all_hits.append((vox_nr, d, fr, my, mz, ome_deg))
1558
+
1559
+ return all_hits