reboost 0.6.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,17 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import re
5
+ from typing import NamedTuple
5
6
 
7
+ import awkward as ak
6
8
  import legendoptics.scintillate as sc
7
9
  import numba
8
10
  import numpy as np
9
- import pint
10
- from legendoptics import lar
11
+ from legendoptics import fibers, lar, pen
11
12
  from lgdo import lh5
12
- from lgdo.lh5 import LH5Iterator
13
13
  from lgdo.types import Array, Histogram, Table
14
- from numba import njit, prange
15
- from numpy.lib.recfunctions import structured_to_unstructured
14
+ from numba import njit
15
+ from numpy.typing import NDArray
16
16
 
17
17
  from .numba_pdg import numba_pdgid_funcs
18
18
 
@@ -23,22 +23,31 @@ OPTMAP_ANY_CH = -1
23
23
  OPTMAP_SUM_CH = -2
24
24
 
25
25
 
26
- def open_optmap(optmap_fn: str):
26
+ class OptmapForConvolve(NamedTuple):
27
+ """A loaded optmap for convolving."""
28
+
29
+ detids: NDArray
30
+ detidx: NDArray
31
+ edges: NDArray
32
+ weights: NDArray
33
+
34
+
35
+ def open_optmap(optmap_fn: str) -> OptmapForConvolve:
27
36
  maps = lh5.ls(optmap_fn)
28
37
  # only accept _<number> (/all is read separately)
29
38
  det_ntuples = [m for m in maps if re.match(r"_\d+$", m)]
30
39
  detids = np.array([int(m.lstrip("_")) for m in det_ntuples])
31
40
  detidx = np.arange(0, detids.shape[0])
32
41
 
33
- optmap_all = lh5.read("/all/p_det", optmap_fn)
42
+ optmap_all = lh5.read("/all/prob", optmap_fn)
34
43
  assert isinstance(optmap_all, Histogram)
35
44
  optmap_edges = tuple([b.edges for b in optmap_all.binning])
36
45
 
37
46
  ow = np.empty((detidx.shape[0] + 2, *optmap_all.weights.nda.shape), dtype=np.float64)
38
47
  # 0, ..., len(detidx)-1 AND OPTMAP_ANY_CH might be negative.
39
48
  ow[OPTMAP_ANY_CH] = optmap_all.weights.nda
40
- for i, nt in zip(detidx, det_ntuples):
41
- optmap = lh5.read(f"/{nt}/p_det", optmap_fn)
49
+ for i, nt in zip(detidx, det_ntuples, strict=True):
50
+ optmap = lh5.read(f"/{nt}/prob", optmap_fn)
42
51
  assert isinstance(optmap, Histogram)
43
52
  ow[i] = optmap.weights.nda
44
53
 
@@ -69,38 +78,67 @@ def open_optmap(optmap_fn: str):
69
78
  if np.isfinite(optmap_multi_det_exp):
70
79
  msg = f"found finite _hitcounts_exp {optmap_multi_det_exp} which is not supported any more"
71
80
  raise RuntimeError(msg)
72
- except KeyError: # the _hitcounts_exp might not be always present.
81
+ except lh5.exceptions.LH5DecodeError: # the _hitcounts_exp might not be always present.
73
82
  pass
74
83
 
75
- return detids, detidx, optmap_edges, ow
84
+ return OptmapForConvolve(detids, detidx, optmap_edges, ow)
76
85
 
77
86
 
