midas-diffract 0.2.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/PKG-INFO +2 -1
  2. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/__init__.py +1 -1
  3. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/forward.py +202 -46
  4. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/simulate_panel_zarrs.py +12 -16
  5. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/PKG-INFO +2 -1
  6. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/requires.txt +1 -0
  7. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/pyproject.toml +2 -1
  8. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_forward.py +174 -0
  9. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/LICENSE +0 -0
  10. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/README.md +0 -0
  11. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/hkls.py +0 -0
  12. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/losses.py +0 -0
  13. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/optimize.py +0 -0
  14. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/SOURCES.txt +0 -0
  15. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
  16. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/top_level.txt +0 -0
  17. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/setup.cfg +0 -0
  18. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_c_comparison.py +0 -0
  19. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_distortion_layer.py +0 -0
  20. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_hkls.py +0 -0
  21. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_losses.py +0 -0
  22. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_multi_detector.py +0 -0
  23. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_strain_tensor.py +0 -0
  24. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_tilts.py +0 -0
  25. {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_wedge.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.2.0
3
+ Version: 0.6.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -19,6 +19,7 @@ License-File: LICENSE
19
19
  Requires-Dist: numpy>=1.22
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: midas-distortion>=0.2.0
22
+ Requires-Dist: midas-stress>=0.8.0
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: pytest>=7.0; extra == "dev"
24
25
  Requires-Dist: pytest-cov; extra == "dev"
@@ -28,7 +28,7 @@ Quick start
28
28
  loss.backward()
29
29
  """
30
30
 
31
- __version__ = "0.2.0"
31
+ __version__ = "0.6.0"
32
32
 
33
33
  from .forward import (
34
34
  HEDMForwardModel,
@@ -27,6 +27,7 @@ Reference C code:
27
27
  """
28
28
 
29
29
  import math
30
+ import warnings
30
31
  from dataclasses import dataclass, field
31
32
  from typing import Optional, Tuple
32
33
 
@@ -34,6 +35,16 @@ import torch
34
35
  import torch.nn as nn
35
36
  import torch.nn.functional as F
36
37
 
38
+ # Canonical orientation + strain-frame primitives. midas_stress is the single
39
+ # source of truth for this math (Bunge ZXZ orientation algebra, sample<->crystal
40
+ # strain rotation); its torch backend is differentiable end-to-end and
41
+ # device-portable, so the forward model delegates rather than re-porting.
42
+ # NOTE: midas_stress's Voigt is Voigt-MANDEL (sqrt2 on shears); this model uses
43
+ # PLAIN-Voigt / raw 3x3 strain, so we only delegate the *rotation*, never the
44
+ # Voigt packing (see rotate_strain_sample_to_crystal / correct_hkls_latc).
45
+ from midas_stress.orientation import euler_to_orient_mat as _ms_euler_to_orient_mat
46
+ from midas_stress.tensor import strain_lab_to_grain as _ms_strain_lab_to_grain
47
+
37
48
 
38
49
  # ---------------------------------------------------------------------------
39
50
  # Configuration data classes
@@ -437,7 +448,13 @@ class HEDMForwardModel(nn.Module):
437
448
 
438
449
  @staticmethod
439
450
  def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
440
- """Convert ZXZ Euler angles to rotation matrices.
451
+ """Convert ZXZ (Bunge) Euler angles to crystal->sample rotation matrices.
452
+
453
+ Delegates to ``midas_stress.orientation.euler_to_orient_mat`` -- the
454
+ canonical orientation primitive -- so the convention can never drift
455
+ from the rest of MIDAS. midas_stress's torch backend is differentiable
456
+ and vmap-safe; the result is identical to the former in-line ZXZ build
457
+ (R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)) to ~1e-16.
441
458
 
442
459
  Parameters
443
460
  ----------
@@ -447,35 +464,14 @@ class HEDMForwardModel(nn.Module):
447
464
  Returns
448
465
  -------
449
466
  Tensor (..., 3, 3)
