midas-diffract 0.1.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/PKG-INFO +3 -1
  2. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/__init__.py +1 -1
  3. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/forward.py +129 -41
  4. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/PKG-INFO +3 -1
  5. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/SOURCES.txt +1 -0
  6. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/requires.txt +2 -0
  7. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/pyproject.toml +3 -1
  8. midas_diffract-0.4.0/tests/test_distortion_layer.py +78 -0
  9. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_forward.py +90 -0
  10. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/LICENSE +0 -0
  11. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/README.md +0 -0
  12. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/hkls.py +0 -0
  13. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/losses.py +0 -0
  14. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/optimize.py +0 -0
  15. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract/simulate_panel_zarrs.py +0 -0
  16. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/dependency_links.txt +0 -0
  17. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/midas_diffract.egg-info/top_level.txt +0 -0
  18. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/setup.cfg +0 -0
  19. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_c_comparison.py +0 -0
  20. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_hkls.py +0 -0
  21. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_losses.py +0 -0
  22. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_multi_detector.py +0 -0
  23. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_strain_tensor.py +0 -0
  24. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_tilts.py +0 -0
  25. {midas_diffract-0.1.2 → midas_diffract-0.4.0}/tests/test_wedge.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.1.2
3
+ Version: 0.4.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -18,6 +18,8 @@ Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: numpy>=1.22
20
20
  Requires-Dist: torch>=2.0
21
+ Requires-Dist: midas-distortion>=0.2.0
22
+ Requires-Dist: midas-stress>=0.8.0
21
23
  Provides-Extra: dev
22
24
  Requires-Dist: pytest>=7.0; extra == "dev"
23
25
  Requires-Dist: pytest-cov; extra == "dev"
@@ -28,7 +28,7 @@ Quick start
28
28
  loss.backward()
