reboost 0.6.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
reboost/optmap/create.py CHANGED
@@ -4,21 +4,19 @@ import copy
4
4
  import gc
5
5
  import logging
6
6
  import multiprocessing as mp
7
+ from collections.abc import Callable
7
8
  from pathlib import Path
8
- from typing import Callable, Literal
9
+ from typing import Literal
9
10
 
10
11
  import numpy as np
11
- import scipy.optimize
12
- from lgdo import Array, Histogram, Scalar, lh5
12
+ from lgdo import Histogram, lh5
13
13
  from numba import njit
14
14
  from numpy.typing import NDArray
15
15
 
16
16
  from ..log_utils import setup_log
17
17
  from .evt import (
18
- EVT_TABLE_NAME,
19
18
  generate_optmap_evt,
20
19
  get_optical_detectors_from_geom,
21
- read_optmap_evt,
22
20
  )
23
21
  from .optmap import OpticalMap
24
22
 
@@ -26,23 +24,21 @@ log = logging.getLogger(__name__)
26
24
 
27
25
 
28
26
  def _optmaps_for_channels(
29
- optmap_evt_columns: list[str],
27
+ all_det_ids: dict[int, str],
30
28
  settings,
31
29
  chfilter: tuple[str | int] | Literal["*"] = (),
32
30
  use_shmem: bool = False,
33
31
  ):
34
- all_det_ids = [ch_id for ch_id in optmap_evt_columns if ch_id.isnumeric()]
35
-
36
32
  if chfilter != "*":
37
33
  chfilter = [str(ch) for ch in chfilter] # normalize types
38
- optmap_det_ids = [det for det in all_det_ids if str(det) in chfilter]
34
+ optmap_det_ids = {det: name for det, name in all_det_ids.items() if str(det) in chfilter}
39
35
  else:
40
36
  optmap_det_ids = all_det_ids
41
37
 
42
38
  log.info("creating empty optmaps")
43
39
  optmap_count = len(optmap_det_ids) + 1
44
40
  optmaps = [
45
- OpticalMap("all" if i == 0 else optmap_det_ids[i - 1], settings, use_shmem)
41
+ OpticalMap("all" if i == 0 else list(optmap_det_ids.values())[i - 1], settings, use_shmem)
46
42
  for i in range(optmap_count)
47
43
  ]
48
44
 
@@ -75,34 +71,6 @@ def _fill_hit_maps(optmaps: list[OpticalMap], loc, hitcounts: NDArray, ch_idx_to
75
71
  optmaps[i].fill_hits(locm)
76
72
 
77
73
 
78
- def _count_multi_ph_detection(hitcounts) -> NDArray:
79
- hits_per_primary = hitcounts.sum(axis=1)
80
- bins = np.arange(0, hits_per_primary.max() + 1.5) - 0.5
81
- return np.histogram(hits_per_primary, bins)[0]
82
-
83
-
84
- def _fit_multi_ph_detection(hits_per_primary) -> float:
85
- if len(hits_per_primary) <= 2: # have only 0 and 1 hits, can't fit (and also don't need to).
86
- return np.inf
87
-
88
- x = np.arange(0, len(hits_per_primary))
89
- popt, pcov = scipy.optimize.curve_fit(
90
- lambda x, p0, k: p0 * np.exp(-k * x), x[1:], hits_per_primary[1:]
91
- )
92
- best_fit_exponent = popt[1]
93
-
94
- log.info(
95
- "p(> 1 detected photon)/p(1 detected photon) = %f",
96
- sum(hits_per_primary[2:]) / hits_per_primary[1],
97
- )
98
- log.info(
99
- "p(> 1 detected photon)/p(<=1 detected photon) = %f",
100
- sum(hits_per_primary[2:]) / sum(hits_per_primary[0:2]),
101
- )
102
-
103
- return best_fit_exponent
104
-
105
-
106
74
  def _create_optical_maps_process_init(optmaps, log_level) -> None:
107
75
  # need to use shared global state. passing the shared memory arrays via "normal" arguments to
108
76
  # the worker function is not supported...
@@ -114,34 +82,29 @@ def _create_optical_maps_process_init(optmaps, log_level) -> None:
114
82
 
115
83
 
116
84
  def _create_optical_maps_process(
117
- optmap_events_fn, buffer_len, is_stp_file, all_det_ids, ch_idx_to_map_idx
118
- ) -> None:
85
+ optmap_events_fn, buffer_len, all_det_ids, ch_idx_to_map_idx
86
+ ) -> bool:
119
87
  log.info("started worker task for %s", optmap_events_fn)
