reboost 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,423 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import gc
5
+ import logging
6
+ import multiprocessing as mp
7
+ from collections.abc import Callable
8
+ from pathlib import Path
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+ from lgdo import Histogram, lh5
13
+ from numba import njit
14
+ from numpy.typing import NDArray
15
+
16
+ from ..log_utils import setup_log
17
+ from .evt import (
18
+ generate_optmap_evt,
19
+ get_optical_detectors_from_geom,
20
+ )
21
+ from .optmap import OpticalMap
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ def _optmaps_for_channels(
27
+ all_det_ids: dict[int, str],
28
+ settings,
29
+ chfilter: tuple[str | int] | Literal["*"] = (),
30
+ use_shmem: bool = False,
31
+ ):
32
+ if chfilter != "*":
33
+ chfilter = [str(ch) for ch in chfilter] # normalize types
34
+ optmap_det_ids = {det: name for det, name in all_det_ids.items() if str(det) in chfilter}
35
+ else:
36
+ optmap_det_ids = all_det_ids
37
+
38
+ log.info("creating empty optmaps")
39
+ optmap_count = len(optmap_det_ids) + 1
40
+ optmaps = [
41
+ OpticalMap("all" if i == 0 else list(optmap_det_ids.values())[i - 1], settings, use_shmem)
42
+ for i in range(optmap_count)
43
+ ]
44
+
45
+ return all_det_ids, optmaps, optmap_det_ids
46
+
47
+
48
+ @njit(cache=True)
49
+ def _compute_hit_maps(hitcounts, optmap_count, ch_idx_to_optmap):
50
+ mask = np.zeros((hitcounts.shape[0], optmap_count), dtype=np.bool_)
51
+ counts = hitcounts.sum(axis=1)
52
+ for idx in range(hitcounts.shape[0]):
53
+ if counts[idx] == 0:
54
+ continue
55
+
56
+ for ch_idx in range(hitcounts.shape[1]):
57
+ c = hitcounts[idx, ch_idx]
58
+ if c > 0: # detected
59
+ mask[idx, 0] = True
60
+ mask_idx = ch_idx_to_optmap[ch_idx]
61
+ if mask_idx > 0:
62
+ mask[idx, mask_idx] = True
63
+ return mask
64
+
65
+
66
+ def _fill_hit_maps(optmaps: list[OpticalMap], loc, hitcounts: NDArray, ch_idx_to_map_idx):
67
+ masks = _compute_hit_maps(hitcounts, len(optmaps), ch_idx_to_map_idx)
68
+
69
+ for i in range(len(optmaps)):
70
+ locm = loc[masks[:, i]]
71
+ optmaps[i].fill_hits(locm)
72
+
73
+
74
+ def _create_optical_maps_process_init(optmaps, log_level) -> None:
75
+ # need to use shared global state. passing the shared memory arrays via "normal" arguments to
76
+ # the worker function is not supported...
77
+ global _shared_optmaps # noqa: PLW0603
78
+ _shared_optmaps = optmaps
79
+
80
+ # setup logging in the worker process.
81
+ setup_log(log_level, multiproc=True)
82
+
83
+
84
+ def _create_optical_maps_process(
85
+ optmap_events_fn, buffer_len, all_det_ids, ch_idx_to_map_idx
86
+ ) -> bool:
87
+ log.info("started worker task for %s", optmap_events_fn)
88
+ x = _create_optical_maps_chunk(
89
+ optmap_events_fn,
90
+ buffer_len,
91
+ all_det_ids,
92
+ _shared_optmaps,
93
+ ch_idx_to_map_idx,
94
+ )
95
+ log.info("finished worker task for %s", optmap_events_fn)
96
+ return x
97
+
98
+
99
+ def _create_optical_maps_chunk(
100
+ optmap_events_fn, buffer_len, all_det_ids, optmaps, ch_idx_to_map_idx
101
+ ) -> bool:
102
+ cols = [str(c) for c in all_det_ids]
103
+ optmap_events_it = generate_optmap_evt(optmap_events_fn, cols, buffer_len)
104
+
105
+ for it_count, events_lgdo in enumerate(optmap_events_it):
106
+ optmap_events = events_lgdo.view_as("pd")
107
+ hitcounts = optmap_events[cols].to_numpy()
108
+ loc = optmap_events[["xloc", "yloc", "zloc"]].to_numpy()
109
+
110
+ log.debug("filling vertex histogram (%d)", it_count)
111
+ optmaps[0].fill_vertex(loc)
112
+
113
+ log.debug("filling hits histogram (%d)", it_count)
114
+ _fill_hit_maps(optmaps, loc, hitcounts, ch_idx_to_map_idx)
115
+
116
+ # commit the final part of the hits to the maps.
117
+ for i in range(len(optmaps)):
118
+ optmaps[i].fill_hits_flush()
119
+ gc.collect()
120
+
121
+ return True
122
+
123
+
124
+ def create_optical_maps(
125
+ optmap_events_fn: list[str],
126
+ settings,
127
+ buffer_len: int = int(5e6),
128
+ chfilter: tuple[str | int] | Literal["*"] = (),
129
+ output_lh5_fn: str | None = None,
130
+ after_save: Callable[[int, str, OpticalMap]] | None = None,
131
+ check_after_create: bool = False,
132
+ n_procs: int | None = 1,
133
+ geom_fn: str | None = None,
134
+ ) -> None:
135
+ """Create optical maps.
136
+
137
+ Parameters
138
+ ----------
139
+ optmap_events_fn
140
+ list of filenames to lh5 files, that can either be stp files from remage or "optmap-evt"
141
+ files with a table ``/optmap_evt`` with columns ``{x,y,z}loc`` and one column (with numeric
142
+ header) for each SiPM channel.
143
+ chfilter
144
+ tuple of detector ids that will be included in the resulting optmap. Those have to match
145
+ the column names in ``optmap_events_fn``.
146
+ n_procs
147
+ number of processors, ``1`` for sequential mode, or ``None`` to use all processors.
148
+ """
149
+ if len(optmap_events_fn) == 0:
150
+ msg = "no input files specified"
151
+ raise ValueError(msg)
152
+
153
+ use_shmem = n_procs is None or n_procs > 1
154
+
155
+ optmap_evt_columns = get_optical_detectors_from_geom(geom_fn)
156
+
157
+ all_det_ids, optmaps, optmap_det_ids = _optmaps_for_channels(
158
+ optmap_evt_columns, settings, chfilter=chfilter, use_shmem=use_shmem
159
+ )
160
+
161
+ # indices for later use in _compute_hit_maps.
162
+ ch_idx_to_map_idx = np.array(
163
+ [
164
+ list(optmap_det_ids.keys()).index(d) + 1 if d in optmap_det_ids else -1
165
+ for d in all_det_ids
166
+ ]
167
+ )
168
+ assert np.sum(ch_idx_to_map_idx > 0) == len(optmaps) - 1
169
+
170
+ log.info(
171
+ "creating optical map groups: %s",
172
+ ", ".join(["all", *[str(t) for t in optmap_det_ids.items()]]),
173
+ )
174
+
175
+ q = []
176
+
177
+ # sequential mode.
178
+ if not use_shmem:
179
+ for fn in optmap_events_fn:
180
+ q.append(
181
+ _create_optical_maps_chunk(fn, buffer_len, all_det_ids, optmaps, ch_idx_to_map_idx)
182
+ )
183
+ else:
184
+ ctx = mp.get_context("forkserver")
185
+ for i in range(len(optmaps)):
186
+ optmaps[i]._mp_preinit(ctx, vertex=(i == 0))
187
+
188
+ # note: errors thrown in initializer will make the main process hang in an endless loop.
189
+ # unfortunately, we cannot pass the objects later, as they contain shmem/array handles.
190
+ pool = ctx.Pool(
191
+ n_procs,
192
+ initializer=_create_optical_maps_process_init,
193
+ initargs=(optmaps, log.getEffectiveLevel()),
194
+ maxtasksperchild=1, # re-create worker after each task, to avoid leaking memory.
195
+ )
196
+
197
+ pool_results = []
198
+ for fn in optmap_events_fn:
199
+ r = pool.apply_async(
200
+ _create_optical_maps_process,
201
+ args=(fn, buffer_len, all_det_ids, ch_idx_to_map_idx),
202
+ )
203
+ pool_results.append((r, fn))
204
+
205
+ pool.close()
206
+ for r, fn in pool_results:
207
+ try:
208
+ q.append(r.get())
209
+ except BaseException as e:
210
+ msg = f"error while processing file {fn}"
211
+ raise RuntimeError(msg) from e # re-throw errors of workers.
212
+ log.debug("got all worker results")
213
+ pool.join()
214
+ log.info("joined worker process pool")
215
+
216
+ if len(q) != len(optmap_events_fn):
217
+ log.error("got %d results for %d files", len(q), len(optmap_events_fn))
218
+
219
+ # all maps share the same vertex histogram.
220
+ for i in range(1, len(optmaps)):
221
+ optmaps[i].h_vertex = optmaps[0].h_vertex
222
+
223
+ log.info("computing probability and storing to %s", output_lh5_fn)
224
+ for i in range(len(optmaps)):
225
+ optmaps[i].create_probability()
226
+ if check_after_create:
227
+ optmaps[i].check_histograms()
228
+ group = "all" if i == 0 else "channels/" + list(optmap_det_ids.values())[i - 1]
229
+ if output_lh5_fn is not None:
230
+ optmaps[i].write_lh5(lh5_file=output_lh5_fn, group=group)
231
+
232
+ if after_save is not None:
233
+ after_save(i, group, optmaps[i])
234
+
235
+ optmaps[i] = None # clear some memory.
236
+
237
+
238
+ def list_optical_maps(lh5_file: str) -> list[str]:
239
+ maps = list(lh5.ls(lh5_file, "/channels/"))
240
+ if "all" in lh5.ls(lh5_file):
241
+ maps.append("all")
242
+ return maps
243
+
244
+
245
+ def _merge_optical_maps_process(
246
+ d: str,
247
+ map_l5_files: list[str],
248
+ output_lh5_fn: str,
249
+ settings,
250
+ check_after_create: bool = False,
251
+ write_part_file: bool = False,
252
+ ) -> bool:
253
+ log.info("merging optical map group: %s", d)
254
+ merged_map = OpticalMap.create_empty(d, settings)
255
+ merged_nr_gen = merged_map.h_vertex
256
+ merged_nr_det = merged_map.h_hits
257
+
258
+ all_edges = None
259
+ for optmap_fn in map_l5_files:
260
+ nr_det = lh5.read(f"/{d}/_nr_det", optmap_fn)
261
+ assert isinstance(nr_det, Histogram)
262
+ nr_gen = lh5.read(f"/{d}/_nr_gen", optmap_fn)
263
+ assert isinstance(nr_gen, Histogram)
264
+
265
+ assert OpticalMap._edges_eq(nr_det.binning, nr_gen.binning)
266
+ if all_edges is not None and not OpticalMap._edges_eq(nr_det.binning, all_edges):
267
+ msg = "edges of input optical maps differ"
268
+ raise ValueError(msg)
269
+ all_edges = nr_det.binning
270
+
271
+ # now that we validated that the map dimensions are equal, add up the actual data (in counts).
272
+ merged_nr_det += nr_det.weights.nda
273
+ merged_nr_gen += nr_gen.weights.nda
274
+
275
+ merged_map.create_probability()
276
+ if check_after_create:
277
+ merged_map.check_histograms(include_prefix=True)
278
+
279
+ if write_part_file:
280
+ d_for_tmp = d.replace("/", "_")
281
+ output_lh5_fn = f"{output_lh5_fn}_{d_for_tmp}.mappart.lh5"
282
+ wo_mode = "overwrite_file" if write_part_file else "write_safe"
283
+ merged_map.write_lh5(lh5_file=output_lh5_fn, group=d, wo_mode=wo_mode)
284
+
285
+ return output_lh5_fn
286
+
287
+
288
+ def merge_optical_maps(
289
+ map_l5_files: list[str],
290
+ output_lh5_fn: str,
291
+ settings,
292
+ check_after_create: bool = False,
293
+ n_procs: int | None = 1,
294
+ ) -> None:
295
+ """Merge optical maps from multiple files.
296
+
297
+ Parameters
298
+ ----------
299
+ n_procs
300
+ number of processors, ``1`` for sequential mode, or ``None`` to use all processors.
301
+ """
302
+ # verify that we have the same maps in all files.
303
+ all_det_ntuples = None
304
+ for optmap_fn in map_l5_files:
305
+ det_ntuples = list_optical_maps(optmap_fn)
306
+ if all_det_ntuples is not None and det_ntuples != all_det_ntuples:
307
+ msg = "available optical maps in input files differ"
308
+ raise ValueError(msg)
309
+ all_det_ntuples = det_ntuples
310
+
311
+ log.info("merging optical map groups: %s", ", ".join(all_det_ntuples))
312
+
313
+ use_mp = (n_procs is None or n_procs > 1) and len(all_det_ntuples) > 1
314
+
315
+ if not use_mp:
316
+ # sequential mode: merge maps one-by-one.
317
+ for d in all_det_ntuples:
318
+ _merge_optical_maps_process(
319
+ d, map_l5_files, output_lh5_fn, settings, check_after_create, use_mp
320
+ )
321
+ else:
322
+ ctx = mp.get_context("forkserver")
323
+
324
+ # note: errors thrown in initializer will make the main process hang in an endless loop.
325
+ pool = ctx.Pool(
326
+ n_procs,
327
+ initializer=_create_optical_maps_process_init,
328
+ initargs=(None, log.getEffectiveLevel()),
329
+ maxtasksperchild=1, # re-create worker after each task, to avoid leaking memory.
330
+ )
331
+
332
+ pool_results = []
333
+
334
+ # merge maps in workers.
335
+ for d in all_det_ntuples:
336
+ r = pool.apply_async(
337
+ _merge_optical_maps_process,
338
+ args=(d, map_l5_files, output_lh5_fn, settings, check_after_create, use_mp),
339
+ )
340
+ pool_results.append((r, d))
341
+
342
+ pool.close()
343
+ q = []
344
+ for r, d in pool_results:
345
+ try:
346
+ q.append((d, r.get()))
347
+ except BaseException as e:
348
+ msg = f"error while processing map {d}"
349
+ raise RuntimeError(msg) from e # re-throw errors of workers.
350
+
351
+ log.debug("got all worker results")
352
+ pool.join()
353
+ log.info("joined worker process pool")
354
+
355
+ # transfer to actual output file.
356
+ for d, part_fn in q:
357
+ assert isinstance(part_fn, str)
358
+ for h_name in ("_nr_det", "_nr_gen", "prob", "prob_unc"):
359
+ obj = f"/{d}/{h_name}"
360
+ log.info("transfer %s from %s", obj, part_fn)
361
+ h = lh5.read(obj, part_fn)
362
+ assert isinstance(h, Histogram)
363
+ lh5.write(h, obj, output_lh5_fn, wo_mode="write_safe")
364
+ Path(part_fn).unlink()
365
+
366
+
367
+ def check_optical_map(map_l5_file: str):
368
+ """Run a health check on the map file.
369
+
370
+ This checks for consistency, and outputs details on map statistics.
371
+ """
372
+ if "_hitcounts_exp" in lh5.ls(map_l5_file):
373
+ log.error("found _hitcounts_exp which is not supported any more")
374
+ return
375
+
376
+ all_binning = None
377
+ for submap in list_optical_maps(map_l5_file):
378
+ try:
379
+ om = OpticalMap.load_from_file(map_l5_file, submap)
380
+ except Exception:
381
+ log.exception("error while loading optical map %s", submap)
382
+ continue
383
+ om.check_histograms(include_prefix=True)
384
+
385
+ if all_binning is not None and not OpticalMap._edges_eq(om.binning, all_binning):
386
+ log.error("edges of optical map %s differ", submap)
387
+ else:
388
+ all_binning = om.binning
389
+
390
+
391
+ def rebin_optical_maps(map_l5_file: str, output_lh5_file: str, factor: int):
392
+ """Rebin the optical map by an integral factor.
393
+
394
+ .. note ::
395
+
396
+ the factor has to divide the bincounts on all axes.
397
+ """
398
+ if not isinstance(factor, int) or factor <= 1:
399
+ msg = f"invalid rebin factor {factor}"
400
+ raise ValueError(msg)
401
+
402
+ def _rebin_map(large: NDArray, factor: int) -> NDArray:
403
+ factor = np.full(3, factor, dtype=int)
404
+ sh = np.column_stack([np.array(large.shape) // factor, factor]).ravel()
405
+ return large.reshape(sh).sum(axis=(1, 3, 5))
406
+
407
+ for submap in list_optical_maps(map_l5_file):
408
+ log.info("rebinning optical map group: %s", submap)
409
+
410
+ om = OpticalMap.load_from_file(map_l5_file, submap)
411
+
412
+ settings = om.get_settings()
413
+ if not all(b % factor == 0 for b in settings["bins"]):
414
+ msg = f"invalid factor {factor}, not a divisor"
415
+ raise ValueError(msg)
416
+ settings = copy.copy(settings)
417
+ settings["bins"] = [b // factor for b in settings["bins"]]
418
+
419
+ om_new = OpticalMap.create_empty(om.name, settings)
420
+ om_new.h_vertex = _rebin_map(om.h_vertex, factor)
421
+ om_new.h_hits = _rebin_map(om.h_hits, factor)
422
+ om_new.create_probability()
423
+ om_new.write_lh5(lh5_file=output_lh5_file, group=submap, wo_mode="write_safe")
reboost/optmap/evt.py ADDED
@@ -0,0 +1,141 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections import OrderedDict
5
+ from collections.abc import Generator, Iterable
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from lgdo import lh5
10
+ from lgdo.lh5 import LH5Iterator
11
+ from lgdo.types import Table
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ EVT_TABLE_NAME = "optmap_evt"
16
+
17
+
18
+ def generate_optmap_evt(
19
+ lh5_in_file: str, detectors: Iterable[str | int], buffer_len: int = int(5e6)
20
+ ) -> Generator[Table, None, None]:
21
+ """Create a faster map for lookup of the hits in each detector, for each primary event."""
22
+ log.info("reading file %s", lh5_in_file)
23
+
24
+ vert_it = LH5Iterator(lh5_in_file, "vtx", buffer_len=buffer_len)
25
+ opti_it = LH5Iterator(lh5_in_file, "stp/optical", buffer_len=buffer_len)
26
+
27
+ if len(detectors) == 0:
28
+ msg = "detector array cannot be empty for optmap-evt building"
29
+ raise ValueError(msg)
30
+ detectors = [str(d) for d in detectors]
31
+ for d in detectors:
32
+ if not d.isnumeric():
33
+ log.warning("Detector ID %s is not numeric.", d)
34
+
35
+ vert_df = None
36
+ vert_df_bounds = None
37
+ hits_expected = 0
38
+ had_last_chunk = False
39
+
40
+ def _store_vert_df(last_chunk: bool) -> Generator[Table, None, None]:
41
+ nonlocal vert_df, had_last_chunk
42
+ if vert_df is None:
43
+ return
44
+
45
+ # sanity check that we did process all hits.
46
+ hits_sum = 0
47
+ for d in detectors:
48
+ hits_sum += np.sum(vert_df[d])
49
+ assert hits_sum == hits_expected
50
+
51
+ yield Table(vert_df)
52
+ had_last_chunk = last_chunk
53
+ vert_df = None
54
+
55
+ # helper function for "windowed join". while iterating the optical hits, we have to
56
+ # make sure that we always have the correct combined vertex/hit output table available.
57
+ #
58
+ # This function follows the assumption, that the output event ids are at least "somewhat"
59
+ # monotonic, i.e. later chunks do not contain lower evtids than the previous chunk(s).
60
+ # Going back is not implemented.
61
+ def _ensure_vert_df(vert_it: LH5Iterator, evtid: int) -> Generator[Table, None, None]:
62
+ nonlocal vert_df, vert_df_bounds, hits_expected
63
+
64
+ # skipping multiple chunks is possible in sparsely populated simulations.
65
+ while vert_df_bounds is None or evtid > vert_df_bounds[1] or evtid < vert_df_bounds[0]:
66
+ if vert_df_bounds is not None and vert_df is not None:
67
+ if evtid < vert_df_bounds[0]:
68
+ msg = "non-monotonic evtid encountered, but cannot go back"
69
+ raise KeyError(msg)
70
+ if evtid >= vert_df_bounds[0] and evtid <= vert_df_bounds[1]:
71
+ return # vert_df already contains the given evtid.
72
+
73
+ # here, evtid > vert_df_bounds[1] (or vert_df_bounds is still None). We need to fetch
74
+ # the next event table chunk.
75
+
76
+ # we might have filled a dataframe, save it to disk.
77
+ yield from _store_vert_df(last_chunk=False)
78
+
79
+ # read the next vertex chunk into memory.
80
+ vert_df = next(vert_it).view_as("pd")
81
+
82
+ # prepare vertex coordinates.
83
+ vert_df = vert_df.set_index("evtid", drop=True).drop(["n_part", "time"], axis=1)
84
+ vert_df_bounds = [vert_df.index.min(), vert_df.index.max()]
85
+ hits_expected = 0
86
+ # add columns for all detectors.
87
+ for d in detectors:
88
+ vert_df[d] = hit_count_type(0)
89
+
90
+ log.info("prepare evt table")
91
+ # use smaller integer type uint8 to spare RAM when storing types.
92
+ hit_count_type = np.uint8
93
+ for opti_it_count, opti_lgdo in enumerate(opti_it):
94
+ opti_df = opti_lgdo.view_as("pd")
95
+
96
+ log.info("build evt table (%d)", opti_it_count)
97
+
98
+ for t in opti_df[["evtid", "det_uid"]].itertuples(name=None, index=False):
99
+ yield from _ensure_vert_df(vert_it, t[0])
100
+ vert_df.loc[t[0], str(t[1])] += 1
101
+ hits_expected += 1
102
+
103
+ yield from _store_vert_df(last_chunk=True) # store the last chunk.
104
+
105
+ assert had_last_chunk, "did not reach last chunk in optmap-evt building"
106
+
107
+
108
+ def build_optmap_evt(
109
+ lh5_in_file: str, lh5_out_file: str, detectors: Iterable[str | int], buffer_len: int = int(5e6)
110
+ ) -> None:
111
+ """Create a faster map for lookup of the hits in each detector, for each primary event."""
112
+ lh5_out_file = Path(lh5_out_file)
113
+ lh5_out_file_tmp = lh5_out_file.with_stem(".evt-tmp." + lh5_out_file.stem)
114
+ if lh5_out_file_tmp.exists():
115
+ msg = f"temporary output file {lh5_out_file_tmp} already exists"
116
+ raise RuntimeError(msg)
117
+
118
+ for vert_it_count, chunk in enumerate(generate_optmap_evt(lh5_in_file, detectors, buffer_len)):
119
+ log.info("store evt file %s (%d)", lh5_out_file_tmp, vert_it_count - 1)
120
+ lh5.write(Table(chunk), name=EVT_TABLE_NAME, lh5_file=lh5_out_file_tmp, wo_mode="append")
121
+
122
+ # after finishing the output file, rename to the actual output file name.
123
+ if lh5_out_file.exists():
124
+ msg = f"output file {lh5_out_file_tmp} already exists after writing tmp output file"
125
+ raise RuntimeError(msg)
126
+ lh5_out_file_tmp.rename(lh5_out_file)
127
+
128
+
129
+ def get_optical_detectors_from_geom(geom_fn) -> dict[int, str]:
130
+ import pyg4ometry
131
+ import pygeomtools
132
+
133
+ geom_registry = pyg4ometry.gdml.Reader(geom_fn).getRegistry()
134
+ detectors = pygeomtools.get_all_sensvols(geom_registry)
135
+ return OrderedDict(
136
+ [(d.uid, name) for name, d in detectors.items() if d.detector_type == "optical"]
137
+ )
138
+
139
+
140
+ def read_optmap_evt(lh5_file: str, buffer_len: int = int(5e6)) -> LH5Iterator:
141
+ return LH5Iterator(lh5_file, EVT_TABLE_NAME, buffer_len=buffer_len)