29
29
  """
30
30
 
31
- __version__ = "0.1.2"
31
+ __version__ = "0.4.0"
32
32
 
33
33
  from .forward import (
34
34
  HEDMForwardModel,
@@ -34,6 +34,16 @@ import torch
34
34
  import torch.nn as nn
35
35
  import torch.nn.functional as F
36
36
 
37
+ # Canonical orientation + strain-frame primitives. midas_stress is the single
38
+ # source of truth for this math (Bunge ZXZ orientation algebra, sample<->crystal
39
+ # strain rotation); its torch backend is differentiable end-to-end and
40
+ # device-portable, so the forward model delegates rather than re-porting.
41
+ # NOTE: midas_stress's Voigt is Voigt-MANDEL (sqrt2 on shears); this model uses
42
+ # PLAIN-Voigt / raw 3x3 strain, so we only delegate the *rotation*, never the
43
+ # Voigt packing (see rotate_strain_sample_to_crystal / correct_hkls_latc).
44
+ from midas_stress.orientation import euler_to_orient_mat as _ms_euler_to_orient_mat
45
+ from midas_stress.tensor import strain_lab_to_grain as _ms_strain_lab_to_grain
46
+
37
47
 
38
48
  # ---------------------------------------------------------------------------
39
49
  # Configuration data classes
@@ -89,6 +99,16 @@ class HEDMGeometry:
89
99
  # Default False preserves the existing
90
100
  # behaviour: NF applies tilts, FF skips.
91
101
  # Set True for raw multi-panel simulation.
102
+ # Radial detector distortion (canonical midas_distortion v2 model). Like
103
+ # tilts, this maps an IDEAL prediction to the RAW detector position and is
104
+ # OFF by default -- the FF/pf experimental pipeline pre-corrects distortion
105
+ # at peak-finding time (transforms), so the forward must NOT re-apply it for
106
+ # the indexer/fit-grain. Raw-pixel-patch consumers (pf_odf, grain_odf) that
107
+ # never go through transforms set ``apply_distortion=True`` and supply the
108
+ # calibrated v2 coefficients to predict in the raw frame.
109
+ apply_distortion: bool = False
110
+ p_distortion: "list[float] | None" = None # 15 v2 coeffs (midas_distortion P_COEF_NAMES order); None/zeros => no-op
111
+ rho_d: "float | None" = None # distortion radius normalization (um); None => resolve from detector corner
92
112
  multi_mode: str = "layered" # "layered" (default): NF semantics --
93
113
  # spot must land on the detector at
94
114
  # EVERY distance (AllDistsFound).
@@ -342,6 +362,25 @@ class HEDMForwardModel(nn.Module):
342
362
 
343
363
  # Multi-detector / multi-panel configuration
344
364
  self.apply_tilts = bool(geometry.apply_tilts)
365
+
366
+ # Radial detector distortion (canonical midas_distortion v2 model),
367
+ # applied ideal->raw in project_to_detector when apply_distortion=True.
368
+ # OFF by default => identity => indexer/fit-grain output unchanged.
369
+ self.apply_distortion = bool(getattr(geometry, "apply_distortion", False))
370
+ p_dist = getattr(geometry, "p_distortion", None)
371
+ if p_dist is not None:
372
+ self.p_distortion = nn.Parameter(
373
+ torch.as_tensor(p_dist, dtype=torch.float64, device=device),
374
+ requires_grad=False,
375
+ )
376
+ self._has_distortion = bool(
377
+ torch.any(torch.abs(self.p_distortion.detach()) > 0.0).item()
378
+ )
379
+ else:
380
+ self.p_distortion = None
381
+ self._has_distortion = False
382
+ self.rho_d = getattr(geometry, "rho_d", None)
383
+
345
384
  if geometry.multi_mode not in ("layered", "panel"):
346
385
  raise ValueError(
347
386
  f"Unknown multi_mode {geometry.multi_mode!r}; "
@@ -408,7 +447,13 @@ class HEDMForwardModel(nn.Module):
408
447
 
409
448
  @staticmethod
410
449
  def euler2mat(euler_angles: torch.Tensor) -> torch.Tensor:
411
- """Convert ZXZ Euler angles to rotation matrices.
450
+ """Convert ZXZ (Bunge) Euler angles to crystal->sample rotation matrices.
451
+
452
+ Delegates to ``midas_stress.orientation.euler_to_orient_mat`` -- the
453
+ canonical orientation primitive -- so the convention can never drift
454
+ from the rest of MIDAS. midas_stress's torch backend is differentiable
455
+ and vmap-safe; the result is identical to the former in-line ZXZ build
456
+ (R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)) to ~1e-16.
412
457
 
413
458
  Parameters
414
459
  ----------
@@ -418,35 +463,14 @@ class HEDMForwardModel(nn.Module):
418
463
  Returns
419
464
  -------
420
465
  Tensor (..., 3, 3)
421
- Rotation matrices.
466
+ Rotation matrices (crystal->sample), orthogonalized onto SO(3).
422
467
  """
423
- c = torch.cos(euler_angles)
424
- s = torch.sin(euler_angles)
425
-
426
- c0, c1, c2 = c[..., 0], c[..., 1], c[..., 2]
427
- s0, s1, s2 = s[..., 0], s[..., 1], s[..., 2]
428
-
429
- # ZXZ rotation matrix: R = Rz(phi1) @ Rx(Phi) @ Rz(phi2)
430
- # Verified element-by-element against nfhedm.py lines 114-120.
431
- # Built via torch.stack rather than indexed assignment so the function
432
- # composes with torch.func.vmap (in-place writes block vmap).
433
- row0 = torch.stack([
434
- c0 * c2 - s0 * c1 * s2,
435
- -s0 * c1 * c2 - c0 * s2,
436
- s0 * s1,
437
- ], dim=-1)
438
- row1 = torch.stack([
439
- s0 * c2 + c0 * c1 * s2,
440
- c0 * c1 * c2 - s0 * s2,
441
- -c0 * s1,
442
- ], dim=-1)
443
- row2 = torch.stack([
444
- s1 * s2,
445
- s1 * c2,
446
- c1,
447
- ], dim=-1)
448
- R = torch.stack([row0, row1, row2], dim=-2)
449
-
468
+ if not isinstance(euler_angles, torch.Tensor):
469
+ euler_angles = torch.as_tensor(euler_angles)
470
+ R = _ms_euler_to_orient_mat(euler_angles) # (..., 9), torch
471
+ R = R.reshape(*R.shape[:-1], 3, 3)
472
+ # midas_stress already returns a proper rotation; orthogonalize keeps the
473
+ # historical "exactly on SO(3)" guarantee and is idempotent here.
450
474
  return HEDMForwardModel.orthogonalize(R)
