essreduce 25.1.1__py3-none-any.whl → 25.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,45 +17,21 @@ from scippneutron._utils import elem_unit
17
17
  from .to_events import to_events
18
18
  from .types import (
19
19
  DistanceResolution,
20
- FastestNeutron,
21
- FrameFoldedTimeOfArrival,
22
- FramePeriod,
23
20
  LookupTableRelativeErrorThreshold,
24
21
  Ltotal,
25
22
  LtotalRange,
26
- MaskedTimeOfFlightLookupTable,
27
- PivotTimeAtDetector,
28
23
  PulsePeriod,
29
24
  PulseStride,
30
25
  PulseStrideOffset,
31
26
  RawData,
32
27
  ResampledTofData,
33
28
  SimulationResults,
34
- TimeOfArrivalMinusPivotTimeModuloPeriod,
35
- TimeOfArrivalResolution,
36
29
  TimeOfFlightLookupTable,
30
+ TimeResolution,
37
31
  TofData,
38
- UnwrappedTimeOfArrival,
39
- UnwrappedTimeOfArrivalMinusPivotTime,
40
32
  )
41
33
 
42
34
 
43
- def frame_period(pulse_period: PulsePeriod, pulse_stride: PulseStride) -> FramePeriod:
44
- """
45
- Return the period of a frame, which is defined by the pulse period times the pulse
46
- stride.
47
-
48
- Parameters
49
- ----------
50
- pulse_period:
51
- Period of the source pulses, i.e., time between consecutive pulse starts.
52
- pulse_stride:
53
- Stride of used pulses. Usually 1, but may be a small integer when
54
- pulse-skipping.
55
- """
56
- return FramePeriod(pulse_period * pulse_stride)
57
-
58
-
59
35
  def extract_ltotal(da: RawData) -> Ltotal:
60
36
  """
61
37
  Extract the total length of the flight path from the source to the detector from the
@@ -70,80 +46,91 @@ def extract_ltotal(da: RawData) -> Ltotal:
70
46
  return Ltotal(da.coords["Ltotal"])
71
47
 
72
48
 
73
- def compute_tof_lookup_table(
49
+ def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float):
50
+ """
51
+ Mask regions with large uncertainty with NaNs.
52
+ The values are modified in place in the input table.
53
+
54
+ Parameters
55
+ ----------
56
+ table:
57
+ Lookup table with time-of-flight as a function of distance and time-of-arrival.
58
+ error_threshold:
59
+ Threshold for the relative standard deviation (coefficient of variation) of the
60
+ projected time-of-flight above which values are masked.
61
+ """
62
+ # Finally, mask regions with large uncertainty with NaNs.
63
+ relative_error = sc.stddevs(table.data) / sc.values(table.data)
64
+ mask = relative_error > sc.scalar(error_threshold)
65
+ # Use numpy for indexing as table is 2D
66
+ table.values[mask.values] = np.nan
67
+
68
+
69
+ def _compute_mean_tof_in_distance_range(
74
70
  simulation: SimulationResults,
75
- ltotal_range: LtotalRange,
76
- distance_resolution: DistanceResolution,
77
- toa_resolution: TimeOfArrivalResolution,
78
- ) -> TimeOfFlightLookupTable:
71
+ distance_bins: sc.Variable,
72
+ time_bins: sc.Variable,
73
+ distance_unit: str,
74
+ time_unit: str,
75
+ frame_period: sc.Variable,
76
+ time_bins_half_width: sc.Variable,
77
+ ) -> sc.DataArray:
79
78
  """
80
- Compute a lookup table for time-of-flight as a function of distance and
81
- time-of-arrival.
79
+ Compute the mean time-of-flight inside event_time_offset bins for a given range of
80
+ distances.
82
81
 
83
82
  Parameters
84
83
  ----------
85
84
  simulation:
86
85
  Results of a time-of-flight simulation used to create a lookup table.