120
88
  x = _create_optical_maps_chunk(
121
89
  optmap_events_fn,
122
90
  buffer_len,
123
- is_stp_file,
124
91
  all_det_ids,
125
92
  _shared_optmaps,
126
93
  ch_idx_to_map_idx,
127
94
  )
128
95
  log.info("finished worker task for %s", optmap_events_fn)
129
- return tuple(int(i) for i in x)
96
+ return x
130
97
 
131
98
 
132
99
  def _create_optical_maps_chunk(
133
- optmap_events_fn, buffer_len, is_stp_file, all_det_ids, optmaps, ch_idx_to_map_idx
134
- ) -> None:
135
- if not is_stp_file:
136
- optmap_events_it = read_optmap_evt(optmap_events_fn, buffer_len)
137
- else:
138
- optmap_events_it = generate_optmap_evt(optmap_events_fn, all_det_ids, buffer_len)
100
+ optmap_events_fn, buffer_len, all_det_ids, optmaps, ch_idx_to_map_idx
101
+ ) -> bool:
102
+ cols = [str(c) for c in all_det_ids]
103
+ optmap_events_it = generate_optmap_evt(optmap_events_fn, cols, buffer_len)
139
104
 
140
- hits_per_primary = np.zeros(10, dtype=np.int64)
141
- hits_per_primary_len = 0
142
105
  for it_count, events_lgdo in enumerate(optmap_events_it):
143
106
  optmap_events = events_lgdo.view_as("pd")
144
- hitcounts = optmap_events[all_det_ids].to_numpy()
107
+ hitcounts = optmap_events[cols].to_numpy()
145
108
  loc = optmap_events[["xloc", "yloc", "zloc"]].to_numpy()
146
109
 
147
110
  log.debug("filling vertex histogram (%d)", it_count)