451
475
 
452
476
  # ------------------------------------------------------------------
@@ -496,6 +520,28 @@ class HEDMForwardModel(nn.Module):
496
520
  """Numerically stable arccos: clamp to [-1+eps, 1-eps]."""
497
521
  return torch.acos(torch.clamp(x, -1.0 + self.epsilon, 1.0 - self.epsilon))
498
522
 
523
+ # ------------------------------------------------------------------
524
+ # strain_as_voigt (accept full 3x3 tensor OR plain-Voigt 6-vector)
525
+ # ------------------------------------------------------------------
526
+
527
+ @staticmethod
528
+ def strain_as_voigt(strain: torch.Tensor) -> torch.Tensor:
529
+ """Normalize a strain input to PLAIN-Voigt [e11,e12,e13,e22,e23,e33].
530
+
531
+ Accepts either a plain-Voigt ``(..., 6)`` tensor (returned unchanged) or
532
+ a full symmetric ``(..., 3, 3)`` strain tensor. The 3x3 path is
533
+ convention-free -- the natural way to hand a strain field straight from
534
+ a tensor source (e.g. a midas_stress strain field) into the forward
535
+ model without picking a Voigt/Mandel packing. The off-diagonals are
536
+ taken as TRUE tensor components (no factor of 2).
537
+ """
538
+ if strain.dim() >= 2 and strain.shape[-1] == 3 and strain.shape[-2] == 3:
539
+ return torch.stack([
540
+ strain[..., 0, 0], strain[..., 0, 1], strain[..., 0, 2],
541
+ strain[..., 1, 1], strain[..., 1, 2], strain[..., 2, 2],
542
+ ], dim=-1)
543
+ return strain
544
+
499
545
  # ------------------------------------------------------------------
500
546
  # rotate_strain_sample_to_crystal (port of C RotateStrainSampleToCrystal)
501
547
  # ------------------------------------------------------------------
@@ -507,19 +553,27 @@ class HEDMForwardModel(nn.Module):
507
553
  ) -> torch.Tensor:
508
554
  """Rotate a symmetric infinitesimal strain from sample to crystal frame.
509
555
 
510
- Port of ``RotateStrainSampleToCrystal`` from
556
+ Matches ``RotateStrainSampleToCrystal`` from
511
557
  ``FF_HEDM/src/ForwardSimulationCompressed.c:399-419``:
512
- eps_crystal = OM^T . eps_sample . OM, in Voigt notation
513
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
558
+ eps_crystal = OM^T . eps_sample . OM.
559
+
560
+ The rotation is delegated to ``midas_stress.tensor.strain_lab_to_grain``
561
+ (the canonical sample/lab -> crystal/grain strain transform; bit-identical
562
+ to the former in-line ``OM^T S OM``). PLAIN-Voigt pack/unpack is kept here
563
+ on purpose -- midas_stress's Voigt is Mandel (sqrt2 on shears), which must
564
+ not touch the forward model's strain input.
514
565
 
515
566
  Parameters
516
567
  ----------
517
568
  orientation_matrices : Tensor (..., 3, 3)
569
+ Crystal->sample matrices (as returned by :meth:`euler2mat`).
518
570
  strain_sample : Tensor (..., 6)
571
+ PLAIN-Voigt [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33].
519
572
 
520
573
  Returns
521
574
  -------
522
575
  strain_crystal : Tensor (..., 6)
576
+ PLAIN-Voigt, same layout.
523
577
  """