87
- The results should be a flat table with columns for time-of-arrival, speed,
88
- wavelength, and weight.
89
- ltotal_range:
90
- Range of total flight path lengths from the source to the detector.
91
- distance_resolution:
92
- Resolution of the distance axis in the lookup table.
93
- toa_resolution:
94
- Resolution of the time-of-arrival axis in the lookup table.
86
+ distance_bins:
87
+ Bin edges for the distance axis in the lookup table.
88
+ time_bins:
89
+ Bin edges for the event_time_offset axis in the lookup table.
90
+ distance_unit:
91
+ Unit of the distance axis.
92
+ time_unit:
93
+ Unit of the event_time_offset axis.
94
+ frame_period:
95
+ Period of the source pulses, i.e., time between consecutive pulse starts.
96
+ time_bins_half_width:
97
+ Half the width of the time bins.
95
98
  """
96
- distance_unit = "m"
97
- res = distance_resolution.to(unit=distance_unit)
98
99
  simulation_distance = simulation.distance.to(unit=distance_unit)
99
-
100
- min_dist, max_dist = (
101
- x.to(unit=distance_unit) - simulation_distance for x in ltotal_range
102
- )
103
- # We need to bin the data below, to compute the weighted mean of the wavelength.
104
- # This results in data with bin edges.
105
- # However, the 2d interpolator expects bin centers.
106
- # We want to give the 2d interpolator a table that covers the requested range,
107
- # hence we need to extend the range by at least half a resolution in each direction.
108
- # Then, we make the choice that the resolution in distance is the quantity that
109
- # should be preserved. Because the difference between min and max distance is
110
- # not necessarily an integer multiple of the resolution, we need to add a pad to
111
- # ensure that the last bin is not cut off. We want the upper edge to be higher than
112
- # the maximum distance, hence we pad with an additional 1.5 x resolution.
113
- pad = 2.0 * res
114
- dist_edges = sc.array(
115
- dims=["distance"],
116
- values=np.arange((min_dist - pad).value, (max_dist + pad).value, res.value),
117
- unit=distance_unit,
118
- )
119
- distances = sc.midpoints(dist_edges)
120
-
121
- time_unit = simulation.time_of_arrival.unit
100
+ distances = sc.midpoints(distance_bins)
101
+ # Compute arrival and flight times for all neutrons
122
102
  toas = simulation.time_of_arrival + (distances / simulation.speed).to(
123
103
  unit=time_unit, copy=False
124
104
  )
125
-
126
- # Compute time-of-flight for all neutrons
127
- wavs = sc.broadcast(simulation.wavelength.to(unit="m"), sizes=toas.sizes).flatten(
128
- to="event"
129
- )
130
- dist = sc.broadcast(distances + simulation_distance, sizes=toas.sizes).flatten(
131
- to="event"
132
- )
133
- tofs = dist * sc.constants.m_n
134
- tofs *= wavs
135
- tofs /= sc.constants.h
105
+ dist = distances + simulation_distance
106
+ tofs = dist * (sc.constants.m_n / sc.constants.h) * simulation.wavelength
136
107
 
137
108
  data = sc.DataArray(
138
- data=sc.broadcast(simulation.weight, sizes=toas.sizes).flatten(to="event"),
109
+ data=sc.broadcast(simulation.weight, sizes=toas.sizes),
139
110
  coords={
140
- "toa": toas.flatten(to="event"),
111
+ "toa": toas,
141
112
  "tof": tofs.to(unit=time_unit, copy=False),
142
113
  "distance": dist,
143
114
  },
115
+ ).flatten(to="event")
116
+
117
+ # Add the event_time_offset coordinate to the data. We first operate on the
118
+ # frame period. The table will later be folded to the pulse period.
119
+ data.coords['event_time_offset'] = data.coords['toa'] % frame_period
120
+
121
+ # Because we staggered the mesh by half a bin width, we want the values above
122
+ # the last bin edge to wrap around to the first bin.
123
+ # Technically, those values should end up between -0.5*bin_width and 0, but
124
+ # a simple modulo also works here because even if they end up between 0 and
125
+ # 0.5*bin_width, we are (below) computing the mean between -0.5*bin_width and
126
+ # 0.5*bin_width and it yields the same result.
127
+ # data.coords['event_time_offset'] %= pulse_period - time_bins_half_width
128
+ data.coords['event_time_offset'] %= frame_period - time_bins_half_width
129
+
130
+ binned = data.bin(
131
+ distance=distance_bins + simulation_distance, event_time_offset=time_bins
144
132
  )
145
133
 
146
- binned = data.bin(distance=dist_edges + simulation_distance, toa=toa_resolution)
147
134
  # Weighted mean of tof inside each bin
148
135
  mean_tof = (
149
136
  binned.bins.data * binned.bins.coords["tof"]
@@ -154,188 +141,316 @@ def compute_tof_lookup_table(
154
141
  ).bins.sum() / binned.bins.sum()
155
142
 
156
143
  mean_tof.variances = variance.values
144
+ return mean_tof
157
145
 
158
- # Convert coordinates to midpoints
159
- mean_tof.coords["toa"] = sc.midpoints(mean_tof.coords["toa"])
160
- mean_tof.coords["distance"] = sc.midpoints(mean_tof.coords["distance"])
161
146
 
162
- return TimeOfFlightLookupTable(mean_tof)
163
-
164
-
165
- def masked_tof_lookup_table(
166
- tof_lookup: TimeOfFlightLookupTable,
167
- error_threshold: LookupTableRelativeErrorThreshold,
168
- ) -> MaskedTimeOfFlightLookupTable:
147
+ def _fold_table_to_pulse_period(
148
+ table: sc.DataArray, pulse_period: sc.Variable, pulse_stride: int
149
+ ) -> sc.DataArray:
169
150
  """
