fpfind 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fpfind/lib/utils.py ADDED
@@ -0,0 +1,716 @@
1
+ import argparse
2
+ import enum
3
+ import pathlib
4
+ import re
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from multiprocessing import Process, Queue
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import scipy.fft
12
+ from numpy.typing import NDArray
13
+
14
+ import fpfind.lib._logging as logging
15
+ from fpfind import NP_PRECISEFLOAT
16
+ from fpfind.lib.constants import TSRES, PeakFindingFailed
17
+ from fpfind.lib.parse_epochs import date2epoch, epoch2int, int2epoch, read_T1, read_T2
18
+ from fpfind.lib.parse_timestamps import read_a1_kth_timestamp
19
+ from fpfind.lib.typing import (
20
+ Complex_,
21
+ DetectorArray,
22
+ Float,
23
+ Integer,
24
+ PathLike,
25
+ TimestampArray,
26
+ )
27
+
28
+ # Create alias to external dependency for fast histogramming
29
+ try:
30
+ from S15lib.g2lib.g2lib import histogram # noqa: F401 # type: ignore
31
+ except ModuleNotFoundError:
32
+ pass
33
+
34
+ logger, log = logging.get_logger("fpfind")
35
+
36
+
37
+ def get_overlap(*arrays):
38
+ """Returns right-truncated arrays of largest possible common length.
39
+
40
+ Used for comparing timestamps of different length, e.g. raw timestamps
41
+ vs chopper-generated individual epochs.
42
+ """
43
+ overlap = min(map(len, arrays))
44
+ arrays = [a[:overlap] for a in arrays]
45
+ return overlap, arrays
46
+
47
+
48
+ @dataclass
49
+ class PeakStatistics:
50
+ signal: NDArray[np.number]
51
+ background: NDArray[np.number]
52
+
53
+ def has_signal(self):
54
+ return len(self.signal) != 0
55
+
56
+ def has_background(self):
57
+ return len(self.background) != 0
58
+
59
+ def has_data(self):
60
+ return self.has_signal() and self.has_background()
61
+
62
+ @property
63
+ def max(self):
64
+ if not self.has_signal():
65
+ return None
66
+ return np.max(self.signal)
67
+
68
+ @property
69
+ def mean(self):
70
+ if not self.has_background():
71
+ return None
72
+ return np.mean(self.background) # type: ignore (not checked, TODO)
73
+
74
+ @property
75
+ def stdev(self):
76
+ if not self.has_background():
77
+ return None
78
+ return np.std(self.background) # type: ignore (not checked, TODO)
79
+
80
+ @property
81
+ def total(self):
82
+ if not self.has_data():
83
+ return None
84
+ return sum(self.signal) - len(self.signal) * self.mean # type: ignore (None accounted for)
85
+
86
+ @property
87
+ def significance(self):
88
+ if self.stdev == 0:
89
+ return 0
90
+ if not self.has_data():
91
+ return 0
92
+ return round((self.max - self.mean) / self.stdev, 3) # type: ignore (None accounted for)
93
+
94
+ @property
95
+ def significance_raw(self):
96
+ if self.stdev == 0:
97
+ return None
98
+ full = np.hstack(self.signal, self.background) # type: ignore (not checked)
99
+ return (np.max(full) - np.mean(full)) / np.std(full)
100
+
101
+ @property
102
+ def significance2(self):
103
+ if self.stdev == 0:
104
+ return None
105
+ # Estimate stdev after grouping in bins of 'len(signal)'
106
+ if not self.has_signal():
107
+ return None
108
+ length = (len(self.background) // len(self.signal)) * len(self.signal)
109
+ rebinned = np.sum(
110
+ self.background[:length].reshape(-1, len(self.signal)), axis=1
111
+ )
112
+ stdev = np.std(rebinned)
113
+ if stdev == 0:
114
+ return None
115
+ return round(self.total / stdev, 3) # type: ignore (None accounted for)
116
+
117
+ @property
118
+ def g2(self):
119
+ if self.mean is None or self.max is None:
120
+ return None
121
+ if self.mean == 0:
122
+ return None
123
+ return self.max / self.mean
124
+
125
+
126
+ def get_statistics(
127
+ hist: NDArray[np.number],
128
+ resolution: Optional[Float] = None,
129
+ center: Optional[Float] = None,
130
+ window: Float = 0.0,
131
+ ) -> PeakStatistics:
132
+ """Returns statistics of histogram, after performing cross-correlation.
133
+
134
+ Args:
135
+ hist: Timing histogram to analyze.
136
+ resolution: Resolution of the histogram.
137
+ center: Timing center of the peak, if known beforehand.
138
+ window: Desired timing window width to exclude from background mean calculation.
139
+ """
140
+ # Fallback to simple statistics, if not other arguments supplied
141
+ if resolution is None:
142
+ if center is None:
143
+ return PeakStatistics(hist, hist)
144
+ else:
145
+ raise ValueError("Resolution must be supplied if 'center' is supplied.")
146
+
147
+ # Guess non-negative center bin position, assuming aligned at zero
148
+ if center is None:
149
+ bin_center = np.argmax(hist)
150
+ else:
151
+ bin_center = np.abs(center) // resolution
152
+ if center < 0:
153
+ bin_center = len(hist) - bin_center
154
+
155
+ # Retrieve size of symmetrical window
156
+ num_windowbins_onesided = int(np.ceil(window / 2 / resolution))
157
+ bin_offset_left = max(0, bin_center - num_windowbins_onesided)
158
+ bin_offset_right = min(len(hist), bin_center + num_windowbins_onesided)
159
+
160
+ # Avoid tails of the cross-correlation by taking only half of the spectrum
161
+ # Use-case when same timestamp is used to obtain histogram, resulting in deadtime
162
+ bin_offset_left_bg = bin_offset_left // 2
163
+ bin_offset_right_bg = (len(hist) + bin_offset_right) // 2
164
+
165
+ # Retrieve signal
166
+ signal = hist[bin_offset_left : bin_offset_right + 1]
167
+ background = np.hstack(
168
+ (
169
+ hist[bin_offset_left_bg:bin_offset_left],
170
+ hist[bin_offset_right + 1 : bin_offset_right_bg + 1],
171
+ )
172
+ )
173
+ return PeakStatistics(signal, background)
174
+
175
+
176
+ def generate_fft(
177
+ arr: TimestampArray,
178
+ num_bins: Integer,
179
+ time_res: Float,
180
+ ) -> NDArray[np.complex128]:
181
+ """Returns the FFT and frequency resolution for the set of timestamps.
182
+
183
+ Assumes the inputs are real-valued, i.e. the FFT output is symmetrical.
184
+
185
+ Args:
186
+ arr: The timestamp series.
187
+ num_bins: The number of bins in the time/frequency domain.
188
+ bin_size: The size of each timing bin, in ns.
189
+
190
+ Note:
191
+ This function is technically not cacheable due to the mutability of
192
+ np.ndarray.
193
+ """
194
+ if len(arr) == 0:
195
+ raise ValueError("Array is empty!")
196
+ bin_arr = np.bincount(np.int32((arr // time_res) % num_bins), minlength=num_bins)
197
+ return scipy.fft.rfft(bin_arr) # type: ignore (dispatchable)
198
+
199
+
200
+ def get_xcorr(
201
+ afft: NDArray[Complex_],
202
+ bfft: NDArray[Complex_],
203
+ filter: Optional[NDArray[np.number]] = None,
204
+ ) -> NDArray[np.float64]:
205
+ """Returns the cross-correlation.
206
+
207
+ Note:
208
+ The conjugation operation on an FFT is essentially a time-reversal
209
+ operation on the original time-series data.
210
+ """
211
+ fft = np.conjugate(afft) * bfft
212
+ if filter is not None:
213
+ fft = fft * filter
214
+ result = scipy.fft.irfft(fft)
215
+ return np.abs(result) # type: ignore (dispatchable)
216
+
217
+
218
+ def convert_histogram_fft(hist: NDArray[np.number], time_bins: NDArray[np.number]):
219
+ """Adjust axes to estimate position."""
220
+ hlen = len(hist) // 2
221
+ hist = np.hstack((hist[hlen:], hist[:hlen]))
222
+ time_bins = np.hstack((-np.flip(time_bins[1 : hlen + 1]), time_bins[:hlen]))
223
+ return hist, time_bins
224
+
225
+
226
+ def get_timing_delay_fft(
227
+ hist: NDArray[np.number],
228
+ time_bins: NDArray[np.number],
229
+ include_negative: bool = False,
230
+ ) -> Tuple[np.signedinteger, np.signedinteger]:
231
+ """Returns the timing delay.
232
+
233
+ Args:
234
+ hist: Timing histogram
235
+ time_bins: Time bins
236
+
237
+ Example:
238
+ >>> get_timing_delay_fft([1,3,0,1], [2,4,6,8])
239
+ (4, -8)
240
+
241
+ """
242
+ # [0, 1, 2, 3] --> (0, -4)
243
+ ppos = np.argmax(hist)
244
+ ptime = time_bins[ppos]
245
+ if ppos == 0:
246
+ npos = 0
247
+ else:
248
+ npos = len(hist) - ppos
249
+ ntime = -time_bins[npos]
250
+ result = (ptime, ntime) if np.abs(ppos) < np.abs(npos) else (ntime, ptime)
251
+ return result
252
+
253
+
254
+ def get_delay_at_index_fft(
255
+ time_bins: NDArray[np.number], pos: Integer
256
+ ) -> Tuple[np.number, np.number]:
257
+ assert pos >= 0
258
+ pos2 = len(time_bins) - pos if pos != 0 else 0 # corner case
259
+ ptime = time_bins[pos]
260
+ ntime = -time_bins[pos2]
261
+ result = (ptime, ntime) if np.abs(pos) < np.abs(pos2) else (ntime, ptime)
262
+ return result
263
+
264
+
265
+ def get_top_k_delays_fft(
266
+ hist: NDArray[np.number], time_bins: NDArray[np.number], k: Integer
267
+ ):
268
+ assert k >= 1
269
+ if k == 1:
270
+ return [get_delay_at_index_fft(time_bins, np.argmax(hist))]
271
+
272
+ xs_raw = np.argpartition(hist, -k)[-k:]
273
+ ys_raw = hist[xs_raw]
274
+ sort = ys_raw.argsort()[::-1] # descending order
275
+ result = [
276
+ (y, *get_delay_at_index_fft(time_bins, x))
277
+ for x, y in zip(xs_raw[sort], ys_raw[sort])
278
+ ]
279
+ return result
280
+
281
+
282
+ def slice_timestamps(
283
+ ts: TimestampArray,
284
+ start: Union[Float, None] = None,
285
+ duration: Union[Float, None] = None,
286
+ ) -> TimestampArray:
287
+ dtype = ts.dtype
288
+ if start is not None:
289
+ ts = ts - start # note: 'ts -= start' is in-place
290
+ ts = ts[ts >= 0]
291
+ else:
292
+ ts = ts - ts[0]
293
+ if duration is not None:
294
+ if len(ts) == 0:
295
+ warnings.warn("No data available.")
296
+ return np.array([], dtype=dtype)
297
+ if duration >= ts[-1]:
298
+ warnings.warn(
299
+ f"Desired duration ({duration * 1e-9:.3f} s) is longer "
300
+ f"than available data ({ts[-1] * 1e-9:.3f} s)"
301
+ )
302
+ ts = ts[ts < duration]
303
+ return ts
304
+
305
+
306
+ def histogram_fft( # noqa: PLR0913
307
+ alice: TimestampArray,
308
+ bob: TimestampArray,
309
+ num_bins: Integer,
310
+ resolution: Float = 1,
311
+ num_wraps: Integer = 1,
312
+ acq_start: Optional[Float] = None,
313
+ freq_corr: Float = 0.0,
314
+ filter: Optional[NDArray[np.number]] = None,
315
+ statistics: bool = False,
316
+ center: Optional[Float] = None,
317
+ window: Float = 0.0,
318
+ ):
319
+ """Returns the cross-correlation histogram.
320
+
321
+ Args:
322
+ acq_start: Starting timing, relative to first common timestamp.
323
+ filter: Optional filter in frequency-space.
324
+ """
325
+ if not isinstance(num_wraps, (int, np.integer)):
326
+ warnings.warn(
327
+ "Number of wraps is not an integer - "
328
+ "statistical significance will be lower."
329
+ )
330
+
331
+ duration = num_wraps * num_bins * resolution
332
+ first_timestamp = max(alice[0], bob[0])
333
+ last_timestamp = min(alice[-1], bob[-1])
334
+ if first_timestamp + duration > last_timestamp:
335
+ warnings.warn(
336
+ f"Desired duration of timestamps ({duration} ns) "
337
+ f"exceeds available data ({last_timestamp - first_timestamp} ns)."
338
+ )
339
+
340
+ # Normalize timestamps
341
+ acq_start = 0 if acq_start is None else acq_start
342
+ alice -= first_timestamp + acq_start
343
+ bob -= first_timestamp + acq_start
344
+ bob = bob * (1 + freq_corr)
345
+
346
+ # Generate FFT
347
+ afft, alen = generate_fft(alice, num_bins, resolution) # TODO: Check if correct
348
+ bfft, blen = generate_fft(bob, num_bins, resolution) # TODO: Check if correct
349
+ result = get_xcorr(afft, bfft, filter)
350
+ bins = np.arange(num_bins) * resolution
351
+ if statistics:
352
+ return (
353
+ result,
354
+ bins,
355
+ alen,
356
+ blen,
357
+ get_statistics(result, resolution, center, window),
358
+ )
359
+ return result, bins
360
+
361
+
362
+ # https://stackoverflow.com/a/23941599
363
+ class ArgparseCustomFormatter(argparse.RawDescriptionHelpFormatter):
364
+ RAW_INDICATOR = "rawtext|"
365
+
366
+ def _format_action_invocation(self, action):
367
+ if not action.option_strings:
368
+ _ = self._metavar_formatter(action, action.dest)(1)
369
+ print(action, _)
370
+ (metavar,) = _
371
+ return metavar
372
+ else:
373
+ parts = []
374
+ # if the Optional doesn't take a value, format is:
375
+ # -s, --long
376
+ if action.nargs == 0:
377
+ parts.extend(action.option_strings)
378
+
379
+ # if the Optional takes a value, format is:
380
+ # -s ARGS, --long ARGS
381
+ # change to
382
+ # -s, --long ARGS
383
+ else:
384
+ default = action.dest.upper()
385
+ args_string = self._format_args(action, default)
386
+ for option_string in action.option_strings:
387
+ # parts.append('%s %s' % (option_string, args_string))
388
+ parts.append("%s" % option_string)
389
+ parts[-1] += " %s" % args_string
390
+ return ", ".join(parts)
391
+
392
+ def _split_lines(self, text, width):
393
+ marker = ArgparseCustomFormatter.RAW_INDICATOR
394
+ if text.startswith(marker):
395
+ return text[len(marker) :].splitlines()
396
+ return super()._split_lines(text, width)
397
+
398
+
399
+ def get_first_overlapping_epoch(
400
+ dir1: PathLike,
401
+ dir2: PathLike,
402
+ first_epoch: Optional[Union[str, bytes]] = None,
403
+ return_length: bool = False,
404
+ ):
405
+ """Get epoch name of smallest overlapping epoch.
406
+
407
+ If 'return_length' is True, the return value is a tuple of the epoch name
408
+ and the number of continguous overlapping epochs starting from said epoch.
409
+ """
410
+ epochints1 = [epoch2int(fp.name) for fp in pathlib.Path(dir1).glob("*")]
411
+ epochints2 = [epoch2int(fp.name) for fp in pathlib.Path(dir2).glob("*")]
412
+ epochints = set(epochints1).intersection(epochints2)
413
+
414
+ # Exclude epochs smaller than 'first_epoch', if supplied
415
+ if first_epoch is not None:
416
+ epochint_first = epoch2int(first_epoch)
417
+ epochints = set([v for v in epochints if v >= epochint_first])
418
+
419
+ # Calculate number of overlapping epochs
420
+ if len(epochints) == 0:
421
+ min_epoch = None
422
+ num_epochs = 0
423
+ else:
424
+ min_epochint = min(epochints)
425
+ min_epoch = int2epoch(min_epochint)
426
+ num_epochs = 1
427
+ while (min_epochint + num_epochs) in epochints:
428
+ num_epochs += 1
429
+
430
+ # Return result
431
+ if return_length:
432
+ return min_epoch, num_epochs
433
+ return min_epoch
434
+
435
+
436
+ def iterate_epochs(
437
+ epoch: Union[str, bytes], length: Optional[int] = None, step: int = 1
438
+ ):
439
+ """Stream incremental epoch names, starting from specified epoch.
440
+
441
+ Mainly as a convenience function. This is a generator, so it
442
+ should be fed into a consumer, e.g. `list(iterate_epochs(...))`.
443
+ If 'length' is None, this becomes an infinite stream of epochs.
444
+ The total number of epochs emitted does not change, if 'step' is specified.
445
+
446
+ Args:
447
+ epoch: Starting epoch, in hex.
448
+ length: Total number of epochs to emit.
449
+ step: Separation between consecutive epochs.
450
+
451
+ Example:
452
+ >>> for epoch in iterate_epochs("bbbbaaa0", length=3, step=2):
453
+ ... print(epoch)
454
+ bbbbaaa0
455
+ bbbbaaa2
456
+ bbbbaaa4
457
+ """
458
+ epochint = epoch2int(epoch)
459
+ if length is not None:
460
+ for i in range(0, length * step, step):
461
+ yield int2epoch(epochint + i)
462
+ else:
463
+ while True:
464
+ yield int2epoch(epochint)
465
+ epochint += step
466
+
467
+
468
+ def get_timestamp(
469
+ dirname: PathLike,
470
+ file_type: str,
471
+ first_epoch: Union[str, bytes],
472
+ skip_epoch: int,
473
+ num_of_epochs: int,
474
+ ):
475
+ epochdir = pathlib.Path(dirname)
476
+ timestamp = np.array([], dtype=NP_PRECISEFLOAT)
477
+ for i in range(num_of_epochs):
478
+ epoch_name = epochdir / int2epoch(epoch2int(first_epoch) + skip_epoch + i)
479
+ reader = read_T1 if file_type == "T1" else read_T2
480
+ timestamp = np.append(timestamp, reader(epoch_name)[0])
481
+ return timestamp
482
+
483
+
484
+ def get_timestamp_pattern(
485
+ dirname: PathLike,
486
+ file_type: str,
487
+ first_epoch: Union[str, bytes],
488
+ skip_epoch: int,
489
+ num_of_epochs: int,
490
+ ) -> Tuple[TimestampArray, DetectorArray]:
491
+ epochdir = pathlib.Path(dirname)
492
+ timestamp = np.array([], dtype=NP_PRECISEFLOAT)
493
+ patterns = np.array([], dtype=np.uint32)
494
+ reader = read_T1 if file_type == "T1" else read_T2
495
+ for i in range(num_of_epochs):
496
+ epoch_name = epochdir / int2epoch(epoch2int(first_epoch) + skip_epoch + i)
497
+ ts, ps = reader(epoch_name)
498
+ timestamp = np.append(timestamp, ts)
499
+ patterns = np.append(patterns, ps)
500
+ return timestamp, patterns
501
+
502
+
503
+ def normalize_timestamps(
504
+ *T: TimestampArray, skip: float = 0.0, preserve_relative: bool = True
505
+ ):
506
+ """Shifts timestamp arrays to reference zero.
507
+
508
+ Args:
509
+ T: List of timestamp arrays.
510
+ skip: Duration to skip, in seconds.
511
+ preserve_relative: Preserve relative durations between arrays.
512
+ """
513
+ if not preserve_relative:
514
+ T = tuple(slice_timestamps(ts) for ts in T)
515
+
516
+ start_time = max([ts[0] for ts in T]) + skip * 1e9 # units of ns
517
+ T = tuple(slice_timestamps(ts, start_time) for ts in T)
518
+ return T
519
+
520
+
521
+ def parse_docstring_description(docstring: Optional[str]) -> str:
522
+ if docstring is None:
523
+ return ""
524
+
525
+ placeholder = "~~~PLACEHOLDER~~~"
526
+ # Remove all annotated sections, including changelog
527
+ d, *_ = re.split(r"\n[a-zA-Z0-9\s]+:\n", docstring)
528
+
529
+ # Replace all newlines except the first
530
+ d = re.sub(r"\n+", placeholder, d, count=1)
531
+ d = re.sub(r"\n+", " ", d)
532
+ d = re.sub(placeholder, "\n\n", d)
533
+ return d
534
+
535
+
536
+ def timestamp2epoch(
537
+ filename: PathLike,
538
+ resolution: TSRES = TSRES.PS4,
539
+ legacy: bool = False,
540
+ full: bool = False,
541
+ ) -> str:
542
+ (t,), _ = read_a1_kth_timestamp(
543
+ filename,
544
+ [0],
545
+ legacy,
546
+ resolution=resolution,
547
+ fractional=False,
548
+ )
549
+ epoch = epoch2int(date2epoch()) # get current epoch (32-bit)
550
+ # epoch_header = epoch >> 17 # increments every ~20h (first 15 bits)
551
+
552
+ if full:
553
+ epochint = read_a1_kth_timestamp(
554
+ filename, [0], legacy=True, resolution=TSRES.PS125, fractional=False
555
+ )
556
+ epochint = (epoch & ~((1 << 17) - 1)) | (int(t) >> 37) # 54 - 17
557
+ return int2epoch(epochint)
558
+
559
+
560
+ def xcorr(abs: NDArray[np.number], bbs: NDArray[np.number]):
561
+ """Performs the main cross-correlation routine."""
562
+ afft: NDArray[Complex_] = scipy.fft.rfft(abs) # type: ignore (dispatchable)
563
+ bfft: NDArray[Complex_] = scipy.fft.rfft(bbs) # type: ignore (dispatchable)
564
+ ys = get_xcorr(afft, bfft)
565
+ return ys
566
+
567
+
568
+ def xcorr_process(abs: NDArray[np.number], bbs: NDArray[np.number]):
569
+ """Runs 'xcorr()' in a separate process for interruptible FFT.
570
+
571
+ Note:
572
+ Activate this with caution, because spawning a separate process incurs a penalty
573
+ of pickling and piping inputs/outputs between processes. Rough estimates yield
574
+ 20% increase in runtime, e.g. 17.6(2)s -> 20.2(2) for N = 2^27 bins.
575
+ """
576
+ q = Queue()
577
+ args = (q, abs, bbs)
578
+
579
+ def _xcorr(q, *args, **kwargs):
580
+ result = xcorr(*args, **kwargs)
581
+ q.put(result)
582
+
583
+ p = Process(target=_xcorr, args=args)
584
+ p.start()
585
+ try:
586
+ return q.get()
587
+ except KeyboardInterrupt: # SIGINT
588
+ p.terminate()
589
+ raise
590
+
591
+
592
+ def generate_fft_bins(
593
+ arr: TimestampArray,
594
+ num_bins: Integer,
595
+ time_res: Float,
596
+ ) -> NDArray[np.signedinteger]:
597
+ if len(arr) == 0:
598
+ raise ValueError("Array is empty!")
599
+ bin_arr = np.bincount(np.int32((arr // time_res) % num_bins), minlength=num_bins)
600
+ return bin_arr
601
+
602
+
603
+ def histogram_fft2(
604
+ ats: TimestampArray,
605
+ bts: TimestampArray,
606
+ start: Optional[Float],
607
+ duration: Optional[Float],
608
+ N: Integer,
609
+ r: Float,
610
+ interruptible: bool = False,
611
+ ):
612
+ """Convenience function that wraps histogram routines."""
613
+ ats_early: TimestampArray = slice_timestamps(ats, start, duration)
614
+ bts_early: TimestampArray = slice_timestamps(bts, start, duration)
615
+ abs = generate_fft_bins(ats_early, N, r)
616
+ bbs = generate_fft_bins(bts_early, N, r)
617
+ _xcorr = xcorr_process if interruptible else xcorr
618
+ ys = _xcorr(abs, bbs)
619
+
620
+ _dtype = np.int32 # signed number needed for negative delays
621
+ if N * r > 2147483647: # int32 max
622
+ _dtype = np.int64
623
+ xs = np.arange(N, dtype=_dtype) * r
624
+ return xs, ys
625
+
626
+
627
+ def fold_histogram(
628
+ xs: NDArray[np.number], ys: NDArray[np.number], binning_factor: int = 2
629
+ ) -> Tuple[NDArray[np.number], NDArray[np.number]]:
630
+ xs = xs[::binning_factor]
631
+ ys = np.sum(ys.reshape(-1, binning_factor), axis=1)
632
+ return xs, ys
633
+
634
+
635
+ class CoarseHistogramStrategy(enum.Enum):
636
+ RESOLUTION = enum.auto()
637
+ BINS = enum.auto()
638
+
639
+
640
+ def histogram_fft3(
641
+ ats: TimestampArray,
642
+ bts: TimestampArray,
643
+ start: Optional[Float],
644
+ duration: Float,
645
+ num_bins: Integer,
646
+ resolution: Float,
647
+ max_duration: Float,
648
+ strategy: CoarseHistogramStrategy = CoarseHistogramStrategy.RESOLUTION,
649
+ interruptible: bool = False,
650
+ ):
651
+ """Histogram with max duration limit."""
652
+ ceil = lambda v: int(np.ceil(v)) # noqa: E731 (using lambda cleaner)
653
+ factor = ceil(resolution * num_bins / max_duration)
654
+ factor = 1 << ceil(np.log2(factor)) # better equal to 2^k for FFT/reshape
655
+ duration = min(duration, max_duration)
656
+ args = (ats, bts, start, duration, num_bins, resolution, interruptible)
657
+
658
+ if factor == 1:
659
+ xs, ys = histogram_fft2(*args)
660
+
661
+ elif strategy is CoarseHistogramStrategy.RESOLUTION:
662
+ resolution = resolution / factor
663
+ xs, ys = histogram_fft2(*args)
664
+ xs, ys = fold_histogram(xs, ys, factor)
665
+
666
+ elif strategy is CoarseHistogramStrategy.BINS:
667
+ num_bins = int(num_bins // factor)
668
+ xs, ys = histogram_fft2(*args)
669
+
670
+ else:
671
+ raise NotImplementedError
672
+
673
+ return xs, ys
674
+
675
+
676
+ def match_dts(dt1s_early, dt1s_late, ddt_window):
677
+ dt1s_early = sorted(dt1s_early)
678
+ dt1s_late = sorted(dt1s_late)
679
+
680
+ i = j = 0 # two-pointer method
681
+ best = (None, (0, 0, 0), (0, 0, 0)) # min_ddt, early, late
682
+ while i < len(dt1s_early) and j < len(dt1s_late):
683
+ dt1_early = dt1s_early[i]
684
+ dt1_late = dt1s_late[j]
685
+ ddt = abs(dt1_early[0] - dt1_late[0])
686
+ if best[0] is None or (ddt < best[0] and ddt < ddt_window):
687
+ best = (ddt, dt1_early, dt1_late)
688
+
689
+ if dt1_early[0] < dt1_late[0]:
690
+ i += 1
691
+ else:
692
+ j += 1
693
+
694
+ if best[0] is None:
695
+ raise PeakFindingFailed("No coincident dt")
696
+
697
+ # Valid point found
698
+ # Approximate the resolvable resolution by finding the
699
+ # time differences found using the smallest resolution
700
+ def resolve(best_pt, pts) -> Tuple[int, int, float]:
701
+ left = best_pt[0] - best_pt[1]
702
+ right = best_pt[0] + best_pt[1]
703
+ assert len(pts) > 0
704
+ for pt in pts:
705
+ if left <= pt[0] <= right:
706
+ break
707
+ return pt # pyright: ignore[reportPossiblyUnboundVariable]
708
+
709
+ dt1_early = resolve(best[1], dt1s_early)
710
+ dt1_late = resolve(best[2], dt1s_late)
711
+
712
+ # Pass control back with expected variables
713
+ r = min([dt1_early[1], dt1_late[1]])
714
+ dt1 = dt1_early[0]
715
+ _dt1 = dt1_late[0]
716
+ return dt1, _dt1, r