essreduce 25.4.1__py3-none-any.whl → 25.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from scipp.constants import g
15
15
  from scipp.core import label_based_index_to_positional_index
16
16
  from scippneutron.chopper import extract_chopper_from_nexus
17
17
 
18
+ from ..utils import prune_type_vars
18
19
  from . import _nexus_loader as nexus
19
20
  from .types import (
20
21
  AllNeXusComponents,
@@ -718,33 +719,6 @@ def GenericNeXusWorkflow(
718
719
  wf[PreopenNeXusFile] = PreopenNeXusFile(False)
719
720
 
720
721
  if run_types is not None or monitor_types is not None:
721
- _prune_type_vars(wf, run_types=run_types, monitor_types=monitor_types)
722
+ prune_type_vars(wf, run_types=run_types, monitor_types=monitor_types)
722
723
 
723
724
  return wf
724
-
725
-
726
- def _prune_type_vars(
727
- workflow: sciline.Pipeline,
728
- *,
729
- run_types: Iterable[sciline.typing.Key] | None,
730
- monitor_types: Iterable[sciline.typing.Key] | None,
731
- ) -> None:
732
- # Remove all nodes that use a run type or monitor types that is
733
- # not listed in the function arguments.
734
- excluded_run_types = _excluded_type_args(RunType, run_types)
735
- excluded_monitor_types = _excluded_type_args(MonitorType, monitor_types)
736
- excluded_types = excluded_run_types | excluded_monitor_types
737
-
738
- graph = workflow.underlying_graph
739
- to_remove = [
740
- node for node in graph if excluded_types & set(getattr(node, "__args__", set()))
741
- ]
742
- graph.remove_nodes_from(to_remove)
743
-
744
-
745
- def _excluded_type_args(
746
- type_var: Any, keep: Iterable[sciline.typing.Key] | None
747
- ) -> set[sciline.typing.Key]:
748
- if keep is None:
749
- return set()
750
- return set(type_var.__constraints__) - set(keep)
@@ -6,42 +6,54 @@ Utilities for computing real neutron time-of-flight from chopper settings and
6
6
  neutron time-of-arrival at the detectors.
7
7
  """
8
8
 
9
- from .eto_to_tof import default_parameters, providers, resample_tof_data
9
+ from .eto_to_tof import (
10
+ default_parameters,
11
+ providers,
12
+ resample_detector_time_of_flight_data,
13
+ resample_monitor_time_of_flight_data,
14
+ )
10
15
  from .simulation import simulate_beamline
11
16
  from .to_events import to_events
12
17
  from .types import (
18
+ DetectorLtotal,
19
+ DetectorTofData,
13
20
  DistanceResolution,
14
21
  LookupTableRelativeErrorThreshold,
15
- Ltotal,
16
22
  LtotalRange,
23
+ MonitorLtotal,
24
+ MonitorTofData,
17
25
  PulsePeriod,
18
26
  PulseStride,
19
27
  PulseStrideOffset,
20
- RawData,
21
- ResampledTofData,
28
+ ResampledDetectorTofData,
29
+ ResampledMonitorTofData,
22
30
  SimulationResults,
23
31
  TimeOfFlightLookupTable,
24
32
  TimeResolution,
25
- TofData,
26
33
  )
34
+ from .workflow import GenericTofWorkflow
27
35
 
28
36
  __all__ = [
37
+ "DetectorLtotal",
38
+ "DetectorTofData",
29
39
  "DistanceResolution",
40
+ "GenericTofWorkflow",
30
41
  "LookupTableRelativeErrorThreshold",
31
- "Ltotal",
32
42
  "LtotalRange",
43
+ "MonitorLtotal",
44
+ "MonitorTofData",
33
45
  "PulsePeriod",
34
46
  "PulseStride",
35
47
  "PulseStrideOffset",
36
- "RawData",
37
- "ResampledTofData",
48
+ "ResampledDetectorTofData",
49
+ "ResampledMonitorTofData",
38
50
  "SimulationResults",
39
51
  "TimeOfFlightLookupTable",
40
52
  "TimeResolution",
41
- "TofData",
42
53
  "default_parameters",
43
54
  "providers",
44
- "resample_tof_data",
55
+ "resample_detector_time_of_flight_data",
56
+ "resample_monitor_time_of_flight_data",
45
57
  "simulate_beamline",
46
58
  "to_events",
47
59
  ]
@@ -11,6 +11,7 @@ from collections.abc import Callable
11
11
 
12
12
  import numpy as np
13
13
  import scipp as sc
14
+ import scippneutron as scn
14
15
  from scipp._scipp.core import _bins_no_validate
15
16
  from scippneutron._utils import elem_unit
16
17
 
@@ -18,21 +19,32 @@ try:
18
19
  from .interpolator_numba import Interpolator as InterpolatorImpl
19
20
  except ImportError:
20
21
  from .interpolator_scipy import Interpolator as InterpolatorImpl
22
+
23
+ from ..nexus.types import (
24
+ CalibratedBeamline,
25
+ CalibratedMonitor,
26
+ DetectorData,
27
+ MonitorData,
28
+ MonitorType,
29
+ RunType,
30
+ )
21
31
  from .to_events import to_events
22
32
  from .types import (
33
+ DetectorLtotal,
34
+ DetectorTofData,
23
35
  DistanceResolution,
24
36
  LookupTableRelativeErrorThreshold,
25
- Ltotal,
26
37
  LtotalRange,
38
+ MonitorLtotal,
39
+ MonitorTofData,
27
40
  PulsePeriod,
28
41
  PulseStride,
29
42
  PulseStrideOffset,
30
- RawData,
31
- ResampledTofData,
43
+ ResampledDetectorTofData,
44
+ ResampledMonitorTofData,
32
45
  SimulationResults,
33
46
  TimeOfFlightLookupTable,
34
47
  TimeResolution,
35
- TofData,
36
48
  )
37
49
 
38
50
 
@@ -84,7 +96,7 @@ def _compute_mean_tof_in_distance_range(
84
96
  frame_period:
85
97
  Period of the source pulses, i.e., time between consecutive pulse starts.
86
98
  time_bins_half_width:
87
- Half the width of the time bins.
99
+ Half width of the time bins in the event_time_offset axis.
88
100
  """
89
101
  simulation_distance = simulation.distance.to(unit=distance_unit)
90
102
  distances = sc.midpoints(distance_bins)
@@ -104,8 +116,7 @@ def _compute_mean_tof_in_distance_range(
104
116
  },
105
117
  ).flatten(to="event")
106
118
 
107
- # Add the event_time_offset coordinate to the data. We first operate on the
108
- # frame period. The table will later be folded to the pulse period.
119
+ # Add the event_time_offset coordinate, wrapped to the frame_period
109
120
  data.coords['event_time_offset'] = data.coords['toa'] % frame_period
110
121
 
111
122
  # Because we staggered the mesh by half a bin width, we want the values above
@@ -134,51 +145,6 @@ def _compute_mean_tof_in_distance_range(
134
145
  return mean_tof
135
146
 
136
147
 
137
- def _fold_table_to_pulse_period(
138
- table: sc.DataArray, pulse_period: sc.Variable, pulse_stride: int
139
- ) -> sc.DataArray:
140
- """
141
- Fold the lookup table to the pulse period. We make sure the left and right edges of
142
- the table wrap around the ``event_time_offset`` dimension.
143
-
144
- Parameters
145
- ----------
146
- table:
147
- Lookup table with time-of-flight as a function of distance and time-of-arrival.
148
- pulse_period:
149
- Period of the source pulses, i.e., time between consecutive pulse starts.
150
- pulse_stride:
151
- Stride of used pulses. Usually 1, but may be a small integer when
152
- pulse-skipping.
153
- """
154
- size = table.sizes['event_time_offset']
155
- if (size % pulse_stride) != 0:
156
- raise ValueError(
157
- "TimeOfFlightLookupTable: the number of time bins must be a multiple of "
158
- f"the pulse stride, but got {size} time bins and a pulse stride of "
159
- f"{pulse_stride}."
160
- )
161
-
162
- size = size // pulse_stride
163
- out = sc.concat([table, table['event_time_offset', 0]], dim='event_time_offset')
164
- out = sc.concat(
165
- [
166
- out['event_time_offset', (i * size) : (i + 1) * size + 1]
167
- for i in range(pulse_stride)
168
- ],
169
- dim='pulse',
170
- )
171
- return out.assign_coords(
172
- event_time_offset=sc.concat(
173
- [
174
- table.coords['event_time_offset']['event_time_offset', :size],
175
- pulse_period,
176
- ],
177
- 'event_time_offset',
178
- )
179
- )
180
-
181
-
182
148
  def compute_tof_lookup_table(
183
149
  simulation: SimulationResults,
184
150
  ltotal_range: LtotalRange,
@@ -212,6 +178,43 @@ def compute_tof_lookup_table(
212
178
  error_threshold:
213
179
  Threshold for the relative standard deviation (coefficient of variation) of the
214
180
  projected time-of-flight above which values are masked.
181
+
182
+ Notes
183
+ -----
184
+
185
+ Below are some details about the binning and wrapping around frame period in the
186
+ time dimension.
187
+
188
+ We have some simulated ``toa`` (events) from a Tof/McStas simulation.
189
+ Those are absolute ``toa``, unwrapped.
190
+ First we compute the usual ``event_time_offset = toa % frame_period``.
191
+
192
+ Now, we want to ensure periodic boundaries. If we make a bin centered around 0,
193
+ and a bin centered around 71ms: the first bin will use events between 0 and
194
+ ``0.5 * dt`` (where ``dt`` is the bin width).
195
+ The last bin will use events between ``frame_period - 0.5*dt`` and
196
+ ``frame_period + 0.5 * dt``. So when we compute the mean inside those two bins,
197
+ they will not yield the same results.
198
+ It is as if the first bin is missing the events it should have between
199
+ ``-0.5 * dt`` and 0 (because of the modulo we computed above).
200
+
201
+ To fix this, we do not make a last bin around 71ms (the bins stop at
202
+ ``frame_period - 0.5*dt``). Instead, we compute modulo a second time,
203
+ but this time using ``event_time_offset %= (frame_period - 0.5*dt)``.
204
+ (we cannot directly do ``event_time_offset = toa % (frame_period - 0.5*dt)`` in a
205
+ single step because it would introduce a gradual shift,
206
+ as the pulse number increases).
207
+
208
+ This second modulo effectively takes all the events that would have gone in the
209
+ last bin (between ``frame_period - 0.5*dt`` and ``frame_period``) and puts them in
210
+ the first bin. Instead of placing them between ``-0.5*dt`` and 0,
211
+ it places them between 0 and ``0.5*dt``, but this does not really matter,
212
+ because we then take the mean inside the first bin.
213
+ Whether the events are on the left or right side of zero does not matter.
214
+
215
+ Finally, we make a copy of the left edge, and append it to the right of the table,
216
+ thus ensuring that the values on the right edge are strictly the same as on the
217
+ left edge.
215
218
  """
216
219
  distance_unit = "m"
217
220
  time_unit = simulation.time_of_arrival.unit
@@ -276,16 +279,31 @@ def compute_tof_lookup_table(
276
279
  table.coords["distance"] = sc.midpoints(table.coords["distance"])
277
280
  table.coords["event_time_offset"] = sc.midpoints(table.coords["event_time_offset"])
278
281
 
279
- table = _fold_table_to_pulse_period(
280
- table=table, pulse_period=pulse_period, pulse_stride=pulse_stride
282
+ # Copy the left edge to the right to create periodic boundary conditions
283
+ table = sc.DataArray(
284
+ data=sc.concat(
285
+ [table.data, table.data['event_time_offset', 0]], dim='event_time_offset'
286
+ ),
287
+ coords={
288
+ "distance": table.coords["distance"],
289
+ "event_time_offset": sc.concat(
290
+ [table.coords["event_time_offset"], frame_period],
291
+ dim='event_time_offset',
292
+ ),
293
+ "pulse_period": pulse_period,
294
+ "pulse_stride": sc.scalar(pulse_stride, unit=None),
295
+ "distance_resolution": table.coords["distance"][1]
296
+ - table.coords["distance"][0],
297
+ "time_resolution": table.coords["event_time_offset"][1]
298
+ - table.coords["event_time_offset"][0],
299
+ "error_threshold": sc.scalar(error_threshold),
300
+ },
281
301
  )
282
302
 
283
303
  # In-place masking for better performance
284
304
  _mask_large_uncertainty(table, error_threshold)
285
305
 
286
- return TimeOfFlightLookupTable(
287
- table.transpose(('pulse', 'distance', 'event_time_offset'))
288
- )
306
+ return TimeOfFlightLookupTable(table)
289
307
 
290
308
 
291
309
  class TofInterpolator:
@@ -293,22 +311,6 @@ class TofInterpolator:
293
311
  self._distance_unit = distance_unit
294
312
  self._time_unit = time_unit
295
313
 
296
- # In the pulse dimension, it could be that for a given event_time_offset and
297
- # distance, a tof value is finite in one pulse and NaN in the other.
298
- # When using the bilinear interpolation, even if the value of the requested
299
- # point is exactly 0 or 1 (in the case of pulse_stride=2), the interpolator
300
- # will still use all 4 corners surrounding the point. This means that if one of
301
- # the corners is NaN, the result will be NaN.
302
- # Here, we use a trick where we duplicate the lookup values in the 'pulse'
303
- # dimension so that the interpolator has values on bin edges for that dimension.
304
- # The interpolator raises an error if axes coordinates are not strictly
305
- # monotonic, so we cannot use e.g. [-0.5, 0.5, 0.5, 1.5] in the case of
306
- # pulse_stride=2. Instead we use [-0.25, 0.25, 0.75, 1.25].
307
- base_grid = np.arange(float(lookup.sizes["pulse"]))
308
- self._pulse_edges = np.sort(
309
- np.concatenate([base_grid - 0.25, base_grid + 0.25])
310
- )
311
-
312
314
  self._time_edges = (
313
315
  lookup.coords["event_time_offset"]
314
316
  .to(unit=self._time_unit, copy=False)
@@ -321,23 +323,16 @@ class TofInterpolator:
321
323
  self._interpolator = InterpolatorImpl(
322
324
  time_edges=self._time_edges,
323
325
  distance_edges=self._distance_edges,
324
- pulse_edges=self._pulse_edges,
325
- values=np.repeat(
326
- lookup.data.to(unit=self._time_unit, copy=False).values, 2, axis=0
327
- ),
326
+ values=lookup.data.to(unit=self._time_unit, copy=False).values,
328
327
  )
329
328
 
330
329
  def __call__(
331
330
  self,
332
- pulse_index: sc.Variable,
333
331
  ltotal: sc.Variable,
334
332
  event_time_offset: sc.Variable,
333
+ pulse_period: sc.Variable,
334
+ pulse_index: sc.Variable | None = None,
335
335
  ) -> sc.Variable:
336
- if pulse_index.unit not in ("", None):
337
- raise sc.UnitError(
338
- "pulse_index must have unit dimensionless or None, "
339
- f"but got unit: {pulse_index.unit}."
340
- )
341
336
  if ltotal.unit != self._distance_unit:
342
337
  raise sc.UnitError(
343
338
  f"ltotal must have unit: {self._distance_unit}, "
@@ -349,31 +344,30 @@ class TofInterpolator:
349
344
  f"but got unit: {event_time_offset.unit}."
350
345
  )
351
346
  out_dims = event_time_offset.dims
352
- pulse_index = pulse_index.values
353
347
  ltotal = ltotal.values
354
348
  event_time_offset = event_time_offset.values
355
349
 
356
350
  return sc.array(
357
351
  dims=out_dims,
358
352
  values=self._interpolator(
359
- times=event_time_offset, distances=ltotal, pulse_indices=pulse_index
353
+ times=event_time_offset,
354
+ distances=ltotal,
355
+ pulse_index=pulse_index.values if pulse_index is not None else None,
356
+ pulse_period=pulse_period.value,
360
357
  ),
361
358
  unit=self._time_unit,
362
359
  )
363
360
 
364
361
 
365
362
  def _time_of_flight_data_histogram(
366
- da: sc.DataArray,
367
- lookup: sc.DataArray,
368
- ltotal: sc.Variable,
369
- pulse_period: sc.Variable,
363
+ da: sc.DataArray, lookup: sc.DataArray, ltotal: sc.Variable
370
364
  ) -> sc.DataArray:
371
365
  # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files,
372
366
  # it may be called 'tof'.
373
367
  key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof"}))
374
368
  raw_eto = da.coords[key].to(dtype=float, copy=False)
375
369
  eto_unit = raw_eto.unit
376
- pulse_period = pulse_period.to(unit=eto_unit)
370
+ pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit)
377
371
 
378
372
  # In histogram mode, because there is a wrap around at the end of the pulse, we
379
373
  # need to insert a bin edge at that exact location to avoid having the last bin
@@ -386,31 +380,14 @@ def _time_of_flight_data_histogram(
386
380
  rebinned = da.rebin({key: new_bins})
387
381
  etos = rebinned.coords[key]
388
382
 
389
- # In histogram mode, the lookup table cannot have a pulse dimension because we
390
- # cannot know in the histogrammed data which pulse the events belong to.
391
- # So we merge the pulse dimension in the lookup table. A quick way to do this
392
- # is to take the mean of the data along the pulse dimension (there should
393
- # only be regions that are NaN in one pulse and finite in the other).
394
- merged = lookup.data.nanmean('pulse')
395
- dim = merged.dims[0]
396
- lookup = sc.DataArray(
397
- data=merged.fold(dim=dim, sizes={'pulse': 1, dim: merged.sizes[dim]}),
398
- coords={
399
- 'pulse': sc.arange('pulse', 1.0),
400
- 'distance': lookup.coords['distance'],
401
- 'event_time_offset': lookup.coords['event_time_offset'],
402
- },
403
- )
404
- pulse_index = sc.zeros(sizes=etos.sizes)
405
-
406
383
  # Create linear interpolator
407
384
  interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit)
408
385
 
409
386
  # Compute time-of-flight of the bin edges using the interpolator
410
387
  tofs = interp(
411
- pulse_index=pulse_index,
412
388
  ltotal=ltotal.broadcast(sizes=etos.sizes),
413
389
  event_time_offset=etos,
390
+ pulse_period=pulse_period,
414
391
  )
415
392
 
416
393
  return rebinned.assign_coords(tof=tofs)
@@ -420,6 +397,7 @@ def _guess_pulse_stride_offset(
420
397
  pulse_index: sc.Variable,
421
398
  ltotal: sc.Variable,
422
399
  event_time_offset: sc.Variable,
400
+ pulse_period: sc.Variable,
423
401
  pulse_stride: int,
424
402
  interp: TofInterpolator,
425
403
  ) -> int:
@@ -446,6 +424,8 @@ def _guess_pulse_stride_offset(
446
424
  Total length of the flight path from the source to the detector for each event.
447
425
  event_time_offset:
448
426
  Time of arrival of the neutron at the detector for each event.
427
+ pulse_period:
428
+ Period of the source pulses, i.e., time between consecutive pulse starts.
449
429
  pulse_stride:
450
430
  Stride of used pulses.
451
431
  interp:
@@ -469,7 +449,12 @@ def _guess_pulse_stride_offset(
469
449
  )
470
450
  for i in range(pulse_stride):
471
451
  pulse_inds = (pulse_index + i) % pulse_stride
472
- tofs[i] = interp(pulse_index=pulse_inds, ltotal=ltotal, event_time_offset=etos)
452
+ tofs[i] = interp(
453
+ ltotal=ltotal,
454
+ event_time_offset=etos,
455
+ pulse_index=pulse_inds,
456
+ pulse_period=pulse_period,
457
+ )
473
458
  # Find the entry in the list with the least number of nan values
474
459
  return sorted(tofs, key=lambda x: sc.isnan(tofs[x]).sum())[0]
475
460
 
@@ -478,8 +463,6 @@ def _time_of_flight_data_events(
478
463
  da: sc.DataArray,
479
464
  lookup: sc.DataArray,
480
465
  ltotal: sc.Variable,
481
- pulse_period: sc.Variable,
482
- pulse_stride: int,
483
466
  pulse_stride_offset: int,
484
467
  ) -> sc.DataArray:
485
468
  etos = da.bins.coords["event_time_offset"].to(dtype=float, copy=False)
@@ -492,20 +475,21 @@ def _time_of_flight_data_events(
492
475
  ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"]
493
476
  etos = etos.bins.constituents["data"]
494
477
 
495
- # Compute a pulse index for every event: it is the index of the pulse within a
496
- # frame period. When there is no pulse skipping, those are all zero. When there is
497
- # pulse skipping, the index ranges from zero to pulse_stride - 1.
498
- if pulse_stride == 1:
499
- pulse_index = sc.zeros(sizes=etos.sizes)
500
- else:
478
+ pulse_index = None
479
+ pulse_period = lookup.coords["pulse_period"].to(unit=eto_unit)
480
+ pulse_stride = lookup.coords["pulse_stride"].value
481
+
482
+ if pulse_stride > 1:
483
+ # Compute a pulse index for every event: it is the index of the pulse within a
484
+ # frame period. The index ranges from zero to pulse_stride - 1.
501
485
  etz_unit = 'ns'
502
486
  etz = (
503
487
  da.bins.coords["event_time_zero"]
504
488
  .bins.constituents["data"]
505
489
  .to(unit=etz_unit, copy=False)
506
490
  )
507
- pulse_period = pulse_period.to(unit=etz_unit, dtype=int)
508
- frame_period = pulse_period * pulse_stride
491
+ pulse_period_ns = pulse_period.to(unit=etz_unit, dtype=int)
492
+ frame_period = pulse_period_ns * pulse_stride
509
493
  # Define a common reference time using epoch as a base, but making sure that it
510
494
  # is aligned with the pulse_period and the frame_period.
511
495
  # We need to use a global reference time instead of simply taking the minimum
@@ -513,17 +497,17 @@ def _time_of_flight_data_events(
513
497
  # may not be the first event of the first pulse for all chunks. This would lead
514
498
  # to inconsistent pulse indices.
515
499
  epoch = sc.datetime(0, unit=etz_unit)
516
- diff_to_epoch = (etz.min() - epoch) % pulse_period
500
+ diff_to_epoch = (etz.min() - epoch) % pulse_period_ns
517
501
  # Here we offset the reference by half a pulse period to avoid errors from
518
502
  # fluctuations in the event_time_zeros in the data. They are triggered by the
519
503
  # neutron source, and may not always be exactly separated by the pulse period.
520
504
  # While fluctuations will exist, they will be small, and offsetting the times
521
505
  # by half a pulse period is a simple enough fix.
522
- reference = epoch + diff_to_epoch - (pulse_period // 2)
506
+ reference = epoch + diff_to_epoch - (pulse_period_ns // 2)
523
507
  # Use in-place operations to avoid large allocations
524
508
  pulse_index = etz - reference
525
509
  pulse_index %= frame_period
526
- pulse_index //= pulse_period
510
+ pulse_index //= pulse_period_ns
527
511
 
528
512
  # Apply the pulse_stride_offset
529
513
  if pulse_stride_offset is None:
@@ -531,6 +515,7 @@ def _time_of_flight_data_events(
531
515
  pulse_index=pulse_index,
532
516
  ltotal=ltotal,
533
517
  event_time_offset=etos,
518
+ pulse_period=pulse_period,
534
519
  pulse_stride=pulse_stride,
535
520
  interp=interp,
536
521
  )
@@ -538,21 +523,86 @@ def _time_of_flight_data_events(
538
523
  pulse_index %= pulse_stride
539
524
 
540
525
  # Compute time-of-flight for all neutrons using the interpolator
541
- tofs = interp(pulse_index=pulse_index, ltotal=ltotal, event_time_offset=etos)
526
+ tofs = interp(
527
+ ltotal=ltotal,
528
+ event_time_offset=etos,
529
+ pulse_index=pulse_index,
530
+ pulse_period=pulse_period,
531
+ )
542
532
 
543
533
  parts = da.bins.constituents
544
534
  parts["data"] = tofs
545
535
  return da.bins.assign_coords(tof=_bins_no_validate(**parts))
546
536
 
547
537
 
548
- def time_of_flight_data(
549
- da: RawData,
538
+ def detector_ltotal_from_straight_line_approximation(
539
+ detector_beamline: CalibratedBeamline[RunType],
540
+ ) -> DetectorLtotal[RunType]:
541
+ """
542
+ Compute Ltotal for the detector pixels.
543
+ This is a naive straight-line approximation to Ltotal based on basic component
544
+ positions.
545
+
546
+ Parameters
547
+ ----------
548
+ detector_beamline:
549
+ Beamline data for the detector that contains the positions necessary to compute
550
+ the straight-line approximation to Ltotal (source, sample, and detector
551
+ positions).
552
+ """
553
+ graph = scn.conversion.graph.beamline.beamline(scatter=True)
554
+ return DetectorLtotal[RunType](
555
+ detector_beamline.transform_coords(
556
+ "Ltotal", graph=graph, keep_intermediate=False
557
+ ).coords["Ltotal"]
558
+ )
559
+
560
+
561
+ def monitor_ltotal_from_straight_line_approximation(
562
+ monitor_beamline: CalibratedMonitor[RunType, MonitorType],
563
+ ) -> MonitorLtotal[RunType, MonitorType]:
564
+ """
565
+ Compute Ltotal for the monitor.
566
+ This is a naive straight-line approximation to Ltotal based on basic component
567
+ positions.
568
+
569
+ Parameters
570
+ ----------
571
+ monitor_beamline:
572
+ Beamline data for the monitor that contains the positions necessary to compute
573
+ the straight-line approximation to Ltotal (source and monitor positions).
574
+ """
575
+ graph = scn.conversion.graph.beamline.beamline(scatter=False)
576
+ return MonitorLtotal[RunType, MonitorType](
577
+ monitor_beamline.transform_coords(
578
+ "Ltotal", graph=graph, keep_intermediate=False
579
+ ).coords["Ltotal"]
580
+ )
581
+
582
+
583
+ def _compute_tof_data(
584
+ da: sc.DataArray,
585
+ lookup: sc.DataArray,
586
+ ltotal: sc.Variable,
587
+ pulse_stride_offset: int,
588
+ ) -> sc.DataArray:
589
+ if da.bins is None:
590
+ return _time_of_flight_data_histogram(da=da, lookup=lookup, ltotal=ltotal)
591
+ else:
592
+ return _time_of_flight_data_events(
593
+ da=da,
594
+ lookup=lookup,
595
+ ltotal=ltotal,
596
+ pulse_stride_offset=pulse_stride_offset,
597
+ )
598
+
599
+
600
+ def detector_time_of_flight_data(
601
+ detector_data: DetectorData[RunType],
550
602
  lookup: TimeOfFlightLookupTable,
551
- ltotal: Ltotal,
552
- pulse_period: PulsePeriod,
553
- pulse_stride: PulseStride,
603
+ ltotal: DetectorLtotal[RunType],
554
604
  pulse_stride_offset: PulseStrideOffset,
555
- ) -> TofData:
605
+ ) -> DetectorTofData[RunType]:
556
606
  """
557
607
  Convert the time-of-arrival data to time-of-flight data using a lookup table.
558
608
  The output data will have a time-of-flight coordinate.
@@ -567,33 +617,55 @@ def time_of_flight_data(
567
617
  arrival.
568
618
  ltotal:
569
619
  Total length of the flight path from the source to the detector.
570
- pulse_period:
571
- Period of the source pulses, i.e., time between consecutive pulse starts.
572
- pulse_stride:
573
- Stride of used pulses. Usually 1, but may be a small integer when
574
- pulse-skipping.
575
620
  pulse_stride_offset:
576
621
  When pulse-skipping, the offset of the first pulse in the stride. This is
577
622
  typically zero but can be a small integer < pulse_stride.
578
623
  """
579
-
580
- if da.bins is None:
581
- out = _time_of_flight_data_histogram(
582
- da=da, lookup=lookup, ltotal=ltotal, pulse_period=pulse_period
624
+ return DetectorTofData[RunType](
625
+ _compute_tof_data(
626
+ da=detector_data,
627
+ lookup=lookup,
628
+ ltotal=ltotal,
629
+ pulse_stride_offset=pulse_stride_offset,
583
630
  )
584
- else:
585
- out = _time_of_flight_data_events(
586
- da=da,
631
+ )
632
+
633
+
634
+ def monitor_time_of_flight_data(
635
+ monitor_data: MonitorData[RunType, MonitorType],
636
+ lookup: TimeOfFlightLookupTable,
637
+ ltotal: MonitorLtotal[RunType, MonitorType],
638
+ pulse_stride_offset: PulseStrideOffset,
639
+ ) -> MonitorTofData[RunType, MonitorType]:
640
+ """
641
+ Convert the time-of-arrival data to time-of-flight data using a lookup table.
642
+ The output data will have a time-of-flight coordinate.
643
+
644
+ Parameters
645
+ ----------
646
+ da:
647
+ Raw monitor data loaded from a NeXus file, e.g., NXmonitor containing
648
+ NXevent_data.
649
+ lookup:
650
+ Lookup table giving time-of-flight as a function of distance and time of
651
+ arrival.
652
+ ltotal:
653
+ Total length of the flight path from the source to the monitor.
654
+ pulse_stride_offset:
655
+ When pulse-skipping, the offset of the first pulse in the stride. This is
656
+ typically zero but can be a small integer < pulse_stride.
657
+ """
658
+ return MonitorTofData[RunType, MonitorType](
659
+ _compute_tof_data(
660
+ da=monitor_data,
587
661
  lookup=lookup,
588
662
  ltotal=ltotal,
589
- pulse_period=pulse_period,
590
- pulse_stride=pulse_stride,
591
663
  pulse_stride_offset=pulse_stride_offset,
592
664
  )
593
- return TofData(out)
665
+ )
594
666
 
595
667
 
596
- def resample_tof_data(da: TofData) -> ResampledTofData:
668
+ def _resample_tof_data(da: sc.DataArray) -> sc.DataArray:
597
669
  """
598
670
  Histogrammed data that has been converted to `tof` will typically have
599
671
  unsorted bin edges (due to either wrapping of `time_of_flight` or wavelength
@@ -626,13 +698,29 @@ def resample_tof_data(da: TofData) -> ResampledTofData:
626
698
  coord = da.coords["tof"]
627
699
  bin_width = (coord[dim, 1:] - coord[dim, :-1]).nanmedian()
628
700
  rehist = events.hist(tof=bin_width)
629
- return ResampledTofData(
630
- rehist.assign_coords(
631
- {key: var for key, var in da.coords.items() if dim not in var.dims}
632
- )
701
+ return rehist.assign_coords(
702
+ {key: var for key, var in da.coords.items() if dim not in var.dims}
633
703
  )
634
704
 
635
705
 
706
+ def resample_detector_time_of_flight_data(
707
+ da: DetectorTofData[RunType],
708
+ ) -> ResampledDetectorTofData[RunType]:
709
+ """
710
+ Resample the detector time-of-flight data to ensure that the bin edges are sorted.
711
+ """
712
+ return ResampledDetectorTofData(_resample_tof_data(da))
713
+
714
+
715
+ def resample_monitor_time_of_flight_data(
716
+ da: MonitorTofData[RunType, MonitorType],
717
+ ) -> ResampledMonitorTofData[RunType, MonitorType]:
718
+ """
719
+ Resample the monitor time-of-flight data to ensure that the bin edges are sorted.
720
+ """
721
+ return ResampledMonitorTofData(_resample_tof_data(da))
722
+
723
+
636
724
  def default_parameters() -> dict:
637
725
  """
638
726
  Default parameters of the time-of-flight workflow.
@@ -651,4 +739,10 @@ def providers() -> tuple[Callable]:
651
739
  """
652
740
  Providers of the time-of-flight workflow.
653
741
  """
654
- return (compute_tof_lookup_table, time_of_flight_data)
742
+ return (
743
+ compute_tof_lookup_table,
744
+ detector_time_of_flight_data,
745
+ monitor_time_of_flight_data,
746
+ detector_ltotal_from_straight_line_approximation,
747
+ monitor_ltotal_from_straight_line_approximation,
748
+ )
@@ -28,9 +28,9 @@ class FakeBeamline:
28
28
  import math
29
29
 
30
30
  import tof as tof_pkg
31
- from tof.facilities.ess_pulse import pulse
31
+ from tof.facilities.ess_pulse import frequency as ess_frequency
32
32
 
33
- self.frequency = pulse.frequency
33
+ self.frequency = ess_frequency
34
34
  self.npulses = math.ceil((run_length * self.frequency).to(unit="").value)
35
35
  self.events_per_pulse = events_per_pulse
36
36
  if source_position is None:
@@ -8,111 +8,82 @@ from numba import njit, prange
8
8
  def interpolate(
9
9
  x: np.ndarray,
10
10
  y: np.ndarray,
11
- z: np.ndarray,
12
11
  values: np.ndarray,
13
12
  xp: np.ndarray,
14
13
  yp: np.ndarray,
15
- zp: np.ndarray,
14
+ xoffset: np.ndarray | None,
15
+ deltax: float,
16
16
  fill_value: float,
17
17
  out: np.ndarray,
18
18
  ):
19
19
  """
20
- Linear interpolation of data on a 3D regular grid.
20
+ Linear interpolation of data on a 2D regular grid.
21
21
 
22
22
  Parameters
23
23
  ----------
24
24
  x:
25
- 1D array of grid edges along the x-axis. They must be linspaced.
25
+ 1D array of grid edges along the x-axis (size nx). They must be linspaced.
26
26
  y:
27
- 1D array of grid edges along the y-axis. They must be linspaced.
28
- z:
29
- 1D array of grid edges along the z-axis. They must be linspaced.
27
+ 1D array of grid edges along the y-axis (size ny). They must be linspaced.
30
28
  values:
31
- 3D array of values on the grid. The shape must be (nz, ny, nx).
29
+ 2D array of values on the grid. The shape must be (ny, nx).
32
30
  xp:
33
31
  1D array of x-coordinates where to interpolate (size N).
34
32
  yp:
35
33
  1D array of y-coordinates where to interpolate (size N).
36
- zp:
37
- 1D array of z-coordinates where to interpolate (size N).
34
+ xoffset:
35
+ 1D array of integer offsets to apply to the x-coordinates (size N).
36
+ deltax:
37
+ Multiplier to apply to the integer offsets (i.e. the step size).
38
38
  fill_value:
39
39
  Value to use for points outside of the grid.
40
40
  out:
41
41
  1D array where the interpolated values will be stored (size N).
42
42
  """
43
- if not (len(xp) == len(yp) == len(zp) == len(out)):
43
+ if not (len(xp) == len(yp) == len(out)):
44
44
  raise ValueError("Interpolator: all input arrays must have the same size.")
45
45
 
46
46
  nx = len(x)
47
47
  ny = len(y)
48
- nz = len(z)
49
48
  npoints = len(xp)
50
49
  xmin = x[0]
51
50
  xmax = x[nx - 1]
52
51
  ymin = y[0]
53
52
  ymax = y[ny - 1]
54
- zmin = z[0]
55
- zmax = z[nz - 1]
56
53
  dx = x[1] - xmin
57
54
  dy = y[1] - ymin
58
- dz = z[1] - zmin
59
55
 
60
56
  one_over_dx = 1.0 / dx
61
57
  one_over_dy = 1.0 / dy
62
- one_over_dz = 1.0 / dz
63
- norm = one_over_dx * one_over_dy * one_over_dz
58
+ norm = one_over_dx * one_over_dy
64
59
 
65
60
  for i in prange(npoints):
66
- xx = xp[i]
61
+ xx = xp[i] + (xoffset[i] * deltax if xoffset is not None else 0.0)
67
62
  yy = yp[i]
68
- zz = zp[i]
69
-
70
- if (
71
- (xx < xmin)
72
- or (xx > xmax)
73
- or (yy < ymin)
74
- or (yy > ymax)
75
- or (zz < zmin)
76
- or (zz > zmax)
77
- ):
63
+
64
+ if (xx < xmin) or (xx > xmax) or (yy < ymin) or (yy > ymax):
78
65
  out[i] = fill_value
79
66
 
80
67
  else:
81
68
  ix = nx - 2 if xx == xmax else int((xx - xmin) * one_over_dx)
82
69
  iy = ny - 2 if yy == ymax else int((yy - ymin) * one_over_dy)
83
- iz = nz - 2 if zz == zmax else int((zz - zmin) * one_over_dz)
84
70
 
85
71
  x1 = x[ix]
86
72
  x2 = x[ix + 1]
87
73
  y1 = y[iy]
88
74
  y2 = y[iy + 1]
89
- z1 = z[iz]
90
- z2 = z[iz + 1]
91
-
92
- a111 = values[iz, iy, ix]
93
- a211 = values[iz, iy, ix + 1]
94
- a121 = values[iz, iy + 1, ix]
95
- a221 = values[iz, iy + 1, ix + 1]
96
- a112 = values[iz + 1, iy, ix]
97
- a212 = values[iz + 1, iy, ix + 1]
98
- a122 = values[iz + 1, iy + 1, ix]
99
- a222 = values[iz + 1, iy + 1, ix + 1]
75
+
76
+ a11 = values[iy, ix]
77
+ a21 = values[iy, ix + 1]
78
+ a12 = values[iy + 1, ix]
79
+ a22 = values[iy + 1, ix + 1]
100
80
 
101
81
  x2mxx = x2 - xx
102
82
  xxmx1 = xx - x1
103
- y2myy = y2 - yy
104
- yymy1 = yy - y1
83
+
105
84
  out[i] = (
106
- (z2 - zz)
107
- * (
108
- y2myy * (x2mxx * a111 + xxmx1 * a211)
109
- + yymy1 * (x2mxx * a121 + xxmx1 * a221)
110
- )
111
- + (zz - z1)
112
- * (
113
- y2myy * (x2mxx * a112 + xxmx1 * a212)
114
- + yymy1 * (x2mxx * a122 + xxmx1 * a222)
115
- )
85
+ (y2 - yy) * (x2mxx * a11 + xxmx1 * a21)
86
+ + (yy - y1) * (x2mxx * a12 + xxmx1 * a22)
116
87
  ) * norm
117
88
 
118
89
 
@@ -121,12 +92,11 @@ class Interpolator:
121
92
  self,
122
93
  time_edges: np.ndarray,
123
94
  distance_edges: np.ndarray,
124
- pulse_edges: np.ndarray,
125
95
  values: np.ndarray,
126
96
  fill_value: float = np.nan,
127
97
  ):
128
98
  """
129
- Interpolator for 3D regular grid data (Numba implementation).
99
+ Interpolator for 2D regular grid data (Numba implementation).
130
100
 
131
101
  Parameters
132
102
  ----------
@@ -134,31 +104,32 @@ class Interpolator:
134
104
  1D array of time edges.
135
105
  distance_edges:
136
106
  1D array of distance edges.
137
- pulse_edges:
138
- 1D array of pulse edges.
139
107
  values:
140
- 3D array of values on the grid. The shape must be (nz, ny, nx).
108
+ 2D array of values on the grid. The shape must be (ny, nx).
141
109
  fill_value:
142
110
  Value to use for points outside of the grid.
143
111
  """
144
112
  self.time_edges = time_edges
145
113
  self.distance_edges = distance_edges
146
- self.pulse_edges = pulse_edges
147
114
  self.values = values
148
115
  self.fill_value = fill_value
149
116
 
150
117
  def __call__(
151
- self, times: np.ndarray, distances: np.ndarray, pulse_indices: np.ndarray
118
+ self,
119
+ times: np.ndarray,
120
+ distances: np.ndarray,
121
+ pulse_period: float = 0.0,
122
+ pulse_index: np.ndarray | None = None,
152
123
  ) -> np.ndarray:
153
124
  out = np.empty_like(times)
154
125
  interpolate(
155
126
  x=self.time_edges,
156
127
  y=self.distance_edges,
157
- z=self.pulse_edges,
158
128
  values=self.values,
159
129
  xp=times,
160
130
  yp=distances,
161
- zp=pulse_indices,
131
+ xoffset=pulse_index,
132
+ deltax=pulse_period,
162
133
  fill_value=self.fill_value,
163
134
  out=out,
164
135
  )
@@ -9,7 +9,6 @@ class Interpolator:
9
9
  self,
10
10
  time_edges: np.ndarray,
11
11
  distance_edges: np.ndarray,
12
- pulse_edges: np.ndarray,
13
12
  values: np.ndarray,
14
13
  method: str = "linear",
15
14
  bounds_error: bool = False,
@@ -17,18 +16,16 @@ class Interpolator:
17
16
  **kwargs,
18
17
  ):
19
18
  """
20
- Interpolator for 3D regular grid data (SciPy implementation).
19
+ Interpolator for 2D regular grid data (SciPy implementation).
21
20
 
22
21
  Parameters
23
22
  ----------
24
23
  time_edges:
25
- 1D array of time edges.
24
+ 1D array of time edges (length N_time).
26
25
  distance_edges:
27
- 1D array of distance edges.
28
- pulse_edges:
29
- 1D array of pulse edges.
26
+ 1D array of distance edges (length N_dist).
30
27
  values:
31
- 3D array of values on the grid. The shape must be (nz, ny, nx).
28
+ 2D array of values on the grid. The shape must be (N_dist, N_time).
32
29
  method:
33
30
  Method of interpolation. Default is "linear".
34
31
  bounds_error:
@@ -42,11 +39,7 @@ class Interpolator:
42
39
  from scipy.interpolate import RegularGridInterpolator
43
40
 
44
41
  self._interp = RegularGridInterpolator(
45
- (
46
- pulse_edges,
47
- distance_edges,
48
- time_edges,
49
- ),
42
+ (distance_edges, time_edges),
50
43
  values,
51
44
  method=method,
52
45
  bounds_error=bounds_error,
@@ -55,6 +48,12 @@ class Interpolator:
55
48
  )
56
49
 
57
50
  def __call__(
58
- self, times: np.ndarray, distances: np.ndarray, pulse_indices: np.ndarray
51
+ self,
52
+ times: np.ndarray,
53
+ distances: np.ndarray,
54
+ pulse_period: float = 0.0,
55
+ pulse_index: np.ndarray | None = None,
59
56
  ) -> np.ndarray:
60
- return self._interp((pulse_indices, distances, times))
57
+ if pulse_index is not None:
58
+ times = times + (pulse_index * pulse_period)
59
+ return self._interp((distances, times))
@@ -4,12 +4,10 @@
4
4
  from dataclasses import dataclass
5
5
  from typing import NewType
6
6
 
7
+ import sciline as sl
7
8
  import scipp as sc
8
9
 
9
- Ltotal = NewType("Ltotal", sc.Variable)
10
- """
11
- Total length of the flight path from the source to the detector.
12
- """
10
+ from ..nexus.types import MonitorType, RunType
13
11
 
14
12
 
15
13
  @dataclass
@@ -107,30 +105,58 @@ When pulse-skipping, the offset of the first pulse in the stride. This is typica
107
105
  zero but can be a small integer < pulse_stride. If None, a guess is made.
108
106
  """
109
107
 
110
- RawData = NewType("RawData", sc.DataArray)
111
- """
112
- Raw detector data loaded from a NeXus file, e.g., NXdetector containing NXevent_data.
113
- """
114
108
 
115
- TofData = NewType("TofData", sc.DataArray)
116
- """
117
- Detector data with time-of-flight coordinate.
118
- """
109
+ class DetectorLtotal(sl.Scope[RunType, sc.Variable], sc.Variable):
110
+ """Total path length of neutrons from source to detector (L1 + L2)."""
119
111
 
120
- ResampledTofData = NewType("ResampledTofData", sc.DataArray)
121
- """
122
- Histogrammed detector data with time-of-flight coordinate, that has been resampled.
123
112
 
124
- Histogrammed data that has been converted to `tof` will typically have
125
- unsorted bin edges (due to either wrapping of `time_of_flight` or wavelength
126
- overlap between subframes).
127
- We thus resample the data to ensure that the bin edges are sorted.
128
- It makes use of the ``to_events`` helper which generates a number of events in each
129
- bin with a uniform distribution. The new events are then histogrammed using a set of
130
- sorted bin edges to yield a new histogram with sorted bin edges.
113
+ class MonitorLtotal(sl.Scope[RunType, MonitorType, sc.Variable], sc.Variable):
114
+ """Total path length of neutrons from source to monitor."""
131
115
 
132
- WARNING:
133
- This function is highly experimental, has limitations and should be used with
134
- caution. It is a workaround to the issue that rebinning data with unsorted bin
135
- edges is not supported in scipp.
136
- """
116
+
117
+ class DetectorTofData(sl.Scope[RunType, sc.DataArray], sc.DataArray):
118
+ """Detector data with time-of-flight coordinate."""
119
+
120
+
121
+ class MonitorTofData(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
122
+ """Monitor data with time-of-flight coordinate."""
123
+
124
+
125
+ class ResampledDetectorTofData(sl.Scope[RunType, sc.DataArray], sc.DataArray):
126
+ """
127
+ Histogrammed detector data with time-of-flight coordinate, that has been resampled.
128
+
129
+ Histogrammed data that has been converted to `tof` will typically have
130
+ unsorted bin edges (due to either wrapping of `time_of_flight` or wavelength
131
+ overlap between subframes).
132
+ We thus resample the data to ensure that the bin edges are sorted.
133
+ It makes use of the ``to_events`` helper which generates a number of events in each
134
+ bin with a uniform distribution. The new events are then histogrammed using a set of
135
+ sorted bin edges to yield a new histogram with sorted bin edges.
136
+
137
+ WARNING:
138
+ This function is highly experimental, has limitations and should be used with
139
+ caution. It is a workaround to the issue that rebinning data with unsorted bin
140
+ edges is not supported in scipp.
141
+ """
142
+
143
+
144
+ class ResampledMonitorTofData(
145
+ sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray
146
+ ):
147
+ """
148
+ Histogrammed monitor data with time-of-flight coordinate, that has been resampled.
149
+
150
+ Histogrammed data that has been converted to `tof` will typically have
151
+ unsorted bin edges (due to either wrapping of `time_of_flight` or wavelength
152
+ overlap between subframes).
153
+ We thus resample the data to ensure that the bin edges are sorted.
154
+ It makes use of the ``to_events`` helper which generates a number of events in each
155
+ bin with a uniform distribution. The new events are then histogrammed using a set of
156
+ sorted bin edges to yield a new histogram with sorted bin edges.
157
+
158
+ WARNING:
159
+ This function is highly experimental, has limitations and should be used with
160
+ caution. It is a workaround to the issue that rebinning data with unsorted bin
161
+ edges is not supported in scipp.
162
+ """
@@ -0,0 +1,61 @@
1
+ # SPDX-License-Identifier: BSD-3-Clause
2
+ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
3
+ from collections.abc import Iterable
4
+
5
+ import sciline
6
+
7
+ from ..nexus import GenericNeXusWorkflow
8
+ from ..utils import prune_type_vars
9
+ from .eto_to_tof import default_parameters, providers
10
+
11
+
12
+ def GenericTofWorkflow(
13
+ *,
14
+ run_types: Iterable[sciline.typing.Key] | None = None,
15
+ monitor_types: Iterable[sciline.typing.Key] | None = None,
16
+ ) -> sciline.Pipeline:
17
+ """
18
+ Generic workflow for computing the neutron time-of-flight for detector and monitor
19
+ data.
20
+ This workflow builds on the ``GenericNeXusWorkflow`` and computes time-of-flight
21
+ from a lookup table that is created from the chopper settings, detector Ltotal and
22
+ the neutron time-of-arrival.
23
+
24
+ It is possible to limit which run types and monitor types
25
+ are supported by the returned workflow.
26
+ This is useful to reduce the size of the workflow and make it easier to inspect.
27
+ Make sure to add *all* required run types and monitor types when using this feature.
28
+
29
+ Attention
30
+ ---------
31
+ Filtering by run type and monitor type does not work with nested type vars.
32
+ E.g., if you have a type like ``Outer[Inner[RunType]]``, this type and its
33
+ provider will be removed.
34
+
35
+ Parameters
36
+ ----------
37
+ run_types:
38
+ List of run types to include in the workflow. If not provided, all run types
39
+ are included.
40
+ Must be a possible value of :class:`ess.reduce.nexus.types.RunType`.
41
+ monitor_types:
42
+ List of monitor types to include in the workflow. If not provided, all monitor
43
+ types are included.
44
+ Must be a possible value of :class:`ess.reduce.nexus.types.MonitorType`.
45
+
46
+ Returns
47
+ -------
48
+ :
49
+ The workflow.
50
+ """
51
+ wf = GenericNeXusWorkflow(run_types=run_types, monitor_types=monitor_types)
52
+
53
+ for provider in providers():
54
+ wf.insert(provider)
55
+ for key, value in default_parameters().items():
56
+ wf[key] = value
57
+
58
+ if run_types is not None or monitor_types is not None:
59
+ prune_type_vars(wf, run_types=run_types, monitor_types=monitor_types)
60
+
61
+ return wf
ess/reduce/utils.py ADDED
@@ -0,0 +1,36 @@
1
+ # SPDX-License-Identifier: BSD-3-Clause
2
+ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
3
+
4
+ from collections.abc import Iterable
5
+ from typing import Any
6
+
7
+ import sciline
8
+
9
+ from .nexus.types import MonitorType, RunType
10
+
11
+
12
+ def prune_type_vars(
13
+ workflow: sciline.Pipeline,
14
+ *,
15
+ run_types: Iterable[sciline.typing.Key] | None,
16
+ monitor_types: Iterable[sciline.typing.Key] | None,
17
+ ) -> None:
18
+ # Remove all nodes that use a run type or monitor types that is
19
+ # not listed in the function arguments.
20
+ excluded_run_types = excluded_type_args(RunType, run_types)
21
+ excluded_monitor_types = excluded_type_args(MonitorType, monitor_types)
22
+ excluded_types = excluded_run_types | excluded_monitor_types
23
+
24
+ graph = workflow.underlying_graph
25
+ to_remove = [
26
+ node for node in graph if excluded_types & set(getattr(node, "__args__", set()))
27
+ ]
28
+ graph.remove_nodes_from(to_remove)
29
+
30
+
31
+ def excluded_type_args(
32
+ type_var: Any, keep: Iterable[sciline.typing.Key] | None
33
+ ) -> set[sciline.typing.Key]:
34
+ if keep is None:
35
+ return set()
36
+ return set(type_var.__constraints__) - set(keep)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: essreduce
3
- Version: 25.4.1
3
+ Version: 25.5.0
4
4
  Summary: Common data reduction tools for the ESS facility
5
5
  Author: Scipp contributors
6
6
  License: BSD 3-Clause License
@@ -61,7 +61,7 @@ Requires-Dist: numba; extra == "test"
61
61
  Requires-Dist: pooch; extra == "test"
62
62
  Requires-Dist: pytest; extra == "test"
63
63
  Requires-Dist: scipy>=1.7.0; extra == "test"
64
- Requires-Dist: tof>=25.01.2; extra == "test"
64
+ Requires-Dist: tof>=25.05.0; extra == "test"
65
65
  Dynamic: license-file
66
66
 
67
67
  [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md)
@@ -6,6 +6,7 @@ ess/reduce/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  ess/reduce/streaming.py,sha256=TBttQV5WdSpUKh38J0pdv53seMWtUFswxd6-ltaZb_M,17403
7
7
  ess/reduce/ui.py,sha256=zmorAbDwX1cU3ygDT--OP58o0qU7OBcmJz03jPeYSLA,10884
8
8
  ess/reduce/uncertainty.py,sha256=LR4O6ApB6Z-W9gC_XW0ajupl8yFG-du0eee1AX_R-gk,6990
9
+ ess/reduce/utils.py,sha256=RBAfJRNil6JjVF-jPaxeL0ssEEfPBBQEZ3ObEorpDLo,1132
9
10
  ess/reduce/workflow.py,sha256=sL34T_2Cjl_8iFlegujxI9VyOUwo6erVC8pOXnfWgYw,3060
10
11
  ess/reduce/live/__init__.py,sha256=jPQVhihRVNtEDrE20PoKkclKV2aBF1lS7cCHootgFgI,204
11
12
  ess/reduce/live/raw.py,sha256=66qV0G2rP8gK5tXuk-syTlDLE2jT3ehfmSnET7Xzfd0,24392
@@ -16,16 +17,17 @@ ess/reduce/nexus/_nexus_loader.py,sha256=5N48AMJx1AaFZb6WZPPbVKUlXyFMVVtZrn7Bae5
16
17
  ess/reduce/nexus/json_generator.py,sha256=ME2Xn8L7Oi3uHJk9ZZdCRQTRX-OV_wh9-DJn07Alplk,2529
17
18
  ess/reduce/nexus/json_nexus.py,sha256=QrVc0p424nZ5dHX9gebAJppTw6lGZq9404P_OFl1giA,10282
18
19
  ess/reduce/nexus/types.py,sha256=vTQD4oQ5JKBHAYy9LWFICSo-dhVi3wX5IinMgjRDtF8,9806
19
- ess/reduce/nexus/workflow.py,sha256=EiD6-58eGwoN5fbo47UTZy_oYFitCbwlIH-xqDOSp4c,24326
20
+ ess/reduce/nexus/workflow.py,sha256=zrBQGNLUxmvqXewe9uNUg9aP43_glfFD6nh5VGAtBK4,23456
20
21
  ess/reduce/scripts/grow_nexus.py,sha256=hET3h06M0xlJd62E3palNLFvJMyNax2kK4XyJcOhl-I,3387
21
- ess/reduce/time_of_flight/__init__.py,sha256=TSHfyoROwFhM2k3jHzamw3zeb0OQOaiuvgCgDEPEQ_g,1097
22
- ess/reduce/time_of_flight/eto_to_tof.py,sha256=Nq2gx7aejoZ_ExLTr9I6KZMqDxCKAx1PpGHslpNXkKU,25271
23
- ess/reduce/time_of_flight/fakes.py,sha256=REyHkJsSSq2_l5UOtpsv2aKkhCuro_i3KpVsxxITbW0,4470
24
- ess/reduce/time_of_flight/interpolator_numba.py,sha256=AgB2R8iw-IOb3YXLWTQVBflhWq5qgb7aqfvDExwLRW8,4682
25
- ess/reduce/time_of_flight/interpolator_scipy.py,sha256=sRJj2ncBiUMv6g9h-MJzI9xyY0Ir0degpAv6FIeSMBw,1834
22
+ ess/reduce/time_of_flight/__init__.py,sha256=v86c3zNTMMqZoR9eHaK0Q-JnzsbOI6XsBGI3mgy2CiU,1469
23
+ ess/reduce/time_of_flight/eto_to_tof.py,sha256=ckXoSrltXdciYwipyUkF-DVtbsz2_XSLZvX2qJ_d8Bs,28238
24
+ ess/reduce/time_of_flight/fakes.py,sha256=0gtbSX3ZQilaM4ZP5dMr3fqbnhpyoVsZX2YEb8GgREE,4489
25
+ ess/reduce/time_of_flight/interpolator_numba.py,sha256=wh2YS3j2rOu30v1Ok3xNHcwS7t8eEtZyZvbfXOCtgrQ,3835
26
+ ess/reduce/time_of_flight/interpolator_scipy.py,sha256=_InoAPuMm2qhJKZQBAHOGRFqtvvuQ8TStoN7j_YgS4M,1853
26
27
  ess/reduce/time_of_flight/simulation.py,sha256=cIF_nWkLQlcWUCW2_wvWBU2ocg_8CSfOnfkoqdLdUgs,2923
27
28
  ess/reduce/time_of_flight/to_events.py,sha256=w9mHpnWd3vwN2ouob-GK_1NPrTjCaOzPuC2QuEey-m0,4342
28
- ess/reduce/time_of_flight/types.py,sha256=Iv1XGLbrZ9bD4CPAVhsIPkAaB46YC7l7yf5XweljLqk,5047
29
+ ess/reduce/time_of_flight/types.py,sha256=xhziZQaCB4XAxvVopHHp2DZSBj7PUt-xgPzEDpni05g,6321
30
+ ess/reduce/time_of_flight/workflow.py,sha256=-g9IyAz7sNrgL-5RZLUTlfjTb2YFej1Xig6GiC7c1bI,2156
29
31
  ess/reduce/widgets/__init__.py,sha256=SoSHBv8Dc3QXV9HUvPhjSYWMwKTGYZLpsWwsShIO97Q,5325
30
32
  ess/reduce/widgets/_base.py,sha256=_wN3FOlXgx_u0c-A_3yyoIH-SdUvDENGgquh9S-h5GI,4852
31
33
  ess/reduce/widgets/_binedges_widget.py,sha256=ZCQsGjYHnJr9GFUn7NjoZc1CdsnAzm_fMzyF-fTKKVY,2785
@@ -38,9 +40,9 @@ ess/reduce/widgets/_spinner.py,sha256=2VY4Fhfa7HMXox2O7UbofcdKsYG-AJGrsgGJB85nDX
38
40
  ess/reduce/widgets/_string_widget.py,sha256=iPAdfANyXHf-nkfhgkyH6gQDklia0LebLTmwi3m-iYQ,1482
39
41
  ess/reduce/widgets/_switchable_widget.py,sha256=fjKz99SKLhIF1BLgGVBSKKn3Lu_jYBwDYGeAjbJY3Q8,2390
40
42
  ess/reduce/widgets/_vector_widget.py,sha256=aTaBqCFHZQhrIoX6-sSqFWCPePEW8HQt5kUio8jP1t8,1203
41
- essreduce-25.4.1.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
42
- essreduce-25.4.1.dist-info/METADATA,sha256=_E84IwG_gnTsMoorflvf6T4K5oJB7IjpsUUTB4bhVh8,3768
43
- essreduce-25.4.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
44
- essreduce-25.4.1.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
45
- essreduce-25.4.1.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
46
- essreduce-25.4.1.dist-info/RECORD,,
43
+ essreduce-25.5.0.dist-info/licenses/LICENSE,sha256=nVEiume4Qj6jMYfSRjHTM2jtJ4FGu0g-5Sdh7osfEYw,1553
44
+ essreduce-25.5.0.dist-info/METADATA,sha256=yfoZMb19ayIQyCRk5_WPEuvrWGAYApWIs4Wr-69nwO8,3768
45
+ essreduce-25.5.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
46
+ essreduce-25.5.0.dist-info/entry_points.txt,sha256=PMZOIYzCifHMTe4pK3HbhxUwxjFaZizYlLD0td4Isb0,66
47
+ essreduce-25.5.0.dist-info/top_level.txt,sha256=0JxTCgMKPLKtp14wb1-RKisQPQWX7i96innZNvHBr-s,4
48
+ essreduce-25.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5