170
- Mask regions of the lookup table where the variance of the projected time-of-flight
171
- is larger than a given threshold.
151
+ Fold the lookup table to the pulse period. We make sure the left and right edges of
152
+ the table wrap around the ``event_time_offset`` dimension.
172
153
 
173
154
  Parameters
174
155
  ----------
175
- tof_lookup:
176
- Lookup table giving time-of-flight as a function of distance and
177
- time-of-arrival.
178
- variance_threshold:
179
- Threshold for the variance of the projected time-of-flight above which regions
180
- are masked.
156
+ table:
157
+ Lookup table with time-of-flight as a function of distance and time-of-arrival.
158
+ pulse_period:
159
+ Period of the source pulses, i.e., time between consecutive pulse starts.
160
+ pulse_stride:
161
+ Stride of used pulses. Usually 1, but may be a small integer when
162
+ pulse-skipping.
181
163
  """
182
- relative_error = sc.stddevs(tof_lookup.data) / sc.values(tof_lookup.data)
183
- mask = relative_error > sc.scalar(error_threshold)
184
- out = tof_lookup.copy()
185
- # Use numpy for indexing as table is 2D
186
- out.values[mask.values] = np.nan
187
- return MaskedTimeOfFlightLookupTable(out)
188
-
164
+ size = table.sizes['event_time_offset']
165
+ if (size % pulse_stride) != 0:
166
+ raise ValueError(
167
+ "TimeOfFlightLookupTable: the number of time bins must be a multiple of "
168
+ f"the pulse stride, but got {size} time bins and a pulse stride of "
169
+ f"{pulse_stride}."
170
+ )
189
171
 
190
- def find_fastest_neutron(simulation: SimulationResults) -> FastestNeutron:
191
- """
192
- Find the fastest neutron in the simulation results.
193
- """
194
- ind = np.argmax(simulation.speed.values)
195
- return FastestNeutron(
196
- time_of_arrival=simulation.time_of_arrival[ind],
197
- speed=simulation.speed[ind],
198
- distance=simulation.distance,
172
+ size = size // pulse_stride
173
+ out = sc.concat([table, table['event_time_offset', 0]], dim='event_time_offset')
174
+ out = sc.concat(
175
+ [
176
+ out['event_time_offset', (i * size) : (i + 1) * size + 1]
177
+ for i in range(pulse_stride)
178
+ ],
179
+ dim='pulse',
199
180
  )
200
-
201
-
202
- def pivot_time_at_detector(
203
- fastest_neutron: FastestNeutron, ltotal: Ltotal
204
- ) -> PivotTimeAtDetector:
205
- """
206
- Compute the pivot time at the detector, i.e., the time of the start of the frame at
207
- the detector.
208
- The assumption here is that the fastest neutron in the simulation results is the one
209
- that arrives at the detector first.
210
- One could have an edge case where a slightly slower neutron which is born earlier
211
- could arrive at the detector first, but this edge case is most probably uncommon,
212
- and the difference in arrival times is likely to be small.
213
-
214
- Parameters
215
- ----------
216
- fastest_neutron:
217
- Properties of the fastest neutron in the simulation results.
218
- ltotal:
219
- Total length of the flight path from the source to the detector.
220
- """
221
- dist = ltotal - fastest_neutron.distance.to(unit=ltotal.unit)
222
- toa = fastest_neutron.time_of_arrival + (dist / fastest_neutron.speed).to(
223
- unit=fastest_neutron.time_of_arrival.unit, copy=False
181
+ return out.assign_coords(
182
+ event_time_offset=sc.concat(
183
+ [
184
+ table.coords['event_time_offset']['event_time_offset', :size],
185
+ pulse_period,
186
+ ],
187
+ 'event_time_offset',
188
+ )
224
189
  )
225
- return PivotTimeAtDetector(toa)
226
190
 
227
191
 
228
- def unwrapped_time_of_arrival(
229
- da: RawData, offset: PulseStrideOffset, pulse_period: PulsePeriod
230
- ) -> UnwrappedTimeOfArrival:
192
+ def compute_tof_lookup_table(
193
+ simulation: SimulationResults,
194
+ ltotal_range: LtotalRange,
195
+ distance_resolution: DistanceResolution,
196
+ time_resolution: TimeResolution,
197
+ pulse_period: PulsePeriod,
198
+ pulse_stride: PulseStride,
199
+ error_threshold: LookupTableRelativeErrorThreshold,
200
+ ) -> TimeOfFlightLookupTable:
231
201
  """
