midas-diffract 0.4.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/PKG-INFO +1 -1
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/__init__.py +1 -1
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/forward.py +133 -5
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/simulate_panel_zarrs.py +12 -16
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/PKG-INFO +1 -1
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/pyproject.toml +1 -1
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_forward.py +84 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/LICENSE +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/README.md +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/hkls.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/losses.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract/optimize.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/SOURCES.txt +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/requires.txt +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/midas_diffract.egg-info/top_level.txt +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/setup.cfg +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_c_comparison.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_distortion_layer.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_hkls.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_losses.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_multi_detector.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_strain_tensor.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_tilts.py +0 -0
- {midas_diffract-0.4.0 → midas_diffract-0.6.0}/tests/test_wedge.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -27,6 +27,7 @@ Reference C code:
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
import math
|
|
30
|
+
import warnings
|
|
30
31
|
from dataclasses import dataclass, field
|
|
31
32
|
from typing import Optional, Tuple
|
|
32
33
|
|
|
@@ -738,6 +739,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
738
739
|
orientation_matrices: torch.Tensor,
|
|
739
740
|
hkls_cart: Optional[torch.Tensor] = None,
|
|
740
741
|
thetas: Optional[torch.Tensor] = None,
|
|
742
|
+
per_grain: bool = False,
|
|
741
743
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
742
744
|
"""Core Bragg geometry: orientations + G-vectors -> angles.
|
|
743
745
|
|
|
@@ -779,14 +781,32 @@ class HEDMForwardModel(nn.Module):
|
|
|
779
781
|
# batch, (b) per-voxel hkls_cart shape (..., M, 3) for strained
|
|
780
782
|
# rendering. Both flow through the same einsum via leading-dim
|
|
781
783
|
# broadcasting on the second arg.
|
|
782
|
-
|
|
784
|
+
#
|
|
785
|
+
# per_grain=True: ELEMENT-WISE pairing of grain i's orientation with
|
|
786
|
+
# grain i's strained hkls -- NO orientation x strain cross-product.
|
|
787
|
+
# Requires orientation_matrices (N,3,3); hkls_cart (N,M,3) per-grain or
|
|
788
|
+
# (M,3) shared. Used by forward_per_grain() for the O(N*M) fast path.
|
|
789
|
+
if per_grain:
|
|
790
|
+
if hkls_cart.dim() == 2: # (M,3) shared lattice
|
|
791
|
+
G_C = torch.einsum("nij,mj->nmi", orientation_matrices, hkls_cart)
|
|
792
|
+
else: # (N,M,3) per-grain strain
|
|
793
|
+
G_C = torch.einsum("nij,nmj->nmi", orientation_matrices, hkls_cart)
|
|
794
|
+
else:
|
|
795
|
+
G_C = torch.einsum("...nij,...mj->...nmi", orientation_matrices, hkls_cart)
|
|
783
796
|
|
|
784
797
|
# v = sin(theta)*|G| -- C precomputes Gs from the UNROTATED G-vector norm
|
|
785
798
|
# (rotation preserves norm in exact arithmetic but not in float64).
|
|
786
799
|
# Match C: use |hkls_cart| (pre-rotation), not |R @ hkls_cart|.
|
|
787
800
|
len_hkl = torch.norm(hkls_cart, dim=-1) # (M,) or (..., M)
|
|
788
801
|
v_no_wedge = torch.sin(thetas) * len_hkl # (M,) or (..., M)
|
|
789
|
-
|
|
802
|
+
# Broadcast v to G_C's (..., N, M) grid. Layout-agnostic: in the
|
|
803
|
+
# cross-product layout v gains an orientation axis at -2; in the
|
|
804
|
+
# per_grain / shared layouts it already matches. Numerically identical
|
|
805
|
+
# to the former ``unsqueeze(-2).expand_as`` for those existing paths.
|
|
806
|
+
gc0 = G_C[..., 0]
|
|
807
|
+
while v_no_wedge.dim() < gc0.dim():
|
|
808
|
+
v_no_wedge = v_no_wedge.unsqueeze(-2)
|
|
809
|
+
v_no_wedge = v_no_wedge.expand_as(gc0)
|
|
790
810
|
|
|
791
811
|
# ---- Wedge: rigorous geometric formulation -------------------
|
|
792
812
|
# The rotation axis tilts from z to n_hat = (sin W, 0, cos W).
|
|
@@ -900,9 +920,16 @@ class HEDMForwardModel(nn.Module):
|
|
|
900
920
|
eta = self.safe_arccos(Gz_lab / r_yz)
|
|
901
921
|
eta = -torch.sign(Gy_lab) * eta
|
|
902
922
|
|
|
903
|
-
# 2*theta
|
|
904
|
-
|
|
905
|
-
|
|
923
|
+
# 2*theta -- broadcast to the single-branch (pre-cat) grid, then double
|
|
924
|
+
# along the grain axis to match all_omega's 2N. Layout-agnostic (handles
|
|
925
|
+
# the per_grain layout where the 2N axis is grain-doubled) and
|
|
926
|
+
# numerically identical to the former expand for the cross-product path
|
|
927
|
+
# (thetas does not depend on the orientation axis).
|
|
928
|
+
tt = 2.0 * thetas
|
|
929
|
+
while tt.dim() < omega_p.dim():
|
|
930
|
+
tt = tt.unsqueeze(-2)
|
|
931
|
+
tt = tt.expand_as(omega_p)
|
|
932
|
+
two_theta = torch.cat([tt, tt], dim=-2)
|
|
906
933
|
|
|
907
934
|
# Validity mask
|
|
908
935
|
valid_p = disc_valid & coswp_valid
|
|
@@ -1354,6 +1381,11 @@ class HEDMForwardModel(nn.Module):
|
|
|
1354
1381
|
if positions.shape[-1] == 2:
|
|
1355
1382
|
positions = F.pad(positions, (0, 1), value=0.0)
|
|
1356
1383
|
|
|
1384
|
+
# Footgun guard: per-grain lattice/strain + N>1 orientations forms an
|
|
1385
|
+
# N x N orientation x strain cross-product (output (N, 2N, M)); callers
|
|
1386
|
+
# simulating a fixed polycrystal almost always want only the diagonal.
|
|
1387
|
+
self._warn_if_cross_product(euler_angles, lattice_params, strain)
|
|
1388
|
+
|
|
1357
1389
|
# 1. Orientation matrices
|
|
1358
1390
|
orientation_matrices = self.euler2mat(euler_angles)
|
|
1359
1391
|
|
|
@@ -1382,6 +1414,102 @@ class HEDMForwardModel(nn.Module):
|
|
|
1382
1414
|
|
|
1383
1415
|
return spots
|
|
1384
1416
|
|
|
1417
|
+
@staticmethod
|
|
1418
|
+
def _warn_if_cross_product(euler_angles, lattice_params, strain):
|
|
1419
|
+
"""Warn when forward() would form an orientation x strain cross-product.
|
|
1420
|
+
|
|
1421
|
+
Fires only when there are N>1 orientations AND lattice_params/strain
|
|
1422
|
+
carry a matching per-grain axis -- the case where the (N, 2N, M) output
|
|
1423
|
+
is an N x N cross-product and the caller likely wanted the diagonal.
|
|
1424
|
+
Shared lattice/strain (no grain axis) is the correct (2N, M) path and
|
|
1425
|
+
does not warn.
|
|
1426
|
+
"""
|
|
1427
|
+
n_orient = euler_angles.shape[-2] if euler_angles.dim() >= 2 else 1
|
|
1428
|
+
if n_orient <= 1:
|
|
1429
|
+
return
|
|
1430
|
+
|
|
1431
|
+
def _has_grain_axis(t):
|
|
1432
|
+
if t is None:
|
|
1433
|
+
return False
|
|
1434
|
+
if t.shape[-1] == 6 and t.dim() >= 2: # Voigt lattice or strain
|
|
1435
|
+
return t.shape[-2] == n_orient
|
|
1436
|
+
if tuple(t.shape[-2:]) == (3, 3) and t.dim() >= 3: # full-tensor strain
|
|
1437
|
+
return t.shape[-3] == n_orient
|
|
1438
|
+
return False
|
|
1439
|
+
|
|
1440
|
+
if _has_grain_axis(lattice_params) or _has_grain_axis(strain):
|
|
1441
|
+
warnings.warn(
|
|
1442
|
+
f"forward() called with {n_orient} orientations and per-grain "
|
|
1443
|
+
"lattice_params/strain forms an orientation x strain "
|
|
1444
|
+
f"cross-product (output shape ({n_orient}, {2 * n_orient}, M)); "
|
|
1445
|
+
"only the diagonal [i, i] and [i, i+N] is physical. For a fixed "
|
|
1446
|
+
"polycrystal use forward_per_grain() (element-wise, O(N*M)), "
|
|
1447
|
+
"or index the diagonal of this output.",
|
|
1448
|
+
stacklevel=3,
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
def forward_per_grain(
|
|
1452
|
+
self,
|
|
1453
|
+
euler_angles: torch.Tensor,
|
|
1454
|
+
positions: torch.Tensor,
|
|
1455
|
+
lattice_params: Optional[torch.Tensor] = None,
|
|
1456
|
+
strain: Optional[torch.Tensor] = None,
|
|
1457
|
+
) -> SpotDescriptors:
|
|
1458
|
+
"""Element-wise per-grain forward simulation -- the fast path.
|
|
1459
|
+
|
|
1460
|
+
Grain ``i`` is simulated with orientation ``i``, lattice/strain ``i``
|
|
1461
|
+
and position ``i``, WITHOUT the orientation x strain cross-product that
|
|
1462
|
+
:meth:`forward` forms when both are per-grain. The output has leading
|
|
1463
|
+
shape ``(2N, M)`` (the two omega branches doubled along the grain axis),
|
|
1464
|
+
which is exactly the diagonal of :meth:`forward`'s ``(N, 2N, M)`` output
|
|
1465
|
+
-- so gradient and gradient-free callers agree bit-for-bit.
|
|
1466
|
+
|
|
1467
|
+
Cost is O(N*M) rather than O(N^2 * M), matching the algorithm of the C
|
|
1468
|
+
reference ``ForwardSimulationCompressed.c``. Fully differentiable and
|
|
1469
|
+
device-portable; for pure forward simulation wrap the call in
|
|
1470
|
+
``torch.inference_mode()``.
|
|
1471
|
+
|
|
1472
|
+
Parameters
|
|
1473
|
+
----------
|
|
1474
|
+
euler_angles : Tensor (N, 3)
|
|
1475
|
+
Bunge ZXZ Euler angles (radians), one per grain. No leading batch.
|
|
1476
|
+
positions : Tensor (N, 3) or (N, 2)
|
|
1477
|
+
Real-space grain positions (micrometers).
|
|
1478
|
+
lattice_params : Tensor (6,) or (N, 6), optional
|
|
1479
|
+
Shared or per-grain lattice [a,b,c,alpha,beta,gamma] (Ang/deg).
|
|
1480
|
+
strain : Tensor (6,), (N, 6), (3, 3), or (N, 3, 3), optional
|
|
1481
|
+
Shared or per-grain crystal-frame strain (plain-Voigt or full 3x3).
|
|
1482
|
+
|
|
1483
|
+
Returns
|
|
1484
|
+
-------
|
|
1485
|
+
SpotDescriptors with leading shape ``(2N, M)``: grain ``i``'s two omega
|
|
1486
|
+
solutions live at axis-(-2) indices ``i`` and ``i + N``.
|
|
1487
|
+
"""
|
|
1488
|
+
if positions.shape[-1] == 2:
|
|
1489
|
+
positions = F.pad(positions, (0, 1), value=0.0)
|
|
1490
|
+
|
|
1491
|
+
orientation_matrices = self.euler2mat(euler_angles) # (N, 3, 3)
|
|
1492
|
+
|
|
1493
|
+
hkls_cart = None
|
|
1494
|
+
thetas = None
|
|
1495
|
+
if lattice_params is not None:
|
|
1496
|
+
hkls_cart, thetas = self.correct_hkls_latc(lattice_params, strain=strain)
|
|
1497
|
+
elif strain is not None:
|
|
1498
|
+
raise ValueError(
|
|
1499
|
+
"strain was supplied but lattice_params is None; strain "
|
|
1500
|
+
"requires a reference lattice to apply (I + eps)^{-1} @ B0."
|
|
1501
|
+
)
|
|
1502
|
+
|
|
1503
|
+
omega, eta, two_theta, valid = self.calc_bragg_geometry(
|
|
1504
|
+
orientation_matrices, hkls_cart, thetas, per_grain=True
|
|
1505
|
+
)
|
|
1506
|
+
spots = self.project_to_detector(omega, eta, two_theta, positions, valid)
|
|
1507
|
+
|
|
1508
|
+
if self.scan_config is not None:
|
|
1509
|
+
spots = self.filter_by_scan(spots, positions)
|
|
1510
|
+
|
|
1511
|
+
return spots
|
|
1512
|
+
|
|
1385
1513
|
# ------------------------------------------------------------------
|
|
1386
1514
|
# filter_by_scan (beam proximity for pf-HEDM)
|
|
1387
1515
|
# ------------------------------------------------------------------
|
|
@@ -613,18 +613,13 @@ def simulate_panel_zarrs(
|
|
|
613
613
|
f"BC=({panel.y_bc:.1f},{panel.z_bc:.1f}), "
|
|
614
614
|
f"tx={panel.tx:.2f}° ty={panel.ty:.2f}° tz={panel.tz:.2f}°")
|
|
615
615
|
with torch.no_grad():
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
diag_mask = np.zeros((G_strain, Kdim, Mdim), dtype=bool)
|
|
624
|
-
for gi in range(G_strain):
|
|
625
|
-
diag_mask[gi, gi, :] = True
|
|
626
|
-
diag_mask[gi, gi + n_grains, :] = True
|
|
627
|
-
valid_np = valid_np & diag_mask
|
|
616
|
+
# Per-grain fast path: O(N*M), no orientation x strain cross-product.
|
|
617
|
+
# Output is (2N, M) -- already the per-grain diagonal, no mask needed.
|
|
618
|
+
spots = model.forward_per_grain(eulers_t, positions_t,
|
|
619
|
+
lattice_params=latc, strain=strain_t)
|
|
620
|
+
|
|
621
|
+
valid_np = (spots.valid > 0.5).cpu().numpy() # (2N, M)
|
|
622
|
+
Kdim, Mdim = valid_np.shape # Kdim == 2 * n_grains
|
|
628
623
|
|
|
629
624
|
# First-wins: drop spots that an earlier panel already took
|
|
630
625
|
# (avoids double-counting in the rare overlap).
|
|
@@ -652,10 +647,11 @@ def simulate_panel_zarrs(
|
|
|
652
647
|
y_pix = spots.y_pixel.cpu().numpy()
|
|
653
648
|
z_pix = spots.z_pixel.cpu().numpy()
|
|
654
649
|
frame_nr = spots.frame_nr.cpu().numpy()
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
650
|
+
# (2N, M): row k holds grain (k % n_grains); rows [0:N]/[N:2N] are the
|
|
651
|
+
# two omega branches.
|
|
652
|
+
grain_ids = np.broadcast_to((np.arange(Kdim) % n_grains)[:, None],
|
|
653
|
+
(Kdim, Mdim))
|
|
654
|
+
hkl_ids = np.broadcast_to(np.arange(Mdim)[None, :], (Kdim, Mdim))
|
|
659
655
|
|
|
660
656
|
rec = {
|
|
661
657
|
"grain_id": grain_ids.reshape(-1)[flat_idx].astype(np.int32),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "midas-diffract"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.6.0"
|
|
8
8
|
description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "BSD-3-Clause"
|
|
@@ -1201,5 +1201,89 @@ class TestStrainTensorInput:
|
|
|
1201
1201
|
torch.testing.assert_close(sp_v.omega, sp_S.omega)
|
|
1202
1202
|
|
|
1203
1203
|
|
|
1204
|
+
# ===================================================================
|
|
1205
|
+
# Test: forward_per_grain (fast path) == diagonal of forward()
|
|
1206
|
+
# ===================================================================
|
|
1207
|
+
|
|
1208
|
+
class TestForwardPerGrain:
|
|
1209
|
+
"""The O(N*M) per-grain path must equal the diagonal of forward()'s
|
|
1210
|
+
O(N^2*M) orientation x strain cross-product, bit-for-bit, so the
|
|
1211
|
+
gradient and gradient-free callers agree."""
|
|
1212
|
+
|
|
1213
|
+
def _setup(self, nf_geometry, device, N=6, strained=True):
|
|
1214
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1215
|
+
torch.manual_seed(7)
|
|
1216
|
+
euler = torch.rand(N, 3, dtype=torch.float64) * 2 * math.pi
|
|
1217
|
+
pos = torch.arange(N * 3, dtype=torch.float64).reshape(N, 3) * 5.0
|
|
1218
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.],
|
|
1219
|
+
dtype=torch.float64).expand(N, 6)
|
|
1220
|
+
if strained:
|
|
1221
|
+
s = torch.randn(N, 3, 3, dtype=torch.float64) * 1e-3
|
|
1222
|
+
strain = (s + s.transpose(-1, -2)) / 2
|
|
1223
|
+
else:
|
|
1224
|
+
strain = None
|
|
1225
|
+
return model, euler, pos, latc, strain, N
|
|
1226
|
+
|
|
1227
|
+
def test_per_grain_equals_forward_diagonal_strained(self, nf_geometry, device):
|
|
1228
|
+
# Compare per-grain (no global sort, which would reorder near-degenerate
|
|
1229
|
+
# spots): forward() diagonal rows [gi,gi] & [gi,gi+N] vs
|
|
1230
|
+
# forward_per_grain() rows [gi] & [gi+N]. Element-wise einsum differs
|
|
1231
|
+
# from the cross-product einsum only by fp64 reduction order (~1e-12).
|
|
1232
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1233
|
+
spf = model(euler, pos, lattice_params=latc, strain=strain) # (N,2N,M)
|
|
1234
|
+
spg = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain) # (2N,M)
|
|
1235
|
+
|
|
1236
|
+
vf = spf.valid > 0.5
|
|
1237
|
+
vg = spg.valid > 0.5
|
|
1238
|
+
max_diff = 0.0
|
|
1239
|
+
for fld in ("two_theta", "eta", "omega", "y_pixel", "z_pixel", "frame_nr"):
|
|
1240
|
+
F = getattr(spf, fld)
|
|
1241
|
+
G = getattr(spg, fld)
|
|
1242
|
+
for gi in range(N):
|
|
1243
|
+
for kf, kg in ((gi, gi), (gi + N, gi + N)):
|
|
1244
|
+
# validity must agree exactly on the diagonal
|
|
1245
|
+
assert torch.equal(vf[gi, kf], vg[kg])
|
|
1246
|
+
max_diff = max(max_diff, float((F[gi, kf] - G[kg]).abs().max()))
|
|
1247
|
+
assert max_diff < 1e-9, f"per-grain vs forward-diagonal max diff {max_diff:.2e}"
|
|
1248
|
+
|
|
1249
|
+
def test_per_grain_shape_is_2N(self, nf_geometry, device):
|
|
1250
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1251
|
+
sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
|
|
1252
|
+
assert sp.valid.shape[0] == 2 * N
|
|
1253
|
+
|
|
1254
|
+
def test_per_grain_shared_lattice_no_strain(self, nf_geometry, device):
|
|
1255
|
+
# nominal (no lattice) path must also work and avoid cross-product
|
|
1256
|
+
model, euler, pos, latc, _, N = self._setup(nf_geometry, device, strained=False)
|
|
1257
|
+
sp = model.forward_per_grain(euler, pos)
|
|
1258
|
+
assert sp.valid.shape[0] == 2 * N
|
|
1259
|
+
|
|
1260
|
+
def test_per_grain_differentiable(self, nf_geometry, device):
|
|
1261
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1262
|
+
euler = euler.clone().requires_grad_(True)
|
|
1263
|
+
sp = model.forward_per_grain(euler, pos, lattice_params=latc, strain=strain)
|
|
1264
|
+
# omega/eta/pixels depend on orientation (two_theta does not)
|
|
1265
|
+
(sp.omega * sp.valid).sum().backward()
|
|
1266
|
+
assert euler.grad is not None and torch.all(torch.isfinite(euler.grad))
|
|
1267
|
+
|
|
1268
|
+
def test_per_grain_fp32(self, nf_geometry, device):
|
|
1269
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1270
|
+
sp = model.forward_per_grain(euler.float(), pos.float(),
|
|
1271
|
+
lattice_params=latc.float(), strain=strain.float())
|
|
1272
|
+
assert sp.valid.dtype == torch.float32 or sp.valid.dtype == torch.float64
|
|
1273
|
+
|
|
1274
|
+
def test_forward_warns_on_cross_product(self, nf_geometry, device):
|
|
1275
|
+
model, euler, pos, latc, strain, N = self._setup(nf_geometry, device)
|
|
1276
|
+
with pytest.warns(UserWarning, match="cross-product"):
|
|
1277
|
+
model(euler, pos, lattice_params=latc, strain=strain)
|
|
1278
|
+
|
|
1279
|
+
def test_forward_no_warn_shared_lattice(self, nf_geometry, device):
|
|
1280
|
+
import warnings as _w
|
|
1281
|
+
model, euler, pos, _, _, N = self._setup(nf_geometry, device, strained=False)
|
|
1282
|
+
shared = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.], dtype=torch.float64)
|
|
1283
|
+
with _w.catch_warnings():
|
|
1284
|
+
_w.simplefilter("error") # any cross-product warning becomes an error
|
|
1285
|
+
model(euler, pos, lattice_params=shared) # shared -> (2N,M), no warn
|
|
1286
|
+
|
|
1287
|
+
|
|
1204
1288
|
if __name__ == "__main__":
|
|
1205
1289
|
pytest.main([__file__, "-v"])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|