78
- def iterate_stepwise_depositions(
79
- edep_df: np.rec.recarray,
80
- optmap_for_convolve,
87
+ def open_optmap_single(optmap_fn: str, spm_det_uid: int) -> OptmapForConvolve:
88
+ try:
89
+ # check the exponent from the optical map file
90
+ optmap_multi_det_exp = lh5.read("/_hitcounts_exp", optmap_fn).value
91
+ assert isinstance(optmap_multi_det_exp, float)
92
+ if np.isfinite(optmap_multi_det_exp):
93
+ msg = f"found finite _hitcounts_exp {optmap_multi_det_exp} which is not supported any more"
94
+ raise RuntimeError(msg)
95
+ except lh5.exceptions.LH5DecodeError: # the _hitcounts_exp might not be always present.
96
+ pass
97
+
98
+ optmap = lh5.read(f"/_{spm_det_uid}/prob", optmap_fn)
99
+ assert isinstance(optmap, Histogram)
100
+ ow = np.empty((1, *optmap.weights.nda.shape), dtype=np.float64)
101
+ ow[0] = optmap.weights.nda
102
+ optmap_edges = tuple([b.edges for b in optmap.binning])
103
+
104
+ return OptmapForConvolve(np.array([spm_det_uid]), np.array([0]), optmap_edges, ow)
105
+
106
+
107
+ def iterate_stepwise_depositions_pois(
108
+ edep_hits: ak.Array,
109
+ optmap: OptmapForConvolve,
81
110
  scint_mat_params: sc.ComputedScintParams,
82
- rng: np.random.Generator = None,
83
- dist: str = "poisson",
84
- mode: str = "no-fano",
111
+ det_uid: int,
112
+ map_scaling: float = 1,
113
+ map_scaling_sigma: float = 0,
114
+ rng: np.random.Generator | None = None,
85
115
  ):
86
- # those np functions are not supported by numba, but needed for efficient array access below.
87
- if "xloc_pre" in edep_df.dtype.names:
88
- x0 = structured_to_unstructured(edep_df[["xloc_pre", "yloc_pre", "zloc_pre"]], np.float64)
89
- x1 = structured_to_unstructured(
90
- edep_df[["xloc_post", "yloc_post", "zloc_post"]], np.float64
91
- )
92
- else:
93
- x0 = structured_to_unstructured(edep_df[["xloc", "yloc", "zloc"]], np.float64)
94
- x1 = None
116
+ if edep_hits.particle.ndim == 1:
117
+ msg = "the pe processors only support already reshaped output"
118
+ raise ValueError(msg)
95
119
 
96
120
  rng = np.random.default_rng() if rng is None else rng
97
- output_map, res = _iterate_stepwise_depositions(
98
- edep_df, x0, x1, rng, *optmap_for_convolve, scint_mat_params, dist, mode
121
+ res, output_list = _iterate_stepwise_depositions_pois(
122
+ edep_hits,
123
+ rng,
124
+ np.where(optmap.detids == det_uid)[0][0],
125
+ map_scaling,
126
+ map_scaling_sigma,
127
+ optmap.edges,
128
+ optmap.weights,
129
+ scint_mat_params,
99
130
  )
100
- if res["any_no_stats"] > 0 or res["det_no_stats"] > 0:
131
+
132
+ # convert the numba result back into an awkward array.
133
+ builder = ak.ArrayBuilder()
134
+ for r in output_list:
135
+ with builder.list():
136
+ for a in r:
137
+ builder.extend(a)
138
+
139
+ if res["det_no_stats"] > 0:
101
140
  log.warning(
102
- "had edep out in voxels without stats: %d (%.2f%%)",
103
- res["any_no_stats"],
141
+ "had edep out in voxels without stats: %d",
104
142
  res["det_no_stats"],
105
143
  )
106
144
  if res["oob"] > 0:
@@ -110,14 +148,34 @@ def iterate_stepwise_depositions(
110
148
  (res["oob"] / (res["ib"] + res["oob"])) * 100,
111
149
  )
112
150
  log.debug(
113
- "VUV_primary %d ->hits_any %d ->hits %d (%.2f %% primaries detected)",
151
+ "VUV_primary %d ->hits %d (%.2f %% primaries detected in this channel)",
114
152
  res["vuv_primary"],
115
- res["hits_any"],
116
153
  res["hits"],
117
- (res["hits_any"] / res["vuv_primary"]) * 100,
154
+ (res["hits"] / res["vuv_primary"]) * 100,
118
155
  )
119
- log.debug("hits/hits_any %.2f", res["hits"] / res["hits_any"])
120
- return output_map
156
+ return builder.snapshot()
157
+
158
+
159
+ def iterate_stepwise_depositions_scintillate(
160
+ edep_hits: ak.Array,
161
+ scint_mat_params: sc.ComputedScintParams,
162
+ rng: np.random.Generator | None = None,
163
+ mode: str = "no-fano",
164
+ ):
165
+ if edep_hits.particle.ndim == 1:
166
+ msg = "the pe processors only support already reshaped output"
167
+ raise ValueError(msg)
168
+
169
+ rng = np.random.default_rng() if rng is None else rng
170
+ output_list = _iterate_stepwise_depositions_scintillate(edep_hits, rng, scint_mat_params, mode)
171
+
172
+ # convert the numba result back into an awkward array.
173
+ builder = ak.ArrayBuilder()
174
+ for r in output_list:
175
+ with builder.list():
176
+ builder.extend(r)
177
+
178
+ return builder.snapshot()
121
179
 
122
180
 
123
181
  _pdg_func = numba_pdgid_funcs()
@@ -144,178 +202,116 @@ __counts_per_bin_key_type = numba.types.UniTuple(numba.types.int64, 3)
144
202
  # - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
145
203
  # - the output dictionary is not threadsafe, so parallel=True is not working with it.
146
204
  @njit(parallel=False, nogil=True, cache=True)
147
- def _iterate_stepwise_depositions(
148
- edep_df,
149
- x0,
150
- x1,
205
+ def _iterate_stepwise_depositions_pois(
206
+ edep_hits,
151
207
  rng,
152
- detids,
153
- detidx,
208
+ detidx: int,
209
+ map_scaling: float,
210
+ map_scaling_sigma: float,
154
211
  optmap_edges,
155
212
  optmap_weights,
156
213
  scint_mat_params: sc.ComputedScintParams,
157
- dist: str,
158
- mode: str,
159
214
  ):
160
215
  pdgid_map = {}
161
- output_map = {}
162
- oob = ib = ph_cnt = ph_det = ph_det2 = any_no_stats = det_no_stats = 0 # for statistics
163
- for rowid in prange(edep_df.shape[0]):
164
- # if rowid % 100000 == 0:
165
- # print(rowid)
166
- t = edep_df[rowid]
167
-
168
- # get the particle information.
169
- if t.particle not in pdgid_map:
170
- pdgid_map[t.particle] = (_pdgid_to_particle(t.particle), _pdg_func.charge(t.particle))
171
-
172
- # do the scintillation.
173
- part, charge = pdgid_map[t.particle]
174
-
175
- # if we have both pre and post step points use them
176
- # else pass as None
177
-
178
- scint_times = sc.scintillate(
179
- scint_mat_params,
180
- x0[rowid],
181
- x1[rowid] if x1 is not None else None,
182
- t.v_pre if x1 is not None else None,
183
- t.v_post if x1 is not None else None,
184
- t.time,
185
- part,
186
- charge,
187
- t.edep,
188
- rng,
189
- emission_term_model=("poisson" if mode == "no-fano" else "normal_fano"),
190
- )
191
- if scint_times.shape[0] == 0: # short-circuit if we have no photons at all.
192
- continue
193
- ph_cnt += scint_times.shape[0]
194
-
195
- # coordinates -> bins of the optical map.
196
- bins = np.empty((scint_times.shape[0], 3), dtype=np.int64)
197
- for j in range(3):
198
- bins[:, j] = np.digitize(scint_times[:, j + 1], optmap_edges[j])
199
- # normalize all out-of-bounds bins just to one end.
200
- bins[:, j][bins[:, j] == optmap_edges[j].shape[0]] = 0
201
-
202
- # there are _much_ less unique bins, unfortunately np.unique(..., axis=n) does not work
203
- # with numba; also np.sort(..., axis=n) also does not work.
204
-
205
- counts_per_bin = numba.typed.Dict.empty(
206
- key_type=__counts_per_bin_key_type,
207
- value_type=np.int64,
208
- )
216
+ oob = ib = ph_cnt = ph_det2 = det_no_stats = 0 # for statistics
217
+ output_list = []
218
+
219
+ for rowid in range(len(edep_hits)): # iterate hits
220
+ hit = edep_hits[rowid]
221
+ hit_output = []
222
+
223
+ map_scaling_evt = map_scaling
224
+ if map_scaling_sigma > 0:
225
+ map_scaling_evt = rng.normal(loc=map_scaling, scale=map_scaling_sigma)
226
+
227
+ assert len(hit.particle) == len(hit.num_scint_ph)
228
+ # iterate steps inside the hit
229
+ for si in range(len(hit.particle)):
230
+ loc = np.array([hit.xloc[si], hit.yloc[si], hit.zloc[si]])
231
+ # coordinates -> bins of the optical map.
232
+ bins = np.empty(3, dtype=np.int64)
233
+ for j in range(3):
234
+ bins[j] = np.digitize(loc[j], optmap_edges[j])
235
+ # normalize all out-of-bounds bins just to one end.
236
+ if bins[j] == optmap_edges[j].shape[0]:
237
+ bins[j] = 0
209
238
 
210
- # get probabilities from map.
211
- hitcount = np.zeros((detidx.shape[0], bins.shape[0]), dtype=np.int64)
212
- for j in prange(bins.shape[0]):
213
239
  # note: subtract 1 from bins, to account for np.digitize output.
214
- cur_bins = (bins[j, 0] - 1, bins[j, 1] - 1, bins[j, 2] - 1)
240
+ cur_bins = (bins[0] - 1, bins[1] - 1, bins[2] - 1)
215
241
  if cur_bins[0] == -1 or cur_bins[1] == -1 or cur_bins[2] == -1:
216
242
  oob += 1
217
243
  continue # out-of-bounds of optmap
218
244
  ib += 1
219
245
 
220
- px_any = optmap_weights[OPTMAP_ANY_CH, cur_bins[0], cur_bins[1], cur_bins[2]]
221
- if px_any < 0.0:
222
- any_no_stats += 1
223
- continue
224
- if px_any == 0.0:
246
+ # get probabilities from map.
247
+ detp = optmap_weights[detidx, cur_bins[0], cur_bins[1], cur_bins[2]] * map_scaling_evt
248
+ if detp < 0.0:
249
+ det_no_stats += 1
225
250
  continue
226
251
 
227
- if dist == "multinomial":
228
- if rng.uniform() >= px_any:
229
- continue
230
- ph_det += 1
231
- # we detect this energy deposition; we should at least get one photon out here!
232
-
233
- detsel_size = 1
234
-
235
- px_sum = optmap_weights[OPTMAP_SUM_CH, cur_bins[0], cur_bins[1], cur_bins[2]]
236
- assert px_sum >= 0.0 # should not be negative.
237
- detp = np.empty(detidx.shape, dtype=np.float64)
238
- had_det_no_stats = 0
239
- for d in detidx:
240
- # normalize so that sum(detp) = 1
241
- detp[d] = optmap_weights[d, cur_bins[0], cur_bins[1], cur_bins[2]] / px_sum
242
- if detp[d] < 0.0:
243
- had_det_no_stats = 1
244
- detp[d] = 0.0
245
- det_no_stats += had_det_no_stats
246
-
247
- # should be equivalent to rng.choice(detidx, size=detsel_size, p=detp)
248
- detsel = detidx[
249
- np.searchsorted(np.cumsum(detp), rng.random(size=(detsel_size,)), side="right")
250
- ]
251
- for d in detsel:
252
- hitcount[d, j] += 1
253
- ph_det2 += detsel.shape[0]
254
-
255
- elif dist == "poisson":
256
- # store the photon count in each bin, to sample them all at once below.
257
- if cur_bins not in counts_per_bin:
258
- counts_per_bin[cur_bins] = 1
259
- else:
260
- counts_per_bin[cur_bins] += 1
261
-
262
- else:
263
- msg = "unknown distribution"
264
- raise RuntimeError(msg)
265
-
266
- if dist == "poisson":
267
- for j, (cur_bins, ph_counts_to_poisson) in enumerate(counts_per_bin.items()):
268
- had_det_no_stats = 0
269
- had_any = 0
270
- for d in detidx:
271
- detp = optmap_weights[d, cur_bins[0], cur_bins[1], cur_bins[2]]
272
- if detp < 0.0:
273
- had_det_no_stats = 1
274
- continue
275
- pois_cnt = rng.poisson(lam=ph_counts_to_poisson * detp)
276
- hitcount[d, j] += pois_cnt
277
- ph_det2 += pois_cnt
278
- had_any = 1
279
- ph_det += had_any
280
- det_no_stats += had_det_no_stats
281
-
282
- assert scint_times.shape[0] >= hitcount.shape[1] # TODO: use the right assertion here.
283
- out_hits_len = np.sum(hitcount)
284
- if out_hits_len > 0:
285
- out_times = np.empty(out_hits_len, dtype=np.float64)
286
- out_det = np.empty(out_hits_len, dtype=np.int64)
287
- out_idx = 0
288
- for d in detidx:
289
- hc_d_plane_max = np.max(hitcount[d, :])
290
- # untangle the hitcount array in "planes" that only contain the given number of hits per
291
- # channel. example: assume a "histogram" of hits per channel:
292
- # x | | <-- this is plane 2 with 1 hit ("max plane")
293
- # x | | x <-- this is plane 1 with 2 hits
294
- # ch: 1 | 2 | 3
295
- for hc_d_plane_cnt in range(1, hc_d_plane_max + 1):
296
- hc_d_plane = hitcount[d, :] >= hc_d_plane_cnt
297
- hc_d_plane_len = np.sum(hc_d_plane)
298
- if hc_d_plane_len == 0:
299
- continue
300
-
301
- # note: we assume "immediate" propagation after scintillation. Here, a single timestamp
302
- # might be coipied to output/"detected" twice.
303
- out_times[out_idx : out_idx + hc_d_plane_len] = scint_times[hc_d_plane, 0]
304
- out_det[out_idx : out_idx + hc_d_plane_len] = detids[d]
305
- out_idx += hc_d_plane_len
306
- assert out_idx == out_hits_len # ensure that all of out_{det,times} is filled.
307
- output_map[np.int64(rowid)] = (t.evtid, out_det, out_times)
252
+ pois_cnt = rng.poisson(lam=hit.num_scint_ph[si] * detp)
253
+ ph_cnt += hit.num_scint_ph[si]
254
+ ph_det2 += pois_cnt
255
+
256
+ # get the particle information.
257
+ particle = hit.particle[si]
258
+ if particle not in pdgid_map:
259
+ pdgid_map[particle] = (_pdgid_to_particle(particle), _pdg_func.charge(particle))
260
+ part, _charge = pdgid_map[particle]
261
+
262
+ # get time spectrum.
263
+ # note: we assume "immediate" propagation after scintillation.
264
+ scint_times = sc.scintillate_times(scint_mat_params, part, pois_cnt, rng) + hit.time[si]
265
+
266
+ hit_output.append(scint_times)
267
+
268
+ output_list.append(hit_output)
308
269
 
309
270
  stats = {
310
271
  "oob": oob,
311
272
  "ib": ib,
312
273
  "vuv_primary": ph_cnt,
313
- "hits_any": ph_det,
314
274
  "hits": ph_det2,
315
- "any_no_stats": any_no_stats,
316
275
  "det_no_stats": det_no_stats,
317
276
  }
318
- return output_map, stats
277
+ return stats, output_list
278
+
279
+
280
+ # - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking
281
+ # - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
282
+ @njit(parallel=False, nogil=True, cache=True)
283
+ def _iterate_stepwise_depositions_scintillate(
284
+ edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str
285
+ ):
286
+ pdgid_map = {}
287
+ output_list = []
288
+
289
+ for rowid in range(len(edep_hits)): # iterate hits
290
+ hit = edep_hits[rowid]
291
+ hit_output = []
292
+
293
+ # iterate steps inside the hit
294
+ for si in range(len(hit.particle)):
295
+ # get the particle information.
296
+ particle = hit.particle[si]
297
+ if particle not in pdgid_map:
298
+ pdgid_map[particle] = (_pdgid_to_particle(particle), _pdg_func.charge(particle))
299
+ part, _charge = pdgid_map[particle]
300
+
301
+ # do the scintillation.
302
+ num_phot = sc.scintillate_numphot(
303
+ scint_mat_params,
304
+ part,
305
+ hit.edep[si],
306
+ rng,
307
+ emission_term_model=("poisson" if mode == "no-fano" else "normal_fano"),
308
+ )
309
+ hit_output.append(num_phot)
310
+
311
+ assert len(hit_output) == len(hit.particle)
312
+ output_list.append(hit_output)
313
+
314
+ return output_list
319
315
 
320
316
 
321
317
  def get_output_table(output_map):
@@ -338,58 +334,35 @@ def get_output_table(output_map):
338
334
  return ph_count_o, tbl
339
335
 
340
336
 
341
- def convolve(
342
- map_file: str,
343
- edep_file: str,
344
- edep_path: str,
345
- material: str,
346
- output_file: str | None = None,
347
- buffer_len: int = int(1e6),
348
- dist_mode: str = "poisson+no-fano",
349
- ):
350
- if material not in ["lar", "pen"]:
351
- msg = f"unknown material {material} for scintillation"
352
- raise ValueError(msg)
337
+ def _reflatten_scint_vov(arr: ak.Array) -> ak.Array:
338
+ if all(arr[f].ndim == 1 for f in ak.fields(arr)):
339
+ return arr
340
+
341
+ group_num = ak.num(arr["edep"]).to_numpy()
342
+ flattened = {
343
+ f: ak.flatten(arr[f]) if arr[f].ndim > 1 else np.repeat(arr[f].to_numpy(), group_num)
344
+ for f in ak.fields(arr)
345
+ }
346
+ return ak.Array(flattened)
353
347
 
348
+
349
+ def _get_scint_params(material: str):
354
350
  if material == "lar":
355
- scint_mat_params = sc.precompute_scintillation_params(
351
+ return sc.precompute_scintillation_params(
356
352
  lar.lar_scintillation_params(),
357
353
  lar.lar_lifetimes().as_tuple(),
358
354
  )
359
- elif material == "pen":
360
- scint_mat_params = sc.precompute_scintillation_params(
361
- lar.pen_scintillation_params(),
362
- (1 * pint.get_application_registry().ns), # dummy!
363
- )
364
-
365
- # special handling of distributions and flags.
366
- dist, mode = dist_mode.split("+")
367
- if (
368
- dist not in ("multinomial", "poisson")
369
- or mode not in ("", "no-fano")
370
- or (dist == "poisson" and mode != "no-fano")
371
- ):
372
- msg = f"unsupported statistical distribution {dist_mode} for scintillation emission"
373
- raise ValueError(msg)
374
-
375
- log.info("opening map %s", map_file)
376
- optmap_for_convolve = open_optmap(map_file)
377
-
378
- log.info("opening energy deposition hit output %s", edep_file)
379
- it = LH5Iterator(edep_file, edep_path, buffer_len=buffer_len)
380
-
381
- for it_count, edep_lgdo in enumerate(it):
382
- edep_df = edep_lgdo.view_as("pd").to_records()
383
-
384
- log.info("start event processing (%d)", it_count)
385
- output_map = iterate_stepwise_depositions(
386
- edep_df, optmap_for_convolve, scint_mat_params, dist=dist, mode=mode
355
+ if material == "pen":
356
+ return sc.precompute_scintillation_params(
357
+ pen.pen_scintillation_params(),
358
+ (pen.pen_scint_timeconstant(),),
387
359
  )
388
-
389
- log.info("store output photon hits (%d)", it_count)
390
- ph_count_o, tbl = get_output_table(output_map)
391
- log.debug(
392
- "output photons: %d energy depositions -> %d photons", len(output_map), ph_count_o
360
+ if material == "fiber":
361
+ return sc.precompute_scintillation_params(
362
+ fibers.fiber_core_scintillation_params(),
363
+ (fibers.fiber_wls_timeconstant(),),
393
364
  )
394
- if output_file is not None:
395
- lh5.write(tbl, "optical", lh5_file=output_file, group="stp", wo_mode="append")
365
+ if isinstance(material, str):
366
+ msg = f"unknown material {material} for scintillation"
367
+ raise ValueError(msg)
368
+ return sc.precompute_scintillation_params(*material)