450
- Rotation matrices.
467
+ Rotation matrices (crystal->sample), orthogonalized onto SO(3).
451
468
  """
452
- c = torch.cos(euler_angles)
453
- s = torch.sin(euler_angles)
454
-
455
- c0, c1, c2 = c[..., 0], c[..., 1], c[..., 2]
456
- s0, s1, s2 = s[..., 0], s[..., 1], s[..., 2]
457
-
458
- # ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
459
- # Verified element-by-element against nfhedm.py lines 114-120.
460
- # Built via torch.stack rather than indexed assignment so the function
461
- # composes with torch.func.vmap (in-place writes block vmap).
462
- row0 = torch.stack([
463
- c0 * c2 - s0 * c1 * s2,
464
- -s0 * c1 * c2 - c0 * s2,
465
- s0 * s1,
466
- ], dim=-1)
467
- row1 = torch.stack([
468
- s0 * c2 + c0 * c1 * s2,
469
- c0 * c1 * c2 - s0 * s2,
470
- -c0 * s1,
471
- ], dim=-1)
472
- row2 = torch.stack([
473
- s1 * s2,
474
- s1 * c2,
475
- c1,
476
- ], dim=-1)
477
- R = torch.stack([row0, row1, row2], dim=-2)
478
-
469
+ if not isinstance(euler_angles, torch.Tensor):
470
+ euler_angles = torch.as_tensor(euler_angles)
471
+ R = _ms_euler_to_orient_mat(euler_angles) # (..., 9), torch
472
+ R = R.reshape(*R.shape[:-1], 3, 3)
473
+ # midas_stress already returns a proper rotation; orthogonalize keeps the
474
+ # historical "exactly on SO(3)" guarantee and is idempotent here.
479
475
  return HEDMForwardModel.orthogonalize(R)
480
476
 
481
477
  # ------------------------------------------------------------------
@@ -525,6 +521,28 @@ class HEDMForwardModel(nn.Module):
525
521
  """Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
526
522
  return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
527
523
 
524
+ # ------------------------------------------------------------------
525
+ # strain_as_voigt (accept full 3x3 tensor OR plain-Voigt 6-vector)
526
+ # ------------------------------------------------------------------
527
+
528
+ @staticmethod
529
+ def strain_as_voigt(strain: torch.Tensor) -> torch.Tensor:
530
+ """Normalize a strain input to PLAIN-Voigt [e11,e12,e13,e22,e23,e33].
531
+
532
+ Accepts either a plain-Voigt ``(..., 6)`` tensor (returned unchanged) or
533
+ a full symmetric ``(..., 3, 3)`` strain tensor. The 3x3 path is
534
+ convention-free -- the natural way to hand a strain field straight from
535
+ a tensor source (e.g. a midas_stress strain field) into the forward
536
+ model without picking a Voigt/Mandel packing. The off-diagonals are
537
+ taken as TRUE tensor components (no factor of 2).
538
+ """
539
+ if strain.dim() >= 2 and strain.shape[-1] == 3 and strain.shape[-2] == 3:
540
+ return torch.stack([
541
+ strain[..., 0, 0], strain[..., 0, 1], strain[..., 0, 2],
542
+ strain[..., 1, 1], strain[..., 1, 2], strain[..., 2, 2],
543
+ ], dim=-1)
544
+ return strain
545
+
528
546
  # ------------------------------------------------------------------
529
547
  # rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
530
548
  # ------------------------------------------------------------------
@@ -536,19 +554,27 @@ class HEDMForwardModel(nn.Module):
536
554
  ) -> torch.Tensor:
537
555
  """Rotate a symmetric infinitesimal strain from sample to crystal frame.
538
556
 
539
- Port of ``RotateStrainSampleToCrystal`` from
557
+ Matches ``RotateStrainSampleToCrystal`` from
540
558
  ``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
541
- eps_crystal = OM^T . eps_sample . OM, in Voigt notation
542
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
559
+ eps_crystal = OM^T . eps_sample . OM.
560
+
561
+ The rotation is delegated to ``midas_stress.tensor.strain_lab_to_grain``
562
+ (the canonical sample/lab -> crystal/grain strain transform; bit-identical
563
+ to the former in-line ``OM^T S OM``). PLAIN-Voigt pack/unpack is kept here
564
+ on purpose -- midas_stress's Voigt is Mandel (sqrt2 on shears), which must
565
+ not touch the forward model's strain input.
543
566
 
544
567
  Parameters
545
568
  ----------
546
569
  orientation_matrices : Tensor (..., 3, 3)
570
+ Crystal->sample matrices (as returned by :meth:`euler2mat`).
547
571
  strain_sample : Tensor (..., 6)
572
+ PLAIN-Voigt [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
548
573
 
549
574
  Returns
550
575
  -------
551
576
  strain_crystal : Tensor (..., 6)
577
+ PLAIN-Voigt, same layout.
552
578
  """
553
579
  e = strain_sample
