umami-preprocessing 0.2.7__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {umami_preprocessing-0.2.7/umami_preprocessing.egg-info → umami_preprocessing-0.3.0}/PKG-INFO +3 -3
  2. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/pyproject.toml +2 -2
  3. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0/umami_preprocessing.egg-info}/PKG-INFO +3 -3
  4. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/umami_preprocessing.egg-info/requires.txt +2 -2
  5. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/__init__.py +1 -1
  6. umami_preprocessing-0.3.0/upp/stages/merging.py +598 -0
  7. umami_preprocessing-0.2.7/upp/stages/merging.py +0 -322
  8. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/LICENSE +0 -0
  9. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/MANIFEST.in +0 -0
  10. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/README.md +0 -0
  11. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/setup.cfg +0 -0
  12. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
  13. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
  14. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/umami_preprocessing.egg-info/entry_points.txt +0 -0
  15. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/umami_preprocessing.egg-info/top_level.txt +0 -0
  16. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/__init__.py +0 -0
  17. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/components.py +0 -0
  18. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/preprocessing_config.py +0 -0
  19. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/region.py +0 -0
  20. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/resampling_config.py +0 -0
  21. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/reweight_config.py +0 -0
  22. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/classes/variable_config.py +0 -0
  23. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/main.py +0 -0
  24. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/__init__.py +0 -0
  25. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/hist.py +0 -0
  26. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/interpolation.py +0 -0
  27. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/normalisation.py +0 -0
  28. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/plot.py +0 -0
  29. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/resampling.py +0 -0
  30. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/reweight.py +0 -0
  31. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/rw_merge.py +0 -0
  32. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/stages/split_containers.py +0 -0
  33. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/utils/__init__.py +0 -0
  34. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/utils/check_input_samples.py +0 -0
  35. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/utils/logger.py +0 -0
  36. {umami_preprocessing-0.2.7 → umami_preprocessing-0.3.0}/upp/utils/tools.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.7
3
+ Version: 0.3.0
4
4
  Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
5
5
  Author: Alexander Froch
6
6
  License: MIT
@@ -9,10 +9,10 @@ Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/iss
9
9
  Requires-Python: <3.12,>=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Requires-Dist: atlas-ftag-tools==0.2.17
12
+ Requires-Dist: atlas-ftag-tools==0.3.1
13
13
  Requires-Dist: dotmap>=1.3.30
14
14
  Requires-Dist: numpy>=2.2.6
15
- Requires-Dist: puma-hep==0.4.11
15
+ Requires-Dist: puma-hep==0.5.1
16
16
  Requires-Dist: pyyaml-include==1.3
17
17
  Requires-Dist: PyYAML>=6.0.2
18
18
  Requires-Dist: rich>=14.1.0
@@ -8,10 +8,10 @@ readme = "README.md"
8
8
  requires-python = ">=3.10,<3.12"
9
9
 