524
578
  e = strain_sample
525
579
  S = torch.stack([
@@ -527,8 +581,7 @@ class HEDMForwardModel(nn.Module):
527
581
  torch.stack([e[..., 1], e[..., 3], e[..., 4]], dim=-1),
528
582
  torch.stack([e[..., 2], e[..., 4], e[..., 5]], dim=-1),
529
583
  ], dim=-2)
530
- OM = orientation_matrices
531
- C = torch.matmul(torch.matmul(OM.transpose(-1, -2), S), OM)
584
+ C = _ms_strain_lab_to_grain(S, orientation_matrices) # OM^T S OM
532
585
  return torch.stack([
533
586
  C[..., 0, 0], C[..., 0, 1], C[..., 0, 2],
534
587
  C[..., 1, 1], C[..., 1, 2], C[..., 2, 2],
@@ -558,10 +611,12 @@ class HEDMForwardModel(nn.Module):
558
611
  lattice_params : Tensor (..., 6)
559
612
  [a, b, c, alpha, beta, gamma] in Angstroms and degrees.
560
613
  The ``...`` dimensions allow per-voxel or per-grain parameters.
561
- strain : Tensor (..., 6), optional
562
- Crystal-frame symmetric infinitesimal strain in Voigt form
563
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. When supplied,
564
- the reciprocal lattice is post-multiplied by (I + eps)^{-1}:
614
+ strain : Tensor (..., 6) or (..., 3, 3), optional
615
+ Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
616
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
617
+ 3x3 tensor (normalized via :meth:`strain_as_voigt`; off-diagonals are
618
+ true tensor components, no factor of 2). When supplied, the
619
+ reciprocal lattice is post-multiplied by (I + eps)^{-1}:
565
620
  B = (I + eps)^{-1} @ B0. Use :meth:`rotate_strain_sample_to_crystal`
566
621
  to convert a sample-frame strain into the crystal frame.
567
622
 
@@ -642,6 +697,7 @@ class HEDMForwardModel(nn.Module):
642
697
  # Voigt layout matches C CorrectHKLsLatCEpsilon:
643
698
  # eps = [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]
644
699
  if strain is not None:
700
+ strain = self.strain_as_voigt(strain) # accept full 3x3 too
645
701
  e11 = strain[..., 0]
646
702
  e12 = strain[..., 1]
647
703
  e13 = strain[..., 2]
@@ -1163,6 +1219,37 @@ class HEDMForwardModel(nn.Module):
1163
1219
  ydet_d = torch.stack(out_y, dim=0)
1164
1220
  zdet_d = torch.stack(out_z, dim=0)
1165
1221
 
1222
+ # Ideal->raw radial distortion (canonical midas_distortion v2 model),
1223
+ # applied on the BC-relative detector-plane coords (um) before pixel
1224
+ # conversion -- mirrors midas_calibrate_v2.forward.geometry. Gated OFF
1225
+ # by default so the indexer/fit-grain (ideal frame) are byte-unchanged;
1226
+ # raw-patch consumers (pf_odf, grain_odf) opt in. R is BC-relative, so
1227
+ # the distortion is frame-flip invariant; the eta convention (and any
1228
+ # phase offset between the calibration frame and this frame) is the one
1229
+ # thing to validate empirically (see implementation_plan_distortion_layer).
1230
+ if self.apply_distortion and self._has_distortion:
1231
+ from midas_distortion import apply_distortion as _apply_dist, \
1232
+ v2_term_layout as _v2_terms, resolve_rho_d_um as _resolve_rho_d
1233
+ eps = torch.tensor(1e-9, dtype=ydet_d.dtype, device=ydet_d.device)
1234
+ R = torch.sqrt(ydet_d * ydet_d + zdet_d * zdet_d).clamp(min=eps)
1235
+ # eta convention matches calibrate_v2 forward: atan2(-y, z), degrees.
1236
+ eta_deg_d = self.RAD2DEG * torch.atan2(-ydet_d, zdet_d)
1237
+ # resolve_rho_d_um passes a supplied rho_d through, or computes the
1238
+ # max BC-relative corner distance (um) when None.
1239
+ rho_d_val, _rho_how = _resolve_rho_d(
1240
+ self.rho_d,
1241
+ NrPixelsY=self.n_pixels_y, NrPixelsZ=self.n_pixels_z,
1242
+ BC_y=float(self._y_BC.reshape(-1)[0]),
1243
+ BC_z=float(self._z_BC.reshape(-1)[0]),
1244
+ pxY=self.px,
1245
+ )
1246
+ rho_d_t = torch.as_tensor(float(rho_d_val), dtype=R.dtype, device=R.device)
1247
+ p_v2 = self.p_distortion.to(R.dtype)
1248
+ R_corr = _apply_dist(R, eta_deg_d, p_v2, rho_d_t, terms=_v2_terms())
1249
+ scale = R_corr / R
1250
+ ydet_d = ydet_d * scale
1251
+ zdet_d = zdet_d * scale
1252
+
1166
1253
  # FF/PF: y-axis on detector flipped (yBC - ydet/px), validated against C
1167
1254
  # NF: not flipped (yBC + ydet/px), validated against C
1168
1255
  y_sign = -1.0 if self.flip_y else 1.0
@@ -1251,9 +1338,10 @@ class HEDMForwardModel(nn.Module):
1251
1338
  lattice_params : Tensor (..., 6) or (..., N, 6), optional
1252
1339
  Strained lattice parameters [a,b,c,alpha,beta,gamma] in
1253
1340
  Angstroms/degrees. None = use nominal hkls/thetas (no strain).
1254
- strain : Tensor (..., 6) or (..., N, 6), optional
1255
- Crystal-frame symmetric infinitesimal strain in Voigt form
1256
- [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33]. Applied as
1341
+ strain : Tensor (..., 6), (..., N, 6), or (..., 3, 3), optional
1342
+ Crystal-frame symmetric infinitesimal strain, either PLAIN-Voigt
1343
+ [eps_11, eps_12, eps_13, eps_22, eps_23, eps_33] or a full symmetric
1344
+ 3x3 tensor (see :meth:`strain_as_voigt`). Applied as
1257
1345
  B = (I + eps)^{-1} @ B0 in addition to any lattice-parameter
1258
1346
  strain expressed through ``lattice_params``. Requires
1259
1347
  ``lattice_params`` to be supplied.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: midas-diffract
3
- Version: 0.1.2
3
+ Version: 0.4.0
4
4
  Summary: End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)
5
5
  Author-email: Hemant Sharma <hsharma@anl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -18,6 +18,8 @@ Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: numpy>=1.22
20
20
  Requires-Dist: torch>=2.0
21
+ Requires-Dist: midas-distortion>=0.2.0
22
+ Requires-Dist: midas-stress>=0.8.0
21
23
  Provides-Extra: dev
22
24
  Requires-Dist: pytest>=7.0; extra == "dev"
23
25
  Requires-Dist: pytest-cov; extra == "dev"
@@ -13,6 +13,7 @@ midas_diffract.egg-info/dependency_links.txt
13
13
  midas_diffract.egg-info/requires.txt
14
14
  midas_diffract.egg-info/top_level.txt
15
15
  tests/test_c_comparison.py
16
+ tests/test_distortion_layer.py
16
17
  tests/test_forward.py
17
18
  tests/test_hkls.py
18
19
  tests/test_losses.py
@@ -1,5 +1,7 @@
1
1
  numpy>=1.22
2
2
  torch>=2.0
3
+ midas-distortion>=0.2.0
4
+ midas-stress>=0.8.0
3
5
 
4
6
  [dev]
5
7
  pytest>=7.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "midas-diffract"
7
- version = "0.1.2"
7
+ version = "0.4.0"
8
8
  description = "End-to-end differentiable forward model for High-Energy Diffraction Microscopy (FF, NF, pf-HEDM)"
9
9
  readme = "README.md"
10
10
  license = "BSD-3-Clause"
@@ -26,6 +26,8 @@ classifiers = [
26
26
  dependencies = [
27
27
  "numpy>=1.22",
28
28
  "torch>=2.0",
29
+ "midas-distortion>=0.2.0",
30
+ "midas-stress>=0.8.0",
29
31
  ]
30
32
 
31
33
  [project.optional-dependencies]
@@ -0,0 +1,78 @@
1
+ """Tests for the gated ideal->raw radial-distortion layer (midas_distortion v2).
2
+
3
+ Guarantees:
4
+ 1. default (apply_distortion=False) is BYTE-IDENTICAL to the pre-layer forward
5
+ -- proves the indexer/fit-grain (ideal-frame) path is untouched.
6
+ 2. apply_distortion=True with zero coeffs is also identity.
7
+ 3. nonzero coeffs shift predicted spots (ideal->raw) by a sane, radius-growing
8
+ amount.
9
+ 4. the layer is differentiable w.r.t. the distortion coefficients.
10
+ """
11
+ import numpy as np
12
+ import pytest
13
+ import torch
14
+
15
+ from midas_diffract.forward import HEDMForwardModel, HEDMGeometry
16
+
17
+ _GD = dict(
18
+ Lsd=752000.0, y_BC=695.0, z_BC=874.0, px=172.0,
19
+ omega_start=180.0, omega_step=-0.25, n_frames=1440,
20
+ n_pixels_y=1679, n_pixels_z=1679, min_eta=6.0, wavelength=0.172979,
21
+ )
22
+ # A representative calibrated v2 coefficient set (Bucsek Pilatus 2M CeO2).
23
+ _COEFFS = [0.00707, -0.01, 0.00624, 0.01, -34.76, 0.00234, 81.47,
24
+ -0.00369, -12.29, -0.00727, -5.29, -0.00863, -1.51, -0.00446, -7.79]
25
+
26
+
27
+ def _model(**extra):
28
+ rng = np.random.default_rng(0)
29
+ hk = torch.tensor(rng.standard_normal((40, 3)), dtype=torch.float64)
30
+ th = torch.tensor(np.abs(rng.standard_normal(40)) * 0.04 + 0.04, dtype=torch.float64)
31
+ hi = torch.tensor(rng.integers(-3, 4, (40, 3)), dtype=torch.float64)
32
+ g = HEDMGeometry(**_GD, **extra)
33
+ return HEDMForwardModel(hk, th, g, hkls_int=hi).double()
34
+
35
+
36
+ _EUL = torch.tensor([[0.3, 0.5, 0.2]], dtype=torch.float64)
37
+ _POS = torch.zeros(1, 3, dtype=torch.float64)
38
+
39
+
40
+ def test_default_off_is_unchanged():
41
+ """No distortion fields == apply_distortion=False == zero coeffs (identity)."""
42
+ o0 = _model().forward(_EUL, _POS)
43
+ o_off = _model(apply_distortion=False, p_distortion=[0.0] * 15).forward(_EUL, _POS)
44
+ o_zero = _model(apply_distortion=True, p_distortion=[0.0] * 15, rho_d=2.0e5).forward(_EUL, _POS)
45
+ assert torch.equal(o0.y_pixel, o_off.y_pixel)
46
+ assert torch.equal(o0.z_pixel, o_off.z_pixel)
47
+ # zero coeffs with the layer ACTIVE must also be a no-op (D == 1).
48
+ assert torch.allclose(o0.y_pixel, o_zero.y_pixel, atol=1e-9)
49
+ assert torch.allclose(o0.z_pixel, o_zero.z_pixel, atol=1e-9)
50
+
51
+
52
+ def test_distortion_shifts_spots():
53
+ o0 = _model().forward(_EUL, _POS)
54
+ o2 = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).forward(_EUL, _POS)
55
+ vm = (o0.valid > 0.5) & (o2.valid > 0.5)
56
+ assert int(vm.sum()) > 5
57
+ dy = (o2.y_pixel - o0.y_pixel)[vm].abs()
58
+ assert float(dy.median()) > 0.05 # measurable
59
+ assert float(dy.max()) < 50.0 # but physical (sub-module)
60
+
61
+
62
+ def test_distortion_is_differentiable():
63
+ m = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5)
64
+ m.p_distortion.requires_grad_(True)
65
+ out = m.forward(_EUL, _POS)
66
+ loss = (out.y_pixel * out.valid).sum() + (out.z_pixel * out.valid).sum()
67
+ loss.backward()
68
+ assert m.p_distortion.grad is not None
69
+ assert float(m.p_distortion.grad.norm()) > 0.0
70
+
71
+
72
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
73
+ def test_distortion_cpu_cuda_parity():
74
+ o_cpu = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).forward(_EUL, _POS)
75
+ m = _model(apply_distortion=True, p_distortion=_COEFFS, rho_d=2.0e5).to("cuda")
76
+ o_gpu = m.forward(_EUL.cuda(), _POS.cuda())
77
+ assert torch.allclose(o_cpu.y_pixel, o_gpu.y_pixel.cpu(), atol=1e-7)
78
+ assert torch.allclose(o_cpu.z_pixel, o_gpu.z_pixel.cpu(), atol=1e-7)
@@ -1111,5 +1111,95 @@ class TestCrossValidation:
1111
1111
  import torch.nn.functional as F # for test_2d_position_backward_compat