@@ -149,23 +112,19 @@ def _create_optical_maps_chunk(
149
112
 
150
113
  log.debug("filling hits histogram (%d)", it_count)
151
114
  _fill_hit_maps(optmaps, loc, hitcounts, ch_idx_to_map_idx)
152
- hpp = _count_multi_ph_detection(hitcounts)
153
- hits_per_primary_len = max(hits_per_primary_len, len(hpp))
154
- hits_per_primary[0 : len(hpp)] += hpp
155
115
 
156
116
  # commit the final part of the hits to the maps.
157
117
  for i in range(len(optmaps)):
158
118
  optmaps[i].fill_hits_flush()
159
119
  gc.collect()
160
120
 
161
- return hits_per_primary[0:hits_per_primary_len]
121
+ return True
162
122
 
163
123
 
164
124
  def create_optical_maps(
165
125
  optmap_events_fn: list[str],
166
126
  settings,
167
127
  buffer_len: int = int(5e6),
168
- is_stp_file: bool = True,
169
128
  chfilter: tuple[str | int] | Literal["*"] = (),
170
129
  output_lh5_fn: str | None = None,
171
130
  after_save: Callable[[int, str, OpticalMap]] | None = None,
@@ -181,8 +140,6 @@ def create_optical_maps(
181
140
  list of filenames to lh5 files, that can either be stp files from remage or "optmap-evt"
182
141
  files with a table ``/optmap_evt`` with columns ``{x,y,z}loc`` and one column (with numeric
183
142
  header) for each SiPM channel.
184
- is_stp_file
185
- if true, do convert a remage output file (stp file) on-the-fly to an optmap-evt file.
186
143
  chfilter
187
144
  tuple of detector ids that will be included in the resulting optmap. Those have to match
188
145
  the column names in ``optmap_events_fn``.
@@ -195,12 +152,7 @@ def create_optical_maps(
195
152
 
196
153
  use_shmem = n_procs is None or n_procs > 1
197
154
 
198
- if not is_stp_file:
199
- optmap_evt_columns = list(
200
- lh5.read(EVT_TABLE_NAME, optmap_events_fn[0], start_row=0, n_rows=1).keys()
201
- ) # peek into the (first) file to find column names.
202
- else:
203
- optmap_evt_columns = [str(i) for i in get_optical_detectors_from_geom(geom_fn)]
155
+ optmap_evt_columns = get_optical_detectors_from_geom(geom_fn)
204
156
 
205
157
  all_det_ids, optmaps, optmap_det_ids = _optmaps_for_channels(
206
158
  optmap_evt_columns, settings, chfilter=chfilter, use_shmem=use_shmem
@@ -208,11 +160,17 @@ def create_optical_maps(
208
160
 
209
161
  # indices for later use in _compute_hit_maps.
210
162
  ch_idx_to_map_idx = np.array(
211
- [optmap_det_ids.index(d) + 1 if d in optmap_det_ids else -1 for d in all_det_ids]
163
+ [
164
+ list(optmap_det_ids.keys()).index(d) + 1 if d in optmap_det_ids else -1
165
+ for d in all_det_ids
166
+ ]
212
167
  )
213
168
  assert np.sum(ch_idx_to_map_idx > 0) == len(optmaps) - 1
214
169
 
215
- log.info("creating optical map groups: %s", ", ".join(["all", *optmap_det_ids]))
170
+ log.info(
171
+ "creating optical map groups: %s",
172
+ ", ".join(["all", *[str(t) for t in optmap_det_ids.items()]]),
173
+ )
216
174
 
217
175
  q = []
218
176
 
@@ -220,9 +178,7 @@ def create_optical_maps(
220
178
  if not use_shmem:
221
179
  for fn in optmap_events_fn:
222
180
  q.append(
223
- _create_optical_maps_chunk(
224
- fn, buffer_len, is_stp_file, all_det_ids, optmaps, ch_idx_to_map_idx
225
- )
181
+ _create_optical_maps_chunk(fn, buffer_len, all_det_ids, optmaps, ch_idx_to_map_idx)
226
182
  )
227
183
  else:
228
184
  ctx = mp.get_context("forkserver")
@@ -242,14 +198,14 @@ def create_optical_maps(
242
198
  for fn in optmap_events_fn:
243
199
  r = pool.apply_async(
244
200
  _create_optical_maps_process,
245
- args=(fn, buffer_len, is_stp_file, all_det_ids, ch_idx_to_map_idx),
201
+ args=(fn, buffer_len, all_det_ids, ch_idx_to_map_idx),
246
202
  )
247
203
  pool_results.append((r, fn))
248
204
 
249
205
  pool.close()
250
206
  for r, fn in pool_results:
251
207
  try:
252
- q.append(np.array(r.get()))
208
+ q.append(r.get())
253
209
  except BaseException as e:
254
210
  msg = f"error while processing file {fn}"
255
211
  raise RuntimeError(msg) from e # re-throw errors of workers.
@@ -257,17 +213,8 @@ def create_optical_maps(
257
213
  pool.join()
258
214
  log.info("joined worker process pool")
259
215
 
260
- # merge hitcounts.
261
216
  if len(q) != len(optmap_events_fn):
262
217
  log.error("got %d results for %d files", len(q), len(optmap_events_fn))
263
- hits_per_primary = np.zeros(10, dtype=np.int64)
264
- hits_per_primary_len = 0
265
- for hitcounts in q:
266
- hits_per_primary[0 : len(hitcounts)] += hitcounts
267
- hits_per_primary_len = max(hits_per_primary_len, len(hitcounts))
268
-
269
- hits_per_primary = hits_per_primary[0:hits_per_primary_len]
270
- hits_per_primary_exponent = _fit_multi_ph_detection(hits_per_primary)
271
218
 
272
219
  # all maps share the same vertex histogram.
273
220
  for i in range(1, len(optmaps)):
@@ -278,7 +225,7 @@ def create_optical_maps(
278
225
  optmaps[i].create_probability()
279
226
  if check_after_create:
280
227
  optmaps[i].check_histograms()
281
- group = "all" if i == 0 else "_" + optmap_det_ids[i - 1]
228
+ group = "all" if i == 0 else "channels/" + list(optmap_det_ids.values())[i - 1]
282
229
  if output_lh5_fn is not None:
283
230
  optmaps[i].write_lh5(lh5_file=output_lh5_fn, group=group)
284
231
 
@@ -287,14 +234,12 @@ def create_optical_maps(
287
234
 
288
235
  optmaps[i] = None # clear some memory.
289
236
 
290
- if output_lh5_fn is not None:
291
- lh5.write(Array(hits_per_primary), "_hitcounts", lh5_file=output_lh5_fn)
292
- lh5.write(Scalar(hits_per_primary_exponent), "_hitcounts_exp", lh5_file=output_lh5_fn)
293
-
294
237
 
295
238
  def list_optical_maps(lh5_file: str) -> list[str]:
296
- maps = lh5.ls(lh5_file)
297
- return [m for m in maps if m not in ("_hitcounts", "_hitcounts_exp")]
239
+ maps = list(lh5.ls(lh5_file, "/channels/"))
240
+ if "all" in lh5.ls(lh5_file):
241
+ maps.append("all")
242
+ return maps
298
243
 
299
244
 
300
245
  def _merge_optical_maps_process(
@@ -312,9 +257,9 @@ def _merge_optical_maps_process(
312
257
 
313
258
  all_edges = None
314
259
  for optmap_fn in map_l5_files:
315
- nr_det = lh5.read(f"/{d}/nr_det", optmap_fn)
260
+ nr_det = lh5.read(f"/{d}/_nr_det", optmap_fn)
316
261
  assert isinstance(nr_det, Histogram)
317
- nr_gen = lh5.read(f"/{d}/nr_gen", optmap_fn)
262
+ nr_gen = lh5.read(f"/{d}/_nr_gen", optmap_fn)
318
263
  assert isinstance(nr_gen, Histogram)
319
264
 
320
265
  assert OpticalMap._edges_eq(nr_det.binning, nr_gen.binning)
@@ -332,7 +277,8 @@ def _merge_optical_maps_process(
332
277
  merged_map.check_histograms(include_prefix=True)
333
278
 
334
279
  if write_part_file:
335
- output_lh5_fn = f"{output_lh5_fn}_{d}.mappart.lh5"
280
+ d_for_tmp = d.replace("/", "_")
281
+ output_lh5_fn = f"{output_lh5_fn}_{d_for_tmp}.mappart.lh5"
336
282
  wo_mode = "overwrite_file" if write_part_file else "write_safe"
337
283
  merged_map.write_lh5(lh5_file=output_lh5_fn, group=d, wo_mode=wo_mode)
338
284
 
@@ -409,7 +355,7 @@ def merge_optical_maps(
409
355
  # transfer to actual output file.
410
356
  for d, part_fn in q:
411
357
  assert isinstance(part_fn, str)
412
- for h_name in ("nr_det", "nr_gen", "p_det", "p_det_err"):
358
+ for h_name in ("_nr_det", "_nr_gen", "prob", "prob_unc"):
413
359
  obj = f"/{d}/{h_name}"
414
360
  log.info("transfer %s from %s", obj, part_fn)
415
361
  h = lh5.read(obj, part_fn)
@@ -417,43 +363,19 @@ def merge_optical_maps(
417
363
  lh5.write(h, obj, output_lh5_fn, wo_mode="write_safe")
418
364
  Path(part_fn).unlink()
419
365
 
420
- # merge hitcounts.
421
- hits_per_primary = np.zeros(10, dtype=np.int64)
422
- hits_per_primary_len = 0
423
- for optmap_fn in map_l5_files:
424
- if "_hitcounts" not in lh5.ls(optmap_fn):
425
- log.warning("skipping _hitcounts calculations, missing in file %s", optmap_fn)
426
- return
427
- hitcounts = lh5.read("/_hitcounts", optmap_fn)
428
- assert isinstance(hitcounts, Array)
429
- hits_per_primary[0 : len(hitcounts)] += hitcounts
430
- hits_per_primary_len = max(hits_per_primary_len, len(hitcounts))
431
-
432
- hits_per_primary = hits_per_primary[0:hits_per_primary_len]
433
- lh5.write(Array(hits_per_primary), "_hitcounts", lh5_file=output_lh5_fn)
434
-
435
- # re-calculate hitcounts exponent.
436
- hits_per_primary_exponent = _fit_multi_ph_detection(hits_per_primary)
437
- lh5.write(Scalar(hits_per_primary_exponent), "_hitcounts_exp", lh5_file=output_lh5_fn)
438
-
439
366
 
440
367
  def check_optical_map(map_l5_file: str):
441
368
  """Run a health check on the map file.
442
369
 
443
370
  This checks for consistency, and outputs details on map statistics.
444
371
  """
445
- if "_hitcounts_exp" not in lh5.ls(map_l5_file):
446
- log.info("no _hitcounts_exp found")
447
- elif lh5.read("_hitcounts_exp", lh5_file=map_l5_file).value != np.inf:
372
+ if (
373
+ "_hitcounts_exp" in lh5.ls(map_l5_file)
374
+ and lh5.read("_hitcounts_exp", lh5_file=map_l5_file).value != np.inf
375
+ ):
448
376
  log.error("unexpected _hitcounts_exp not equal to positive infinity")
449
377
  return
450
378
 
451
- if "_hitcounts" not in lh5.ls(map_l5_file):
452
- log.info("no _hitcounts found")
453
- elif lh5.read("_hitcounts", lh5_file=map_l5_file).nda.shape != (2,):
454
- log.error("unexpected _hitcounts shape")
455
- return
456
-
457
379
  all_binning = None
458
380
  for submap in list_optical_maps(map_l5_file):
459
381
  try:
@@ -502,8 +424,3 @@ def rebin_optical_maps(map_l5_file: str, output_lh5_file: str, factor: int):
502
424
  om_new.h_hits = _rebin_map(om.h_hits, factor)
503
425
  om_new.create_probability()
504
426
  om_new.write_lh5(lh5_file=output_lh5_file, group=submap, wo_mode="write_safe")
505
-
506
- # just copy hitcounts exponent.
507
- for dset in ("_hitcounts_exp", "_hitcounts"):
508
- if dset in lh5.ls(map_l5_file):
509
- lh5.write(lh5.read(dset, lh5_file=map_l5_file), dset, lh5_file=output_lh5_file)
reboost/optmap/evt.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from collections import OrderedDict
4
5
  from collections.abc import Generator, Iterable
5
6
  from pathlib import Path
6
7
 
@@ -125,13 +126,15 @@ def build_optmap_evt(
125
126
  lh5_out_file_tmp.rename(lh5_out_file)
126
127
 
127
128
 
128
- def get_optical_detectors_from_geom(geom_fn) -> list[int]:
129
+ def get_optical_detectors_from_geom(geom_fn) -> dict[int, str]:
129
130
  import pyg4ometry
130
131
  import pygeomtools
131
132
 
132
133
  geom_registry = pyg4ometry.gdml.Reader(geom_fn).getRegistry()
133
134
  detectors = pygeomtools.get_all_sensvols(geom_registry)
134
- return [d.uid for d in detectors.values() if d.detector_type == "optical"]
135
+ return OrderedDict(
136
+ [(d.uid, name) for name, d in detectors.items() if d.detector_type == "optical"]
137
+ )
135
138
 
136
139
 
137
140
  def read_optmap_evt(lh5_file: str, buffer_len: int = int(5e6)) -> LH5Iterator:
reboost/optmap/mapview.py CHANGED
@@ -92,14 +92,16 @@ def _channel_selector(fig) -> None:
92
92
  def _read_data(
93
93
  optmap_fn: str,
94
94
  detid: str = "all",
95
- histogram_choice: str = "p_det",
95
+ histogram_choice: str = "prob",
96
96
  ) -> tuple[tuple[NDArray], NDArray]:
97
- histogram = histogram_choice if histogram_choice != "p_det_err_rel" else "p_det"
97
+ histogram = histogram_choice if histogram_choice != "prob_unc_rel" else "prob"
98
+ detid = f"channels/{detid}" if detid != all and not detid.startswith("channels/") else detid
99
+
98
100
  optmap_all = lh5.read(f"/{detid}/{histogram}", optmap_fn)
99
101
  optmap_edges = tuple([b.edges for b in optmap_all.binning])
100
102
  optmap_weights = optmap_all.weights.nda.copy()
101
- if histogram_choice == "p_det_err_rel":
102
- optmap_err = lh5.read(f"/{detid}/p_det_err", optmap_fn)
103
+ if histogram_choice == "prob_unc_rel":
104
+ optmap_err = lh5.read(f"/{detid}/prob_unc", optmap_fn)
103
105
  divmask = optmap_weights > 0
104
106
  optmap_weights[divmask] = optmap_err.weights.nda[divmask] / optmap_weights[divmask]
105
107
  optmap_weights[~divmask] = -1
@@ -112,13 +114,13 @@ def _prepare_data(
112
114
  divide_fn: str | None = None,
113
115
  cmap_min: float | Literal["auto"] = 1e-4,
114
116
  cmap_max: float | Literal["auto"] = 1e-2,
115
- histogram_choice: str = "p_det",
117
+ histogram_choice: str = "prob",
116
118
  detid: str = "all",
117
119
  ) -> tuple[tuple[NDArray], NDArray]:
118
120
  optmap_edges, optmap_weights = _read_data(optmap_fn, detid, histogram_choice)
119
121
 
120
122
  if divide_fn is not None:
121
- divide_edges, divide_map = _read_data(divide_fn, detid, histogram_choice)
123
+ _, divide_map = _read_data(divide_fn, detid, histogram_choice)
122
124
  divmask = divide_map > 0
123
125
  optmap_weights[divmask] = optmap_weights[divmask] / divide_map[divmask]
124
126
  optmap_weights[~divmask] = -1
@@ -158,7 +160,7 @@ def view_optmap(
158
160
  start_axis: int = 2,
159
161
  cmap_min: float | Literal["auto"] = 1e-4,
160
162
  cmap_max: float | Literal["auto"] = 1e-2,
161
- histogram_choice: str = "p_det",
163
+ histogram_choice: str = "prob",
162
164
  title: str | None = None,
163
165
  ) -> None:
164
166
  available_dets = list_optical_maps(optmap_fn)
reboost/optmap/optmap.py CHANGED
@@ -8,7 +8,7 @@ import multiprocessing as mp
8
8
  from collections.abc import Mapping
9
9
 
10
10
  import numpy as np
11
- from lgdo import Histogram, lh5
11
+ from lgdo import Histogram, Struct, lh5
12
12
  from numpy.typing import NDArray
13
13
 
14
14
  log = logging.getLogger(__name__)
@@ -66,10 +66,10 @@ class OpticalMap:
66
66
  raise RuntimeError(msg)
67
67
  return h.weights.nda, h.binning
68
68
 
69
- om.h_vertex, bin_nr_gen = read_hist("nr_gen", lh5_file, group=group)
70
- om.h_hits, bin_nr_det = read_hist("nr_det", lh5_file, group=group)
71
- om.h_prob, bin_p_det = read_hist("p_det", lh5_file, group=group)
72
- om.h_prob_uncert, bin_p_det_err = read_hist("p_det_err", lh5_file, group=group)
69
+ om.h_vertex, bin_nr_gen = read_hist("_nr_gen", lh5_file, group=group)
70
+ om.h_hits, bin_nr_det = read_hist("_nr_det", lh5_file, group=group)
71
+ om.h_prob, bin_p_det = read_hist("prob", lh5_file, group=group)
72
+ om.h_prob_uncert, bin_p_det_err = read_hist("prob_unc", lh5_file, group=group)
73
73
 
74
74
  for bins in (bin_nr_det, bin_p_det, bin_p_det_err):
75
75
  if not OpticalMap._edges_eq(bin_nr_gen, bins):
@@ -112,7 +112,7 @@ class OpticalMap:
112
112
  idx = np.zeros(xyz.shape[1], np.int64) # bin indices for flattened array
113
113
  oor_mask = np.ones(xyz.shape[1], np.bool_) # mask to remove out of range values
114
114
  dims = range(xyz.shape[0])
115
- for col, ax, s, dim in zip(xyz, self.binning, self._single_stride, dims):
115
+ for col, ax, s, dim in zip(xyz, self.binning, self._single_stride, dims, strict=True):
116
116
  assert ax.is_range
117
117
  assert ax.closedleft
118
118
  oor_mask &= (ax.first <= col) & (col < ax.last)
@@ -227,18 +227,17 @@ class OpticalMap:
227
227
 
228
228
  def write_hist(h: NDArray, name: str, fn: str, group: str, wo_mode: str):
229
229
  lh5.write(
230
- Histogram(self._nda(h), self.binning),
231
- name,
230
+ Struct({name: Histogram(self._nda(h), self.binning)}),
231
+ group,
232
232
  fn,
233
- group=group,
234
233
  wo_mode=wo_mode,
235
234
  )
236
235
 
237
236
  # only use the passed wo_mode for the first file.
238
- write_hist(self.h_vertex, "nr_gen", lh5_file, group, wo_mode)
239
- write_hist(self.h_hits, "nr_det", lh5_file, group, "write_safe")
240
- write_hist(self.h_prob, "p_det", lh5_file, group, "write_safe")
241
- write_hist(self.h_prob_uncert, "p_det_err", lh5_file, group, "write_safe")
237
+ write_hist(self.h_vertex, "_nr_gen", lh5_file, group, wo_mode)
238
+ write_hist(self.h_hits, "_nr_det", lh5_file, group, "append_column")
239
+ write_hist(self.h_prob, "prob", lh5_file, group, "append_column")
240
+ write_hist(self.h_prob_uncert, "prob_unc", lh5_file, group, "append_column")
242
241
 
243
242
  def get_settings(self) -> dict:
244
243
  """Get the binning settings that were used to create this optical map instance."""
@@ -326,4 +325,4 @@ class OpticalMap:
326
325
  assert all(isinstance(b, np.ndarray) for b in e1)
327
326
  assert all(isinstance(b, np.ndarray) for b in e2)
328
327
 
329
- return len(e1) == len(e2) and all(np.all(x1 == x2) for x1, x2 in zip(e1, e2))
328
+ return all(np.all(x1 == x2) for x1, x2 in zip(e1, e2, strict=True))
reboost/shape/cluster.py CHANGED
@@ -97,9 +97,9 @@ def cluster_by_step_length(
97
97
 
98
98
  pos = np.vstack(
99
99
  [
100
- ak.flatten(pos_x).to_numpy(),
101
- ak.flatten(pos_y).to_numpy(),
102
- ak.flatten(pos_z).to_numpy(),
100
+ ak.flatten(pos_x).to_numpy().astype(np.float64),
101
+ ak.flatten(pos_y).to_numpy().astype(np.float64),
102
+ ak.flatten(pos_z).to_numpy().astype(np.float64),
103
103
  ]
104
104
  ).T
105
105
 
@@ -164,7 +164,7 @@ def cluster_by_distance_numba(
164
164
  return np.sqrt(np.sum((a - b) ** 2))
165
165
 
166
166
  n = len(local_index)
167
- out = np.zeros(n, dtype=numba.int32)
167
+ out = np.zeros((n,), dtype=numba.int32)
168
168
 
169
169
  trackid_prev = -1
170
170
  pos_prev = np.zeros(3, dtype=numba.float64)
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .pe import detected_photoelectrons, emitted_scintillation_photons, load_optmap
4
+
5
+ __all__ = ["detected_photoelectrons", "emitted_scintillation_photons", "load_optmap"]