232
- Compute the unwrapped time of arrival of the neutron at the detector.
233
- For event data, this is essentially ``event_time_offset + event_time_zero``.
202
+ Compute a lookup table for time-of-flight as a function of distance and
203
+ time-of-arrival.
234
204
 
235
205
  Parameters
236
206
  ----------
237
- da:
238
- Raw detector data loaded from a NeXus file, e.g., NXdetector containing
239
- NXevent_data.
240
- offset:
241
- Integer offset of the first pulse in the stride (typically zero unless we are
242
- using pulse-skipping and the events do not begin with the first pulse in the
243
- stride).
207
+ simulation:
208
+ Results of a time-of-flight simulation used to create a lookup table.
209
+ The results should be a flat table with columns for time-of-arrival, speed,
210
+ wavelength, and weight.
211
+ ltotal_range:
212
+ Range of total flight path lengths from the source to the detector.
213
+ distance_resolution:
214
+ Resolution of the distance axis in the lookup table.
215
+ time_resolution:
216
+ Resolution of the time-of-arrival axis in the lookup table. Must be an integer.
244
217
  pulse_period:
245
218
  Period of the source pulses, i.e., time between consecutive pulse starts.
219
+ pulse_stride:
220
+ Stride of used pulses. Usually 1, but may be a small integer when
221
+ pulse-skipping.
222
+ error_threshold:
223
+ Threshold for the relative standard deviation (coefficient of variation) of the
224
+ projected time-of-flight above which values are masked.
246
225
  """
247
- if da.bins is None:
248
- # 'time_of_flight' is the canonical name in NXmonitor, but in some files, it
249
- # may be called 'tof'.
250
- key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof"}))
251
- toa = da.coords[key]
252
- else:
253
- # To unwrap the time of arrival, we want to add the event_time_zero to the
254
- # event_time_offset. However, we do not really care about the exact datetimes,
255
- # we just want to know the offsets with respect to the start of the run.
256
- # Hence we use the smallest event_time_zero as the time origin.
257
- time_zero = da.coords["event_time_zero"] - da.coords["event_time_zero"].min()
258
- coord = da.bins.coords["event_time_offset"]
259
- unit = elem_unit(coord)
260
- toa = (
261
- coord
262
- + time_zero.to(dtype=float, unit=unit, copy=False)
263
- - (offset * pulse_period).to(unit=unit, copy=False)
226
+ distance_unit = "m"
227
+ time_unit = simulation.time_of_arrival.unit
228
+ res = distance_resolution.to(unit=distance_unit)
229
+ pulse_period = pulse_period.to(unit=time_unit)
230
+ frame_period = pulse_period * pulse_stride
231
+
232
+ min_dist, max_dist = (
233
+ x.to(unit=distance_unit) - simulation.distance.to(unit=distance_unit)
234
+ for x in ltotal_range
235
+ )
236
+ # We need to bin the data below, to compute the weighted mean of the wavelength.
237
+ # This results in data with bin edges.
238
+ # However, the 2d interpolator expects bin centers.
239
+ # We want to give the 2d interpolator a table that covers the requested range,
240
+ # hence we need to extend the range by at least half a resolution in each direction.
241
+ # Then, we make the choice that the resolution in distance is the quantity that
242
+ # should be preserved. Because the difference between min and max distance is
243
+ # not necessarily an integer multiple of the resolution, we need to add a pad to
244
+ # ensure that the last bin is not cut off. We want the upper edge to be higher than
245
+ # the maximum distance, hence we pad with an additional 1.5 x resolution.
246
+ pad = 2.0 * res
247
+ distance_bins = sc.arange('distance', min_dist - pad, max_dist + pad, res)
248
+
249
+ # Create some time bins for event_time_offset.
250
+ # We want our final table to strictly cover the range [0, frame_period].
251
+ # However, binning the data associates mean values inside the bins to the bin
252
+ # centers. Instead, we stagger the mesh by half a bin width so we are computing
253
+ # values for the final mesh edges (the bilinear interpolation needs values on the
254
+ # edges/corners).
255
+ nbins = int(frame_period / time_resolution.to(unit=time_unit)) + 1
256
+ time_bins = sc.linspace(
257
+ 'event_time_offset', 0.0, frame_period.value, nbins + 1, unit=pulse_period.unit
258
+ )
259
+ time_bins_half_width = 0.5 * (time_bins[1] - time_bins[0])
260
+ time_bins -= time_bins_half_width
261
+
262
+ # To avoid a too large RAM usage, we compute the table in chunks, and piece them
263
+ # together at the end.
264
+ ndist = len(distance_bins) - 1
265
+ max_size = 2e7
266
+ total_size = ndist * len(simulation.time_of_arrival)
267
+ nchunks = total_size / max_size
268
+ chunk_size = int(ndist / nchunks) + 1
269
+ pieces = []
270
+ for i in range(int(nchunks) + 1):
271
+ dist_edges = distance_bins[i * chunk_size : (i + 1) * chunk_size + 1]
272
+
273
+ pieces.append(
274
+ _compute_mean_tof_in_distance_range(
275
+ simulation=simulation,
276
+ distance_bins=dist_edges,
277
+ time_bins=time_bins,
278
+ distance_unit=distance_unit,
279
+ time_unit=time_unit,
280
+ frame_period=frame_period,
281
+ time_bins_half_width=time_bins_half_width,
282
+ )
264
283
  )
265
- return UnwrappedTimeOfArrival(toa)
266
284
 
285
+ table = sc.concat(pieces, 'distance')
286
+ table.coords["distance"] = sc.midpoints(table.coords["distance"])
287
+ table.coords["event_time_offset"] = sc.midpoints(table.coords["event_time_offset"])
267
288
 
268
- def unwrapped_time_of_arrival_minus_frame_pivot_time(
269
- toa: UnwrappedTimeOfArrival, pivot_time: PivotTimeAtDetector
270
- ) -> UnwrappedTimeOfArrivalMinusPivotTime:
271
- """
272
- Compute the time of arrival of the neutron at the detector, unwrapped at the pulse
273
- period, minus the start time of the frame.
274
- We subtract the start time of the frame so that we can use a modulo operation to
275
- wrap the time of arrival at the frame period in the case of pulse-skipping.
289
+ table = _fold_table_to_pulse_period(
290
+ table=table, pulse_period=pulse_period, pulse_stride=pulse_stride
291
+ )
276
292
 
