reboost 0.3.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
reboost/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from reboost import build_hit, core, iterator, math, shape
4
- from reboost._version import version as __version__
3
+ from . import build_hit, core, iterator, math, shape
4
+ from ._version import version as __version__
5
5
 
6
6
  __all__ = [
7
7
  "__version__",
reboost/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.3.0'
21
- __version_tuple__ = version_tuple = (0, 3, 0)
20
+ __version__ = version = '0.4.2'
21
+ __version_tuple__ = version_tuple = (0, 4, 2)
reboost/build_glm.py CHANGED
@@ -9,7 +9,7 @@ from lgdo import Array, Table, lh5
9
9
  from lgdo.lh5 import LH5Iterator, LH5Store
10
10
  from numpy.typing import ArrayLike
11
11
 
12
- from reboost import utils
12
+ from . import utils
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -35,6 +35,7 @@ def get_glm_rows(stp_evtids: ArrayLike, vert: ArrayLike, *, start_row: int = 0)
35
35
  output = ak.Array({"evtid": vert})
36
36
  output["n_rows"] = np.array([0] * len(vert), dtype=float)
37
37
  output["start_row"] = np.array([np.nan] * len(vert), dtype=float)
38
+
38
39
  return output
39
40
 
40
41
  if not isinstance(stp_evtids, np.ndarray):
@@ -182,6 +183,7 @@ def get_stp_evtids(
182
183
  def build_glm(
183
184
  stp_files: str | list[str],
184
185
  glm_files: str | list[str] | None,
186
+ lh5_groups: list | None = None,
185
187
  *,
186
188
  out_table_name: str = "glm",
187
189
  id_name: str = "g4_evtid",
@@ -225,7 +227,11 @@ def build_glm(
225
227
  log.info(msg)
226
228
 
227
229
  # loop over the lh5_tables
228
- lh5_table_list = list(lh5.ls(stp_file, "stp/"))
230
+ lh5_table_list = [
231
+ det
232
+ for det in lh5.ls(stp_file, "stp/")
233
+ if lh5_groups is None or det.split("/")[1] in lh5_groups
234
+ ]
229
235
 
230
236
  # get rows in the table
231
237
  if files.glm[file_idx] is None:
reboost/build_hit.py CHANGED
@@ -20,7 +20,7 @@ A :func:`build_hit` to parse the following configuration file:
20
20
 
21
21
  # this is a list of included detectors (part of the processing group)
22
22
  detector_mapping:
23
- - output: OBJECTS.lmeta.channglmap(on=ARGS.timestamp)
23
+ - output: OBJECTS.lmeta.channelmap(on=ARGS.timestamp)
24
24
  .group('system').geds
25
25
  .group('analysis.status').on
26
26
  .map('name').keys()
@@ -153,6 +153,12 @@ A :func:`build_hit` to parse the following configuration file:
153
153
  )
154
154
 
155
155
  pe_times: ak.concatenate([HITS.pe_times_lar, HITS.pe_times_pen], axis=-1)
156
+
157
+ # can list here some lh5 objects that should just be forwarded to the
158
+ # output file, without any processing
159
+ forward:
160
+ - /vtx
161
+ - /some/dataset
156
162
  """
157
163
 
158
164
  from __future__ import annotations
@@ -166,12 +172,11 @@ import awkward as ak
166
172
  import dbetto
167
173
  from dbetto import AttrsDict
168
174
  from lgdo import lh5
169
- from lgdo.types import Struct
170
-
171
- from reboost.iterator import GLMIterator
172
- from reboost.profile import ProfileDict
175
+ from lgdo.lh5.exceptions import LH5EncodeError
173
176
 
174
177
  from . import core, utils
178
+ from .iterator import GLMIterator
179
+ from .profile import ProfileDict
175
180
 
176
181
  log = logging.getLogger(__name__)
177
182
 
@@ -225,19 +230,19 @@ def build_hit(
225
230
  # get the arguments
226
231
  if not isinstance(args, AttrsDict):
227
232
  args = AttrsDict(args)
233
+
228
234
  time_dict = ProfileDict()
229
235
 
230
236
  # get the global objects
231
- global_objects = AttrsDict(
232
- core.get_global_objects(
233
- expressions=config.get("objects", {}), local_dict={"ARGS": args}, time_dict=time_dict
234
- )
237
+ global_objects = core.get_global_objects(
238
+ expressions=config.get("objects", {}), local_dict={"ARGS": args}, time_dict=time_dict
235
239
  )
236
240
 
237
241
  # get the input files
238
242
  files = utils.get_file_dict(stp_files=stp_files, glm_files=glm_files, hit_files=hit_files)
239
243
 
240
244
  output_tables = {}
245
+
241
246
  # iterate over files
242
247
  for file_idx, (stp_file, glm_file) in enumerate(zip(files.stp, files.glm)):
243
248
  msg = (
@@ -257,21 +262,14 @@ def build_hit(
257
262
  time_dict[proc_name] = ProfileDict()
258
263
 
259
264
  # extract the output detectors and the mapping to input detectors
260
- detectors_mapping = utils.merge_dicts(
261
- [
262
- core.get_detectors_mapping(
263
- mapping["output"],
264
- input_detector_name=mapping.get("input", None),
265
- objects=global_objects,
266
- )
267
- for mapping in proc_group.get("detector_mapping")
268
- ]
265
+ detectors_mapping = core.get_detector_mapping(
266
+ proc_group.get("detector_mapping"), global_objects
269
267
  )
270
268
 
271
269
  # loop over detectors
272
270
  for in_det_idx, (in_detector, out_detectors) in enumerate(detectors_mapping.items()):
273
271
  msg = f"... processing {in_detector} (to {out_detectors})"
274
- log.info(msg)
272
+ log.debug(msg)
275
273
 
276
274
  # get detector objects
277
275
  det_objects = core.get_detector_objects(
@@ -283,18 +281,19 @@ def build_hit(
283
281
  )
284
282
 
285
283
  # begin iterating over the glm
286
- glm_it = GLMIterator(
284
+ iterator = GLMIterator(
287
285
  glm_file,
288
286
  stp_file,
289
287
  lh5_group=in_detector,
290
288
  start_row=start_evtid,
291
289
  stp_field=in_field,
292
290
  n_rows=n_evtid,
293
- read_vertices=False,
294
291
  buffer=buffer,
295
292
  time_dict=time_dict[proc_name],
293
+ reshaped_files="hit_table_layout" not in proc_group,
296
294
  )
297
- for stps, _, chunk_idx, _ in glm_it:
295
+
296
+ for stps, chunk_idx, _ in iterator:
298
297
  # converting to awkward
299
298
  if stps is None:
300
299
  continue
@@ -342,7 +341,7 @@ def build_hit(
342
341
  time_dict=time_dict[proc_name],
343
342
  name=field,
344
343
  )
345
- hit_table.add_field(field, col)
344
+ core.add_field_with_nesting(hit_table, field, col)
346
345
 
347
346
  # remove unwanted fields
348
347
  if "outputs" in proc_group:
@@ -353,46 +352,56 @@ def build_hit(
353
352
  # assign units in the output table
354
353
  hit_table = utils.assign_units(hit_table, attrs)
355
354
 
356
- new_hit_file = (file_idx == 0) or (
357
- files.hit[file_idx] != files.hit[file_idx - 1]
358
- )
359
-
360
- wo_mode = utils.get_wo_mode(
361
- group=group_idx,
362
- out_det=out_det_idx,
363
- in_det=in_det_idx,
364
- chunk=chunk_idx,
365
- new_hit_file=new_hit_file,
366
- overwrite=overwrite,
367
- )
368
-
369
355
  # now write
370
356
  if files.hit[file_idx] is not None:
371
- if time_dict is not None:
372
- start_time = time.time()
373
-
374
- if wo_mode != "a":
375
- lh5.write(
376
- Struct({out_detector: hit_table}),
377
- out_field,
378
- files.hit[file_idx],
379
- wo_mode=wo_mode,
380
- )
381
- else:
382
- lh5.write(
383
- hit_table,
384
- f"{out_field}/{out_detector}",
385
- files.hit[file_idx],
386
- wo_mode=wo_mode,
387
- )
388
- if time_dict is not None:
389
- time_dict[proc_name].update_field("write", start_time)
357
+ # get modes to write with
358
+ new_hit_file = (file_idx == 0) or (
359
+ files.hit[file_idx] != files.hit[file_idx - 1]
360
+ )
361
+
362
+ wo_mode = utils.get_wo_mode(
363
+ group=group_idx,
364
+ out_det=out_det_idx,
365
+ in_det=in_det_idx,
366
+ chunk=chunk_idx,
367
+ new_hit_file=new_hit_file,
368
+ overwrite=overwrite,
369
+ )
370
+ # write the file
371
+ utils.write_lh5(
372
+ hit_table,
373
+ files.hit[file_idx],
374
+ time_dict[proc_name],
375
+ out_field=out_field,
376
+ out_detector=out_detector,
377
+ wo_mode=wo_mode,
378
+ )
390
379
 
391
380
  else:
392
381
  output_tables[out_detector] = core.merge(
393
382
  hit_table, output_tables[out_detector]
394
383
  )
395
384
 
385
+ # forward some data, if requested
386
+ # possible improvement: iterate over data if it's a lot
387
+ if "forward" in config and files.hit[file_idx] is not None:
388
+ obj_list = config["forward"]
389
+
390
+ if not isinstance(obj_list, list):
391
+ obj_list = [obj_list]
392
+
393
+ for obj in obj_list:
394
+ try:
395
+ lh5.write(
396
+ lh5.read(obj, stp_file),
397
+ obj,
398
+ files.hit[file_idx],
399
+ wo_mode="write_safe",
400
+ )
401
+ except LH5EncodeError as e:
402
+ msg = f"cannot forward object {obj} as it has been already processed by reboost"
403
+ raise RuntimeError(msg) from e
404
+
396
405
  # return output table or nothing
397
406
  log.info(time_dict)
398
407
 
reboost/build_tcm.py CHANGED
@@ -6,7 +6,7 @@ import re
6
6
  import awkward as ak
7
7
  from lgdo import Table, lh5
8
8
 
9
- from reboost.shape import group
9
+ from .shape import group
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
reboost/cli.py CHANGED
@@ -5,11 +5,10 @@ import logging
5
5
 
6
6
  import dbetto
7
7
 
8
- from reboost.build_glm import build_glm
9
- from reboost.build_hit import build_hit
10
- from reboost.utils import _check_input_file, _check_output_file, get_file_list
11
-
8
+ from .build_glm import build_glm
9
+ from .build_hit import build_hit
12
10
  from .log_utils import setup_log
11
+ from .utils import _check_input_file, _check_output_file, get_file_list
13
12
 
14
13
  log = logging.getLogger(__name__)
15
14
 
@@ -88,7 +87,8 @@ def cli(args=None) -> None:
88
87
  hit_parser.add_argument(
89
88
  "--glm-file",
90
89
  type=str,
91
- required=True,
90
+ required=False,
91
+ default=None,
92
92
  help="glm file to process, if multithreaded this will be appended with _t$idx",
93
93
  )
94
94
  hit_parser.add_argument(
@@ -159,7 +159,9 @@ def cli(args=None) -> None:
159
159
  hit_files = get_file_list(args.hit_file, threads=args.threads)
160
160
 
161
161
  _check_input_file(parser, stp_files)
162
- _check_input_file(parser, glm_files)
162
+
163
+ if args.glm_file is not None:
164
+ _check_input_file(parser, glm_files)
163
165
 
164
166
  if args.overwrite is False:
165
167
  _check_output_file(parser, hit_files)
@@ -174,8 +176,8 @@ def cli(args=None) -> None:
174
176
  msg += f" n_evtid: {args.n_evtid}\n"
175
177
  msg += f" in_field: {args.in_field}\n"
176
178
  msg += f" out_field: {args.out_field}\n"
177
- msg += f" buffer: {args.buffer}"
178
- msg += f" overwrite: {args.overwrite}"
179
+ msg += f" buffer: {args.buffer} \n"
180
+ msg += f" overwrite: {args.overwrite} \n"
179
181
 
180
182
  log.info(msg)
181
183
 
reboost/core.py CHANGED
@@ -8,9 +8,8 @@ import awkward as ak
8
8
  from dbetto import AttrsDict
9
9
  from lgdo.types import LGDO, Table
10
10
 
11
- from reboost.profile import ProfileDict
12
-
13
11
  from . import utils
12
+ from .profile import ProfileDict
14
13
 
15
14
  log = logging.getLogger(__name__)
16
15
 
@@ -119,7 +118,7 @@ def evaluate_object(
119
118
 
120
119
  def get_global_objects(
121
120
  expressions: dict[str, str], *, local_dict: dict, time_dict: dict | None = None
122
- ) -> dict:
121
+ ) -> AttrsDict:
123
122
  """Extract global objects used in the processing.
124
123
 
125
124
  Parameters
@@ -141,19 +140,42 @@ def get_global_objects(
141
140
 
142
141
  msg = f"Getting global objects with {expressions.keys()} and {local_dict}"
143
142
  log.info(msg)
143
+ res = {}
144
+
145
+ for obj_name, expression in expressions.items():
146
+ res[obj_name] = evaluate_object(
147
+ expression, local_dict=local_dict | {"OBJECTS": AttrsDict(res)}
148
+ )
144
149
 
145
- res = AttrsDict(
146
- {
147
- obj_name: evaluate_object(expression, local_dict=local_dict)
148
- for obj_name, expression in expressions.items()
149
- }
150
- )
151
150
  if time_dict is not None:
152
151
  time_dict.update_field(name="global_objects", time_start=time_start)
153
- return res
154
152
 
153
+ return AttrsDict(res)
154
+
155
+
156
+ def get_detector_mapping(detector_mapping: dict, global_objects: AttrsDict) -> dict:
157
+ """Get all the detector mapping using :func:`get_one_detector_mapping`.
155
158
 
156
- def get_detectors_mapping(
159
+ Parameters
160
+ ----------
161
+ detector_mapping
162
+ dictionary of detector mapping
163
+ global_objects
164
+ dictionary of global objects to use in evaluating the mapping.
165
+ """
166
+ return utils.merge_dicts(
167
+ [
168
+ get_one_detector_mapping(
169
+ mapping["output"],
170
+ input_detector_name=mapping.get("input", None),
171
+ objects=global_objects,
172
+ )
173
+ for mapping in detector_mapping
174
+ ]
175
+ )
176
+
177
+
178
+ def get_one_detector_mapping(
157
179
  output_detector_expression: str | list,
158
180
  objects: AttrsDict | None = None,
159
181
  input_detector_name: str | None = None,
@@ -342,6 +364,55 @@ def evaluate_hit_table_layout(
342
364
  return res
343
365
 
344
366
 
367
+ def add_field_with_nesting(tab: Table, col: str, field: LGDO) -> Table:
368
+ """Add a field handling the nesting."""
369
+ subfields = col.strip("/").split("___")
370
+ tab_next = tab
371
+
372
+ for level in subfields:
373
+ # if we are at the end, just add the field
374
+ if level == subfields[-1]:
375
+ tab_next.add_field(level, field)
376
+ break
377
+
378
+ if not level:
379
+ msg = f"invalid field name '{field}'"
380
+ raise RuntimeError(msg)
381
+
382
+ # otherwise, increase nesting
383
+ if level not in tab:
384
+ tab_next.add_field(level, Table(size=len(tab)))
385
+ tab_next = tab[level]
386
+ else:
387
+ tab_next = tab[level]
388
+
389
+ return tab
390
+
391
+
392
+ def _get_table_keys(tab: Table):
393
+ """Get keys in a table."""
394
+ existing_cols = list(tab.keys())
395
+ output_cols = []
396
+ for col in existing_cols:
397
+ if isinstance(tab[col], Table):
398
+ output_cols.extend(
399
+ [f"{col}___{col_second}" for col_second in _get_table_keys(tab[col])]
400
+ )
401
+ else:
402
+ output_cols.append(col)
403
+
404
+ return output_cols
405
+
406
+
407
+ def _remove_col(field: str, tab: Table):
408
+ """Remove column accounting for nesting."""
409
+ if "___" in field:
410
+ base_name, sub_field = field.split("___", 1)[0], field.split("___", 1)[1]
411
+ _remove_col(sub_field, tab[base_name])
412
+ else:
413
+ tab.remove_column(field, delete=True)
414
+
415
+
345
416
  def remove_columns(tab: Table, outputs: list) -> Table:
346
417
  """Remove columns from the table not found in the outputs.
347
418
 
@@ -356,11 +427,10 @@ def remove_columns(tab: Table, outputs: list) -> Table:
356
427
  -------
357
428
  the table with columns removed.
358
429
  """
359
- existing_cols = list(tab.keys())
360
- for col in existing_cols:
361
- if col not in outputs:
362
- tab.remove_column(col, delete=True)
363
-
430
+ cols = _get_table_keys(tab)
431
+ for col_unrename in cols:
432
+ if col_unrename not in outputs:
433
+ _remove_col(col_unrename, tab)
364
434
  return tab
365
435
 
366
436