midas-diffract 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,248 @@
1
+ """Single-grain orientation + lattice-parameter recovery via gradient descent.
2
+
3
+ Three-phase L-BFGS schedule (orientation -> lattice -> joint), with
4
+ nearest-neighbour spot association at every step. Designed for the
5
+ companion-paper FF-HEDM single-grain demo, but works for any geometry
6
+ the underlying ``HEDMForwardModel`` supports.
7
+
8
+ Quick start
9
+ -----------
10
+ import midas_diffract as md
11
+ result = md.optimize_single_grain(
12
+ model,
13
+ observed_spots=obs_angular, # (N, 3): (2theta, eta, omega) in rad
14
+ init_euler=init_euler_rad, # (3,) Bunge angles in rad
15
+ init_lattice=init_latc, # (6,) [a, b, c, alpha, beta, gamma]
16
+ position=torch.zeros(3), # grain centroid in lab frame (um)
17
+ loss=md.SpotMatchingLoss(metric="l2"),
18
+ )
19
+ print(result["misori_deg"], result["lattice_errors"])
20
+
21
+ The function does not assume any particular weighting; pass a
22
+ :class:`midas_diffract.SpotMatchingLoss` configured with whatever
23
+ ``weights=`` you need (typically derived from measurement resolution).
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import math
28
+ from typing import Any, Callable, Dict, Optional
29
+
30
+ import torch
31
+
32
+ from .forward import HEDMForwardModel
33
+ from .losses import SpotMatchingLoss
34
+
35
+ DEG2RAD = math.pi / 180.0
36
+ RAD2DEG = 180.0 / math.pi
37
+
38
+
39
+ def _associate(pred_valid: torch.Tensor, observed: torch.Tensor,
40
+ max_dist: float) -> "tuple[torch.Tensor, torch.Tensor]":
41
+ """Nearest-neighbour observed->predicted association.
42
+
43
+ Returns ``(pred_matched, obs_matched)``. Both have leading dim equal to
44
+ the number of observed spots whose nearest predicted neighbour is
45
+ within ``max_dist``.
46
+ """
47
+ dists = torch.cdist(observed, pred_valid)
48
+ min_dists, nn_idx = dists.min(dim=1)
49
+ keep = min_dists < max_dist
50
+ return pred_valid[nn_idx[keep]], observed[keep]
51
+
52
+
53
+ def optimize_single_grain(
54
+ model: HEDMForwardModel,
55
+ observed_spots: torch.Tensor,
56
+ init_euler: torch.Tensor,
57
+ init_lattice: torch.Tensor,
58
+ position: Optional[torch.Tensor] = None,
59
+ *,
60
+ loss: Optional[SpotMatchingLoss] = None,
61
+ max_match_distance: float = 0.5,
62
+ min_matches: int = 5,
63
+ phase1_steps: int = 15,
64
+ phase2_steps: int = 15,
65
+ phase3_steps: int = 10,
66
+ lbfgs_max_iter: int = 20,
67
+ convergence_misori_deg: float = 1e-3,
68
+ convergence_lattice_err: float = 1e-5,
69
+ verbose: bool = False,
70
+ ) -> Dict[str, Any]:
71
+ """Recover Bunge Euler angles and lattice parameters for a single grain.
72
+
73
+ Three-phase L-BFGS schedule:
74
+ 1. Orientation only (Euler angles) -- eta/omega-sensitive.
75
+ 2. Lattice parameters only -- 2theta-sensitive.
76
+ 3. Joint refinement.
77
+
78
+ Parameters
79
+ ----------
80
+ model : HEDMForwardModel
81
+ Pre-built forward model. Its ``hkls`` and ``thetas`` set the
82
+ reflection list against which the grain is fit.
83
+ observed_spots : Tensor (N, 3)
84
+ Angular coordinates of observed spots in radians:
85
+ ``(2theta, eta, omega)``. Use the ``angular`` output of
86
+ :meth:`HEDMForwardModel.predict_spot_coords` for synthetic data.
87
+ init_euler : Tensor (3,)
88
+ Initial Bunge Euler angles in radians.
89
+ init_lattice : Tensor (6,)
90
+ Initial lattice parameters ``[a, b, c, alpha, beta, gamma]``,
91
+ in Angstroms / degrees.
92
+ position : Tensor (3,), optional
93
+ Grain centroid in the lab frame (microns). Defaults to the origin.
94
+ loss : SpotMatchingLoss, optional
95
+ Loss object. Defaults to ``SpotMatchingLoss(metric="l2")``.
96
+ max_match_distance : float
97
+ Discard observed spots whose nearest predicted neighbour is
98
+ farther than this in the angular metric (radians).
99
+ min_matches : int
100
+ If fewer than this many spots are matched at any iteration,
101
+ return a sentinel large loss. Prevents L-BFGS from diverging
102
+ through a near-empty match set.
103
+ phase1_steps, phase2_steps, phase3_steps : int
104
+ Outer-loop step counts per phase. Each step is one L-BFGS call
105
+ with up to ``lbfgs_max_iter`` inner iterations.
106
+ convergence_misori_deg, convergence_lattice_err : float
107
+ Early-exit thresholds for phases 1, 2, and 3.
108
+ verbose : bool
109
+ If True, print a per-step progress table.
110
+
111
+ Returns
112
+ -------
113
+ dict with keys:
114
+ ``euler_rad`` -- (3,) recovered Euler angles in radians.
115
+ ``euler_deg`` -- (3,) same in degrees.
116
+ ``lattice`` -- (6,) recovered lattice parameters.
117
+ ``misori_deg``-- final misorientation against ``init_euler`` (deg)
118
+ when no ground truth is supplied; see
119
+ :func:`evaluate_recovery` for ground-truth eval.
120
+ ``loss_history`` -- list of phase-final loss values.
121
+ """
122
+ if loss is None:
123
+ loss = SpotMatchingLoss(metric="l2")
124
+ if position is None:
125
+ position = torch.zeros(3, dtype=init_euler.dtype, device=init_euler.device)
126
+
127
+ pos = position.unsqueeze(0)
128
+ R_init = HEDMForwardModel.euler2mat(init_euler).detach()
129
+
130
+ opt_euler = init_euler.clone().requires_grad_(True)
131
+ opt_latc = init_lattice.clone().requires_grad_(False)
132
+ loss_history: list = []
133
+
134
+ def make_closure(params):
135
+ def closure():
136
+ for p in params:
137
+ if p.grad is not None:
138
+ p.grad.zero_()
139
+ spots = model(opt_euler.unsqueeze(0), pos, lattice_params=opt_latc)
140
+ coords, valid = HEDMForwardModel.predict_spot_coords(
141
+ spots, space="angular"
142
+ )
143
+ pred_flat = coords.squeeze().reshape(-1, 3)
144
+ valid_flat = valid.squeeze().reshape(-1)
145
+ pred_valid = pred_flat[valid_flat > 0.5]
146
+ if pred_valid.shape[0] == 0:
147
+ return torch.tensor(
148
+ 1e6, dtype=opt_euler.dtype, requires_grad=True,
149
+ )
150
+ pred_match, obs_match = _associate(
151
+ pred_valid, observed_spots, max_match_distance
152
+ )
153
+ if pred_match.shape[0] < min_matches:
154
+ return torch.tensor(
155
+ 1e6, dtype=opt_euler.dtype, requires_grad=True,
156
+ )
157
+ l = loss(pred_match, obs_match)
158
+ l.backward()
159
+ return l
160
+ return closure
161
+
162
+ def current_misori_deg() -> float:
163
+ with torch.no_grad():
164
+ R_cur = HEDMForwardModel.euler2mat(opt_euler)
165
+ trace = torch.trace(R_init.T @ R_cur)
166
+ return torch.acos(torch.clamp((trace - 1) / 2, -1, 1)).item() * RAD2DEG
167
+
168
+ def log(step: int, l: torch.Tensor) -> None:
169
+ if not verbose:
170
+ return
171
+ misori = current_misori_deg()
172
+ lat_err = (opt_latc[:3] - init_lattice[:3]).abs().max().item()
173
+ print(f"{step:5d} {l.item():12.6e} {misori:12.6f} {lat_err:10.6f}")
174
+
175
+ if verbose:
176
+ print(f"{'Step':>5} {'Loss':>12} {'dMisori(deg)':>12} {'dLat':>10}")
177
+ print("-" * 55)
178
+ print("--- Phase 1: Orientation ---")
179
+
180
+ optimizer = torch.optim.LBFGS(
181
+ [opt_euler], lr=1.0, max_iter=lbfgs_max_iter,
182
+ line_search_fn="strong_wolfe",
183
+ )
184
+ for step in range(phase1_steps):
185
+ l = optimizer.step(make_closure([opt_euler]))
186
+ log(step, l)
187
+ if current_misori_deg() < convergence_misori_deg and step > 0:
188
+ break
189
+ loss_history.append(float(l.detach()))
190
+
191
+ if verbose:
192
+ print("--- Phase 2: Lattice parameters ---")
193
+ opt_euler.requires_grad_(False)
194
+ opt_latc.requires_grad_(True)
195
+ optimizer = torch.optim.LBFGS(
196
+ [opt_latc], lr=1.0, max_iter=lbfgs_max_iter,
197
+ line_search_fn="strong_wolfe",
198
+ )
199
+ for step in range(phase2_steps):
200
+ l = optimizer.step(make_closure([opt_latc]))
201
+ log(step + phase1_steps, l)
202
+ if (opt_latc[:3] - init_lattice[:3]).abs().max().item() > 0 \
203
+ and abs(float(l.detach()) - loss_history[-1]) < convergence_lattice_err:
204
+ break
205
+ loss_history.append(float(l.detach()))
206
+
207
+ if verbose:
208
+ print("--- Phase 3: Joint refinement ---")
209
+ opt_euler.requires_grad_(True)
210
+ opt_latc.requires_grad_(True)
211
+ optimizer = torch.optim.LBFGS(
212
+ [opt_euler, opt_latc], lr=0.5, max_iter=lbfgs_max_iter,
213
+ line_search_fn="strong_wolfe",
214
+ )
215
+ for step in range(phase3_steps):
216
+ l = optimizer.step(make_closure([opt_euler, opt_latc]))
217
+ log(step + phase1_steps + phase2_steps, l)
218
+ loss_history.append(float(l.detach()))
219
+
220
+ return {
221
+ "euler_rad": opt_euler.detach().clone(),
222
+ "euler_deg": opt_euler.detach().clone() * RAD2DEG,
223
+ "lattice": opt_latc.detach().clone(),
224
+ "misori_deg": current_misori_deg(),
225
+ "loss_history": loss_history,
226
+ }
227
+
228
+
229
+ def evaluate_recovery(
230
+ result: Dict[str, Any],
231
+ gt_euler: torch.Tensor,
232
+ gt_lattice: torch.Tensor,
233
+ ) -> Dict[str, float]:
234
+ """Evaluate a recovery against ground truth.
235
+
236
+ Returns misorientation (deg) and per-element lattice errors.
237
+ Useful in unit tests and the demo notebooks.
238
+ """
239
+ R_gt = HEDMForwardModel.euler2mat(gt_euler)
240
+ R_rec = HEDMForwardModel.euler2mat(result["euler_rad"])
241
+ trace = torch.trace(R_gt.T @ R_rec)
242
+ misori = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)).item() * RAD2DEG
243
+ lat_err = (result["lattice"] - gt_lattice).abs()
244
+ return {
245
+ "misori_deg": misori,
246
+ "lattice_max_err": lat_err.max().item(),
247
+ "lattice_errors": lat_err.detach().clone(),
248
+ }
@@ -0,0 +1,122 @@
1
+ Metadata-Version: 2.4
2
+ Name: midas-diffract
3
+ Version: 0.1.0
4
+ Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
+ Author: Simon Zhang, Nina Andrejevic, Mathew Cherukara
6
+ Author-email: Hemant Sharma <hsharma@anl.gov>
7
+ License-Expression: BSD-3-Clause
8
+ Project-URL: Homepage, https://github.com/marinerhemant/MIDAS
9
+ Project-URL: Documentation, https://github.com/marinerhemant/MIDAS
10
+ Project-URL: Issues, https://github.com/marinerhemant/MIDAS/issues
11
+ Keywords: HEDM,3DXRD,differentiable physics,PyTorch,forward model,crystallography,polycrystal,grain
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Topic :: Scientific/Engineering :: Physics
16
+ Requires-Python: >=3.9
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: numpy>=1.22
20
+ Requires-Dist: torch>=2.0
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest>=7.0; extra == "dev"
23
+ Requires-Dist: pytest-cov; extra == "dev"
24
+ Provides-Extra: hkls
25
+ Requires-Dist: midas-hkls>=0.1.0; extra == "hkls"
26
+ Dynamic: license-file
27
+
28
+ # midas-diffract
29
+
30
+ End-to-end differentiable forward model for High-Energy Diffraction Microscopy (HEDM), covering far-field (FF), near-field (NF), and point-focused (pf-HEDM) geometries. Pixel-exact agreement with the canonical C reference simulators in [MIDAS](https://github.com/marinerhemant/MIDAS).
31
+
32
+ Companion paper: Sharma, Zhang, Andrejevic & Cherukara, *An End-to-End Differentiable Forward Model for High-Energy Diffraction Microscopy*, IUCrJ (in preparation, 2026).
33
+
34
+ ## Installation
35
+
36
+ ```bash
37
+ pip install midas-diffract # core forward model + losses + optimizer
38
+ pip install midas-diffract[hkls] # also installs midas-hkls for the
39
+ # pure-Python reflection-list helper
40
+ ```
41
+
42
+ Optional PyTorch CUDA or MPS back-ends are used automatically if available.
43
+
44
+ ## Quick start
45
+
46
+ ```python
47
+ import torch
48
+ import midas_diffract as md
49
+ from midas_hkls import Lattice, SpaceGroup # optional, see [hkls]
50
+
51
+ # Detector + scan geometry
52
+ geom = md.HEDMGeometry(
53
+ Lsd=1_000_000.0, # um
54
+ y_BC=1024.0, z_BC=1024.0,
55
+ px=200.0,
56
+ omega_start=0.0, omega_step=0.25, n_frames=1440,
57
+ n_pixels_y=2048, n_pixels_z=2048,
58
+ min_eta=6.0,
59
+ wavelength=0.172979, # Angstroms
60
+ )
61
+
62
+ # Reflection list: either compute from a SpaceGroup + Lattice via the
63
+ # midas-hkls helper, or supply (hkls_cart, thetas, hkls_int) yourself
64
+ # (e.g. parsed from MIDAS GetHKLList output).
65
+ sg = SpaceGroup.from_number(225) # FCC
66
+ lat = Lattice.for_system("cubic", a=4.08) # Au
67
+ hkls_cart, thetas, hkls_int = md.hkls_for_forward_model(
68
+ sg, lat, wavelength_A=geom.wavelength, two_theta_max_deg=15.0,
69
+ )
70
+
71
+ model = md.HEDMForwardModel(
72
+ hkls=hkls_cart, thetas=thetas, geometry=geom, hkls_int=hkls_int,
73
+ )
74
+
75
+ # Forward pass: grain state -> predicted spots. All inputs are leaves
76
+ # of the autograd graph.
77
+ euler = torch.tensor([[45.0, 30.0, 60.0]], requires_grad=True) * (3.14159 / 180)
78
+ pos = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True)
79
+ latc = torch.tensor([4.08, 4.08, 4.08, 90.0, 90.0, 90.0], requires_grad=True)
80
+ spots = model(euler, pos, lattice_params=latc)
81
+
82
+ # Scalar loss -> gradients w.r.t. orientation, position, lattice
83
+ loss = ((spots.omega * spots.valid) ** 2).sum()
84
+ loss.backward()
85
+ ```
86
+
87
+ ## Output modes
88
+
89
+ - `md.HEDMForwardModel.predict_spot_coords(spots, space="angular")` — returns
90
+ `(2θ, η, ω)` in radians for each valid reflection (FF and pf-HEDM).
91
+ - `md.HEDMForwardModel.predict_spot_coords(spots, space="detector")` — returns
92
+ `(y_pixel, z_pixel, frame_nr)` in fractional units (FF and pf-HEDM).
93
+ - `md.HEDMForwardModel.predict_images(spots, ...)` — renders a differentiable
94
+ 3D detector volume via Gaussian splatting (NF-HEDM output mode).
95
+
96
+ ## Validation
97
+
98
+ The forward model has been validated to pixel-exact agreement against the
99
+ canonical C simulators `ForwardSimulationCompressed` and `simulateNF` in the
100
+ MIDAS distribution. See the companion paper and the MIDAS repository
101
+ `fwd_sim/paper/` directory for reproducibility scripts.
102
+
103
+ ## Scope
104
+
105
+ `midas-diffract` v0.1.0 is deliberately focused on the forward model and its
106
+ gradient chain. The following capabilities build on this substrate and are
107
+ released separately as they mature:
108
+
109
+ - Sub-voxel grain mixtures
110
+ - Physics-informed regularisation
111
+ - Bayesian uncertainty quantification via HMC / variational inference
112
+ - Temporal 4D-HEDM tracking
113
+ - Coupling to differentiable crystal plasticity (JAX-FEM)
114
+ - EM spot ownership for ambiguous FF patterns
115
+
116
+ ## Citation
117
+
118
+ If you use `midas-diffract` in published work, please cite the companion paper.
119
+
120
+ ## Licence
121
+
122
+ BSD-3-Clause.
@@ -0,0 +1,10 @@
1
+ midas_diffract/__init__.py,sha256=YzApb7RC00sN8nQyWGnABXWR6tdsnV0EXTDfwg4acCA,1709
2
+ midas_diffract/forward.py,sha256=S0VXDsZmUVuhxOA1WfihH7xFhZTU_aggRgoASJNkNhk,63110
3
+ midas_diffract/hkls.py,sha256=yihZkshLhTEaSF11hjN5m1EzbZAZo5wHv9DX6nWFERE,6496
4
+ midas_diffract/losses.py,sha256=QTpSq-I6N_BBc1i0ocom1JGe4HPcA5fbTt2vJegLHJs,16109
5
+ midas_diffract/optimize.py,sha256=kU2lh5ekM0honOIxg1n3-GwKckcoStovU97epHzKKUQ,9406
6
+ midas_diffract-0.1.0.dist-info/licenses/LICENSE,sha256=aWLB2q1q7h9xTi8OBqYGyjcAS_ik-mDFbRx-yoVZxeI,1603
7
+ midas_diffract-0.1.0.dist-info/METADATA,sha256=SxdtyEnavHcNq9JTha8_SURsqNxUg-mabPSelElgsCs,4751
8
+ midas_diffract-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
9
+ midas_diffract-0.1.0.dist-info/top_level.txt,sha256=M7QaYqJ9qjhJINmex9O9LMEAhW85zTvR3kpdwG_1oSE,15
10
+ midas_diffract-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,31 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2026, UChicago Argonne, LLC, operator of Argonne National
4
+ Laboratory, and the midas-diffract authors.
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ 1. Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from
19
+ this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31
+ POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1 @@
1
+ midas_diffract