midas-diffract 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- midas_diffract/__init__.py +55 -0
- midas_diffract/forward.py +1559 -0
- midas_diffract/hkls.py +180 -0
- midas_diffract/losses.py +494 -0
- midas_diffract/optimize.py +248 -0
- midas_diffract-0.1.0.dist-info/METADATA +122 -0
- midas_diffract-0.1.0.dist-info/RECORD +10 -0
- midas_diffract-0.1.0.dist-info/WHEEL +5 -0
- midas_diffract-0.1.0.dist-info/licenses/LICENSE +31 -0
- midas_diffract-0.1.0.dist-info/top_level.txt +1 -0
midas_diffract/hkls.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Build forward-model reflection lists from ``midas-hkls`` outputs.
|
|
2
|
+
|
|
3
|
+
``HEDMForwardModel`` consumes three tensors that are conventionally supplied
|
|
4
|
+
by ``GetHKLList`` (the C tool in MIDAS):
|
|
5
|
+
|
|
6
|
+
* ``hkls_int`` -- (M, 3) integer Miller indices, **one row per spot**
|
|
7
|
+
(i.e. all symmetry-equivalent variants of each ASU
|
|
8
|
+
representative are enumerated)
|
|
9
|
+
* ``hkls_cart`` -- (M, 3) reference Cartesian G-vectors in 1/Angstroms
|
|
10
|
+
* ``thetas`` -- (M,) reference Bragg angles in radians
|
|
11
|
+
|
|
12
|
+
This module produces the same triplet from the pure-Python ``midas-hkls``
|
|
13
|
+
package, so users do not need the MIDAS C build to drive the forward model.
|
|
14
|
+
|
|
15
|
+
Example
|
|
16
|
+
-------
|
|
17
|
+
from midas_hkls import SpaceGroup, Lattice
|
|
18
|
+
import midas_diffract as md
|
|
19
|
+
|
|
20
|
+
sg = SpaceGroup.from_number(225) # FCC (Cu/Au/Ni)
|
|
21
|
+
lat = sg_lat = md.Lattice.for_system("cubic", a=4.08) # if you re-export
|
|
22
|
+
hkls_cart, thetas, hkls_int = md.hkls_for_forward_model(
|
|
23
|
+
sg, lat, wavelength_A=0.172979, two_theta_max_deg=15.0,
|
|
24
|
+
)
|
|
25
|
+
model = md.HEDMForwardModel(
|
|
26
|
+
hkls=hkls_cart, thetas=thetas, geometry=geom, hkls_int=hkls_int,
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
from math import cos, pi, sin
|
|
32
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
import torch
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from midas_hkls import Lattice, SpaceGroup
|
|
39
|
+
|
|
40
|
+
DEG2RAD = pi / 180.0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _cartesian_B_matrix(latc: "tuple[float, float, float, float, float, float]") -> np.ndarray:
|
|
44
|
+
"""Reference reciprocal-lattice basis in Cartesian coords (column = a*, b*, c*).
|
|
45
|
+
|
|
46
|
+
Mirrors the B-matrix convention in
|
|
47
|
+
:meth:`midas_diffract.forward.HEDMForwardModel.correct_hkls_latc`, which
|
|
48
|
+
in turn is the C convention from ``CorrectHKLsLatC`` in
|
|
49
|
+
``FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252``. Keeping the
|
|
50
|
+
convention bit-aligned guarantees that ``hkls_cart = B @ hkls_int^T``
|
|
51
|
+
here matches the model's strain path, so passing ``lattice_params=`` at
|
|
52
|
+
forward time recomputes the same numbers up to floating-point error.
|
|
53
|
+
"""
|
|
54
|
+
a, b, c, alpha_d, beta_d, gamma_d = latc
|
|
55
|
+
alpha = alpha_d * DEG2RAD
|
|
56
|
+
beta = beta_d * DEG2RAD
|
|
57
|
+
gamma = gamma_d * DEG2RAD
|
|
58
|
+
sin_a, cos_a = sin(alpha), cos(alpha)
|
|
59
|
+
sin_b, cos_b = sin(beta), cos(beta)
|
|
60
|
+
sin_g, cos_g = sin(gamma), cos(gamma)
|
|
61
|
+
|
|
62
|
+
eps = 1e-7
|
|
63
|
+
gamma_pr = np.arccos(np.clip(
|
|
64
|
+
(cos_a * cos_b - cos_g) / (sin_a * sin_b + eps), -1 + eps, 1 - eps,
|
|
65
|
+
))
|
|
66
|
+
beta_pr = np.arccos(np.clip(
|
|
67
|
+
(cos_g * cos_a - cos_b) / (sin_g * sin_a + eps), -1 + eps, 1 - eps,
|
|
68
|
+
))
|
|
69
|
+
sin_beta_pr = np.sin(beta_pr)
|
|
70
|
+
|
|
71
|
+
vol = a * b * c * sin_a * sin_beta_pr * sin_g
|
|
72
|
+
a_pr = b * c * sin_a / (vol + eps)
|
|
73
|
+
b_pr = c * a * sin_b / (vol + eps)
|
|
74
|
+
c_pr = a * b * sin_g / (vol + eps)
|
|
75
|
+
|
|
76
|
+
B = np.array([
|
|
77
|
+
[a_pr, b_pr * np.cos(gamma_pr), c_pr * np.cos(beta_pr)],
|
|
78
|
+
[0.0, b_pr * np.sin(gamma_pr), -c_pr * sin_beta_pr * cos_a],
|
|
79
|
+
[0.0, 0.0, c_pr * sin_beta_pr * sin_a],
|
|
80
|
+
])
|
|
81
|
+
return B
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def hkls_for_forward_model(
|
|
85
|
+
space_group: "SpaceGroup",
|
|
86
|
+
lattice: "Lattice",
|
|
87
|
+
*,
|
|
88
|
+
wavelength_A: float,
|
|
89
|
+
two_theta_max_deg: Optional[float] = None,
|
|
90
|
+
d_min: Optional[float] = None,
|
|
91
|
+
expand_equivalents: bool = True,
|
|
92
|
+
dtype: torch.dtype = torch.float64,
|
|
93
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
94
|
+
"""Build (``hkls_cart``, ``thetas``, ``hkls_int``) for ``HEDMForwardModel``.
|
|
95
|
+
|
|
96
|
+
Wraps :func:`midas_hkls.generate_hkls` -- which returns ASU
|
|
97
|
+
representatives -- and (by default) expands each to all
|
|
98
|
+
Laue-equivalent integer Miller indices, so every detector spot is
|
|
99
|
+
enumerated. Then computes the Cartesian G-vectors using a B-matrix
|
|
100
|
+
convention that is consistent with the forward model's internal
|
|
101
|
+
strain-recompute path.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
space_group, lattice
|
|
106
|
+
From the ``midas-hkls`` package.
|
|
107
|
+
wavelength_A : float
|
|
108
|
+
X-ray wavelength in Angstroms.
|
|
109
|
+
two_theta_max_deg, d_min
|
|
110
|
+
Cutoff for reflection enumeration. At least one must be supplied.
|
|
111
|
+
See :func:`midas_hkls.generate_hkls`.
|
|
112
|
+
expand_equivalents : bool, default True
|
|
113
|
+
If True, return one row per Laue-equivalent reflection (matches
|
|
114
|
+
``GetHKLList`` output and is what the forward model expects). If
|
|
115
|
+
False, return only ASU representatives -- useful for diagnostics.
|
|
116
|
+
dtype : torch.dtype
|
|
117
|
+
Output tensor dtype. Defaults to float64; the model casts to
|
|
118
|
+
float32 internally for the buffers but keeps double precision in
|
|
119
|
+
the input pipeline if requested.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
hkls_cart : Tensor (M, 3)
|
|
124
|
+
Cartesian reciprocal-space G-vectors in 1/Angstroms.
|
|
125
|
+
thetas : Tensor (M,)
|
|
126
|
+
Bragg angles in radians.
|
|
127
|
+
hkls_int : Tensor (M, 3)
|
|
128
|
+
Integer Miller indices (one row per spot), as floats so they can
|
|
129
|
+
be moved through ``torch.matmul`` cleanly inside the model.
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
from midas_hkls import generate_hkls # type: ignore
|
|
133
|
+
except ImportError as exc:
|
|
134
|
+
raise ImportError(
|
|
135
|
+
"midas_diffract.hkls requires the optional 'midas-hkls' package. "
|
|
136
|
+
"Install with: pip install midas-hkls"
|
|
137
|
+
) from exc
|
|
138
|
+
|
|
139
|
+
refs = generate_hkls(
|
|
140
|
+
space_group,
|
|
141
|
+
lattice,
|
|
142
|
+
wavelength_A=wavelength_A,
|
|
143
|
+
two_theta_max_deg=two_theta_max_deg,
|
|
144
|
+
d_min=d_min,
|
|
145
|
+
)
|
|
146
|
+
if not refs:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"midas_hkls.generate_hkls returned no reflections; check "
|
|
149
|
+
"wavelength / cutoff arguments."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
rows = []
|
|
153
|
+
for r in refs:
|
|
154
|
+
if expand_equivalents:
|
|
155
|
+
rows.extend(space_group.equivalent_hkls(r.h, r.k, r.l))
|
|
156
|
+
else:
|
|
157
|
+
rows.append((r.h, r.k, r.l))
|
|
158
|
+
hkls_int_np = np.asarray(rows, dtype=np.float64)
|
|
159
|
+
|
|
160
|
+
B = _cartesian_B_matrix(
|
|
161
|
+
(lattice.a, lattice.b, lattice.c,
|
|
162
|
+
lattice.alpha, lattice.beta, lattice.gamma)
|
|
163
|
+
)
|
|
164
|
+
G_cart = hkls_int_np @ B.T # (M, 3) Cartesian G in 1/A
|
|
165
|
+
|
|
166
|
+
g_mag = np.linalg.norm(G_cart, axis=-1)
|
|
167
|
+
s = g_mag * wavelength_A / 2.0
|
|
168
|
+
if np.any(s > 1.0):
|
|
169
|
+
bad = int(np.sum(s > 1.0))
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"{bad} reflections fall outside the Bragg cutoff (|G|*lambda/2 > 1) "
|
|
172
|
+
"for the requested cutoff -- tighten two_theta_max_deg / d_min."
|
|
173
|
+
)
|
|
174
|
+
thetas_np = np.arcsin(s)
|
|
175
|
+
|
|
176
|
+
return (
|
|
177
|
+
torch.tensor(G_cart, dtype=dtype),
|
|
178
|
+
torch.tensor(thetas_np, dtype=dtype),
|
|
179
|
+
torch.tensor(hkls_int_np, dtype=dtype),
|
|
180
|
+
)
|
midas_diffract/losses.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
"""Loss functions and spot matching utilities for HEDM optimization.
|
|
2
|
+
|
|
3
|
+
Two output modes:
|
|
4
|
+
NF-HEDM: Image comparison losses (NCC, L2, log-ratio, SSIM)
|
|
5
|
+
FF/pf-HEDM: Spot coordinate matching losses (L2, angular, Huber)
|
|
6
|
+
|
|
7
|
+
Also provides SpotAssigner for non-differentiable spot-to-spot matching
|
|
8
|
+
used in the FF/pf optimization loop.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import math
|
|
12
|
+
from typing import Optional, Tuple
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
# Image comparison losses (NF-HEDM)
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
|
|
22
|
+
class ImageComparisonLoss(nn.Module):
|
|
23
|
+
"""Loss for comparing predicted vs observed detector images.
|
|
24
|
+
|
|
25
|
+
Used in NF-HEDM where the forward model produces full predicted images
|
|
26
|
+
via Gaussian splatting and we compare to observed detector images.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
mode : str
|
|
31
|
+
``"ncc"`` : Normalized Cross-Correlation (scale-invariant, recommended).
|
|
32
|
+
``"l2"`` : Mean Squared Error.
|
|
33
|
+
``"log_ratio"`` : Log-ratio loss (marginalizes unknown scale factor).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, mode: str = "ncc"):
|
|
37
|
+
super().__init__()
|
|
38
|
+
if mode not in ("ncc", "l2", "log_ratio"):
|
|
39
|
+
raise ValueError(f"Unknown mode: {mode!r}")
|
|
40
|
+
self.mode = mode
|
|
41
|
+
|
|
42
|
+
def forward(
|
|
43
|
+
self,
|
|
44
|
+
pred: torch.Tensor,
|
|
45
|
+
obs: torch.Tensor,
|
|
46
|
+
mask: Optional[torch.Tensor] = None,
|
|
47
|
+
) -> torch.Tensor:
|
|
48
|
+
"""Compute image comparison loss.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
pred : Tensor (..., H, W) or (..., F, H, W)
|
|
53
|
+
Predicted images.
|
|
54
|
+
obs : Tensor (same shape as pred)
|
|
55
|
+
Observed images.
|
|
56
|
+
mask : Tensor (same shape), optional
|
|
57
|
+
Binary mask. 1 = include pixel, 0 = ignore.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
Scalar loss tensor.
|
|
62
|
+
"""
|
|
63
|
+
if mask is not None:
|
|
64
|
+
pred = pred * mask
|
|
65
|
+
obs = obs * mask
|
|
66
|
+
|
|
67
|
+
if self.mode == "ncc":
|
|
68
|
+
return self._ncc_loss(pred, obs)
|
|
69
|
+
elif self.mode == "l2":
|
|
70
|
+
return self._l2_loss(pred, obs)
|
|
71
|
+
elif self.mode == "log_ratio":
|
|
72
|
+
return self._log_ratio_loss(pred, obs)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _ncc_loss(pred: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
"""Normalized Cross-Correlation loss (1 - NCC).
|
|
77
|
+
|
|
78
|
+
NCC = sum(pred * obs) / (||pred|| * ||obs||)
|
|
79
|
+
Loss = 1 - NCC (so 0 = perfect match)
|
|
80
|
+
"""
|
|
81
|
+
# Flatten spatial dims for dot product
|
|
82
|
+
p = pred.reshape(pred.shape[0], -1) if pred.ndim > 1 else pred.unsqueeze(0)
|
|
83
|
+
o = obs.reshape(obs.shape[0], -1) if obs.ndim > 1 else obs.unsqueeze(0)
|
|
84
|
+
|
|
85
|
+
# General flatten: merge all but keep at least 1 batch dim
|
|
86
|
+
p_flat = pred.flatten()
|
|
87
|
+
o_flat = obs.flatten()
|
|
88
|
+
|
|
89
|
+
dot = torch.sum(p_flat * o_flat)
|
|
90
|
+
norm_p = torch.norm(p_flat).clamp(min=1e-12)
|
|
91
|
+
norm_o = torch.norm(o_flat).clamp(min=1e-12)
|
|
92
|
+
ncc = dot / (norm_p * norm_o)
|
|
93
|
+
return 1.0 - ncc
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _l2_loss(pred: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
"""Mean Squared Error loss."""
|
|
98
|
+
return torch.mean((pred - obs) ** 2)
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def _log_ratio_loss(
|
|
102
|
+
pred: torch.Tensor, obs: torch.Tensor, eps: float = 1e-6
|
|
103
|
+
) -> torch.Tensor:
|
|
104
|
+
"""Log-ratio loss: ||log(pred+eps) - log(obs+eps) - mu||^2.
|
|
105
|
+
|
|
106
|
+
Analytically marginalizes out the unknown global scaling factor
|
|
107
|
+
by subtracting the mean log-ratio (mu).
|
|
108
|
+
"""
|
|
109
|
+
log_pred = torch.log(pred + eps)
|
|
110
|
+
log_obs = torch.log(obs + eps)
|
|
111
|
+
diff = log_pred - log_obs
|
|
112
|
+
mu = torch.mean(diff)
|
|
113
|
+
return torch.mean((diff - mu) ** 2)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ---------------------------------------------------------------------------
|
|
117
|
+
# Spot coordinate matching losses (FF/pf-HEDM)
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
class SpotMatchingLoss(nn.Module):
|
|
121
|
+
"""Loss for matching predicted spot coordinates to observed spot COMs.
|
|
122
|
+
|
|
123
|
+
Used in FF/pf-HEDM where the forward model predicts spot coordinates
|
|
124
|
+
(2theta, eta, omega) and we compare to observed center-of-mass positions.
|
|
125
|
+
|
|
126
|
+
The assignment of predicted-to-observed spots is done externally by
|
|
127
|
+
``SpotAssigner`` (non-differentiable). Given fixed assignments, this
|
|
128
|
+
loss is fully differentiable w.r.t. predicted coordinates.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
metric : str
|
|
133
|
+
``"l2"`` : Euclidean distance (sum of squared differences).
|
|
134
|
+
``"huber"`` : Smooth L1 (robust to outliers).
|
|
135
|
+
``"angular"``: Weighted angular distance with per-coordinate weights.
|
|
136
|
+
weights : Tensor (3,), optional
|
|
137
|
+
Per-coordinate weights for [2theta, eta, omega].
|
|
138
|
+
Default: equal weights [1, 1, 1].
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
metric: str = "l2",
|
|
144
|
+
weights: Optional[torch.Tensor] = None,
|
|
145
|
+
):
|
|
146
|
+
super().__init__()
|
|
147
|
+
if metric not in ("l2", "huber", "angular"):
|
|
148
|
+
raise ValueError(f"Unknown metric: {metric!r}")
|
|
149
|
+
self.metric = metric
|
|
150
|
+
if weights is not None:
|
|
151
|
+
self.register_buffer("weights", weights.float())
|
|
152
|
+
else:
|
|
153
|
+
self.weights = None
|
|
154
|
+
|
|
155
|
+
def forward(
|
|
156
|
+
self,
|
|
157
|
+
pred_coords: torch.Tensor,
|
|
158
|
+
obs_coords: torch.Tensor,
|
|
159
|
+
spot_weights: Optional[torch.Tensor] = None,
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
"""Compute spot matching loss.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
pred_coords : Tensor (N_matched, 3)
|
|
166
|
+
Predicted spot coordinates.
|
|
167
|
+
obs_coords : Tensor (N_matched, 3)
|
|
168
|
+
Observed spot coordinates (same order as pred).
|
|
169
|
+
spot_weights : Tensor (N_matched,), optional
|
|
170
|
+
Per-spot weights (e.g., intensity-based).
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
Scalar loss tensor.
|
|
175
|
+
"""
|
|
176
|
+
diff = pred_coords - obs_coords
|
|
177
|
+
|
|
178
|
+
if self.weights is not None:
|
|
179
|
+
diff = diff * self.weights.unsqueeze(0)
|
|
180
|
+
|
|
181
|
+
if self.metric == "l2":
|
|
182
|
+
per_spot = torch.sum(diff ** 2, dim=-1)
|
|
183
|
+
elif self.metric == "huber":
|
|
184
|
+
per_spot = torch.sum(
|
|
185
|
+
torch.nn.functional.smooth_l1_loss(
|
|
186
|
+
diff, torch.zeros_like(diff), reduction="none"
|
|
187
|
+
),
|
|
188
|
+
dim=-1,
|
|
189
|
+
)
|
|
190
|
+
elif self.metric == "angular":
|
|
191
|
+
per_spot = torch.sum(diff ** 2, dim=-1)
|
|
192
|
+
|
|
193
|
+
if spot_weights is not None:
|
|
194
|
+
per_spot = per_spot * spot_weights
|
|
195
|
+
|
|
196
|
+
return torch.mean(per_spot)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# ---------------------------------------------------------------------------
|
|
200
|
+
# Spot assignment (non-differentiable)
|
|
201
|
+
# ---------------------------------------------------------------------------
|
|
202
|
+
|
|
203
|
+
class SpotAssigner:
|
|
204
|
+
"""Assign predicted spots to nearest observed spots.
|
|
205
|
+
|
|
206
|
+
This is a non-differentiable operation used in the FF/pf-HEDM
|
|
207
|
+
optimization loop: run periodically to update assignments, then
|
|
208
|
+
use ``SpotMatchingLoss`` with fixed assignments for gradient steps.
|
|
209
|
+
|
|
210
|
+
Matches by nearest neighbor in (2theta, eta, omega) space, optionally
|
|
211
|
+
restricted to the same ring number (HKL family).
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
obs_coords : Tensor (N_obs, 3)
|
|
216
|
+
Observed spot coordinates (2theta, eta, omega) in radians.
|
|
217
|
+
obs_ring_numbers : Tensor (N_obs,), optional
|
|
218
|
+
Ring number for each observed spot. If provided, matching is
|
|
219
|
+
restricted to same-ring spots only.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
def __init__(
|
|
223
|
+
self,
|
|
224
|
+
obs_coords: torch.Tensor,
|
|
225
|
+
obs_ring_numbers: Optional[torch.Tensor] = None,
|
|
226
|
+
):
|
|
227
|
+
self.obs_coords = obs_coords
|
|
228
|
+
self.obs_ring_numbers = obs_ring_numbers
|
|
229
|
+
|
|
230
|
+
@torch.no_grad()
|
|
231
|
+
def assign(
|
|
232
|
+
self,
|
|
233
|
+
pred_coords: torch.Tensor,
|
|
234
|
+
pred_valid: torch.Tensor,
|
|
235
|
+
pred_ring_numbers: Optional[torch.Tensor] = None,
|
|
236
|
+
max_distance: float = 0.1,
|
|
237
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
238
|
+
"""Find nearest observed spot for each valid predicted spot.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
pred_coords : Tensor (..., K, M, 3)
|
|
243
|
+
Predicted spot coordinates from ``predict_spot_coords``.
|
|
244
|
+
pred_valid : Tensor (..., K, M)
|
|
245
|
+
Validity mask.
|
|
246
|
+
pred_ring_numbers : Tensor (M,), optional
|
|
247
|
+
Ring number per HKL. If provided, matching restricted to same ring.
|
|
248
|
+
max_distance : float
|
|
249
|
+
Maximum matching distance in radians. Pairs beyond this are rejected.
|
|
250
|
+
|
|
251
|
+
Returns
|
|
252
|
+
-------
|
|
253
|
+
pred_matched : Tensor (N_matched, 3)
|
|
254
|
+
Matched predicted coordinates (detached, but index-aligned with obs_matched).
|
|
255
|
+
obs_matched : Tensor (N_matched, 3)
|
|
256
|
+
Matched observed coordinates.
|
|
257
|
+
pred_indices : Tensor (N_matched,) of long
|
|
258
|
+
Flat indices into the valid predicted spots (for gradient routing).
|
|
259
|
+
"""
|
|
260
|
+
# Flatten predicted spots
|
|
261
|
+
flat_coords = pred_coords.reshape(-1, 3)
|
|
262
|
+
flat_valid = pred_valid.reshape(-1)
|
|
263
|
+
|
|
264
|
+
# Get valid indices
|
|
265
|
+
valid_idx = torch.nonzero(flat_valid > 0.5, as_tuple=False).squeeze(-1)
|
|
266
|
+
if valid_idx.numel() == 0:
|
|
267
|
+
empty = torch.zeros(0, 3, device=flat_coords.device)
|
|
268
|
+
return empty, empty, torch.zeros(0, dtype=torch.long, device=flat_coords.device)
|
|
269
|
+
|
|
270
|
+
valid_coords = flat_coords[valid_idx] # (V, 3)
|
|
271
|
+
|
|
272
|
+
# Compute distances to all observed spots
|
|
273
|
+
# valid_coords: (V, 3), obs_coords: (N_obs, 3)
|
|
274
|
+
# Use cdist for efficiency
|
|
275
|
+
dists = torch.cdist(valid_coords, self.obs_coords) # (V, N_obs)
|
|
276
|
+
|
|
277
|
+
# If ring numbers provided, mask cross-ring matches
|
|
278
|
+
if (pred_ring_numbers is not None and
|
|
279
|
+
self.obs_ring_numbers is not None):
|
|
280
|
+
# Expand ring numbers for valid spots
|
|
281
|
+
# pred_ring_numbers: (M,), repeat for K*M pattern
|
|
282
|
+
M = pred_ring_numbers.shape[0]
|
|
283
|
+
K_total = flat_valid.shape[0] // M if M > 0 else 0
|
|
284
|
+
if K_total > 0:
|
|
285
|
+
flat_rings = pred_ring_numbers.repeat(K_total)
|
|
286
|
+
valid_rings = flat_rings[valid_idx] # (V,)
|
|
287
|
+
ring_mismatch = (
|
|
288
|
+
valid_rings.unsqueeze(1) != self.obs_ring_numbers.unsqueeze(0)
|
|
289
|
+
)
|
|
290
|
+
dists = dists + ring_mismatch.float() * 1e6
|
|
291
|
+
|
|
292
|
+
# Nearest neighbor
|
|
293
|
+
min_dists, nn_idx = dists.min(dim=1) # (V,), (V,)
|
|
294
|
+
|
|
295
|
+
# Filter by max distance
|
|
296
|
+
keep = min_dists < max_distance
|
|
297
|
+
if not keep.any():
|
|
298
|
+
empty = torch.zeros(0, 3, device=flat_coords.device)
|
|
299
|
+
return empty, empty, torch.zeros(0, dtype=torch.long, device=flat_coords.device)
|
|
300
|
+
|
|
301
|
+
pred_matched = valid_coords[keep]
|
|
302
|
+
obs_matched = self.obs_coords[nn_idx[keep]]
|
|
303
|
+
pred_indices = valid_idx[keep]
|
|
304
|
+
|
|
305
|
+
return pred_matched, obs_matched, pred_indices
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# ---------------------------------------------------------------------------
|
|
309
|
+
# Differentiable stress/strain (PyTorch)
|
|
310
|
+
# ---------------------------------------------------------------------------
|
|
311
|
+
|
|
312
|
+
def tensor_to_voigt(T: torch.Tensor) -> torch.Tensor:
|
|
313
|
+
"""3x3 symmetric tensor to 6-vector Voigt-Mandel (sqrt(2) shear).
|
|
314
|
+
|
|
315
|
+
Fully differentiable.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
T : Tensor (..., 3, 3)
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
Tensor (..., 6) -- [xx, yy, zz, sqrt2*yz, sqrt2*xz, sqrt2*xy]
|
|
324
|
+
"""
|
|
325
|
+
s2 = math.sqrt(2.0)
|
|
326
|
+
return torch.stack([
|
|
327
|
+
T[..., 0, 0], T[..., 1, 1], T[..., 2, 2],
|
|
328
|
+
s2 * T[..., 1, 2], s2 * T[..., 0, 2], s2 * T[..., 0, 1],
|
|
329
|
+
], dim=-1)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def voigt_to_tensor(v: torch.Tensor) -> torch.Tensor:
|
|
333
|
+
"""6-vector Voigt-Mandel to 3x3 symmetric tensor.
|
|
334
|
+
|
|
335
|
+
Fully differentiable.
|
|
336
|
+
|
|
337
|
+
Parameters
|
|
338
|
+
----------
|
|
339
|
+
v : Tensor (..., 6)
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
Tensor (..., 3, 3)
|
|
344
|
+
"""
|
|
345
|
+
s2i = 1.0 / math.sqrt(2.0)
|
|
346
|
+
xx, yy, zz = v[..., 0], v[..., 1], v[..., 2]
|
|
347
|
+
yz = v[..., 3] * s2i
|
|
348
|
+
xz = v[..., 4] * s2i
|
|
349
|
+
xy = v[..., 5] * s2i
|
|
350
|
+
row0 = torch.stack([xx, xy, xz], dim=-1)
|
|
351
|
+
row1 = torch.stack([xy, yy, yz], dim=-1)
|
|
352
|
+
row2 = torch.stack([xz, yz, zz], dim=-1)
|
|
353
|
+
return torch.stack([row0, row1, row2], dim=-2)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def cubic_stiffness_tensor(
|
|
357
|
+
C11: float, C12: float, C44: float,
|
|
358
|
+
dtype: torch.dtype = torch.float64,
|
|
359
|
+
device: torch.device = torch.device("cpu"),
|
|
360
|
+
) -> torch.Tensor:
|
|
361
|
+
"""6x6 stiffness matrix for cubic crystal (Voigt-Mandel notation).
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
C11, C12, C44 : float
|
|
366
|
+
Independent elastic constants in GPa.
|
|
367
|
+
|
|
368
|
+
Returns
|
|
369
|
+
-------
|
|
370
|
+
Tensor (6, 6)
|
|
371
|
+
"""
|
|
372
|
+
C = torch.zeros(6, 6, dtype=dtype, device=device)
|
|
373
|
+
C[0, 0] = C[1, 1] = C[2, 2] = C11
|
|
374
|
+
C[0, 1] = C[0, 2] = C[1, 0] = C[1, 2] = C[2, 0] = C[2, 1] = C12
|
|
375
|
+
C[3, 3] = C[4, 4] = C[5, 5] = 2.0 * C44 # Mandel convention
|
|
376
|
+
return C
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def rotation_voigt_mandel(U: torch.Tensor) -> torch.Tensor:
|
|
380
|
+
"""6x6 rotation matrix in Voigt-Mandel space. Fully differentiable.
|
|
381
|
+
|
|
382
|
+
Transforms vectorized symmetric tensors between frames:
|
|
383
|
+
{eps_rotated} = M @ {eps_original}
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
U : Tensor (..., 3, 3) rotation matrix
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
Tensor (..., 6, 6)
|
|
392
|
+
"""
|
|
393
|
+
s2 = math.sqrt(2.0)
|
|
394
|
+
pairs = [(1, 2), (0, 2), (0, 1)]
|
|
395
|
+
|
|
396
|
+
M = torch.zeros(*U.shape[:-2], 6, 6, dtype=U.dtype, device=U.device)
|
|
397
|
+
|
|
398
|
+
# Normal-normal block
|
|
399
|
+
for i in range(3):
|
|
400
|
+
for j in range(3):
|
|
401
|
+
M[..., i, j] = U[..., i, j] ** 2
|
|
402
|
+
|
|
403
|
+
# Normal-shear coupling
|
|
404
|
+
for ci, (p, q) in enumerate(pairs):
|
|
405
|
+
for r in range(3):
|
|
406
|
+
M[..., r, 3 + ci] = s2 * U[..., r, p] * U[..., r, q]
|
|
407
|
+
|
|
408
|
+
# Shear-normal coupling
|
|
409
|
+
for ri, (p, q) in enumerate(pairs):
|
|
410
|
+
for c in range(3):
|
|
411
|
+
M[..., 3 + ri, c] = s2 * U[..., p, c] * U[..., q, c]
|
|
412
|
+
|
|
413
|
+
# Shear-shear block
|
|
414
|
+
for ri, (r1, r2) in enumerate(pairs):
|
|
415
|
+
for ci, (c1, c2) in enumerate(pairs):
|
|
416
|
+
M[..., 3 + ri, 3 + ci] = (
|
|
417
|
+
U[..., r1, c1] * U[..., r2, c2]
|
|
418
|
+
+ U[..., r1, c2] * U[..., r2, c1]
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
return M
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def hooke_stress(
|
|
425
|
+
strain: torch.Tensor,
|
|
426
|
+
stiffness: torch.Tensor,
|
|
427
|
+
orient: Optional[torch.Tensor] = None,
|
|
428
|
+
frame: str = "lab",
|
|
429
|
+
) -> torch.Tensor:
|
|
430
|
+
"""Differentiable Hooke's law: strain -> stress.
|
|
431
|
+
|
|
432
|
+
Parameters
|
|
433
|
+
----------
|
|
434
|
+
strain : Tensor (..., 3, 3) or (..., 6)
|
|
435
|
+
Strain tensor (Voigt-Mandel or full 3x3).
|
|
436
|
+
stiffness : Tensor (6, 6)
|
|
437
|
+
Single-crystal stiffness in Voigt-Mandel notation, crystal frame.
|
|
438
|
+
orient : Tensor (..., 3, 3), optional
|
|
439
|
+
Orientation matrix. Required for ``frame="lab"``.
|
|
440
|
+
frame : str
|
|
441
|
+
``"grain"``: strain and output in grain frame.
|
|
442
|
+
``"lab"``: strain in lab frame; transform, apply C, transform back.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
Tensor (..., 3, 3) stress tensor.
|
|
447
|
+
"""
|
|
448
|
+
if strain.shape[-1] == 3 and strain.shape[-2] == 3:
|
|
449
|
+
eps_v = tensor_to_voigt(strain)
|
|
450
|
+
else:
|
|
451
|
+
eps_v = strain
|
|
452
|
+
|
|
453
|
+
if frame == "grain":
|
|
454
|
+
sig_v = eps_v @ stiffness.T
|
|
455
|
+
return voigt_to_tensor(sig_v)
|
|
456
|
+
|
|
457
|
+
if orient is None:
|
|
458
|
+
raise ValueError("orient required for lab-frame computation")
|
|
459
|
+
|
|
460
|
+
M = rotation_voigt_mandel(orient) # (..., 6, 6)
|
|
461
|
+
Mt = M.transpose(-1, -2)
|
|
462
|
+
C_lab = Mt @ stiffness @ M # (..., 6, 6)
|
|
463
|
+
sig_v = (C_lab @ eps_v.unsqueeze(-1)).squeeze(-1)
|
|
464
|
+
return voigt_to_tensor(sig_v)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def volume_average_stress_constraint(
|
|
468
|
+
stresses: torch.Tensor,
|
|
469
|
+
volumes: torch.Tensor,
|
|
470
|
+
applied_stress: Optional[torch.Tensor] = None,
|
|
471
|
+
) -> torch.Tensor:
|
|
472
|
+
"""Differentiable volume-average stress constraint (FF-1).
|
|
473
|
+
|
|
474
|
+
Enforces: sum(V_g * sigma_g) / V_total = sigma_applied
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
stresses : Tensor (N, 3, 3)
|
|
479
|
+
volumes : Tensor (N,)
|
|
480
|
+
applied_stress : Tensor (3, 3), optional. Default: zero.
|
|
481
|
+
|
|
482
|
+
Returns
|
|
483
|
+
-------
|
|
484
|
+
Tensor (N, 3, 3) corrected stresses.
|
|
485
|
+
"""
|
|
486
|
+
if applied_stress is None:
|
|
487
|
+
applied_stress = torch.zeros(3, 3, dtype=stresses.dtype,
|
|
488
|
+
device=stresses.device)
|
|
489
|
+
|
|
490
|
+
V_total = volumes.sum()
|
|
491
|
+
w = volumes / V_total
|
|
492
|
+
sig_avg = (w[:, None, None] * stresses).sum(dim=0)
|
|
493
|
+
delta = applied_stress - sig_avg
|
|
494
|
+
return stresses + delta.unsqueeze(0)
|