1112
1112
 
1113
1113
 
1114
+ # ===================================================================
1115
+ # Test: midas_stress is the canonical orientation/strain-frame source
1116
+ # ===================================================================
1117
+
1118
+ class TestMidasStressDelegation:
1119
+ """Guards that euler2mat / rotate_strain_sample_to_crystal delegate to
1120
+ midas_stress and never silently re-port (convention drift)."""
1121
+
1122
+ def test_euler2mat_matches_midas_stress(self):
1123
+ from midas_stress.orientation import euler_to_orient_mat
1124
+ torch.manual_seed(0)
1125
+ eul = torch.rand(7, 3, dtype=torch.float64) * 2 * math.pi
1126
+ R = HEDMForwardModel.euler2mat(eul)
1127
+ R_ms = euler_to_orient_mat(eul).reshape(7, 3, 3)
1128
+ torch.testing.assert_close(R, R_ms, atol=1e-12, rtol=0)
1129
+
1130
+ def test_euler2mat_vmap_safe(self):
1131
+ from torch.func import vmap
1132
+ eul = torch.rand(5, 3, dtype=torch.float64) * 2 * math.pi
1133
+ R = vmap(HEDMForwardModel.euler2mat)(eul)
1134
+ assert R.shape == (5, 3, 3)
1135
+ torch.testing.assert_close(R, HEDMForwardModel.euler2mat(eul), atol=1e-12, rtol=0)
1136
+
1137
+ def test_euler2mat_differentiable(self):
1138
+ eul = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
1139
+ HEDMForwardModel.euler2mat(eul).pow(2).sum().backward()
1140
+ assert eul.grad is not None and torch.all(torch.isfinite(eul.grad))
1141
+
1142
+ def test_rotate_strain_matches_midas_stress(self):
1143
+ from midas_stress.tensor import strain_lab_to_grain
1144
+ torch.manual_seed(1)
1145
+ OM = HEDMForwardModel.euler2mat(torch.rand(4, 3, dtype=torch.float64))
1146
+ v = torch.rand(4, 6, dtype=torch.float64) * 1e-3 # plain Voigt
1147
+ out = HEDMForwardModel.rotate_strain_sample_to_crystal(OM, v)
1148
+ # reference: build 3x3, rotate via midas_stress, repack plain Voigt
1149
+ S = torch.zeros(4, 3, 3, dtype=torch.float64)
1150
+ S[:, 0, 0], S[:, 1, 1], S[:, 2, 2] = v[:, 0], v[:, 3], v[:, 5]
1151
+ S[:, 0, 1] = S[:, 1, 0] = v[:, 1]
1152
+ S[:, 0, 2] = S[:, 2, 0] = v[:, 2]
1153
+ S[:, 1, 2] = S[:, 2, 1] = v[:, 4]
1154
+ C = strain_lab_to_grain(S, OM)
1155
+ ref = torch.stack([C[:, 0, 0], C[:, 0, 1], C[:, 0, 2],
1156
+ C[:, 1, 1], C[:, 1, 2], C[:, 2, 2]], dim=-1)
1157
+ torch.testing.assert_close(out, ref, atol=1e-14, rtol=0)
1158
+
1159
+
1160
+ # ===================================================================
1161
+ # Test: full 3x3 strain tensor entry point (convention-free)
1162
+ # ===================================================================
1163
+
1164
+ class TestStrainTensorInput:
1165
+ """A full symmetric 3x3 strain must give the same result as the
1166
+ equivalent plain-Voigt 6-vector (NOT Voigt-Mandel)."""
1167
+
1168
+ def test_strain_as_voigt_roundtrip(self):
1169
+ v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
1170
+ S = torch.tensor([[1e-3, 2e-4, 3e-4],
1171
+ [2e-4, -5e-4, 1e-4],
1172
+ [3e-4, 1e-4, 2e-4]])
1173
+ torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(S), v)
1174
+ # plain Voigt passes through unchanged
1175
+ torch.testing.assert_close(HEDMForwardModel.strain_as_voigt(v), v)
1176
+
1177
+ def test_3x3_strain_equals_voigt_in_correct_hkls_latc(self, nf_geometry, device):
1178
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1179
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.])
1180
+ v = torch.tensor([1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4])
1181
+ S = torch.tensor([[1e-3, 2e-4, 3e-4],
1182
+ [2e-4, -5e-4, 1e-4],
1183
+ [3e-4, 1e-4, 2e-4]])
1184
+ g_v, t_v = model.correct_hkls_latc(latc, strain=v)
1185
+ g_S, t_S = model.correct_hkls_latc(latc, strain=S)
1186
+ torch.testing.assert_close(g_v, g_S)
1187
+ torch.testing.assert_close(t_v, t_S)
1188
+
1189
+ def test_3x3_strain_equals_voigt_in_forward(self, nf_geometry, device):
1190
+ model, _, _ = make_model_with_cubic_iron(nf_geometry, device)
1191
+ eul = torch.tensor([[0.3, 0.5, 1.1]])
1192
+ pos = torch.zeros(1, 3)
1193
+ latc = torch.tensor([2.87, 2.87, 2.87, 90., 90., 90.]).expand(1, 6)
1194
+ v = torch.tensor([[1e-3, 2e-4, 3e-4, -5e-4, 1e-4, 2e-4]])
1195
+ S = torch.tensor([[[1e-3, 2e-4, 3e-4],
1196
+ [2e-4, -5e-4, 1e-4],
1197
+ [3e-4, 1e-4, 2e-4]]])
1198
+ sp_v = model(eul, pos, lattice_params=latc, strain=v)
1199
+ sp_S = model(eul, pos, lattice_params=latc, strain=S)
1200
+ torch.testing.assert_close(sp_v.two_theta, sp_S.two_theta)
1201
+ torch.testing.assert_close(sp_v.omega, sp_S.omega)
1202
+
1203
+
1114
1204
  if __name__ == "__main__":
1115
1205
  pytest.main([__file__, "-v"])
File without changes
File without changes
File without changes