midas-transforms 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ """midas-transforms: Pure-Python/PyTorch FF-HEDM transforms.
2
+
3
+ Drop-in replacement for the four C binaries that sit between peak-fitting
4
+ (``midas-peakfit``) and indexing (``midas-index``):
5
+
6
+ - ``MergeOverlappingPeaksAllZarr`` -> ``midas_transforms.merge_overlapping_peaks``
7
+ - ``CalcRadiusAllZarr`` -> ``midas_transforms.calc_radius``
8
+ - ``FitSetupZarr`` -> ``midas_transforms.fit_setup``
9
+ - ``SaveBinData`` -> ``midas_transforms.bin_data``
10
+
11
+ Two equally-supported usage modes:
12
+
13
+ **Mode 1 - per-stage (round-trips through disk, like the C binaries):**
14
+
15
+ from midas_transforms import merge_overlapping_peaks, calc_radius, fit_setup, bin_data
16
+ merge_overlapping_peaks(zarr_path="...", result_folder="...", device="cuda")
17
+ calc_radius(result_folder="...", device="cuda")
18
+ fit_setup(result_folder="...", device="cuda")
19
+ bin_data(result_folder="...", device="cuda")
20
+
21
+ **Mode 2 - chained Pipeline (intermediates stay on GPU, only final outputs written):**
22
+
23
+ from midas_transforms import Pipeline
24
+ pipe = Pipeline.from_zarr(zarr_path, device="cuda")
25
+ result = pipe.run()
26
+ pipe.dump(out_dir)
27
+
28
+ See ``dev/implementation_plan.md`` for design and roadmap.
29
+ """
30
+
31
+ __version__ = "0.1.0"
32
+
33
+ from .merge import merge_overlapping_peaks
34
+ from .radius import calc_radius
35
+ from .fit_setup import fit_setup
36
+ from .bin_data import bin_data
37
+ from .pipeline import Pipeline
38
+
39
+ __all__ = [
40
+ "merge_overlapping_peaks",
41
+ "calc_radius",
42
+ "fit_setup",
43
+ "bin_data",
44
+ "Pipeline",
45
+ "__version__",
46
+ ]
@@ -0,0 +1,6 @@
1
+ """Allow ``python -m midas_transforms <stage> [args]``."""
2
+
3
+ from .cli import main
4
+
5
+ if __name__ == "__main__":
6
+ raise SystemExit(main())
@@ -0,0 +1,10 @@
1
+ """bin_data — replaces the C ``SaveBinData`` binary.
2
+
3
+ Reads ``InputAll.csv``, ``InputAllExtraInfoFittingAll.csv``, ``paramstest.txt``;
4
+ writes ``Spots.bin``, ``ExtraInfo.bin``, and (unless ``NoSaveAll==1``)
5
+ ``Data.bin`` + ``nData.bin``.
6
+ """
7
+
8
+ from .core import bin_data, BinDataResult
9
+
10
+ __all__ = ["bin_data", "BinDataResult"]
@@ -0,0 +1,448 @@
1
+ """bin_data: drop-in replacement for ``SaveBinData``.
2
+
3
+ The C source is `FF_HEDM/src/SaveBinData.c` (341 LoC). The torch port:
4
+
5
+ 1. Reads ``InputAll.csv`` (8 cols) and ``InputAllExtraInfoFittingAll.csv``
6
+ (18 cols) into tensors on ``device``.
7
+ 2. Computes the per-spot ``RadiusDistIdeal = radius_obs - ring_radii[ring_nr]``.
8
+ 3. Writes ``Spots.bin`` (Nx9 float64) and ``ExtraInfo.bin`` (Nx16 float64).
9
+ 4. If ``NoSaveAll == 0``: builds the per-(ring, eta-bin, ome-bin) lookup
10
+ table and writes ``Data.bin`` (int32 ragged) and ``nData.bin`` (count/offset
11
+ pairs, int32).
12
+
13
+ The (eta, ome) bin assignment per spot uses the C margin formulae
14
+ (``SaveBinData.c:265-271``):
15
+
16
+ omemargin = MarginOme + 0.5 * StepSizeOrient / |sin(eta_deg)|
17
+ etamargin = rad2deg * atan(MarginEta / RingRadii[ring]) + 0.5 * StepSizeOrient
18
+
19
+ Then for each spot, all bins in ``[iEtaMin..iEtaMax] × [iOmeMin..iOmeMax]``
20
+ mod ``n_eta``, ``n_ome`` receive the spot's index.
21
+
22
+ The vectorised torch path emits one ``(spot_idx, ring, iEta, iOme)`` tuple
23
+ per (spot, eta-bin, ome-bin) triple, sorts by ``(ring, iEta, iOme, spot_idx)``,
24
+ then ``unique_consecutive`` to recover counts and offsets per bin. The
25
+ ``spot_idx`` secondary key makes the output bit-stable across runs and
26
+ between CPU and GPU (the C version's order is implicitly the spot
27
+ iteration order, which is row order in InputAll.csv == ascending
28
+ ``spot_idx`` for non-empty bins).
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import math
34
+ from dataclasses import dataclass, field
35
+ from pathlib import Path
36
+ from typing import Optional, Union
37
+
38
+ import numpy as np
39
+ import torch
40
+
41
+ from ..device import resolve_device, resolve_dtype
42
+ from ..io import binary as bio
43
+ from ..io import csv as csv_io
44
+ from ..params import ParamsTest, read_paramstest
45
+
46
+
47
+ @dataclass
48
+ class BinDataResult:
49
+ """In-memory result of the bin_data stage. Tensors live on ``device``.
50
+
51
+ Used by ``Pipeline`` to pass to a downstream consumer (e.g. ``midas-index``)
52
+ without going through disk.
53
+ """
54
+
55
+ spots: torch.Tensor # (N, 9) float64
56
+ extra_info: torch.Tensor # (N, 16) float64
57
+ data: Optional[torch.Tensor] = None # (T,) int32, or None when NoSaveAll==1
58
+ ndata: Optional[torch.Tensor] = None # (M, 2) int32, M = n_ring * n_eta * n_ome
59
+ n_ring_bins: int = 0
60
+ n_eta_bins: int = 0
61
+ n_ome_bins: int = 0
62
+ paramstest: Optional[ParamsTest] = field(default=None)
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Pure-tensor kernels
67
+ # ---------------------------------------------------------------------------
68
+
69
+ _DEG2RAD = math.pi / 180.0
70
+ _RAD2DEG = 180.0 / math.pi
71
+
72
+
73
+ # ---- libm.fma binding (for byte-exact y² + z² matching the C build) -----
74
+
75
+ def _resolve_libm_fma():
76
+ """Return a vectorised libm ``fma`` if available, else ``None``.
77
+
78
+ Used to reproduce clang's default ``-ffp-contract=on`` FMA fusion of
79
+ ``y*y + z*z`` so that ``Spots.bin`` col 8 (RadiusDistIdeal) is byte-exact
80
+ against the C ``SaveBinData`` output. Falls back to plain ``y*y + z*z``
81
+ when libm or its ``fma`` symbol is unavailable (Windows, oddball libcs).
82
+ """
83
+ import ctypes
84
+ import ctypes.util
85
+ name = ctypes.util.find_library("m") or "libm.dylib"
86
+ try:
87
+ libm = ctypes.CDLL(name)
88
+ libm.fma.restype = ctypes.c_double
89
+ libm.fma.argtypes = [ctypes.c_double, ctypes.c_double, ctypes.c_double]
90
+ except (OSError, AttributeError):
91
+ return None
92
+ return np.vectorize(libm.fma, otypes=[np.float64])
93
+
94
+
95
+ _FMA = _resolve_libm_fma()
96
+
97
+
98
+ def _radius_dist_ideal_numpy(
99
+ yl: np.ndarray, zl: np.ndarray, ring_nr: np.ndarray, ring_radii: np.ndarray,
100
+ ) -> np.ndarray:
101
+ """Compute ``sqrt(y*y + z*z) - RingRadii[ring]`` using FMA when available.
102
+
103
+ The byte-exact CPU path. The ``y*y + z*z`` fused expression is what the
104
+ C ``SaveBinData`` code compiles to under clang's default FMA contraction.
105
+ """
106
+ z2 = zl * zl
107
+ if _FMA is not None:
108
+ s = _FMA(yl, yl, z2)
109
+ else:
110
+ s = yl * yl + z2
111
+ return np.sqrt(s) - ring_radii[ring_nr]
112
+
113
+
114
+ def _build_ring_radii(p: ParamsTest, max_n_rings: int = 500) -> torch.Tensor:
115
+ """Return a 1-D tensor ``[max_n_rings]`` of radii indexed by ring number.
116
+
117
+ Mirrors ``SaveBinData.c:170-174`` — ``RingRadii[RingNumbers[i]] = RingRadiiUser[i]``
118
+ with everything else zeroed.
119
+ """
120
+ out = np.zeros(max_n_rings, dtype=np.float64)
121
+ for r, rad in zip(p.RingNumbers, p.RingRadii):
122
+ if 0 <= r < max_n_rings:
123
+ out[r] = rad
124
+ return torch.from_numpy(out)
125
+
126
+
127
+ def _compute_radius_dist_ideal(
128
+ spots_first8: torch.Tensor, ring_radii: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ """Per spot: distance from observed radius (col 8 of InputAll, after Ttheta) to ideal ring radius.
131
+
132
+ InputAll.csv columns are
133
+ 0=YLab, 1=ZLab, 2=Omega, 3=GrainRadius, 4=SpotID,
134
+ 5=RingNumber, 6=Eta, 7=Ttheta
135
+ The radius-distance is ``sqrt(YLab^2 + ZLab^2) - RingRadii[RingNumber]``.
136
+
137
+ The C version (``CalcDistanceIdealRing``) recomputes this per spot before
138
+ writing; we replicate the same formula.
139
+ """
140
+ yl = spots_first8[:, 0]
141
+ zl = spots_first8[:, 1]
142
+ ring_nr = spots_first8[:, 5].long()
143
+ radius = torch.sqrt(yl * yl + zl * zl)
144
+ ideal = ring_radii.to(radius.device)[ring_nr]
145
+ return radius - ideal
146
+
147
+
148
+ def _bin_assignment(
149
+ spots_first8: torch.Tensor,
150
+ ring_radii: torch.Tensor,
151
+ margin_ome: float,
152
+ margin_eta: float,
153
+ eta_bin_size: float,
154
+ ome_bin_size: float,
155
+ step_size_orient: float,
156
+ ):
157
+ """Vectorised per-spot bin assignment.
158
+
159
+ Returns: (data_idx, ring_idx, eta_idx, ome_idx) — flattened triples.
160
+
161
+ Notes (from SaveBinData.c:260-289):
162
+ - Margins are spot-specific; the eta margin depends on ring radius and
163
+ the ome margin depends on |sin(eta)|.
164
+ - Bins wrap [0, 360) modulo n_eta and n_ome.
165
+ - The C code does ``omemin = 180 + omega - omemargin`` and floors by
166
+ bin size, generating an integer range [iOmeMin, iOmeMax]. We replicate.
167
+ """
168
+ device = spots_first8.device
169
+ dtype = spots_first8.dtype
170
+
171
+ omega = spots_first8[:, 2]
172
+ ring_nr = spots_first8[:, 5].long()
173
+ eta = spots_first8[:, 6]
174
+
175
+ # Filter spots to those whose ring has a configured radius (>0).
176
+ rrng = ring_radii.to(device=device, dtype=dtype)
177
+ rad_for_ring = rrng[ring_nr]
178
+ keep = rad_for_ring > 0
179
+ if not keep.any():
180
+ empty = torch.empty((0,), dtype=torch.long, device=device)
181
+ return empty, empty, empty, empty
182
+
183
+ spot_idx_all = torch.arange(spots_first8.shape[0], device=device)[keep]
184
+ omega = omega[keep]
185
+ eta = eta[keep]
186
+ ring_nr = ring_nr[keep]
187
+ rad_for_ring = rad_for_ring[keep]
188
+
189
+ # Per-spot ome margin: avoid division by zero when sin(eta)=0 (eta=0 or 180)
190
+ sin_eta_abs = torch.abs(torch.sin(eta * _DEG2RAD))
191
+ sin_eta_abs = torch.clamp(sin_eta_abs, min=1e-12)
192
+ ome_margin_per = margin_ome + 0.5 * step_size_orient / sin_eta_abs
193
+ eta_margin_per = _RAD2DEG * torch.atan(margin_eta / rad_for_ring) + 0.5 * step_size_orient
194
+
195
+ omemin = 180.0 + omega - ome_margin_per
196
+ omemax = 180.0 + omega + ome_margin_per
197
+ etamin = 180.0 + eta - eta_margin_per
198
+ etamax = 180.0 + eta + eta_margin_per
199
+
200
+ iome_min = torch.floor(omemin / ome_bin_size).long()
201
+ iome_max = torch.floor(omemax / ome_bin_size).long()
202
+ ieta_min = torch.floor(etamin / eta_bin_size).long()
203
+ ieta_max = torch.floor(etamax / eta_bin_size).long()
204
+
205
+ # Per-spot range count
206
+ n_eta_per = (ieta_max - ieta_min + 1).clamp(min=0)
207
+ n_ome_per = (iome_max - iome_min + 1).clamp(min=0)
208
+ n_pairs_per = n_eta_per * n_ome_per
209
+
210
+ total = int(n_pairs_per.sum().item())
211
+ if total == 0:
212
+ empty = torch.empty((0,), dtype=torch.long, device=device)
213
+ return empty, empty, empty, empty
214
+
215
+ # Build flat arrays of (spot_idx, ring, iEta, iOme).
216
+ # We do this with cumsum-based offsets and segment indices.
217
+ cum = torch.cumsum(n_pairs_per, dim=0)
218
+ seg_starts = cum - n_pairs_per
219
+
220
+ # Per-output-row source-spot index
221
+ # (Equivalent to ``np.repeat(spot_idx, n_pairs_per)``, but in torch.)
222
+ out_spot_idx = torch.repeat_interleave(spot_idx_all, n_pairs_per)
223
+ out_ring = torch.repeat_interleave(ring_nr, n_pairs_per)
224
+ out_n_eta = torch.repeat_interleave(n_eta_per, n_pairs_per)
225
+ out_ieta_min = torch.repeat_interleave(ieta_min, n_pairs_per)
226
+ out_iome_min = torch.repeat_interleave(iome_min, n_pairs_per)
227
+
228
+ pos = torch.arange(total, device=device) - torch.repeat_interleave(seg_starts, n_pairs_per)
229
+ eta_off = pos // out_n_eta.clamp(min=1)
230
+ # NOTE: above we want the layout (iEta outer, iOme inner) matching C
231
+ # ``for iEta0 ... for iOme0``. So:
232
+ # pos = i_eta * n_ome_per + i_ome
233
+ # We need the per-spot n_ome to recover the inner iteration.
234
+ out_n_ome = torch.repeat_interleave(n_ome_per, n_pairs_per)
235
+ eta_off = pos // out_n_ome.clamp(min=1)
236
+ ome_off = pos - eta_off * out_n_ome.clamp(min=1)
237
+
238
+ out_ieta = out_ieta_min + eta_off
239
+ out_iome = out_iome_min + ome_off
240
+
241
+ return out_spot_idx, out_ring, out_ieta, out_iome
242
+
243
+
244
+ def _bin_to_data_ndata(
245
+ out_spot_idx: torch.Tensor,
246
+ out_ring: torch.Tensor,
247
+ out_ieta: torch.Tensor,
248
+ out_iome: torch.Tensor,
249
+ n_ring_bins: int,
250
+ n_eta_bins: int,
251
+ n_ome_bins: int,
252
+ ):
253
+ """Convert per-(spot, eta, ome) triples to ``(Data, nData)`` arrays.
254
+
255
+ Layout matches ``SaveBinData.c:308-322`` — ring-major, eta-major, ome-major.
256
+ Bins wrap modulo n_eta / n_ome.
257
+ """
258
+ device = out_spot_idx.device
259
+
260
+ # Modulo wrap (negative-aware).
261
+ ieta_mod = (out_ieta % n_eta_bins + n_eta_bins) % n_eta_bins
262
+ iome_mod = (out_iome % n_ome_bins + n_ome_bins) % n_ome_bins
263
+
264
+ # iRing in C is `ringnr - 1`; ring-bin axis is [0, HighestRingNo).
265
+ iring = out_ring - 1
266
+
267
+ # Drop entries whose ring index is out of range. (Defensive; ring_nr is
268
+ # filtered upstream.)
269
+ mask = (iring >= 0) & (iring < n_ring_bins)
270
+ iring = iring[mask]
271
+ ieta_mod = ieta_mod[mask]
272
+ iome_mod = iome_mod[mask]
273
+ out_spot_idx = out_spot_idx[mask]
274
+
275
+ # Composite bin id for sorting (ring outer, eta middle, ome inner).
276
+ bin_id = (iring.long() * n_eta_bins + ieta_mod.long()) * n_ome_bins + iome_mod.long()
277
+ # Stable sort by (bin_id, spot_idx) for deterministic insertion order.
278
+ composite = bin_id * (out_spot_idx.max().item() + 2 if out_spot_idx.numel() else 1) + out_spot_idx
279
+ order = torch.argsort(composite, stable=True)
280
+ sorted_bin_id = bin_id[order]
281
+ sorted_spot_idx = out_spot_idx[order]
282
+
283
+ # Counts per bin: scatter add ones.
284
+ total_bins = n_ring_bins * n_eta_bins * n_ome_bins
285
+ counts = torch.zeros(total_bins, dtype=torch.int64, device=device)
286
+ counts.scatter_add_(0, sorted_bin_id, torch.ones_like(sorted_bin_id))
287
+
288
+ # Offsets are cumulative by visit order. Since bins are in
289
+ # ring-major / eta-major / ome-major order in the output, the offset
290
+ # for each bin is the running total of all preceding bins.
291
+ offsets = torch.zeros(total_bins, dtype=torch.int64, device=device)
292
+ offsets[1:] = torch.cumsum(counts[:-1], dim=0)
293
+
294
+ # Pack ndata as (count, offset) per bin.
295
+ ndata = torch.stack([counts, offsets], dim=1).to(torch.int32)
296
+ data = sorted_spot_idx.to(torch.int32)
297
+ return data, ndata
298
+
299
+
300
+ # ---------------------------------------------------------------------------
301
+ # Public entry point
302
+ # ---------------------------------------------------------------------------
303
+
304
+
305
+ def bin_data(
306
+ result_folder: Union[str, Path] = ".",
307
+ *,
308
+ inputall_csv: Optional[Union[str, Path]] = None,
309
+ inputall_extra_csv: Optional[Union[str, Path]] = None,
310
+ paramstest_path: Optional[Union[str, Path]] = None,
311
+ out_dir: Optional[Union[str, Path]] = None,
312
+ paramstest: Optional[ParamsTest] = None,
313
+ spots_inputall: Optional[np.ndarray] = None,
314
+ extra_inputall: Optional[np.ndarray] = None,
315
+ device: Optional[Union[str, torch.device]] = None,
316
+ dtype: Optional[Union[str, torch.dtype]] = None,
317
+ write: bool = True,
318
+ ) -> BinDataResult:
319
+ """Run the binning stage. Drop-in replacement for the C ``SaveBinData`` binary.
320
+
321
+ Defaults match the C binary's argv-less convention: read ``InputAll.csv``,
322
+ ``InputAllExtraInfoFittingAll.csv``, and ``paramstest.txt`` from the
323
+ current directory; write outputs to the same directory.
324
+
325
+ Parameters
326
+ ----------
327
+ result_folder
328
+ Directory to read inputs from / write outputs to (when no override).
329
+ inputall_csv, inputall_extra_csv, paramstest_path
330
+ Optional input file overrides.
331
+ out_dir
332
+ Optional output directory override.
333
+ paramstest, spots_inputall, extra_inputall
334
+ Optional in-memory inputs (used by ``Pipeline``).
335
+ device, dtype
336
+ Torch device / dtype.
337
+ write
338
+ If ``False``, skip disk writes and return only the in-memory result.
339
+
340
+ Returns
341
+ -------
342
+ BinDataResult
343
+ """
344
+ rf = Path(result_folder)
345
+ out_dir = Path(out_dir) if out_dir is not None else rf
346
+
347
+ dev = resolve_device(device)
348
+ dt = resolve_dtype(dev, dtype)
349
+
350
+ # Inputs.
351
+ if paramstest is None:
352
+ ppath = paramstest_path if paramstest_path is not None else rf / "paramstest.txt"
353
+ paramstest = read_paramstest(ppath)
354
+
355
+ if spots_inputall is None:
356
+ ipath = inputall_csv if inputall_csv is not None else rf / "InputAll.csv"
357
+ spots_inputall = csv_io.read_inputall_csv(ipath)
358
+ if extra_inputall is None:
359
+ ipath = inputall_extra_csv if inputall_extra_csv is not None else rf / "InputAllExtraInfoFittingAll.csv"
360
+ extra_inputall = csv_io.read_inputall_extra_csv(ipath)
361
+
362
+ if spots_inputall.shape[0] != extra_inputall.shape[0]:
363
+ raise ValueError(
364
+ f"InputAll ({spots_inputall.shape[0]} rows) and "
365
+ f"InputAllExtraInfoFittingAll ({extra_inputall.shape[0]} rows) "
366
+ "must agree on row count."
367
+ )
368
+
369
+ n_spots = spots_inputall.shape[0]
370
+ if n_spots == 0:
371
+ raise ValueError("No spots in InputAll.csv. Aborting.")
372
+
373
+ # Move to device.
374
+ spots_t = torch.from_numpy(spots_inputall.astype(np.float64)).to(device=dev, dtype=dt)
375
+ extra_t = torch.from_numpy(extra_inputall.astype(np.float64)).to(device=dev, dtype=dt)
376
+
377
+ # Compute Spots.bin layout (cols 0-7 + RadiusDistIdeal).
378
+ ring_radii = _build_ring_radii(paramstest).to(device=dev, dtype=dt)
379
+ # On the byte-exact CPU path, replicate the C compiler's FMA-fused
380
+ # ``y*y + z*z`` (clang -O3 emits a fused-multiply-add on default
381
+ # contraction). Without FMA, plain ``y*y + z*z`` rounds twice and
382
+ # differs by 1 ULP in ~1% of spots from the C output.
383
+ if dev.type == "cpu" and dt == torch.float64:
384
+ yl_np = spots_inputall[:, 0].astype(np.float64)
385
+ zl_np = spots_inputall[:, 1].astype(np.float64)
386
+ ring_np = spots_inputall[:, 5].astype(np.int64)
387
+ ring_radii_np = np.zeros(500, dtype=np.float64)
388
+ for r, rad in zip(paramstest.RingNumbers, paramstest.RingRadii):
389
+ if 0 <= r < 500:
390
+ ring_radii_np[r] = rad
391
+ rad_dist_np = _radius_dist_ideal_numpy(yl_np, zl_np, ring_np, ring_radii_np)
392
+ rad_dist = torch.from_numpy(rad_dist_np)
393
+ else:
394
+ rad_dist = _compute_radius_dist_ideal(spots_t, ring_radii)
395
+ spots_out = torch.cat([spots_t, rad_dist.unsqueeze(1)], dim=1) # (N, 9)
396
+ # ExtraInfo.bin is 16 cols; drop CSV cols 14 and 15 (the C version's dummy0/dummy1).
397
+ # See SaveBinData.c — sscanf maps CSV[16, 17] to AllSpots[14, 15].
398
+ if extra_t.shape[1] == 18:
399
+ extra_out = torch.cat([extra_t[:, :14], extra_t[:, 16:18]], dim=1)
400
+ elif extra_t.shape[1] == 16:
401
+ extra_out = extra_t
402
+ else:
403
+ raise ValueError(
404
+ f"InputAllExtraInfoFittingAll must have 16 or 18 cols, got {extra_t.shape[1]}"
405
+ )
406
+
407
+ if write:
408
+ bio.write_spots_bin(out_dir / "Spots.bin", spots_out.detach().cpu().numpy().astype(np.float64))
409
+ bio.write_extrainfo_bin(out_dir / "ExtraInfo.bin", extra_out.detach().cpu().numpy().astype(np.float64))
410
+
411
+ if paramstest.NoSaveAll == 1:
412
+ return BinDataResult(
413
+ spots=spots_out, extra_info=extra_out,
414
+ paramstest=paramstest,
415
+ )
416
+
417
+ # Determine bin counts.
418
+ n_ring_bins = paramstest.highest_ring_no
419
+ n_eta_bins = math.ceil(360.0 / paramstest.EtaBinSize)
420
+ n_ome_bins = math.ceil(360.0 / paramstest.OmeBinSize)
421
+
422
+ out_spot_idx, out_ring, out_ieta, out_iome = _bin_assignment(
423
+ spots_t,
424
+ ring_radii,
425
+ margin_ome=paramstest.MarginOme,
426
+ margin_eta=paramstest.MarginEta,
427
+ eta_bin_size=paramstest.EtaBinSize,
428
+ ome_bin_size=paramstest.OmeBinSize,
429
+ step_size_orient=paramstest.StepSizeOrient,
430
+ )
431
+ data, ndata = _bin_to_data_ndata(
432
+ out_spot_idx, out_ring, out_ieta, out_iome,
433
+ n_ring_bins=n_ring_bins, n_eta_bins=n_eta_bins, n_ome_bins=n_ome_bins,
434
+ )
435
+
436
+ if write:
437
+ bio.write_data_ndata_bin(
438
+ out_dir / "Data.bin", out_dir / "nData.bin",
439
+ data.detach().cpu().numpy().astype(np.int32),
440
+ ndata.detach().cpu().numpy().astype(np.int32),
441
+ )
442
+
443
+ return BinDataResult(
444
+ spots=spots_out, extra_info=extra_out,
445
+ data=data, ndata=ndata,
446
+ n_ring_bins=n_ring_bins, n_eta_bins=n_eta_bins, n_ome_bins=n_ome_bins,
447
+ paramstest=paramstest,
448
+ )
@@ -0,0 +1,175 @@
1
+ """CLI dispatch for ``midas-transforms``.
2
+
3
+ The umbrella command ``midas-transforms <stage> [args]`` plus four
4
+ sub-CLIs that mirror the C-binary argv contracts:
5
+
6
+ ``midas-merge-peaks <zarr_path>``
7
+ ``midas-calc-radius <zarr_path>``
8
+ ``midas-fit-setup <zarr_path>``
9
+ ``midas-bin-data`` (no positional args, reads from the cwd)
10
+
11
+ Each sub-CLI is a thin wrapper around the corresponding library function.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import sys
18
+ from pathlib import Path
19
+ from typing import List, Optional
20
+
21
+ from . import __version__
22
+
23
+
24
+ def _common_argparser(prog: str, description: str) -> argparse.ArgumentParser:
25
+ p = argparse.ArgumentParser(prog=prog, description=description)
26
+ p.add_argument("--device", choices=["cpu", "cuda", "mps"], default=None)
27
+ p.add_argument("--dtype", choices=["float32", "float64"], default=None)
28
+ p.add_argument("--version", action="version", version=f"midas-transforms {__version__}")
29
+ return p
30
+
31
+
32
+ def merge_main(argv: Optional[List[str]] = None) -> int:
33
+ p = _common_argparser("midas-merge-peaks", "Frame-by-frame mutual-nearest merge of consolidated peakfit output.")
34
+ p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
35
+ p.add_argument("--result-folder", default=None,
36
+ help="Override the result folder (default: directory of zarr_path).")
37
+ p.add_argument("--allpeaks-ps-bin", default=None,
38
+ help="Override the AllPeaks_PS.bin path "
39
+ "(default: <result-folder>/Temp/AllPeaks_PS.bin).")
40
+ p.add_argument("--overlap-length", type=float, default=None,
41
+ help="Centroid distance threshold in px (default: from Zarr params, fallback 2.0).")
42
+ args = p.parse_args(argv)
43
+
44
+ from .merge import merge_overlapping_peaks
45
+ from .params import read_zarr_params
46
+ rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
47
+ zp = read_zarr_params(args.zarr_path)
48
+ overlap = args.overlap_length if args.overlap_length is not None else zp.OverlapLength
49
+ merge_overlapping_peaks(
50
+ zarr_path=args.zarr_path,
51
+ allpeaks_ps_bin=args.allpeaks_ps_bin,
52
+ result_folder=rf,
53
+ overlap_length=overlap,
54
+ skip_frame=zp.SkipFrame,
55
+ use_maxima_positions=bool(zp.UseMaximaPositions),
56
+ end_nr=zp.EndNr if zp.EndNr > 0 else None,
57
+ device=args.device, dtype=args.dtype,
58
+ write=True,
59
+ )
60
+ print(f"midas-merge-peaks {__version__}: wrote Result_*.csv and MergeMap.csv to {rf}", file=sys.stderr)
61
+ return 0
62
+
63
+
64
+ def radius_main(argv: Optional[List[str]] = None) -> int:
65
+ p = _common_argparser("midas-calc-radius", "Per-spot ring/Bragg/grain-volume calculation.")
66
+ p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
67
+ p.add_argument("--result-folder", default=None)
68
+ args = p.parse_args(argv)
69
+
70
+ from .params import read_zarr_params
71
+ from .radius import calc_radius
72
+ rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
73
+ zp = read_zarr_params(args.zarr_path)
74
+ calc_radius(
75
+ result_folder=rf, zarr_params=zp,
76
+ end_nr=zp.EndNr if zp.EndNr > 0 else None,
77
+ device=args.device, dtype=args.dtype, write=True,
78
+ )
79
+ print(f"midas-calc-radius {__version__}: wrote Radius_*.csv to {rf}", file=sys.stderr)
80
+ return 0
81
+
82
+
83
+ def fit_setup_main(argv: Optional[List[str]] = None) -> int:
84
+ p = _common_argparser("midas-fit-setup", "Per-spot tilt+distortion+wedge correction, filtering, and paramstest.txt writer.")
85
+ p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
86
+ p.add_argument("--result-folder", default=None)
87
+ p.add_argument("--no-fit", action="store_true", help="Force DoFit=0 (skip the geometry refine).")
88
+ args = p.parse_args(argv)
89
+
90
+ from .fit_setup import fit_setup
91
+ from .params import read_zarr_params
92
+ rf = Path(args.result_folder) if args.result_folder else Path(args.zarr_path).parent
93
+ zp = read_zarr_params(args.zarr_path)
94
+ do_fit = False if args.no_fit else (zp.DoFit == 1)
95
+ fit_setup(
96
+ result_folder=rf, zarr_params=zp,
97
+ end_nr=zp.EndNr if zp.EndNr > 0 else None,
98
+ do_fit=do_fit,
99
+ device=args.device, dtype=args.dtype, write=True,
100
+ )
101
+ print(f"midas-fit-setup {__version__}: wrote InputAll.csv et al to {rf}", file=sys.stderr)
102
+ return 0
103
+
104
+
105
+ def bin_data_main(argv: Optional[List[str]] = None) -> int:
106
+ p = _common_argparser("midas-bin-data", "Bin spots into Spots.bin / ExtraInfo.bin / Data.bin / nData.bin.")
107
+ p.add_argument("--result-folder", default=".")
108
+ args = p.parse_args(argv)
109
+
110
+ from .bin_data import bin_data
111
+ bin_data(
112
+ result_folder=args.result_folder,
113
+ device=args.device, dtype=args.dtype, write=True,
114
+ )
115
+ print(f"midas-bin-data {__version__}: wrote Spots.bin / ExtraInfo.bin / Data.bin / nData.bin to {args.result_folder}", file=sys.stderr)
116
+ return 0
117
+
118
+
119
+ def main(argv: Optional[List[str]] = None) -> int:
120
+ """Umbrella command: ``midas-transforms <stage> [args]``."""
121
+ parser = argparse.ArgumentParser(
122
+ prog="midas-transforms",
123
+ description="Pure-Python/PyTorch FF-HEDM transforms (merge / radius / fit-setup / bin-data).",
124
+ )
125
+ parser.add_argument("--version", action="version", version=f"midas-transforms {__version__}")
126
+ sub = parser.add_subparsers(dest="stage", required=True)
127
+ sub.add_parser("merge-peaks", add_help=False)
128
+ sub.add_parser("calc-radius", add_help=False)
129
+ sub.add_parser("fit-setup", add_help=False)
130
+ sub.add_parser("bin-data", add_help=False)
131
+ sub.add_parser("pipeline", add_help=False)
132
+
133
+ # Parse only the first positional, dispatch the rest.
134
+ if argv is None:
135
+ argv = sys.argv[1:]
136
+ if not argv:
137
+ parser.print_help(sys.stderr)
138
+ return 2
139
+ if argv[0] in ("--version", "-V"):
140
+ parser.parse_args(argv)
141
+ return 0
142
+ stage, rest = argv[0], argv[1:]
143
+ if stage == "merge-peaks":
144
+ return merge_main(rest)
145
+ if stage == "calc-radius":
146
+ return radius_main(rest)
147
+ if stage == "fit-setup":
148
+ return fit_setup_main(rest)
149
+ if stage == "bin-data":
150
+ return bin_data_main(rest)
151
+ if stage == "pipeline":
152
+ return pipeline_main(rest)
153
+
154
+ parser.print_help(sys.stderr)
155
+ return 2
156
+
157
+
158
+ def pipeline_main(argv: Optional[List[str]] = None) -> int:
159
+ p = _common_argparser("midas-transforms pipeline", "Run all four stages on-device with no disk round-trips between them.")
160
+ p.add_argument("zarr_path", help="Path to the MIDAS Zarr archive (.zip).")
161
+ p.add_argument("--out-dir", default=None, help="Output directory (default: dir of zarr_path).")
162
+ p.add_argument("--allpeaks-ps-bin", default=None)
163
+ args = p.parse_args(argv)
164
+
165
+ from .pipeline import Pipeline
166
+ pipe = Pipeline.from_zarr(
167
+ args.zarr_path,
168
+ allpeaks_ps_bin=args.allpeaks_ps_bin,
169
+ device=args.device, dtype=args.dtype,
170
+ )
171
+ pipe.run()
172
+ out_dir = Path(args.out_dir) if args.out_dir else Path(args.zarr_path).parent
173
+ pipe.dump(out_dir)
174
+ print(f"midas-transforms pipeline {__version__}: wrote 9 files to {out_dir}", file=sys.stderr)
175
+ return 0