277
- Parameters
278
- ----------
279
- toa:
280
- Time of arrival of the neutron at the detector, unwrapped at the pulse period.
281
- pivot_time:
282
- Pivot time at the detector, i.e., the time of the start of the frame at the
283
- detector.
284
- """
285
- # Order of operation to preserve dimension order
286
- return UnwrappedTimeOfArrivalMinusPivotTime(
287
- -pivot_time.to(unit=elem_unit(toa), copy=False) + toa
293
+ # In-place masking for better performance
294
+ _mask_large_uncertainty(table, error_threshold)
295
+
296
+ return TimeOfFlightLookupTable(
297
+ table.transpose(('pulse', 'distance', 'event_time_offset'))
288
298
  )
289
299
 
290
300
 
291
- def time_of_arrival_minus_pivot_time_modulo_period(
292
- toa_minus_pivot_time: UnwrappedTimeOfArrivalMinusPivotTime,
293
- frame_period: FramePeriod,
294
- ) -> TimeOfArrivalMinusPivotTimeModuloPeriod:
295
- """
296
- Compute the time of arrival of the neutron at the detector, unwrapped at the pulse
297
- period, minus the start time of the frame, modulo the frame period.
301
+ def _make_tof_interpolator(
302
+ lookup: sc.DataArray, distance_unit: str, time_unit: str
303
+ ) -> Callable:
304
+ from scipy.interpolate import RegularGridInterpolator
298
305
 