554
580
  S = torch.stack([
@@ -556,8 +582,7 @@ class HEDMForwardModel(nn.Module):
556
582
  torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
557
583
  torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
558
584
  ], dim=-2)
559
- OM = orientation_matrices
560
- C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
585
+ C = _ms_strain_lab_to_grain(S, orientation_matrices) # OM^T S OM
561
586
  return torch.stack([
562
587
  C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
563
588
  C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
@@ -587,10 +612,12 @@ class HEDMForwardModel(nn.Module):
587
612
  lattice_params : Tensor (..., 6)
588
613
  [a, b, c, alpha, beta, gamma] in Angstroms and degrees.
589
614
  The ``...`` dimensions allow per-voxel or per-grain parameters.
590
- strain : Tensor (..., 6), optional
591
- Crystal-frame symmetric infinitesimal strain in Voigt form
592
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. When supplied,
593
- the reciprocal lattice is post-multiplied by (I + eps)^{-1}:
615
+ strain : Tensor (..., 6) or (..., 3, 3), optional
616
+ Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
617
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
618
+ 3x3 tensor (normalized via :meth:`strain_as_voigt`; off-diagonals are
619
+ true tensor components, no factor of 2). When supplied, the
620
+ reciprocal lattice is post-multiplied by (I + eps)^{-1}:
594
621
  B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
595
622
  to convert a sample-frame strain into the crystal frame.
596
623
 
@@ -671,6 +698,7 @@ class HEDMForwardModel(nn.Module):
671
698
  # Voigt layout matches C CorrectHKLsLatCEpsilon:
672
699
  # eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
673
700
  if strain is not None:
701
+ strain = self.strain_as_voigt(strain) # accept full 3x3 too
674
702
  e11 = strain[..., 0]
675
703
  e12 = strain[..., 1]
676
704
  e13 = strain[..., 2]
@@ -711,6 +739,7 @@ class HEDMForwardModel(nn.Module):
711
739
  orientation_matrices: torch.Tensor,
712
740
  hkls_cart: Optional[torch.Tensor] = None,
713
741
  thetas: Optional[torch.Tensor] = None,
742
+ per_grain: bool = False,
714
743
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
715
744
  """Core Bragg geometry: orientations + G-vectors -> angles.
716
745
 
@@ -752,14 +781,32 @@ class HEDMForwardModel(nn.Module):
752
781
  # batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
753
782
  # rendering. Both flow through the same einsum via leading-dim
754
783
  # broadcasting on the second arg.
755
- G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
784
+ #
785
+ # per_grain=True: ELEMENT-WISE pairing of grain i's orientation with
786
+ # grain i's strained hkls -- NO orientation x strain cross-product.
787
+ # Requires orientation_matrices (N,3,3); hkls_cart (N,M,3) per-grain or
788
+ # (M,3) shared. Used by forward_per_grain() for the O(N*M) fast path.
789
+ if per_grain:
790
+ if hkls_cart.dim() == 2: # (M,3) shared lattice
791
+ G_C = torch.einsum("nij,mj->nmi", orientation_matrices, hkls_cart)
792
+ else: # (N,M,3) per-grain strain
793
+ G_C = torch.einsum("nij,nmj->nmi", orientation_matrices, hkls_cart)
794
+ else:
795
+ G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
756
796
 
757
797
  # v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
758
798
  # (rotation preserves norm in exact arithmetic but not in float64).
759
799
  # Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
760
800
  len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
761
801
  v_no_wedge = torch.sin(thetas) * len_hkl # (M,) or (..., M)
762
- v_no_wedge = v_no_wedge.unsqueeze(-2).expand_as(G_C[..., 0]) # (..., N, M)
802
+ # Broadcast v to G_C's (..., N, M) grid. Layout-agnostic: in the
803
+ # cross-product layout v gains an orientation axis at -2; in the
804
+ # per_grain / shared layouts it already matches. Numerically identical
805
+ # to the former ``unsqueeze(-2).expand_as`` for those existing paths.
806
+ gc0 = G_C[..., 0]
807
+ while v_no_wedge.dim() < gc0.dim():
808
+ v_no_wedge = v_no_wedge.unsqueeze(-2)
809
+ v_no_wedge = v_no_wedge.expand_as(gc0)
763
810
 
764
811
  # ---- Wedge: rigorous geometric formulation -------------------
765
812
  # The rotation axis tilts from z to n_hat = (sin W, 0, cos W).
@@ -873,9 +920,16 @@ class HEDMForwardModel(nn.Module):
873
920
  eta = self.safe_arccos(Gz_lab / r_yz)
874
921
  eta = -torch.sign(Gy_lab) * eta
875
922
 
876
- # 2*theta (broadcast thetas to match 2N dimension)
877
- two_theta_single = 2.0 * thetas.unsqueeze(-2) # (..., 1, M) or (1, M)
878
- two_theta = two_theta_single.expand_as(all_omega)
923
+ # 2*theta -- broadcast to the single-branch (pre-cat) grid, then double
924
+ # along the grain axis to match all_omega's 2N. Layout-agnostic (handles
925
+ # the per_grain layout where the 2N axis is grain-doubled) and
926
+ # numerically identical to the former expand for the cross-product path
927
+ # (thetas does not depend on the orientation axis).
928
+ tt = 2.0 * thetas
929
+ while tt.dim() < omega_p.dim():
930
+ tt = tt.unsqueeze(-2)
931
+ tt = tt.expand_as(omega_p)
932
+ two_theta = torch.cat([tt, tt], dim=-2)
879
933
 
880
934
  # Validity mask
881
935
  valid_p = disc_valid & coswp_valid
@@ -1311,9 +1365,10 @@ class HEDMForwardModel(nn.Module):
1311
1365
  lattice_params : Tensor (..., 6) or (..., N, 6), optional
1312
1366
  Strained lattice parameters [a,b,c,alpha,beta,gamma] in
1313
1367
  Angstroms/degrees. None = use nominal hkls/thetas (no strain).
1314
- strain : Tensor (..., 6) or (..., N, 6), optional
1315
- Crystal-frame symmetric infinitesimal strain in Voigt form
1316
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. Applied as
1368
+ strain : Tensor (..., 6), (..., N, 6), or (..., 3, 3), optional
1369
+ Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
1370
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
1371
+ 3x3 tensor (see :meth:`strain_as_voigt`). Applied as
1317
1372
  B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
1318
1373
  strain expressed through ``lattice_params``. Requires
1319
1374
  ``lattice_params`` to be supplied.
@@ -1326,6 +1381,11 @@ class HEDMForwardModel(nn.Module):
1326
1381
  if positions.shape[-1] == 2:
1327
1382
  positions = F.pad(positions, (0, 1), value=0.0)
1328
1383
 
1384
+ # Footgun guard: per-grain lattice/strain + N>1 orientations forms an
1385
+ # N x N orientation x strain cross-product (output (N, 2N, M)); callers
1386
+ # simulating a fixed polycrystal almost always want only the diagonal.
1387
+ self._warn_if_cross_product(euler_angles, lattice_params, strain)
1388
+
1329
1389
  # 1. Orientation matrices
1330
1390
  orientation_matrices = self.euler2mat(euler_angles)
1331
1391
 
@@ -1354,6 +1414,102 @@ class HEDMForwardModel(nn.Module):
1354
1414
 
1355
1415
  return spots
1356
1416
 
1417
+ @staticmethod
1418
+ def _warn_if_cross_product(euler_angles, lattice_params, strain):
1419
+ """Warn when forward() would form an orientation x strain cross-product.
1420
+
1421
+ Fires only when there are N>1 orientations AND lattice_params/strain
1422
+ carry a matching per-grain axis -- the case where the (N, 2N, M) output
1423
+ is an N x N cross-product and the caller likely wanted the diagonal.
1424
+ Shared lattice/strain (no grain axis) is the correct (2N, M) path and
1425
+ does not warn.
1426
+ """
1427
+ n_orient = euler_angles.shape[-2] if euler_angles.dim() >= 2 else 1
1428
+ if n_orient <= 1:
1429
+ return
1430
+
1431
+ def _has_grain_axis(t):
1432
+ if t is None:
1433
+ return False
1434
+ if t.shape[-1] == 6 and t.dim() >= 2: # Voigt lattice or strain
1435
+ return t.shape[-2] == n_orient
1436
+ if tuple(t.shape[-2:]) == (3, 3) and t.dim() >= 3: # full-tensor strain
1437
+ return t.shape[-3] == n_orient
1438
+ return False
1439
+
1440
+ if _has_grain_axis(lattice_params) or _has_grain_axis(strain):
1441
+ warnings.warn(
1442
+ f"forward() called with {n_orient} orientations and per-grain "
1443
+ "lattice_params/strain forms an orientation x strain "
1444
+ f"cross-product (output shape ({n_orient}, {2 * n_orient}, M)); "
1445
+ "only the diagonal [i, i] and [i, i+N] is physical. For a fixed "
1446
+ "polycrystal use forward_per_grain() (element-wise, O(N*M)), "
1447
+ "or index the diagonal of this output.",
1448
+ stacklevel=3,
1449
+ )
1450
+
1451
+ def forward_per_grain(
1452
+ self,
1453
+ euler_angles: torch.Tensor,
1454
+ positions: torch.Tensor,
1455
+ lattice_params: Optional[torch.Tensor] = None,
1456
+ strain: Optional[torch.Tensor] = None,
1457
+ ) -> SpotDescriptors:
1458
+ """Element-wise per-grain forward simulation -- the fast path.
1459
+
1460
+ Grain ``i`` is simulated with orientation ``i``, lattice/strain ``i``
1461
+ and position ``i``, WITHOUT the orientation x strain cross-product that
1462
+ :meth:`forward` forms when both are per-grain. The output has leading
1463
+ shape ``(2N, M)`` (the two omega branches doubled along the grain axis),
1464
+ which is exactly the diagonal of :meth:`forward`'s ``(N, 2N, M)`` output
1465
+ -- so gradient and gradient-free callers agree bit-for-bit.
1466
+
1467
+ Cost is O(N*M) rather than O(N^2 * M), matching the algorithm of the C
1468
+ reference ``ForwardSimulationCompressed.c``. Fully differentiable and
1469
+ device-portable; for pure forward simulation wrap the call in
1470
+ ``torch.inference_mode()``.
1471
+
1472
+ Parameters
1473
+ ----------
1474
+ euler_angles : Tensor (N, 3)
1475
+ Bunge ZXZ Euler angles (radians), one per grain. No leading batch.
1476
+ positions : Tensor (N, 3) or (N, 2)
1477
+ Real-space grain positions (micrometers).
1478
+ lattice_params : Tensor (6,) or (N, 6), optional
1479
+ Shared or per-grain lattice [a,b,c,alpha,beta,gamma] (Ang/deg).
1480
+ strain : Tensor (6,), (N, 6), (3, 3), or (N, 3, 3), optional
1481
+ Shared or per-grain crystal-frame strain (plain-Voigt or full 3x3).
1482
+
1483
+ Returns
1484
+ -------
1485
+ SpotDescriptors with leading shape ``(2N, M)``: grain ``i``'s two omega
1486
+ solutions live at axis-(-2) indices ``i`` and ``i + N``.
1487
+ """
1488
+ if positions.shape[-1] == 2:
1489
+ positions = F.pad(positions, (0, 1), value=0.0)
1490
+
1491
+ orientation_matrices = self.euler2mat(euler_angles) # (N, 3, 3)
1492
+
1493
+ hkls_cart = None
1494
+ thetas = None
1495
+ if lattice_params is not None:
1496
+ hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
1497
+ elif strain is not None:
1498
+ raise ValueError(
1499
+ "strain was supplied but lattice_params is None; strain "
1500
+ "requires a reference lattice to apply (I + eps)^{-1} @ B0."
1501
+ )
1502
+
1503
+ omega, eta, two_theta, valid = self.calc_bragg_geometry(
1504
+ orientation_matrices, hkls_cart, thetas, per_grain=True
1505
+ )
1506
+ spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
1507
+
1508
+ if self.scan_config is not None:
1509
+ spots = self.filter_by_scan(spots, positions)
1510
+
1511
+ return spots
1512
+
1357
1513
  # ------------------------------------------------------------------
1358
1514
  # filter_by_scan (beam proximity for pf-HEDM)
1359
1515
  # ------------------------------------------------------------------
@@ -613,18 +613,13 @@ def simulate_panel_zarrs(
613
613
  f"BC=({panel.y_bc:.1f},{panel.z_bc:.1f}), "
614
614
  f"tx={panel.tx:.2f}° ty={panel.ty:.2f}° tz={panel.tz:.2f}°")
615
615
  with torch.no_grad():
616
- spots = model(eulers_t, positions_t,
617
- lattice_params=latc, strain=strain_t)
618
-
619
- valid_np = (spots.valid > 0.5).cpu().numpy()
620
- # Strain × orientation diagonal mask
621
- G_strain, Kdim, Mdim = valid_np.shape
622
- if G_strain == n_grains and Kdim == 2 * n_grains:
623
- diag_mask = np.zeros((G_strain, Kdim, Mdim), dtype=bool)
624
- for gi in range(G_strain):
625
- diag_mask[gi, gi, :] = True
626
- diag_mask[gi, gi + n_grains, :] = True
627
- valid_np = valid_np & diag_mask
616
+ # Per-grain fast path: O(N*M), no orientation x strain cross-product.
617
+ # Output is (2N, M) -- already the per-grain diagonal, no mask needed.
618
+ spots = model.forward_per_grain(eulers_t, positions_t,
619
+ lattice_params=latc, strain=strain_t)
620
+
621
+ valid_np = (spots.valid > 0.5).cpu().numpy() # (2N, M)
622
+ Kdim, Mdim = valid_np.shape # Kdim == 2 * n_grains
628
623
 
629
624
  # First-wins: drop spots that an earlier panel already took
630
625
  # (avoids double-counting in the rare overlap).
@@ -652,10 +647,11 @@ def simulate_panel_zarrs(
652
647
  y_pix = spots.y_pixel.cpu().numpy()
653
648
  z_pix = spots.z_pixel.cpu().numpy()
654
649
  frame_nr = spots.frame_nr.cpu().numpy()
655
- grain_ids = np.broadcast_to(np.arange(G_strain)[:, None, None],
656
- (G_strain, Kdim, Mdim))
657
- hkl_ids = np.broadcast_to(np.arange(Mdim)[None, None, :],
658
- (G_strain, Kdim, Mdim))
650
+ # (2N, M): row k holds grain (k % n_grains); rows [0:N]/[N:2N] are the
651
+ # two omega branches.
652
+ grain_ids = np.broadcast_to((np.arange(Kdim) % n_grains)[:, None],
653
+ (Kdim, Mdim))
654
+ hkl_ids = np.broadcast_to(np.arange(Mdim)[None, :], (Kdim, Mdim))
659
655
 
660
656
  rec = {
661
657
  "grain_id": grain_ids.reshape(-1)[flat_idx].astype(np.int32),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.2.0
3
+ Version: 0.6.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -19,6 +19,7 @@ License-File: LICENSE
19
19
  Requires-Dist: numpy>=1.22
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: midas-distortion>=0.2.0
22
+ Requires-Dist: midas-stress>=0.8.0
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: pytest>=7.0; extra == "dev"
24
25
  Requires-Dist: pytest-cov; extra == "dev"
@@ -1,6 +1,7 @@
1
1
  numpy>=1.22
2
2
  torch>=2.0
3
3
  midas-distortion>=0.2.0
4
+ midas-stress>=0.8.0
4
5
 
5
6
  [dev]
6
7
  pytest>=7.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "midas-diffract"
7
- version = "0.2.0"
7
+ version = "0.6.0"
8
8
  description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
9
9
  readme = "README.md"
10
10
  license = "BSD-3-Clause"
@@ -27,6 +27,7 @@ dependencies = [
27
27
  "numpy>=1.22",
28
28
  "torch>=2.0",
29
29
  "midas-distortion>=0.2.0",
30
+ "midas-stress>=0.8.0",
30
31
  ]
31
32
 
32
33
  [project.optional-dependencies]
@@ -1111,5 +1111,179 @@ class TestCrossValidation:
1111
1111
  import torch.nn.functional as F # for test_2d_position_backward_compat
1112
1112
 
1113
1113
 
1114
+ # ===================================================================
1115
+ # Test: midas_stress is the canonical orientation/strain-frame source
1116
+ # ===================================================================
1117
+
1118
+ class TestMidasStressDelegation:
1119
+ """Guards that euler2mat / rotate_strain_sample_to_crystal delegate to
1120
+ midas_stress and never silently re-port (convention drift)."""
1121
+
1122
+ def test_euler2mat_matches_midas_stress(self):
1123
+ from midas_stress.orientation import euler_to_orient_mat
1124
+ torch.manual_seed(0)
1125
+ eul = torch.rand(7, 3, dtype=torch.float64) * 2 * math.pi
1126
+ R = HEDMForwardModel.euler2mat(eul)
1127
+ R_ms = euler_to_orient_mat(eul).reshape(7, 3, 3)
1128
+ torch.testing.assert_close(R, R_ms, atol=1e-12, rtol=0)
1129
+
1130
+ def test_euler2mat_vmap_safe(self):
1131
+ from torch.func import vmap
1132
+ eul = torch.rand(5, 3, dtype=torch.float64) * 2 * math.pi
1133
+ R = vmap(HEDMForwardModel.euler2mat)(eul)
1134
+ assert R.shape == (5, 3, 3)
1135
+ torch.testing.assert_close(R, HEDMForwardModel.euler2mat(eul), atol=1e-12, rtol=0)
1136
+
1137
+ def test_euler2mat_differentiable(self):
1138
+ eul = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
1139
+ HEDMForwardModel.euler2mat(eul).pow(2).sum().backward()
1140
+ assert eul.grad is not None and torch.all(torch.isfinite(eul.grad))
1141
+
1142
+ def test_rotate_strain_matches_midas_stress(self):
1143
+ from midas_stress.tensor import strain_lab_to_grain
1144
+ torch.manual_seed(1)
1145
+ OM = HEDMForwardModel.euler2mat(torch.rand(4, 3, dtype=torch.float64))
1146
+ v = torch.rand(4, 6, dtype=torch.float64) * 1e-3 # plain Voigt
1147
+ out = HEDMForwardModel.rotate_strain_sample_to_crystal(OM, v)
1148
+ # reference: build 3x3, rotate via midas_stress, repack plain Voigt
1149
+ S = torch.zeros(4, 3, 3, dtype=torch.float64)
1150
+ S[:, 0, 0], S[:, 1, 1], S[:, 2, 2] = v[:, 0], v[:, 3], v[:, 5]
1151
+ S[:, 0, 1] = S[:, 1, 0] = v[:, 1]
1152
+ S[:, 0, 2] = S[:, 2, 0] = v[:, 2]
1153
+ S[:, 1, 2] = S[:, 2, 1] = v[:, 4]
1154
+ C = strain_lab_to_grain(S, OM)
1155
+ ref = torch.stack([C[:, 0, 0], C[:, 0, 1], C[:, 0, 2],
1156
+ C[:, 1, 1], C[:, 1, 2], C[:, 2, 2]], dim=-1)
1157
+ torch.testing.assert_close(out, ref, atol=1e-14, rtol=0)
1158
+
1159
+
1160
+ # ===================================================================
1161
+ # Test: full 3x3 strain tensor entry point (convention-free)
1162
+ # ===================================================================
1163
+
1164
+ class TestStrainTensorInput:
1165
+ """A full symmetric 3x3 strain must give the same result as the
1166
+ equivalent plain-Voigt 6-vector (NOT Voigt-Mandel)."""
1167
+
1168
+ def test_strain_as_voigt_roundtrip(self):
1169
+ v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
1170
+ S = torch.tensor([[1e-3, 2e-4, 3e-4],
1171
+ [2e-4, -5e-4, 1e-4],
1172
+ [3e-4, 1e-4, 2e-4]])
1173
+ torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(S), v)
1174
+ # plain Voigt passes through unchanged
1175
+ torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(v), v)
1176
+
1177
+ def test_3x3_strain_equals_voigt_in_correct_hkls_latc(self, nf_geometry, device):
1178
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1179
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.])
1180
+ v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
1181
+ S = torch.tensor([[1e-3, 2e-4, 3e-4],
1182
+ [2e-4, -5e-4, 1e-4],
1183
+ [3e-4, 1e-4, 2e-4]])
1184
+ g_v, t_v = model.correct_hkls_latc(latc, strain=v)
1185
+ g_S, t_S = model.correct_hkls_latc(latc, strain=S)
1186
+ torch.testing.assert_close(g_v, g_S)
1187
+ torch.testing.assert_close(t_v, t_S)
1188
+
1189
+ def test_3x3_strain_equals_voigt_in_forward(self, nf_geometry, device):
1190
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1191
+ eul = torch.tensor([[0.3, 0.5, 1.1]])
1192
+ pos = torch.zeros(1, 3)
1193
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.]).expand(1, 6)
1194
+ v = torch.tensor([[1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4]])
1195
+ S = torch.tensor([[[1e-3, 2e-4, 3e-4],
1196
+ [2e-4, -5e-4, 1e-4],
1197
+ [3e-4, 1e-4, 2e-4]]])
1198
+ sp_v = model(eul, pos, lattice_params=latc, strain=v)
1199
+ sp_S = model(eul, pos, lattice_params=latc, strain=S)
1200
+ torch.testing.assert_close(sp_v.two_theta, sp_S.two_theta)
1201
+ torch.testing.assert_close(sp_v.omega, sp_S.omega)
1202
+
1203
+
1204
+ # ===================================================================
1205
+ # Test: forward_per_grain (fast path) == diagonal of forward()
1206
+ # ===================================================================
1207
+
1208
+ class TestForwardPerGrain:
1209
+ """The O(N*M) per-grain path must equal the diagonal of forward()'s
1210
+ O(N^2*M) orientation x strain cross-product, bit-for-bit, so the
1211
+ gradient and gradient-free callers agree."""
1212
+
1213
+ def _setup(self, nf_geometry, device, N=6, strained=True):
1214
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1215
+ torch.manual_seed(7)
1216
+ euler = torch.rand(N, 3, dtype=torch.float64) * 2 * math.pi
1217
+ pos = torch.arange(N * 3, dtype=torch.float64).reshape(N, 3) * 5.0
1218
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.],
1219
+ dtype=torch.float64).expand(N, 6)
1220
+ if strained:
1221
+ s = torch.randn(N, 3, 3, dtype=torch.float64) * 1e-3
1222
+ strain = (s + s.transpose(-1, -2)) / 2
1223
+ else:
1224
+ strain = None
1225
+ return model, euler, pos, latc, strain, N
1226
+
1227
+ def test_per_grain_equals_forward_diagonal_strained(self, nf_geometry, device):
1228
+ # Compare per-grain (no global sort, which would reorder near-degenerate
1229
+ # spots): forward() diagonal rows [gi,gi] & [gi,gi+N] vs
1230
+ # forward_per_grain() rows [gi] & [gi+N]. Element-wise einsum differs
1231
+ # from the cross-product einsum only by fp64 reduction order (~1e-12).
1232
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1233
+ spf = model(euler, pos, lattice_params=latc, strain=strain) # (N,2N,M)
1234
+ spg = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain) # (2N,M)
1235
+
1236
+ vf = spf.valid > 0.5
1237
+ vg = spg.valid > 0.5
1238
+ max_diff = 0.0
1239
+ for fld in ("two_theta", "eta", "omega", "y_pixel", "z_pixel", "frame_nr"):
1240
+ F = getattr(spf, fld)
1241
+ G = getattr(spg, fld)
1242
+ for gi in range(N):
1243
+ for kf, kg in ((gi, gi), (gi + N, gi + N)):
1244
+ # validity must agree exactly on the diagonal
1245
+ assert torch.equal(vf[gi, kf], vg[kg])
1246
+ max_diff = max(max_diff, float((F[gi, kf] - G[kg]).abs().max()))
1247
+ assert max_diff < 1e-9, f"per-grain vs forward-diagonal max diff {max_diff:.2e}"
1248
+
1249
+ def test_per_grain_shape_is_2N(self, nf_geometry, device):
1250
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1251
+ sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
1252
+ assert sp.valid.shape[0] == 2 * N
1253
+
1254
+ def test_per_grain_shared_lattice_no_strain(self, nf_geometry, device):
1255
+ # nominal (no lattice) path must also work and avoid cross-product
1256
+ model, euler, pos, latc, _, N = self._setup(nf_geometry, device, strained=False)
1257
+ sp = model.forward_per_grain(euler, pos)
1258
+ assert sp.valid.shape[0] == 2 * N
1259
+
1260
+ def test_per_grain_differentiable(self, nf_geometry, device):
1261
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1262
+ euler = euler.clone().requires_grad_(True)
1263
+ sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
1264
+ # omega/eta/pixels depend on orientation (two_theta does not)
1265
+ (sp.omega * sp.valid).sum().backward()
1266
+ assert euler.grad is not None and torch.all(torch.isfinite(euler.grad))
1267
+
1268
+ def test_per_grain_fp32(self, nf_geometry, device):
1269
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1270
+ sp = model.forward_per_grain(euler.float(), pos.float(),
1271
+ lattice_params=latc.float(), strain=strain.float())
1272
+ assert sp.valid.dtype == torch.float32 or sp.valid.dtype == torch.float64
1273
+
1274
+ def test_forward_warns_on_cross_product(self, nf_geometry, device):
1275
+ model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
1276
+ with pytest.warns(UserWarning, match="cross-product"):
1277
+ model(euler, pos, lattice_params=latc, strain=strain)
1278
+
1279
+ def test_forward_no_warn_shared_lattice(self, nf_geometry, device):
1280
+ import warnings as _w
1281
+ model, euler, pos, _, _, N = self._setup(nf_geometry, device, strained=False)
1282
+ shared = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.], dtype=torch.float64)
1283
+ with _w.catch_warnings():
1284
+ _w.simplefilter("error") # any cross-product warning becomes an error
1285
+ model(euler, pos, lattice_params=shared) # shared -> (2N,M), no warn
1286
+
1287
+
1114
1288
  if __name__ == "__main__":
1115
1289
  pytest.main([__file__, "-v"])
File without changes
File without changes
File without changes