reboost 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,325 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import NamedTuple
5
+
6
+ import awkward as ak
7
+ import numba
8
+ import numpy as np
9
+ import pygeomoptics.scintillate as sc
10
+ from lgdo import lh5
11
+ from lgdo.types import Histogram
12
+ from numba import njit
13
+ from numpy.typing import NDArray
14
+ from pygeomoptics import fibers, lar, pen
15
+
16
+ from .numba_pdg import numba_pdgid_funcs
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ OPTMAP_ANY_CH = -1
22
+
23
+
24
+ class OptmapForConvolve(NamedTuple):
25
+ """A loaded optmap for convolving."""
26
+
27
+ dets: NDArray
28
+ detidx: NDArray
29
+ edges: NDArray
30
+ weights: NDArray
31
+
32
+
33
+ def open_optmap(optmap_fn: str) -> OptmapForConvolve:
34
+ dets = lh5.ls(optmap_fn, "/channels/")
35
+ detidx = np.arange(0, dets.shape[0])
36
+
37
+ optmap_all = lh5.read("/all/prob", optmap_fn)
38
+ assert isinstance(optmap_all, Histogram)
39
+ optmap_edges = tuple([b.edges for b in optmap_all.binning])
40
+
41
+ ow = np.empty((detidx.shape[0] + 2, *optmap_all.weights.nda.shape), dtype=np.float64)
42
+ # 0, ..., len(detidx)-1 AND OPTMAP_ANY_CH might be negative.
43
+ ow[OPTMAP_ANY_CH] = optmap_all.weights.nda
44
+ for i, nt in zip(detidx, dets, strict=True):
45
+ optmap = lh5.read(f"/{nt}/prob", optmap_fn)
46
+ assert isinstance(optmap, Histogram)
47
+ ow[i] = optmap.weights.nda
48
+
49
+ # if we have any individual channels registered, the sum is potentially larger than the
50
+ # probability to find _any_ hit.
51
+ if len(detidx) != 0:
52
+ map_sum = np.sum(ow[0:-2], axis=0, where=(ow[0:-2] >= 0))
53
+ assert not np.any(map_sum < 0)
54
+
55
+ # give this check some numerical slack.
56
+ if np.any(
57
+ np.abs(map_sum[ow[OPTMAP_ANY_CH] >= 0] - ow[OPTMAP_ANY_CH][ow[OPTMAP_ANY_CH] >= 0])
58
+ < -1e-15
59
+ ):
60
+ msg = "optical map does not fulfill relation sum(p_i) >= p_any"
61
+ raise ValueError(msg)
62
+ else:
63
+ detidx = np.array([OPTMAP_ANY_CH])
64
+ dets = np.array(["all"])
65
+
66
+ # check the exponent from the optical map file
67
+ if "_hitcounts_exp" in lh5.ls(optmap_fn):
68
+ msg = "found _hitcounts_exp which is not supported any more"
69
+ raise RuntimeError(msg)
70
+
71
+ dets = [d.replace("/channels/", "") for d in dets]
72
+
73
+ return OptmapForConvolve(dets, detidx, optmap_edges, ow)
74
+
75
+
76
+ def open_optmap_single(optmap_fn: str, spm_det: str) -> OptmapForConvolve:
77
+ # check the exponent from the optical map file
78
+ if "_hitcounts_exp" in lh5.ls(optmap_fn):
79
+ msg = "found _hitcounts_exp which is not supported any more"
80
+ raise RuntimeError(msg)
81
+
82
+ h5_path = f"channels/{spm_det}" if spm_det != "all" else spm_det
83
+ optmap = lh5.read(f"/{h5_path}/prob", optmap_fn)
84
+ assert isinstance(optmap, Histogram)
85
+ ow = np.empty((1, *optmap.weights.nda.shape), dtype=np.float64)
86
+ ow[0] = optmap.weights.nda
87
+ optmap_edges = tuple([b.edges for b in optmap.binning])
88
+
89
+ return OptmapForConvolve(np.array([spm_det]), np.array([0]), optmap_edges, ow)
90
+
91
+
92
+ def iterate_stepwise_depositions_pois(
93
+ edep_hits: ak.Array,
94
+ optmap: OptmapForConvolve,
95
+ scint_mat_params: sc.ComputedScintParams,
96
+ det: str,
97
+ map_scaling: float = 1,
98
+ map_scaling_sigma: float = 0,
99
+ rng: np.random.Generator | None = None,
100
+ ):
101
+ if edep_hits.particle.ndim == 1:
102
+ msg = "the pe processors only support already reshaped output"
103
+ raise ValueError(msg)
104
+
105
+ if det not in optmap.dets:
106
+ msg = f"channel {det} not available in optical map (contains {optmap.dets})"
107
+ raise ValueError(msg)
108
+
109
+ rng = np.random.default_rng() if rng is None else rng
110
+ res, output_list = _iterate_stepwise_depositions_pois(
111
+ edep_hits,
112
+ rng,
113
+ np.where(optmap.dets == det)[0][0],
114
+ map_scaling,
115
+ map_scaling_sigma,
116
+ optmap.edges,
117
+ optmap.weights,
118
+ scint_mat_params,
119
+ )
120
+
121
+ # convert the numba result back into an awkward array.
122
+ builder = ak.ArrayBuilder()
123
+ for r in output_list:
124
+ with builder.list():
125
+ for a in r:
126
+ builder.extend(a)
127
+
128
+ if res["det_no_stats"] > 0:
129
+ log.warning(
130
+ "had edep out in voxels without stats: %d",
131
+ res["det_no_stats"],
132
+ )
133
+ if res["oob"] > 0:
134
+ log.warning(
135
+ "had edep out of map bounds: %d (%.2f%%)",
136
+ res["oob"],
137
+ (res["oob"] / (res["ib"] + res["oob"])) * 100,
138
+ )
139
+ log.debug(
140
+ "VUV_primary %d ->hits %d (%.2f %% primaries detected in this channel)",
141
+ res["vuv_primary"],
142
+ res["hits"],
143
+ (res["hits"] / res["vuv_primary"]) * 100,
144
+ )
145
+ return builder.snapshot()
146
+
147
+
148
+ def iterate_stepwise_depositions_scintillate(
149
+ edep_hits: ak.Array,
150
+ scint_mat_params: sc.ComputedScintParams,
151
+ rng: np.random.Generator | None = None,
152
+ mode: str = "no-fano",
153
+ ):
154
+ if edep_hits.particle.ndim == 1:
155
+ msg = "the pe processors only support already reshaped output"
156
+ raise ValueError(msg)
157
+
158
+ rng = np.random.default_rng() if rng is None else rng
159
+ output_list = _iterate_stepwise_depositions_scintillate(edep_hits, rng, scint_mat_params, mode)
160
+
161
+ # convert the numba result back into an awkward array.
162
+ builder = ak.ArrayBuilder()
163
+ for r in output_list:
164
+ with builder.list():
165
+ builder.extend(r)
166
+
167
+ return builder.snapshot()
168
+
169
+
170
+ _pdg_func = numba_pdgid_funcs()
171
+
172
+
173
+ @njit
174
+ def _pdgid_to_particle(pdgid: int) -> sc.ParticleIndex:
175
+ abs_pdgid = abs(pdgid)
176
+ if abs_pdgid == 1000020040:
177
+ return sc.PARTICLE_INDEX_ALPHA
178
+ if abs_pdgid == 1000010020:
179
+ return sc.PARTICLE_INDEX_DEUTERON
180
+ if abs_pdgid == 1000010030:
181
+ return sc.PARTICLE_INDEX_TRITON
182
+ if _pdg_func.is_nucleus(pdgid):
183
+ return sc.PARTICLE_INDEX_ION
184
+ return sc.PARTICLE_INDEX_ELECTRON
185
+
186
+
187
+ __counts_per_bin_key_type = numba.types.UniTuple(numba.types.int64, 3)
188
+
189
+
190
+ # - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking
191
+ # - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
192
+ # - the output dictionary is not threadsafe, so parallel=True is not working with it.
193
+ @njit(parallel=False, nogil=True, cache=True)
194
+ def _iterate_stepwise_depositions_pois(
195
+ edep_hits,
196
+ rng,
197
+ detidx: int,
198
+ map_scaling: float,
199
+ map_scaling_sigma: float,
200
+ optmap_edges,
201
+ optmap_weights,
202
+ scint_mat_params: sc.ComputedScintParams,
203
+ ):
204
+ pdgid_map = {}
205
+ oob = ib = ph_cnt = ph_det2 = det_no_stats = 0 # for statistics
206
+ output_list = []
207
+
208
+ for rowid in range(len(edep_hits)): # iterate hits
209
+ hit = edep_hits[rowid]
210
+ hit_output = []
211
+
212
+ map_scaling_evt = map_scaling
213
+ if map_scaling_sigma > 0:
214
+ map_scaling_evt = rng.normal(loc=map_scaling, scale=map_scaling_sigma)
215
+
216
+ assert len(hit.particle) == len(hit.num_scint_ph)
217
+ # iterate steps inside the hit
218
+ for si in range(len(hit.particle)):
219
+ loc = np.array([hit.xloc[si], hit.yloc[si], hit.zloc[si]])
220
+ # coordinates -> bins of the optical map.
221
+ bins = np.empty(3, dtype=np.int64)
222
+ for j in range(3):
223
+ bins[j] = np.digitize(loc[j], optmap_edges[j])
224
+ # normalize all out-of-bounds bins just to one end.
225
+ if bins[j] == optmap_edges[j].shape[0]:
226
+ bins[j] = 0
227
+
228
+ # note: subtract 1 from bins, to account for np.digitize output.
229
+ cur_bins = (bins[0] - 1, bins[1] - 1, bins[2] - 1)
230
+ if cur_bins[0] == -1 or cur_bins[1] == -1 or cur_bins[2] == -1:
231
+ oob += 1
232
+ continue # out-of-bounds of optmap
233
+ ib += 1
234
+
235
+ # get probabilities from map.
236
+ detp = optmap_weights[detidx, cur_bins[0], cur_bins[1], cur_bins[2]] * map_scaling_evt
237
+ if detp < 0.0:
238
+ det_no_stats += 1
239
+ continue
240
+
241
+ pois_cnt = rng.poisson(lam=hit.num_scint_ph[si] * detp)
242
+ ph_cnt += hit.num_scint_ph[si]
243
+ ph_det2 += pois_cnt
244
+
245
+ # get the particle information.
246
+ particle = hit.particle[si]
247
+ if particle not in pdgid_map:
248
+ pdgid_map[particle] = (_pdgid_to_particle(particle), _pdg_func.charge(particle))
249
+ part, _charge = pdgid_map[particle]
250
+
251
+ # get time spectrum.
252
+ # note: we assume "immediate" propagation after scintillation.
253
+ scint_times = sc.scintillate_times(scint_mat_params, part, pois_cnt, rng) + hit.time[si]
254
+
255
+ hit_output.append(scint_times)
256
+
257
+ output_list.append(hit_output)
258
+
259
+ stats = {
260
+ "oob": oob,
261
+ "ib": ib,
262
+ "vuv_primary": ph_cnt,
263
+ "hits": ph_det2,
264
+ "det_no_stats": det_no_stats,
265
+ }
266
+ return stats, output_list
267
+
268
+
269
+ # - run with NUMBA_FULL_TRACEBACKS=1 NUMBA_BOUNDSCHECK=1 for testing/checking
270
+ # - cache=True does not work with outer prange, i.e. loading the cached file fails (numba bug?)
271
+ @njit(parallel=False, nogil=True, cache=True)
272
+ def _iterate_stepwise_depositions_scintillate(
273
+ edep_hits, rng, scint_mat_params: sc.ComputedScintParams, mode: str
274
+ ):
275
+ pdgid_map = {}
276
+ output_list = []
277
+
278
+ for rowid in range(len(edep_hits)): # iterate hits
279
+ hit = edep_hits[rowid]
280
+ hit_output = []
281
+
282
+ # iterate steps inside the hit
283
+ for si in range(len(hit.particle)):
284
+ # get the particle information.
285
+ particle = hit.particle[si]
286
+ if particle not in pdgid_map:
287
+ pdgid_map[particle] = (_pdgid_to_particle(particle), _pdg_func.charge(particle))
288
+ part, _charge = pdgid_map[particle]
289
+
290
+ # do the scintillation.
291
+ num_phot = sc.scintillate_numphot(
292
+ scint_mat_params,
293
+ part,
294
+ hit.edep[si],
295
+ rng,
296
+ emission_term_model=("poisson" if mode == "no-fano" else "normal_fano"),
297
+ )
298
+ hit_output.append(num_phot)
299
+
300
+ assert len(hit_output) == len(hit.particle)
301
+ output_list.append(hit_output)
302
+
303
+ return output_list
304
+
305
+
306
+ def _get_scint_params(material: str):
307
+ if material == "lar":
308
+ return sc.precompute_scintillation_params(
309
+ lar.lar_scintillation_params(),
310
+ lar.lar_lifetimes().as_tuple(),
311
+ )
312
+ if material == "pen":
313
+ return sc.precompute_scintillation_params(
314
+ pen.pen_scintillation_params(),
315
+ (pen.pen_scint_timeconstant(),),
316
+ )
317
+ if material == "fiber":
318
+ return sc.precompute_scintillation_params(
319
+ fibers.fiber_core_scintillation_params(),
320
+ (fibers.fiber_wls_timeconstant(),),
321
+ )
322
+ if isinstance(material, str):
323
+ msg = f"unknown material {material} for scintillation"
324
+ raise ValueError(msg)
325
+ return sc.precompute_scintillation_params(*material)