midas-diffract 0.2.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/PKG-INFO +2 -1
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/__init__.py +1 -1
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/forward.py +202 -46
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/simulate_panel_zarrs.py +12 -16
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/PKG-INFO +2 -1
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/requires.txt +1 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/pyproject.toml +2 -1
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_forward.py +174 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/LICENSE +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/README.md +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/hkls.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/losses.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract/optimize.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/SOURCES.txt +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/top_level.txt +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/setup.cfg +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_c_comparison.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_distortion_layer.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_hkls.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_losses.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_multi_detector.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_strain_tensor.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_tilts.py +0 -0
- {midas_diffract-0.2.0 → midas_diffract-0.6.0}/tests/test_wedge.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -19,6 +19,7 @@ License-File: LICENSE
|
|
|
19
19
|
Requires-Dist: numpy>=1.22
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: midas-distortion>=0.2.0
|
|
22
|
+
Requires-Dist: midas-stress>=0.8.0
|
|
22
23
|
Provides-Extra: dev
|
|
23
24
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
24
25
|
Requires-Dist: pytest-cov; extra == "dev"
|
|
@@ -27,6 +27,7 @@ Reference C code:
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
import math
|
|
30
|
+
import warnings
|
|
30
31
|
from dataclasses import dataclass, field
|
|
31
32
|
from typing import Optional, Tuple
|
|
32
33
|
|
|
@@ -34,6 +35,16 @@ import torch
|
|
|
34
35
|
import torch.nn as nn
|
|
35
36
|
import torch.nn.functional as F
|
|
36
37
|
|
|
38
|
+
# Canonical orientation + strain-frame primitives. midas_stress is the single
|
|
39
|
+
# source of truth for this math (Bunge ZXZ orientation algebra, sample<->crystal
|
|
40
|
+
# strain rotation); its torch backend is differentiable end-to-end and
|
|
41
|
+
# device-portable, so the forward model delegates rather than re-porting.
|
|
42
|
+
# NOTE: midas_stress's Voigt is Voigt-MANDEL (sqrt2 on shears); this model uses
|
|
43
|
+
# PLAIN-Voigt / raw 3x3 strain, so we only delegate the *rotation*, never the
|
|
44
|
+
# Voigt packing (see rotate_strain_sample_to_crystal / correct_hkls_latc).
|
|
45
|
+
from midas_stress.orientation import euler_to_orient_mat as _ms_euler_to_orient_mat
|
|
46
|
+
from midas_stress.tensor import strain_lab_to_grain as _ms_strain_lab_to_grain
|
|
47
|
+
|
|
37
48
|
|
|
38
49
|
# ---------------------------------------------------------------------------
|
|
39
50
|
# Configuration data classes
|
|
@@ -437,7 +448,13 @@ class HEDMForwardModel(nn.Module):
|
|
|
437
448
|
|
|
438
449
|
@staticmethod
|
|
439
450
|
def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
|
|
440
|
-
"""Convert ZXZ Euler angles to rotation matrices.
|
|
451
|
+
"""Convert ZXZ (Bunge) Euler angles to crystal->sample rotation matrices.
|
|
452
|
+
|
|
453
|
+
Delegates to ``midas_stress.orientation.euler_to_orient_mat`` -- the
|
|
454
|
+
canonical orientation primitive -- so the convention can never drift
|
|
455
|
+
from the rest of MIDAS. midas_stress's torch backend is differentiable
|
|
456
|
+
and vmap-safe; the result is identical to the former in-line ZXZ build
|
|
457
|
+
(R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)) to ~1e-16.
|
|
441
458
|
|
|
442
459
|
Parameters
|
|
443
460
|
----------
|
|
@@ -447,35 +464,14 @@ class HEDMForwardModel(nn.Module):
|
|
|
447
464
|
Returns
|
|
448
465
|
-------
|
|
449
466
|
Tensor (..., 3, 3)
|
|
450
|
-
Rotation matrices.
|
|
467
|
+
Rotation matrices (crystal->sample), orthogonalized onto SO(3).
|
|
451
468
|
"""
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
# ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
|
|
459
|
-
# Verified element-by-element against nfhedm.py lines 114-120.
|
|
460
|
-
# Built via torch.stack rather than indexed assignment so the function
|
|
461
|
-
# composes with torch.func.vmap (in-place writes block vmap).
|
|
462
|
-
row0 = torch.stack([
|
|
463
|
-
c0 * c2 - s0 * c1 * s2,
|
|
464
|
-
-s0 * c1 * c2 - c0 * s2,
|
|
465
|
-
s0 * s1,
|
|
466
|
-
], dim=-1)
|
|
467
|
-
row1 = torch.stack([
|
|
468
|
-
s0 * c2 + c0 * c1 * s2,
|
|
469
|
-
c0 * c1 * c2 - s0 * s2,
|
|
470
|
-
-c0 * s1,
|
|
471
|
-
], dim=-1)
|
|
472
|
-
row2 = torch.stack([
|
|
473
|
-
s1 * s2,
|
|
474
|
-
s1 * c2,
|
|
475
|
-
c1,
|
|
476
|
-
], dim=-1)
|
|
477
|
-
R = torch.stack([row0, row1, row2], dim=-2)
|
|
478
|
-
|
|
469
|
+
if not isinstance(euler_angles, torch.Tensor):
|
|
470
|
+
euler_angles = torch.as_tensor(euler_angles)
|
|
471
|
+
R = _ms_euler_to_orient_mat(euler_angles) # (..., 9), torch
|
|
472
|
+
R = R.reshape(*R.shape[:-1], 3, 3)
|
|
473
|
+
# midas_stress already returns a proper rotation; orthogonalize keeps the
|
|
474
|
+
# historical "exactly on SO(3)" guarantee and is idempotent here.
|
|
479
475
|
return HEDMForwardModel.orthogonalize(R)
|
|
480
476
|
|
|
481
477
|
# ------------------------------------------------------------------
|
|
@@ -525,6 +521,28 @@ class HEDMForwardModel(nn.Module):
|
|
|
525
521
|
"""Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
|
|
526
522
|
return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
|
|
527
523
|
|
|
524
|
+
# ------------------------------------------------------------------
|
|
525
|
+
# strain_as_voigt (accept full 3x3 tensor OR plain-Voigt 6-vector)
|
|
526
|
+
# ------------------------------------------------------------------
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def strain_as_voigt(strain: torch.Tensor) -> torch.Tensor:
|
|
530
|
+
"""Normalize a strain input to PLAIN-Voigt [e11,e12,e13,e22,e23,e33].
|
|
531
|
+
|
|
532
|
+
Accepts either a plain-Voigt ``(..., 6)`` tensor (returned unchanged) or
|
|
533
|
+
a full symmetric ``(..., 3, 3)`` strain tensor. The 3x3 path is
|
|
534
|
+
convention-free -- the natural way to hand a strain field straight from
|
|
535
|
+
a tensor source (e.g. a midas_stress strain field) into the forward
|
|
536
|
+
model without picking a Voigt/Mandel packing. The off-diagonals are
|
|
537
|
+
taken as TRUE tensor components (no factor of 2).
|
|
538
|
+
"""
|
|
539
|
+
if strain.dim() >= 2 and strain.shape[-1] == 3 and strain.shape[-2] == 3:
|
|
540
|
+
return torch.stack([
|
|
541
|
+
strain[..., 0, 0], strain[..., 0, 1], strain[..., 0, 2],
|
|
542
|
+
strain[..., 1, 1], strain[..., 1, 2], strain[..., 2, 2],
|
|
543
|
+
], dim=-1)
|
|
544
|
+
return strain
|
|
545
|
+
|
|
528
546
|
# ------------------------------------------------------------------
|
|
529
547
|
# rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
|
|
530
548
|
# ------------------------------------------------------------------
|
|
@@ -536,19 +554,27 @@ class HEDMForwardModel(nn.Module):
|
|
|
536
554
|
) -> torch.Tensor:
|
|
537
555
|
"""Rotate a symmetric infinitesimal strain from sample to crystal frame.
|
|
538
556
|
|
|
539
|
-
|
|
557
|
+
Matches ``RotateStrainSampleToCrystal`` from
|
|
540
558
|
``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
|
|
541
|
-
eps_crystal = OM^T . eps_sample . OM
|
|
542
|
-
|
|
559
|
+
eps_crystal = OM^T . eps_sample . OM.
|
|
560
|
+
|
|
561
|
+
The rotation is delegated to ``midas_stress.tensor.strain_lab_to_grain``
|
|
562
|
+
(the canonical sample/lab -> crystal/grain strain transform; bit-identical
|
|
563
|
+
to the former in-line ``OM^T S OM``). PLAIN-Voigt pack/unpack is kept here
|
|
564
|
+
on purpose -- midas_stress's Voigt is Mandel (sqrt2 on shears), which must
|
|
565
|
+
not touch the forward model's strain input.
|
|
543
566
|
|
|
544
567
|
Parameters
|
|
545
568
|
----------
|
|
546
569
|
orientation_matrices : Tensor (..., 3, 3)
|
|
570
|
+
Crystal->sample matrices (as returned by :meth:`euler2mat`).
|
|
547
571
|
strain_sample : Tensor (..., 6)
|
|
572
|
+
PLAIN-Voigt [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
|
|
548
573
|
|
|
549
574
|
Returns
|
|
550
575
|
-------
|
|
551
576
|
strain_crystal : Tensor (..., 6)
|
|
577
|
+
PLAIN-Voigt, same layout.
|
|
552
578
|
"""
|
|
553
579
|
e = strain_sample
|
|
554
580
|
S = torch.stack([
|
|
@@ -556,8 +582,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
556
582
|
torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
|
|
557
583
|
torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
|
|
558
584
|
], dim=-2)
|
|
559
|
-
|
|
560
|
-
C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
|
|
585
|
+
C = _ms_strain_lab_to_grain(S, orientation_matrices) # OM^T S OM
|
|
561
586
|
return torch.stack([
|
|
562
587
|
C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
|
|
563
588
|
C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
|
|
@@ -587,10 +612,12 @@ class HEDMForwardModel(nn.Module):
|
|
|
587
612
|
lattice_params : Tensor (..., 6)
|
|
588
613
|
[a, b, c, alpha, beta, gamma] in Angstroms and degrees.
|
|
589
614
|
The ``...`` dimensions allow per-voxel or per-grain parameters.
|
|
590
|
-
strain : Tensor (..., 6), optional
|
|
591
|
-
Crystal-frame symmetric infinitesimal strain
|
|
592
|
-
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
593
|
-
|
|
615
|
+
strain : Tensor (..., 6) or (..., 3, 3), optional
|
|
616
|
+
Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
|
|
617
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
|
|
618
|
+
3x3 tensor (normalized via :meth:`strain_as_voigt`; off-diagonals are
|
|
619
|
+
true tensor components, no factor of 2). When supplied, the
|
|
620
|
+
reciprocal lattice is post-multiplied by (I + eps)^{-1}:
|
|
594
621
|
B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
|
|
595
622
|
to convert a sample-frame strain into the crystal frame.
|
|
596
623
|
|
|
@@ -671,6 +698,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
671
698
|
# Voigt layout matches C CorrectHKLsLatCEpsilon:
|
|
672
699
|
# eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
673
700
|
if strain is not None:
|
|
701
|
+
strain = self.strain_as_voigt(strain) # accept full 3x3 too
|
|
674
702
|
e11 = strain[..., 0]
|
|
675
703
|
e12 = strain[..., 1]
|
|
676
704
|
e13 = strain[..., 2]
|
|
@@ -711,6 +739,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
711
739
|
orientation_matrices: torch.Tensor,
|
|
712
740
|
hkls_cart: Optional[torch.Tensor] = None,
|
|
713
741
|
thetas: Optional[torch.Tensor] = None,
|
|
742
|
+
per_grain: bool = False,
|
|
714
743
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
715
744
|
"""Core Bragg geometry: orientations + G-vectors -> angles.
|
|
716
745
|
|
|
@@ -752,14 +781,32 @@ class HEDMForwardModel(nn.Module):
|
|
|
752
781
|
# batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
|
|
753
782
|
# rendering. Both flow through the same einsum via leading-dim
|
|
754
783
|
# broadcasting on the second arg.
|
|
755
|
-
|
|
784
|
+
#
|
|
785
|
+
# per_grain=True: ELEMENT-WISE pairing of grain i's orientation with
|
|
786
|
+
# grain i's strained hkls -- NO orientation x strain cross-product.
|
|
787
|
+
# Requires orientation_matrices (N,3,3); hkls_cart (N,M,3) per-grain or
|
|
788
|
+
# (M,3) shared. Used by forward_per_grain() for the O(N*M) fast path.
|
|
789
|
+
if per_grain:
|
|
790
|
+
if hkls_cart.dim() == 2: # (M,3) shared lattice
|
|
791
|
+
G_C = torch.einsum("nij,mj->nmi", orientation_matrices, hkls_cart)
|
|
792
|
+
else: # (N,M,3) per-grain strain
|
|
793
|
+
G_C = torch.einsum("nij,nmj->nmi", orientation_matrices, hkls_cart)
|
|
794
|
+
else:
|
|
795
|
+
G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
|
|
756
796
|
|
|
757
797
|
# v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
|
|
758
798
|
# (rotation preserves norm in exact arithmetic but not in float64).
|
|
759
799
|
# Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
|
|
760
800
|
len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
|
|
761
801
|
v_no_wedge = torch.sin(thetas) * len_hkl # (M,) or (..., M)
|
|
762
|
-
|
|
802
|
+
# Broadcast v to G_C's (..., N, M) grid. Layout-agnostic: in the
|
|
803
|
+
# cross-product layout v gains an orientation axis at -2; in the
|
|
804
|
+
# per_grain / shared layouts it already matches. Numerically identical
|
|
805
|
+
# to the former ``unsqueeze(-2).expand_as`` for those existing paths.
|
|
806
|
+
gc0 = G_C[..., 0]
|
|
807
|
+
while v_no_wedge.dim() < gc0.dim():
|
|
808
|
+
v_no_wedge = v_no_wedge.unsqueeze(-2)
|
|
809
|
+
v_no_wedge = v_no_wedge.expand_as(gc0)
|
|
763
810
|
|
|
764
811
|
# ---- Wedge: rigorous geometric formulation -------------------
|
|
765
812
|
# The rotation axis tilts from z to n_hat = (sin W, 0, cos W).
|
|
@@ -873,9 +920,16 @@ class HEDMForwardModel(nn.Module):
|
|
|
873
920
|
eta = self.safe_arccos(Gz_lab / r_yz)
|
|
874
921
|
eta = -torch.sign(Gy_lab) * eta
|
|
875
922
|
|
|
876
|
-
# 2*theta
|
|
877
|
-
|
|
878
|
-
|
|
923
|
+
# 2*theta -- broadcast to the single-branch (pre-cat) grid, then double
|
|
924
|
+
# along the grain axis to match all_omega's 2N. Layout-agnostic (handles
|
|
925
|
+
# the per_grain layout where the 2N axis is grain-doubled) and
|
|
926
|
+
# numerically identical to the former expand for the cross-product path
|
|
927
|
+
# (thetas does not depend on the orientation axis).
|
|
928
|
+
tt = 2.0 * thetas
|
|
929
|
+
while tt.dim() < omega_p.dim():
|
|
930
|
+
tt = tt.unsqueeze(-2)
|
|
931
|
+
tt = tt.expand_as(omega_p)
|
|
932
|
+
two_theta = torch.cat([tt, tt], dim=-2)
|
|
879
933
|
|
|
880
934
|
# Validity mask
|
|
881
935
|
valid_p = disc_valid & coswp_valid
|
|
@@ -1311,9 +1365,10 @@ class HEDMForwardModel(nn.Module):
|
|
|
1311
1365
|
lattice_params : Tensor (..., 6) or (..., N, 6), optional
|
|
1312
1366
|
Strained lattice parameters [a,b,c,alpha,beta,gamma] in
|
|
1313
1367
|
Angstroms/degrees. None = use nominal hkls/thetas (no strain).
|
|
1314
|
-
strain : Tensor (..., 6)
|
|
1315
|
-
Crystal-frame symmetric infinitesimal strain
|
|
1316
|
-
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
1368
|
+
strain : Tensor (..., 6), (..., N, 6), or (..., 3, 3), optional
|
|
1369
|
+
Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
|
|
1370
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
|
|
1371
|
+
3x3 tensor (see :meth:`strain_as_voigt`). Applied as
|
|
1317
1372
|
B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
|
|
1318
1373
|
strain expressed through ``lattice_params``. Requires
|
|
1319
1374
|
``lattice_params`` to be supplied.
|
|
@@ -1326,6 +1381,11 @@ class HEDMForwardModel(nn.Module):
|
|
|
1326
1381
|
if positions.shape[-1] == 2:
|
|
1327
1382
|
positions = F.pad(positions, (0, 1), value=0.0)
|
|
1328
1383
|
|
|
1384
|
+
# Footgun guard: per-grain lattice/strain + N>1 orientations forms an
|
|
1385
|
+
# N x N orientation x strain cross-product (output (N, 2N, M)); callers
|
|
1386
|
+
# simulating a fixed polycrystal almost always want only the diagonal.
|
|
1387
|
+
self._warn_if_cross_product(euler_angles, lattice_params, strain)
|
|
1388
|
+
|
|
1329
1389
|
# 1. Orientation matrices
|
|
1330
1390
|
orientation_matrices = self.euler2mat(euler_angles)
|
|
1331
1391
|
|
|
@@ -1354,6 +1414,102 @@ class HEDMForwardModel(nn.Module):
|
|
|
1354
1414
|
|
|
1355
1415
|
return spots
|
|
1356
1416
|
|
|
1417
|
+
@staticmethod
|
|
1418
|
+
def _warn_if_cross_product(euler_angles, lattice_params, strain):
|
|
1419
|
+
"""Warn when forward() would form an orientation x strain cross-product.
|
|
1420
|
+
|
|
1421
|
+
Fires only when there are N>1 orientations AND lattice_params/strain
|
|
1422
|
+
carry a matching per-grain axis -- the case where the (N, 2N, M) output
|
|
1423
|
+
is an N x N cross-product and the caller likely wanted the diagonal.
|
|
1424
|
+
Shared lattice/strain (no grain axis) is the correct (2N, M) path and
|
|
1425
|
+
does not warn.
|
|
1426
|
+
"""
|
|
1427
|
+
n_orient = euler_angles.shape[-2] if euler_angles.dim() >= 2 else 1
|
|
1428
|
+
if n_orient <= 1:
|
|
1429
|
+
return
|
|
1430
|
+
|
|
1431
|
+
def _has_grain_axis(t):
|
|
1432
|
+
if t is None:
|
|
1433
|
+
return False
|
|
1434
|
+
if t.shape[-1] == 6 and t.dim() >= 2: # Voigt lattice or strain
|
|
1435
|
+
return t.shape[-2] == n_orient
|
|
1436
|
+
if tuple(t.shape[-2:]) == (3, 3) and t.dim() >= 3: # full-tensor strain
|
|
1437
|
+
return t.shape[-3] == n_orient
|
|
1438
|
+
return False
|
|
1439
|
+
|
|
1440
|
+
if _has_grain_axis(lattice_params) or _has_grain_axis(strain):
|
|
1441
|
+
warnings.warn(
|
|
1442
|
+
f"forward() called with {n_orient} orientations and per-grain "
|
|
1443
|
+
"lattice_params/strain forms an orientation x strain "
|
|
1444
|
+
f"cross-product (output shape ({n_orient}, {2 * n_orient}, M)); "
|
|
1445
|
+
"only the diagonal [i, i] and [i, i+N] is physical. For a fixed "
|
|
1446
|
+
"polycrystal use forward_per_grain() (element-wise, O(N*M)), "
|
|
1447
|
+
"or index the diagonal of this output.",
|
|
1448
|
+
stacklevel=3,
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
def forward_per_grain(
|
|
1452
|
+
self,
|
|
1453
|
+
euler_angles: torch.Tensor,
|
|
1454
|
+
positions: torch.Tensor,
|
|
1455
|
+
lattice_params: Optional[torch.Tensor] = None,
|
|
1456
|
+
strain: Optional[torch.Tensor] = None,
|
|
1457
|
+
) -> SpotDescriptors:
|
|
1458
|
+
"""Element-wise per-grain forward simulation -- the fast path.
|
|
1459
|
+
|
|
1460
|
+
Grain ``i`` is simulated with orientation ``i``, lattice/strain ``i``
|
|
1461
|
+
and position ``i``, WITHOUT the orientation x strain cross-product that
|
|
1462
|
+
:meth:`forward` forms when both are per-grain. The output has leading
|
|
1463
|
+
shape ``(2N, M)`` (the two omega branches doubled along the grain axis),
|
|
1464
|
+
which is exactly the diagonal of :meth:`forward`'s ``(N, 2N, M)`` output
|
|
1465
|
+
-- so gradient and gradient-free callers agree bit-for-bit.
|
|
1466
|
+
|
|
1467
|
+
Cost is O(N*M) rather than O(N^2 * M), matching the algorithm of the C
|
|
1468
|
+
reference ``ForwardSimulationCompressed.c``. Fully differentiable and
|
|
1469
|
+
device-portable; for pure forward simulation wrap the call in
|
|
1470
|
+
``torch.inference_mode()``.
|
|
1471
|
+
|
|
1472
|
+
Parameters
|
|
1473
|
+
----------
|
|
1474
|
+
euler_angles : Tensor (N, 3)
|
|
1475
|
+
Bunge ZXZ Euler angles (radians), one per grain. No leading batch.
|
|
1476
|
+
positions : Tensor (N, 3) or (N, 2)
|
|
1477
|
+
Real-space grain positions (micrometers).
|
|
1478
|
+
lattice_params : Tensor (6,) or (N, 6), optional
|
|
1479
|
+
Shared or per-grain lattice [a,b,c,alpha,beta,gamma] (Ang/deg).
|
|
1480
|
+
strain : Tensor (6,), (N, 6), (3, 3), or (N, 3, 3), optional
|
|
1481
|
+
Shared or per-grain crystal-frame strain (plain-Voigt or full 3x3).
|
|
1482
|
+
|
|
1483
|
+
Returns
|
|
1484
|
+
-------
|
|
1485
|
+
SpotDescriptors with leading shape ``(2N, M)``: grain ``i``'s two omega
|
|
1486
|
+
solutions live at axis-(-2) indices ``i`` and ``i + N``.
|
|
1487
|
+
"""
|
|
1488
|
+
if positions.shape[-1] == 2:
|
|
1489
|
+
positions = F.pad(positions, (0, 1), value=0.0)
|
|
1490
|
+
|
|
1491
|
+
orientation_matrices = self.euler2mat(euler_angles) # (N, 3, 3)
|
|
1492
|
+
|
|
1493
|
+
hkls_cart = None
|
|
1494
|
+
thetas = None
|
|
1495
|
+
if lattice_params is not None:
|
|
1496
|
+
hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
|
|
1497
|
+
elif strain is not None:
|
|
1498
|
+
raise ValueError(
|
|
1499
|
+
"strain was supplied but lattice_params is None; strain "
|
|
1500
|
+
"requires a reference lattice to apply (I + eps)^{-1} @ B0."
|
|
1501
|
+
)
|
|
1502
|
+
|
|
1503
|
+
omega, eta, two_theta, valid = self.calc_bragg_geometry(
|
|
1504
|
+
orientation_matrices, hkls_cart, thetas, per_grain=True
|
|
1505
|
+
)
|
|
1506
|
+
spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
|
|
1507
|
+
|
|
1508
|
+
if self.scan_config is not None:
|
|
1509
|
+
spots = self.filter_by_scan(spots, positions)
|
|
1510
|
+
|
|
1511
|
+
return spots
|
|
1512
|
+
|
|
1357
1513
|
# ------------------------------------------------------------------
|
|
1358
1514
|
# filter_by_scan (beam proximity for pf-HEDM)
|
|
1359
1515
|
# ------------------------------------------------------------------
|
|
@@ -613,18 +613,13 @@ def simulate_panel_zarrs(
|
|
|
613
613
|
f"BC=({panel.y_bc:.1f},{panel.z_bc:.1f}), "
|
|
614
614
|
f"tx={panel.tx:.2f}° ty={panel.ty:.2f}° tz={panel.tz:.2f}°")
|
|
615
615
|
with torch.no_grad():
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
diag_mask = np.zeros((G_strain, Kdim, Mdim), dtype=bool)
|
|
624
|
-
for gi in range(G_strain):
|
|
625
|
-
diag_mask[gi, gi, :] = True
|
|
626
|
-
diag_mask[gi, gi + n_grains, :] = True
|
|
627
|
-
valid_np = valid_np & diag_mask
|
|
616
|
+
# Per-grain fast path: O(N*M), no orientation x strain cross-product.
|
|
617
|
+
# Output is (2N, M) -- already the per-grain diagonal, no mask needed.
|
|
618
|
+
spots = model.forward_per_grain(eulers_t, positions_t,
|
|
619
|
+
lattice_params=latc, strain=strain_t)
|
|
620
|
+
|
|
621
|
+
valid_np = (spots.valid > 0.5).cpu().numpy() # (2N, M)
|
|
622
|
+
Kdim, Mdim = valid_np.shape # Kdim == 2 * n_grains
|
|
628
623
|
|
|
629
624
|
# First-wins: drop spots that an earlier panel already took
|
|
630
625
|
# (avoids double-counting in the rare overlap).
|
|
@@ -652,10 +647,11 @@ def simulate_panel_zarrs(
|
|
|
652
647
|
y_pix = spots.y_pixel.cpu().numpy()
|
|
653
648
|
z_pix = spots.z_pixel.cpu().numpy()
|
|
654
649
|
frame_nr = spots.frame_nr.cpu().numpy()
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
650
|
+
# (2N, M): row k holds grain (k % n_grains); rows [0:N]/[N:2N] are the
|
|
651
|
+
# two omega branches.
|
|
652
|
+
grain_ids = np.broadcast_to((np.arange(Kdim) % n_grains)[:, None],
|
|
653
|
+
(Kdim, Mdim))
|
|
654
|
+
hkl_ids = np.broadcast_to(np.arange(Mdim)[None, :], (Kdim, Mdim))
|
|
659
655
|
|
|
660
656
|
rec = {
|
|
661
657
|
"grain_id": grain_ids.reshape(-1)[flat_idx].astype(np.int32),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -19,6 +19,7 @@ License-File: LICENSE
|
|
|
19
19
|
Requires-Dist: numpy>=1.22
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: midas-distortion>=0.2.0
|
|
22
|
+
Requires-Dist: midas-stress>=0.8.0
|
|
22
23
|
Provides-Extra: dev
|
|
23
24
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
24
25
|
Requires-Dist: pytest-cov; extra == "dev"
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "midas-diffract"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.6.0"
|
|
8
8
|
description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "BSD-3-Clause"
|
|
@@ -27,6 +27,7 @@ dependencies = [
|
|
|
27
27
|
"numpy>=1.22",
|
|
28
28
|
"torch>=2.0",
|
|
29
29
|
"midas-distortion>=0.2.0",
|
|
30
|
+
"midas-stress>=0.8.0",
|
|
30
31
|
]
|
|
31
32
|
|
|
32
33
|
[project.optional-dependencies]
|
|
@@ -1111,5 +1111,179 @@ class TestCrossValidation:
|
|
|
1111
1111
|
import torch.nn.functional as F # for test_2d_position_backward_compat
|
|
1112
1112
|
|
|
1113
1113
|
|
|
1114
|
+
# ===================================================================
|
|
1115
|
+
# Test: midas_stress is the canonical orientation/strain-frame source
|
|
1116
|
+
# ===================================================================
|
|
1117
|
+
|
|
1118
|
+
class TestMidasStressDelegation:
|
|
1119
|
+
"""Guards that euler2mat / rotate_strain_sample_to_crystal delegate to
|
|
1120
|
+
midas_stress and never silently re-port (convention drift)."""
|
|
1121
|
+
|
|
1122
|
+
def test_euler2mat_matches_midas_stress(self):
|
|
1123
|
+
from midas_stress.orientation import euler_to_orient_mat
|
|
1124
|
+
torch.manual_seed(0)
|
|
1125
|
+
eul = torch.rand(7, 3, dtype=torch.float64) * 2 * math.pi
|
|
1126
|
+
R = HEDMForwardModel.euler2mat(eul)
|
|
1127
|
+
R_ms = euler_to_orient_mat(eul).reshape(7, 3, 3)
|
|
1128
|
+
torch.testing.assert_close(R, R_ms, atol=1e-12, rtol=0)
|
|
1129
|
+
|
|
1130
|
+
def test_euler2mat_vmap_safe(self):
|
|
1131
|
+
from torch.func import vmap
|
|
1132
|
+
eul = torch.rand(5, 3, dtype=torch.float64) * 2 * math.pi
|
|
1133
|
+
R = vmap(HEDMForwardModel.euler2mat)(eul)
|
|
1134
|
+
assert R.shape == (5, 3, 3)
|
|
1135
|
+
torch.testing.assert_close(R, HEDMForwardModel.euler2mat(eul), atol=1e-12, rtol=0)
|
|
1136
|
+
|
|
1137
|
+
def test_euler2mat_differentiable(self):
|
|
1138
|
+
eul = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
|
|
1139
|
+
HEDMForwardModel.euler2mat(eul).pow(2).sum().backward()
|
|
1140
|
+
assert eul.grad is not None and torch.all(torch.isfinite(eul.grad))
|
|
1141
|
+
|
|
1142
|
+
def test_rotate_strain_matches_midas_stress(self):
|
|
1143
|
+
from midas_stress.tensor import strain_lab_to_grain
|
|
1144
|
+
torch.manual_seed(1)
|
|
1145
|
+
OM = HEDMForwardModel.euler2mat(torch.rand(4, 3, dtype=torch.float64))
|
|
1146
|
+
v = torch.rand(4, 6, dtype=torch.float64) * 1e-3 # plain Voigt
|
|
1147
|
+
out = HEDMForwardModel.rotate_strain_sample_to_crystal(OM, v)
|
|
1148
|
+
# reference: build 3x3, rotate via midas_stress, repack plain Voigt
|
|
1149
|
+
S = torch.zeros(4, 3, 3, dtype=torch.float64)
|
|
1150
|
+
S[:, 0, 0], S[:, 1, 1], S[:, 2, 2] = v[:, 0], v[:, 3], v[:, 5]
|
|
1151
|
+
S[:, 0, 1] = S[:, 1, 0] = v[:, 1]
|
|
1152
|
+
S[:, 0, 2] = S[:, 2, 0] = v[:, 2]
|
|
1153
|
+
S[:, 1, 2] = S[:, 2, 1] = v[:, 4]
|
|
1154
|
+
C = strain_lab_to_grain(S, OM)
|
|
1155
|
+
ref = torch.stack([C[:, 0, 0], C[:, 0, 1], C[:, 0, 2],
|
|
1156
|
+
C[:, 1, 1], C[:, 1, 2], C[:, 2, 2]], dim=-1)
|
|
1157
|
+
torch.testing.assert_close(out, ref, atol=1e-14, rtol=0)
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
# ===================================================================
|
|
1161
|
+
# Test: full 3x3 strain tensor entry point (convention-free)
|
|
1162
|
+
# ===================================================================
|
|
1163
|
+
|
|
1164
|
+
class TestStrainTensorInput:
|
|
1165
|
+
"""A full symmetric 3x3 strain must give the same result as the
|
|
1166
|
+
equivalent plain-Voigt 6-vector (NOT Voigt-Mandel)."""
|
|
1167
|
+
|
|
1168
|
+
def test_strain_as_voigt_roundtrip(self):
|
|
1169
|
+
v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
|
|
1170
|
+
S = torch.tensor([[1e-3, 2e-4, 3e-4],
|
|
1171
|
+
[2e-4, -5e-4, 1e-4],
|
|
1172
|
+
[3e-4, 1e-4, 2e-4]])
|
|
1173
|
+
torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(S), v)
|
|
1174
|
+
# plain Voigt passes through unchanged
|
|
1175
|
+
torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(v), v)
|
|
1176
|
+
|
|
1177
|
+
def test_3x3_strain_equals_voigt_in_correct_hkls_latc(self, nf_geometry, device):
|
|
1178
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1179
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.])
|
|
1180
|
+
v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
|
|
1181
|
+
S = torch.tensor([[1e-3, 2e-4, 3e-4],
|
|
1182
|
+
[2e-4, -5e-4, 1e-4],
|
|
1183
|
+
[3e-4, 1e-4, 2e-4]])
|
|
1184
|
+
g_v, t_v = model.correct_hkls_latc(latc, strain=v)
|
|
1185
|
+
g_S, t_S = model.correct_hkls_latc(latc, strain=S)
|
|
1186
|
+
torch.testing.assert_close(g_v, g_S)
|
|
1187
|
+
torch.testing.assert_close(t_v, t_S)
|
|
1188
|
+
|
|
1189
|
+
def test_3x3_strain_equals_voigt_in_forward(self, nf_geometry, device):
|
|
1190
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1191
|
+
eul = torch.tensor([[0.3, 0.5, 1.1]])
|
|
1192
|
+
pos = torch.zeros(1, 3)
|
|
1193
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.]).expand(1, 6)
|
|
1194
|
+
v = torch.tensor([[1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4]])
|
|
1195
|
+
S = torch.tensor([[[1e-3, 2e-4, 3e-4],
|
|
1196
|
+
[2e-4, -5e-4, 1e-4],
|
|
1197
|
+
[3e-4, 1e-4, 2e-4]]])
|
|
1198
|
+
sp_v = model(eul, pos, lattice_params=latc, strain=v)
|
|
1199
|
+
sp_S = model(eul, pos, lattice_params=latc, strain=S)
|
|
1200
|
+
torch.testing.assert_close(sp_v.two_theta, sp_S.two_theta)
|
|
1201
|
+
torch.testing.assert_close(sp_v.omega, sp_S.omega)
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
# ===================================================================
|
|
1205
|
+
# Test: forward_per_grain (fast path) == diagonal of forward()
|
|
1206
|
+
# ===================================================================
|
|
1207
|
+
|
|
1208
|
+
class TestForwardPerGrain:
|
|
1209
|
+
"""The O(N*M) per-grain path must equal the diagonal of forward()'s
|
|
1210
|
+
O(N^2*M) orientation x strain cross-product, bit-for-bit, so the
|
|
1211
|
+
gradient and gradient-free callers agree."""
|
|
1212
|
+
|
|
1213
|
+
def _setup(self, nf_geometry, device, N=6, strained=True):
|
|
1214
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1215
|
+
torch.manual_seed(7)
|
|
1216
|
+
euler = torch.rand(N, 3, dtype=torch.float64) * 2 * math.pi
|
|
1217
|
+
pos = torch.arange(N * 3, dtype=torch.float64).reshape(N, 3) * 5.0
|
|
1218
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.],
|
|
1219
|
+
dtype=torch.float64).expand(N, 6)
|
|
1220
|
+
if strained:
|
|
1221
|
+
s = torch.randn(N, 3, 3, dtype=torch.float64) * 1e-3
|
|
1222
|
+
strain = (s + s.transpose(-1, -2)) / 2
|
|
1223
|
+
else:
|
|
1224
|
+
strain = None
|
|
1225
|
+
return model, euler, pos, latc, strain, N
|
|
1226
|
+
|
|
1227
|
+
def test_per_grain_equals_forward_diagonal_strained(self, nf_geometry, device):
|
|
1228
|
+
# Compare per-grain (no global sort, which would reorder near-degenerate
|
|
1229
|
+
# spots): forward() diagonal rows [gi,gi] & [gi,gi+N] vs
|
|
1230
|
+
# forward_per_grain() rows [gi] & [gi+N]. Element-wise einsum differs
|
|
1231
|
+
# from the cross-product einsum only by fp64 reduction order (~1e-12).
|
|
1232
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1233
|
+
spf = model(euler, pos, lattice_params=latc, strain=strain) # (N,2N,M)
|
|
1234
|
+
spg = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain) # (2N,M)
|
|
1235
|
+
|
|
1236
|
+
vf = spf.valid > 0.5
|
|
1237
|
+
vg = spg.valid > 0.5
|
|
1238
|
+
max_diff = 0.0
|
|
1239
|
+
for fld in ("two_theta", "eta", "omega", "y_pixel", "z_pixel", "frame_nr"):
|
|
1240
|
+
F = getattr(spf, fld)
|
|
1241
|
+
G = getattr(spg, fld)
|
|
1242
|
+
for gi in range(N):
|
|
1243
|
+
for kf, kg in ((gi, gi), (gi + N, gi + N)):
|
|
1244
|
+
# validity must agree exactly on the diagonal
|
|
1245
|
+
assert torch.equal(vf[gi, kf], vg[kg])
|
|
1246
|
+
max_diff = max(max_diff, float((F[gi, kf] - G[kg]).abs().max()))
|
|
1247
|
+
assert max_diff < 1e-9, f"per-grain vs forward-diagonal max diff {max_diff:.2e}"
|
|
1248
|
+
|
|
1249
|
+
def test_per_grain_shape_is_2N(self, nf_geometry, device):
|
|
1250
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1251
|
+
sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
|
|
1252
|
+
assert sp.valid.shape[0] == 2 * N
|
|
1253
|
+
|
|
1254
|
+
def test_per_grain_shared_lattice_no_strain(self, nf_geometry, device):
|
|
1255
|
+
# nominal (no lattice) path must also work and avoid cross-product
|
|
1256
|
+
model, euler, pos, latc, _, N = self._setup(nf_geometry, device, strained=False)
|
|
1257
|
+
sp = model.forward_per_grain(euler, pos)
|
|
1258
|
+
assert sp.valid.shape[0] == 2 * N
|
|
1259
|
+
|
|
1260
|
+
def test_per_grain_differentiable(self, nf_geometry, device):
|
|
1261
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1262
|
+
euler = euler.clone().requires_grad_(True)
|
|
1263
|
+
sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
|
|
1264
|
+
# omega/eta/pixels depend on orientation (two_theta does not)
|
|
1265
|
+
(sp.omega * sp.valid).sum().backward()
|
|
1266
|
+
assert euler.grad is not None and torch.all(torch.isfinite(euler.grad))
|
|
1267
|
+
|
|
1268
|
+
def test_per_grain_fp32(self, nf_geometry, device):
|
|
1269
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1270
|
+
sp = model.forward_per_grain(euler.float(), pos.float(),
|
|
1271
|
+
lattice_params=latc.float(), strain=strain.float())
|
|
1272
|
+
assert sp.valid.dtype == torch.float32 or sp.valid.dtype == torch.float64
|
|
1273
|
+
|
|
1274
|
+
def test_forward_warns_on_cross_product(self, nf_geometry, device):
|
|
1275
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1276
|
+
with pytest.warns(UserWarning, match="cross-product"):
|
|
1277
|
+
model(euler, pos, lattice_params=latc, strain=strain)
|
|
1278
|
+
|
|
1279
|
+
def test_forward_no_warn_shared_lattice(self, nf_geometry, device):
|
|
1280
|
+
import warnings as _w
|
|
1281
|
+
model, euler, pos, _, _, N = self._setup(nf_geometry, device, strained=False)
|
|
1282
|
+
shared = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.], dtype=torch.float64)
|
|
1283
|
+
with _w.catch_warnings():
|
|
1284
|
+
_w.simplefilter("error") # any cross-product warning becomes an error
|
|
1285
|
+
model(euler, pos, lattice_params=shared) # shared -> (2N,M), no warn
|
|
1286
|
+
|
|
1287
|
+
|
|
1114
1288
|
if __name__ == "__main__":
|
|
1115
1289
|
pytest.main([__file__, "-v"])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|