midas-diffract 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
midas_diffract/hkls.py ADDED
@@ -0,0 +1,180 @@
1
+ """Build forward-model reflection lists from ``midas-hkls`` outputs.
2
+
3
+ ``HEDMForwardModel`` consumes three tensors that are conventionally supplied
4
+ by ``GetHKLList`` (the C tool in MIDAS):
5
+
6
+ * ``hkls_int`` -- (M, 3) integer Miller indices, **one row per spot**
7
+ (i.e. all symmetry-equivalent variants of each ASU
8
+ representative are enumerated)
9
+ * ``hkls_cart`` -- (M, 3) reference Cartesian G-vectors in 1/Angstroms
10
+ * ``thetas`` -- (M,) reference Bragg angles in radians
11
+
12
+ This module produces the same triplet from the pure-Python ``midas-hkls``
13
+ package, so users do not need the MIDAS C build to drive the forward model.
14
+
15
+ Example
16
+ -------
17
+ from midas_hkls import SpaceGroup, Lattice
18
+ import midas_diffract as md
19
+
20
+ sg = SpaceGroup.from_number(225) # FCC (Cu/Au/Ni)
21
+ lat = sg_lat = md.Lattice.for_system("cubic", a=4.08) # if you re-export
22
+ hkls_cart, thetas, hkls_int = md.hkls_for_forward_model(
23
+ sg, lat, wavelength_A=0.172979, two_theta_max_deg=15.0,
24
+ )
25
+ model = md.HEDMForwardModel(
26
+ hkls=hkls_cart, thetas=thetas, geometry=geom, hkls_int=hkls_int,
27
+ )
28
+ """
29
+ from __future__ import annotations
30
+
31
+ from math import cos, pi, sin
32
+ from typing import TYPE_CHECKING, Optional, Tuple
33
+
34
+ import numpy as np
35
+ import torch
36
+
37
+ if TYPE_CHECKING:
38
+ from midas_hkls import Lattice, SpaceGroup
39
+
40
+ DEG2RAD = pi / 180.0
41
+
42
+
43
+ def _cartesian_B_matrix(latc: "tuple[float, float, float, float, float, float]") -> np.ndarray:
44
+ """Reference reciprocal-lattice basis in Cartesian coords (column = a*, b*, c*).
45
+
46
+ Mirrors the B-matrix convention in
47
+ :meth:`midas_diffract.forward.HEDMForwardModel.correct_hkls_latc`, which
48
+ in turn is the C convention from ``CorrectHKLsLatC`` in
49
+ ``FF_HEDM/src/FitPosOrStrainsDoubleDataset.c:214-252``. Keeping the
50
+ convention bit-aligned guarantees that ``hkls_cart = B @ hkls_int^T``
51
+ here matches the model's strain path, so passing ``lattice_params=`` at
52
+ forward time recomputes the same numbers up to floating-point error.
53
+ """
54
+ a, b, c, alpha_d, beta_d, gamma_d = latc
55
+ alpha = alpha_d * DEG2RAD
56
+ beta = beta_d * DEG2RAD
57
+ gamma = gamma_d * DEG2RAD
58
+ sin_a, cos_a = sin(alpha), cos(alpha)
59
+ sin_b, cos_b = sin(beta), cos(beta)
60
+ sin_g, cos_g = sin(gamma), cos(gamma)
61
+
62
+ eps = 1e-7
63
+ gamma_pr = np.arccos(np.clip(
64
+ (cos_a * cos_b - cos_g) / (sin_a * sin_b + eps), -1 + eps, 1 - eps,
65
+ ))
66
+ beta_pr = np.arccos(np.clip(
67
+ (cos_g * cos_a - cos_b) / (sin_g * sin_a + eps), -1 + eps, 1 - eps,
68
+ ))
69
+ sin_beta_pr = np.sin(beta_pr)
70
+
71
+ vol = a * b * c * sin_a * sin_beta_pr * sin_g
72
+ a_pr = b * c * sin_a / (vol + eps)
73
+ b_pr = c * a * sin_b / (vol + eps)
74
+ c_pr = a * b * sin_g / (vol + eps)
75
+
76
+ B = np.array([
77
+ [a_pr, b_pr * np.cos(gamma_pr), c_pr * np.cos(beta_pr)],
78
+ [0.0, b_pr * np.sin(gamma_pr), -c_pr * sin_beta_pr * cos_a],
79
+ [0.0, 0.0, c_pr * sin_beta_pr * sin_a],
80
+ ])
81
+ return B
82
+
83
+
84
+ def hkls_for_forward_model(
85
+ space_group: "SpaceGroup",
86
+ lattice: "Lattice",
87
+ *,
88
+ wavelength_A: float,
89
+ two_theta_max_deg: Optional[float] = None,
90
+ d_min: Optional[float] = None,
91
+ expand_equivalents: bool = True,
92
+ dtype: torch.dtype = torch.float64,
93
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94
+ """Build (``hkls_cart``, ``thetas``, ``hkls_int``) for ``HEDMForwardModel``.
95
+
96
+ Wraps :func:`midas_hkls.generate_hkls` -- which returns ASU
97
+ representatives -- and (by default) expands each to all
98
+ Laue-equivalent integer Miller indices, so every detector spot is
99
+ enumerated. Then computes the Cartesian G-vectors using a B-matrix
100
+ convention that is consistent with the forward model's internal
101
+ strain-recompute path.
102
+
103
+ Parameters
104
+ ----------
105
+ space_group, lattice
106
+ From the ``midas-hkls`` package.
107
+ wavelength_A : float
108
+ X-ray wavelength in Angstroms.
109
+ two_theta_max_deg, d_min
110
+ Cutoff for reflection enumeration. At least one must be supplied.
111
+ See :func:`midas_hkls.generate_hkls`.
112
+ expand_equivalents : bool, default True
113
+ If True, return one row per Laue-equivalent reflection (matches
114
+ ``GetHKLList`` output and is what the forward model expects). If
115
+ False, return only ASU representatives -- useful for diagnostics.
116
+ dtype : torch.dtype
117
+ Output tensor dtype. Defaults to float64; the model casts to
118
+ float32 internally for the buffers but keeps double precision in
119
+ the input pipeline if requested.
120
+
121
+ Returns
122
+ -------
123
+ hkls_cart : Tensor (M, 3)
124
+ Cartesian reciprocal-space G-vectors in 1/Angstroms.
125
+ thetas : Tensor (M,)
126
+ Bragg angles in radians.
127
+ hkls_int : Tensor (M, 3)
128
+ Integer Miller indices (one row per spot), as floats so they can
129
+ be moved through ``torch.matmul`` cleanly inside the model.
130
+ """
131
+ try:
132
+ from midas_hkls import generate_hkls # type: ignore
133
+ except ImportError as exc:
134
+ raise ImportError(
135
+ "midas_diffract.hkls requires the optional 'midas-hkls' package. "
136
+ "Install with: pip install midas-hkls"
137
+ ) from exc
138
+
139
+ refs = generate_hkls(
140
+ space_group,
141
+ lattice,
142
+ wavelength_A=wavelength_A,
143
+ two_theta_max_deg=two_theta_max_deg,
144
+ d_min=d_min,
145
+ )
146
+ if not refs:
147
+ raise ValueError(
148
+ "midas_hkls.generate_hkls returned no reflections; check "
149
+ "wavelength / cutoff arguments."
150
+ )
151
+
152
+ rows = []
153
+ for r in refs:
154
+ if expand_equivalents:
155
+ rows.extend(space_group.equivalent_hkls(r.h, r.k, r.l))
156
+ else:
157
+ rows.append((r.h, r.k, r.l))
158
+ hkls_int_np = np.asarray(rows, dtype=np.float64)
159
+
160
+ B = _cartesian_B_matrix(
161
+ (lattice.a, lattice.b, lattice.c,
162
+ lattice.alpha, lattice.beta, lattice.gamma)
163
+ )
164
+ G_cart = hkls_int_np @ B.T # (M, 3) Cartesian G in 1/A
165
+
166
+ g_mag = np.linalg.norm(G_cart, axis=-1)
167
+ s = g_mag * wavelength_A / 2.0
168
+ if np.any(s > 1.0):
169
+ bad = int(np.sum(s > 1.0))
170
+ raise ValueError(
171
+ f"{bad} reflections fall outside the Bragg cutoff (|G|*lambda/2 > 1) "
172
+ "for the requested cutoff -- tighten two_theta_max_deg / d_min."
173
+ )
174
+ thetas_np = np.arcsin(s)
175
+
176
+ return (
177
+ torch.tensor(G_cart, dtype=dtype),
178
+ torch.tensor(thetas_np, dtype=dtype),
179
+ torch.tensor(hkls_int_np, dtype=dtype),
180
+ )
@@ -0,0 +1,494 @@
1
+ """Loss functions and spot matching utilities for HEDM optimization.
2
+
3
+ Two output modes:
4
+ NF-HEDM: Image comparison losses (NCC, L2, log-ratio, SSIM)
5
+ FF/pf-HEDM: Spot coordinate matching losses (L2, angular, Huber)
6
+
7
+ Also provides SpotAssigner for non-differentiable spot-to-spot matching
8
+ used in the FF/pf optimization loop.
9
+ """
10
+
11
+ import math
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Image comparison losses (NF-HEDM)
20
+ # ---------------------------------------------------------------------------
21
+
22
+ class ImageComparisonLoss(nn.Module):
23
+ """Loss for comparing predicted vs observed detector images.
24
+
25
+ Used in NF-HEDM where the forward model produces full predicted images
26
+ via Gaussian splatting and we compare to observed detector images.
27
+
28
+ Parameters
29
+ ----------
30
+ mode : str
31
+ ``"ncc"`` : Normalized Cross-Correlation (scale-invariant, recommended).
32
+ ``"l2"`` : Mean Squared Error.
33
+ ``"log_ratio"`` : Log-ratio loss (marginalizes unknown scale factor).
34
+ """
35
+
36
+ def __init__(self, mode: str = "ncc"):
37
+ super().__init__()
38
+ if mode not in ("ncc", "l2", "log_ratio"):
39
+ raise ValueError(f"Unknown mode: {mode!r}")
40
+ self.mode = mode
41
+
42
+ def forward(
43
+ self,
44
+ pred: torch.Tensor,
45
+ obs: torch.Tensor,
46
+ mask: Optional[torch.Tensor] = None,
47
+ ) -> torch.Tensor:
48
+ """Compute image comparison loss.
49
+
50
+ Parameters
51
+ ----------
52
+ pred : Tensor (..., H, W) or (..., F, H, W)
53
+ Predicted images.
54
+ obs : Tensor (same shape as pred)
55
+ Observed images.
56
+ mask : Tensor (same shape), optional
57
+ Binary mask. 1 = include pixel, 0 = ignore.
58
+
59
+ Returns
60
+ -------
61
+ Scalar loss tensor.
62
+ """
63
+ if mask is not None:
64
+ pred = pred * mask
65
+ obs = obs * mask
66
+
67
+ if self.mode == "ncc":
68
+ return self._ncc_loss(pred, obs)
69
+ elif self.mode == "l2":
70
+ return self._l2_loss(pred, obs)
71
+ elif self.mode == "log_ratio":
72
+ return self._log_ratio_loss(pred, obs)
73
+
74
+ @staticmethod
75
+ def _ncc_loss(pred: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
76
+ """Normalized Cross-Correlation loss (1 - NCC).
77
+
78
+ NCC = sum(pred * obs) / (||pred|| * ||obs||)
79
+ Loss = 1 - NCC (so 0 = perfect match)
80
+ """
81
+ # Flatten spatial dims for dot product
82
+ p = pred.reshape(pred.shape[0], -1) if pred.ndim > 1 else pred.unsqueeze(0)
83
+ o = obs.reshape(obs.shape[0], -1) if obs.ndim > 1 else obs.unsqueeze(0)
84
+
85
+ # General flatten: merge all but keep at least 1 batch dim
86
+ p_flat = pred.flatten()
87
+ o_flat = obs.flatten()
88
+
89
+ dot = torch.sum(p_flat * o_flat)
90
+ norm_p = torch.norm(p_flat).clamp(min=1e-12)
91
+ norm_o = torch.norm(o_flat).clamp(min=1e-12)
92
+ ncc = dot / (norm_p * norm_o)
93
+ return 1.0 - ncc
94
+
95
+ @staticmethod
96
+ def _l2_loss(pred: torch.Tensor, obs: torch.Tensor) -> torch.Tensor:
97
+ """Mean Squared Error loss."""
98
+ return torch.mean((pred - obs) ** 2)
99
+
100
+ @staticmethod
101
+ def _log_ratio_loss(
102
+ pred: torch.Tensor, obs: torch.Tensor, eps: float = 1e-6
103
+ ) -> torch.Tensor:
104
+ """Log-ratio loss: ||log(pred+eps) - log(obs+eps) - mu||^2.
105
+
106
+ Analytically marginalizes out the unknown global scaling factor
107
+ by subtracting the mean log-ratio (mu).
108
+ """
109
+ log_pred = torch.log(pred + eps)
110
+ log_obs = torch.log(obs + eps)
111
+ diff = log_pred - log_obs
112
+ mu = torch.mean(diff)
113
+ return torch.mean((diff - mu) ** 2)
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Spot coordinate matching losses (FF/pf-HEDM)
118
+ # ---------------------------------------------------------------------------
119
+
120
+ class SpotMatchingLoss(nn.Module):
121
+ """Loss for matching predicted spot coordinates to observed spot COMs.
122
+
123
+ Used in FF/pf-HEDM where the forward model predicts spot coordinates
124
+ (2theta, eta, omega) and we compare to observed center-of-mass positions.
125
+
126
+ The assignment of predicted-to-observed spots is done externally by
127
+ ``SpotAssigner`` (non-differentiable). Given fixed assignments, this
128
+ loss is fully differentiable w.r.t. predicted coordinates.
129
+
130
+ Parameters
131
+ ----------
132
+ metric : str
133
+ ``"l2"`` : Euclidean distance (sum of squared differences).
134
+ ``"huber"`` : Smooth L1 (robust to outliers).
135
+ ``"angular"``: Weighted angular distance with per-coordinate weights.
136
+ weights : Tensor (3,), optional
137
+ Per-coordinate weights for [2theta, eta, omega].
138
+ Default: equal weights [1, 1, 1].
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ metric: str = "l2",
144
+ weights: Optional[torch.Tensor] = None,
145
+ ):
146
+ super().__init__()
147
+ if metric not in ("l2", "huber", "angular"):
148
+ raise ValueError(f"Unknown metric: {metric!r}")
149
+ self.metric = metric
150
+ if weights is not None:
151
+ self.register_buffer("weights", weights.float())
152
+ else:
153
+ self.weights = None
154
+
155
+ def forward(
156
+ self,
157
+ pred_coords: torch.Tensor,
158
+ obs_coords: torch.Tensor,
159
+ spot_weights: Optional[torch.Tensor] = None,
160
+ ) -> torch.Tensor:
161
+ """Compute spot matching loss.
162
+
163
+ Parameters
164
+ ----------
165
+ pred_coords : Tensor (N_matched, 3)
166
+ Predicted spot coordinates.
167
+ obs_coords : Tensor (N_matched, 3)
168
+ Observed spot coordinates (same order as pred).
169
+ spot_weights : Tensor (N_matched,), optional
170
+ Per-spot weights (e.g., intensity-based).
171
+
172
+ Returns
173
+ -------
174
+ Scalar loss tensor.
175
+ """
176
+ diff = pred_coords - obs_coords
177
+
178
+ if self.weights is not None:
179
+ diff = diff * self.weights.unsqueeze(0)
180
+
181
+ if self.metric == "l2":
182
+ per_spot = torch.sum(diff ** 2, dim=-1)
183
+ elif self.metric == "huber":
184
+ per_spot = torch.sum(
185
+ torch.nn.functional.smooth_l1_loss(
186
+ diff, torch.zeros_like(diff), reduction="none"
187
+ ),
188
+ dim=-1,
189
+ )
190
+ elif self.metric == "angular":
191
+ per_spot = torch.sum(diff ** 2, dim=-1)
192
+
193
+ if spot_weights is not None:
194
+ per_spot = per_spot * spot_weights
195
+
196
+ return torch.mean(per_spot)
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Spot assignment (non-differentiable)
201
+ # ---------------------------------------------------------------------------
202
+
203
+ class SpotAssigner:
204
+ """Assign predicted spots to nearest observed spots.
205
+
206
+ This is a non-differentiable operation used in the FF/pf-HEDM
207
+ optimization loop: run periodically to update assignments, then
208
+ use ``SpotMatchingLoss`` with fixed assignments for gradient steps.
209
+
210
+ Matches by nearest neighbor in (2theta, eta, omega) space, optionally
211
+ restricted to the same ring number (HKL family).
212
+
213
+ Parameters
214
+ ----------
215
+ obs_coords : Tensor (N_obs, 3)
216
+ Observed spot coordinates (2theta, eta, omega) in radians.
217
+ obs_ring_numbers : Tensor (N_obs,), optional
218
+ Ring number for each observed spot. If provided, matching is
219
+ restricted to same-ring spots only.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ obs_coords: torch.Tensor,
225
+ obs_ring_numbers: Optional[torch.Tensor] = None,
226
+ ):
227
+ self.obs_coords = obs_coords
228
+ self.obs_ring_numbers = obs_ring_numbers
229
+
230
+ @torch.no_grad()
231
+ def assign(
232
+ self,
233
+ pred_coords: torch.Tensor,
234
+ pred_valid: torch.Tensor,
235
+ pred_ring_numbers: Optional[torch.Tensor] = None,
236
+ max_distance: float = 0.1,
237
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
238
+ """Find nearest observed spot for each valid predicted spot.
239
+
240
+ Parameters
241
+ ----------
242
+ pred_coords : Tensor (..., K, M, 3)
243
+ Predicted spot coordinates from ``predict_spot_coords``.
244
+ pred_valid : Tensor (..., K, M)
245
+ Validity mask.
246
+ pred_ring_numbers : Tensor (M,), optional
247
+ Ring number per HKL. If provided, matching restricted to same ring.
248
+ max_distance : float
249
+ Maximum matching distance in radians. Pairs beyond this are rejected.
250
+
251
+ Returns
252
+ -------
253
+ pred_matched : Tensor (N_matched, 3)
254
+ Matched predicted coordinates (detached, but index-aligned with obs_matched).
255
+ obs_matched : Tensor (N_matched, 3)
256
+ Matched observed coordinates.
257
+ pred_indices : Tensor (N_matched,) of long
258
+ Flat indices into the valid predicted spots (for gradient routing).
259
+ """
260
+ # Flatten predicted spots
261
+ flat_coords = pred_coords.reshape(-1, 3)
262
+ flat_valid = pred_valid.reshape(-1)
263
+
264
+ # Get valid indices
265
+ valid_idx = torch.nonzero(flat_valid > 0.5, as_tuple=False).squeeze(-1)
266
+ if valid_idx.numel() == 0:
267
+ empty = torch.zeros(0, 3, device=flat_coords.device)
268
+ return empty, empty, torch.zeros(0, dtype=torch.long, device=flat_coords.device)
269
+
270
+ valid_coords = flat_coords[valid_idx] # (V, 3)
271
+
272
+ # Compute distances to all observed spots
273
+ # valid_coords: (V, 3), obs_coords: (N_obs, 3)
274
+ # Use cdist for efficiency
275
+ dists = torch.cdist(valid_coords, self.obs_coords) # (V, N_obs)
276
+
277
+ # If ring numbers provided, mask cross-ring matches
278
+ if (pred_ring_numbers is not None and
279
+ self.obs_ring_numbers is not None):
280
+ # Expand ring numbers for valid spots
281
+ # pred_ring_numbers: (M,), repeat for K*M pattern
282
+ M = pred_ring_numbers.shape[0]
283
+ K_total = flat_valid.shape[0] // M if M > 0 else 0
284
+ if K_total > 0:
285
+ flat_rings = pred_ring_numbers.repeat(K_total)
286
+ valid_rings = flat_rings[valid_idx] # (V,)
287
+ ring_mismatch = (
288
+ valid_rings.unsqueeze(1) != self.obs_ring_numbers.unsqueeze(0)
289
+ )
290
+ dists = dists + ring_mismatch.float() * 1e6
291
+
292
+ # Nearest neighbor
293
+ min_dists, nn_idx = dists.min(dim=1) # (V,), (V,)
294
+
295
+ # Filter by max distance
296
+ keep = min_dists < max_distance
297
+ if not keep.any():
298
+ empty = torch.zeros(0, 3, device=flat_coords.device)
299
+ return empty, empty, torch.zeros(0, dtype=torch.long, device=flat_coords.device)
300
+
301
+ pred_matched = valid_coords[keep]
302
+ obs_matched = self.obs_coords[nn_idx[keep]]
303
+ pred_indices = valid_idx[keep]
304
+
305
+ return pred_matched, obs_matched, pred_indices
306
+
307
+
308
+ # ---------------------------------------------------------------------------
309
+ # Differentiable stress/strain (PyTorch)
310
+ # ---------------------------------------------------------------------------
311
+
312
+ def tensor_to_voigt(T: torch.Tensor) -> torch.Tensor:
313
+ """3x3 symmetric tensor to 6-vector Voigt-Mandel (sqrt(2) shear).
314
+
315
+ Fully differentiable.
316
+
317
+ Parameters
318
+ ----------
319
+ T : Tensor (..., 3, 3)
320
+
321
+ Returns
322
+ -------
323
+ Tensor (..., 6) -- [xx, yy, zz, sqrt2*yz, sqrt2*xz, sqrt2*xy]
324
+ """
325
+ s2 = math.sqrt(2.0)
326
+ return torch.stack([
327
+ T[..., 0, 0], T[..., 1, 1], T[..., 2, 2],
328
+ s2 * T[..., 1, 2], s2 * T[..., 0, 2], s2 * T[..., 0, 1],
329
+ ], dim=-1)
330
+
331
+
332
+ def voigt_to_tensor(v: torch.Tensor) -> torch.Tensor:
333
+ """6-vector Voigt-Mandel to 3x3 symmetric tensor.
334
+
335
+ Fully differentiable.
336
+
337
+ Parameters
338
+ ----------
339
+ v : Tensor (..., 6)
340
+
341
+ Returns
342
+ -------
343
+ Tensor (..., 3, 3)
344
+ """
345
+ s2i = 1.0 / math.sqrt(2.0)
346
+ xx, yy, zz = v[..., 0], v[..., 1], v[..., 2]
347
+ yz = v[..., 3] * s2i
348
+ xz = v[..., 4] * s2i
349
+ xy = v[..., 5] * s2i
350
+ row0 = torch.stack([xx, xy, xz], dim=-1)
351
+ row1 = torch.stack([xy, yy, yz], dim=-1)
352
+ row2 = torch.stack([xz, yz, zz], dim=-1)
353
+ return torch.stack([row0, row1, row2], dim=-2)
354
+
355
+
356
+ def cubic_stiffness_tensor(
357
+ C11: float, C12: float, C44: float,
358
+ dtype: torch.dtype = torch.float64,
359
+ device: torch.device = torch.device("cpu"),
360
+ ) -> torch.Tensor:
361
+ """6x6 stiffness matrix for cubic crystal (Voigt-Mandel notation).
362
+
363
+ Parameters
364
+ ----------
365
+ C11, C12, C44 : float
366
+ Independent elastic constants in GPa.
367
+
368
+ Returns
369
+ -------
370
+ Tensor (6, 6)
371
+ """
372
+ C = torch.zeros(6, 6, dtype=dtype, device=device)
373
+ C[0, 0] = C[1, 1] = C[2, 2] = C11
374
+ C[0, 1] = C[0, 2] = C[1, 0] = C[1, 2] = C[2, 0] = C[2, 1] = C12
375
+ C[3, 3] = C[4, 4] = C[5, 5] = 2.0 * C44 # Mandel convention
376
+ return C
377
+
378
+
379
+ def rotation_voigt_mandel(U: torch.Tensor) -> torch.Tensor:
380
+ """6x6 rotation matrix in Voigt-Mandel space. Fully differentiable.
381
+
382
+ Transforms vectorized symmetric tensors between frames:
383
+ {eps_rotated} = M @ {eps_original}
384
+
385
+ Parameters
386
+ ----------
387
+ U : Tensor (..., 3, 3) rotation matrix
388
+
389
+ Returns
390
+ -------
391
+ Tensor (..., 6, 6)
392
+ """
393
+ s2 = math.sqrt(2.0)
394
+ pairs = [(1, 2), (0, 2), (0, 1)]
395
+
396
+ M = torch.zeros(*U.shape[:-2], 6, 6, dtype=U.dtype, device=U.device)
397
+
398
+ # Normal-normal block
399
+ for i in range(3):
400
+ for j in range(3):
401
+ M[..., i, j] = U[..., i, j] ** 2
402
+
403
+ # Normal-shear coupling
404
+ for ci, (p, q) in enumerate(pairs):
405
+ for r in range(3):
406
+ M[..., r, 3 + ci] = s2 * U[..., r, p] * U[..., r, q]
407
+
408
+ # Shear-normal coupling
409
+ for ri, (p, q) in enumerate(pairs):
410
+ for c in range(3):
411
+ M[..., 3 + ri, c] = s2 * U[..., p, c] * U[..., q, c]
412
+
413
+ # Shear-shear block
414
+ for ri, (r1, r2) in enumerate(pairs):
415
+ for ci, (c1, c2) in enumerate(pairs):
416
+ M[..., 3 + ri, 3 + ci] = (
417
+ U[..., r1, c1] * U[..., r2, c2]
418
+ + U[..., r1, c2] * U[..., r2, c1]
419
+ )
420
+
421
+ return M
422
+
423
+
424
+ def hooke_stress(
425
+ strain: torch.Tensor,
426
+ stiffness: torch.Tensor,
427
+ orient: Optional[torch.Tensor] = None,
428
+ frame: str = "lab",
429
+ ) -> torch.Tensor:
430
+ """Differentiable Hooke's law: strain -> stress.
431
+
432
+ Parameters
433
+ ----------
434
+ strain : Tensor (..., 3, 3) or (..., 6)
435
+ Strain tensor (Voigt-Mandel or full 3x3).
436
+ stiffness : Tensor (6, 6)
437
+ Single-crystal stiffness in Voigt-Mandel notation, crystal frame.
438
+ orient : Tensor (..., 3, 3), optional
439
+ Orientation matrix. Required for ``frame="lab"``.
440
+ frame : str
441
+ ``"grain"``: strain and output in grain frame.
442
+ ``"lab"``: strain in lab frame; transform, apply C, transform back.
443
+
444
+ Returns
445
+ -------
446
+ Tensor (..., 3, 3) stress tensor.
447
+ """
448
+ if strain.shape[-1] == 3 and strain.shape[-2] == 3:
449
+ eps_v = tensor_to_voigt(strain)
450
+ else:
451
+ eps_v = strain
452
+
453
+ if frame == "grain":
454
+ sig_v = eps_v @ stiffness.T
455
+ return voigt_to_tensor(sig_v)
456
+
457
+ if orient is None:
458
+ raise ValueError("orient required for lab-frame computation")
459
+
460
+ M = rotation_voigt_mandel(orient) # (..., 6, 6)
461
+ Mt = M.transpose(-1, -2)
462
+ C_lab = Mt @ stiffness @ M # (..., 6, 6)
463
+ sig_v = (C_lab @ eps_v.unsqueeze(-1)).squeeze(-1)
464
+ return voigt_to_tensor(sig_v)
465
+
466
+
467
+ def volume_average_stress_constraint(
468
+ stresses: torch.Tensor,
469
+ volumes: torch.Tensor,
470
+ applied_stress: Optional[torch.Tensor] = None,
471
+ ) -> torch.Tensor:
472
+ """Differentiable volume-average stress constraint (FF-1).
473
+
474
+ Enforces: sum(V_g * sigma_g) / V_total = sigma_applied
475
+
476
+ Parameters
477
+ ----------
478
+ stresses : Tensor (N, 3, 3)
479
+ volumes : Tensor (N,)
480
+ applied_stress : Tensor (3, 3), optional. Default: zero.
481
+
482
+ Returns
483
+ -------
484
+ Tensor (N, 3, 3) corrected stresses.
485
+ """
486
+ if applied_stress is None:
487
+ applied_stress = torch.zeros(3, 3, dtype=stresses.dtype,
488
+ device=stresses.device)
489
+
490
+ V_total = volumes.sum()
491
+ w = volumes / V_total
492
+ sig_avg = (w[:, None, None] * stresses).sum(dim=0)
493
+ delta = applied_stress - sig_avg
494
+ return stresses + delta.unsqueeze(0)