midas-transforms 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- midas_transforms/__init__.py +46 -0
- midas_transforms/__main__.py +6 -0
- midas_transforms/bin_data/__init__.py +10 -0
- midas_transforms/bin_data/core.py +448 -0
- midas_transforms/cli.py +175 -0
- midas_transforms/device.py +65 -0
- midas_transforms/fit_setup/__init__.py +10 -0
- midas_transforms/fit_setup/core.py +422 -0
- midas_transforms/fit_setup/refine.py +121 -0
- midas_transforms/fit_setup/transform.py +244 -0
- midas_transforms/io/__init__.py +4 -0
- midas_transforms/io/binary.py +59 -0
- midas_transforms/io/csv.py +138 -0
- midas_transforms/io/zarr_io.py +87 -0
- midas_transforms/merge/__init__.py +10 -0
- midas_transforms/merge/core.py +525 -0
- midas_transforms/params.py +493 -0
- midas_transforms/pipeline.py +197 -0
- midas_transforms/radius/__init__.py +10 -0
- midas_transforms/radius/core.py +350 -0
- midas_transforms-0.1.0.dist-info/METADATA +137 -0
- midas_transforms-0.1.0.dist-info/RECORD +25 -0
- midas_transforms-0.1.0.dist-info/WHEEL +5 -0
- midas_transforms-0.1.0.dist-info/entry_points.txt +6 -0
- midas_transforms-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""midas-transforms: Pure-Python/PyTorch FF-HEDM transforms.
|
|
2
|
+
|
|
3
|
+
Drop-in replacement for the four C binaries that sit between peak-fitting
|
|
4
|
+
(``midas-peakfit``) and indexing (``midas-index``):
|
|
5
|
+
|
|
6
|
+
- ``MergeOverlappingPeaksAllZarr`` -> ``midas_transforms.merge_overlapping_peaks``
|
|
7
|
+
- ``CalcRadiusAllZarr`` -> ``midas_transforms.calc_radius``
|
|
8
|
+
- ``FitSetupZarr`` -> ``midas_transforms.fit_setup``
|
|
9
|
+
- ``SaveBinData`` -> ``midas_transforms.bin_data``
|
|
10
|
+
|
|
11
|
+
Two equally-supported usage modes:
|
|
12
|
+
|
|
13
|
+
**Mode 1 - per-stage (round-trips through disk, like the C binaries):**
|
|
14
|
+
|
|
15
|
+
from midas_transforms import merge_overlapping_peaks, calc_radius, fit_setup, bin_data
|
|
16
|
+
merge_overlapping_peaks(zarr_path="...", result_folder="...", device="cuda")
|
|
17
|
+
calc_radius(result_folder="...", device="cuda")
|
|
18
|
+
fit_setup(result_folder="...", device="cuda")
|
|
19
|
+
bin_data(result_folder="...", device="cuda")
|
|
20
|
+
|
|
21
|
+
**Mode 2 - chained Pipeline (intermediates stay on GPU, only final outputs written):**
|
|
22
|
+
|
|
23
|
+
from midas_transforms import Pipeline
|
|
24
|
+
pipe = Pipeline.from_zarr(zarr_path, device="cuda")
|
|
25
|
+
result = pipe.run()
|
|
26
|
+
pipe.dump(out_dir)
|
|
27
|
+
|
|
28
|
+
See ``dev/implementation_plan.md`` for design and roadmap.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
__version__ = "0.1.0"
|
|
32
|
+
|
|
33
|
+
from .merge import merge_overlapping_peaks
|
|
34
|
+
from .radius import calc_radius
|
|
35
|
+
from .fit_setup import fit_setup
|
|
36
|
+
from .bin_data import bin_data
|
|
37
|
+
from .pipeline import Pipeline
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"merge_overlapping_peaks",
|
|
41
|
+
"calc_radius",
|
|
42
|
+
"fit_setup",
|
|
43
|
+
"bin_data",
|
|
44
|
+
"Pipeline",
|
|
45
|
+
"__version__",
|
|
46
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""bin_data — replaces the C ``SaveBinData`` binary.
|
|
2
|
+
|
|
3
|
+
Reads ``InputAll.csv``, ``InputAllExtraInfoFittingAll.csv``, ``paramstest.txt``;
|
|
4
|
+
writes ``Spots.bin``, ``ExtraInfo.bin``, and (unless ``NoSaveAll==1``)
|
|
5
|
+
``Data.bin`` + ``nData.bin``.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .core import bin_data, BinDataResult
|
|
9
|
+
|
|
10
|
+
__all__ = ["bin_data", "BinDataResult"]
|
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
"""bin_data: drop-in replacement for ``SaveBinData``.
|
|
2
|
+
|
|
3
|
+
The C source is `FF_HEDM/src/SaveBinData.c` (341 LoC). The torch port:
|
|
4
|
+
|
|
5
|
+
1. Reads ``InputAll.csv`` (8 cols) and ``InputAllExtraInfoFittingAll.csv``
|
|
6
|
+
(18 cols) into tensors on ``device``.
|
|
7
|
+
2. Computes the per-spot ``RadiusDistIdeal = radius_obs - ring_radii[ring_nr]``.
|
|
8
|
+
3. Writes ``Spots.bin`` (Nx9 float64) and ``ExtraInfo.bin`` (Nx16 float64).
|
|
9
|
+
4. If ``NoSaveAll == 0``: builds the per-(ring, eta-bin, ome-bin) lookup
|
|
10
|
+
table and writes ``Data.bin`` (int32 ragged) and ``nData.bin`` (count/offset
|
|
11
|
+
pairs, int32).
|
|
12
|
+
|
|
13
|
+
The (eta, ome) bin assignment per spot uses the C margin formulae
|
|
14
|
+
(``SaveBinData.c:265-271``):
|
|
15
|
+
|
|
16
|
+
omemargin = MarginOme + 0.5 * StepSizeOrient / |sin(eta_deg)|
|
|
17
|
+
etamargin = rad2deg * atan(MarginEta / RingRadii[ring]) + 0.5 * StepSizeOrient
|
|
18
|
+
|
|
19
|
+
Then for each spot, all bins in ``[iEtaMin..iEtaMax] × [iOmeMin..iOmeMax]``
|
|
20
|
+
mod ``n_eta``, ``n_ome`` receive the spot's index.
|
|
21
|
+
|
|
22
|
+
The vectorised torch path emits one ``(spot_idx, ring, iEta, iOme)`` tuple
|
|
23
|
+
per (spot, eta-bin, ome-bin) triple, sorts by ``(ring, iEta, iOme, spot_idx)``,
|
|
24
|
+
then ``unique_consecutive`` to recover counts and offsets per bin. The
|
|
25
|
+
``spot_idx`` secondary key makes the output bit-stable across runs and
|
|
26
|
+
between CPU and GPU (the C version's order is implicitly the spot
|
|
27
|
+
iteration order, which is row order in InputAll.csv == ascending
|
|
28
|
+
``spot_idx`` for non-empty bins).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
from dataclasses import dataclass, field
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Optional, Union
|
|
37
|
+
|
|
38
|
+
import numpy as np
|
|
39
|
+
import torch
|
|
40
|
+
|
|
41
|
+
from ..device import resolve_device, resolve_dtype
|
|
42
|
+
from ..io import binary as bio
|
|
43
|
+
from ..io import csv as csv_io
|
|
44
|
+
from ..params import ParamsTest, read_paramstest
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class BinDataResult:
|
|
49
|
+
"""In-memory result of the bin_data stage. Tensors live on ``device``.
|
|
50
|
+
|
|
51
|
+
Used by ``Pipeline`` to pass to a downstream consumer (e.g. ``midas-index``)
|
|
52
|
+
without going through disk.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
spots: torch.Tensor # (N, 9) float64
|
|
56
|
+
extra_info: torch.Tensor # (N, 16) float64
|
|
57
|
+
data: Optional[torch.Tensor] = None # (T,) int32, or None when NoSaveAll==1
|
|
58
|
+
ndata: Optional[torch.Tensor] = None # (M, 2) int32, M = n_ring * n_eta * n_ome
|
|
59
|
+
n_ring_bins: int = 0
|
|
60
|
+
n_eta_bins: int = 0
|
|
61
|
+
n_ome_bins: int = 0
|
|
62
|
+
paramstest: Optional[ParamsTest] = field(default=None)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
# Pure-tensor kernels
|
|
67
|
+
# ---------------------------------------------------------------------------
|
|
68
|
+
|
|
69
|
+
_DEG2RAD = math.pi / 180.0
|
|
70
|
+
_RAD2DEG = 180.0 / math.pi
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ---- libm.fma binding (for byte-exact y² + z² matching the C build) -----
|
|
74
|
+
|
|
75
|
+
def _resolve_libm_fma():
|
|
76
|
+
"""Return a vectorised libm ``fma`` if available, else ``None``.
|
|
77
|
+
|
|
78
|
+
Used to reproduce clang's default ``-ffp-contract=on`` FMA fusion of
|
|
79
|
+
``y*y + z*z`` so that ``Spots.bin`` col 8 (RadiusDistIdeal) is byte-exact
|
|
80
|
+
against the C ``SaveBinData`` output. Falls back to plain ``y*y + z*z``
|
|
81
|
+
when libm or its ``fma`` symbol is unavailable (Windows, oddball libcs).
|
|
82
|
+
"""
|
|
83
|
+
import ctypes
|
|
84
|
+
import ctypes.util
|
|
85
|
+
name = ctypes.util.find_library("m") or "libm.dylib"
|
|
86
|
+
try:
|
|
87
|
+
libm = ctypes.CDLL(name)
|
|
88
|
+
libm.fma.restype = ctypes.c_double
|
|
89
|
+
libm.fma.argtypes = [ctypes.c_double, ctypes.c_double, ctypes.c_double]
|
|
90
|
+
except (OSError, AttributeError):
|
|
91
|
+
return None
|
|
92
|
+
return np.vectorize(libm.fma, otypes=[np.float64])
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
_FMA = _resolve_libm_fma()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _radius_dist_ideal_numpy(
|
|
99
|
+
yl: np.ndarray, zl: np.ndarray, ring_nr: np.ndarray, ring_radii: np.ndarray,
|
|
100
|
+
) -> np.ndarray:
|
|
101
|
+
"""Compute ``sqrt(y*y + z*z) - RingRadii[ring]`` using FMA when available.
|
|
102
|
+
|
|
103
|
+
The byte-exact CPU path. The ``y*y + z*z`` fused expression is what the
|
|
104
|
+
C ``SaveBinData`` code compiles to under clang's default FMA contraction.
|
|
105
|
+
"""
|
|
106
|
+
z2 = zl * zl
|
|
107
|
+
if _FMA is not None:
|
|
108
|
+
s = _FMA(yl, yl, z2)
|
|
109
|
+
else:
|
|
110
|
+
s = yl * yl + z2
|
|
111
|
+
return np.sqrt(s) - ring_radii[ring_nr]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _build_ring_radii(p: ParamsTest, max_n_rings: int = 500) -> torch.Tensor:
|
|
115
|
+
"""Return a 1-D tensor ``[max_n_rings]`` of radii indexed by ring number.
|
|
116
|
+
|
|
117
|
+
Mirrors ``SaveBinData.c:170-174`` — ``RingRadii[RingNumbers[i]] = RingRadiiUser[i]``
|
|
118
|
+
with everything else zeroed.
|
|
119
|
+
"""
|
|
120
|
+
out = np.zeros(max_n_rings, dtype=np.float64)
|
|
121
|
+
for r, rad in zip(p.RingNumbers, p.RingRadii):
|
|
122
|
+
if 0 <= r < max_n_rings:
|
|
123
|
+
out[r] = rad
|
|
124
|
+
return torch.from_numpy(out)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _compute_radius_dist_ideal(
|
|
128
|
+
spots_first8: torch.Tensor, ring_radii: torch.Tensor
|
|
129
|
+
) -> torch.Tensor:
|
|
130
|
+
"""Per spot: distance from observed radius (col 8 of InputAll, after Ttheta) to ideal ring radius.
|
|
131
|
+
|
|
132
|
+
InputAll.csv columns are
|
|
133
|
+
0=YLab, 1=ZLab, 2=Omega, 3=GrainRadius, 4=SpotID,
|
|
134
|
+
5=RingNumber, 6=Eta, 7=Ttheta
|
|
135
|
+
The radius-distance is ``sqrt(YLab^2 + ZLab^2) - RingRadii[RingNumber]``.
|
|
136
|
+
|
|
137
|
+
The C version (``CalcDistanceIdealRing``) recomputes this per spot before
|
|
138
|
+
writing; we replicate the same formula.
|
|
139
|
+
"""
|
|
140
|
+
yl = spots_first8[:, 0]
|
|
141
|
+
zl = spots_first8[:, 1]
|
|
142
|
+
ring_nr = spots_first8[:, 5].long()
|
|
143
|
+
radius = torch.sqrt(yl * yl + zl * zl)
|
|
144
|
+
ideal = ring_radii.to(radius.device)[ring_nr]
|
|
145
|
+
return radius - ideal
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _bin_assignment(
|
|
149
|
+
spots_first8: torch.Tensor,
|
|
150
|
+
ring_radii: torch.Tensor,
|
|
151
|
+
margin_ome: float,
|
|
152
|
+
margin_eta: float,
|
|
153
|
+
eta_bin_size: float,
|
|
154
|
+
ome_bin_size: float,
|
|
155
|
+
step_size_orient: float,
|
|
156
|
+
):
|
|
157
|
+
"""Vectorised per-spot bin assignment.
|
|
158
|
+
|
|
159
|
+
Returns: (data_idx, ring_idx, eta_idx, ome_idx) — flattened triples.
|
|
160
|
+
|
|
161
|
+
Notes (from SaveBinData.c:260-289):
|
|
162
|
+
- Margins are spot-specific; the eta margin depends on ring radius and
|
|
163
|
+
the ome margin depends on |sin(eta)|.
|
|
164
|
+
- Bins wrap [0, 360) modulo n_eta and n_ome.
|
|
165
|
+
- The C code does ``omemin = 180 + omega - omemargin`` and floors by
|
|
166
|
+
bin size, generating an integer range [iOmeMin, iOmeMax]. We replicate.
|
|
167
|
+
"""
|
|
168
|
+
device = spots_first8.device
|
|
169
|
+
dtype = spots_first8.dtype
|
|
170
|
+
|
|
171
|
+
omega = spots_first8[:, 2]
|
|
172
|
+
ring_nr = spots_first8[:, 5].long()
|
|
173
|
+
eta = spots_first8[:, 6]
|
|
174
|
+
|
|
175
|
+
# Filter spots to those whose ring has a configured radius (>0).
|
|
176
|
+
rrng = ring_radii.to(device=device, dtype=dtype)
|
|
177
|
+
rad_for_ring = rrng[ring_nr]
|
|
178
|
+
keep = rad_for_ring > 0
|
|
179
|
+
if not keep.any():
|
|
180
|
+
empty = torch.empty((0,), dtype=torch.long, device=device)
|
|
181
|
+
return empty, empty, empty, empty
|
|
182
|
+
|
|
183
|
+
spot_idx_all = torch.arange(spots_first8.shape[0], device=device)[keep]
|
|
184
|
+
omega = omega[keep]
|
|
185
|
+
eta = eta[keep]
|
|
186
|
+
ring_nr = ring_nr[keep]
|
|
187
|
+
rad_for_ring = rad_for_ring[keep]
|
|
188
|
+
|
|
189
|
+
# Per-spot ome margin: avoid division by zero when sin(eta)=0 (eta=0 or 180)
|
|
190
|
+
sin_eta_abs = torch.abs(torch.sin(eta * _DEG2RAD))
|
|
191
|
+
sin_eta_abs = torch.clamp(sin_eta_abs, min=1e-12)
|
|
192
|
+
ome_margin_per = margin_ome + 0.5 * step_size_orient / sin_eta_abs
|
|
193
|
+
eta_margin_per = _RAD2DEG * torch.atan(margin_eta / rad_for_ring) + 0.5 * step_size_orient
|
|
194
|
+
|
|
195
|
+
omemin = 180.0 + omega - ome_margin_per
|
|
196
|
+
omemax = 180.0 + omega + ome_margin_per
|
|
197
|
+
etamin = 180.0 + eta - eta_margin_per
|
|
198
|
+
etamax = 180.0 + eta + eta_margin_per
|
|
199
|
+
|
|
200
|
+
iome_min = torch.floor(omemin / ome_bin_size).long()
|
|
201
|
+
iome_max = torch.floor(omemax / ome_bin_size).long()
|
|
202
|
+
ieta_min = torch.floor(etamin / eta_bin_size).long()
|
|
203
|
+
ieta_max = torch.floor(etamax / eta_bin_size).long()
|
|
204
|
+
|
|
205
|
+
# Per-spot range count
|
|
206
|
+
n_eta_per = (ieta_max - ieta_min + 1).clamp(min=0)
|
|
207
|
+
n_ome_per = (iome_max - iome_min + 1).clamp(min=0)
|
|
208
|
+
n_pairs_per = n_eta_per * n_ome_per
|
|
209
|
+
|
|
210
|
+
total = int(n_pairs_per.sum().item())
|
|
211
|
+
if total == 0:
|
|
212
|
+
empty = torch.empty((0,), dtype=torch.long, device=device)
|
|
213
|
+
return empty, empty, empty, empty
|
|
214
|
+
|
|
215
|
+
# Build flat arrays of (spot_idx, ring, iEta, iOme).
|
|
216
|
+
# We do this with cumsum-based offsets and segment indices.
|
|
217
|
+
cum = torch.cumsum(n_pairs_per, dim=0)
|
|
218
|
+
seg_starts = cum - n_pairs_per
|
|
219
|
+
|
|
220
|
+
# Per-output-row source-spot index
|
|
221
|
+
# (Equivalent to ``np.repeat(spot_idx, n_pairs_per)``, but in torch.)
|
|
222
|
+
out_spot_idx = torch.repeat_interleave(spot_idx_all, n_pairs_per)
|
|
223
|
+
out_ring = torch.repeat_interleave(ring_nr, n_pairs_per)
|
|
224
|
+
out_n_eta = torch.repeat_interleave(n_eta_per, n_pairs_per)
|
|
225
|
+
out_ieta_min = torch.repeat_interleave(ieta_min, n_pairs_per)
|
|
226
|
+
out_iome_min = torch.repeat_interleave(iome_min, n_pairs_per)
|
|
227
|
+
|
|
228
|
+
pos = torch.arange(total, device=device) - torch.repeat_interleave(seg_starts, n_pairs_per)
|
|
229
|
+
eta_off = pos // out_n_eta.clamp(min=1)
|
|
230
|
+
# NOTE: above we want the layout (iEta outer, iOme inner) matching C
|
|
231
|
+
# ``for iEta0 ... for iOme0``. So:
|
|
232
|
+
# pos = i_eta * n_ome_per + i_ome
|
|
233
|
+
# We need the per-spot n_ome to recover the inner iteration.
|
|
234
|
+
out_n_ome = torch.repeat_interleave(n_ome_per, n_pairs_per)
|
|
235
|
+
eta_off = pos // out_n_ome.clamp(min=1)
|
|
236
|
+
ome_off = pos - eta_off * out_n_ome.clamp(min=1)
|
|
237
|
+
|
|
238
|
+
out_ieta = out_ieta_min + eta_off
|
|
239
|
+
out_iome = out_iome_min + ome_off
|
|
240
|
+
|
|
241
|
+
return out_spot_idx, out_ring, out_ieta, out_iome
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _bin_to_data_ndata(
|
|
245
|
+
out_spot_idx: torch.Tensor,
|
|
246
|
+
out_ring: torch.Tensor,
|
|
247
|
+
out_ieta: torch.Tensor,
|
|
248
|
+
out_iome: torch.Tensor,
|
|
249
|
+
n_ring_bins: int,
|
|
250
|
+
n_eta_bins: int,
|
|
251
|
+
n_ome_bins: int,
|
|
252
|
+
):
|
|
253
|
+
"""Convert per-(spot, eta, ome) triples to ``(Data, nData)`` arrays.
|
|
254
|
+
|
|
255
|
+
Layout matches ``SaveBinData.c:308-322`` — ring-major, eta-major, ome-major.
|
|
256
|
+
Bins wrap modulo n_eta / n_ome.
|
|
257
|
+
"""
|
|
258
|
+
device = out_spot_idx.device
|
|
259
|
+
|
|
260
|
+
# Modulo wrap (negative-aware).
|
|
261
|
+
ieta_mod = (out_ieta % n_eta_bins + n_eta_bins) % n_eta_bins
|
|
262
|
+
iome_mod = (out_iome % n_ome_bins + n_ome_bins) % n_ome_bins
|
|
263
|
+
|
|
264
|
+
# iRing in C is `ringnr - 1`; ring-bin axis is [0, HighestRingNo).
|
|
265
|
+
iring = out_ring - 1
|
|
266
|
+
|
|
267
|
+
# Drop entries whose ring index is out of range. (Defensive; ring_nr is
|
|
268
|
+
# filtered upstream.)
|
|
269
|
+
mask = (iring >= 0) & (iring < n_ring_bins)
|
|
270
|
+
iring = iring[mask]
|
|
271
|
+
ieta_mod = ieta_mod[mask]
|
|
272
|
+
iome_mod = iome_mod[mask]
|
|
273
|
+
out_spot_idx = out_spot_idx[mask]
|
|
274
|
+
|
|
275
|
+
# Composite bin id for sorting (ring outer, eta middle, ome inner).
|
|
276
|
+
bin_id = (iring.long() * n_eta_bins + ieta_mod.long()) * n_ome_bins + iome_mod.long()
|
|
277
|
+
# Stable sort by (bin_id, spot_idx) for deterministic insertion order.
|
|
278
|
+
composite = bin_id * (out_spot_idx.max().item() + 2 if out_spot_idx.numel() else 1) + out_spot_idx
|
|
279
|
+
order = torch.argsort(composite, stable=True)
|
|
280
|
+
sorted_bin_id = bin_id[order]
|
|
281
|
+
sorted_spot_idx = out_spot_idx[order]
|
|
282
|
+
|
|
283
|
+
# Counts per bin: scatter add ones.
|
|
284
|
+
total_bins = n_ring_bins * n_eta_bins * n_ome_bins
|
|
285
|
+
counts = torch.zeros(total_bins, dtype=torch.int64, device=device)
|
|
286
|
+
counts.scatter_add_(0, sorted_bin_id, torch.ones_like(sorted_bin_id))
|
|
287
|
+
|
|
288
|
+
# Offsets are cumulative by visit order. Since bins are in
|
|
289
|
+
# ring-major / eta-major / ome-major order in the output, the offset
|
|
290
|
+
# for each bin is the running total of all preceding bins.
|
|
291
|
+
offsets = torch.zeros(total_bins, dtype=torch.int64, device=device)
|
|
292
|
+
offsets[1:] = torch.cumsum(counts[:-1], dim=0)
|
|
293
|
+
|
|
294
|
+
# Pack ndata as (count, offset) per bin.
|
|
295
|
+
ndata = torch.stack([counts, offsets], dim=1).to(torch.int32)
|
|
296
|
+
data = sorted_spot_idx.to(torch.int32)
|
|
297
|
+
return data, ndata
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
# ---------------------------------------------------------------------------
|
|
301
|
+
# Public entry point
|
|
302
|
+
# ---------------------------------------------------------------------------
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def bin_data(
|
|
306
|
+
result_folder: Union[str, Path] = ".",
|
|
307
|
+
*,
|
|
308
|
+
inputall_csv: Optional[Union[str, Path]] = None,
|
|
309
|
+
inputall_extra_csv: Optional[Union[str, Path]] = None,
|
|
310
|
+
paramstest_path: Optional[Union[str, Path]] = None,
|
|
311
|
+
out_dir: Optional[Union[str, Path]] = None,
|
|
312
|
+
paramstest: Optional[ParamsTest] = None,
|
|
313
|
+
spots_inputall: Optional[np.ndarray] = None,
|
|
314
|
+
extra_inputall: Optional[np.ndarray] = None,
|
|
315
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
316
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
317
|
+
write: bool = True,
|
|
318
|
+
) -> BinDataResult:
|
|
319
|
+
"""Run the binning stage. Drop-in replacement for the C ``SaveBinData`` binary.
|
|
320
|
+
|
|
321
|
+
Defaults match the C binary's argv-less convention: read ``InputAll.csv``,
|
|
322
|
+
``InputAllExtraInfoFittingAll.csv``, and ``paramstest.txt`` from the
|
|
323
|
+
current directory; write outputs to the same directory.
|
|
324
|
+
|
|
325
|
+
Parameters
|
|
326
|
+
----------
|
|
327
|
+
result_folder
|
|
328
|
+
Directory to read inputs from / write outputs to (when no override).
|
|
329
|
+
inputall_csv, inputall_extra_csv, paramstest_path
|
|
330
|
+
Optional input file overrides.
|
|
331
|
+
out_dir
|
|
332
|
+
Optional output directory override.
|
|
333
|
+
paramstest, spots_inputall, extra_inputall
|
|
334
|
+
Optional in-memory inputs (used by ``Pipeline``).
|
|
335
|
+
device, dtype
|
|
336
|
+
Torch device / dtype.
|
|
337
|
+
write
|
|
338
|
+
If ``False``, skip disk writes and return only the in-memory result.
|
|
339
|
+
|
|
340
|
+
Returns
|
|
341
|
+
-------
|
|
342
|
+
BinDataResult
|
|
343
|
+
"""
|
|
344
|
+
rf = Path(result_folder)
|
|
345
|
+
out_dir = Path(out_dir) if out_dir is not None else rf
|
|
346
|
+
|
|
347
|
+
dev = resolve_device(device)
|
|
348
|
+
dt = resolve_dtype(dev, dtype)
|
|
349
|
+
|
|
350
|
+
# Inputs.
|
|
351
|
+
if paramstest is None:
|
|
352
|
+
ppath = paramstest_path if paramstest_path is not None else rf / "paramstest.txt"
|
|
353
|
+
paramstest = read_paramstest(ppath)
|
|
354
|
+
|
|
355
|
+
if spots_inputall is None:
|
|
356
|
+
ipath = inputall_csv if inputall_csv is not None else rf / "InputAll.csv"
|
|
357
|
+
spots_inputall = csv_io.read_inputall_csv(ipath)
|
|
358
|
+
if extra_inputall is None:
|
|
359
|
+
ipath = inputall_extra_csv if inputall_extra_csv is not None else rf / "InputAllExtraInfoFittingAll.csv"
|
|
360
|
+
extra_inputall = csv_io.read_inputall_extra_csv(ipath)
|
|
361
|
+
|
|
362
|
+
if spots_inputall.shape[0] != extra_inputall.shape[0]:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"InputAll ({spots_inputall.shape[0]} rows) and "
|
|
365
|
+
f"InputAllExtraInfoFittingAll ({extra_inputall.shape[0]} rows) "
|
|
366
|
+
"must agree on row count."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
n_spots = spots_inputall.shape[0]
|
|
370
|
+
if n_spots == 0:
|
|
371
|
+
raise ValueError("No spots in InputAll.csv. Aborting.")
|
|
372
|
+
|
|
373
|
+
# Move to device.
|
|
374
|
+
spots_t = torch.from_numpy(spots_inputall.astype(np.float64)).to(device=dev, dtype=dt)
|
|
375
|
+
extra_t = torch.from_numpy(extra_inputall.astype(np.float64)).to(device=dev, dtype=dt)
|
|
376
|
+
|
|
377
|
+
# Compute Spots.bin layout (cols 0-7 + RadiusDistIdeal).
|
|
378
|
+
ring_radii = _build_ring_radii(paramstest).to(device=dev, dtype=dt)
|
|
379
|
+
# On the byte-exact CPU path, replicate the C compiler's FMA-fused
|
|
380
|
+
# ``y*y + z*z`` (clang -O3 emits a fused-multiply-add on default
|
|
381
|
+
# contraction). Without FMA, plain ``y*y + z*z`` rounds twice and
|
|
382
|
+
# differs by 1 ULP in ~1% of spots from the C output.
|
|
383
|
+
if dev.type == "cpu" and dt == torch.float64:
|
|
384
|
+
yl_np = spots_inputall[:, 0].astype(np.float64)
|
|
385
|
+
zl_np = spots_inputall[:, 1].astype(np.float64)
|
|
386
|
+
ring_np = spots_inputall[:, 5].astype(np.int64)
|
|
387
|
+
ring_radii_np = np.zeros(500, dtype=np.float64)
|
|
388
|
+
for r, rad in zip(paramstest.RingNumbers, paramstest.RingRadii):
|
|
389
|
+
if 0 <= r < 500:
|
|
390
|
+
ring_radii_np[r] = rad
|
|
391
|
+
rad_dist_np = _radius_dist_ideal_numpy(yl_np, zl_np, ring_np, ring_radii_np)
|
|
392
|
+
rad_dist = torch.from_numpy(rad_dist_np)
|
|
393
|
+
else:
|
|
394
|
+
rad_dist = _compute_radius_dist_ideal(spots_t, ring_radii)
|
|
395
|
+
spots_out = torch.cat([spots_t, rad_dist.unsqueeze(1)], dim=1) # (N, 9)
|
|
396
|
+
# ExtraInfo.bin is 16 cols; drop CSV cols 14 and 15 (the C version's dummy0/dummy1).
|
|
397
|
+
# See SaveBinData.c — sscanf maps CSV[16, 17] to AllSpots[14, 15].
|
|
398
|
+
if extra_t.shape[1] == 18:
|
|
399
|
+
extra_out = torch.cat([extra_t[:, :14], extra_t[:, 16:18]], dim=1)
|
|
400
|
+
elif extra_t.shape[1] == 16:
|
|
401
|
+
extra_out = extra_t
|
|
402
|
+
else:
|
|
403
|
+
raise ValueError(
|
|
404
|
+
f"InputAllExtraInfoFittingAll must have 16 or 18 cols, got {extra_t.shape[1]}"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
if write:
|
|
408
|
+
bio.write_spots_bin(out_dir / "Spots.bin", spots_out.detach().cpu().numpy().astype(np.float64))
|
|
409
|
+
bio.write_extrainfo_bin(out_dir / "ExtraInfo.bin", extra_out.detach().cpu().numpy().astype(np.float64))
|
|
410
|
+
|
|
411
|
+
if paramstest.NoSaveAll == 1:
|
|
412
|
+
return BinDataResult(
|
|
413
|
+
spots=spots_out, extra_info=extra_out,
|
|
414
|
+
paramstest=paramstest,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Determine bin counts.
|
|
418
|
+
n_ring_bins = paramstest.highest_ring_no
|
|
419
|
+
n_eta_bins = math.ceil(360.0 / paramstest.EtaBinSize)
|
|
420
|
+
n_ome_bins = math.ceil(360.0 / paramstest.OmeBinSize)
|
|
421
|
+
|
|
422
|
+
out_spot_idx, out_ring, out_ieta, out_iome = _bin_assignment(
|
|
423
|
+
spots_t,
|
|
424
|
+
ring_radii,
|
|
425
|
+
margin_ome=paramstest.MarginOme,
|
|
426
|
+
margin_eta=paramstest.MarginEta,
|
|
427
|
+
eta_bin_size=paramstest.EtaBinSize,
|
|
428
|
+
ome_bin_size=paramstest.OmeBinSize,
|
|
429
|
+
step_size_orient=paramstest.StepSizeOrient,
|
|
430
|
+
)
|
|
431
|
+
data, ndata = _bin_to_data_ndata(
|
|
432
|
+
out_spot_idx, out_ring, out_ieta, out_iome,
|
|
433
|
+
n_ring_bins=n_ring_bins, n_eta_bins=n_eta_bins, n_ome_bins=n_ome_bins,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
if write:
|
|
437
|
+
bio.write_data_ndata_bin(
|
|
438
|
+
out_dir / "Data.bin", out_dir / "nData.bin",
|
|
439
|
+
data.detach().cpu().numpy().astype(np.int32),
|
|
440
|
+
ndata.detach().cpu().numpy().astype(np.int32),
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
return BinDataResult(
|
|
444
|
+
spots=spots_out, extra_info=extra_out,
|
|
445
|
+
data=data, ndata=ndata,
|
|
446
|
+
n_ring_bins=n_ring_bins, n_eta_bins=n_eta_bins, n_ome_bins=n_ome_bins,
|
|
447
|
+
paramstest=paramstest,
|
|
448
|
+
)
|
midas_transforms/cli.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""CLI dispatch for ``midas-transforms``.
|
|
2
|
+
|
|
3
|
+
The umbrella command ``midas-transforms <stage> [args]`` plus four
|
|
4
|
+
sub-CLIs that mirror the C-binary argv contracts:
|
|
5
|
+
|
|
6
|
+
``midas-merge-peaks <zarr_path>``
|
|
7
|
+
``midas-calc-radius <zarr_path>``
|
|
8
|
+
``midas-fit-setup <zarr_path>``
|
|
9
|
+
``midas-bin-data`` (no positional args, reads from the cwd)
|
|
10
|
+
|
|
11
|
+
Each sub-CLI is a thin wrapper around the corresponding library function.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import List, Optional
|
|
20
|
+
|
|
21
|
+
from . import __version__
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _common_argparser(prog: str, description: str) -> argparse.ArgumentParser:
|
|
25
|
+
p = argparse.ArgumentParser(prog=prog, description=description)
|
|
26
|
+
p.add_argument("--device", choices=["cpu", "cuda", "mps"], default=None)
|
|
27
|
+
p.add_argument("--dtype", choices=["float32", "float64"], default=None)
|
|
28
|
+
p.add_argument("--version", action="version", version=f"midas-transforms {__version__}")
|
|
29
|
+
return p
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def merge_main(argv: Optional[List[str]] = None) -> int:
|
|
33
|
+
p = _common_argparser("midas-merge-peaks", "Frame-by-frame mutual-nearest merge of consolidated peakfit output.")
|
|
34
|
+
p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
|
|
35
|
+
p.add_argument("--result-folder", default=None,
|
|
36
|
+
help="Override the result folder (default: directory of zarr_path).")
|
|
37
|
+
p.add_argument("--allpeaks-ps-bin", default=None,
|
|
38
|
+
help="Override the AllPeaks_PS.bin path "
|
|
39
|
+
"(default: <result-folder>/Temp/AllPeaks_PS.bin).")
|
|
40
|
+
p.add_argument("--overlap-length", type=float, default=None,
|
|
41
|
+
help="Centroid distance threshold in px (default: from Zarr params, fallback 2.0).")
|
|
42
|
+
args = p.parse_args(argv)
|
|
43
|
+
|
|
44
|
+
from .merge import merge_overlapping_peaks
|
|
45
|
+
from .params import read_zarr_params
|
|
46
|
+
rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
|
|
47
|
+
zp = read_zarr_params(args.zarr_path)
|
|
48
|
+
overlap = args.overlap_length if args.overlap_length is not None else zp.OverlapLength
|
|
49
|
+
merge_overlapping_peaks(
|
|
50
|
+
zarr_path=args.zarr_path,
|
|
51
|
+
allpeaks_ps_bin=args.allpeaks_ps_bin,
|
|
52
|
+
result_folder=rf,
|
|
53
|
+
overlap_length=overlap,
|
|
54
|
+
skip_frame=zp.SkipFrame,
|
|
55
|
+
use_maxima_positions=bool(zp.UseMaximaPositions),
|
|
56
|
+
end_nr=zp.EndNr if zp.EndNr > 0 else None,
|
|
57
|
+
device=args.device, dtype=args.dtype,
|
|
58
|
+
write=True,
|
|
59
|
+
)
|
|
60
|
+
print(f"midas-merge-peaks {__version__}: wrote Result_*.csv and MergeMap.csv to {rf}", file=sys.stderr)
|
|
61
|
+
return 0
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def radius_main(argv: Optional[List[str]] = None) -> int:
|
|
65
|
+
p = _common_argparser("midas-calc-radius", "Per-spot ring/Bragg/grain-volume calculation.")
|
|
66
|
+
p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
|
|
67
|
+
p.add_argument("--result-folder", default=None)
|
|
68
|
+
args = p.parse_args(argv)
|
|
69
|
+
|
|
70
|
+
from .params import read_zarr_params
|
|
71
|
+
from .radius import calc_radius
|
|
72
|
+
rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
|
|
73
|
+
zp = read_zarr_params(args.zarr_path)
|
|
74
|
+
calc_radius(
|
|
75
|
+
result_folder=rf, zarr_params=zp,
|
|
76
|
+
end_nr=zp.EndNr if zp.EndNr > 0 else None,
|
|
77
|
+
device=args.device, dtype=args.dtype, write=True,
|
|
78
|
+
)
|
|
79
|
+
print(f"midas-calc-radius {__version__}: wrote Radius_*.csv to {rf}", file=sys.stderr)
|
|
80
|
+
return 0
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def fit_setup_main(argv: Optional[List[str]] = None) -> int:
|
|
84
|
+
p = _common_argparser("midas-fit-setup", "Per-spot tilt+distortion+wedge correction, filtering, and paramstest.txt writer.")
|
|
85
|
+
p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
|
|
86
|
+
p.add_argument("--result-folder", default=None)
|
|
87
|
+
p.add_argument("--no-fit", action="store_true", help="Force DoFit=0 (skip the geometry refine).")
|
|
88
|
+
args = p.parse_args(argv)
|
|
89
|
+
|
|
90
|
+
from .fit_setup import fit_setup
|
|
91
|
+
from .params import read_zarr_params
|
|
92
|
+
rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
|
|
93
|
+
zp = read_zarr_params(args.zarr_path)
|
|
94
|
+
do_fit = False if args.no_fit else (zp.DoFit == 1)
|
|
95
|
+
fit_setup(
|
|
96
|
+
result_folder=rf, zarr_params=zp,
|
|
97
|
+
end_nr=zp.EndNr if zp.EndNr > 0 else None,
|
|
98
|
+
do_fit=do_fit,
|
|
99
|
+
device=args.device, dtype=args.dtype, write=True,
|
|
100
|
+
)
|
|
101
|
+
print(f"midas-fit-setup {__version__}: wrote InputAll.csv et al to {rf}", file=sys.stderr)
|
|
102
|
+
return 0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def bin_data_main(argv: Optional[List[str]] = None) -> int:
|
|
106
|
+
p = _common_argparser("midas-bin-data", "Bin spots into Spots.bin / ExtraInfo.bin / Data.bin / nData.bin.")
|
|
107
|
+
p.add_argument("--result-folder", default=".")
|
|
108
|
+
args = p.parse_args(argv)
|
|
109
|
+
|
|
110
|
+
from .bin_data import bin_data
|
|
111
|
+
bin_data(
|
|
112
|
+
result_folder=args.result_folder,
|
|
113
|
+
device=args.device, dtype=args.dtype, write=True,
|
|
114
|
+
)
|
|
115
|
+
print(f"midas-bin-data {__version__}: wrote Spots.bin / ExtraInfo.bin / Data.bin / nData.bin to {args.result_folder}", file=sys.stderr)
|
|
116
|
+
return 0
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def main(argv: Optional[List[str]] = None) -> int:
|
|
120
|
+
"""Umbrella command: ``midas-transforms <stage> [args]``."""
|
|
121
|
+
parser = argparse.ArgumentParser(
|
|
122
|
+
prog="midas-transforms",
|
|
123
|
+
description="Pure-Python/PyTorch FF-HEDM transforms (merge / radius / fit-setup / bin-data).",
|
|
124
|
+
)
|
|
125
|
+
parser.add_argument("--version", action="version", version=f"midas-transforms {__version__}")
|
|
126
|
+
sub = parser.add_subparsers(dest="stage", required=True)
|
|
127
|
+
sub.add_parser("merge-peaks", add_help=False)
|
|
128
|
+
sub.add_parser("calc-radius", add_help=False)
|
|
129
|
+
sub.add_parser("fit-setup", add_help=False)
|
|
130
|
+
sub.add_parser("bin-data", add_help=False)
|
|
131
|
+
sub.add_parser("pipeline", add_help=False)
|
|
132
|
+
|
|
133
|
+
# Parse only the first positional, dispatch the rest.
|
|
134
|
+
if argv is None:
|
|
135
|
+
argv = sys.argv[1:]
|
|
136
|
+
if not argv:
|
|
137
|
+
parser.print_help(sys.stderr)
|
|
138
|
+
return 2
|
|
139
|
+
if argv[0] in ("--version", "-V"):
|
|
140
|
+
parser.parse_args(argv)
|
|
141
|
+
return 0
|
|
142
|
+
stage, rest = argv[0], argv[1:]
|
|
143
|
+
if stage == "merge-peaks":
|
|
144
|
+
return merge_main(rest)
|
|
145
|
+
if stage == "calc-radius":
|
|
146
|
+
return radius_main(rest)
|
|
147
|
+
if stage == "fit-setup":
|
|
148
|
+
return fit_setup_main(rest)
|
|
149
|
+
if stage == "bin-data":
|
|
150
|
+
return bin_data_main(rest)
|
|
151
|
+
if stage == "pipeline":
|
|
152
|
+
return pipeline_main(rest)
|
|
153
|
+
|
|
154
|
+
parser.print_help(sys.stderr)
|
|
155
|
+
return 2
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def pipeline_main(argv: Optional[List[str]] = None) -> int:
|
|
159
|
+
p = _common_argparser("midas-transforms pipeline", "Run all four stages on-device with no disk round-trips between them.")
|
|
160
|
+
p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
|
|
161
|
+
p.add_argument("--out-dir", default=None, help="Output directory (default: dir of zarr_path).")
|
|
162
|
+
p.add_argument("--allpeaks-ps-bin", default=None)
|
|
163
|
+
args = p.parse_args(argv)
|
|
164
|
+
|
|
165
|
+
from .pipeline import Pipeline
|
|
166
|
+
pipe = Pipeline.from_zarr(
|
|
167
|
+
args.zarr_path,
|
|
168
|
+
allpeaks_ps_bin=args.allpeaks_ps_bin,
|
|
169
|
+
device=args.device, dtype=args.dtype,
|
|
170
|
+
)
|
|
171
|
+
pipe.run()
|
|
172
|
+
out_dir = Path(args.out_dir) if args.out_dir else Path(args.zarr_path).parent
|
|
173
|
+
pipe.dump(out_dir)
|
|
174
|
+
print(f"midas-transforms pipeline {__version__}: wrote 9 files to {out_dir}", file=sys.stderr)
|
|
175
|
+
return 0
|