299
- Parameters
300
- ----------
301
- toa_minus_pivot_time:
302
- Time of arrival of the neutron at the detector, unwrapped at the pulse period,
303
- minus the start time of the frame.
304
- frame_period:
305
- Period of the frame, i.e., time between the start of two consecutive frames.
306
- """
307
- return TimeOfArrivalMinusPivotTimeModuloPeriod(
308
- toa_minus_pivot_time
309
- % frame_period.to(unit=elem_unit(toa_minus_pivot_time), copy=False)
306
+ # TODO: to make use of multi-threading, we could write our own interpolator.
307
+ # This should be simple enough as we are making the bins linspace, so computing
308
+ # bin indices is fast.
309
+
310
+ # In the pulse dimension, it could be that for a given event_time_offset and
311
+ # distance, a tof value is finite in one pulse and NaN in the other.
312
+ # When using the bilinear interpolation, even if the value of the requested point is
313
+ # exactly 0 or 1 (in the case of pulse_stride=2), the interpolator will still
314
+ # use all 4 corners surrounding the point. This means that if one of the corners
315
+ # is NaN, the result will be NaN.
316
+ # Here, we use a trick where we duplicate the lookup values in the 'pulse' dimension
317
+ # so that the interpolator has values on bin edges for that dimension.
318
+ # The interpolator raises an error if axes coordinates are not strictly monotonic,
319
+ # so we cannot use e.g. [-0.5, 0.5, 0.5, 1.5] in the case of pulse_stride=2.
320
+ # Instead we use [-0.25, 0.25, 0.75, 1.25].
321
+ base_grid = np.arange(float(lookup.sizes["pulse"]))
322
+ return RegularGridInterpolator(
323
+ (
324
+ np.sort(np.concatenate([base_grid - 0.25, base_grid + 0.25])),
325
+ lookup.coords["distance"].to(unit=distance_unit, copy=False).values,
326
+ lookup.coords["event_time_offset"].to(unit=time_unit, copy=False).values,
327
+ ),
328
+ np.repeat(lookup.data.to(unit=time_unit, copy=False).values, 2, axis=0),
329
+ method="linear",
330
+ bounds_error=False,
331
+ fill_value=np.nan,
310
332
  )
311
333
 
312
334
 
313
- def time_of_arrival_folded_by_frame(
314
- toa: TimeOfArrivalMinusPivotTimeModuloPeriod,
315
- pivot_time: PivotTimeAtDetector,
316
- ) -> FrameFoldedTimeOfArrival:
317
- """
318
- The time of arrival of the neutron at the detector, folded by the frame period.
335
+ def _time_of_flight_data_histogram(
336
+ da: sc.DataArray,
337
+ lookup: sc.DataArray,
338
+ ltotal: sc.Variable,
339
+ pulse_period: sc.Variable,
340
+ ) -> sc.DataArray:
341
+ # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files,
342
+ # it may be called 'tof'.
343
+ key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof"}))
344
+ eto_unit = da.coords[key].unit
345
+ pulse_period = pulse_period.to(unit=eto_unit)
346
+
347
+ # In histogram mode, because there is a wrap around at the end of the pulse, we
348
+ # need to insert a bin edge at that exact location to avoid having the last bin
349
+ # with one finite left edge and a NaN right edge (it becomes NaN as it would be
350
+ # outside the range of the lookup table).
351
+ new_bins = sc.sort(
352
+ sc.concat(
353
+ [da.coords[key], sc.scalar(0.0, unit=eto_unit), pulse_period], dim=key
354
+ ),
355
+ key=key,
356
+ )
357
+ rebinned = da.rebin({key: new_bins})
358
+ etos = rebinned.coords[key]
359
+
360
+ # In histogram mode, the lookup table cannot have a pulse dimension because we
361
+ # cannot know in the histogrammed data which pulse the events belong to.
362
+ # So we merge the pulse dimension in the lookup table. A quick way to do this
363
+ # is to take the mean of the data along the pulse dimension (there should
364
+ # only be regions that are NaN in one pulse and finite in the other).
365
+ merged = lookup.data.nanmean('pulse')
366
+ dim = merged.dims[0]
367
+ lookup = sc.DataArray(
368
+ data=merged.fold(dim=dim, sizes={'pulse': 1, dim: merged.sizes[dim]}),
369
+ coords={
370
+ 'pulse': sc.arange('pulse', 1.0),
371
+ 'distance': lookup.coords['distance'],
372
+ 'event_time_offset': lookup.coords['event_time_offset'],
373
+ },
374
+ )
375
+ pulse_index = sc.zeros(sizes=etos.sizes)
319
376
 
