midas-diffract 0.4.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/PKG-INFO +1 -1
  2. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/__init__.py +1 -1
  3. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/forward.py +133 -5
  4. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/simulate_panel_zarrs.py +12 -16
  5. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/PKG-INFO +1 -1
  6. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/pyproject.toml +1 -1
  7. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_forward.py +84 -0
  8. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/LICENSE +0 -0
  9. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/README.md +0 -0
  10. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/hkls.py +0 -0
  11. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/losses.py +0 -0
  12. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/optimize.py +0 -0
  13. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/SOURCES.txt +0 -0
  14. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
  15. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/requires.txt +0 -0
  16. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/top_level.txt +0 -0
  17. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/setup.cfg +0 -0
  18. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_c_comparison.py +0 -0
  19. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_distortion_layer.py +0 -0
  20. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_hkls.py +0 -0
  21. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_losses.py +0 -0
  22. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_multi_detector.py +0 -0
  23. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_strain_tensor.py +0 -0
  24. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_tilts.py +0 -0
  25. {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_wedge.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.4.0
3
+ Version: 0.6.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -28,7 +28,7 @@ Quick start
28
28
  loss.backward()
29
29
  """
30
30
 
31
- __version__ = "0.4.0"
31
+ __version__ = "0.6.0"
32
32
 
33
33
  from .forward import (
34
34
  HEDMForwardModel,
@@ -27,6 +27,7 @@ Reference C code:
27
27
  """
28
28
 
29
29
  import math
30
+ import warnings
30
31
  from dataclasses import dataclass, field
31
32
  from typing import Optional, Tuple
32
33
 
@@ -738,6 +739,7 @@ class HEDMForwardModel(nn.Module):
738
739
  orientation_matrices: torch.Tensor,
739
740
  hkls_cart: Optional[torch.Tensor] = None,
740
741
  thetas: Optional[torch.Tensor] = None,
742
+ per_grain: bool = False,
741
743
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
742
744
  """Core Bragg geometry: orientations + G-vectors -> angles.
743
745
 
@@ -779,14 +781,32 @@ class HEDMForwardModel(nn.Module):
779
781
  # batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
780
782
  # rendering. Both flow through the same einsum via leading-dim
781
783
  # broadcasting on the second arg.
782
- G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
784
+ #
785
+ # per_grain=True: ELEMENT-WISE pairing of grain i's orientation with
786
+ # grain i's strained hkls -- NO orientation x strain cross-product.
787
+ # Requires orientation_matrices (N,3,3); hkls_cart (N,M,3) per-grain or
788
+ # (M,3) shared. Used by forward_per_grain() for the O(N*M) fast path.
789
+ if per_grain:
790
+ if hkls_cart.dim() == 2: # (M,3) shared lattice
791
+ G_C = torch.einsum("nij,mj->nmi", orientation_matrices, hkls_cart)
792
+ else: # (N,M,3) per-grain strain
793
+ G_C = torch.einsum("nij,nmj->nmi", orientation_matrices, hkls_cart)
794
+ else:
795
+ G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
783
796
 
784
797
  # v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
785
798
  # (rotation preserves norm in exact arithmetic but not in float64).
786
799
  # Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
787
800
  len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
788
801
  v_no_wedge = torch.sin(thetas) * len_hkl # (M,) or (..., M)
789
- v_no_wedge = v_no_wedge.unsqueeze(-2).expand_as(G_C[..., 0]) # (..., N, M)
802
+ # Broadcast v to G_C's (..., N, M) grid. Layout-agnostic: in the
803
+ # cross-product layout v gains an orientation axis at -2; in the
804
+ # per_grain / shared layouts it already matches. Numerically identical
805
+ # to the former ``unsqueeze(-2).expand_as`` for those existing paths.
806
+ gc0 = G_C[..., 0]
807
+ while v_no_wedge.dim() < gc0.dim():
808
+ v_no_wedge = v_no_wedge.unsqueeze(-2)
809
+ v_no_wedge = v_no_wedge.expand_as(gc0)
790
810
 
791
811
  # ---- Wedge: rigorous geometric formulation -------------------
792
812
  # The rotation axis tilts from z to n_hat = (sin W, 0, cos W).
@@ -900,9 +920,16 @@ class HEDMForwardModel(nn.Module):
900
920
  eta = self.safe_arccos(Gz_lab / r_yz)
901
921
  eta = -torch.sign(Gy_lab) * eta
902
922
 
903
- # 2*theta (broadcast thetas to match 2N dimension)
904
- two_theta_single = 2.0 * thetas.unsqueeze(-2) # (..., 1, M) or (1, M)
905
- two_theta = two_theta_single.expand_as(all_omega)
923
+ # 2*theta -- broadcast to the single-branch (pre-cat) grid, then double
924
+ # along the grain axis to match all_omega's 2N. Layout-agnostic (handles
925
+ # the per_grain layout where the 2N axis is grain-doubled) and
926
+ # numerically identical to the former expand for the cross-product path
927
+ # (thetas does not depend on the orientation axis).
928
+ tt = 2.0 * thetas
929
+ while tt.dim() < omega_p.dim():
930
+ tt = tt.unsqueeze(-2)
931
+ tt = tt.expand_as(omega_p)
932
+ two_theta = torch.cat([tt, tt], dim=-2)
906
933
 
907
934
  # Validity mask
908
935
  valid_p = disc_valid & coswp_valid
@@ -1354,6 +1381,11 @@ class HEDMForwardModel(nn.Module):
1354
1381
  if positions.shape[-1] == 2:
1355
1382
  positions = F.pad(positions, (0, 1), value=0.0)
1356
1383
 
1384
+ # Footgun guard: per-grain lattice/strain + N>1 orientations forms an
1385
+ # N x N orientation x strain cross-product (output (N, 2N, M)); callers
1386
+ # simulating a fixed polycrystal almost always want only the diagonal.
1387
+ self._warn_if_cross_product(euler_angles, lattice_params, strain)
1388
+
1357
1389
  # 1. Orientation matrices
1358
1390
  orientation_matrices = self.euler2mat(euler_angles)
1359
1391
 
@@ -1382,6 +1414,102 @@ class HEDMForwardModel(nn.Module):
1382
1414
 
1383
1415
  return spots
1384
1416
 
1417
+ @staticmethod
1418
+ def _warn_if_cross_product(euler_angles, lattice_params, strain):
1419
+ """Warn when forward() would form an orientation x strain cross-product.
1420
+
1421
+ Fires only when there are N>1 orientations AND lattice_params/strain
1422
+ carry a matching per-grain axis -- the case where the (N, 2N, M) output
1423
+ is an N x N cross-product and the caller likely wanted the diagonal.
1424
+ Shared lattice/strain (no grain axis) is the correct (2N, M) path and
1425
+ does not warn.
1426
+ """
1427
+ n_orient = euler_angles.shape[-2] if euler_angles.dim() >= 2 else 1
1428
+ if n_orient <= 1:
1429
+ return
1430
+
1431
+ def _has_grain_axis(t):
1432
+ if t is None:
1433
+ return False
1434
+ if t.shape[-1] == 6 and t.dim() >= 2: # Voigt lattice or strain
1435
+ return t.shape[-2] == n_orient
1436
+ if tuple(t.shape[-2:]) == (3, 3) and t.dim() >= 3: # full-tensor strain
1437
+ return t.shape[-3] == n_orient
1438
+ return False
1439
+
1440
+ if _has_grain_axis(lattice_params) or _has_grain_axis(strain):
1441
+ warnings.warn(
1442
+ f"forward() called with {n_orient} orientations and per-grain "
1443
+ "lattice_params/strain forms an orientation x strain "
1444
+ f"cross-product (output shape ({n_orient}, {2 * n_orient}, M)); "
1445
+ "only the diagonal [i, i] and [i, i+N] is physical. For a fixed "
1446
+ "polycrystal use forward_per_grain() (element-wise, O(N*M)), "
1447
+ "or index the diagonal of this output.",
1448
+ stacklevel=3,
1449
+ )
1450
+
1451
+ def forward_per_grain(
1452
+ self,
1453
+ euler_angles: torch.Tensor,
1454
+ positions: torch.Tensor,
1455
+ lattice_params: Optional[torch.Tensor] = None,
1456
+ strain: Optional[torch.Tensor] = None,
1457
+ ) -> SpotDescriptors:
1458
+ """Element-wise per-grain forward simulation -- the fast path.
1459
+
1460
+ Grain ``i`` is simulated with orientation ``i``, lattice/strain ``i``
1461
+ and position ``i``, WITHOUT the orientation x strain cross-product that
1462
+ :meth:`forward` forms when both are per-grain. The output has leading
1463
+ shape ``(2N, M)`` (the two omega branches doubled along the grain axis),
1464
+ which is exactly the diagonal of :meth:`forward`'s ``(N, 2N, M)`` output
1465
+ -- so gradient and gradient-free callers agree bit-for-bit.
1466
+
1467
+ Cost is O(N*M) rather than O(N^2 * M), matching the algorithm of the C
1468
+ reference ``ForwardSimulationCompressed.c``. Fully differentiable and
1469
+ device-portable; for pure forward simulation wrap the call in
1470
+ ``torch.inference_mode()``.
1471
+
1472
+ Parameters
1473
+ ----------
1474
+ euler_angles : Tensor (N, 3)
1475
+ Bunge ZXZ Euler angles (radians), one per grain. No leading batch.
1476
+ positions : Tensor (N, 3) or (N, 2)
1477
+ Real-space grain positions (micrometers).
1478
+ lattice_params : Tensor (6,) or (N, 6), optional
1479
+ Shared or per-grain lattice [a,b,c,alpha,beta,gamma] (Ang/deg).
1480
+ strain : Tensor (6,), (N, 6), (3, 3), or (N, 3, 3), optional
1481
+ Shared or per-grain crystal-frame strain (plain-Voigt or full 3x3).
1482
+
1483
+ Returns
1484
+ -------
1485
+ SpotDescriptors with leading shape ``(2N, M)``: grain ``i``'s two omega
1486
+ solutions live at axis-(-2) indices ``i`` and ``i + N``.
1487
+ """
1488
+ if positions.shape[-1] == 2:
1489
+ positions = F.pad(positions, (0, 1), value=0.0)
1490
+
1491
+ orientation_matrices = self.euler2mat(euler_angles) # (N, 3, 3)
1492
+
1493
+ hkls_cart = None
1494
+ thetas = None
1495
+ if lattice_params is not None:
1496
+ hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
1497
+ elif strain is not None:
1498
+ raise ValueError(
1499
+ "strain was supplied but lattice_params is None; strain "
1500
+ "requires a reference lattice to apply (I + eps)^{-1} @ B0."
1501
+ )
1502
+
1503
+ omega, eta, two_theta, valid = self.calc_bragg_geometry(
1504
+ orientation_matrices, hkls_cart, thetas, per_grain=True
1505
+ )
1506
+ spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
1507
+
1508
+ if self.scan_config is not None:
1509
+ spots = self.filter_by_scan(spots, positions)
1510
+
1511
+ return spots
1512
+
1385
1513
  # ------------------------------------------------------------------
1386
1514
  # filter_by_scan (beam proximity for pf-HEDM)
1387
1515
  # ------------------------------------------------------------------
@@ -613,18 +613,13 @@ def simulate_panel_zarrs(
613
613
  f"BC=({panel.y_bc:.1f},{panel.z_bc:.1f}), "
614
614
  f"tx={panel.tx:.2f}° ty={panel.ty:.2f}° tz={panel.tz:.2f}°")
615
615
  with torch.no_grad():
616
- spots = model(eulers_t, positions_t,
617
- lattice_params=latc, strain=strain_t)
618
-
619
- valid_np = (spots.valid > 0.5).cpu().numpy()
620
- # Strain × orientation diagonal mask
621
- G_strain, Kdim, Mdim = valid_np.shape
622
- if G_strain == n_grains and Kdim == 2 * n_grains:
623
- diag_mask = np.zeros((G_strain, Kdim, Mdim), dtype=bool)
624
- for gi in range(G_strain):
625
- diag_mask[gi, gi, :] = True
626
- diag_mask[gi, gi + n_grains, :] = True
627
- valid_np = valid_np & diag_mask
616
+ # Per-grain fast path: O(N*M), no orientation x strain cross-product.
617
+ # Output is (2N, M) -- already the per-grain diagonal, no mask needed.
618
+ spots = model.forward_per_grain(eulers_t, positions_t,
619
+ lattice_params=latc, strain=strain_t)
620
+
621
+ valid_np = (spots.valid > 0.5).cpu().numpy() # (2N, M)
622
+ Kdim, Mdim = valid_np.shape # Kdim == 2 * n_grains
628
623
 
629
624
  # First-wins: drop spots that an earlier panel already took
630
625
  # (avoids double-counting in the rare overlap).
@@ -652,10 +647,11 @@ def simulate_panel_zarrs(
652
647
  y_pix = spots.y_pixel.cpu().numpy()
653
648
  z_pix = spots.z_pixel.cpu().numpy()
654
649
  frame_nr = spots.frame_nr.cpu().numpy()
655
- grain_ids = np.broadcast_to(np.arange(G_strain)[:, None, None],
656
- (G_strain, Kdim, Mdim))
657
- hkl_ids = np.broadcast_to(np.arange(Mdim)[None, None, :],
658
- (G_strain, Kdim, Mdim))
650
+ # (2N, M): row k holds grain (k % n_grains); rows [0:N]/[N:2N] are the
651
+ # two omega branches.
652
+ grain_ids = np.broadcast_to((np.arange(Kdim) % n_grains)[:, None],
653
+ (Kdim, Mdim))
654
+ hkl_ids = np.broadcast_to(np.arange(Mdim)[None, :], (Kdim, Mdim))
659
655
 
660
656
  rec = {
661
657
  "grain_id": grain_ids.reshape(-1)[flat_idx].astype(np.int32),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.4.0
3
+ Version: 0.6.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "midas-diffract"
7
- version = "0.4.0"
7
+ version = "0.6.0"
8
8
  description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
9
9
  readme = "README.md"
10
10
  license = "BSD-3-Clause"
@@ -1201,5 +1201,89 @@ class TestStrainTensorInput:
1201
1201
  torch.testing.assert_close(sp_v.omega, sp_S.omega)
1202
1202
 
1203
1203
 
1204
+ # ===================================================================
1205
+ # Test: forward_per_grain (fast path) == diagonal of forward()
1206
+ # ===================================================================
1207
+
1208
+ class TestForwardPerGrain:
1209
+ """The O(N*M) per-grain path must equal the diagonal of forward()'s
1210
+ O(N^2*M) orientation x strain cross-product, bit-for-bit, so the
1211
+ gradient and gradient-free callers agree."""
1212
+
1213
+ def _setup(self, nf_geometry, device, N=6, strained=True):
1214
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1215
+ torch.manual_seed(7)
1216
+ euler = torch.rand(N, 3, dtype=torch.float64) * 2 * math.pi
1217
+ pos = torch.arange(N * 3, dtype=torch.float64).reshape(N, 3) * 5.0
1218
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.],
1219
+ dtype=torch.float64).expand(N, 6)
1220
+ if strained:
1221
+ s = torch.randn(N, 3, 3, dtype=torch.float64) * 1e-3
1222
+ strain = (s + s.transpose(-1, -2)) / 2
1223
+ else:
1224
+ strain = None
1225
+ return model, euler, pos, latc, strain, N
1226
+
1227
+ def test_per_grain_equals_forward_diagonal_strained(self, nf_geometry, device):
1228
+ # Compare per-grain (no global sort, which would reorder near-degenerate
1229
+ # spots): forward() diagonal rows [gi,gi] & [gi,gi+N] vs
1230
+ # forward_per_grain() rows [gi] & [gi+N]. Element-wise einsum differs
1231
+ # from the cross-product einsum only by fp64 reduction order (~1e-12).
1232
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1233
+ spf = model(euler, pos, lattice_params=latc, strain=strain) # (N,2N,M)
1234
+ spg = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain) # (2N,M)
1235
+
1236
+ vf = spf.valid > 0.5
1237
+ vg = spg.valid > 0.5
1238
+ max_diff = 0.0
1239
+ for fld in ("two_theta", "eta", "omega", "y_pixel", "z_pixel", "frame_nr"):
1240
+ F = getattr(spf, fld)
1241
+ G = getattr(spg, fld)
1242
+ for gi in range(N):
1243
+ for kf, kg in ((gi, gi), (gi + N, gi + N)):
1244
+ # validity must agree exactly on the diagonal
1245
+ assert torch.equal(vf[gi, kf], vg[kg])
1246
+ max_diff = max(max_diff, float((F[gi, kf] - G[kg]).abs().max()))
1247
+ assert max_diff < 1e-9, f"per-grain vs forward-diagonal max diff {max_diff:.2e}"
1248
+
1249
+ def test_per_grain_shape_is_2N(self, nf_geometry, device):
1250
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1251
+ sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
1252
+ assert sp.valid.shape[0] == 2 * N
1253
+
1254
+ def test_per_grain_shared_lattice_no_strain(self, nf_geometry, device):
1255
+ # nominal (no lattice) path must also work and avoid cross-product
1256
+ model, euler, pos, latc, _, N = self._setup(nf_geometry, device, strained=False)
1257
+ sp = model.forward_per_grain(euler, pos)
1258
+ assert sp.valid.shape[0] == 2 * N
1259
+
1260
+ def test_per_grain_differentiable(self, nf_geometry, device):
1261
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1262
+ euler = euler.clone().requires_grad_(True)
1263
+ sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
1264
+ # omega/eta/pixels depend on orientation (two_theta does not)
1265
+ (sp.omega * sp.valid).sum().backward()
1266
+ assert euler.grad is not None and torch.all(torch.isfinite(euler.grad))
1267
+
1268
+ def test_per_grain_fp32(self, nf_geometry, device):
1269
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1270
+ sp = model.forward_per_grain(euler.float(), pos.float(),
1271
+ lattice_params=latc.float(), strain=strain.float())
1272
+ assert sp.valid.dtype == torch.float32 or sp.valid.dtype == torch.float64
1273
+
1274
+ def test_forward_warns_on_cross_product(self, nf_geometry, device):
1275
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1276
+ with pytest.warns(UserWarning, match="cross-product"):
1277
+ model(euler, pos, lattice_params=latc, strain=strain)
1278
+
1279
+ def test_forward_no_warn_shared_lattice(self, nf_geometry, device):
1280
+ import warnings as _w
1281
+ model, euler, pos, _, _, N = self._setup(nf_geometry, device, strained=False)
1282
+ shared = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.], dtype=torch.float64)
1283
+ with _w.catch_warnings():
1284
+ _w.simplefilter("error") # any cross-product warning becomes an error
1285
+ model(euler, pos, lattice_params=shared) # shared -> (2N,M), no warn
1286
+
1287
+
1204
1288
  if __name__ == "__main__":
1205
1289
  pytest.main([__file__, "-v"])
File without changes
File without changes
File without changes