reboost 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from lgdo import lh5
10
+ from matplotlib import colors, widgets
11
+ from numpy.typing import NDArray
12
+
13
+ from .create import list_optical_maps
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def _get_weights(viewdata: dict):
19
+ rolled = np.array([viewdata["axis"]] + [i for i in [2, 1, 0] if i != viewdata["axis"]])
20
+ weights = viewdata["weights"].transpose(*rolled)
21
+
22
+ extent_axes = [i for i in np.flip(rolled) if i != viewdata["axis"]]
23
+ extent = [viewdata["edges"][i] for i in extent_axes]
24
+ extent_d = [(e[1] - e[0]) / 2 for e in extent]
25
+ extent = [(e[0] - extent_d[i], e[-1] + extent_d[i]) for i, e in enumerate(extent)]
26
+
27
+ labels = [["x", "y", "z"][i] + " [m]" for i in extent_axes]
28
+
29
+ return weights, np.array(extent).flatten(), labels
30
+
31
+
32
+ def _slice_text(viewdata: dict) -> str:
33
+ axis_name = ["x", "y", "z"][viewdata["axis"]]
34
+ axis_pos = viewdata["edges"][viewdata["axis"]][viewdata["idx"]]
35
+ return f"{viewdata['detid']} | slice {axis_name} = {axis_pos:.2f} m"
36
+
37
+
38
+ def _process_key(event) -> None:
39
+ fig = event.canvas.figure
40
+ viewdata = fig.__reboost
41
+
42
+ if event.key == "up":
43
+ viewdata["axis"] = min(viewdata["axis"] + 1, 2)
44
+ elif event.key == "down":
45
+ viewdata["axis"] = max(viewdata["axis"] - 1, 0)
46
+ elif event.key == "right":
47
+ viewdata["idx"] += 1
48
+ elif event.key == "left":
49
+ viewdata["idx"] -= 1
50
+ elif event.key == "c":
51
+ _channel_selector(fig)
52
+
53
+ max_idx = viewdata["weights"].shape[viewdata["axis"]] - 1
54
+ viewdata["idx"] = max(min(viewdata["idx"], max_idx), 0)
55
+
56
+ _update_figure(fig)
57
+
58
+
59
+ def _update_figure(fig) -> None:
60
+ viewdata = fig.__reboost
61
+ w, extent, labels = _get_weights(viewdata)
62
+
63
+ ax = fig.axes[0]
64
+ ax.texts[0].set_text(_slice_text(viewdata))
65
+ ax.images[0].set_array(w[viewdata["idx"]])
66
+ ax.images[0].set_extent(extent)
67
+ ax.set_xlabel(labels[0])
68
+ ax.set_ylabel(labels[1])
69
+ ax.set_anchor("C")
70
+ fig.canvas.draw()
71
+
72
+
73
+ def _channel_selector(fig) -> None:
74
+ axbox = fig.add_axes([0.01, 0.01, 0.98, 0.98])
75
+ channels = fig.__reboost["available_dets"]
76
+ tb = widgets.RadioButtons(axbox, channels, active=channels.index(fig.__reboost["detid"]))
77
+
78
+ def change_detector(label: str | None) -> None:
79
+ if fig.__reboost["detid"] != label:
80
+ fig.__reboost["detid"] = label
81
+ edges, weights, _, _ = _prepare_data(*fig.__reboost["prepare_args"], label)
82
+ fig.__reboost["weights"] = weights
83
+ fig.__reboost["edges"] = edges
84
+ tb.disconnect_events()
85
+ axbox.remove()
86
+ _update_figure(fig)
87
+
88
+ tb.on_clicked(change_detector)
89
+ fig.canvas.draw()
90
+
91
+
92
+ def _read_data(
93
+ optmap_fn: str,
94
+ detid: str = "all",
95
+ histogram_choice: str = "prob",
96
+ ) -> tuple[tuple[NDArray], NDArray]:
97
+ histogram = histogram_choice if histogram_choice != "prob_unc_rel" else "prob"
98
+ detid = f"channels/{detid}" if detid != all and not detid.startswith("channels/") else detid
99
+
100
+ optmap_all = lh5.read(f"/{detid}/{histogram}", optmap_fn)
101
+ optmap_edges = tuple([b.edges for b in optmap_all.binning])
102
+ optmap_weights = optmap_all.weights.nda.copy()
103
+ if histogram_choice == "prob_unc_rel":
104
+ optmap_err = lh5.read(f"/{detid}/prob_unc", optmap_fn)
105
+ divmask = optmap_weights > 0
106
+ optmap_weights[divmask] = optmap_err.weights.nda[divmask] / optmap_weights[divmask]
107
+ optmap_weights[~divmask] = -1
108
+
109
+ return optmap_edges, optmap_weights
110
+
111
+
112
+ def _prepare_data(
113
+ optmap_fn: str,
114
+ divide_fn: str | None = None,
115
+ cmap_min: float | Literal["auto"] = 1e-4,
116
+ cmap_max: float | Literal["auto"] = 1e-2,
117
+ histogram_choice: str = "prob",
118
+ detid: str = "all",
119
+ ) -> tuple[tuple[NDArray], NDArray]:
120
+ optmap_edges, optmap_weights = _read_data(optmap_fn, detid, histogram_choice)
121
+
122
+ if divide_fn is not None:
123
+ _, divide_map = _read_data(divide_fn, detid, histogram_choice)
124
+ divmask = divide_map > 0
125
+ optmap_weights[divmask] = optmap_weights[divmask] / divide_map[divmask]
126
+ optmap_weights[~divmask] = -1
127
+
128
+ if cmap_min == "auto":
129
+ cmap_min = max(1e-10, optmap_weights[optmap_weights > 0].min())
130
+ if cmap_max == "auto":
131
+ cmap_max = optmap_weights.max()
132
+
133
+ lower_count = np.sum((optmap_weights > 0) & (optmap_weights < cmap_min))
134
+ if lower_count > 0:
135
+ log.warning(
136
+ "%d cells are non-zero and lower than the current colorbar minimum %.2e",
137
+ lower_count,
138
+ cmap_min,
139
+ )
140
+ higher_count = np.sum(optmap_weights > cmap_max)
141
+ if higher_count > 0:
142
+ log.warning(
143
+ "%d cells are non-zero and higher than the current colorbar maximum %.2e",
144
+ higher_count,
145
+ cmap_max,
146
+ )
147
+
148
+ # set zero/close-to-zero values to a very small, but nonzero value. This means we can
149
+ # style those cells using the `under` style and, can style `bad` (i.e. everything < 0
150
+ # after this re-assignment) in a different way.
151
+ optmap_weights[(optmap_weights >= 0) & (optmap_weights < 1e-50)] = min(1e-10, cmap_min)
152
+
153
+ return optmap_edges, optmap_weights, cmap_min, cmap_max
154
+
155
+
156
+ def view_optmap(
157
+ optmap_fn: list[str],
158
+ detid: str = "all",
159
+ divide_fn: str | None = None,
160
+ start_axis: int = 2,
161
+ cmap_min: float | Literal["auto"] = 1e-4,
162
+ cmap_max: float | Literal["auto"] = 1e-2,
163
+ histogram_choice: str = "prob",
164
+ title: str | None = None,
165
+ ) -> None:
166
+ available_dets = list_optical_maps(optmap_fn)
167
+
168
+ prepare_args = (optmap_fn, divide_fn, cmap_min, cmap_max, histogram_choice)
169
+ edges, weights, cmap_min, cmap_max = _prepare_data(*prepare_args, detid)
170
+
171
+ fig = plt.figure(figsize=(10, 10))
172
+ fig.canvas.mpl_connect("key_press_event", _process_key)
173
+ start_axis_len = edges[start_axis].shape[0] - 1
174
+ fig.__reboost = {
175
+ "axis": start_axis,
176
+ "weights": weights,
177
+ "detid": detid,
178
+ "edges": edges,
179
+ "idx": min(int(start_axis_len / 2), start_axis_len - 1),
180
+ "available_dets": available_dets,
181
+ "prepare_args": prepare_args,
182
+ }
183
+
184
+ cmap = plt.cm.plasma.with_extremes(bad="w", under="gray", over="red")
185
+ weights, extent, labels = _get_weights(fig.__reboost)
186
+ plt.imshow(
187
+ weights[fig.__reboost["idx"]],
188
+ norm=colors.LogNorm(vmin=cmap_min, vmax=cmap_max),
189
+ aspect=1,
190
+ interpolation="none",
191
+ cmap=cmap,
192
+ extent=extent,
193
+ origin="lower",
194
+ )
195
+
196
+ if title is None:
197
+ title = Path(optmap_fn).stem
198
+ if divide_fn is not None:
199
+ title += " / " + Path(divide_fn).stem
200
+
201
+ plt.suptitle(title)
202
+
203
+ plt.xlabel(labels[0])
204
+ plt.ylabel(labels[1])
205
+
206
+ plt.text(0, 1.02, _slice_text(fig.__reboost), transform=fig.axes[0].transAxes)
207
+ plt.colorbar()
208
+ plt.show()
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import types
5
+
6
+ from numba import njit
7
+
8
+
9
+ def numba_pdgid_funcs():
10
+ """Load a numby-optimized copy of the scikit-hep/particle package."""
11
+ spec_pdg = importlib.util.find_spec("particle.pdgid.functions")
12
+ pdg_func = importlib.util.module_from_spec(spec_pdg)
13
+ spec_pdg.loader.exec_module(pdg_func)
14
+
15
+ def _digit2(pdgid, loc: int) -> int:
16
+ e = 10 ** (loc - 1)
17
+ return (pdgid // e % 10) if pdgid >= e else 0
18
+
19
+ pdg_func._digit = _digit2
20
+
21
+ for fname, f in pdg_func.__dict__.items():
22
+ if not callable(f) or not isinstance(f, types.FunctionType):
23
+ continue
24
+ setattr(pdg_func, fname, njit(f, cache=True))
25
+
26
+ return pdg_func
@@ -0,0 +1,328 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import ctypes
5
+ import logging
6
+ import math
7
+ import multiprocessing as mp
8
+ from collections.abc import Mapping
9
+
10
+ import numpy as np
11
+ from lgdo import Histogram, Struct, lh5
12
+ from numpy.typing import NDArray
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class OpticalMap:
18
+ def __init__(self, name: str, settings: Mapping[str, str], use_shmem: bool = False):
19
+ self.settings = settings
20
+ self.name = name
21
+ self.use_shmem = use_shmem
22
+
23
+ self.h_vertex = None
24
+ self.h_hits = None
25
+ self.h_prob = None
26
+ self.h_prob_uncert = None
27
+
28
+ self.binning = None
29
+
30
+ self.__fill_hits_buf = None
31
+
32
+ if settings is not None:
33
+ self._single_shape = tuple(self.settings["bins"])
34
+ self._single_stride = None
35
+
36
+ binedge_attrs = {"units": "m"}
37
+ bins = self.settings["bins"]
38
+ bounds = self.settings["range_in_m"]
39
+ self.binning = [
40
+ Histogram.Axis(
41
+ None,
42
+ bounds[i][0],
43
+ bounds[i][1],
44
+ (bounds[i][1] - bounds[i][0]) / bins[i],
45
+ True,
46
+ binedge_attrs,
47
+ )
48
+ for i in range(3)
49
+ ]
50
+
51
+ @staticmethod
52
+ def create_empty(name: str, settings: Mapping[str, str]) -> OpticalMap:
53
+ om = OpticalMap(name, settings)
54
+ om.h_vertex = om._prepare_hist()
55
+ om.h_hits = om._prepare_hist()
56
+ return om
57
+
58
+ @staticmethod
59
+ def load_from_file(lh5_file: str, group: str) -> OpticalMap:
60
+ om = OpticalMap(group, None)
61
+
62
+ def read_hist(name: str, fn: str, group: str = "all"):
63
+ h = lh5.read(f"/{group}/{name}", lh5_file=fn)
64
+ if not isinstance(h, Histogram):
65
+ msg = f"encountered invalid optical map while reading /{group}/{name} in {fn}"
66
+ raise RuntimeError(msg)
67
+ return h.weights.nda, h.binning
68
+
69
+ om.h_vertex, bin_nr_gen = read_hist("_nr_gen", lh5_file, group=group)
70
+ om.h_hits, bin_nr_det = read_hist("_nr_det", lh5_file, group=group)
71
+ om.h_prob, bin_p_det = read_hist("prob", lh5_file, group=group)
72
+ om.h_prob_uncert, bin_p_det_err = read_hist("prob_unc", lh5_file, group=group)
73
+
74
+ for bins in (bin_nr_det, bin_p_det, bin_p_det_err):
75
+ if not OpticalMap._edges_eq(bin_nr_gen, bins):
76
+ msg = "edges of optical map histograms differ"
77
+ raise RuntimeError(msg)
78
+
79
+ om.binning = bin_nr_gen
80
+ return om
81
+
82
+ def _prepare_hist(self) -> np.ndarray:
83
+ """Prepare an empty histogram with the parameters global to this map instance."""
84
+ if self.use_shmem:
85
+ assert mp.current_process().name == "MainProcess"
86
+ a = self._mp_man.Array(ctypes.c_double, math.prod(self._single_shape))
87
+ nda = self._nda(a)
88
+ nda.fill(0)
89
+ else:
90
+ a = np.zeros(shape=self._single_shape, dtype=np.float64)
91
+ nda = a
92
+ stride = [s // nda.dtype.itemsize for s in nda.strides]
93
+ if self._single_stride is None:
94
+ self._single_stride = stride
95
+ assert self._single_stride == stride
96
+ return a
97
+
98
+ def _fill_histogram(
99
+ self,
100
+ h: NDArray | mp.sharedctypes.SynchronizedArray,
101
+ xyz: NDArray,
102
+ for_hits: bool = False,
103
+ ) -> None:
104
+ assert xyz.shape[1] == 3
105
+ xyz = xyz.T
106
+
107
+ # use as much pre-allocated memory as possible.
108
+ if self.__fill_hits_buf is None:
109
+ self.__fill_hits_buf = np.empty(5000, np.int64)
110
+ self.__fill_hits_pos = 0
111
+
112
+ idx = np.zeros(xyz.shape[1], np.int64) # bin indices for flattened array
113
+ oor_mask = np.ones(xyz.shape[1], np.bool_) # mask to remove out of range values
114
+ dims = range(xyz.shape[0])
115
+ for col, ax, s, dim in zip(xyz, self.binning, self._single_stride, dims, strict=True):
116
+ assert ax.is_range
117
+ assert ax.closedleft
118
+ oor_mask &= (ax.first <= col) & (col < ax.last)
119
+ idx_s = np.floor((col.astype(np.float64) - ax.first) / ax.step).astype(np.int64)
120
+ assert np.all(idx_s[oor_mask] < self._single_shape[dim])
121
+ idx += s * idx_s
122
+
123
+ idx = idx[oor_mask]
124
+ if idx.shape[0] == 0:
125
+ return
126
+
127
+ if for_hits and idx.shape[0] < self.__fill_hits_buf.shape[0]:
128
+ # special path for the typically small number of hits.
129
+ # this circumvents a memory leak in _fill_histogram_buf when called with varying and
130
+ # small shapes of the idx array.
131
+ end = self.__fill_hits_pos + idx.shape[0]
132
+ if end >= self.__fill_hits_buf.shape[0]:
133
+ # flush the old buffer to the map, as the new data does not fit.
134
+ self._fill_histogram_buf(h, self.__fill_hits_buf[0 : self.__fill_hits_pos])
135
+ self.__fill_hits_pos = 0
136
+ end = idx.shape[0]
137
+ self.__fill_hits_buf[self.__fill_hits_pos : end] = idx
138
+ self.__fill_hits_pos = end
139
+ else:
140
+ # here we assume a uniform size of idx, so that we do not hit the memory leak.
141
+ self._fill_histogram_buf(h, idx)
142
+
143
+ def _fill_histogram_buf(
144
+ self,
145
+ h: NDArray | mp.sharedctypes.SynchronizedArray,
146
+ idx: NDArray,
147
+ ) -> None:
148
+ # increment bin contents
149
+ with self._lock_nda(h)():
150
+ np.add.at(self._nda(h).reshape(-1), idx, 1)
151
+
152
+ def _nda(self, h: NDArray | mp.sharedctypes.SynchronizedArray) -> NDArray:
153
+ if not self.use_shmem:
154
+ return h
155
+ return np.ndarray(self._single_shape, dtype=np.float64, buffer=h.get_obj())
156
+
157
+ def _lock_nda(self, h: NDArray | mp.sharedctypes.SynchronizedArray):
158
+ if not self.use_shmem:
159
+ return contextlib.nullcontext
160
+ return h.get_lock
161
+
162
+ def _mp_preinit(self, mp_man: mp.context.BaseContext, vertex: bool) -> None:
163
+ self._mp_man = mp_man
164
+ if self.h_vertex is None and vertex:
165
+ self.h_vertex = self._prepare_hist()
166
+ if self.h_hits is None:
167
+ self.h_hits = self._prepare_hist()
168
+
169
+ def fill_vertex(self, loc: NDArray) -> None:
170
+ """Fill map with a chunk of hit coordinates."""
171
+ if self.h_vertex is None:
172
+ self.h_vertex = self._prepare_hist()
173
+ self._fill_histogram(self.h_vertex, loc)
174
+
175
+ def fill_hits(self, loc: NDArray) -> None:
176
+ """Fill map with a chunk of hit coordinates.
177
+
178
+ .. note::
179
+
180
+ For performance reasons, this function is buffered and does not
181
+ directly write to the map array. Use :meth:`.fill_hits_flush` to
182
+ flush the remaining hits in the buffer to this map.
183
+ """
184
+ if self.h_hits is None:
185
+ self.h_hits = self._prepare_hist()
186
+ self._fill_histogram(self.h_hits, loc, for_hits=True)
187
+
188
+ def fill_hits_flush(self) -> None:
189
+ """Commit all remaining hit coordinates in the buffer."""
190
+ if self.h_hits is None or self.__fill_hits_pos <= 0:
191
+ return
192
+ self._fill_histogram_buf(self.h_hits, self.__fill_hits_buf[0 : self.__fill_hits_pos])
193
+ self.__fill_hits_buf = None
194
+
195
+ def _divide_hist(self, h1: NDArray, h2: NDArray) -> tuple[NDArray, NDArray]:
196
+ """Calculate the ratio (and its standard error) from two histograms."""
197
+ h1 = self._nda(h1)
198
+ h2 = self._nda(h2)
199
+
200
+ ratio_0 = self._prepare_hist()
201
+ ratio_err_0 = self._prepare_hist()
202
+ ratio, ratio_err = self._nda(ratio_0), self._nda(ratio_err_0)
203
+
204
+ ratio[:] = np.divide(h1, h2, where=(h2 != 0))
205
+ ratio[h2 == 0] = -1 # -1 denotes no statistics.
206
+
207
+ if np.any(ratio > 1):
208
+ msg = "encountered cell(s) with more hits than primaries"
209
+ raise RuntimeError(msg)
210
+
211
+ # compute uncertainty according to Bernoulli statistics.
212
+ # TODO: this does not make sense for ratio==1
213
+ ratio_err[h2 != 0] = np.sqrt((ratio[h2 != 0]) * (1 - ratio[h2 != 0]) / h2[h2 != 0])
214
+ ratio_err[h2 == 0] = -1 # -1 denotes no statistics.
215
+
216
+ return ratio_0, ratio_err_0
217
+
218
+ def create_probability(self) -> None:
219
+ """Compute probability map (and map uncertainty) from vertex and hit map."""
220
+ self.h_prob, self.h_prob_uncert = self._divide_hist(self.h_hits, self.h_vertex)
221
+
222
+ def write_lh5(self, lh5_file: str, group: str = "all", wo_mode: str = "write_safe") -> None:
223
+ """Write this map to a LH5 file."""
224
+ if wo_mode not in ("write_safe", "overwrite_file"):
225
+ msg = f"invalid wo_mode {wo_mode} for optical map"
226
+ raise ValueError(msg)
227
+
228
+ def write_hist(h: NDArray, name: str, fn: str, group: str, wo_mode: str):
229
+ lh5.write(
230
+ Struct({name: Histogram(self._nda(h), self.binning)}),
231
+ group,
232
+ fn,
233
+ wo_mode=wo_mode,
234
+ )
235
+
236
+ # only use the passed wo_mode for the first file.
237
+ write_hist(self.h_vertex, "_nr_gen", lh5_file, group, wo_mode)
238
+ write_hist(self.h_hits, "_nr_det", lh5_file, group, "append_column")
239
+ write_hist(self.h_prob, "prob", lh5_file, group, "append_column")
240
+ write_hist(self.h_prob_uncert, "prob_unc", lh5_file, group, "append_column")
241
+
242
+ def get_settings(self) -> dict:
243
+ """Get the binning settings that were used to create this optical map instance."""
244
+ if self.settings is not None:
245
+ return self.settings
246
+
247
+ range_in_m = []
248
+ bins = []
249
+ for b in self.binning:
250
+ if not b.is_range:
251
+ msg = "cannot get binning settings for variable binning map"
252
+ raise RuntimeError(msg)
253
+ if b.get_binedgeattrs().get("units") != "m":
254
+ msg = "invalid units. can only work with optical maps in meter"
255
+ raise RuntimeError(msg)
256
+ range_in_m.append([b.first, b.last])
257
+ bins.append(b.nbins)
258
+
259
+ return {"range_in_m": np.array(range_in_m), "bins": np.array(bins)}
260
+
261
+ def check_histograms(self, include_prefix: bool = False) -> None:
262
+ log_prefix = "" if not include_prefix else self.name + " - "
263
+
264
+ def _warn(fmt: str, *args):
265
+ log.warning("%s" + fmt, log_prefix, *args) # noqa: G003
266
+
267
+ h_vertex = self._nda(self.h_vertex)
268
+ h_prob = self._nda(self.h_prob)
269
+ h_prob_uncert = self._nda(self.h_prob_uncert)
270
+
271
+ ncells = h_vertex.shape[0] * h_vertex.shape[1] * h_vertex.shape[2]
272
+
273
+ missing_v = np.sum(h_vertex <= 0) # bins without vertices.
274
+ if missing_v > 0:
275
+ _warn("%d missing_v %.2f %%", missing_v, missing_v / ncells * 100)
276
+
277
+ missing_p = np.sum(h_prob <= 0) # bins without hist.
278
+ if missing_p > 0:
279
+ _warn("%d missing_p %.2f %%", missing_p, missing_p / ncells * 100)
280
+
281
+ non_phys = np.sum(h_prob > 1) # non-physical events with probability > 1.
282
+ if non_phys > 0:
283
+ _warn(
284
+ "%d voxels (%.2f %%) with non-physical probability (p>1)",
285
+ non_phys,
286
+ non_phys / ncells * 100,
287
+ )
288
+
289
+ # warnings on insufficient statistics.
290
+ large_error = np.sum(h_prob_uncert > 0.01 * h_prob)
291
+ if large_error > 0:
292
+ _warn(
293
+ "%d voxels (%.2f %%) with large relative statistical uncertainty (> 1 %%)",
294
+ large_error,
295
+ large_error / ncells * 100,
296
+ )
297
+
298
+ primaries_low_stats_th = 100
299
+ low_stat_zero = np.sum((h_vertex < primaries_low_stats_th) & (h_prob == 0))
300
+ if low_stat_zero > 0:
301
+ _warn(
302
+ "%d voxels (%.2f %%) with non reliable probability estimate (p=0 and primaries < %d)",
303
+ low_stat_zero,
304
+ low_stat_zero / ncells * 100,
305
+ primaries_low_stats_th,
306
+ )
307
+ low_stat_one = np.sum((h_vertex < primaries_low_stats_th) & (h_prob == 1))
308
+ if low_stat_one > 0:
309
+ _warn(
310
+ "%d voxels (%.2f %%) with non reliable probability estimate (p=1 and primaries < %d)",
311
+ low_stat_one,
312
+ low_stat_one / ncells * 100,
313
+ primaries_low_stats_th,
314
+ )
315
+
316
+ @staticmethod
317
+ def _edges_eq(
318
+ e1: tuple[NDArray] | tuple[Histogram.Axis], e2: tuple[NDArray] | tuple[Histogram.Axis]
319
+ ) -> bool:
320
+ """Compare edge-tuples for two histograms."""
321
+ if isinstance(e1[0], Histogram.Axis):
322
+ e1 = tuple([b.edges for b in e1])
323
+ if isinstance(e2[0], Histogram.Axis):
324
+ e2 = tuple([b.edges for b in e2])
325
+ assert all(isinstance(b, np.ndarray) for b in e1)
326
+ assert all(isinstance(b, np.ndarray) for b in e2)
327
+
328
+ return all(np.all(x1 == x2) for x1, x2 in zip(e1, e2, strict=True))
reboost/profile.py ADDED
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+
6
+ from dbetto import AttrsDict
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class ProfileDict(AttrsDict):
12
+ """A class to store the results of time profiling."""
13
+
14
+ def update_field(self, name: str, time_start: float) -> None:
15
+ """Update the stored time.
16
+
17
+ Parameters
18
+ ----------
19
+ name
20
+ the name of the field to update. If it contains / this
21
+ will be interpreted as subdictionaries.
22
+ time_start
23
+ the starting time of the block to evaluate
24
+ """
25
+ name_split = name.split("/")
26
+ group = None
27
+ dict_tmp = None
28
+
29
+ time_end = time.time()
30
+
31
+ for idx, name_tmp in enumerate(name_split):
32
+ dict_tmp = self if (group is None) else dict_tmp[group]
33
+
34
+ # if we are at the end and the name is not in the dictionary add it
35
+ if (idx == len(name_split) - 1) and (name_tmp not in dict_tmp):
36
+ dict_tmp[name_tmp] = time_end - time_start
37
+
38
+ # append the time different
39
+ elif (idx == len(name_split) - 1) and (name_tmp in dict_tmp):
40
+ dict_tmp[name_tmp] = dict_tmp[name_tmp] + (time_end - time_start)
41
+
42
+ # create a subdictionary
43
+ elif name_tmp not in dict_tmp:
44
+ dict_tmp[name_tmp] = {}
45
+
46
+ group = name_tmp
47
+
48
+ def __repr__(self):
49
+ return f"ProfileDict({dict(self)})"
50
+
51
+ def __str__(self):
52
+ """Return a human-readable profiling summary."""
53
+ return "\nReboost post processing took: \n" + self._format(self, indent=1)
54
+
55
+ def _format(self, data: ProfileDict, indent: int = 1) -> str:
56
+ """Recursively format the dictionary.
57
+
58
+ Parameters
59
+ ----------
60
+ data
61
+ The dictionary to format.
62
+ indent
63
+ The current indentation level.
64
+
65
+ Returns
66
+ -------
67
+ the formatted print out.
68
+ """
69
+ output = ""
70
+ space = " " * indent # Indentation spaces
71
+
72
+ for key, value in data.items():
73
+ if isinstance(value, dict): # If the value is a dictionary, recurse
74
+ output += f"{space}- {key}:\n" + self._format(value, indent + 2)
75
+ else:
76
+ # Round floats to 1 decimal place
77
+ value_print = round(value, 1) if isinstance(value, float) else value
78
+ value_print = f"{value_print}".rjust(7) if value_print > 0 else "< 0.1".rjust(7)
79
+ output += f"{space}- {key}".ljust(25)
80
+ output += f": {value_print} s\n"
81
+
82
+ return output
File without changes