10
10
  dependencies = [
11
- "atlas-ftag-tools==0.2.17",
11
+ "atlas-ftag-tools==0.3.1",
12
12
  "dotmap>=1.3.30",
13
13
  "numpy>=2.2.6",
14
- "puma-hep==0.4.11",
14
+ "puma-hep==0.5.1",
15
15
  "pyyaml-include==1.3",
16
16
  "PyYAML>=6.0.2",
17
17
  "rich>=14.1.0",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.7
3
+ Version: 0.3.0
4
4
  Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
5
5
  Author: Alexander Froch
6
6
  License: MIT
@@ -9,10 +9,10 @@ Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/iss
9
9
  Requires-Python: <3.12,>=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Requires-Dist: atlas-ftag-tools==0.2.17
12
+ Requires-Dist: atlas-ftag-tools==0.3.1
13
13
  Requires-Dist: dotmap>=1.3.30
14
14
  Requires-Dist: numpy>=2.2.6
15
- Requires-Dist: puma-hep==0.4.11
15
+ Requires-Dist: puma-hep==0.5.1
16
16
  Requires-Dist: pyyaml-include==1.3
17
17
  Requires-Dist: PyYAML>=6.0.2
18
18
  Requires-Dist: rich>=14.1.0
@@ -1,7 +1,7 @@
1
- atlas-ftag-tools==0.2.17
1
+ atlas-ftag-tools==0.3.1
2
2
  dotmap>=1.3.30
3
3
  numpy>=2.2.6
4
- puma-hep==0.4.11
4
+ puma-hep==0.5.1
5
5
  pyyaml-include==1.3
6
6
  PyYAML>=6.0.2
7
7
  rich>=14.1.0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.7"
5
+ __version__ = "v0.3.0"
6
6
 
7
7
  from . import classes, stages, utils
8
8
  from .main import run_pp
@@ -0,0 +1,598 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging as log
5
+ from copy import copy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, cast
8
+
9
+ import h5py
10
+ import numpy as np
11
+ from ftag.hdf5 import H5Writer, join_structured_arrays
12
+
13
+ from upp.utils.logger import ProgressBar
14
+ from upp.utils.tools import path_append
15
+
16
+ if TYPE_CHECKING: # pragma: no cover
17
+ from upp.classes.components import Component, Components
18
+ from upp.classes.preprocessing_config import PreprocessingConfig
19
+
20
+
21
+ class Merging:
22
+ """Merging Classto merge different components/regions."""
23
+
24
+ def __init__(self, config: PreprocessingConfig):
25
+ self.config = config
26
+ self.components = config.components
27
+ self.variables = config.variables
28
+ self.batch_size = config.batch_size
29
+ self.jets_name = config.jets_name
30
+ self.rng = np.random.default_rng(42)
31
+ self.flavours = self.components.flavours
32
+ self.num_jets_per_output_file = config.num_jets_per_output_file
33
+ self.file_tag = "split"
34
+
35
+ # Auto-resume toggle (make configurable if you prefer opt-in)
36
+ self.resume = True
37
+
38
+ # Auto-delete a corrupted existing part before resuming
39
+ self.auto_fix_parts = True
40
+
41
+ # Internal state, guard to keep fast-forward from opening files
42
+ self._fast_forwarding: bool = False
43
+
44
+ # Pending tail (used only across the fast-forward boundary)
45
+ self._ff_pending: dict[str, np.ndarray] | None = None
46
+
47
+ # perfectly valid, no union necessary
48
+ self.dtypes: dict[str, np.dtype] = {}
49
+ self.base_shapes: dict[str, tuple[int, ...]] = {}
50
+
51
+ # Setup all the jet counters
52
+ self._file_idx: int = 0
53
+ self.total_jets: int = 0
54
+ self.jets_written: int = 0
55
+
56
+ # Setup the sample string
57
+ self._sample: str | None = None
58
+
59
+ # Use cast because we cannot init Components/H5Writer here
60
+ self.current_components = cast("Components", None)
61
+ self.writer = cast(H5Writer, None)
62
+
63
+ def add_jet_flavour_label(self, jets: np.ndarray, component: Component) -> np.ndarray:
64
+ """Add the jet flavour label to the jets.
65
+
66
+ If already present, jets will be returned without any changes.
67
+
68
+ Parameters
69
+ ----------
70
+ jets : np.ndarray
71
+ Structured array of with the jets and their variables
72
+ component : Component
73
+ Component instance of the
74
+
75
+ Returns
76
+ -------
77
+ np.ndarray
78
+ Structured array of the jets and their variables with the
79
+ "flavour_label" added.
80
+ """
81
+ if "flavour_label" in jets.dtype.names:
82
+ return jets
83
+ int_label = self.flavours.index(component.flavour)
84
+ label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
85
+
86
+ return join_structured_arrays([jets, label_array])
87
+
88
+ def _part_fname(self, sample: str | None, file_idx: int) -> Path:
89
+ """Construct the exact output filename for a given part index.
90
+
91
+ Parameters
92
+ ----------
93
+ sample : str | None
94
+ Name of the output sample
95
+ file_idx : int
96
+ Iterator number of the output file
97
+
98
+ Returns
99
+ -------
100
+ Path
101
+ Final path to the output file
102
+ """
103
+ # Get base path of the output
104
+ fname = Path(self.config.out_fname)
105
+
106
+ # Append the sample name to the file name
107
+ if sample:
108
+ fname = path_append(fname, sample)
109
+
110
+ # Define the suffix for the file (including the iterator number)
111
+ suffix = f"{self.file_tag}_{file_idx:03d}"
112
+
113
+ # Return the final path
114
+ return fname.with_name(f"{fname.stem}_{suffix}{fname.suffix}")
115
+
116
+ def _expected_rows_for_part(self, part_idx: int) -> int:
117
+ """Return the expected number of rows for part `part_idx` given total_jets and split size.
118
+
119
+ Parameters
120
+ ----------
121
+ part_idx : int
122
+ Iterator number of the file
123
+
124
+ Returns
125
+ -------
126
+ int
127
+ Expected number of rows for the given partial file
128
+ """
129
+ # Assert that the final output file will be splitted
130
+ assert self.num_jets_per_output_file is not None
131
+
132
+ # Remaining jets starting at this part
133
+ start = part_idx * int(self.num_jets_per_output_file)
134
+ remaining = max(0, self.total_jets - start)
135
+
136
+ return min(int(self.num_jets_per_output_file), remaining)
137
+
138
+ def _is_part_valid(self, sample: str | None, part_idx: int) -> bool:
139
+ """Heuristically validate that a part file is complete and consistent.
140
+
141
+ Checks:
142
+ - File can be opened.
143
+ - All expected datasets exist (based on self.base_shapes keys).
144
+ - All datasets share the same first-dimension length.
145
+ - First-dimension equals the expected rows for this part.
146
+
147
+ Parameters
148
+ ----------
149
+ sample : str | None
150
+ Name of the sample
151
+ part_idx : int
152
+ Iterator number of the file
153
+
154
+ Returns
155
+ -------
156
+ bool
157
+ Check that the partial file is complete and valid.
158
+ """
159
+ # Get the file path
160
+ fname = self._part_fname(sample, part_idx)
161
+
162
+ # Try to open the h5 file
163
+ try:
164
+ with h5py.File(fname, "r") as f:
165
+ # Collect expected dataset names from base_shapes (already computed)
166
+ expected_names = list(self.base_shapes.keys())
167
+
168
+ # Tolerate missing optional groups, but require the jet dataset at least
169
+ if self.jets_name not in f:
170
+ log.warning(f"Missing dataset '{self.jets_name}' in {fname}")
171
+ return False
172
+
173
+ # Determine observed length from anchor (jets_name) or first dataset
174
+ anchor = self.jets_name if self.jets_name in f else expected_names[0]
175
+ if anchor not in f:
176
+ # if jets_name wasn't found, try any expected dataset that exists
177
+ for nm in expected_names:
178
+ if nm in f:
179
+ anchor = nm
180
+ break
181
+
182
+ if anchor not in f:
183
+ log.warning(f"No expected datasets found in {fname}")
184
+ return False
185
+
186
+ obs_len = f[anchor].shape[0]
187
+
188
+ # All expected datasets that are present should match obs_len
189
+ for nm in expected_names:
190
+ if nm in f and f[nm].shape[0] != obs_len:
191
+ log.warning(
192
+ f"Dataset '{nm}' len={f[nm].shape[0]} " f"!= {obs_len} in {fname}"
193
+ )
194
+ return False
195
+
196
+ # Compare with expected rows for this part (if split mode)
197
+ if self.num_jets_per_output_file is not None:
198
+ exp_len = self._expected_rows_for_part(part_idx)
199
+ if obs_len != exp_len:
200
+ log.warning(
201
+ f"Part {part_idx:03d} in {fname} has {obs_len} rows, "
202
+ f"expected {exp_len}."
203
+ )
204
+ return False
205
+
206
+ return True
207
+
208
+ # Except the file is broken
209
+ except OSError as e:
210
+ # Typical for truncated/half-written files
211
+ log.warning(f"Failed to open {fname}: {e}")
212
+ return False
213
+
214
+ def _detect_and_clean_completed_parts(self, sample: str | None) -> int:
215
+ """Detect valid and invalid parts and remove the invalid path.
216
+
217
+ Count contiguous **valid** parts; if the first invalid part is found and
218
+ `auto_fix_parts` is enabled, delete it so resume can overwrite it.
219
+
220
+ Parameters
221
+ ----------
222
+ sample : str | None
223
+ Name of the sample to use
224
+
225
+ Returns
226
+ -------
227
+ int
228
+ The index of the first missing/invalid part.
229
+ """
230
+ # Check that multiple output files should be created
231
+ if self.num_jets_per_output_file is None:
232
+ return 0
233
+
234
+ # Define a counter
235
+ idx = 0
236
+
237
+ # Loop over the files
238
+ while True:
239
+ # Get the name of the file
240
+ fname = self._part_fname(sample, idx)
241
+
242
+ # If the file doesn't exist, stop the loop
243
+ if not fname.exists():
244
+ break
245
+
246
+ # Validate the existing file
247
+ if not self._is_part_valid(sample, idx):
248
+ if self.auto_fix_parts:
249
+ try:
250
+ fname.unlink()
251
+ log.warning(
252
+ f"[bold yellow]Deleted corrupted part: {fname.name} "
253
+ f"(will be re-written)."
254
+ )
255
+ except OSError as e:
256
+ log.error(f"Could not delete corrupted part {fname}: {e}")
257
+
258
+ # Stop at the first invalid file (deleted or left as-is)
259
+ break
260
+
261
+ # Go to next file
262
+ idx += 1
263
+
264
+ # Return the idx number of the new file
265
+ return idx
266
+
267
+ class _NullWriter:
268
+ """A minimal writer that discards data while tracking how much would be written."""
269
+
270
+ def __init__(self, capacity: int):
271
+ self.num_jets = capacity
272
+ self.num_written = 0
273
+
274
+ def write(self, batch: dict[str, np.ndarray]) -> None:
275
+ """Count the number of jets that would be written.
276
+
277
+ Parameters
278
+ ----------
279
+ batch : dict[str, np.ndarray]
280
+ Dict with the batches
281
+ """
282
+ # advance by the leading dimension of any array (they are aligned)
283
+ if not batch:
284
+ return
285
+ any_arr = next(iter(batch.values()))
286
+ k = len(any_arr)
287
+ self.num_written = min(self.num_written + k, self.num_jets)
288
+
289
+ def add_attr(self, *args, **kwargs):
290
+ """Skip the attribute addition."""
291
+ pass
292
+
293
+ def close(self):
294
+ """Skip the close."""
295
+ pass
296
+
297
+ def _open_writer(
298
+ self,
299
+ sample: str | None,
300
+ jets_in_file: int,
301
+ file_idx: int,
302
+ components: Components,
303
+ ) -> None:
304
+ """Create `self.writer` for the next output file and attach static attributes.
305
+
306
+ Parameters
307
+ ----------
308
+ sample : str | None
309
+ Sample name (``None`` for the "train/val test" merge).
310
+ jets_in_file : int
311
+ Capacity of the new file (= leading dimension of every dataset).
312
+ file_idx : int
313
+ Running part index (0, 1, 2, …); used only for the filename suffix.
314
+ components : Components
315
+ The `Components` object we are currently merging needed for `jet_counts`, etc.
316
+ """
317
+ # Construct the filename
318
+ fname = Path(self.config.out_fname)
319
+
320
+ if sample:
321
+ fname = path_append(fname, sample)
322
+
323
+ if self.num_jets_per_output_file is not None:
324
+ suffix = f"{self.file_tag}_{file_idx:03d}"
325
+ fname = fname.with_name(f"{fname.stem}_{suffix}{fname.suffix}")
326
+
327
+ # Adjust shapes to the capacity of this file
328
+ shapes = {name: (jets_in_file,) + shape[1:] for name, shape in self.base_shapes.items()}
329
+
330
+ # Instantiate an H5Writer
331
+ self.writer = H5Writer(
332
+ fname,
333
+ self.dtypes,
334
+ shapes,
335
+ add_flavour_label=self.jets_name,
336
+ jets_name=self.jets_name,
337
+ num_jets=jets_in_file,
338
+ )
339
+
340
+ # Copy the metadata attributes
341
+ self.writer.add_attr(
342
+ "flavour_label",
343
+ [f.name for f in self.flavours],
344
+ self.jets_name,
345
+ )
346
+ self.writer.add_attr("unique_jets", components.unique_jets)
347
+ self.writer.add_attr("jet_counts", json.dumps(components.jet_counts))
348
+ self.writer.add_attr("dsids", str(components.dsids))
349
+ self.writer.add_attr("config", json.dumps(self.config.config))
350
+ self.writer.add_attr("upp_hash", self.config.git_hash)
351
+
352
+ # Log for debugging
353
+ log.debug(f"Setup merge output at {self.writer.dst}")
354
+
355
+ def write_chunk(self, components: Components) -> int:
356
+ """Read one chunk, merge and write it to disk (or discard in fast-forward).
357
+
358
+ Read one batch from every active component, merge them and write
359
+ them to disk. If the batch does not fit into the current file it is
360
+ split across files transparently.
361
+
362
+ Parameters
363
+ ----------
364
+ components : Components
365
+ Components that are to be written.
366
+
367
+ Returns
368
+ -------
369
+ int
370
+ The number of jets that were consumed from the components
371
+ (== written to disk). When all components are exhausted the
372
+ function returns 0 so that the caller can stop its loop.
373
+ """
374
+ # Init a merged dict
375
+ merged: dict[str, np.ndarray] = {}
376
+
377
+ # 1) Use pending tail first (only set across fast-forward boundary)
378
+ if self._ff_pending is not None:
379
+ merged = self._ff_pending
380
+ self._ff_pending = None
381
+ else:
382
+ # 2) Otherwise, pull one batch from every active component
383
+ for component in components:
384
+ try:
385
+ # shallow copy because we will add a field
386
+ batch = copy(next(component.stream))
387
+ batch[self.jets_name] = self.add_jet_flavour_label(
388
+ jets=batch[self.jets_name], component=component
389
+ )
390
+ except StopIteration:
391
+ component.complete = True
392
+
393
+ if component.complete:
394
+ continue
395
+
396
+ # Merge this component's arrays into the running dict
397
+ for name, array in batch.items():
398
+ if name not in merged:
399
+ merged[name] = array
400
+ else:
401
+ merged[name] = np.concatenate([merged[name], array])
402
+
403
+ # If nothing merged and all components are exhausted -> stop
404
+ if not merged and all(c.complete for c in components):
405
+ return 0
406
+
407
+ # Apply track selections
408
+ for name in self.variables.variables:
409
+ if name == self.jets_name:
410
+ continue
411
+ if selector := self.variables.selectors.get(name):
412
+ merged[name] = selector(merged[name])
413
+
414
+ # Get the total length of jets from the batch and how much
415
+ # capacity is left in the file
416
+ merged_len = len(merged[self.jets_name])
417
+ capacity_left = self.writer.num_jets - self.writer.num_written
418
+
419
+ if self._fast_forwarding:
420
+ # Limit consumption to the remaining discard quota
421
+ if merged_len <= capacity_left:
422
+ self.writer.write(merged)
423
+ self.jets_written += merged_len
424
+ return merged_len
425
+ else:
426
+ head = {n: a[:capacity_left] for n, a in merged.items()}
427
+ tail = {n: a[capacity_left:] for n, a in merged.items()}
428
+ self.writer.write(head)
429
+ self._ff_pending = tail # keep remainder for next iteration
430
+ self.jets_written += capacity_left
431
+ return capacity_left
432
+
433
+ # If current file is full (and not fast-forwarding), roll to next file
434
+ if capacity_left == 0 and not self._fast_forwarding:
435
+ # close the filled file
436
+ self.writer.close()
437
+
438
+ # open the next one
439
+ self._file_idx += 1
440
+ remaining_total = self.total_jets - self.jets_written
441
+
442
+ # Quit writing when no jets are left to write
443
+ if remaining_total == 0:
444
+ return 0
445
+
446
+ next_file_size = (
447
+ min(self.num_jets_per_output_file, remaining_total)
448
+ if self.num_jets_per_output_file
449
+ else remaining_total
450
+ )
451
+ self._open_writer(
452
+ self._sample,
453
+ next_file_size,
454
+ self._file_idx,
455
+ self.current_components,
456
+ )
457
+
458
+ # Recompute free space in the freshly-opened file
459
+ capacity_left = self.writer.num_jets - self.writer.num_written
460
+
461
+ # Write (or discard) the batch
462
+ if merged_len <= capacity_left or self._fast_forwarding:
463
+ # whole batch fits
464
+ self.writer.write(merged)
465
+
466
+ else:
467
+ # Write the *head* that still fits into the present file
468
+ head = {n: a[:capacity_left] for n, a in merged.items()}
469
+ self.writer.write(head)
470
+ self.writer.close()
471
+
472
+ # Open a fresh file sized for the remaining jets
473
+ self._file_idx += 1
474
+ remaining_total = self.total_jets - (self.jets_written + capacity_left)
475
+ next_file_size = (
476
+ min(self.num_jets_per_output_file, remaining_total)
477
+ if self.num_jets_per_output_file
478
+ else remaining_total
479
+ )
480
+ self._open_writer(self._sample, next_file_size, self._file_idx, self.current_components)
481
+
482
+ # Write the *tail* that goes into the new file
483
+ tail = {n: a[capacity_left:] for n, a in merged.items()}
484
+ self.writer.write(tail)
485
+
486
+ # Updating the progress-bar
487
+ self.jets_written += merged_len
488
+ return merged_len
489
+
490
+ def write_components(self, sample: str | None, components: Components) -> None:
491
+ """Merge *components* into one or more HDF5 files.
492
+
493
+ If ``self.num_jets_per_output_file`` is ``None`` the behaviour is identical to the
494
+ original implementation (exactly one output file). Otherwise the function
495
+ keeps opening new `H5Writer`s whenever the current file reaches that jet
496
+ limit. All heavy work (splitting batches, rolling files) is handled in
497
+ ``self.write_chunk``.
498
+
499
+ Parameters
500
+ ----------
501
+ sample : str | None
502
+ Name of the sample
503
+ components : Components
504
+ Components that are to be written
505
+ """
506
+ # Prepare every Component's reader
507
+ for component in components:
508
+ batch_size = self.batch_size * component.num_jets // components.num_jets + 1
509
+ component.setup_reader(
510
+ batch_size,
511
+ fname=component.out_path,
512
+ jets_name=self.jets_name,
513
+ )
514
+ component.stream = component.reader.stream(
515
+ self.variables.combined(),
516
+ component.reader.num_jets,
517
+ )
518
+ component.complete = False
519
+
520
+ # Cache dtype / base shapes once (re-used for every new file)
521
+ self.dtypes = components[0].reader.dtypes(self.variables.combined())
522
+ self.base_shapes = components[0].reader.shapes(components.num_jets, self.variables.keys())
523
+
524
+ # Bookkeeping shared with write_chunk
525
+ self.total_jets = components.num_jets
526
+ self.jets_written = 0
527
+ self._file_idx = 0
528
+ self._sample = sample
529
+ self.current_components = components
530
+
531
+ # Auto-resume: detect contiguous valid parts; delete a corrupt last part if found
532
+ resume_parts = 0
533
+ if self.resume and isinstance(self.num_jets_per_output_file, int):
534
+ resume_parts = self._detect_and_clean_completed_parts(sample)
535
+
536
+ if resume_parts and isinstance(self.num_jets_per_output_file, int):
537
+ to_discard = resume_parts * int(self.num_jets_per_output_file)
538
+ log.info(
539
+ f"[bold yellow]Resuming merge: found {resume_parts} completed part(s); "
540
+ f"skipping first {to_discard:,} jets."
541
+ )
542
+ # Use a NullWriter to pre-consume data via the exact same logic
543
+ self._fast_forwarding = True
544
+ self.writer = self._NullWriter(to_discard)
545
+ while self.jets_written < to_discard:
546
+ consumed = self.write_chunk(components)
547
+ if consumed == 0:
548
+ break
549
+ self.writer.close()
550
+ self._fast_forwarding = False
551
+
552
+ # Align counters with the next missing part
553
+ self._file_idx = resume_parts
554
+ self.jets_written = to_discard
555
+
556
+ # Decide capacity of the first real file
557
+ remaining_total = self.total_jets - self.jets_written
558
+ first_file_size = (
559
+ min(self.num_jets_per_output_file, remaining_total)
560
+ if self.num_jets_per_output_file
561
+ else remaining_total
562
+ )
563
+
564
+ # Open the first output file
565
+ self._open_writer(sample, first_file_size, self._file_idx, components)
566
+
567
+ # Main merge loop with progress
568
+ with ProgressBar() as progress:
569
+ task = progress.add_task(
570
+ f"[green]Merging {components.num_jets:,} jets...",
571
+ total=components.num_jets,
572
+ )
573
+ if self.jets_written:
574
+ progress.update(task, advance=self.jets_written)
575
+
576
+ while True:
577
+ n = self.write_chunk(components)
578
+ if not n:
579
+ break
580
+ progress.update(task, advance=n)
581
+
582
+ # Close Writer
583
+ self.writer.close()
584
+ label = "merged" if sample is None else sample
585
+ log.info(f"[bold green]Finished merging {components.num_jets:,} {label} jets!")
586
+
587
+ def run(self):
588
+ """Run merging of the components."""
589
+ title = " Running Merging "
590
+ log.info(f"[bold green]{title:-^100}")
591
+
592
+ if not self.config.is_test or self.config.merge_test_samples:
593
+ components = [(None, self.components)]
594
+ else:
595
+ components = self.components.groupby_sample()
596
+
597
+ for sample, comps in components:
598
+ self.write_components(sample, comps)
@@ -1,322 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import logging as log
5
- from copy import copy
6
- from pathlib import Path
7
- from typing import TYPE_CHECKING, cast
8
-
9
- import numpy as np
10
- from ftag.hdf5 import H5Writer, join_structured_arrays
11
-
12
- from upp.utils.logger import ProgressBar
13
- from upp.utils.tools import path_append
14
-
15
- if TYPE_CHECKING: # pragma: no cover
16
- from upp.classes.components import Component, Components
17
- from upp.classes.preprocessing_config import PreprocessingConfig
18
-
19
-
20
- class Merging:
21
- """Merging Class to merge different components/regions."""
22
-
23
- def __init__(self, config: PreprocessingConfig):
24
- self.config = config
25
- self.components = config.components
26
- self.variables = config.variables
27
- self.batch_size = config.batch_size
28
- self.jets_name = config.jets_name
29
- self.rng = np.random.default_rng(42)
30
- self.flavours = self.components.flavours
31
- self.num_jets_per_output_file = config.num_jets_per_output_file
32
- self.file_tag = "split"
33
-
34
- # perfectly valid, no union necessary
35
- self.dtypes: dict[str, np.dtype] = {}
36
- self.base_shapes: dict[str, tuple[int, ...]] = {}
37
-
38
- # Setup all the jet counters
39
- self._file_idx: int = 0
40
- self.total_jets: int = 0
41
- self.jets_written: int = 0
42
-
43
- # Setup the sample string
44
- self._sample: str | None = None
45
-
46
- # Use cast because we cannot init Components/H5Writer
47
- self.current_components = cast("Components", None)
48
- self.writer = cast(H5Writer, None)
49
-
50
- def add_jet_flavour_label(self, jets: np.ndarray, component: Component) -> np.ndarray:
51
- """Add the jet flavour label to the jets.
52
-
53
- If already present, jets will be returned without any changes.
54
-
55
- Parameters
56
- ----------
57
- jets : np.ndarray
58
- Structured array of with the jets and their variables
59
- component : Component
60
- Component instance of the
61
-
62
- Returns
63
- -------
64
- np.ndarray
65
- Structured array of the jets and their variables with the
66
- "flavour_label" added.
67
- """
68
- if "flavour_label" in jets.dtype.names:
69
- return jets
70
- int_label = self.flavours.index(component.flavour)
71
- label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
72
-
73
- return join_structured_arrays([jets, label_array])
74
-
75
- def _open_writer(
76
- self,
77
- sample: str | None,
78
- jets_in_file: int,
79
- file_idx: int,
80
- components: Components,
81
- ) -> None:
82
- """Create `self.writer` for the next output file and attach all static attributes.
83
-
84
- Parameters
85
- ----------
86
- sample : str | None
87
- Sample name (``None`` for the "train/val test" merge).
88
- jets_in_file : int
89
- Capacity of the new file (= leading dimension of every dataset).
90
- file_idx : int
91
- Running part index (0, 1, 2, …); used only for the filename suffix.
92
- components : Components
93
- The `Components` object we are currently merging needed for `jet_counts`, etc.
94
- """
95
- # Construct the filename
96
- fname = Path(self.config.out_fname)
97
-
98
- if sample:
99
- fname = path_append(fname, sample)
100
-
101
- if self.num_jets_per_output_file is not None:
102
- suffix = f"{self.file_tag}_{file_idx:03d}"
103
- fname = fname.with_name(f"{fname.stem}_{suffix}{fname.suffix}")
104
-
105
- # Adjust shapes to the capacity of this file
106
- shapes = {name: (jets_in_file,) + shape[1:] for name, shape in self.base_shapes.items()}
107
-
108
- # Instantiate an H5Writer
109
- self.writer = H5Writer(
110
- fname,
111
- self.dtypes,
112
- shapes,
113
- add_flavour_label=self.jets_name,
114
- jets_name=self.jets_name,
115
- num_jets=jets_in_file,
116
- )
117
-
118
- # Copy the metadata attributes
119
- self.writer.add_attr(
120
- "flavour_label",
121
- [f.name for f in self.flavours],
122
- self.jets_name,
123
- )
124
- self.writer.add_attr("unique_jets", components.unique_jets)
125
- self.writer.add_attr("jet_counts", json.dumps(components.jet_counts))
126
- self.writer.add_attr("dsids", str(components.dsids))
127
- self.writer.add_attr("config", json.dumps(self.config.config))
128
- self.writer.add_attr("upp_hash", self.config.git_hash)
129
-
130
- # Log for debugging
131
- log.debug(f"Setup merge output at {self.writer.dst}")
132
-
133
- def write_chunk(self, components: Components) -> int:
134
- """Read one chunk, merge and write it to disk.
135
-
136
- Read one batch from every active component, merge them and write
137
- them to disk. If the batch does not fit into the current file it is
138
- split across files transparently.
139
-
140
- Parameters
141
- ----------
142
- components : Components
143
- Components that are to be written.
144
-
145
- Returns
146
- -------
147
- int
148
- The number of jets that were consumed from the components
149
- (== written to disk). When all components are exhausted the
150
- function returns 0 so that the caller can stop its loop.
151
- """
152
- # Init a merged dict
153
- merged: dict[str, np.ndarray] = {}
154
-
155
- # Loop over components
156
- for component in components:
157
- try:
158
- # shallow copy because we will add a field
159
- batch = copy(next(component.stream))
160
- batch[self.jets_name] = self.add_jet_flavour_label(
161
- jets=batch[self.jets_name], component=component
162
- )
163
- except StopIteration:
164
- component.complete = True
165
-
166
- if component.complete:
167
- continue
168
-
169
- # Merge this component's arrays into the running dict
170
- for name, array in batch.items():
171
- if name not in merged:
172
- merged[name] = array
173
- else:
174
- merged[name] = np.concatenate([merged[name], array])
175
-
176
- # Stop if there is nothing more to read
177
- if all(c.complete for c in components):
178
- return 0
179
-
180
- # Apply track selections
181
- for name in self.variables.variables:
182
- if name == self.jets_name:
183
- continue
184
- if selector := self.variables.selectors.get(name):
185
- merged[name] = selector(merged[name])
186
-
187
- # Get the total length of jets from the batch and how much
188
- # capacity is left in the file
189
- merged_len = len(merged[self.jets_name])
190
- capacity_left = self.writer.num_jets - self.writer.num_written
191
-
192
- # Check if the capacity of the given file is already zero
193
- if capacity_left == 0:
194
- # close the filled file
195
- self.writer.close()
196
-
197
- # open the next one
198
- self._file_idx += 1
199
- remaining_total = self.total_jets - self.jets_written
200
-
201
- # Quit writing when no jets are left to write
202
- if remaining_total == 0:
203
- return 0
204
-
205
- next_file_size = (
206
- min(self.num_jets_per_output_file, remaining_total)
207
- if self.num_jets_per_output_file
208
- else remaining_total
209
- )
210
- self._open_writer(
211
- self._sample,
212
- next_file_size,
213
- self._file_idx,
214
- self.current_components,
215
- )
216
-
217
- # Recompute free space in the freshly-opened file
218
- capacity_left = self.writer.num_jets - self.writer.num_written
219
-
220
- # Check if the whole batch fits into the file
221
- if merged_len <= capacity_left:
222
- # whole batch fits
223
- self.writer.write(merged)
224
-
225
- else:
226
- # Write the *head* that still fits into the present file
227
- head = {n: a[:capacity_left] for n, a in merged.items()}
228
- self.writer.write(head)
229
- self.writer.close()
230
-
231
- # Open a fresh file sized for the remaining jets
232
- self._file_idx += 1
233
- remaining_total = self.total_jets - (self.jets_written + capacity_left)
234
- next_file_size = (
235
- min(self.num_jets_per_output_file, remaining_total)
236
- if self.num_jets_per_output_file
237
- else remaining_total
238
- )
239
- self._open_writer(self._sample, next_file_size, self._file_idx, self.current_components)
240
-
241
- # Write the *tail* that goes into the new file
242
- tail = {n: a[capacity_left:] for n, a in merged.items()}
243
- self.writer.write(tail)
244
-
245
- # Updating the progress-bar
246
- self.jets_written += merged_len
247
- return merged_len
248
-
249
- def write_components(self, sample: str | None, components: Components) -> None:
250
- """
251
- Merge *components* into one or more HDF5 files.
252
-
253
- If ``self.num_jets_per_output_file`` is ``None`` the behaviour is identical to the
254
- original implementation (exactly one output file). Otherwise the function
255
- keeps opening new `H5Writer`s whenever the current file reaches that jet
256
- limit. All heavy work (splitting batches, rolling files) is handled in
257
- ``self.write_chunk``.
258
- """
259
- # Prepare every Component's reader
260
- for component in components:
261
- batch_size = self.batch_size * component.num_jets // components.num_jets + 1
262
- component.setup_reader(
263
- batch_size,
264
- fname=component.out_path,
265
- jets_name=self.jets_name,
266
- )
267
- component.stream = component.reader.stream(
268
- self.variables.combined(),
269
- component.reader.num_jets,
270
- )
271
- component.complete = False
272
-
273
- # Cache dtype / base shapes once (re-used for every new file)
274
- self.dtypes = components[0].reader.dtypes(self.variables.combined())
275
- self.base_shapes = components[0].reader.shapes(components.num_jets, self.variables.keys())
276
-
277
- # Bookkeeping shared with write_chunk
278
- self.total_jets = components.num_jets
279
- self.jets_written = 0
280
- self._file_idx = 0
281
- self._sample = sample
282
- self.current_components = components
283
-
284
- # decide capacity of the first file
285
- first_file_size = (
286
- min(self.num_jets_per_output_file, self.total_jets)
287
- if self.num_jets_per_output_file
288
- else self.total_jets
289
- )
290
-
291
- # Open the first output file
292
- self._open_writer(sample, first_file_size, self._file_idx, components)
293
-
294
- # Main merge loop (progress bar unchanged)
295
- with ProgressBar() as progress:
296
- task = progress.add_task(
297
- f"[green]Merging {components.num_jets:,} jets...",
298
- total=components.num_jets,
299
- )
300
- while True:
301
- n = self.write_chunk(components)
302
- if not n:
303
- break
304
- progress.update(task, advance=n)
305
-
306
- # Close Writer
307
- self.writer.close()
308
- label = "merged" if sample is None else sample
309
- log.info(f"[bold green]Finished merging {components.num_jets:,} {label} jets!")
310
-
311
- def run(self):
312
- """Run merging of the components."""
313
- title = " Running Merging "
314
- log.info(f"[bold green]{title:-^100}")
315
-
316
- if not self.config.is_test or self.config.merge_test_samples:
317
- components = [(None, self.components)]
318
- else:
319
- components = self.components.groupby_sample()
320
-
321
- for sample, comps in components:
322
- self.write_components(sample, comps)