midas-diffract 0.1.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/PKG-INFO +3 -1
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/__init__.py +1 -1
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/forward.py +129 -41
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/PKG-INFO +3 -1
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/SOURCES.txt +1 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/requires.txt +2 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/pyproject.toml +3 -1
- midas_diffract-0.4.0/tests/test_distortion_layer.py +78 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_forward.py +90 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/LICENSE +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/README.md +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/hkls.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/losses.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/optimize.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/simulate_panel_zarrs.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/top_level.txt +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/setup.cfg +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_c_comparison.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_hkls.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_losses.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_multi_detector.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_strain_tensor.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_tilts.py +0 -0
- {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_wedge.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -18,6 +18,8 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: numpy>=1.22
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
|
+
Requires-Dist: midas-distortion>=0.2.0
|
|
22
|
+
Requires-Dist: midas-stress>=0.8.0
|
|
21
23
|
Provides-Extra: dev
|
|
22
24
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
23
25
|
Requires-Dist: pytest-cov; extra == "dev"
|
|
@@ -34,6 +34,16 @@ import torch
|
|
|
34
34
|
import torch.nn as nn
|
|
35
35
|
import torch.nn.functional as F
|
|
36
36
|
|
|
37
|
+
# Canonical orientation + strain-frame primitives. midas_stress is the single
|
|
38
|
+
# source of truth for this math (Bunge ZXZ orientation algebra, sample<->crystal
|
|
39
|
+
# strain rotation); its torch backend is differentiable end-to-end and
|
|
40
|
+
# device-portable, so the forward model delegates rather than re-porting.
|
|
41
|
+
# NOTE: midas_stress's Voigt is Voigt-MANDEL (sqrt2 on shears); this model uses
|
|
42
|
+
# PLAIN-Voigt / raw 3x3 strain, so we only delegate the *rotation*, never the
|
|
43
|
+
# Voigt packing (see rotate_strain_sample_to_crystal / correct_hkls_latc).
|
|
44
|
+
from midas_stress.orientation import euler_to_orient_mat as _ms_euler_to_orient_mat
|
|
45
|
+
from midas_stress.tensor import strain_lab_to_grain as _ms_strain_lab_to_grain
|
|
46
|
+
|
|
37
47
|
|
|
38
48
|
# ---------------------------------------------------------------------------
|
|
39
49
|
# Configuration data classes
|
|
@@ -89,6 +99,16 @@ class HEDMGeometry:
|
|
|
89
99
|
# Default False preserves the existing
|
|
90
100
|
# behaviour: NF applies tilts, FF skips.
|
|
91
101
|
# Set True for raw multi-panel simulation.
|
|
102
|
+
# Radial detector distortion (canonical midas_distortion v2 model). Like
|
|
103
|
+
# tilts, this maps an IDEAL prediction to the RAW detector position and is
|
|
104
|
+
# OFF by default -- the FF/pf experimental pipeline pre-corrects distortion
|
|
105
|
+
# at peak-finding time (transforms), so the forward must NOT re-apply it for
|
|
106
|
+
# the indexer/fit-grain. Raw-pixel-patch consumers (pf_odf, grain_odf) that
|
|
107
|
+
# never go through transforms set ``apply_distortion=True`` and supply the
|
|
108
|
+
# calibrated v2 coefficients to predict in the raw frame.
|
|
109
|
+
apply_distortion: bool = False
|
|
110
|
+
p_distortion: "list[float] | None" = None # 15 v2 coeffs (midas_distortion P_COEF_NAMES order); None/zeros => no-op
|
|
111
|
+
rho_d: "float | None" = None # distortion radius normalization (um); None => resolve from detector corner
|
|
92
112
|
multi_mode: str = "layered" # "layered" (default): NF semantics --
|
|
93
113
|
# spot must land on the detector at
|
|
94
114
|
# EVERY distance (AllDistsFound).
|
|
@@ -342,6 +362,25 @@ class HEDMForwardModel(nn.Module):
|
|
|
342
362
|
|
|
343
363
|
# Multi-detector / multi-panel configuration
|
|
344
364
|
self.apply_tilts = bool(geometry.apply_tilts)
|
|
365
|
+
|
|
366
|
+
# Radial detector distortion (canonical midas_distortion v2 model),
|
|
367
|
+
# applied ideal->raw in project_to_detector when apply_distortion=True.
|
|
368
|
+
# OFF by default => identity => indexer/fit-grain output unchanged.
|
|
369
|
+
self.apply_distortion = bool(getattr(geometry, "apply_distortion", False))
|
|
370
|
+
p_dist = getattr(geometry, "p_distortion", None)
|
|
371
|
+
if p_dist is not None:
|
|
372
|
+
self.p_distortion = nn.Parameter(
|
|
373
|
+
torch.as_tensor(p_dist, dtype=torch.float64, device=device),
|
|
374
|
+
requires_grad=False,
|
|
375
|
+
)
|
|
376
|
+
self._has_distortion = bool(
|
|
377
|
+
torch.any(torch.abs(self.p_distortion.detach()) > 0.0).item()
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
self.p_distortion = None
|
|
381
|
+
self._has_distortion = False
|
|
382
|
+
self.rho_d = getattr(geometry, "rho_d", None)
|
|
383
|
+
|
|
345
384
|
if geometry.multi_mode not in ("layered", "panel"):
|
|
346
385
|
raise ValueError(
|
|
347
386
|
f"Unknown multi_mode {geometry.multi_mode!r}; "
|
|
@@ -408,7 +447,13 @@ class HEDMForwardModel(nn.Module):
|
|
|
408
447
|
|
|
409
448
|
@staticmethod
|
|
410
449
|
def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
|
|
411
|
-
"""Convert ZXZ Euler angles to rotation matrices.
|
|
450
|
+
"""Convert ZXZ (Bunge) Euler angles to crystal->sample rotation matrices.
|
|
451
|
+
|
|
452
|
+
Delegates to ``midas_stress.orientation.euler_to_orient_mat`` -- the
|
|
453
|
+
canonical orientation primitive -- so the convention can never drift
|
|
454
|
+
from the rest of MIDAS. midas_stress's torch backend is differentiable
|
|
455
|
+
and vmap-safe; the result is identical to the former in-line ZXZ build
|
|
456
|
+
(R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)) to ~1e-16.
|
|
412
457
|
|
|
413
458
|
Parameters
|
|
414
459
|
----------
|
|
@@ -418,35 +463,14 @@ class HEDMForwardModel(nn.Module):
|
|
|
418
463
|
Returns
|
|
419
464
|
-------
|
|
420
465
|
Tensor (..., 3, 3)
|
|
421
|
-
Rotation matrices.
|
|
466
|
+
Rotation matrices (crystal->sample), orthogonalized onto SO(3).
|
|
422
467
|
"""
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
# ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
|
|
430
|
-
# Verified element-by-element against nfhedm.py lines 114-120.
|
|
431
|
-
# Built via torch.stack rather than indexed assignment so the function
|
|
432
|
-
# composes with torch.func.vmap (in-place writes block vmap).
|
|
433
|
-
row0 = torch.stack([
|
|
434
|
-
c0 * c2 - s0 * c1 * s2,
|
|
435
|
-
-s0 * c1 * c2 - c0 * s2,
|
|
436
|
-
s0 * s1,
|
|
437
|
-
], dim=-1)
|
|
438
|
-
row1 = torch.stack([
|
|
439
|
-
s0 * c2 + c0 * c1 * s2,
|
|
440
|
-
c0 * c1 * c2 - s0 * s2,
|
|
441
|
-
-c0 * s1,
|
|
442
|
-
], dim=-1)
|
|
443
|
-
row2 = torch.stack([
|
|
444
|
-
s1 * s2,
|
|
445
|
-
s1 * c2,
|
|
446
|
-
c1,
|
|
447
|
-
], dim=-1)
|
|
448
|
-
R = torch.stack([row0, row1, row2], dim=-2)
|
|
449
|
-
|
|
468
|
+
if not isinstance(euler_angles, torch.Tensor):
|
|
469
|
+
euler_angles = torch.as_tensor(euler_angles)
|
|
470
|
+
R = _ms_euler_to_orient_mat(euler_angles) # (..., 9), torch
|
|
471
|
+
R = R.reshape(*R.shape[:-1], 3, 3)
|
|
472
|
+
# midas_stress already returns a proper rotation; orthogonalize keeps the
|
|
473
|
+
# historical "exactly on SO(3)" guarantee and is idempotent here.
|
|
450
474
|
return HEDMForwardModel.orthogonalize(R)
|
|
451
475
|
|
|
452
476
|
# ------------------------------------------------------------------
|
|
@@ -496,6 +520,28 @@ class HEDMForwardModel(nn.Module):
|
|
|
496
520
|
"""Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
|
|
497
521
|
return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
|
|
498
522
|
|
|
523
|
+
# ------------------------------------------------------------------
|
|
524
|
+
# strain_as_voigt (accept full 3x3 tensor OR plain-Voigt 6-vector)
|
|
525
|
+
# ------------------------------------------------------------------
|
|
526
|
+
|
|
527
|
+
@staticmethod
|
|
528
|
+
def strain_as_voigt(strain: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
"""Normalize a strain input to PLAIN-Voigt [e11,e12,e13,e22,e23,e33].
|
|
530
|
+
|
|
531
|
+
Accepts either a plain-Voigt ``(..., 6)`` tensor (returned unchanged) or
|
|
532
|
+
a full symmetric ``(..., 3, 3)`` strain tensor. The 3x3 path is
|
|
533
|
+
convention-free -- the natural way to hand a strain field straight from
|
|
534
|
+
a tensor source (e.g. a midas_stress strain field) into the forward
|
|
535
|
+
model without picking a Voigt/Mandel packing. The off-diagonals are
|
|
536
|
+
taken as TRUE tensor components (no factor of 2).
|
|
537
|
+
"""
|
|
538
|
+
if strain.dim() >= 2 and strain.shape[-1] == 3 and strain.shape[-2] == 3:
|
|
539
|
+
return torch.stack([
|
|
540
|
+
strain[..., 0, 0], strain[..., 0, 1], strain[..., 0, 2],
|
|
541
|
+
strain[..., 1, 1], strain[..., 1, 2], strain[..., 2, 2],
|
|
542
|
+
], dim=-1)
|
|
543
|
+
return strain
|
|
544
|
+
|
|
499
545
|
# ------------------------------------------------------------------
|
|
500
546
|
# rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
|
|
501
547
|
# ------------------------------------------------------------------
|
|
@@ -507,19 +553,27 @@ class HEDMForwardModel(nn.Module):
|
|
|
507
553
|
) -> torch.Tensor:
|
|
508
554
|
"""Rotate a symmetric infinitesimal strain from sample to crystal frame.
|
|
509
555
|
|
|
510
|
-
|
|
556
|
+
Matches ``RotateStrainSampleToCrystal`` from
|
|
511
557
|
``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
|
|
512
|
-
eps_crystal = OM^T . eps_sample . OM
|
|
513
|
-
|
|
558
|
+
eps_crystal = OM^T . eps_sample . OM.
|
|
559
|
+
|
|
560
|
+
The rotation is delegated to ``midas_stress.tensor.strain_lab_to_grain``
|
|
561
|
+
(the canonical sample/lab -> crystal/grain strain transform; bit-identical
|
|
562
|
+
to the former in-line ``OM^T S OM``). PLAIN-Voigt pack/unpack is kept here
|
|
563
|
+
on purpose -- midas_stress's Voigt is Mandel (sqrt2 on shears), which must
|
|
564
|
+
not touch the forward model's strain input.
|
|
514
565
|
|
|
515
566
|
Parameters
|
|
516
567
|
----------
|
|
517
568
|
orientation_matrices : Tensor (..., 3, 3)
|
|
569
|
+
Crystal->sample matrices (as returned by :meth:`euler2mat`).
|
|
518
570
|
strain_sample : Tensor (..., 6)
|
|
571
|
+
PLAIN-Voigt [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
|
|
519
572
|
|
|
520
573
|
Returns
|
|
521
574
|
-------
|
|
522
575
|
strain_crystal : Tensor (..., 6)
|
|
576
|
+
PLAIN-Voigt, same layout.
|
|
523
577
|
"""
|
|
524
578
|
e = strain_sample
|
|
525
579
|
S = torch.stack([
|
|
@@ -527,8 +581,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
527
581
|
torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
|
|
528
582
|
torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
|
|
529
583
|
], dim=-2)
|
|
530
|
-
|
|
531
|
-
C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
|
|
584
|
+
C = _ms_strain_lab_to_grain(S, orientation_matrices) # OM^T S OM
|
|
532
585
|
return torch.stack([
|
|
533
586
|
C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
|
|
534
587
|
C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
|
|
@@ -558,10 +611,12 @@ class HEDMForwardModel(nn.Module):
|
|
|
558
611
|
lattice_params : Tensor (..., 6)
|
|
559
612
|
[a, b, c, alpha, beta, gamma] in Angstroms and degrees.
|
|
560
613
|
The ``...`` dimensions allow per-voxel or per-grain parameters.
|
|
561
|
-
strain : Tensor (..., 6), optional
|
|
562
|
-
Crystal-frame symmetric infinitesimal strain
|
|
563
|
-
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
564
|
-
|
|
614
|
+
strain : Tensor (..., 6) or (..., 3, 3), optional
|
|
615
|
+
Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
|
|
616
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
|
|
617
|
+
3x3 tensor (normalized via :meth:`strain_as_voigt`; off-diagonals are
|
|
618
|
+
true tensor components, no factor of 2). When supplied, the
|
|
619
|
+
reciprocal lattice is post-multiplied by (I + eps)^{-1}:
|
|
565
620
|
B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
|
|
566
621
|
to convert a sample-frame strain into the crystal frame.
|
|
567
622
|
|
|
@@ -642,6 +697,7 @@ class HEDMForwardModel(nn.Module):
|
|
|
642
697
|
# Voigt layout matches C CorrectHKLsLatCEpsilon:
|
|
643
698
|
# eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
644
699
|
if strain is not None:
|
|
700
|
+
strain = self.strain_as_voigt(strain) # accept full 3x3 too
|
|
645
701
|
e11 = strain[..., 0]
|
|
646
702
|
e12 = strain[..., 1]
|
|
647
703
|
e13 = strain[..., 2]
|
|
@@ -1163,6 +1219,37 @@ class HEDMForwardModel(nn.Module):
|
|
|
1163
1219
|
ydet_d = torch.stack(out_y, dim=0)
|
|
1164
1220
|
zdet_d = torch.stack(out_z, dim=0)
|
|
1165
1221
|
|
|
1222
|
+
# Ideal->raw radial distortion (canonical midas_distortion v2 model),
|
|
1223
|
+
# applied on the BC-relative detector-plane coords (um) before pixel
|
|
1224
|
+
# conversion -- mirrors midas_calibrate_v2.forward.geometry. Gated OFF
|
|
1225
|
+
# by default so the indexer/fit-grain (ideal frame) are byte-unchanged;
|
|
1226
|
+
# raw-patch consumers (pf_odf, grain_odf) opt in. R is BC-relative, so
|
|
1227
|
+
# the distortion is frame-flip invariant; the eta convention (and any
|
|
1228
|
+
# phase offset between the calibration frame and this frame) is the one
|
|
1229
|
+
# thing to validate empirically (see implementation_plan_distortion_layer).
|
|
1230
|
+
if self.apply_distortion and self._has_distortion:
|
|
1231
|
+
from midas_distortion import apply_distortion as _apply_dist, \
|
|
1232
|
+
v2_term_layout as _v2_terms, resolve_rho_d_um as _resolve_rho_d
|
|
1233
|
+
eps = torch.tensor(1e-9, dtype=ydet_d.dtype, device=ydet_d.device)
|
|
1234
|
+
R = torch.sqrt(ydet_d * ydet_d + zdet_d * zdet_d).clamp(min=eps)
|
|
1235
|
+
# eta convention matches calibrate_v2 forward: atan2(-y, z), degrees.
|
|
1236
|
+
eta_deg_d = self.RAD2DEG * torch.atan2(-ydet_d, zdet_d)
|
|
1237
|
+
# resolve_rho_d_um passes a supplied rho_d through, or computes the
|
|
1238
|
+
# max BC-relative corner distance (um) when None.
|
|
1239
|
+
rho_d_val, _rho_how = _resolve_rho_d(
|
|
1240
|
+
self.rho_d,
|
|
1241
|
+
NrPixelsY=self.n_pixels_y, NrPixelsZ=self.n_pixels_z,
|
|
1242
|
+
BC_y=float(self._y_BC.reshape(-1)[0]),
|
|
1243
|
+
BC_z=float(self._z_BC.reshape(-1)[0]),
|
|
1244
|
+
pxY=self.px,
|
|
1245
|
+
)
|
|
1246
|
+
rho_d_t = torch.as_tensor(float(rho_d_val), dtype=R.dtype, device=R.device)
|
|
1247
|
+
p_v2 = self.p_distortion.to(R.dtype)
|
|
1248
|
+
R_corr = _apply_dist(R, eta_deg_d, p_v2, rho_d_t, terms=_v2_terms())
|
|
1249
|
+
scale = R_corr / R
|
|
1250
|
+
ydet_d = ydet_d * scale
|
|
1251
|
+
zdet_d = zdet_d * scale
|
|
1252
|
+
|
|
1166
1253
|
# FF/PF: y-axis on detector flipped (yBC - ydet/px), validated against C
|
|
1167
1254
|
# NF: not flipped (yBC + ydet/px), validated against C
|
|
1168
1255
|
y_sign = -1.0 if self.flip_y else 1.0
|
|
@@ -1251,9 +1338,10 @@ class HEDMForwardModel(nn.Module):
|
|
|
1251
1338
|
lattice_params : Tensor (..., 6) or (..., N, 6), optional
|
|
1252
1339
|
Strained lattice parameters [a,b,c,alpha,beta,gamma] in
|
|
1253
1340
|
Angstroms/degrees. None = use nominal hkls/thetas (no strain).
|
|
1254
|
-
strain : Tensor (..., 6)
|
|
1255
|
-
Crystal-frame symmetric infinitesimal strain
|
|
1256
|
-
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
|
|
1341
|
+
strain : Tensor (..., 6), (..., N, 6), or (..., 3, 3), optional
|
|
1342
|
+
Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
|
|
1343
|
+
[eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
|
|
1344
|
+
3x3 tensor (see :meth:`strain_as_voigt`). Applied as
|
|
1257
1345
|
B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
|
|
1258
1346
|
strain expressed through ``lattice_params``. Requires
|
|
1259
1347
|
``lattice_params`` to be supplied.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: midas-diffract
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
|
|
5
5
|
Author-email: Hemant Sharma <hsharma@anl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -18,6 +18,8 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: numpy>=1.22
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
|
+
Requires-Dist: midas-distortion>=0.2.0
|
|
22
|
+
Requires-Dist: midas-stress>=0.8.0
|
|
21
23
|
Provides-Extra: dev
|
|
22
24
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
23
25
|
Requires-Dist: pytest-cov; extra == "dev"
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "midas-diffract"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.4.0"
|
|
8
8
|
description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "BSD-3-Clause"
|
|
@@ -26,6 +26,8 @@ classifiers = [
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"numpy>=1.22",
|
|
28
28
|
"torch>=2.0",
|
|
29
|
+
"midas-distortion>=0.2.0",
|
|
30
|
+
"midas-stress>=0.8.0",
|
|
29
31
|
]
|
|
30
32
|
|
|
31
33
|
[project.optional-dependencies]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Tests for the gated ideal->raw radial-distortion layer (midas_distortion v2).
|
|
2
|
+
|
|
3
|
+
Guarantees:
|
|
4
|
+
1. default (apply_distortion=False) is BYTE-IDENTICAL to the pre-layer forward
|
|
5
|
+
-- proves the indexer/fit-grain (ideal-frame) path is untouched.
|
|
6
|
+
2. apply_distortion=True with zero coeffs is also identity.
|
|
7
|
+
3. nonzero coeffs shift predicted spots (ideal->raw) by a sane, radius-growing
|
|
8
|
+
amount.
|
|
9
|
+
4. the layer is differentiable w.r.t. the distortion coefficients.
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pytest
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from midas_diffract.forward import HEDMForwardModel, HEDMGeometry
|
|
16
|
+
|
|
17
|
+
_GD = dict(
|
|
18
|
+
Lsd=752000.0, y_BC=695.0, z_BC=874.0, px=172.0,
|
|
19
|
+
omega_start=180.0, omega_step=-0.25, n_frames=1440,
|
|
20
|
+
n_pixels_y=1679, n_pixels_z=1679, min_eta=6.0, wavelength=0.172979,
|
|
21
|
+
)
|
|
22
|
+
# A representative calibrated v2 coefficient set (Bucsek Pilatus 2M CeO2).
|
|
23
|
+
_COEFFS = [0.00707, -0.01, 0.00624, 0.01, -34.76, 0.00234, 81.47,
|
|
24
|
+
-0.00369, -12.29, -0.00727, -5.29, -0.00863, -1.51, -0.00446, -7.79]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _model(**extra):
|
|
28
|
+
rng = np.random.default_rng(0)
|
|
29
|
+
hk = torch.tensor(rng.standard_normal((40, 3)), dtype=torch.float64)
|
|
30
|
+
th = torch.tensor(np.abs(rng.standard_normal(40)) * 0.04 + 0.04, dtype=torch.float64)
|
|
31
|
+
hi = torch.tensor(rng.integers(-3, 4, (40, 3)), dtype=torch.float64)
|
|
32
|
+
g = HEDMGeometry(**_GD, **extra)
|
|
33
|
+
return HEDMForwardModel(hk, th, g, hkls_int=hi).double()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_EUL = torch.tensor([[0.3, 0.5, 0.2]], dtype=torch.float64)
|
|
37
|
+
_POS = torch.zeros(1, 3, dtype=torch.float64)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_default_off_is_unchanged():
|
|
41
|
+
"""No distortion fields == apply_distortion=False == zero coeffs (identity)."""
|
|
42
|
+
o0 = _model().forward(_EUL, _POS)
|
|
43
|
+
o_off = _model(apply_distortion=False, p_distortion=[0.0] * 15).forward(_EUL, _POS)
|
|
44
|
+
o_zero = _model(apply_distortion=True, p_distortion=[0.0] * 15, rho_d=2.0e5).forward(_EUL, _POS)
|
|
45
|
+
assert torch.equal(o0.y_pixel, o_off.y_pixel)
|
|
46
|
+
assert torch.equal(o0.z_pixel, o_off.z_pixel)
|
|
47
|
+
# zero coeffs with the layer ACTIVE must also be a no-op (D == 1).
|
|
48
|
+
assert torch.allclose(o0.y_pixel, o_zero.y_pixel, atol=1e-9)
|
|
49
|
+
assert torch.allclose(o0.z_pixel, o_zero.z_pixel, atol=1e-9)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_distortion_shifts_spots():
|
|
53
|
+
o0 = _model().forward(_EUL, _POS)
|
|
54
|
+
o2 = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).forward(_EUL, _POS)
|
|
55
|
+
vm = (o0.valid > 0.5) & (o2.valid > 0.5)
|
|
56
|
+
assert int(vm.sum()) > 5
|
|
57
|
+
dy = (o2.y_pixel - o0.y_pixel)[vm].abs()
|
|
58
|
+
assert float(dy.median()) > 0.05 # measurable
|
|
59
|
+
assert float(dy.max()) < 50.0 # but physical (sub-module)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_distortion_is_differentiable():
|
|
63
|
+
m = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5)
|
|
64
|
+
m.p_distortion.requires_grad_(True)
|
|
65
|
+
out = m.forward(_EUL, _POS)
|
|
66
|
+
loss = (out.y_pixel * out.valid).sum() + (out.z_pixel * out.valid).sum()
|
|
67
|
+
loss.backward()
|
|
68
|
+
assert m.p_distortion.grad is not None
|
|
69
|
+
assert float(m.p_distortion.grad.norm()) > 0.0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
73
|
+
def test_distortion_cpu_cuda_parity():
|
|
74
|
+
o_cpu = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).forward(_EUL, _POS)
|
|
75
|
+
m = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).to("cuda")
|
|
76
|
+
o_gpu = m.forward(_EUL.cuda(), _POS.cuda())
|
|
77
|
+
assert torch.allclose(o_cpu.y_pixel, o_gpu.y_pixel.cpu(), atol=1e-7)
|
|
78
|
+
assert torch.allclose(o_cpu.z_pixel, o_gpu.z_pixel.cpu(), atol=1e-7)
|
|
@@ -1111,5 +1111,95 @@ class TestCrossValidation:
|
|
|
1111
1111
|
import torch.nn.functional as F # for test_2d_position_backward_compat
|
|
1112
1112
|
|
|
1113
1113
|
|
|
1114
|
+
# ===================================================================
|
|
1115
|
+
# Test: midas_stress is the canonical orientation/strain-frame source
|
|
1116
|
+
# ===================================================================
|
|
1117
|
+
|
|
1118
|
+
class TestMidasStressDelegation:
|
|
1119
|
+
"""Guards that euler2mat / rotate_strain_sample_to_crystal delegate to
|
|
1120
|
+
midas_stress and never silently re-port (convention drift)."""
|
|
1121
|
+
|
|
1122
|
+
def test_euler2mat_matches_midas_stress(self):
|
|
1123
|
+
from midas_stress.orientation import euler_to_orient_mat
|
|
1124
|
+
torch.manual_seed(0)
|
|
1125
|
+
eul = torch.rand(7, 3, dtype=torch.float64) * 2 * math.pi
|
|
1126
|
+
R = HEDMForwardModel.euler2mat(eul)
|
|
1127
|
+
R_ms = euler_to_orient_mat(eul).reshape(7, 3, 3)
|
|
1128
|
+
torch.testing.assert_close(R, R_ms, atol=1e-12, rtol=0)
|
|
1129
|
+
|
|
1130
|
+
def test_euler2mat_vmap_safe(self):
|
|
1131
|
+
from torch.func import vmap
|
|
1132
|
+
eul = torch.rand(5, 3, dtype=torch.float64) * 2 * math.pi
|
|
1133
|
+
R = vmap(HEDMForwardModel.euler2mat)(eul)
|
|
1134
|
+
assert R.shape == (5, 3, 3)
|
|
1135
|
+
torch.testing.assert_close(R, HEDMForwardModel.euler2mat(eul), atol=1e-12, rtol=0)
|
|
1136
|
+
|
|
1137
|
+
def test_euler2mat_differentiable(self):
|
|
1138
|
+
eul = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
|
|
1139
|
+
HEDMForwardModel.euler2mat(eul).pow(2).sum().backward()
|
|
1140
|
+
assert eul.grad is not None and torch.all(torch.isfinite(eul.grad))
|
|
1141
|
+
|
|
1142
|
+
def test_rotate_strain_matches_midas_stress(self):
|
|
1143
|
+
from midas_stress.tensor import strain_lab_to_grain
|
|
1144
|
+
torch.manual_seed(1)
|
|
1145
|
+
OM = HEDMForwardModel.euler2mat(torch.rand(4, 3, dtype=torch.float64))
|
|
1146
|
+
v = torch.rand(4, 6, dtype=torch.float64) * 1e-3 # plain Voigt
|
|
1147
|
+
out = HEDMForwardModel.rotate_strain_sample_to_crystal(OM, v)
|
|
1148
|
+
# reference: build 3x3, rotate via midas_stress, repack plain Voigt
|
|
1149
|
+
S = torch.zeros(4, 3, 3, dtype=torch.float64)
|
|
1150
|
+
S[:, 0, 0], S[:, 1, 1], S[:, 2, 2] = v[:, 0], v[:, 3], v[:, 5]
|
|
1151
|
+
S[:, 0, 1] = S[:, 1, 0] = v[:, 1]
|
|
1152
|
+
S[:, 0, 2] = S[:, 2, 0] = v[:, 2]
|
|
1153
|
+
S[:, 1, 2] = S[:, 2, 1] = v[:, 4]
|
|
1154
|
+
C = strain_lab_to_grain(S, OM)
|
|
1155
|
+
ref = torch.stack([C[:, 0, 0], C[:, 0, 1], C[:, 0, 2],
|
|
1156
|
+
C[:, 1, 1], C[:, 1, 2], C[:, 2, 2]], dim=-1)
|
|
1157
|
+
torch.testing.assert_close(out, ref, atol=1e-14, rtol=0)
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
# ===================================================================
|
|
1161
|
+
# Test: full 3x3 strain tensor entry point (convention-free)
|
|
1162
|
+
# ===================================================================
|
|
1163
|
+
|
|
1164
|
+
class TestStrainTensorInput:
|
|
1165
|
+
"""A full symmetric 3x3 strain must give the same result as the
|
|
1166
|
+
equivalent plain-Voigt 6-vector (NOT Voigt-Mandel)."""
|
|
1167
|
+
|
|
1168
|
+
def test_strain_as_voigt_roundtrip(self):
|
|
1169
|
+
v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
|
|
1170
|
+
S = torch.tensor([[1e-3, 2e-4, 3e-4],
|
|
1171
|
+
[2e-4, -5e-4, 1e-4],
|
|
1172
|
+
[3e-4, 1e-4, 2e-4]])
|
|
1173
|
+
torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(S), v)
|
|
1174
|
+
# plain Voigt passes through unchanged
|
|
1175
|
+
torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(v), v)
|
|
1176
|
+
|
|
1177
|
+
def test_3x3_strain_equals_voigt_in_correct_hkls_latc(self, nf_geometry, device):
|
|
1178
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1179
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.])
|
|
1180
|
+
v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
|
|
1181
|
+
S = torch.tensor([[1e-3, 2e-4, 3e-4],
|
|
1182
|
+
[2e-4, -5e-4, 1e-4],
|
|
1183
|
+
[3e-4, 1e-4, 2e-4]])
|
|
1184
|
+
g_v, t_v = model.correct_hkls_latc(latc, strain=v)
|
|
1185
|
+
g_S, t_S = model.correct_hkls_latc(latc, strain=S)
|
|
1186
|
+
torch.testing.assert_close(g_v, g_S)
|
|
1187
|
+
torch.testing.assert_close(t_v, t_S)
|
|
1188
|
+
|
|
1189
|
+
def test_3x3_strain_equals_voigt_in_forward(self, nf_geometry, device):
|
|
1190
|
+
model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
|
|
1191
|
+
eul = torch.tensor([[0.3, 0.5, 1.1]])
|
|
1192
|
+
pos = torch.zeros(1, 3)
|
|
1193
|
+
latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.]).expand(1, 6)
|
|
1194
|
+
v = torch.tensor([[1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4]])
|
|
1195
|
+
S = torch.tensor([[[1e-3, 2e-4, 3e-4],
|
|
1196
|
+
[2e-4, -5e-4, 1e-4],
|
|
1197
|
+
[3e-4, 1e-4, 2e-4]]])
|
|
1198
|
+
sp_v = model(eul, pos, lattice_params=latc, strain=v)
|
|
1199
|
+
sp_S = model(eul, pos, lattice_params=latc, strain=S)
|
|
1200
|
+
torch.testing.assert_close(sp_v.two_theta, sp_S.two_theta)
|
|
1201
|
+
torch.testing.assert_close(sp_v.omega, sp_S.omega)
|
|
1202
|
+
|
|
1203
|
+
|
|
1114
1204
|
if __name__ == "__main__":
|
|
1115
1205
|
pytest.main([__file__, "-v"])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|