umami-preprocessing 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/PKG-INFO +11 -11
  2. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/pyproject.toml +11 -11
  3. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/PKG-INFO +11 -11
  4. umami_preprocessing-0.2.5/umami_preprocessing.egg-info/requires.txt +15 -0
  5. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/__init__.py +1 -1
  6. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/preprocessing_config.py +34 -16
  7. umami_preprocessing-0.2.5/upp/logger.py +76 -0
  8. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/main.py +4 -4
  9. umami_preprocessing-0.2.5/upp/stages/merging.py +308 -0
  10. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/normalisation.py +9 -2
  11. umami_preprocessing-0.2.5/upp/stages/plot.py +203 -0
  12. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/resampling.py +99 -31
  13. umami_preprocessing-0.2.3/umami_preprocessing.egg-info/requires.txt +0 -15
  14. umami_preprocessing-0.2.3/upp/logger.py +0 -39
  15. umami_preprocessing-0.2.3/upp/stages/merging.py +0 -176
  16. umami_preprocessing-0.2.3/upp/stages/plot.py +0 -325
  17. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/README.md +0 -0
  18. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/setup.cfg +0 -0
  19. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
  20. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
  21. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/entry_points.txt +0 -0
  22. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/top_level.txt +0 -0
  23. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/__init__.py +0 -0
  24. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/components.py +0 -0
  25. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/region.py +0 -0
  26. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/resampling_config.py +0 -0
  27. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/variable_config.py +0 -0
  28. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/__init__.py +0 -0
  29. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/hist.py +0 -0
  30. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/interpolation.py +0 -0
  31. {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/utils.py +0 -0
@@ -1,25 +1,25 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
7
  Requires-Python: <3.12,>=3.8
8
8
  Description-Content-Type: text/markdown
9
+ Requires-Dist: atlas-ftag-tools==0.2.14
10
+ Requires-Dist: dotmap==1.3.30
11
+ Requires-Dist: puma-hep==0.4.9
9
12
  Requires-Dist: pyyaml-include==1.3
10
- Requires-Dist: PyYAML==6.0.1
13
+ Requires-Dist: PyYAML>=6.0.1
11
14
  Requires-Dist: rich==12.6.0
12
- Requires-Dist: scipy==1.10.1
13
- Requires-Dist: puma-hep==0.4.2
14
- Requires-Dist: atlas-ftag-tools==0.2.8
15
- Requires-Dist: dotmap==1.3.30
15
+ Requires-Dist: scipy>=1.15.2
16
16
  Provides-Extra: dev
17
- Requires-Dist: ruff==0.1.6; extra == "dev"
18
- Requires-Dist: mypy==1.5.1; extra == "dev"
17
+ Requires-Dist: mypy==1.11.2; extra == "dev"
19
18
  Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
- Requires-Dist: pytest>=7.0.1; extra == "dev"
19
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
20
  Requires-Dist: pytest-mock==3.11.1; extra == "dev"
22
- Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
21
+ Requires-Dist: pytest>=7.2.2; extra == "dev"
22
+ Requires-Dist: ruff==0.6.2; extra == "dev"
23
23
 
24
24
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
25
25
  [![codecov](https://codecov.io/gh/umami-hep/umami-preprocessing/graph/badge.svg?token=K8MJI20UZO)](https://codecov.io/gh/umami-hep/umami-preprocessing)
@@ -7,23 +7,23 @@ readme = "README.md"
7
7
  requires-python = "<3.12,>=3.8"
8
8
 
9
9
  dependencies = [
10
- "pyyaml-include==1.3",
11
- "PyYAML==6.0.1",
12
- "rich==12.6.0",
13
- "scipy==1.10.1",
14
- "puma-hep==0.4.2",
15
- "atlas-ftag-tools==0.2.8",
16
- "dotmap==1.3.30"
10
+ "atlas-ftag-tools==0.2.14",
11
+ "dotmap==1.3.30",
12
+ "puma-hep==0.4.9",
13
+ "pyyaml-include==1.3",
14
+ "PyYAML>=6.0.1",
15
+ "rich==12.6.0",
16
+ "scipy>=1.15.2",
17
17
  ]
18
18
 
19
19
  [project.optional-dependencies]
20
20
  dev = [
21
- "ruff==0.1.6",
22
- "mypy==1.5.1",
21
+ "mypy==1.11.2",
23
22
  "pre-commit==3.5.0",
24
- "pytest>=7.0.1",
23
+ "pytest-cov>=4.0.0",
25
24
  "pytest-mock==3.11.1",
26
- "pytest-cov>=3.0.0",
25
+ "pytest>=7.2.2",
26
+ "ruff==0.6.2",
27
27
  ]
28
28
 
29
29
  [project.urls]
@@ -1,25 +1,25 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: umami-preprocessing
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Preprocessing for jet tagging
5
5
  License: MIT
6
6
  Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
7
7
  Requires-Python: <3.12,>=3.8
8
8
  Description-Content-Type: text/markdown
9
+ Requires-Dist: atlas-ftag-tools==0.2.14
10
+ Requires-Dist: dotmap==1.3.30
11
+ Requires-Dist: puma-hep==0.4.9
9
12
  Requires-Dist: pyyaml-include==1.3
10
- Requires-Dist: PyYAML==6.0.1
13
+ Requires-Dist: PyYAML>=6.0.1
11
14
  Requires-Dist: rich==12.6.0
12
- Requires-Dist: scipy==1.10.1
13
- Requires-Dist: puma-hep==0.4.2
14
- Requires-Dist: atlas-ftag-tools==0.2.8
15
- Requires-Dist: dotmap==1.3.30
15
+ Requires-Dist: scipy>=1.15.2
16
16
  Provides-Extra: dev
17
- Requires-Dist: ruff==0.1.6; extra == "dev"
18
- Requires-Dist: mypy==1.5.1; extra == "dev"
17
+ Requires-Dist: mypy==1.11.2; extra == "dev"
19
18
  Requires-Dist: pre-commit==3.5.0; extra == "dev"
20
- Requires-Dist: pytest>=7.0.1; extra == "dev"
19
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
20
  Requires-Dist: pytest-mock==3.11.1; extra == "dev"
22
- Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
21
+ Requires-Dist: pytest>=7.2.2; extra == "dev"
22
+ Requires-Dist: ruff==0.6.2; extra == "dev"
23
23
 
24
24
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
25
25
  [![codecov](https://codecov.io/gh/umami-hep/umami-preprocessing/graph/badge.svg?token=K8MJI20UZO)](https://codecov.io/gh/umami-hep/umami-preprocessing)
@@ -0,0 +1,15 @@
1
+ atlas-ftag-tools==0.2.14
2
+ dotmap==1.3.30
3
+ puma-hep==0.4.9
4
+ pyyaml-include==1.3
5
+ PyYAML>=6.0.1
6
+ rich==12.6.0
7
+ scipy>=1.15.2
8
+
9
+ [dev]
10
+ mypy==1.11.2
11
+ pre-commit==3.5.0
12
+ pytest-cov>=4.0.0
13
+ pytest-mock==3.11.1
14
+ pytest>=7.2.2
15
+ ruff==0.6.2
@@ -2,4 +2,4 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.3"
5
+ __version__ = "v0.2.5"
@@ -53,42 +53,56 @@ class PreprocessingConfig:
53
53
 
54
54
  Parameters
55
55
  ----------
56
+ config_path : Path
57
+ Path to the config yaml file that is used. Does not need to be set in config.
58
+ split : Split
59
+ For which part the preprocessing is run. Either train, val or test. This needs
60
+ to be set as a command line argument when running the programm. Does not need
61
+ to be set in config.
62
+ config : dict
63
+ Dict with the loaded config. Does not need to be set in config.
56
64
  base_dir : Path
57
65
  Base directory for all other paths.
58
- ntuple_dir : Path
66
+ ntuple_dir : Path, optional
59
67
  Directory containing the input h5 ntuples. If a relative path is given, it is
60
- interpreted as relative to base_dir.
61
- components_dir : Path
68
+ interpreted as relative to base_dir. By default Path("ntuples")
69
+ components_dir : Path, optional
62
70
  Directory for intermediate component files. If a relative path is given, it is
63
- interpreted as relative to base_dir.
64
- out_dir : Path
71
+ interpreted as relative to base_dir. By default Path("components")
72
+ out_dir : Path, optional
65
73
  Directory for output files. If a relative path is given, it is interpreted as
66
- relative to base_dir.
67
- out_fname : Path
68
- Filename stem for the output files.
69
- batch_size : int
74
+ relative to base_dir. By default Path("output")
75
+ out_fname : Path, optional
76
+ Filename stem for the output files. By default Path("pp_output.h5")
77
+ batch_size : int, optional
70
78
  Batch size for the preprocessing. For each batch select
71
79
  `sampling_fraction*batch_size_after_cuts`. It is recommended to choose high batch sizes
72
80
  especially to the `countup` method to achive best agreement of target and resampled
73
- distributions.
74
- num_jets_estimate : int
81
+ distributions. By default 100_000
82
+ num_jets_estimate : int, optional
75
83
  Any of the further three arguments that are not specified will default to this value
76
84
  Is equal to 1_000_000 by default.
77
- num_jets_estimate_available : int | None
85
+ num_jets_estimate_available : int, optional
78
86
  A sabsample taken from the whole sample to estimate the number of jets after the cuts.
79
87
  Please keep this number high in order to not get poisson error of more then 5%.
80
88
  If time allows you can use -1 to get a precise number of jets and not just an estimate
81
89
  although it will be slow for large datasets. Is equal to num_jets_estimate by default.
82
- num_jets_estimate_hist : int
90
+ num_jets_estimate_hist : int, optional
83
91
  Number of jets of each flavour that are used to construct histograms for probability
84
92
  density function estimation. Larger numbers give a better quality estmate of the pdfs.
85
93
  Is equal to num_jets_estimate by default.
86
- num_jets_estimate_norm : int
94
+ num_jets_estimate_norm : int, optional
87
95
  Number of jets of each flavour that are used to estimate shifting and scaling during
88
96
  normalisation step. Larger numbers give a better quality estmates.
89
97
  Is equal to num_jets_estimate by default.
90
- jets_name : str
91
- Name of the jets dataset in the input file.
98
+ num_jets_estimate_plotting : int, optional
99
+ Number of jets of each flavour used for plotting the initial and the final resampling
100
+ variable distributions. Larger numbers give a better estimate of the full distributions.
101
+ Is equal to num_jets_estimate by default.
102
+ merge_test_samples : bool, optional
103
+ Merge the test samples of the different processes into one file. By default False.
104
+ jets_name : str, optional
105
+ Name of the jets dataset in the input file. By default "jets".
92
106
  """
93
107
 
94
108
  config_path: Path
@@ -104,9 +118,11 @@ class PreprocessingConfig:
104
118
  num_jets_estimate_available: int | None = None
105
119
  num_jets_estimate_hist: int | None = None
106
120
  num_jets_estimate_norm: int | None = None
121
+ num_jets_estimate_plotting: int | None = None
107
122
  merge_test_samples: bool = False
108
123
  jets_name: str = "jets"
109
124
  flavour_config: Path | None = None
125
+ num_jets_per_output_file: int | None = None
110
126
 
111
127
  def __post_init__(self):
112
128
  # postprocess paths
@@ -117,6 +133,8 @@ class PreprocessingConfig:
117
133
  self.num_jets_estimate_hist = self.num_jets_estimate
118
134
  if self.num_jets_estimate_norm is None:
119
135
  self.num_jets_estimate_norm = self.num_jets_estimate
136
+ if self.num_jets_estimate_plotting is None:
137
+ self.num_jets_estimate_plotting = self.num_jets_estimate
120
138
 
121
139
  for field in dataclasses.fields(self):
122
140
  if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ from functools import partial
6
+
7
+ from rich.console import Console
8
+ from rich.logging import RichHandler
9
+ from rich.progress import (
10
+ BarColumn,
11
+ Progress,
12
+ TextColumn,
13
+ TimeElapsedColumn,
14
+ TimeRemainingColumn,
15
+ )
16
+
17
+ # Detect if the program is executed in an interactive terminal
18
+ _IS_TTY = sys.stderr.isatty()
19
+
20
+ # One console object is reused everywhere so that Rich keeps a consistent idea
21
+ # of whether it may emit ANSI control codes / animations.
22
+ _console = Console(
23
+ width=100,
24
+ force_terminal=_IS_TTY,
25
+ force_interactive=_IS_TTY,
26
+ no_color=not _IS_TTY,
27
+ )
28
+
29
+ # Template for the progress bar
30
+ ProgressBar = partial(
31
+ Progress,
32
+ TextColumn("[task.description]{task.description}"),
33
+ BarColumn(),
34
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
35
+ TextColumn("•"),
36
+ TimeRemainingColumn(),
37
+ TextColumn("•"),
38
+ TimeElapsedColumn(),
39
+ refresh_per_second=1 if _IS_TTY else 0.05,
40
+ speed_estimate_period=30 if _IS_TTY else 120,
41
+ console=_console,
42
+ disable=not _IS_TTY,
43
+ transient=_IS_TTY,
44
+ )
45
+
46
+
47
+ # Helper for setup the logger
48
+ def setup_logger(level: str = "INFO"):
49
+ """Set up the logger.
50
+
51
+ Configure Rich logging so that colourful / interactive output is used when
52
+ the program is attached to a terminal and plain text is written when it is
53
+ executed under a batch system such as Slurm (where stdout / stderr are files).
54
+ """
55
+ FORMAT = "%(message)s"
56
+
57
+ # In a batch job we create a console that never emits colour codes.
58
+ console = None
59
+ if not _IS_TTY:
60
+ console = Console(
61
+ width=120,
62
+ force_terminal=False,
63
+ force_interactive=False,
64
+ no_color=True,
65
+ )
66
+
67
+ handler = RichHandler(
68
+ show_time=False,
69
+ show_path=False,
70
+ markup=True,
71
+ rich_tracebacks=True,
72
+ console=console,
73
+ )
74
+
75
+ logging.basicConfig(level=level, format=FORMAT, handlers=[handler])
76
+ return logging
@@ -21,7 +21,7 @@ from upp.logger import setup_logger
21
21
  from upp.stages.hist import create_histograms
22
22
  from upp.stages.merging import Merging
23
23
  from upp.stages.normalisation import Normalisation
24
- from upp.stages.plot import plot_initial_resampling_dists, plot_resampled_dists
24
+ from upp.stages.plot import plot_resampling_dists
25
25
  from upp.stages.resampling import Resampling
26
26
 
27
27
 
@@ -79,7 +79,7 @@ def run_pp(args) -> None:
79
79
  # run the resampling
80
80
  if args.resample:
81
81
  resampling = Resampling(config)
82
- resampling.run(region=args.region)
82
+ resampling.run(region=args.region, component=args.component)
83
83
 
84
84
  # run the merging
85
85
  if args.merge:
@@ -95,8 +95,8 @@ def run_pp(args) -> None:
95
95
  if args.plot:
96
96
  title = " Plotting "
97
97
  log.info(f"[bold green]{title:-^100}")
98
- plot_initial_resampling_dists(config=config)
99
- plot_resampled_dists(config=config, stage=args.split)
98
+ plot_resampling_dists(config=config, stage="initial")
99
+ plot_resampling_dists(config=config, stage=args.split)
100
100
 
101
101
  # print end info
102
102
  end = datetime.now()
@@ -0,0 +1,308 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging as log
5
+ from copy import copy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import numpy as np
10
+ from ftag.hdf5 import H5Writer, join_structured_arrays
11
+
12
+ from upp.logger import ProgressBar
13
+ from upp.utils import path_append
14
+
15
+ if TYPE_CHECKING: # pragma: no cover
16
+ from upp.classes.components import Component, Components
17
+ from upp.classes.preprocessing_config import PreprocessingConfig
18
+
19
+
20
+ class Merging:
21
+ """Merging Class to merge different components/regions."""
22
+
23
+ def __init__(self, config: PreprocessingConfig):
24
+ """Init the Merging class instance.
25
+
26
+ Parameters
27
+ ----------
28
+ config : PreprocessingConfig
29
+ Loaded preprocessing config as a PreprocessingConfig instance
30
+ """
31
+ self.config = config
32
+ self.components = config.components
33
+ self.variables = config.variables
34
+ self.batch_size = config.batch_size
35
+ self.jets_name = config.jets_name
36
+ self.rng = np.random.default_rng(42)
37
+ self.flavours = self.components.flavours
38
+ self.num_jets_per_output_file = config.num_jets_per_output_file
39
+ self.file_tag = "split"
40
+
41
+ def add_jet_flavour_label(self, jets: np.ndarray, component: Component) -> np.ndarray:
42
+ """Add the jet flavour label to the jets.
43
+
44
+ If already present, jets will be returned without any changes.
45
+
46
+ Parameters
47
+ ----------
48
+ jets : np.ndarray
49
+ Structured array of with the jets and their variables
50
+ component : Component
51
+ Component instance of the
52
+
53
+ Returns
54
+ -------
55
+ np.ndarray
56
+ Structured array of the jets and their variables with the
57
+ "flavour_label" added.
58
+ """
59
+ if "flavour_label" in jets.dtype.names:
60
+ return jets
61
+ int_label = self.flavours.index(component.flavour)
62
+ label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
63
+
64
+ return join_structured_arrays([jets, label_array])
65
+
66
+ def _open_writer(
67
+ self,
68
+ sample: str | None,
69
+ jets_in_file: int,
70
+ file_idx: int,
71
+ components: Components,
72
+ ) -> None:
73
+ """Create `self.writer` for the next output file and attach all static attributes.
74
+
75
+ Parameters
76
+ ----------
77
+ sample
78
+ Sample name (``None`` for the "train/val test" merge).
79
+ jets_in_file
80
+ Capacity of the new file (= leading dimension of every dataset).
81
+ file_idx
82
+ Running part index (0, 1, 2, …); used only for the filename suffix.
83
+ components
84
+ The `Components` object we are currently merging needed for `jet_counts`, etc.
85
+ """
86
+ # Construct the filename
87
+ fname = Path(self.config.out_fname)
88
+
89
+ if sample:
90
+ fname = path_append(fname, sample)
91
+
92
+ if self.num_jets_per_output_file is not None:
93
+ suffix = f"{self.file_tag}_{file_idx:03d}"
94
+ fname = fname.with_name(f"{fname.stem}_{suffix}{fname.suffix}")
95
+
96
+ # Adjust shapes to the capacity of this file
97
+ shapes = {name: (jets_in_file,) + shape[1:] for name, shape in self.base_shapes.items()}
98
+
99
+ # Instantiate an H5Writer
100
+ self.writer = H5Writer(
101
+ fname,
102
+ self.dtypes,
103
+ shapes,
104
+ add_flavour_label=self.jets_name,
105
+ jets_name=self.jets_name,
106
+ num_jets=jets_in_file,
107
+ )
108
+
109
+ # Copy the metadata attributes
110
+ self.writer.add_attr(
111
+ "flavour_label",
112
+ [f.name for f in self.flavours],
113
+ self.jets_name,
114
+ )
115
+ self.writer.add_attr("unique_jets", components.unique_jets)
116
+ self.writer.add_attr("jet_counts", json.dumps(components.jet_counts))
117
+ self.writer.add_attr("dsids", str(components.dsids))
118
+ self.writer.add_attr("config", json.dumps(self.config.config))
119
+ self.writer.add_attr("upp_hash", self.config.git_hash)
120
+
121
+ # Log for debugging
122
+ log.debug(f"Setup merge output at {self.writer.dst}")
123
+
124
+ def write_chunk(self, components: Components) -> int:
125
+ """Read one chunk, merge and write it to disk.
126
+
127
+ Read one batch from every active component, merge them and write
128
+ them to disk. If the batch does not fit into the current file it is
129
+ split across files transparently.
130
+
131
+ Returns
132
+ -------
133
+ int
134
+ The number of jets that were consumed from the components
135
+ (== written to disk). When all components are exhausted the
136
+ function returns 0 so that the caller can stop its loop.
137
+ """
138
+ # Init a merged dict
139
+ merged: dict[str, np.ndarray] = {}
140
+
141
+ # Loop over components
142
+ for component in components:
143
+ try:
144
+ # shallow copy because we will add a field
145
+ batch = copy(next(component.stream))
146
+ batch[self.jets_name] = self.add_jet_flavour_label(
147
+ jets=batch[self.jets_name], component=component
148
+ )
149
+ except StopIteration:
150
+ component.complete = True
151
+
152
+ if component.complete:
153
+ continue
154
+
155
+ # Merge this component's arrays into the running dict
156
+ for name, array in batch.items():
157
+ if name not in merged:
158
+ merged[name] = array
159
+ else:
160
+ merged[name] = np.concatenate([merged[name], array])
161
+
162
+ # Stop if there is nothing more to read
163
+ if all(c.complete for c in components):
164
+ return 0
165
+
166
+ # Apply track selections
167
+ for name in self.variables.variables:
168
+ if name == self.jets_name:
169
+ continue
170
+ if selector := self.variables.selectors.get(name):
171
+ merged[name] = selector(merged[name])
172
+
173
+ # Get the total length of jets from the batch and how much
174
+ # capacity is left in the file
175
+ merged_len = len(merged[self.jets_name])
176
+ capacity_left = self.writer.num_jets - self.writer.num_written
177
+
178
+ # Check if the capacity of the given file is already zero
179
+ if capacity_left == 0:
180
+ # close the filled file
181
+ self.writer.close()
182
+
183
+ # open the next one
184
+ self._file_idx += 1
185
+ remaining_total = self.total_jets - self.jets_written
186
+
187
+ # Quit writing when no jets are left to write
188
+ if remaining_total == 0:
189
+ return 0
190
+
191
+ next_file_size = (
192
+ min(self.num_jets_per_output_file, remaining_total)
193
+ if self.num_jets_per_output_file
194
+ else remaining_total
195
+ )
196
+ self._open_writer(
197
+ self._sample,
198
+ next_file_size,
199
+ self._file_idx,
200
+ self.current_components,
201
+ )
202
+
203
+ # Recompute free space in the freshly-opened file
204
+ capacity_left = self.writer.num_jets - self.writer.num_written
205
+
206
+ # Check if the whole batch fits into the file
207
+ if merged_len <= capacity_left:
208
+ # whole batch fits
209
+ self.writer.write(merged)
210
+
211
+ else:
212
+ # Write the *head* that still fits into the present file
213
+ head = {n: a[:capacity_left] for n, a in merged.items()}
214
+ self.writer.write(head)
215
+ self.writer.close()
216
+
217
+ # Open a fresh file sized for the remaining jets
218
+ self._file_idx += 1
219
+ remaining_total = self.total_jets - (self.jets_written + capacity_left)
220
+ next_file_size = (
221
+ min(self.num_jets_per_output_file, remaining_total)
222
+ if self.num_jets_per_output_file
223
+ else remaining_total
224
+ )
225
+ self._open_writer(self._sample, next_file_size, self._file_idx, self.current_components)
226
+
227
+ # Write the *tail* that goes into the new file
228
+ tail = {n: a[capacity_left:] for n, a in merged.items()}
229
+ self.writer.write(tail)
230
+
231
+ # Updating the progress-bar
232
+ self.jets_written += merged_len
233
+ return merged_len
234
+
235
+ def write_components(self, sample: str | None, components: Components) -> None:
236
+ """
237
+ Merge *components* into one or more HDF5 files.
238
+
239
+ If ``self.num_jets_per_output_file`` is ``None`` the behaviour is identical to the
240
+ original implementation (exactly one output file). Otherwise the function
241
+ keeps opening new `H5Writer`s whenever the current file reaches that jet
242
+ limit. All heavy work (splitting batches, rolling files) is handled in
243
+ ``self.write_chunk``.
244
+ """
245
+ # Prepare every Component's reader
246
+ for component in components:
247
+ batch_size = self.batch_size * component.num_jets // components.num_jets + 1
248
+ component.setup_reader(
249
+ batch_size,
250
+ fname=component.out_path,
251
+ jets_name=self.jets_name,
252
+ )
253
+ component.stream = component.reader.stream(
254
+ self.variables.combined(),
255
+ component.reader.num_jets,
256
+ )
257
+ component.complete = False
258
+
259
+ # Cache dtype / base shapes once (re-used for every new file)
260
+ self.dtypes = components[0].reader.dtypes(self.variables.combined())
261
+ self.base_shapes = components[0].reader.shapes(components.num_jets, self.variables.keys())
262
+
263
+ # Bookkeeping shared with write_chunk
264
+ self.total_jets = components.num_jets
265
+ self.jets_written = 0
266
+ self._file_idx = 0
267
+ self._sample = sample
268
+ self.current_components = components
269
+
270
+ # decide capacity of the first file
271
+ first_file_size = (
272
+ min(self.num_jets_per_output_file, self.total_jets)
273
+ if self.num_jets_per_output_file
274
+ else self.total_jets
275
+ )
276
+
277
+ # Open the first output file
278
+ self._open_writer(sample, first_file_size, self._file_idx, components)
279
+
280
+ # Main merge loop (progress bar unchanged)
281
+ with ProgressBar() as progress:
282
+ task = progress.add_task(
283
+ f"[green]Merging {components.num_jets:,} jets...",
284
+ total=components.num_jets,
285
+ )
286
+ while True:
287
+ n = self.write_chunk(components)
288
+ if not n:
289
+ break
290
+ progress.update(task, advance=n)
291
+
292
+ # Close Writer
293
+ self.writer.close()
294
+ label = "merged" if sample is None else sample
295
+ log.info(f"[bold green]Finished merging {components.num_jets:,} {label} jets!")
296
+
297
+ def run(self):
298
+ """Run merging of the components."""
299
+ title = " Running Merging "
300
+ log.info(f"[bold green]{title:-^100}")
301
+
302
+ if not self.config.is_test or self.config.merge_test_samples:
303
+ components = [(None, self.components)]
304
+ else:
305
+ components = self.components.groupby_sample()
306
+
307
+ for sample, comps in components:
308
+ self.write_components(sample, comps)
@@ -241,14 +241,21 @@ class Normalisation:
241
241
  title = " Computing Normalisations "
242
242
  log.info(f"[bold green]{title:-^100}")
243
243
 
244
+ # Get the correct output names if multiple output files were written
245
+ if self.config.num_jets_per_output_file:
246
+ fname = self.config.out_fname.parent / f"{self.config.out_fname.stem}*.h5"
247
+
248
+ else:
249
+ fname = self.config.out_fname
250
+
244
251
  # Setup reader
245
252
  reader = H5Reader(
246
- self.config.out_fname,
253
+ fname,
247
254
  self.config.batch_size,
248
255
  precision="full",
249
256
  jets_name=self.jets_name,
250
257
  )
251
- log.debug(f"Setup reader at: {self.config.out_fname}")
258
+ log.debug(f"Setup reader at: {fname}")
252
259
 
253
260
  norm_dict = None
254
261
  class_dict = None