320
- Parameters
321
- ----------
322
- toa:
323
- Time of arrival of the neutron at the detector, unwrapped at the pulse period,
324
- minus the start time of the frame, modulo the frame period.
325
- pivot_time:
326
- Pivot time at the detector, i.e., the time of the start of the frame at the
327
- detector.
328
- """
329
- return FrameFoldedTimeOfArrival(
330
- toa + pivot_time.to(unit=elem_unit(toa), copy=False)
377
+ # Create 2D interpolator
378
+ interp = _make_tof_interpolator(
379
+ lookup, distance_unit=ltotal.unit, time_unit=eto_unit
380
+ )
381
+
382
+ # Compute time-of-flight of the bin edges using the interpolator
383
+ tofs = sc.array(
384
+ dims=etos.dims,
385
+ values=interp((pulse_index.values, ltotal.values, etos.values)),
386
+ unit=eto_unit,
387
+ )
388
+
389
+ return rebinned.assign_coords(tof=tofs)
390
+
391
+
392
+ def _time_of_flight_data_events(
393
+ da: sc.DataArray,
394
+ lookup: sc.DataArray,
395
+ ltotal: sc.Variable,
396
+ pulse_period: sc.Variable,
397
+ pulse_stride: int,
398
+ pulse_stride_offset: int,
399
+ ) -> sc.DataArray:
400
+ etos = da.bins.coords["event_time_offset"]
401
+ eto_unit = elem_unit(etos)
402
+ pulse_period = pulse_period.to(unit=eto_unit)
403
+ frame_period = pulse_period * pulse_stride
404
+
405
+ # TODO: Finding the `tmin` below will not work in the case were data is processed
406
+ # in chunks, as taking the minimum time in each chunk will lead to inconsistent
407
+ # pulse indices (this will be the case in live data, or when using the
408
+ # StreamProcessor). We could instead read it from the first chunk and store it?
409
+
410
+ # Compute a pulse index for every event: it is the index of the pulse within a
411
+ # frame period. When there is no pulse skipping, those are all zero. When there is
412
+ # pulse skipping, the index ranges from zero to pulse_stride - 1.
413
+ tmin = da.bins.coords['event_time_zero'].min()
414
+ pulse_index = (
415
+ (
416
+ (da.bins.coords['event_time_zero'] - tmin).to(unit=eto_unit)
417
+ + 0.5 * pulse_period
418
+ )
419
+ % frame_period
420
+ ) // pulse_period
421
+ # Apply the pulse_stride_offset
422
+ pulse_index += pulse_stride_offset
423
+ pulse_index %= pulse_stride
424
+
425
+ # Create 2D interpolator
426
+ interp = _make_tof_interpolator(
427
+ lookup, distance_unit=ltotal.unit, time_unit=eto_unit
428
+ )
429
+
430
+ # Operate on events (broadcast distances to all events)
431
+ ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"]
432
+ etos = etos.bins.constituents["data"]
433
+ pulse_index = pulse_index.bins.constituents["data"]
434
+
435
+ # Compute time-of-flight for all neutrons using the interpolator
436
+ tofs = sc.array(
437
+ dims=etos.dims,
438
+ values=interp((pulse_index.values, ltotal.values, etos.values)),
439
+ unit=eto_unit,
331
440
  )
332
441
 
442
+ parts = da.bins.constituents
443
+ parts["data"] = tofs
444
+ return da.bins.assign_coords(tof=_bins_no_validate(**parts))
445
+
333
446
 
334
447
  def time_of_flight_data(
335
448
  da: RawData,
336
- lookup: MaskedTimeOfFlightLookupTable,
449
+ lookup: TimeOfFlightLookupTable,
337
450
  ltotal: Ltotal,
338
- toas: FrameFoldedTimeOfArrival,
451
+ pulse_period: PulsePeriod,
452
+ pulse_stride: PulseStride,
453
+ pulse_stride_offset: PulseStrideOffset,
339
454
  ) -> TofData:
340
455
  """
341
456
  Convert the time-of-arrival data to time-of-flight data using a lookup table.
