midas-diffract 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- midas_diffract/__init__.py +55 -0
- midas_diffract/forward.py +1559 -0
- midas_diffract/hkls.py +180 -0
- midas_diffract/losses.py +494 -0
- midas_diffract/optimize.py +248 -0
- midas_diffract-0.1.0.dist-info/METADATA +122 -0
- midas_diffract-0.1.0.dist-info/RECORD +10 -0
- midas_diffract-0.1.0.dist-info/WHEEL +5 -0
- midas_diffract-0.1.0.dist-info/licenses/LICENSE +31 -0
- midas_diffract-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1559 @@
|
|
|
1
|
+
"""Generic differentiable forward model for all HEDM modalities (NF, FF, pf).
|
|
2
|
+
|
|
3
|
+
The core Bragg geometry, omega solver, and detector projection are identical
|
|
4
|
+
across modalities. Modality differences (detector distance, output mode, scan
|
|
5
|
+
strategy) are handled via configuration, not subclassing.
|
|
6
|
+
|
|
7
|
+
Physics pipeline:
|
|
8
|
+
euler_angles, positions [, lattice_params]
|
|
9
|
+
-> orientation matrices (euler2mat)
|
|
10
|
+
-> G-vectors in crystal frame (calc_bragg_geometry)
|
|
11
|
+
-> [optional: strained G-vectors] (correct_hkls_latc)
|
|
12
|
+
-> omega solver (quadratic) (calc_bragg_geometry)
|
|
13
|
+
-> eta computation (calc_bragg_geometry)
|
|
14
|
+
-> position-dependent projection (project_to_detector)
|
|
15
|
+
-> validity filtering (project_to_detector)
|
|
16
|
+
-> SpotDescriptors
|
|
17
|
+
|
|
18
|
+
Output modes:
|
|
19
|
+
SpotDescriptors -> predict_images() [NF: Gaussian splatting]
|
|
20
|
+
SpotDescriptors -> predict_spot_coords() [FF/pf: angular coordinates]
|
|
21
|
+
|
|
22
|
+
Reference C code:
|
|
23
|
+
CorrectHKLsLatC: FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252
|
|
24
|
+
CalcDiffrSpots_Furnace: NF_HEDM/src/CalcDiffractionSpots.c:87-183
|
|
25
|
+
DisplacementSpots: NF_HEDM/src/SharedFuncsFit.c:269-292
|
|
26
|
+
Beam proximity filter: FF_HEDM/src/FitOrStrainsScanningOMP.c:1048-1058
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import math
|
|
30
|
+
from dataclasses import dataclass, field
|
|
31
|
+
from typing import Optional, Tuple
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# Configuration data classes
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class HEDMGeometry:
|
|
44
|
+
"""Detector and scan geometry configuration.
|
|
45
|
+
|
|
46
|
+
Units: distances in micrometers, angles in degrees (converted internally).
|
|
47
|
+
|
|
48
|
+
NF-HEDM uses multiple detector distances (nDistances=2-4), each with its
|
|
49
|
+
own Lsd, y_BC, z_BC. FF-HEDM and pf-HEDM use a single distance.
|
|
50
|
+
When lists are provided for Lsd/y_BC/z_BC, each entry is one "layer"
|
|
51
|
+
(detector distance). A spot is valid only if it falls on the detector
|
|
52
|
+
at **every** distance (the AllDistsFound logic in the C code).
|
|
53
|
+
"""
|
|
54
|
+
Lsd: "float | list[float]" # Sample-detector distance(s) (um)
|
|
55
|
+
y_BC: "float | list[float]" # Beam center y (pixels), per distance
|
|
56
|
+
z_BC: "float | list[float]" # Beam center z (pixels), per distance
|
|
57
|
+
px: float # Pixel size (um) -- shared across distances
|
|
58
|
+
omega_start: float # Omega start (degrees)
|
|
59
|
+
omega_step: float # Omega step (degrees, may be negative)
|
|
60
|
+
n_frames: int # Frames per distance (NrFilesPerDistance)
|
|
61
|
+
n_pixels_y: int # Detector pixels in y
|
|
62
|
+
n_pixels_z: int # Detector pixels in z
|
|
63
|
+
min_eta: float # Minimum eta angle (degrees)
|
|
64
|
+
wavelength: float = 0.0 # X-ray wavelength (Angstroms)
|
|
65
|
+
# Detector tilts (degrees). Applied only in NF mode (flip_y=False), which
|
|
66
|
+
# compares predictions to raw detector images at pixel level. FF and
|
|
67
|
+
# pf-HEDM workflows apply a DetCor correction at peak-finding time, so
|
|
68
|
+
# their centroids are already tilt- and distortion-corrected; the
|
|
69
|
+
# forward model therefore ignores these fields when flip_y=True to avoid
|
|
70
|
+
# double-correcting.
|
|
71
|
+
tx: float = 0.0
|
|
72
|
+
ty: float = 0.0
|
|
73
|
+
tz: float = 0.0
|
|
74
|
+
flip_y: bool = True # FF/PF: True (DetHor = yBC - ydet/px).
|
|
75
|
+
# NF: False (pixel = yBC + ydet/px).
|
|
76
|
+
# Validated against C code conventions.
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def n_distances(self) -> int:
|
|
80
|
+
return len(self.Lsd) if isinstance(self.Lsd, list) else 1
|
|
81
|
+
|
|
82
|
+
def _as_list(self, attr):
|
|
83
|
+
v = getattr(self, attr)
|
|
84
|
+
return v if isinstance(v, list) else [v]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class ScanConfig:
|
|
89
|
+
"""Multi-scan configuration for pf-HEDM (beam translation positions).
|
|
90
|
+
|
|
91
|
+
In pf-HEDM the pencil beam is translated to different Y-positions.
|
|
92
|
+
Each scan is a full omega sweep at one beam Y-position.
|
|
93
|
+
NF-HEDM and FF-HEDM do NOT use this (they have a single beam position).
|
|
94
|
+
"""
|
|
95
|
+
beam_positions: torch.Tensor # (S,) beam y-positions per scan (um)
|
|
96
|
+
beam_size: float # Beam height (um)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class TriVoxelConfig:
|
|
101
|
+
"""Triangular voxel grid configuration for NF-HEDM.
|
|
102
|
+
|
|
103
|
+
NF-HEDM uses equilateral triangle voxels. Each voxel is defined by
|
|
104
|
+
a center ``(x, y)``, an ``edge_length``, and an up/down flag ``ud``.
|
|
105
|
+
The three vertices are computed as:
|
|
106
|
+
|
|
107
|
+
.. code-block:: text
|
|
108
|
+
|
|
109
|
+
gs = edge_length / 2
|
|
110
|
+
dy1 = edge_length / sqrt(3) (flipped if ud < 0)
|
|
111
|
+
dy2 = -edge_length / (2*sqrt(3))
|
|
112
|
+
|
|
113
|
+
V0 = (x, y + dy1)
|
|
114
|
+
V1 = (x - gs, y + dy2)
|
|
115
|
+
V2 = (x + gs, y + dy2)
|
|
116
|
+
|
|
117
|
+
Matches ``simulateNF.c`` lines 556-572.
|
|
118
|
+
"""
|
|
119
|
+
edge_lengths: torch.Tensor # (N,) per-voxel edge length in um
|
|
120
|
+
ud: torch.Tensor # (N,) up/down flag (+1 or -1)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class SpotDescriptors:
|
|
125
|
+
"""Output of the forward model: all information about predicted spots.
|
|
126
|
+
|
|
127
|
+
All angular quantities are in radians. Pixel coordinates are fractional.
|
|
128
|
+
Shape convention: ``(..., K, M)`` where ``K = 2*N`` (two omega solutions
|
|
129
|
+
per position) and ``M`` = number of HKL reflections.
|
|
130
|
+
|
|
131
|
+
For multi-distance NF-HEDM, ``y_pixel``, ``z_pixel``, and
|
|
132
|
+
``layer_valid`` have an extra leading ``D`` (n_distances) dimension.
|
|
133
|
+
The ``valid`` mask combines the angular validity with ALL-distances-found.
|
|
134
|
+
"""
|
|
135
|
+
omega: torch.Tensor # (..., K, M) radians
|
|
136
|
+
eta: torch.Tensor # (..., K, M) radians
|
|
137
|
+
two_theta: torch.Tensor # (..., K, M) radians
|
|
138
|
+
y_pixel: torch.Tensor # (D, ..., K, M) or (..., K, M) fractional pixel
|
|
139
|
+
z_pixel: torch.Tensor # (D, ..., K, M) or (..., K, M) fractional pixel
|
|
140
|
+
frame_nr: torch.Tensor # (..., K, M) fractional frame (same at all distances)
|
|
141
|
+
valid: torch.Tensor # (..., K, M) float mask (1=valid at ALL distances)
|
|
142
|
+
layer_valid: Optional[torch.Tensor] = None # (D, ..., K, M) per-distance validity
|
|
143
|
+
scan_mask: Optional[torch.Tensor] = None # (..., S, K, M) per-beam-position validity (pf)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# ---------------------------------------------------------------------------
|
|
147
|
+
# Forward model
|
|
148
|
+
# ---------------------------------------------------------------------------
|
|
149
|
+
|
|
150
|
+
class HEDMForwardModel(nn.Module):
|
|
151
|
+
"""Generic differentiable forward model for NF / FF / pf-HEDM.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
hkls : Tensor (M, 3)
|
|
156
|
+
Reciprocal-space G-vectors in Cartesian coordinates (1/Angstroms).
|
|
157
|
+
These are the *nominal* (unstrained) G-vectors, already transformed
|
|
158
|
+
through the B matrix for the reference lattice.
|
|
159
|
+
thetas : Tensor (M,)
|
|
160
|
+
Nominal Bragg angles in radians corresponding to ``hkls``.
|
|
161
|
+
geometry : HEDMGeometry
|
|
162
|
+
Detector / scan geometry.
|
|
163
|
+
hkls_int : Tensor (M, 3), optional
|
|
164
|
+
Integer Miller indices. Required for ``correct_hkls_latc`` (strain).
|
|
165
|
+
If None, strain correction is unavailable.
|
|
166
|
+
scan_config : ScanConfig, optional
|
|
167
|
+
Multi-scan configuration. None for single-scan (standard NF / FF).
|
|
168
|
+
device : torch.device
|
|
169
|
+
Target device.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
# Match the C code: both now use M_PI/180.0 (previously the C code
|
|
173
|
+
# used hardcoded 13-digit constants that caused precision loss).
|
|
174
|
+
DEG2RAD = math.pi / 180.0
|
|
175
|
+
RAD2DEG = 180.0 / math.pi
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
hkls: torch.Tensor,
|
|
180
|
+
thetas: torch.Tensor,
|
|
181
|
+
geometry: HEDMGeometry,
|
|
182
|
+
hkls_int: Optional[torch.Tensor] = None,
|
|
183
|
+
scan_config: Optional[ScanConfig] = None,
|
|
184
|
+
device: torch.device = torch.device("cpu"),
|
|
185
|
+
):
|
|
186
|
+
super().__init__()
|
|
187
|
+
|
|
188
|
+
self.register_buffer("hkls", hkls.to(device).float())
|
|
189
|
+
self.register_buffer("thetas", thetas.to(device).float())
|
|
190
|
+
|
|
191
|
+
if hkls_int is not None:
|
|
192
|
+
self.register_buffer("hkls_int", hkls_int.to(device).float())
|
|
193
|
+
else:
|
|
194
|
+
self.hkls_int = None
|
|
195
|
+
|
|
196
|
+
# Geometry -- per-distance arrays stored as tensors for vectorised projection
|
|
197
|
+
Lsd_list = geometry._as_list("Lsd")
|
|
198
|
+
yBC_list = geometry._as_list("y_BC")
|
|
199
|
+
zBC_list = geometry._as_list("z_BC")
|
|
200
|
+
self.n_distances = len(Lsd_list)
|
|
201
|
+
self.register_buffer("_Lsd", torch.tensor(Lsd_list, dtype=torch.float32, device=device))
|
|
202
|
+
self.register_buffer("_y_BC", torch.tensor(yBC_list, dtype=torch.float32, device=device))
|
|
203
|
+
self.register_buffer("_z_BC", torch.tensor(zBC_list, dtype=torch.float32, device=device))
|
|
204
|
+
# Convenience aliases for single-distance (backward compat / simple access)
|
|
205
|
+
self.Lsd = Lsd_list[0]
|
|
206
|
+
self.y_BC = yBC_list[0]
|
|
207
|
+
self.z_BC = zBC_list[0]
|
|
208
|
+
self.px = geometry.px
|
|
209
|
+
self.omega_start = geometry.omega_start
|
|
210
|
+
self.omega_step = geometry.omega_step
|
|
211
|
+
self.n_frames = geometry.n_frames
|
|
212
|
+
self.n_pixels_y = geometry.n_pixels_y
|
|
213
|
+
self.n_pixels_z = geometry.n_pixels_z
|
|
214
|
+
self.min_eta = geometry.min_eta * self.DEG2RAD # store in radians
|
|
215
|
+
self.wavelength = geometry.wavelength
|
|
216
|
+
self.flip_y = geometry.flip_y
|
|
217
|
+
|
|
218
|
+
# Detector tilts (degrees). Stored as an nn.Parameter so they can be
|
|
219
|
+
# optimised via gradient descent (auto-calibration). NF mode
|
|
220
|
+
# (flip_y=False) applies them via _apply_nf_tilt; FF/pf mode
|
|
221
|
+
# (flip_y=True) ignores them because the experimental pipeline pre-
|
|
222
|
+
# corrects for detector tilts at peak-finding time.
|
|
223
|
+
#
|
|
224
|
+
# Composition: RotMatTilts = Rz(tz) @ Ry(ty) @ Rx(tx)
|
|
225
|
+
# Matches RotationTilts() in NF_HEDM/src/SharedFuncsFit.c:230-266.
|
|
226
|
+
self.tilts = nn.Parameter(
|
|
227
|
+
torch.tensor([geometry.tx, geometry.ty, geometry.tz],
|
|
228
|
+
dtype=torch.float64, device=device),
|
|
229
|
+
requires_grad=False,
|
|
230
|
+
)
|
|
231
|
+
self.tx = float(geometry.tx)
|
|
232
|
+
self.ty = float(geometry.ty)
|
|
233
|
+
self.tz = float(geometry.tz)
|
|
234
|
+
self._has_tilts = abs(self.tx) + abs(self.ty) + abs(self.tz) > 0.0
|
|
235
|
+
|
|
236
|
+
# Scan config
|
|
237
|
+
self.scan_config = scan_config
|
|
238
|
+
if scan_config is not None:
|
|
239
|
+
self.register_buffer(
|
|
240
|
+
"_beam_positions", scan_config.beam_positions.to(device).float()
|
|
241
|
+
)
|
|
242
|
+
self._beam_size = scan_config.beam_size
|
|
243
|
+
|
|
244
|
+
self.epsilon = 1e-7
|
|
245
|
+
|
|
246
|
+
# ------------------------------------------------------------------
|
|
247
|
+
# euler2mat (ZXZ convention)
|
|
248
|
+
# ------------------------------------------------------------------
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
|
|
252
|
+
"""Convert ZXZ Euler angles to rotation matrices.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
euler_angles : Tensor (..., 3)
|
|
257
|
+
Euler angles (phi1, Phi, phi2) in radians.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
Tensor (..., 3, 3)
|
|
262
|
+
Rotation matrices.
|
|
263
|
+
"""
|
|
264
|
+
c = torch.cos(euler_angles)
|
|
265
|
+
s = torch.sin(euler_angles)
|
|
266
|
+
|
|
267
|
+
c0, c1, c2 = c[..., 0], c[..., 1], c[..., 2]
|
|
268
|
+
s0, s1, s2 = s[..., 0], s[..., 1], s[..., 2]
|
|
269
|
+
|
|
270
|
+
# ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
|
|
271
|
+
# Verified element-by-element against nfhedm.py lines 114-120
|
|
272
|
+
R = torch.zeros(*euler_angles.shape[:-1], 3, 3,
|
|
273
|
+
dtype=euler_angles.dtype, device=euler_angles.device)
|
|
274
|
+
R[..., 0, 0] = c0 * c2 - s0 * c1 * s2
|
|
275
|
+
R[..., 0, 1] = -s0 * c1 * c2 - c0 * s2
|
|
276
|
+
R[..., 0, 2] = s0 * s1
|
|
277
|
+
R[..., 1, 0] = s0 * c2 + c0 * c1 * s2
|
|
278
|
+
R[..., 1, 1] = c0 * c1 * c2 - s0 * s2
|
|
279
|
+
R[..., 1, 2] = -c0 * s1
|
|
280
|
+
R[..., 2, 0] = s1 * s2
|
|
281
|
+
R[..., 2, 1] = s1 * c2
|
|
282
|
+
R[..., 2, 2] = c1
|
|
283
|
+
|
|
284
|
+
return HEDMForwardModel.orthogonalize(R)
|
|
285
|
+
|
|
286
|
+
# ------------------------------------------------------------------
|
|
287
|
+
# orthogonalize (SO(3) projection via SVD)
|
|
288
|
+
# ------------------------------------------------------------------
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def orthogonalize(R: torch.Tensor) -> torch.Tensor:
|
|
292
|
+
"""Project a (..., 3, 3) matrix onto SO(3).
|
|
293
|
+
|
|
294
|
+
Guarantees ``R^T R = I`` and ``det(R) = +1`` (proper rotation).
|
|
295
|
+
Uses one Newton-Schulz iteration which is differentiable and
|
|
296
|
+
numerically stable (no SVD singularity at repeated singular values).
|
|
297
|
+
|
|
298
|
+
For matrices already near SO(3) (like those from ``euler2mat``),
|
|
299
|
+
a single iteration suffices (quadratic convergence).
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
R : Tensor (..., 3, 3)
|
|
304
|
+
|
|
305
|
+
Returns
|
|
306
|
+
-------
|
|
307
|
+
Tensor (..., 3, 3) -- orthogonal with det = +1.
|
|
308
|
+
"""
|
|
309
|
+
# Newton-Schulz iteration for polar decomposition:
|
|
310
|
+
# Q_{k+1} = 0.5 * Q_k * (3*I - Q_k^T * Q_k)
|
|
311
|
+
# Converges quadratically for matrices near SO(3).
|
|
312
|
+
# One iteration is sufficient for matrices from euler2mat
|
|
313
|
+
# (error < 1e-14 after one step for input error < 1e-7).
|
|
314
|
+
I = torch.eye(3, dtype=R.dtype, device=R.device)
|
|
315
|
+
Q = R
|
|
316
|
+
Q = 0.5 * Q @ (3.0 * I - Q.transpose(-1, -2) @ Q)
|
|
317
|
+
# Ensure det = +1 (the iteration preserves det sign for near-SO(3) input,
|
|
318
|
+
# but we guard against pathological cases)
|
|
319
|
+
det = torch.det(Q)
|
|
320
|
+
# Where det < 0, negate (flips to proper rotation)
|
|
321
|
+
sign = torch.where(det < 0, torch.tensor(-1.0, dtype=R.dtype, device=R.device),
|
|
322
|
+
torch.tensor(1.0, dtype=R.dtype, device=R.device))
|
|
323
|
+
return Q * sign.unsqueeze(-1).unsqueeze(-1)
|
|
324
|
+
|
|
325
|
+
# ------------------------------------------------------------------
|
|
326
|
+
# safe_arccos
|
|
327
|
+
# ------------------------------------------------------------------
|
|
328
|
+
|
|
329
|
+
def safe_arccos(self, x: torch.Tensor) -> torch.Tensor:
|
|
330
|
+
"""Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
|
|
331
|
+
return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
|
|
332
|
+
|
|
333
|
+
# ------------------------------------------------------------------
|
|
334
|
+
# rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
|
|
335
|
+
# ------------------------------------------------------------------
|
|
336
|
+
|
|
337
|
+
@staticmethod
|
|
338
|
+
def rotate_strain_sample_to_crystal(
|
|
339
|
+
orientation_matrices: torch.Tensor,
|
|
340
|
+
strain_sample: torch.Tensor,
|
|
341
|
+
) -> torch.Tensor:
|
|
342
|
+
"""Rotate a symmetric infinitesimal strain from sample to crystal frame.
|
|
343
|
+
|
|
344
|
+
Port of ``RotateStrainSampleToCrystal`` from
|
|
345
|
+
``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
|
|
346
|
+
eps_crystal = OM^T . eps_sample . OM, in Voigt notation
|
|
347
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
orientation_matrices : Tensor (..., 3, 3)
|
|
352
|
+
strain_sample : Tensor (..., 6)
|
|
353
|
+
|
|
354
|
+
Returns
|
|
355
|
+
-------
|
|
356
|
+
strain_crystal : Tensor (..., 6)
|
|
357
|
+
"""
|
|
358
|
+
e = strain_sample
|
|
359
|
+
S = torch.stack([
|
|
360
|
+
torch.stack([e[..., 0], e[..., 1], e[..., 2]], dim=-1),
|
|
361
|
+
torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
|
|
362
|
+
torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
|
|
363
|
+
], dim=-2)
|
|
364
|
+
OM = orientation_matrices
|
|
365
|
+
C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
|
|
366
|
+
return torch.stack([
|
|
367
|
+
C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
|
|
368
|
+
C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
|
|
369
|
+
], dim=-1)
|
|
370
|
+
|
|
371
|
+
# ------------------------------------------------------------------
|
|
372
|
+
# correct_hkls_latc (port of C CorrectHKLsLatC)
|
|
373
|
+
# ------------------------------------------------------------------
|
|
374
|
+
|
|
375
|
+
def correct_hkls_latc(
|
|
376
|
+
self,
|
|
377
|
+
lattice_params: torch.Tensor,
|
|
378
|
+
strain: Optional[torch.Tensor] = None,
|
|
379
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
380
|
+
"""Compute strained reciprocal-space G-vectors and Bragg angles.
|
|
381
|
+
|
|
382
|
+
Builds the reciprocal lattice B matrix from lattice parameters
|
|
383
|
+
and transforms integer Miller indices to Cartesian G-vectors.
|
|
384
|
+
|
|
385
|
+
Faithfully ports ``CorrectHKLsLatC`` from
|
|
386
|
+
``FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252``, with the
|
|
387
|
+
optional crystal-frame strain path from ``CorrectHKLsLatCEpsilon``
|
|
388
|
+
in ``FF_HEDM/src/ForwardSimulationCompressed.c:423-475``.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
lattice_params : Tensor (..., 6)
|
|
393
|
+
[a, b, c, alpha, beta, gamma] in Angstroms and degrees.
|
|
394
|
+
The ``...`` dimensions allow per-voxel or per-grain parameters.
|
|
395
|
+
strain : Tensor (..., 6), optional
|
|
396
|
+
Crystal-frame symmetric infinitesimal strain in Voigt form
|
|
397
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. When supplied,
|
|
398
|
+
the reciprocal lattice is post-multiplied by (I + eps)^{-1}:
|
|
399
|
+
B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
|
|
400
|
+
to convert a sample-frame strain into the crystal frame.
|
|
401
|
+
|
|
402
|
+
Returns
|
|
403
|
+
-------
|
|
404
|
+
hkls_cart : Tensor (..., M, 3)
|
|
405
|
+
G-vectors in Cartesian reciprocal space (1/Angstroms).
|
|
406
|
+
thetas : Tensor (..., M)
|
|
407
|
+
Bragg angles in radians.
|
|
408
|
+
|
|
409
|
+
Raises
|
|
410
|
+
------
|
|
411
|
+
RuntimeError
|
|
412
|
+
If ``hkls_int`` was not provided at construction.
|
|
413
|
+
"""
|
|
414
|
+
if self.hkls_int is None:
|
|
415
|
+
raise RuntimeError(
|
|
416
|
+
"correct_hkls_latc requires integer Miller indices "
|
|
417
|
+
"(pass hkls_int to the constructor)."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
a = lattice_params[..., 0]
|
|
421
|
+
b = lattice_params[..., 1]
|
|
422
|
+
c = lattice_params[..., 2]
|
|
423
|
+
# Angles in degrees -> radians
|
|
424
|
+
alpha = lattice_params[..., 3] * self.DEG2RAD
|
|
425
|
+
beta = lattice_params[..., 4] * self.DEG2RAD
|
|
426
|
+
gamma = lattice_params[..., 5] * self.DEG2RAD
|
|
427
|
+
|
|
428
|
+
sin_a = torch.sin(alpha)
|
|
429
|
+
cos_a = torch.cos(alpha)
|
|
430
|
+
sin_b = torch.sin(beta)
|
|
431
|
+
cos_b = torch.cos(beta)
|
|
432
|
+
sin_g = torch.sin(gamma)
|
|
433
|
+
cos_g = torch.cos(gamma)
|
|
434
|
+
|
|
435
|
+
# Reciprocal lattice angles
|
|
436
|
+
# C: GammaPr = acosd((CosA*CosB - CosG) / (SinA*SinB))
|
|
437
|
+
gamma_pr = torch.acos(
|
|
438
|
+
torch.clamp((cos_a * cos_b - cos_g) / (sin_a * sin_b + self.epsilon),
|
|
439
|
+
-1.0 + self.epsilon, 1.0 - self.epsilon)
|
|
440
|
+
)
|
|
441
|
+
beta_pr = torch.acos(
|
|
442
|
+
torch.clamp((cos_g * cos_a - cos_b) / (sin_g * sin_a + self.epsilon),
|
|
443
|
+
-1.0 + self.epsilon, 1.0 - self.epsilon)
|
|
444
|
+
)
|
|
445
|
+
sin_beta_pr = torch.sin(beta_pr)
|
|
446
|
+
|
|
447
|
+
# Volume and reciprocal lengths
|
|
448
|
+
vol = a * b * c * sin_a * sin_beta_pr * sin_g
|
|
449
|
+
a_pr = b * c * sin_a / (vol + self.epsilon)
|
|
450
|
+
b_pr = c * a * sin_b / (vol + self.epsilon)
|
|
451
|
+
c_pr = a * b * sin_g / (vol + self.epsilon)
|
|
452
|
+
|
|
453
|
+
# Build B matrix (..., 3, 3)
|
|
454
|
+
zeros = torch.zeros_like(a)
|
|
455
|
+
# Row 0
|
|
456
|
+
B00 = a_pr
|
|
457
|
+
B01 = b_pr * torch.cos(gamma_pr)
|
|
458
|
+
B02 = c_pr * torch.cos(beta_pr)
|
|
459
|
+
# Row 1
|
|
460
|
+
B10 = zeros
|
|
461
|
+
B11 = b_pr * torch.sin(gamma_pr)
|
|
462
|
+
B12 = -c_pr * sin_beta_pr * cos_a
|
|
463
|
+
# Row 2
|
|
464
|
+
B20 = zeros
|
|
465
|
+
B21 = zeros
|
|
466
|
+
B22 = c_pr * sin_beta_pr * sin_a
|
|
467
|
+
|
|
468
|
+
# Stack into (..., 3, 3)
|
|
469
|
+
B = torch.stack([
|
|
470
|
+
torch.stack([B00, B01, B02], dim=-1),
|
|
471
|
+
torch.stack([B10, B11, B12], dim=-1),
|
|
472
|
+
torch.stack([B20, B21, B22], dim=-1),
|
|
473
|
+
], dim=-2)
|
|
474
|
+
|
|
475
|
+
# Optional crystal-frame strain: B = (I + eps)^{-1} @ B0
|
|
476
|
+
# Voigt layout matches C CorrectHKLsLatCEpsilon:
|
|
477
|
+
# eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
478
|
+
if strain is not None:
|
|
479
|
+
e11 = strain[..., 0]
|
|
480
|
+
e12 = strain[..., 1]
|
|
481
|
+
e13 = strain[..., 2]
|
|
482
|
+
e22 = strain[..., 3]
|
|
483
|
+
e23 = strain[..., 4]
|
|
484
|
+
e33 = strain[..., 5]
|
|
485
|
+
one = torch.ones_like(e11)
|
|
486
|
+
F_mat = torch.stack([
|
|
487
|
+
torch.stack([one + e11, e12, e13 ], dim=-1),
|
|
488
|
+
torch.stack([e12, one + e22, e23 ], dim=-1),
|
|
489
|
+
torch.stack([e13, e23, one + e33], dim=-1),
|
|
490
|
+
], dim=-2)
|
|
491
|
+
F_inv = torch.linalg.inv(F_mat)
|
|
492
|
+
B = torch.matmul(F_inv, B)
|
|
493
|
+
|
|
494
|
+
# G_cart = B @ hkls_int^T => (..., M, 3)
|
|
495
|
+
# hkls_int is (M, 3), B is (..., 3, 3) -- match dtype
|
|
496
|
+
hkls_cart = torch.einsum("...ij,mj->...mi", B, self.hkls_int.to(B.dtype))
|
|
497
|
+
|
|
498
|
+
# d-spacing = 1 / |G_cart|
|
|
499
|
+
g_norm = torch.norm(hkls_cart, dim=-1).clamp(min=self.epsilon)
|
|
500
|
+
d_spacing = 1.0 / g_norm
|
|
501
|
+
|
|
502
|
+
# Bragg angle: theta = arcsin(wavelength / (2*d))
|
|
503
|
+
sin_theta = (self.wavelength / (2.0 * d_spacing)).clamp(
|
|
504
|
+
-1.0 + self.epsilon, 1.0 - self.epsilon
|
|
505
|
+
)
|
|
506
|
+
thetas = torch.asin(sin_theta)
|
|
507
|
+
|
|
508
|
+
return hkls_cart, thetas
|
|
509
|
+
|
|
510
|
+
# ------------------------------------------------------------------
|
|
511
|
+
# calc_bragg_geometry (omega quadratic solver + eta)
|
|
512
|
+
# ------------------------------------------------------------------
|
|
513
|
+
|
|
514
|
+
def calc_bragg_geometry(
|
|
515
|
+
self,
|
|
516
|
+
orientation_matrices: torch.Tensor,
|
|
517
|
+
hkls_cart: Optional[torch.Tensor] = None,
|
|
518
|
+
thetas: Optional[torch.Tensor] = None,
|
|
519
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
520
|
+
"""Core Bragg geometry: orientations + G-vectors -> angles.
|
|
521
|
+
|
|
522
|
+
Solves the omega quadratic from the diffraction condition and
|
|
523
|
+
computes the eta azimuthal angle.
|
|
524
|
+
|
|
525
|
+
Ports the quadratic solver from
|
|
526
|
+
``NF_HEDM/src/CalcDiffractionSpots.c:87-183``.
|
|
527
|
+
|
|
528
|
+
Parameters
|
|
529
|
+
----------
|
|
530
|
+
orientation_matrices : Tensor (..., N, 3, 3)
|
|
531
|
+
Rotation matrices for each voxel/grain.
|
|
532
|
+
hkls_cart : Tensor (..., M, 3) or None
|
|
533
|
+
G-vectors in Cartesian reciprocal space. If None, uses
|
|
534
|
+
the nominal ``self.hkls``.
|
|
535
|
+
thetas : Tensor (..., M) or None
|
|
536
|
+
Bragg angles in radians. If None, uses ``self.thetas``.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
omega : Tensor (..., 2N, M) -- two solutions (+/-) per position
|
|
541
|
+
eta : Tensor (..., 2N, M)
|
|
542
|
+
two_theta : Tensor (..., 2N, M)
|
|
543
|
+
valid : Tensor (..., 2N, M) float mask (1=valid, 0=invalid)
|
|
544
|
+
"""
|
|
545
|
+
dtype = orientation_matrices.dtype
|
|
546
|
+
if hkls_cart is None:
|
|
547
|
+
hkls_cart = self.hkls.to(dtype) # (M, 3)
|
|
548
|
+
else:
|
|
549
|
+
hkls_cart = hkls_cart.to(dtype)
|
|
550
|
+
if thetas is None:
|
|
551
|
+
thetas = self.thetas.to(dtype) # (M,)
|
|
552
|
+
else:
|
|
553
|
+
thetas = thetas.to(dtype)
|
|
554
|
+
|
|
555
|
+
# G_C = R @ hkls^T => (..., N, M, 3)
|
|
556
|
+
# Two cases supported: (a) hkls_cart shape (M, 3) shared across the
|
|
557
|
+
# batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
|
|
558
|
+
# rendering. Both flow through the same einsum via leading-dim
|
|
559
|
+
# broadcasting on the second arg.
|
|
560
|
+
G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
|
|
561
|
+
|
|
562
|
+
# v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
|
|
563
|
+
# (rotation preserves norm in exact arithmetic but not in float64).
|
|
564
|
+
# Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
|
|
565
|
+
len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
|
|
566
|
+
v = torch.sin(thetas) * len_hkl # (M,) or (..., M)
|
|
567
|
+
v = v.unsqueeze(-2).expand_as(G_C[..., 0]) # (..., N, M)
|
|
568
|
+
|
|
569
|
+
# Extract components
|
|
570
|
+
Gx = G_C[..., 0] # (..., N, M)
|
|
571
|
+
Gy = G_C[..., 1]
|
|
572
|
+
Gz = G_C[..., 2]
|
|
573
|
+
|
|
574
|
+
# Quadratic solver for omega
|
|
575
|
+
# -Gx*cos(w) + Gy*sin(w) = v
|
|
576
|
+
# Rearranged: a*cos^2(w) + b*cos(w) + c = 0
|
|
577
|
+
# C uses almostzero=1e-12 for the Gy≈0 branch (see
|
|
578
|
+
# NF_HEDM/src/CalcDiffractionSpots.c:96 and
|
|
579
|
+
# FF_HEDM/src/ForwardSimulationCompressed.c:168). Match exactly.
|
|
580
|
+
almostzero = 1e-12
|
|
581
|
+
x2 = Gx * Gx
|
|
582
|
+
y2 = Gy * Gy
|
|
583
|
+
a = 1.0 + x2 / (y2 + self.epsilon)
|
|
584
|
+
b_coeff = 2.0 * v * Gx / (y2 + self.epsilon)
|
|
585
|
+
c_coeff = v * v / (y2 + self.epsilon) - 1.0
|
|
586
|
+
discriminant = b_coeff * b_coeff - 4.0 * a * c_coeff
|
|
587
|
+
|
|
588
|
+
sqrt_disc = torch.sqrt(torch.abs(discriminant))
|
|
589
|
+
|
|
590
|
+
coswp = (-b_coeff + sqrt_disc) / (2.0 * a)
|
|
591
|
+
coswn = (-b_coeff - sqrt_disc) / (2.0 * a)
|
|
592
|
+
|
|
593
|
+
wap = self.safe_arccos(coswp)
|
|
594
|
+
wan = self.safe_arccos(coswn)
|
|
595
|
+
wbp = -wap
|
|
596
|
+
wbn = -wan
|
|
597
|
+
|
|
598
|
+
# Select correct branch: the one satisfying -Gx*cos(w)+Gy*sin(w)=v
|
|
599
|
+
eqap = -Gx * torch.cos(wap) + Gy * torch.sin(wap)
|
|
600
|
+
eqbp = -Gx * torch.cos(wbp) + Gy * torch.sin(wbp)
|
|
601
|
+
eqan = -Gx * torch.cos(wan) + Gy * torch.sin(wan)
|
|
602
|
+
eqbn = -Gx * torch.cos(wbn) + Gy * torch.sin(wbn)
|
|
603
|
+
|
|
604
|
+
Dap = torch.abs(eqap - v)
|
|
605
|
+
Dbp = torch.abs(eqbp - v)
|
|
606
|
+
Dan = torch.abs(eqan - v)
|
|
607
|
+
Dbn = torch.abs(eqbn - v)
|
|
608
|
+
|
|
609
|
+
all_wp = torch.where(Dap < Dbp, wap, wbp)
|
|
610
|
+
all_wn = torch.where(Dan < Dbn, wan, wbn)
|
|
611
|
+
|
|
612
|
+
# Special case: Gy ~ 0 (C uses almostzero=1e-12)
|
|
613
|
+
# C code (CalcDiffractionSpots.c:97-106):
|
|
614
|
+
# cosome1 = -v / x;
|
|
615
|
+
# if (|cosome1| <= 1) { ome = acos(cosome1); solutions: +ome, -ome }
|
|
616
|
+
gy_zero = torch.abs(Gy) < almostzero
|
|
617
|
+
cosome_special = -v / (Gx + self.epsilon)
|
|
618
|
+
cosome_special_valid = (torch.abs(cosome_special) <= 1.0) & gy_zero & (torch.abs(Gx) > self.epsilon)
|
|
619
|
+
special_w = self.safe_arccos(cosome_special) # positive omega solution
|
|
620
|
+
# Two solutions: +ome and -ome
|
|
621
|
+
special_wp = special_w # positive
|
|
622
|
+
special_wn = -special_w # negative
|
|
623
|
+
|
|
624
|
+
# When |Gy| < almostzero, use the special case; otherwise use the quadratic
|
|
625
|
+
disc_valid = (discriminant >= 0) & (~gy_zero)
|
|
626
|
+
coswp_valid = (coswp >= -1.0) & (coswp <= 1.0)
|
|
627
|
+
coswn_valid = (coswn >= -1.0) & (coswn <= 1.0)
|
|
628
|
+
|
|
629
|
+
omega_p = torch.where(cosome_special_valid, special_wp,
|
|
630
|
+
torch.where(disc_valid & coswp_valid, all_wp,
|
|
631
|
+
torch.zeros_like(all_wp)))
|
|
632
|
+
omega_n = torch.where(cosome_special_valid, special_wn,
|
|
633
|
+
torch.where(disc_valid & coswn_valid, all_wn,
|
|
634
|
+
torch.zeros_like(all_wn)))
|
|
635
|
+
|
|
636
|
+
# Concatenate two solutions: (..., 2N, M)
|
|
637
|
+
all_omega = torch.cat([omega_p, omega_n], dim=-2)
|
|
638
|
+
|
|
639
|
+
# Build omega rotation matrix for each spot
|
|
640
|
+
cos_w = torch.cos(all_omega)
|
|
641
|
+
sin_w = torch.sin(all_omega)
|
|
642
|
+
|
|
643
|
+
# Construct Rz(omega): rotation around z-axis
|
|
644
|
+
# [[cos, -sin, 0], [sin, cos, 0], [0, 0, 1]]
|
|
645
|
+
Omega_mat = torch.zeros(*all_omega.shape, 3, 3,
|
|
646
|
+
dtype=all_omega.dtype, device=all_omega.device)
|
|
647
|
+
Omega_mat[..., 0, 0] = cos_w
|
|
648
|
+
Omega_mat[..., 0, 1] = -sin_w
|
|
649
|
+
Omega_mat[..., 1, 0] = sin_w
|
|
650
|
+
Omega_mat[..., 1, 1] = cos_w
|
|
651
|
+
Omega_mat[..., 2, 2] = 1.0
|
|
652
|
+
|
|
653
|
+
# Rotate G_C by omega: nrot = Omega @ G_C
|
|
654
|
+
# G_C is (..., N, M, 3); double along N dim (dim=-3)
|
|
655
|
+
G_C_doubled = torch.cat([G_C, G_C], dim=-3) # (..., 2N, M, 3)
|
|
656
|
+
nrot = torch.einsum("...kmij,...kmj->...kmi", Omega_mat, G_C_doubled)
|
|
657
|
+
nrot_y = nrot[..., 1] # (..., 2N, M)
|
|
658
|
+
nrot_z = nrot[..., 2]
|
|
659
|
+
|
|
660
|
+
# Eta angle
|
|
661
|
+
r_yz = torch.sqrt(nrot_y * nrot_y + nrot_z * nrot_z).clamp(min=self.epsilon)
|
|
662
|
+
eta = self.safe_arccos(nrot_z / r_yz)
|
|
663
|
+
eta = -torch.sign(nrot_y) * eta
|
|
664
|
+
|
|
665
|
+
# 2*theta (broadcast thetas to match 2N dimension)
|
|
666
|
+
two_theta_single = 2.0 * thetas.unsqueeze(-2) # (..., 1, M) or (1, M)
|
|
667
|
+
two_theta = two_theta_single.expand_as(all_omega)
|
|
668
|
+
|
|
669
|
+
# Validity mask
|
|
670
|
+
valid_p = disc_valid & coswp_valid
|
|
671
|
+
valid_n = disc_valid & coswn_valid
|
|
672
|
+
# For gy_zero special case, valid only if cosome is in [-1, 1]
|
|
673
|
+
valid_p = valid_p | cosome_special_valid
|
|
674
|
+
valid_n = valid_n | cosome_special_valid
|
|
675
|
+
valid = torch.cat([valid_p, valid_n], dim=-2).float()
|
|
676
|
+
|
|
677
|
+
# Eta bounds
|
|
678
|
+
eta_ok = (torch.abs(eta) >= self.min_eta) & \
|
|
679
|
+
((math.pi - torch.abs(eta)) >= self.min_eta)
|
|
680
|
+
valid = valid * eta_ok.float()
|
|
681
|
+
|
|
682
|
+
return all_omega, eta, two_theta, valid
|
|
683
|
+
|
|
684
|
+
# ------------------------------------------------------------------
|
|
685
|
+
# Tilt rotation matrix (RotationTilts in SharedFuncsFit.c:230-266)
|
|
686
|
+
# ------------------------------------------------------------------
|
|
687
|
+
|
|
688
|
+
@staticmethod
|
|
689
|
+
def _build_rot_tilts(tx_deg: float, ty_deg: float, tz_deg: float,
|
|
690
|
+
device: torch.device) -> torch.Tensor:
|
|
691
|
+
"""Build the 3x3 NF-style tilt rotation matrix Rz(tz) @ Ry(ty) @ Rx(tx).
|
|
692
|
+
|
|
693
|
+
Matches RotationTilts() in NF_HEDM/src/SharedFuncsFit.c:230-266.
|
|
694
|
+
"""
|
|
695
|
+
d2r = math.pi / 180.0
|
|
696
|
+
tx, ty, tz = tx_deg * d2r, ty_deg * d2r, tz_deg * d2r
|
|
697
|
+
cx, sx = math.cos(tx), math.sin(tx)
|
|
698
|
+
cy, sy = math.cos(ty), math.sin(ty)
|
|
699
|
+
cz, sz = math.cos(tz), math.sin(tz)
|
|
700
|
+
Rx = torch.tensor([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=torch.float64)
|
|
701
|
+
Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=torch.float64)
|
|
702
|
+
Rz = torch.tensor([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=torch.float64)
|
|
703
|
+
# Composition matches NF C: Rz @ Ry @ Rx
|
|
704
|
+
return (Rz @ Ry @ Rx).to(device)
|
|
705
|
+
|
|
706
|
+
def _build_rot_tilts_from_param(self, dtype) -> torch.Tensor:
|
|
707
|
+
"""Build Rz(tz) @ Ry(ty) @ Rx(tx) from self.tilts (differentiable)."""
|
|
708
|
+
d2r = math.pi / 180.0
|
|
709
|
+
t = self.tilts.to(dtype) * d2r
|
|
710
|
+
tx_, ty_, tz_ = t[0], t[1], t[2]
|
|
711
|
+
cx, sx = torch.cos(tx_), torch.sin(tx_)
|
|
712
|
+
cy, sy = torch.cos(ty_), torch.sin(ty_)
|
|
713
|
+
cz, sz = torch.cos(tz_), torch.sin(tz_)
|
|
714
|
+
zero = torch.zeros((), dtype=dtype, device=t.device)
|
|
715
|
+
one = torch.ones((), dtype=dtype, device=t.device)
|
|
716
|
+
Rx = torch.stack([
|
|
717
|
+
torch.stack([one, zero, zero]),
|
|
718
|
+
torch.stack([zero, cx, -sx ]),
|
|
719
|
+
torch.stack([zero, sx, cx ]),
|
|
720
|
+
])
|
|
721
|
+
Ry = torch.stack([
|
|
722
|
+
torch.stack([cy, zero, sy ]),
|
|
723
|
+
torch.stack([zero, one, zero]),
|
|
724
|
+
torch.stack([-sy, zero, cy ]),
|
|
725
|
+
])
|
|
726
|
+
Rz = torch.stack([
|
|
727
|
+
torch.stack([cz, -sz, zero]),
|
|
728
|
+
torch.stack([sz, cz, zero]),
|
|
729
|
+
torch.stack([zero, zero, one]),
|
|
730
|
+
])
|
|
731
|
+
return Rz @ Ry @ Rx
|
|
732
|
+
|
|
733
|
+
def _apply_nf_tilt(self, ydet: torch.Tensor, zdet: torch.Tensor,
|
|
734
|
+
Lsd_val) -> "tuple[torch.Tensor, torch.Tensor]":
|
|
735
|
+
"""Apply the NF detector-tilt correction to lab-frame (ydet, zdet).
|
|
736
|
+
|
|
737
|
+
Ports the ray-plane intersection in
|
|
738
|
+
``NF_HEDM/src/SharedFuncsFit.c:947-958``: builds
|
|
739
|
+
P0 = RotMatTilts @ [-Lsd, 0, 0], P1 = RotMatTilts @ [0, ydet, zdet],
|
|
740
|
+
and returns the (y, z) coordinates where the line from P0 through P1
|
|
741
|
+
crosses the plane x = 0. Reduces to the identity when tilts are zero.
|
|
742
|
+
|
|
743
|
+
The rotation matrix is rebuilt from ``self.tilts`` on every call, so
|
|
744
|
+
if ``self.tilts.requires_grad`` is True the tilt parameters enter the
|
|
745
|
+
autograd graph and can be optimised via gradient descent
|
|
746
|
+
(auto-calibration). ``Lsd_val`` can be a Python scalar or a scalar tensor.
|
|
747
|
+
"""
|
|
748
|
+
if not self._has_tilts and not self.tilts.requires_grad:
|
|
749
|
+
return ydet, zdet
|
|
750
|
+
dtype = ydet.dtype
|
|
751
|
+
R = self._build_rot_tilts_from_param(dtype)
|
|
752
|
+
# P0 = -Lsd * R[:, 0] (3 scalars)
|
|
753
|
+
p0x = -Lsd_val * R[0, 0]
|
|
754
|
+
p0y = -Lsd_val * R[1, 0]
|
|
755
|
+
p0z = -Lsd_val * R[2, 0]
|
|
756
|
+
# P1 = ydet * R[:, 1] + zdet * R[:, 2] (pointwise on tensors)
|
|
757
|
+
P1x = ydet * R[0, 1] + zdet * R[0, 2]
|
|
758
|
+
P1y = ydet * R[1, 1] + zdet * R[1, 2]
|
|
759
|
+
P1z = ydet * R[2, 1] + zdet * R[2, 2]
|
|
760
|
+
ABCx = P1x - p0x
|
|
761
|
+
ABCy = P1y - p0y
|
|
762
|
+
ABCz = P1z - p0z
|
|
763
|
+
safe_denom = torch.where(
|
|
764
|
+
torch.abs(ABCx) < self.epsilon,
|
|
765
|
+
torch.full_like(ABCx, self.epsilon),
|
|
766
|
+
ABCx,
|
|
767
|
+
)
|
|
768
|
+
out_y = p0y - ABCy * p0x / safe_denom
|
|
769
|
+
out_z = p0z - ABCz * p0x / safe_denom
|
|
770
|
+
return out_y, out_z
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
def _build_ff_tilt_rot(tx_deg: float, ty_deg: float, tz_deg: float,
|
|
774
|
+
device: torch.device) -> torch.Tensor:
|
|
775
|
+
"""Build the 3x3 FF-style tilt rotation matrix Rx(tx) @ Ry(ty) @ Rz(tz).
|
|
776
|
+
|
|
777
|
+
Matches CorrectTiltSpatialDistortion() in
|
|
778
|
+
FF_HEDM/src/ForwardSimulationCompressed.c:593-612.
|
|
779
|
+
Note the composition differs from the NF convention.
|
|
780
|
+
"""
|
|
781
|
+
d2r = math.pi / 180.0
|
|
782
|
+
tx, ty, tz = tx_deg * d2r, ty_deg * d2r, tz_deg * d2r
|
|
783
|
+
cx, sx = math.cos(tx), math.sin(tx)
|
|
784
|
+
cy, sy = math.cos(ty), math.sin(ty)
|
|
785
|
+
cz, sz = math.cos(tz), math.sin(tz)
|
|
786
|
+
Rx = torch.tensor([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=torch.float64)
|
|
787
|
+
Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=torch.float64)
|
|
788
|
+
Rz = torch.tensor([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=torch.float64)
|
|
789
|
+
# Composition matches FF C: Rx @ Ry @ Rz
|
|
790
|
+
return (Rx @ Ry @ Rz).to(device)
|
|
791
|
+
|
|
792
|
+
# ------------------------------------------------------------------
|
|
793
|
+
# project_to_detector
|
|
794
|
+
# ------------------------------------------------------------------
|
|
795
|
+
|
|
796
|
+
def project_to_detector(
|
|
797
|
+
self,
|
|
798
|
+
omega: torch.Tensor,
|
|
799
|
+
eta: torch.Tensor,
|
|
800
|
+
two_theta: torch.Tensor,
|
|
801
|
+
positions: torch.Tensor,
|
|
802
|
+
valid: torch.Tensor,
|
|
803
|
+
) -> SpotDescriptors:
|
|
804
|
+
"""Position-dependent detector projection for one or more distances.
|
|
805
|
+
|
|
806
|
+
Implements the geometry from ``SharedFuncsFit.c:DisplacementSpots``
|
|
807
|
+
(lines 269-292). For multi-distance NF-HEDM, projects to each
|
|
808
|
+
distance and requires spots to be on-detector at ALL distances
|
|
809
|
+
(the ``AllDistsFound`` logic from ``CalcFracOverlap``, line 638).
|
|
810
|
+
|
|
811
|
+
Parameters
|
|
812
|
+
----------
|
|
813
|
+
omega : Tensor (..., 2N, M)
|
|
814
|
+
eta : Tensor (..., 2N, M)
|
|
815
|
+
two_theta : Tensor (..., 2N, M)
|
|
816
|
+
positions : Tensor (N, 3) or (..., N, 3)
|
|
817
|
+
Real-space positions [x, y, z] in micrometers.
|
|
818
|
+
valid : Tensor (..., 2N, M)
|
|
819
|
+
|
|
820
|
+
Returns
|
|
821
|
+
-------
|
|
822
|
+
SpotDescriptors
|
|
823
|
+
"""
|
|
824
|
+
N = positions.shape[-2]
|
|
825
|
+
|
|
826
|
+
# Omega-rotated position: rotate (x,y,z) by omega around z-axis
|
|
827
|
+
pos_doubled = torch.cat([positions, positions], dim=-2) # (..., 2N, 3)
|
|
828
|
+
|
|
829
|
+
cos_w = torch.cos(omega) # (..., 2N, M)
|
|
830
|
+
sin_w = torch.sin(omega)
|
|
831
|
+
|
|
832
|
+
px = pos_doubled[..., 0].unsqueeze(-1) # (..., 2N, 1)
|
|
833
|
+
py = pos_doubled[..., 1].unsqueeze(-1)
|
|
834
|
+
pz = pos_doubled[..., 2].unsqueeze(-1)
|
|
835
|
+
|
|
836
|
+
x_grain = px * cos_w - py * sin_w # (..., 2N, M)
|
|
837
|
+
y_grain = px * sin_w + py * cos_w
|
|
838
|
+
z_grain = pz.expand_as(x_grain)
|
|
839
|
+
|
|
840
|
+
tan_2th = torch.tan(two_theta)
|
|
841
|
+
sin_eta = torch.sin(eta)
|
|
842
|
+
cos_eta = torch.cos(eta)
|
|
843
|
+
|
|
844
|
+
# Frame number (same at all distances -- omega doesn't change)
|
|
845
|
+
frame_nr = (omega / self.DEG2RAD - self.omega_start) / self.omega_step
|
|
846
|
+
frame_ok = (frame_nr >= 0) & (frame_nr < self.n_frames)
|
|
847
|
+
|
|
848
|
+
# Project to each detector distance
|
|
849
|
+
D = self.n_distances
|
|
850
|
+
dtype = omega.dtype
|
|
851
|
+
# _Lsd, _y_BC, _z_BC are (D,) tensors
|
|
852
|
+
# Reshape to (D, 1..., 1, 1) for broadcasting against (..., 2N, M)
|
|
853
|
+
extra_dims = omega.ndim # number of dims in (..., 2N, M)
|
|
854
|
+
Lsd_d = self._Lsd.to(dtype).reshape(D, *([1] * extra_dims))
|
|
855
|
+
yBC_d = self._y_BC.to(dtype).reshape(D, *([1] * extra_dims))
|
|
856
|
+
zBC_d = self._z_BC.to(dtype).reshape(D, *([1] * extra_dims))
|
|
857
|
+
|
|
858
|
+
# ydet, zdet, y_pixel, z_pixel all get shape (D, ..., 2N, M)
|
|
859
|
+
dist_d = Lsd_d - x_grain.unsqueeze(0) # (D, ..., 2N, M)
|
|
860
|
+
ydet_d = y_grain.unsqueeze(0) - dist_d * tan_2th.unsqueeze(0) * sin_eta.unsqueeze(0)
|
|
861
|
+
zdet_d = z_grain.unsqueeze(0) + dist_d * tan_2th.unsqueeze(0) * cos_eta.unsqueeze(0)
|
|
862
|
+
|
|
863
|
+
# Apply detector tilt -- NF mode only.
|
|
864
|
+
#
|
|
865
|
+
# Design note: FF and pf-HEDM experimental workflows apply a DetCor
|
|
866
|
+
# correction at peak-finding time, so the per-spot centroids in
|
|
867
|
+
# SpotMatrix.csv are already tilt- and distortion-corrected. A
|
|
868
|
+
# differentiable forward model targeting FF/pf experimental data
|
|
869
|
+
# therefore must NOT apply tilts -- doing so would double-correct.
|
|
870
|
+
# NF-HEDM works at pixel level against raw detector images, with no
|
|
871
|
+
# DetCor step, so the forward model MUST include tilts to produce
|
|
872
|
+
# pixel predictions that match real NF measurements.
|
|
873
|
+
#
|
|
874
|
+
# The NF branch below ports the ray-plane intersection from
|
|
875
|
+
# NF_HEDM/src/SharedFuncsFit.c:947-958 (composition Rz @ Ry @ Rx,
|
|
876
|
+
# P0 = R @ [-Lsd, 0, 0]). The FF/pf path ignores tilts entirely.
|
|
877
|
+
if (not self.flip_y) and self._has_tilts:
|
|
878
|
+
# Per-distance tilt application. Use a loop to keep per-Lsd
|
|
879
|
+
# handling explicit (typical NF has 1-4 distances).
|
|
880
|
+
Lsd_list = self._Lsd.to(dtype)
|
|
881
|
+
out_y = []
|
|
882
|
+
out_z = []
|
|
883
|
+
for d in range(self.n_distances):
|
|
884
|
+
yd, zd = self._apply_nf_tilt(
|
|
885
|
+
ydet_d[d], zdet_d[d], Lsd_list[d]
|
|
886
|
+
)
|
|
887
|
+
out_y.append(yd)
|
|
888
|
+
out_z.append(zd)
|
|
889
|
+
ydet_d = torch.stack(out_y, dim=0)
|
|
890
|
+
zdet_d = torch.stack(out_z, dim=0)
|
|
891
|
+
|
|
892
|
+
# FF/PF: y-axis on detector flipped (yBC - ydet/px), validated against C
|
|
893
|
+
# NF: not flipped (yBC + ydet/px), validated against C
|
|
894
|
+
y_sign = -1.0 if self.flip_y else 1.0
|
|
895
|
+
y_pixel_d = yBC_d + y_sign * ydet_d / self.px # (D, ..., 2N, M)
|
|
896
|
+
z_pixel_d = zBC_d + zdet_d / self.px
|
|
897
|
+
|
|
898
|
+
# Per-distance detector bounds
|
|
899
|
+
layer_bounds_ok = (
|
|
900
|
+
(y_pixel_d >= 0) & (y_pixel_d < self.n_pixels_y) &
|
|
901
|
+
(z_pixel_d >= 0) & (z_pixel_d < self.n_pixels_z)
|
|
902
|
+
) # (D, ..., 2N, M)
|
|
903
|
+
|
|
904
|
+
# Per-distance validity = angular valid & frame ok & detector bounds
|
|
905
|
+
layer_valid = valid.unsqueeze(0) * frame_ok.unsqueeze(0).float() * layer_bounds_ok.float()
|
|
906
|
+
|
|
907
|
+
# Overall valid = valid at ALL distances (AllDistsFound)
|
|
908
|
+
all_dists_valid = layer_valid.prod(dim=0) # (..., 2N, M)
|
|
909
|
+
|
|
910
|
+
# For single-distance, squeeze out the D dimension for convenience
|
|
911
|
+
if D == 1:
|
|
912
|
+
y_pixel_out = y_pixel_d.squeeze(0)
|
|
913
|
+
z_pixel_out = z_pixel_d.squeeze(0)
|
|
914
|
+
layer_valid_out = None
|
|
915
|
+
else:
|
|
916
|
+
y_pixel_out = y_pixel_d
|
|
917
|
+
z_pixel_out = z_pixel_d
|
|
918
|
+
layer_valid_out = layer_valid
|
|
919
|
+
|
|
920
|
+
return SpotDescriptors(
|
|
921
|
+
omega=omega,
|
|
922
|
+
eta=eta,
|
|
923
|
+
two_theta=two_theta,
|
|
924
|
+
y_pixel=y_pixel_out,
|
|
925
|
+
z_pixel=z_pixel_out,
|
|
926
|
+
frame_nr=frame_nr,
|
|
927
|
+
valid=all_dists_valid,
|
|
928
|
+
layer_valid=layer_valid_out,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# ------------------------------------------------------------------
|
|
932
|
+
# forward (orchestrator)
|
|
933
|
+
# ------------------------------------------------------------------
|
|
934
|
+
|
|
935
|
+
def forward(
|
|
936
|
+
self,
|
|
937
|
+
euler_angles: torch.Tensor,
|
|
938
|
+
positions: torch.Tensor,
|
|
939
|
+
lattice_params: Optional[torch.Tensor] = None,
|
|
940
|
+
strain: Optional[torch.Tensor] = None,
|
|
941
|
+
) -> SpotDescriptors:
|
|
942
|
+
"""Full forward simulation pipeline.
|
|
943
|
+
|
|
944
|
+
Parameters
|
|
945
|
+
----------
|
|
946
|
+
euler_angles : Tensor (..., N, 3)
|
|
947
|
+
Euler angles (phi1, Phi, phi2) in radians at each position.
|
|
948
|
+
positions : Tensor (N, 3) or (N, 2) or (..., N, 3)
|
|
949
|
+
Real-space positions in micrometers. If (N,2), z is padded to 0.
|
|
950
|
+
lattice_params : Tensor (..., 6) or (..., N, 6), optional
|
|
951
|
+
Strained lattice parameters [a,b,c,alpha,beta,gamma] in
|
|
952
|
+
Angstroms/degrees. None = use nominal hkls/thetas (no strain).
|
|
953
|
+
strain : Tensor (..., 6) or (..., N, 6), optional
|
|
954
|
+
Crystal-frame symmetric infinitesimal strain in Voigt form
|
|
955
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. Applied as
|
|
956
|
+
B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
|
|
957
|
+
strain expressed through ``lattice_params``. Requires
|
|
958
|
+
``lattice_params`` to be supplied.
|
|
959
|
+
|
|
960
|
+
Returns
|
|
961
|
+
-------
|
|
962
|
+
SpotDescriptors
|
|
963
|
+
"""
|
|
964
|
+
# Backward compat: pad (N,2) -> (N,3)
|
|
965
|
+
if positions.shape[-1] == 2:
|
|
966
|
+
positions = F.pad(positions, (0, 1), value=0.0)
|
|
967
|
+
|
|
968
|
+
# 1. Orientation matrices
|
|
969
|
+
orientation_matrices = self.euler2mat(euler_angles)
|
|
970
|
+
|
|
971
|
+
# 2. Optionally compute strained G-vectors / thetas
|
|
972
|
+
hkls_cart = None
|
|
973
|
+
thetas = None
|
|
974
|
+
if lattice_params is not None:
|
|
975
|
+
hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
|
|
976
|
+
elif strain is not None:
|
|
977
|
+
raise ValueError(
|
|
978
|
+
"strain was supplied but lattice_params is None; strain "
|
|
979
|
+
"requires a reference lattice to apply (I + eps)^{-1} @ B0."
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# 3. Bragg geometry
|
|
983
|
+
omega, eta, two_theta, valid = self.calc_bragg_geometry(
|
|
984
|
+
orientation_matrices, hkls_cart, thetas
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
# 4. Detector projection
|
|
988
|
+
spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
|
|
989
|
+
|
|
990
|
+
# 5. Scan filter (multi-scan only)
|
|
991
|
+
if self.scan_config is not None:
|
|
992
|
+
spots = self.filter_by_scan(spots, positions)
|
|
993
|
+
|
|
994
|
+
return spots
|
|
995
|
+
|
|
996
|
+
# ------------------------------------------------------------------
|
|
997
|
+
# filter_by_scan (beam proximity for pf-HEDM)
|
|
998
|
+
# ------------------------------------------------------------------
|
|
999
|
+
|
|
1000
|
+
def filter_by_scan(
|
|
1001
|
+
self, spots: SpotDescriptors, positions: torch.Tensor
|
|
1002
|
+
) -> SpotDescriptors:
|
|
1003
|
+
"""Apply beam illumination filter for multi-scan geometry.
|
|
1004
|
+
|
|
1005
|
+
For each spot, checks whether the omega-rotated y-position of the
|
|
1006
|
+
source voxel falls within the beam at each scan position.
|
|
1007
|
+
|
|
1008
|
+
Ports ``FitOrStrainsScanningOMP.c:1050-1058``:
|
|
1009
|
+
yRot = posX * sin(omega) + posY * cos(omega)
|
|
1010
|
+
|yRot - beam_y[scan]| < beam_size / 2
|
|
1011
|
+
|
|
1012
|
+
Parameters
|
|
1013
|
+
----------
|
|
1014
|
+
spots : SpotDescriptors
|
|
1015
|
+
positions : Tensor (..., N, 3)
|
|
1016
|
+
|
|
1017
|
+
Returns
|
|
1018
|
+
-------
|
|
1019
|
+
SpotDescriptors with ``scan_mask`` populated.
|
|
1020
|
+
"""
|
|
1021
|
+
if self.scan_config is None:
|
|
1022
|
+
return spots
|
|
1023
|
+
|
|
1024
|
+
N = positions.shape[-2]
|
|
1025
|
+
pos_doubled = torch.cat([positions, positions], dim=-2) # (..., 2N, 3)
|
|
1026
|
+
|
|
1027
|
+
cos_w = torch.cos(spots.omega) # (..., 2N, M)
|
|
1028
|
+
sin_w = torch.sin(spots.omega)
|
|
1029
|
+
|
|
1030
|
+
px = pos_doubled[..., 0].unsqueeze(-1) # (..., 2N, 1)
|
|
1031
|
+
py = pos_doubled[..., 1].unsqueeze(-1)
|
|
1032
|
+
|
|
1033
|
+
# Omega-rotated y position
|
|
1034
|
+
y_rot = px * sin_w + py * cos_w # (..., 2N, M)
|
|
1035
|
+
|
|
1036
|
+
# beam_positions: (S,)
|
|
1037
|
+
beam_y = self._beam_positions
|
|
1038
|
+
half_beam = self._beam_size / 2.0
|
|
1039
|
+
|
|
1040
|
+
# |yRot - beam_y[s]| < half_beam
|
|
1041
|
+
# y_rot: (..., 2N, M), beam_y: (S,)
|
|
1042
|
+
# Expand for broadcasting: (..., 1, 2N, M) vs (S, 1, 1)
|
|
1043
|
+
y_rot_exp = y_rot.unsqueeze(-3) # (..., 1, 2N, M)
|
|
1044
|
+
beam_y_exp = beam_y.reshape(-1, 1, 1) # (S, 1, 1)
|
|
1045
|
+
|
|
1046
|
+
scan_mask = (torch.abs(y_rot_exp - beam_y_exp) < half_beam).float()
|
|
1047
|
+
# Combine with overall validity
|
|
1048
|
+
scan_mask = scan_mask * spots.valid.unsqueeze(-3)
|
|
1049
|
+
|
|
1050
|
+
spots = SpotDescriptors(
|
|
1051
|
+
omega=spots.omega,
|
|
1052
|
+
eta=spots.eta,
|
|
1053
|
+
two_theta=spots.two_theta,
|
|
1054
|
+
y_pixel=spots.y_pixel,
|
|
1055
|
+
z_pixel=spots.z_pixel,
|
|
1056
|
+
frame_nr=spots.frame_nr,
|
|
1057
|
+
valid=spots.valid,
|
|
1058
|
+
scan_mask=scan_mask,
|
|
1059
|
+
)
|
|
1060
|
+
return spots
|
|
1061
|
+
|
|
1062
|
+
# ------------------------------------------------------------------
|
|
1063
|
+
# predict_images (NF output mode: Gaussian splatting)
|
|
1064
|
+
# ------------------------------------------------------------------
|
|
1065
|
+
|
|
1066
|
+
@staticmethod
|
|
1067
|
+
def predict_images(
|
|
1068
|
+
spots: SpotDescriptors,
|
|
1069
|
+
n_frames: int,
|
|
1070
|
+
n_pixels_y: int,
|
|
1071
|
+
n_pixels_z: int,
|
|
1072
|
+
sigma: float = 1.0,
|
|
1073
|
+
radius: int = 3,
|
|
1074
|
+
) -> torch.Tensor:
|
|
1075
|
+
"""Gaussian-splat predicted spots onto detector grid.
|
|
1076
|
+
|
|
1077
|
+
This is the NF-HEDM output mode. Each valid spot is represented
|
|
1078
|
+
as a Gaussian blob on the (frame, y, z) detector volume.
|
|
1079
|
+
|
|
1080
|
+
Parameters
|
|
1081
|
+
----------
|
|
1082
|
+
spots : SpotDescriptors
|
|
1083
|
+
Output from ``forward()``.
|
|
1084
|
+
n_frames, n_pixels_y, n_pixels_z : int
|
|
1085
|
+
Grid dimensions.
|
|
1086
|
+
sigma : float
|
|
1087
|
+
Gaussian kernel sigma in pixels.
|
|
1088
|
+
radius : int
|
|
1089
|
+
Kernel radius in pixels.
|
|
1090
|
+
|
|
1091
|
+
Returns
|
|
1092
|
+
-------
|
|
1093
|
+
Tensor (..., n_frames, n_pixels_y, n_pixels_z)
|
|
1094
|
+
"""
|
|
1095
|
+
# Extract coordinates and mask
|
|
1096
|
+
frame_nr = spots.frame_nr # (..., K, M)
|
|
1097
|
+
y_pix = spots.y_pixel
|
|
1098
|
+
z_pix = spots.z_pixel
|
|
1099
|
+
valid = spots.valid
|
|
1100
|
+
|
|
1101
|
+
# Flatten batch dims for processing
|
|
1102
|
+
orig_shape = frame_nr.shape # (..., K, M)
|
|
1103
|
+
batch_shape = orig_shape[:-2]
|
|
1104
|
+
n_batch = 1
|
|
1105
|
+
for s in batch_shape:
|
|
1106
|
+
n_batch *= s
|
|
1107
|
+
KM = orig_shape[-2] * orig_shape[-1]
|
|
1108
|
+
|
|
1109
|
+
# Reshape to (B, KM, 3) where 3 = (frame, y, z)
|
|
1110
|
+
coords = torch.stack([frame_nr, y_pix, z_pix], dim=-1) # (..., K, M, 3)
|
|
1111
|
+
coords = coords.reshape(n_batch, KM, 3)
|
|
1112
|
+
mask = valid.reshape(n_batch, KM)
|
|
1113
|
+
|
|
1114
|
+
# Zero out invalid spots
|
|
1115
|
+
coords = coords * mask.unsqueeze(-1)
|
|
1116
|
+
|
|
1117
|
+
device = coords.device
|
|
1118
|
+
dtype = coords.dtype
|
|
1119
|
+
|
|
1120
|
+
grids = torch.zeros(n_batch, n_frames, n_pixels_y, n_pixels_z,
|
|
1121
|
+
dtype=dtype, device=device)
|
|
1122
|
+
gaussian_factor = -0.5 / (sigma ** 2)
|
|
1123
|
+
|
|
1124
|
+
# Filter non-zero coordinates
|
|
1125
|
+
batch_ids = torch.arange(n_batch, device=device).unsqueeze(1).expand(-1, KM)
|
|
1126
|
+
non_zero_mask = mask > 0.5
|
|
1127
|
+
non_zero_coords = coords[non_zero_mask]
|
|
1128
|
+
non_zero_batch = batch_ids[non_zero_mask]
|
|
1129
|
+
|
|
1130
|
+
if non_zero_coords.shape[0] == 0:
|
|
1131
|
+
return grids.reshape(*batch_shape, n_frames, n_pixels_y, n_pixels_z)
|
|
1132
|
+
|
|
1133
|
+
rounded = non_zero_coords.round().long()
|
|
1134
|
+
|
|
1135
|
+
# Neighborhood offsets
|
|
1136
|
+
offsets = torch.arange(-radius, radius + 1, device=device)
|
|
1137
|
+
oz, ox, oy = torch.meshgrid(offsets, offsets, offsets, indexing="ij")
|
|
1138
|
+
local_offsets = torch.stack([oz, ox, oy], dim=-1).reshape(-1, 3)
|
|
1139
|
+
n_local = local_offsets.shape[0]
|
|
1140
|
+
|
|
1141
|
+
expanded_centers = rounded.unsqueeze(1) # (P, 1, 3)
|
|
1142
|
+
neighbors = expanded_centers + local_offsets.unsqueeze(0) # (P, L, 3)
|
|
1143
|
+
|
|
1144
|
+
f_neigh = neighbors[..., 0].clamp(0, n_frames - 1)
|
|
1145
|
+
y_neigh = neighbors[..., 1].clamp(0, n_pixels_y - 1)
|
|
1146
|
+
z_neigh = neighbors[..., 2].clamp(0, n_pixels_z - 1)
|
|
1147
|
+
|
|
1148
|
+
distances = torch.sum(
|
|
1149
|
+
(neighbors.float() - non_zero_coords.unsqueeze(1)) ** 2, dim=-1
|
|
1150
|
+
)
|
|
1151
|
+
weights = torch.exp(distances * gaussian_factor)
|
|
1152
|
+
|
|
1153
|
+
# Flatten for scatter_add
|
|
1154
|
+
f_flat = f_neigh.flatten()
|
|
1155
|
+
y_flat = y_neigh.flatten()
|
|
1156
|
+
z_flat = z_neigh.flatten()
|
|
1157
|
+
w_flat = weights.flatten()
|
|
1158
|
+
|
|
1159
|
+
batch_flat = non_zero_batch.unsqueeze(1).expand(-1, n_local).flatten()
|
|
1160
|
+
|
|
1161
|
+
flat_idx = (batch_flat * (n_frames * n_pixels_y * n_pixels_z)
|
|
1162
|
+
+ f_flat * (n_pixels_y * n_pixels_z)
|
|
1163
|
+
+ y_flat * n_pixels_z
|
|
1164
|
+
+ z_flat)
|
|
1165
|
+
|
|
1166
|
+
grids.view(-1).scatter_add_(0, flat_idx, w_flat)
|
|
1167
|
+
|
|
1168
|
+
return grids.reshape(*batch_shape, n_frames, n_pixels_y, n_pixels_z)
|
|
1169
|
+
|
|
1170
|
+
# ------------------------------------------------------------------
|
|
1171
|
+
# predict_spot_coords (FF/pf output mode)
|
|
1172
|
+
# ------------------------------------------------------------------
|
|
1173
|
+
|
|
1174
|
+
@staticmethod
|
|
1175
|
+
def predict_spot_coords(
|
|
1176
|
+
spots: SpotDescriptors,
|
|
1177
|
+
space: str = "angular",
|
|
1178
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1179
|
+
"""Extract spot coordinates for COM matching (FF/pf mode).
|
|
1180
|
+
|
|
1181
|
+
Parameters
|
|
1182
|
+
----------
|
|
1183
|
+
spots : SpotDescriptors
|
|
1184
|
+
space : str
|
|
1185
|
+
``"angular"``: return (2theta, eta, omega) in radians.
|
|
1186
|
+
``"detector"``: return (y_pixel, z_pixel, frame_nr).
|
|
1187
|
+
|
|
1188
|
+
Returns
|
|
1189
|
+
-------
|
|
1190
|
+
coords : Tensor (..., K, M, 3)
|
|
1191
|
+
valid : Tensor (..., K, M)
|
|
1192
|
+
"""
|
|
1193
|
+
if space == "angular":
|
|
1194
|
+
coords = torch.stack(
|
|
1195
|
+
[spots.two_theta, spots.eta, spots.omega], dim=-1
|
|
1196
|
+
)
|
|
1197
|
+
elif space == "detector":
|
|
1198
|
+
coords = torch.stack(
|
|
1199
|
+
[spots.y_pixel, spots.z_pixel, spots.frame_nr], dim=-1
|
|
1200
|
+
)
|
|
1201
|
+
else:
|
|
1202
|
+
raise ValueError(f"Unknown space: {space!r}. Use 'angular' or 'detector'.")
|
|
1203
|
+
return coords, spots.valid
|
|
1204
|
+
|
|
1205
|
+
# ------------------------------------------------------------------
|
|
1206
|
+
# Triangular voxel support (NF-HEDM)
|
|
1207
|
+
# ------------------------------------------------------------------
|
|
1208
|
+
|
|
1209
|
+
@staticmethod
|
|
1210
|
+
def tri_vertices(
|
|
1211
|
+
centers: torch.Tensor,
|
|
1212
|
+
edge_lengths: torch.Tensor,
|
|
1213
|
+
ud: torch.Tensor,
|
|
1214
|
+
) -> torch.Tensor:
|
|
1215
|
+
"""Compute 3 triangle vertices from voxel centres.
|
|
1216
|
+
|
|
1217
|
+
Matches ``simulateNF.c`` lines 556-572.
|
|
1218
|
+
|
|
1219
|
+
Parameters
|
|
1220
|
+
----------
|
|
1221
|
+
centers : Tensor (N, 2) or (N, 3)
|
|
1222
|
+
Voxel centres [x, y] or [x, y, z] in micrometers.
|
|
1223
|
+
edge_lengths : Tensor (N,)
|
|
1224
|
+
Edge length per voxel in micrometers.
|
|
1225
|
+
ud : Tensor (N,)
|
|
1226
|
+
Up/down flag (+1 or -1).
|
|
1227
|
+
|
|
1228
|
+
Returns
|
|
1229
|
+
-------
|
|
1230
|
+
Tensor (N, 3, 3)
|
|
1231
|
+
Vertices ``[V0, V1, V2]`` each of shape ``(3,)`` = ``[x, y, z]``.
|
|
1232
|
+
If input centres are 2D, z is set to 0.
|
|
1233
|
+
"""
|
|
1234
|
+
if centers.shape[-1] == 2:
|
|
1235
|
+
centers = F.pad(centers, (0, 1), value=0.0)
|
|
1236
|
+
|
|
1237
|
+
xs = centers[:, 0]
|
|
1238
|
+
ys = centers[:, 1]
|
|
1239
|
+
zs = centers[:, 2]
|
|
1240
|
+
|
|
1241
|
+
gs = edge_lengths / 2.0
|
|
1242
|
+
sqrt3 = math.sqrt(3.0)
|
|
1243
|
+
dy1 = edge_lengths / sqrt3
|
|
1244
|
+
dy2 = -edge_lengths / (2.0 * sqrt3)
|
|
1245
|
+
# flip if ud < 0
|
|
1246
|
+
sign = torch.sign(ud)
|
|
1247
|
+
dy1 = dy1 * sign
|
|
1248
|
+
dy2 = dy2 * sign
|
|
1249
|
+
|
|
1250
|
+
# V0 = (xs, ys+dy1, zs), V1 = (xs-gs, ys+dy2, zs), V2 = (xs+gs, ys+dy2, zs)
|
|
1251
|
+
V0 = torch.stack([xs, ys + dy1, zs], dim=-1)
|
|
1252
|
+
V1 = torch.stack([xs - gs, ys + dy2, zs], dim=-1)
|
|
1253
|
+
V2 = torch.stack([xs + gs, ys + dy2, zs], dim=-1)
|
|
1254
|
+
|
|
1255
|
+
return torch.stack([V0, V1, V2], dim=1) # (N, 3, 3)
|
|
1256
|
+
|
|
1257
|
+
@staticmethod
|
|
1258
|
+
def _c_round(x: float) -> int:
|
|
1259
|
+
"""C-compatible round(): half away from zero.
|
|
1260
|
+
|
|
1261
|
+
Python's ``round()`` uses banker's rounding (half to even),
|
|
1262
|
+
which differs from C's ``round()`` at ``.5`` boundaries.
|
|
1263
|
+
C's ``(int)round(x)`` = ``floor(x + 0.5)`` for x >= 0,
|
|
1264
|
+
``ceil(x - 0.5)`` for x < 0.
|
|
1265
|
+
"""
|
|
1266
|
+
return int(math.floor(x + 0.5)) if x >= 0 else int(math.ceil(x - 0.5))
|
|
1267
|
+
|
|
1268
|
+
@staticmethod
|
|
1269
|
+
def rasterize_triangle(v0y, v0z, v1y, v1z, v2y, v2z):
|
|
1270
|
+
"""Rasterize a single triangle on an integer pixel grid.
|
|
1271
|
+
|
|
1272
|
+
Matches ``CalcPixels2`` in ``SharedFuncsFit.c`` lines 308-370.
|
|
1273
|
+
Uses the edge-function rasterizer with a distance-based border
|
|
1274
|
+
(``distSq < 0.9801``) for edges, exactly as the C code does.
|
|
1275
|
+
|
|
1276
|
+
Parameters
|
|
1277
|
+
----------
|
|
1278
|
+
v0y, v0z, v1y, v1z, v2y, v2z : int
|
|
1279
|
+
Rounded integer pixel coordinates of the 3 vertices.
|
|
1280
|
+
|
|
1281
|
+
Returns
|
|
1282
|
+
-------
|
|
1283
|
+
list of (int, int)
|
|
1284
|
+
List of (y, z) integer pixel coordinates inside the triangle.
|
|
1285
|
+
"""
|
|
1286
|
+
min_y = min(v0y, v1y, v2y)
|
|
1287
|
+
max_y = max(v0y, v1y, v2y)
|
|
1288
|
+
min_z = min(v0z, v1z, v2z)
|
|
1289
|
+
max_z = max(v0z, v1z, v2z)
|
|
1290
|
+
|
|
1291
|
+
# Edge function coefficients (matching C variable names)
|
|
1292
|
+
A01 = v0z - v1z; B01 = v1y - v0y
|
|
1293
|
+
A12 = v1z - v2z; B12 = v2y - v1y
|
|
1294
|
+
A20 = v2z - v0z; B20 = v0y - v2y
|
|
1295
|
+
|
|
1296
|
+
def orient2d(ax, ay, bx, by, cx, cy):
|
|
1297
|
+
return (bx - ax) * (cy - ay) - (by - ay) * (cx - ax)
|
|
1298
|
+
|
|
1299
|
+
def dist_sq_to_edge(ax, ay, bx, by, px, py):
|
|
1300
|
+
num = (bx - ax) * (py - ay) - (by - ay) * (px - ax)
|
|
1301
|
+
den_sq = (ay - by) ** 2 + (bx - ax) ** 2
|
|
1302
|
+
if den_sq == 0:
|
|
1303
|
+
return 1e30
|
|
1304
|
+
return (num * num) / den_sq
|
|
1305
|
+
|
|
1306
|
+
pixels = []
|
|
1307
|
+
w0_row = orient2d(v1y, v1z, v2y, v2z, min_y, min_z)
|
|
1308
|
+
w1_row = orient2d(v2y, v2z, v0y, v0z, min_y, min_z)
|
|
1309
|
+
w2_row = orient2d(v0y, v0z, v1y, v1z, min_y, min_z)
|
|
1310
|
+
|
|
1311
|
+
for pz in range(min_z, max_z + 1):
|
|
1312
|
+
w0 = w0_row
|
|
1313
|
+
w1 = w1_row
|
|
1314
|
+
w2 = w2_row
|
|
1315
|
+
for py in range(min_y, max_y + 1):
|
|
1316
|
+
inside = (w0 >= 0 and w1 >= 0 and w2 >= 0)
|
|
1317
|
+
if not inside:
|
|
1318
|
+
# Check distance to each edge (0.9801 = 0.99^2)
|
|
1319
|
+
inside = (
|
|
1320
|
+
dist_sq_to_edge(v1y, v1z, v2y, v2z, py, pz) < 0.9801 or
|
|
1321
|
+
dist_sq_to_edge(v2y, v2z, v0y, v0z, py, pz) < 0.9801 or
|
|
1322
|
+
dist_sq_to_edge(v0y, v0z, v1y, v1z, py, pz) < 0.9801
|
|
1323
|
+
)
|
|
1324
|
+
if inside:
|
|
1325
|
+
pixels.append((py, pz))
|
|
1326
|
+
w0 += A12
|
|
1327
|
+
w1 += A20
|
|
1328
|
+
w2 += A01
|
|
1329
|
+
w0_row += B12
|
|
1330
|
+
w1_row += B20
|
|
1331
|
+
w2_row += B01
|
|
1332
|
+
|
|
1333
|
+
return pixels
|
|
1334
|
+
|
|
1335
|
+
def forward_nf_triangles(
|
|
1336
|
+
self,
|
|
1337
|
+
euler_angles: torch.Tensor,
|
|
1338
|
+
centers: torch.Tensor,
|
|
1339
|
+
tri_config: TriVoxelConfig,
|
|
1340
|
+
lattice_params: Optional[torch.Tensor] = None,
|
|
1341
|
+
strain: Optional[torch.Tensor] = None,
|
|
1342
|
+
) -> SpotDescriptors:
|
|
1343
|
+
"""NF-HEDM forward simulation with triangular voxel rasterization.
|
|
1344
|
+
|
|
1345
|
+
Matches the full ``simulateNF`` / ``CalcFracOverlap`` pipeline:
|
|
1346
|
+
compute Bragg geometry once per voxel, project 3 triangle vertices,
|
|
1347
|
+
rasterize the detector-space triangle, check all distances.
|
|
1348
|
+
|
|
1349
|
+
Parameters
|
|
1350
|
+
----------
|
|
1351
|
+
euler_angles : Tensor (N, 3) in radians
|
|
1352
|
+
centers : Tensor (N, 2) or (N, 3) in micrometers
|
|
1353
|
+
tri_config : TriVoxelConfig
|
|
1354
|
+
lattice_params : optional strain
|
|
1355
|
+
|
|
1356
|
+
Returns
|
|
1357
|
+
-------
|
|
1358
|
+
SpotDescriptors with per-pixel y_pixel/z_pixel (not per-vertex).
|
|
1359
|
+
y_pixel, z_pixel: lists of per-voxel, per-spot rasterized pixel coords.
|
|
1360
|
+
For bit-level comparison, use ``predict_spotsinfo_bits`` instead.
|
|
1361
|
+
"""
|
|
1362
|
+
if centers.shape[-1] == 2:
|
|
1363
|
+
centers = F.pad(centers, (0, 1), value=0.0)
|
|
1364
|
+
|
|
1365
|
+
N = centers.shape[0]
|
|
1366
|
+
vertices = self.tri_vertices(
|
|
1367
|
+
centers, tri_config.edge_lengths, tri_config.ud
|
|
1368
|
+
) # (N, 3, 3)
|
|
1369
|
+
|
|
1370
|
+
# 1. Compute orientation matrices.
|
|
1371
|
+
# C: reads radians from .mic, multiplies by rad2deg, then
|
|
1372
|
+
# Euler2OrientMat uses cosd()=cos(deg2rad*x). With both C and Python
|
|
1373
|
+
# now using M_PI/180, the roundtrip is lossless and we can use
|
|
1374
|
+
# radians directly.
|
|
1375
|
+
euler_deg_c = euler_angles * self.RAD2DEG
|
|
1376
|
+
euler_rad_c = euler_deg_c * self.DEG2RAD
|
|
1377
|
+
orientation_matrices = self.euler2mat(euler_rad_c)
|
|
1378
|
+
|
|
1379
|
+
# 2. Optionally strained HKLs
|
|
1380
|
+
hkls_cart = thetas = None
|
|
1381
|
+
if lattice_params is not None:
|
|
1382
|
+
hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
|
|
1383
|
+
elif strain is not None:
|
|
1384
|
+
raise ValueError(
|
|
1385
|
+
"strain was supplied but lattice_params is None; strain "
|
|
1386
|
+
"requires a reference lattice to apply (I + eps)^{-1} @ B0."
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
# 3. Bragg geometry (once per voxel, from center orientation)
|
|
1390
|
+
omega, eta, two_theta, valid = self.calc_bragg_geometry(
|
|
1391
|
+
orientation_matrices, hkls_cart, thetas
|
|
1392
|
+
)
|
|
1393
|
+
# omega, eta, two_theta, valid: (2N, M)
|
|
1394
|
+
|
|
1395
|
+
# Recompute G_C for eta recomputation below
|
|
1396
|
+
dtype = omega.dtype
|
|
1397
|
+
use_hkls = hkls_cart if hkls_cart is not None else self.hkls.to(dtype)
|
|
1398
|
+
G_C = torch.einsum("...nij,mj->...nmi", orientation_matrices, use_hkls)
|
|
1399
|
+
|
|
1400
|
+
# 4. Frame number
|
|
1401
|
+
frame_nr = (omega / self.DEG2RAD - self.omega_start) / self.omega_step
|
|
1402
|
+
frame_ok = (frame_nr >= 0) & (frame_nr < self.n_frames)
|
|
1403
|
+
valid = valid * frame_ok.float()
|
|
1404
|
+
|
|
1405
|
+
# 5. Project each vertex through DisplacementSpots for each distance
|
|
1406
|
+
D = self.n_distances
|
|
1407
|
+
dtype = omega.dtype
|
|
1408
|
+
Lsd_0 = self._Lsd[0].to(dtype)
|
|
1409
|
+
|
|
1410
|
+
# ---------------------------------------------------------------
|
|
1411
|
+
# Replicate C's exact CalcOmega→CalcSpotPosition chain to avoid
|
|
1412
|
+
# float-precision differences at pixel boundaries.
|
|
1413
|
+
#
|
|
1414
|
+
# C CalcOmega stores omega in DEGREES: omega_deg = acos(...)*rad2deg
|
|
1415
|
+
# C then recomputes eta by RotateAroundZ(G, omega_DEG):
|
|
1416
|
+
# internally: omega_rad2 = omega_deg * deg2rad (roundtrip!)
|
|
1417
|
+
# gw = Rz(omega_rad2) @ G
|
|
1418
|
+
# eta_deg = CalcEtaAngle(gw[1], gw[2])
|
|
1419
|
+
# C CalcSpotPosition: eta_rad = eta_deg * deg2rad (another roundtrip!)
|
|
1420
|
+
# yl = -sin(eta_rad) * RingRadius
|
|
1421
|
+
# zl = cos(eta_rad) * RingRadius
|
|
1422
|
+
# C RingRadius = Lsd * tan(2 * deg2rad * Theta_deg)
|
|
1423
|
+
#
|
|
1424
|
+
# We replicate every degree conversion.
|
|
1425
|
+
# ---------------------------------------------------------------
|
|
1426
|
+
omega_deg = omega * self.RAD2DEG
|
|
1427
|
+
omega_rad_c = omega_deg * self.DEG2RAD # C's deg2rad * omega_deg
|
|
1428
|
+
|
|
1429
|
+
cos_w = torch.cos(omega_rad_c)
|
|
1430
|
+
sin_w = torch.sin(omega_rad_c)
|
|
1431
|
+
|
|
1432
|
+
# Recompute eta via C's chain: rotate G by omega_deg, then CalcEtaAngle
|
|
1433
|
+
G_C_doubled = torch.cat([G_C, G_C], dim=-3) # (2N, M, 3)
|
|
1434
|
+
# gw = Rz(omega_rad_c) @ G
|
|
1435
|
+
gw_y = G_C_doubled[..., 0] * sin_w + G_C_doubled[..., 1] * cos_w
|
|
1436
|
+
gw_z = G_C_doubled[..., 2]
|
|
1437
|
+
r_yz = torch.sqrt(gw_y * gw_y + gw_z * gw_z).clamp(min=self.epsilon)
|
|
1438
|
+
eta_c_rad = torch.acos(torch.clamp(gw_z / r_yz,
|
|
1439
|
+
-1.0 + self.epsilon, 1.0 - self.epsilon))
|
|
1440
|
+
eta_c_rad = -torch.sign(gw_y) * eta_c_rad # C: if (y > 0) alpha = -alpha
|
|
1441
|
+
# Convert to degrees then back (C's CalcSpotPosition chain)
|
|
1442
|
+
eta_c_deg = eta_c_rad * self.RAD2DEG
|
|
1443
|
+
eta_c_rad2 = eta_c_deg * self.DEG2RAD
|
|
1444
|
+
|
|
1445
|
+
theta_deg = (two_theta / 2.0) * self.RAD2DEG
|
|
1446
|
+
ring_radius = Lsd_0 * torch.tan(2.0 * self.DEG2RAD * theta_deg)
|
|
1447
|
+
|
|
1448
|
+
sin_eta_c = torch.sin(eta_c_rad2)
|
|
1449
|
+
cos_eta_c = torch.cos(eta_c_rad2)
|
|
1450
|
+
ythis = -sin_eta_c * ring_radius
|
|
1451
|
+
zthis = cos_eta_c * ring_radius
|
|
1452
|
+
|
|
1453
|
+
# Project center reference. For non-zero tilts, apply the NF
|
|
1454
|
+
# ray-plane intersection (SharedFuncsFit.c:947-958).
|
|
1455
|
+
# YZSpotsTemp = outxyz/px + bc
|
|
1456
|
+
ybc_0 = self._y_BC[0].to(dtype)
|
|
1457
|
+
zbc_0 = self._z_BC[0].to(dtype)
|
|
1458
|
+
Lsd_0_scalar = self._Lsd[0].to(dtype)
|
|
1459
|
+
y_center_lab, z_center_lab = self._apply_nf_tilt(ythis, zthis, Lsd_0_scalar)
|
|
1460
|
+
y_center = y_center_lab / self.px + ybc_0 # (2N, M)
|
|
1461
|
+
z_center = z_center_lab / self.px + zbc_0
|
|
1462
|
+
|
|
1463
|
+
# Project each vertex: DisplacementSpots for each of 3 vertices
|
|
1464
|
+
# vertices: (N, 3, 3) -> double to (2N, 3, 3)
|
|
1465
|
+
verts_doubled = torch.cat([vertices, vertices], dim=0) # (2N, 3, 3)
|
|
1466
|
+
|
|
1467
|
+
# For each vertex v, compute Displ_Y, Displ_Z:
|
|
1468
|
+
# xa = vx*cos(w) - vy*sin(w), ya = vx*sin(w) + vy*cos(w)
|
|
1469
|
+
# Displ_Y = ya + ythis*(1-xa/Lsd), Displ_Z = (1-xa/Lsd)*zthis
|
|
1470
|
+
vert_y_pixel = []
|
|
1471
|
+
vert_z_pixel = []
|
|
1472
|
+
for vi in range(3):
|
|
1473
|
+
vx = verts_doubled[:, vi, 0].unsqueeze(-1) # (2N, 1)
|
|
1474
|
+
vy = verts_doubled[:, vi, 1].unsqueeze(-1)
|
|
1475
|
+
|
|
1476
|
+
xa = vx * cos_w - vy * sin_w # (2N, M)
|
|
1477
|
+
ya = vx * sin_w + vy * cos_w
|
|
1478
|
+
|
|
1479
|
+
t = 1.0 - xa / Lsd_0
|
|
1480
|
+
displ_y = ya + ythis * t
|
|
1481
|
+
displ_z = t * zthis
|
|
1482
|
+
|
|
1483
|
+
# Apply NF tilt (no-op when tilts are zero)
|
|
1484
|
+
displ_y_tilt, displ_z_tilt = self._apply_nf_tilt(
|
|
1485
|
+
displ_y, displ_z, Lsd_0_scalar
|
|
1486
|
+
)
|
|
1487
|
+
yp = displ_y_tilt / self.px + ybc_0
|
|
1488
|
+
zp = displ_z_tilt / self.px + zbc_0
|
|
1489
|
+
vert_y_pixel.append(yp)
|
|
1490
|
+
vert_z_pixel.append(zp)
|
|
1491
|
+
|
|
1492
|
+
# 6. Compute relative offsets from center (matching C lines 584-586)
|
|
1493
|
+
# YZSpots[k] = YZSpotsT[k] - YZSpotsTemp
|
|
1494
|
+
rel_y = [vyp - y_center for vyp in vert_y_pixel] # 3 x (2N, M)
|
|
1495
|
+
rel_z = [vzp - z_center for vzp in vert_z_pixel]
|
|
1496
|
+
|
|
1497
|
+
# 7. Rasterize and collect hits
|
|
1498
|
+
# This is the per-spot loop (not vectorizable due to variable triangle sizes)
|
|
1499
|
+
K = 2 * N
|
|
1500
|
+
M = omega.shape[-1]
|
|
1501
|
+
|
|
1502
|
+
all_hits = [] # list of (vox_nr, dist_nr, frame, y_px, z_px, omega_deg)
|
|
1503
|
+
|
|
1504
|
+
for k in range(K):
|
|
1505
|
+
for m in range(M):
|
|
1506
|
+
if valid[k, m] < 0.5:
|
|
1507
|
+
continue
|
|
1508
|
+
|
|
1509
|
+
# Relative triangle vertices in pixel coords
|
|
1510
|
+
ry = [rel_y[vi][k, m].item() for vi in range(3)]
|
|
1511
|
+
rz = [rel_z[vi][k, m].item() for vi in range(3)]
|
|
1512
|
+
|
|
1513
|
+
gs_um = tri_config.edge_lengths[k % N].item()
|
|
1514
|
+
cr = self._c_round
|
|
1515
|
+
if gs_um > self.px:
|
|
1516
|
+
# Rasterize triangle (CalcPixels2 rounds vertices, line 312)
|
|
1517
|
+
v0y, v0z = cr(ry[0]), cr(rz[0])
|
|
1518
|
+
v1y, v1z = cr(ry[1]), cr(rz[1])
|
|
1519
|
+
v2y, v2z = cr(ry[2]), cr(rz[2])
|
|
1520
|
+
pixels = self.rasterize_triangle(v0y, v0z, v1y, v1z, v2y, v2z)
|
|
1521
|
+
else:
|
|
1522
|
+
# Single center pixel (C line 596-599: (int)round(...))
|
|
1523
|
+
cy = cr((ry[0] + ry[1] + ry[2]) / 3.0)
|
|
1524
|
+
cz = cr((rz[0] + rz[1] + rz[2]) / 3.0)
|
|
1525
|
+
pixels = [(cy, cz)]
|
|
1526
|
+
|
|
1527
|
+
# Absolute pixel = center + offset, check all distances
|
|
1528
|
+
yc = y_center[k, m].item()
|
|
1529
|
+
zc = z_center[k, m].item()
|
|
1530
|
+
frame = int(frame_nr[k, m].item()) # C uses (int) truncation
|
|
1531
|
+
ome_deg = omega[k, m].item() * self.RAD2DEG
|
|
1532
|
+
|
|
1533
|
+
for (py_off, pz_off) in pixels:
|
|
1534
|
+
all_dists_ok = True
|
|
1535
|
+
layer_pixels = []
|
|
1536
|
+
for d in range(D):
|
|
1537
|
+
Lsd_d = self._Lsd[d].item()
|
|
1538
|
+
ybc_d = self._y_BC[d].item()
|
|
1539
|
+
zbc_d = self._z_BC[d].item()
|
|
1540
|
+
# Scale to this distance (matching C lines 605-613)
|
|
1541
|
+
my = int(math.floor(
|
|
1542
|
+
((yc - ybc_0) * self.px * (Lsd_d / Lsd_0)) / self.px
|
|
1543
|
+
+ ybc_d
|
|
1544
|
+
)) + py_off
|
|
1545
|
+
mz = int(math.floor(
|
|
1546
|
+
((zc - zbc_0) * self.px * (Lsd_d / Lsd_0)) / self.px
|
|
1547
|
+
+ zbc_d
|
|
1548
|
+
)) + pz_off
|
|
1549
|
+
if my < 0 or my >= self.n_pixels_y or mz < 0 or mz >= self.n_pixels_z:
|
|
1550
|
+
all_dists_ok = False
|
|
1551
|
+
break
|
|
1552
|
+
layer_pixels.append((d, frame, my, mz))
|
|
1553
|
+
|
|
1554
|
+
if all_dists_ok:
|
|
1555
|
+
vox_nr = k % N
|
|
1556
|
+
for (d, fr, my, mz) in layer_pixels:
|
|
1557
|
+
all_hits.append((vox_nr, d, fr, my, mz, ome_deg))
|
|
1558
|
+
|
|
1559
|
+
return all_hits
|