@@ -351,39 +466,29 @@ def time_of_flight_data(
351
466
  arrival.
352
467
  ltotal:
353
468
  Total length of the flight path from the source to the detector.
354
- toas:
355
- Time of arrival of the neutron at the detector, folded by the frame period.
469
+ pulse_period:
470
+ Period of the source pulses, i.e., time between consecutive pulse starts.
471
+ pulse_stride:
472
+ Stride of used pulses. Usually 1, but may be a small integer when
473
+ pulse-skipping.
474
+ pulse_stride_offset:
475
+ When pulse-skipping, the offset of the first pulse in the stride. This is
476
+ typically zero but can be a small integer < pulse_stride.
356
477
  """
357
- from scipy.interpolate import RegularGridInterpolator
358
478
 
359
- # TODO: to make use of multi-threading, we could write our own interpolator.
360
- # This should be simple enough as we are making the bins linspace, so computing
361
- # bin indices is fast.
362
- f = RegularGridInterpolator(
363
- (
364
- lookup.coords["toa"].to(unit=elem_unit(toas), copy=False).values,
365
- lookup.coords["distance"].to(unit=ltotal.unit, copy=False).values,
366
- ),
367
- lookup.data.to(unit=elem_unit(toas), copy=False).values.T,
368
- method="linear",
369
- bounds_error=False,
370
- )
371
-
372
- if da.bins is not None:
373
- ltotal = sc.bins_like(toas, ltotal).bins.constituents["data"]
374
- toas = toas.bins.constituents["data"]
375
-
376
- tofs = sc.array(
377
- dims=toas.dims, values=f((toas.values, ltotal.values)), unit=elem_unit(toas)
378
- )
379
-
380
- if da.bins is not None:
381
- parts = da.bins.constituents
382
- parts["data"] = tofs
383
- out = da.bins.assign_coords(tof=_bins_no_validate(**parts))
479
+ if da.bins is None:
480
+ out = _time_of_flight_data_histogram(
481
+ da=da, lookup=lookup, ltotal=ltotal, pulse_period=pulse_period
482
+ )
384
483
  else:
385
- out = da.assign_coords(tof=tofs)
386
-
484
+ out = _time_of_flight_data_events(
485
+ da=da,
486
+ lookup=lookup,
487
+ ltotal=ltotal,
488
+ pulse_period=pulse_period,
489
+ pulse_stride=pulse_stride,
490
+ pulse_stride_offset=pulse_stride_offset,
491
+ )
387
492
  return TofData(out)
388
493
 
389
494
 
@@ -432,7 +537,7 @@ def default_parameters() -> dict:
432
537
  PulseStride: 1,
433
538
  PulseStrideOffset: 0,
434
539
  DistanceResolution: sc.scalar(0.1, unit="m"),
435
- TimeOfArrivalResolution: 500,
540
+ TimeResolution: sc.scalar(250.0, unit='us'),
436
541
  LookupTableRelativeErrorThreshold: 0.1,
437
542
  }
438
543
 
@@ -441,101 +546,4 @@ def providers() -> tuple[Callable]:
441
546
  """
442
547
  Providers of the time-of-flight workflow.
443
548
  """
444
- return (
445
- compute_tof_lookup_table,
446
- extract_ltotal,
447
- find_fastest_neutron,
448
- frame_period,
449
- masked_tof_lookup_table,
450
- pivot_time_at_detector,
451
- time_of_arrival_folded_by_frame,
452
- time_of_arrival_minus_pivot_time_modulo_period,
453
- time_of_flight_data,
454
- unwrapped_time_of_arrival,
455
- unwrapped_time_of_arrival_minus_frame_pivot_time,
456
- )
457
-
458
-
459
- class TofWorkflow:
460
- """
461
- Helper class to build a time-of-flight workflow and cache the expensive part of the
462
- computation: running the simulation and building the lookup table.
463
-
464
- Parameters
465
- ----------
466
- simulated_neutrons:
467
- Results of a time-of-flight simulation used to create a lookup table.
468
- The results should be a flat table with columns for time-of-arrival, speed,
469
- wavelength, and weight.
470
- ltotal_range:
471
- Range of total flight path lengths from the source to the detector.
472
- This is used to create the lookup table to compute the neutron
473
- time-of-flight.
474
- Note that the resulting table will extend slightly beyond this range, as the
475
- supplied range is not necessarily a multiple of the distance resolution.
476
- pulse_stride:
477
- Stride of used pulses. Usually 1, but may be a small integer when
478
- pulse-skipping.
479
- pulse_stride_offset:
480
- Integer offset of the first pulse in the stride (typically zero unless we
481
- are using pulse-skipping and the events do not begin with the first pulse in
482
- the stride).
483
- distance_resolution:
484
- Resolution of the distance axis in the lookup table.
485
- Should be a single scalar value with a unit of length.
486
- This is typically of the order of 1-10 cm.
487
- toa_resolution:
488
- Resolution of the time of arrival axis in the lookup table.
489
- Can be an integer (number of bins) or a sc.Variable (bin width).
490
- error_threshold:
491
- Threshold for the variance of the projected time-of-flight above which
492
- regions are masked.
493
- """
494
-
495
- def __init__(
496
- self,
497
- simulated_neutrons: SimulationResults,
498
- ltotal_range: LtotalRange,
499
- pulse_stride: PulseStride | None = None,
500
- pulse_stride_offset: PulseStrideOffset | None = None,
501
- distance_resolution: DistanceResolution | None = None,
502
- toa_resolution: TimeOfArrivalResolution | None = None,
503
- error_threshold: LookupTableRelativeErrorThreshold | None = None,
504
- ):
505
- import sciline as sl
506
-
507
- self.pipeline = sl.Pipeline(providers())
508
- self.pipeline[SimulationResults] = simulated_neutrons
509
- self.pipeline[LtotalRange] = ltotal_range
510
-
511
- params = default_parameters()
512
- self.pipeline[PulsePeriod] = params[PulsePeriod]
513
- self.pipeline[PulseStride] = pulse_stride or params[PulseStride]
514
- self.pipeline[PulseStrideOffset] = (
515
- pulse_stride_offset or params[PulseStrideOffset]
516
- )
517
- self.pipeline[DistanceResolution] = (
518
- distance_resolution or params[DistanceResolution]
519
- )
520
- self.pipeline[TimeOfArrivalResolution] = (
521
- toa_resolution or params[TimeOfArrivalResolution]
522
- )
523
- self.pipeline[LookupTableRelativeErrorThreshold] = (
524
- error_threshold or params[LookupTableRelativeErrorThreshold]
525
- )
526
-
527
- def cache_results(
528
- self,
529
- results=(SimulationResults, MaskedTimeOfFlightLookupTable, FastestNeutron),
530
- ) -> None:
531
- """
532
- Cache a list of (usually expensive to compute) intermediate results of the
533
- time-of-flight workflow.
534
-
535
- Parameters
536
- ----------
537
- results:
538
- List of results to cache.
539
- """
540
- for t in results:
541
- self.pipeline[t] = self.pipeline.compute(t)
549
+ return (compute_tof_lookup_table, extract_ltotal, time_of_flight_data)