umami-preprocessing 0.2.3__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/PKG-INFO +11 -11
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/pyproject.toml +11 -11
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/PKG-INFO +11 -11
- umami_preprocessing-0.2.5/umami_preprocessing.egg-info/requires.txt +15 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/__init__.py +1 -1
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/preprocessing_config.py +34 -16
- umami_preprocessing-0.2.5/upp/logger.py +76 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/main.py +4 -4
- umami_preprocessing-0.2.5/upp/stages/merging.py +308 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/normalisation.py +9 -2
- umami_preprocessing-0.2.5/upp/stages/plot.py +203 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/resampling.py +99 -31
- umami_preprocessing-0.2.3/umami_preprocessing.egg-info/requires.txt +0 -15
- umami_preprocessing-0.2.3/upp/logger.py +0 -39
- umami_preprocessing-0.2.3/upp/stages/merging.py +0 -176
- umami_preprocessing-0.2.3/upp/stages/plot.py +0 -325
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/README.md +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/setup.cfg +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/entry_points.txt +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/top_level.txt +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/__init__.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/components.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/region.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/resampling_config.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/classes/variable_config.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/__init__.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/hist.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/stages/interpolation.py +0 -0
- {umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/upp/utils.py +0 -0
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Preprocessing for jet tagging
|
|
5
5
|
License: MIT
|
|
6
6
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
7
|
Requires-Python: <3.12,>=3.8
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: atlas-ftag-tools==0.2.14
|
|
10
|
+
Requires-Dist: dotmap==1.3.30
|
|
11
|
+
Requires-Dist: puma-hep==0.4.9
|
|
9
12
|
Requires-Dist: pyyaml-include==1.3
|
|
10
|
-
Requires-Dist: PyYAML
|
|
13
|
+
Requires-Dist: PyYAML>=6.0.1
|
|
11
14
|
Requires-Dist: rich==12.6.0
|
|
12
|
-
Requires-Dist: scipy
|
|
13
|
-
Requires-Dist: puma-hep==0.4.2
|
|
14
|
-
Requires-Dist: atlas-ftag-tools==0.2.8
|
|
15
|
-
Requires-Dist: dotmap==1.3.30
|
|
15
|
+
Requires-Dist: scipy>=1.15.2
|
|
16
16
|
Provides-Extra: dev
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: mypy==1.5.1; extra == "dev"
|
|
17
|
+
Requires-Dist: mypy==1.11.2; extra == "dev"
|
|
19
18
|
Requires-Dist: pre-commit==3.5.0; extra == "dev"
|
|
20
|
-
Requires-Dist: pytest>=
|
|
19
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
21
20
|
Requires-Dist: pytest-mock==3.11.1; extra == "dev"
|
|
22
|
-
Requires-Dist: pytest
|
|
21
|
+
Requires-Dist: pytest>=7.2.2; extra == "dev"
|
|
22
|
+
Requires-Dist: ruff==0.6.2; extra == "dev"
|
|
23
23
|
|
|
24
24
|
[](https://github.com/psf/black)
|
|
25
25
|
[](https://codecov.io/gh/umami-hep/umami-preprocessing)
|
|
@@ -7,23 +7,23 @@ readme = "README.md"
|
|
|
7
7
|
requires-python = "<3.12,>=3.8"
|
|
8
8
|
|
|
9
9
|
dependencies = [
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
"atlas-ftag-tools==0.2.14",
|
|
11
|
+
"dotmap==1.3.30",
|
|
12
|
+
"puma-hep==0.4.9",
|
|
13
|
+
"pyyaml-include==1.3",
|
|
14
|
+
"PyYAML>=6.0.1",
|
|
15
|
+
"rich==12.6.0",
|
|
16
|
+
"scipy>=1.15.2",
|
|
17
17
|
]
|
|
18
18
|
|
|
19
19
|
[project.optional-dependencies]
|
|
20
20
|
dev = [
|
|
21
|
-
"
|
|
22
|
-
"mypy==1.5.1",
|
|
21
|
+
"mypy==1.11.2",
|
|
23
22
|
"pre-commit==3.5.0",
|
|
24
|
-
"pytest>=
|
|
23
|
+
"pytest-cov>=4.0.0",
|
|
25
24
|
"pytest-mock==3.11.1",
|
|
26
|
-
"pytest
|
|
25
|
+
"pytest>=7.2.2",
|
|
26
|
+
"ruff==0.6.2",
|
|
27
27
|
]
|
|
28
28
|
|
|
29
29
|
[project.urls]
|
{umami_preprocessing-0.2.3 → umami_preprocessing-0.2.5}/umami_preprocessing.egg-info/PKG-INFO
RENAMED
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Preprocessing for jet tagging
|
|
5
5
|
License: MIT
|
|
6
6
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
7
|
Requires-Python: <3.12,>=3.8
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: atlas-ftag-tools==0.2.14
|
|
10
|
+
Requires-Dist: dotmap==1.3.30
|
|
11
|
+
Requires-Dist: puma-hep==0.4.9
|
|
9
12
|
Requires-Dist: pyyaml-include==1.3
|
|
10
|
-
Requires-Dist: PyYAML
|
|
13
|
+
Requires-Dist: PyYAML>=6.0.1
|
|
11
14
|
Requires-Dist: rich==12.6.0
|
|
12
|
-
Requires-Dist: scipy
|
|
13
|
-
Requires-Dist: puma-hep==0.4.2
|
|
14
|
-
Requires-Dist: atlas-ftag-tools==0.2.8
|
|
15
|
-
Requires-Dist: dotmap==1.3.30
|
|
15
|
+
Requires-Dist: scipy>=1.15.2
|
|
16
16
|
Provides-Extra: dev
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: mypy==1.5.1; extra == "dev"
|
|
17
|
+
Requires-Dist: mypy==1.11.2; extra == "dev"
|
|
19
18
|
Requires-Dist: pre-commit==3.5.0; extra == "dev"
|
|
20
|
-
Requires-Dist: pytest>=
|
|
19
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
21
20
|
Requires-Dist: pytest-mock==3.11.1; extra == "dev"
|
|
22
|
-
Requires-Dist: pytest
|
|
21
|
+
Requires-Dist: pytest>=7.2.2; extra == "dev"
|
|
22
|
+
Requires-Dist: ruff==0.6.2; extra == "dev"
|
|
23
23
|
|
|
24
24
|
[](https://github.com/psf/black)
|
|
25
25
|
[](https://codecov.io/gh/umami-hep/umami-preprocessing)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
atlas-ftag-tools==0.2.14
|
|
2
|
+
dotmap==1.3.30
|
|
3
|
+
puma-hep==0.4.9
|
|
4
|
+
pyyaml-include==1.3
|
|
5
|
+
PyYAML>=6.0.1
|
|
6
|
+
rich==12.6.0
|
|
7
|
+
scipy>=1.15.2
|
|
8
|
+
|
|
9
|
+
[dev]
|
|
10
|
+
mypy==1.11.2
|
|
11
|
+
pre-commit==3.5.0
|
|
12
|
+
pytest-cov>=4.0.0
|
|
13
|
+
pytest-mock==3.11.1
|
|
14
|
+
pytest>=7.2.2
|
|
15
|
+
ruff==0.6.2
|
|
@@ -53,42 +53,56 @@ class PreprocessingConfig:
|
|
|
53
53
|
|
|
54
54
|
Parameters
|
|
55
55
|
----------
|
|
56
|
+
config_path : Path
|
|
57
|
+
Path to the config yaml file that is used. Does not need to be set in config.
|
|
58
|
+
split : Split
|
|
59
|
+
For which part the preprocessing is run. Either train, val or test. This needs
|
|
60
|
+
to be set as a command line argument when running the programm. Does not need
|
|
61
|
+
to be set in config.
|
|
62
|
+
config : dict
|
|
63
|
+
Dict with the loaded config. Does not need to be set in config.
|
|
56
64
|
base_dir : Path
|
|
57
65
|
Base directory for all other paths.
|
|
58
|
-
ntuple_dir : Path
|
|
66
|
+
ntuple_dir : Path, optional
|
|
59
67
|
Directory containing the input h5 ntuples. If a relative path is given, it is
|
|
60
|
-
interpreted as relative to base_dir.
|
|
61
|
-
components_dir : Path
|
|
68
|
+
interpreted as relative to base_dir. By default Path("ntuples")
|
|
69
|
+
components_dir : Path, optional
|
|
62
70
|
Directory for intermediate component files. If a relative path is given, it is
|
|
63
|
-
interpreted as relative to base_dir.
|
|
64
|
-
out_dir : Path
|
|
71
|
+
interpreted as relative to base_dir. By default Path("components")
|
|
72
|
+
out_dir : Path, optional
|
|
65
73
|
Directory for output files. If a relative path is given, it is interpreted as
|
|
66
|
-
relative to base_dir.
|
|
67
|
-
out_fname : Path
|
|
68
|
-
Filename stem for the output files.
|
|
69
|
-
batch_size : int
|
|
74
|
+
relative to base_dir. By default Path("output")
|
|
75
|
+
out_fname : Path, optional
|
|
76
|
+
Filename stem for the output files. By default Path("pp_output.h5")
|
|
77
|
+
batch_size : int, optional
|
|
70
78
|
Batch size for the preprocessing. For each batch select
|
|
71
79
|
`sampling_fraction*batch_size_after_cuts`. It is recommended to choose high batch sizes
|
|
72
80
|
especially to the `countup` method to achive best agreement of target and resampled
|
|
73
|
-
distributions.
|
|
74
|
-
num_jets_estimate : int
|
|
81
|
+
distributions. By default 100_000
|
|
82
|
+
num_jets_estimate : int, optional
|
|
75
83
|
Any of the further three arguments that are not specified will default to this value
|
|
76
84
|
Is equal to 1_000_000 by default.
|
|
77
|
-
num_jets_estimate_available : int
|
|
85
|
+
num_jets_estimate_available : int, optional
|
|
78
86
|
A sabsample taken from the whole sample to estimate the number of jets after the cuts.
|
|
79
87
|
Please keep this number high in order to not get poisson error of more then 5%.
|
|
80
88
|
If time allows you can use -1 to get a precise number of jets and not just an estimate
|
|
81
89
|
although it will be slow for large datasets. Is equal to num_jets_estimate by default.
|
|
82
|
-
num_jets_estimate_hist : int
|
|
90
|
+
num_jets_estimate_hist : int, optional
|
|
83
91
|
Number of jets of each flavour that are used to construct histograms for probability
|
|
84
92
|
density function estimation. Larger numbers give a better quality estmate of the pdfs.
|
|
85
93
|
Is equal to num_jets_estimate by default.
|
|
86
|
-
num_jets_estimate_norm : int
|
|
94
|
+
num_jets_estimate_norm : int, optional
|
|
87
95
|
Number of jets of each flavour that are used to estimate shifting and scaling during
|
|
88
96
|
normalisation step. Larger numbers give a better quality estmates.
|
|
89
97
|
Is equal to num_jets_estimate by default.
|
|
90
|
-
|
|
91
|
-
|
|
98
|
+
num_jets_estimate_plotting : int, optional
|
|
99
|
+
Number of jets of each flavour used for plotting the initial and the final resampling
|
|
100
|
+
variable distributions. Larger numbers give a better estimate of the full distributions.
|
|
101
|
+
Is equal to num_jets_estimate by default.
|
|
102
|
+
merge_test_samples : bool, optional
|
|
103
|
+
Merge the test samples of the different processes into one file. By default False.
|
|
104
|
+
jets_name : str, optional
|
|
105
|
+
Name of the jets dataset in the input file. By default "jets".
|
|
92
106
|
"""
|
|
93
107
|
|
|
94
108
|
config_path: Path
|
|
@@ -104,9 +118,11 @@ class PreprocessingConfig:
|
|
|
104
118
|
num_jets_estimate_available: int | None = None
|
|
105
119
|
num_jets_estimate_hist: int | None = None
|
|
106
120
|
num_jets_estimate_norm: int | None = None
|
|
121
|
+
num_jets_estimate_plotting: int | None = None
|
|
107
122
|
merge_test_samples: bool = False
|
|
108
123
|
jets_name: str = "jets"
|
|
109
124
|
flavour_config: Path | None = None
|
|
125
|
+
num_jets_per_output_file: int | None = None
|
|
110
126
|
|
|
111
127
|
def __post_init__(self):
|
|
112
128
|
# postprocess paths
|
|
@@ -117,6 +133,8 @@ class PreprocessingConfig:
|
|
|
117
133
|
self.num_jets_estimate_hist = self.num_jets_estimate
|
|
118
134
|
if self.num_jets_estimate_norm is None:
|
|
119
135
|
self.num_jets_estimate_norm = self.num_jets_estimate
|
|
136
|
+
if self.num_jets_estimate_plotting is None:
|
|
137
|
+
self.num_jets_estimate_plotting = self.num_jets_estimate
|
|
120
138
|
|
|
121
139
|
for field in dataclasses.fields(self):
|
|
122
140
|
if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.logging import RichHandler
|
|
9
|
+
from rich.progress import (
|
|
10
|
+
BarColumn,
|
|
11
|
+
Progress,
|
|
12
|
+
TextColumn,
|
|
13
|
+
TimeElapsedColumn,
|
|
14
|
+
TimeRemainingColumn,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Detect if the program is executed in an interactive terminal
|
|
18
|
+
_IS_TTY = sys.stderr.isatty()
|
|
19
|
+
|
|
20
|
+
# One console object is reused everywhere so that Rich keeps a consistent idea
|
|
21
|
+
# of whether it may emit ANSI control codes / animations.
|
|
22
|
+
_console = Console(
|
|
23
|
+
width=100,
|
|
24
|
+
force_terminal=_IS_TTY,
|
|
25
|
+
force_interactive=_IS_TTY,
|
|
26
|
+
no_color=not _IS_TTY,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Template for the progress bar
|
|
30
|
+
ProgressBar = partial(
|
|
31
|
+
Progress,
|
|
32
|
+
TextColumn("[task.description]{task.description}"),
|
|
33
|
+
BarColumn(),
|
|
34
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
35
|
+
TextColumn("•"),
|
|
36
|
+
TimeRemainingColumn(),
|
|
37
|
+
TextColumn("•"),
|
|
38
|
+
TimeElapsedColumn(),
|
|
39
|
+
refresh_per_second=1 if _IS_TTY else 0.05,
|
|
40
|
+
speed_estimate_period=30 if _IS_TTY else 120,
|
|
41
|
+
console=_console,
|
|
42
|
+
disable=not _IS_TTY,
|
|
43
|
+
transient=_IS_TTY,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Helper for setup the logger
|
|
48
|
+
def setup_logger(level: str = "INFO"):
|
|
49
|
+
"""Set up the logger.
|
|
50
|
+
|
|
51
|
+
Configure Rich logging so that colourful / interactive output is used when
|
|
52
|
+
the program is attached to a terminal and plain text is written when it is
|
|
53
|
+
executed under a batch system such as Slurm (where stdout / stderr are files).
|
|
54
|
+
"""
|
|
55
|
+
FORMAT = "%(message)s"
|
|
56
|
+
|
|
57
|
+
# In a batch job we create a console that never emits colour codes.
|
|
58
|
+
console = None
|
|
59
|
+
if not _IS_TTY:
|
|
60
|
+
console = Console(
|
|
61
|
+
width=120,
|
|
62
|
+
force_terminal=False,
|
|
63
|
+
force_interactive=False,
|
|
64
|
+
no_color=True,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
handler = RichHandler(
|
|
68
|
+
show_time=False,
|
|
69
|
+
show_path=False,
|
|
70
|
+
markup=True,
|
|
71
|
+
rich_tracebacks=True,
|
|
72
|
+
console=console,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logging.basicConfig(level=level, format=FORMAT, handlers=[handler])
|
|
76
|
+
return logging
|
|
@@ -21,7 +21,7 @@ from upp.logger import setup_logger
|
|
|
21
21
|
from upp.stages.hist import create_histograms
|
|
22
22
|
from upp.stages.merging import Merging
|
|
23
23
|
from upp.stages.normalisation import Normalisation
|
|
24
|
-
from upp.stages.plot import
|
|
24
|
+
from upp.stages.plot import plot_resampling_dists
|
|
25
25
|
from upp.stages.resampling import Resampling
|
|
26
26
|
|
|
27
27
|
|
|
@@ -79,7 +79,7 @@ def run_pp(args) -> None:
|
|
|
79
79
|
# run the resampling
|
|
80
80
|
if args.resample:
|
|
81
81
|
resampling = Resampling(config)
|
|
82
|
-
resampling.run(region=args.region)
|
|
82
|
+
resampling.run(region=args.region, component=args.component)
|
|
83
83
|
|
|
84
84
|
# run the merging
|
|
85
85
|
if args.merge:
|
|
@@ -95,8 +95,8 @@ def run_pp(args) -> None:
|
|
|
95
95
|
if args.plot:
|
|
96
96
|
title = " Plotting "
|
|
97
97
|
log.info(f"[bold green]{title:-^100}")
|
|
98
|
-
|
|
99
|
-
|
|
98
|
+
plot_resampling_dists(config=config, stage="initial")
|
|
99
|
+
plot_resampling_dists(config=config, stage=args.split)
|
|
100
100
|
|
|
101
101
|
# print end info
|
|
102
102
|
end = datetime.now()
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging as log
|
|
5
|
+
from copy import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from ftag.hdf5 import H5Writer, join_structured_arrays
|
|
11
|
+
|
|
12
|
+
from upp.logger import ProgressBar
|
|
13
|
+
from upp.utils import path_append
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
16
|
+
from upp.classes.components import Component, Components
|
|
17
|
+
from upp.classes.preprocessing_config import PreprocessingConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Merging:
|
|
21
|
+
"""Merging Class to merge different components/regions."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: PreprocessingConfig):
|
|
24
|
+
"""Init the Merging class instance.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
config : PreprocessingConfig
|
|
29
|
+
Loaded preprocessing config as a PreprocessingConfig instance
|
|
30
|
+
"""
|
|
31
|
+
self.config = config
|
|
32
|
+
self.components = config.components
|
|
33
|
+
self.variables = config.variables
|
|
34
|
+
self.batch_size = config.batch_size
|
|
35
|
+
self.jets_name = config.jets_name
|
|
36
|
+
self.rng = np.random.default_rng(42)
|
|
37
|
+
self.flavours = self.components.flavours
|
|
38
|
+
self.num_jets_per_output_file = config.num_jets_per_output_file
|
|
39
|
+
self.file_tag = "split"
|
|
40
|
+
|
|
41
|
+
def add_jet_flavour_label(self, jets: np.ndarray, component: Component) -> np.ndarray:
|
|
42
|
+
"""Add the jet flavour label to the jets.
|
|
43
|
+
|
|
44
|
+
If already present, jets will be returned without any changes.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
jets : np.ndarray
|
|
49
|
+
Structured array of with the jets and their variables
|
|
50
|
+
component : Component
|
|
51
|
+
Component instance of the
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
np.ndarray
|
|
56
|
+
Structured array of the jets and their variables with the
|
|
57
|
+
"flavour_label" added.
|
|
58
|
+
"""
|
|
59
|
+
if "flavour_label" in jets.dtype.names:
|
|
60
|
+
return jets
|
|
61
|
+
int_label = self.flavours.index(component.flavour)
|
|
62
|
+
label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
|
|
63
|
+
|
|
64
|
+
return join_structured_arrays([jets, label_array])
|
|
65
|
+
|
|
66
|
+
def _open_writer(
|
|
67
|
+
self,
|
|
68
|
+
sample: str | None,
|
|
69
|
+
jets_in_file: int,
|
|
70
|
+
file_idx: int,
|
|
71
|
+
components: Components,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Create `self.writer` for the next output file and attach all static attributes.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
sample
|
|
78
|
+
Sample name (``None`` for the "train/val test" merge).
|
|
79
|
+
jets_in_file
|
|
80
|
+
Capacity of the new file (= leading dimension of every dataset).
|
|
81
|
+
file_idx
|
|
82
|
+
Running part index (0, 1, 2, …); used only for the filename suffix.
|
|
83
|
+
components
|
|
84
|
+
The `Components` object we are currently merging needed for `jet_counts`, etc.
|
|
85
|
+
"""
|
|
86
|
+
# Construct the filename
|
|
87
|
+
fname = Path(self.config.out_fname)
|
|
88
|
+
|
|
89
|
+
if sample:
|
|
90
|
+
fname = path_append(fname, sample)
|
|
91
|
+
|
|
92
|
+
if self.num_jets_per_output_file is not None:
|
|
93
|
+
suffix = f"{self.file_tag}_{file_idx:03d}"
|
|
94
|
+
fname = fname.with_name(f"{fname.stem}_{suffix}{fname.suffix}")
|
|
95
|
+
|
|
96
|
+
# Adjust shapes to the capacity of this file
|
|
97
|
+
shapes = {name: (jets_in_file,) + shape[1:] for name, shape in self.base_shapes.items()}
|
|
98
|
+
|
|
99
|
+
# Instantiate an H5Writer
|
|
100
|
+
self.writer = H5Writer(
|
|
101
|
+
fname,
|
|
102
|
+
self.dtypes,
|
|
103
|
+
shapes,
|
|
104
|
+
add_flavour_label=self.jets_name,
|
|
105
|
+
jets_name=self.jets_name,
|
|
106
|
+
num_jets=jets_in_file,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Copy the metadata attributes
|
|
110
|
+
self.writer.add_attr(
|
|
111
|
+
"flavour_label",
|
|
112
|
+
[f.name for f in self.flavours],
|
|
113
|
+
self.jets_name,
|
|
114
|
+
)
|
|
115
|
+
self.writer.add_attr("unique_jets", components.unique_jets)
|
|
116
|
+
self.writer.add_attr("jet_counts", json.dumps(components.jet_counts))
|
|
117
|
+
self.writer.add_attr("dsids", str(components.dsids))
|
|
118
|
+
self.writer.add_attr("config", json.dumps(self.config.config))
|
|
119
|
+
self.writer.add_attr("upp_hash", self.config.git_hash)
|
|
120
|
+
|
|
121
|
+
# Log for debugging
|
|
122
|
+
log.debug(f"Setup merge output at {self.writer.dst}")
|
|
123
|
+
|
|
124
|
+
def write_chunk(self, components: Components) -> int:
|
|
125
|
+
"""Read one chunk, merge and write it to disk.
|
|
126
|
+
|
|
127
|
+
Read one batch from every active component, merge them and write
|
|
128
|
+
them to disk. If the batch does not fit into the current file it is
|
|
129
|
+
split across files transparently.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
int
|
|
134
|
+
The number of jets that were consumed from the components
|
|
135
|
+
(== written to disk). When all components are exhausted the
|
|
136
|
+
function returns 0 so that the caller can stop its loop.
|
|
137
|
+
"""
|
|
138
|
+
# Init a merged dict
|
|
139
|
+
merged: dict[str, np.ndarray] = {}
|
|
140
|
+
|
|
141
|
+
# Loop over components
|
|
142
|
+
for component in components:
|
|
143
|
+
try:
|
|
144
|
+
# shallow copy because we will add a field
|
|
145
|
+
batch = copy(next(component.stream))
|
|
146
|
+
batch[self.jets_name] = self.add_jet_flavour_label(
|
|
147
|
+
jets=batch[self.jets_name], component=component
|
|
148
|
+
)
|
|
149
|
+
except StopIteration:
|
|
150
|
+
component.complete = True
|
|
151
|
+
|
|
152
|
+
if component.complete:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Merge this component's arrays into the running dict
|
|
156
|
+
for name, array in batch.items():
|
|
157
|
+
if name not in merged:
|
|
158
|
+
merged[name] = array
|
|
159
|
+
else:
|
|
160
|
+
merged[name] = np.concatenate([merged[name], array])
|
|
161
|
+
|
|
162
|
+
# Stop if there is nothing more to read
|
|
163
|
+
if all(c.complete for c in components):
|
|
164
|
+
return 0
|
|
165
|
+
|
|
166
|
+
# Apply track selections
|
|
167
|
+
for name in self.variables.variables:
|
|
168
|
+
if name == self.jets_name:
|
|
169
|
+
continue
|
|
170
|
+
if selector := self.variables.selectors.get(name):
|
|
171
|
+
merged[name] = selector(merged[name])
|
|
172
|
+
|
|
173
|
+
# Get the total length of jets from the batch and how much
|
|
174
|
+
# capacity is left in the file
|
|
175
|
+
merged_len = len(merged[self.jets_name])
|
|
176
|
+
capacity_left = self.writer.num_jets - self.writer.num_written
|
|
177
|
+
|
|
178
|
+
# Check if the capacity of the given file is already zero
|
|
179
|
+
if capacity_left == 0:
|
|
180
|
+
# close the filled file
|
|
181
|
+
self.writer.close()
|
|
182
|
+
|
|
183
|
+
# open the next one
|
|
184
|
+
self._file_idx += 1
|
|
185
|
+
remaining_total = self.total_jets - self.jets_written
|
|
186
|
+
|
|
187
|
+
# Quit writing when no jets are left to write
|
|
188
|
+
if remaining_total == 0:
|
|
189
|
+
return 0
|
|
190
|
+
|
|
191
|
+
next_file_size = (
|
|
192
|
+
min(self.num_jets_per_output_file, remaining_total)
|
|
193
|
+
if self.num_jets_per_output_file
|
|
194
|
+
else remaining_total
|
|
195
|
+
)
|
|
196
|
+
self._open_writer(
|
|
197
|
+
self._sample,
|
|
198
|
+
next_file_size,
|
|
199
|
+
self._file_idx,
|
|
200
|
+
self.current_components,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Recompute free space in the freshly-opened file
|
|
204
|
+
capacity_left = self.writer.num_jets - self.writer.num_written
|
|
205
|
+
|
|
206
|
+
# Check if the whole batch fits into the file
|
|
207
|
+
if merged_len <= capacity_left:
|
|
208
|
+
# whole batch fits
|
|
209
|
+
self.writer.write(merged)
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
# Write the *head* that still fits into the present file
|
|
213
|
+
head = {n: a[:capacity_left] for n, a in merged.items()}
|
|
214
|
+
self.writer.write(head)
|
|
215
|
+
self.writer.close()
|
|
216
|
+
|
|
217
|
+
# Open a fresh file sized for the remaining jets
|
|
218
|
+
self._file_idx += 1
|
|
219
|
+
remaining_total = self.total_jets - (self.jets_written + capacity_left)
|
|
220
|
+
next_file_size = (
|
|
221
|
+
min(self.num_jets_per_output_file, remaining_total)
|
|
222
|
+
if self.num_jets_per_output_file
|
|
223
|
+
else remaining_total
|
|
224
|
+
)
|
|
225
|
+
self._open_writer(self._sample, next_file_size, self._file_idx, self.current_components)
|
|
226
|
+
|
|
227
|
+
# Write the *tail* that goes into the new file
|
|
228
|
+
tail = {n: a[capacity_left:] for n, a in merged.items()}
|
|
229
|
+
self.writer.write(tail)
|
|
230
|
+
|
|
231
|
+
# Updating the progress-bar
|
|
232
|
+
self.jets_written += merged_len
|
|
233
|
+
return merged_len
|
|
234
|
+
|
|
235
|
+
def write_components(self, sample: str | None, components: Components) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Merge *components* into one or more HDF5 files.
|
|
238
|
+
|
|
239
|
+
If ``self.num_jets_per_output_file`` is ``None`` the behaviour is identical to the
|
|
240
|
+
original implementation (exactly one output file). Otherwise the function
|
|
241
|
+
keeps opening new `H5Writer`s whenever the current file reaches that jet
|
|
242
|
+
limit. All heavy work (splitting batches, rolling files) is handled in
|
|
243
|
+
``self.write_chunk``.
|
|
244
|
+
"""
|
|
245
|
+
# Prepare every Component's reader
|
|
246
|
+
for component in components:
|
|
247
|
+
batch_size = self.batch_size * component.num_jets // components.num_jets + 1
|
|
248
|
+
component.setup_reader(
|
|
249
|
+
batch_size,
|
|
250
|
+
fname=component.out_path,
|
|
251
|
+
jets_name=self.jets_name,
|
|
252
|
+
)
|
|
253
|
+
component.stream = component.reader.stream(
|
|
254
|
+
self.variables.combined(),
|
|
255
|
+
component.reader.num_jets,
|
|
256
|
+
)
|
|
257
|
+
component.complete = False
|
|
258
|
+
|
|
259
|
+
# Cache dtype / base shapes once (re-used for every new file)
|
|
260
|
+
self.dtypes = components[0].reader.dtypes(self.variables.combined())
|
|
261
|
+
self.base_shapes = components[0].reader.shapes(components.num_jets, self.variables.keys())
|
|
262
|
+
|
|
263
|
+
# Bookkeeping shared with write_chunk
|
|
264
|
+
self.total_jets = components.num_jets
|
|
265
|
+
self.jets_written = 0
|
|
266
|
+
self._file_idx = 0
|
|
267
|
+
self._sample = sample
|
|
268
|
+
self.current_components = components
|
|
269
|
+
|
|
270
|
+
# decide capacity of the first file
|
|
271
|
+
first_file_size = (
|
|
272
|
+
min(self.num_jets_per_output_file, self.total_jets)
|
|
273
|
+
if self.num_jets_per_output_file
|
|
274
|
+
else self.total_jets
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Open the first output file
|
|
278
|
+
self._open_writer(sample, first_file_size, self._file_idx, components)
|
|
279
|
+
|
|
280
|
+
# Main merge loop (progress bar unchanged)
|
|
281
|
+
with ProgressBar() as progress:
|
|
282
|
+
task = progress.add_task(
|
|
283
|
+
f"[green]Merging {components.num_jets:,} jets...",
|
|
284
|
+
total=components.num_jets,
|
|
285
|
+
)
|
|
286
|
+
while True:
|
|
287
|
+
n = self.write_chunk(components)
|
|
288
|
+
if not n:
|
|
289
|
+
break
|
|
290
|
+
progress.update(task, advance=n)
|
|
291
|
+
|
|
292
|
+
# Close Writer
|
|
293
|
+
self.writer.close()
|
|
294
|
+
label = "merged" if sample is None else sample
|
|
295
|
+
log.info(f"[bold green]Finished merging {components.num_jets:,} {label} jets!")
|
|
296
|
+
|
|
297
|
+
def run(self):
|
|
298
|
+
"""Run merging of the components."""
|
|
299
|
+
title = " Running Merging "
|
|
300
|
+
log.info(f"[bold green]{title:-^100}")
|
|
301
|
+
|
|
302
|
+
if not self.config.is_test or self.config.merge_test_samples:
|
|
303
|
+
components = [(None, self.components)]
|
|
304
|
+
else:
|
|
305
|
+
components = self.components.groupby_sample()
|
|
306
|
+
|
|
307
|
+
for sample, comps in components:
|
|
308
|
+
self.write_components(sample, comps)
|
|
@@ -241,14 +241,21 @@ class Normalisation:
|
|
|
241
241
|
title = " Computing Normalisations "
|
|
242
242
|
log.info(f"[bold green]{title:-^100}")
|
|
243
243
|
|
|
244
|
+
# Get the correct output names if multiple output files were written
|
|
245
|
+
if self.config.num_jets_per_output_file:
|
|
246
|
+
fname = self.config.out_fname.parent / f"{self.config.out_fname.stem}*.h5"
|
|
247
|
+
|
|
248
|
+
else:
|
|
249
|
+
fname = self.config.out_fname
|
|
250
|
+
|
|
244
251
|
# Setup reader
|
|
245
252
|
reader = H5Reader(
|
|
246
|
-
|
|
253
|
+
fname,
|
|
247
254
|
self.config.batch_size,
|
|
248
255
|
precision="full",
|
|
249
256
|
jets_name=self.jets_name,
|
|
250
257
|
)
|
|
251
|
-
log.debug(f"Setup reader at: {
|
|
258
|
+
log.debug(f"Setup reader at: {fname}")
|
|
252
259
|
|
|
253
260
|
norm_dict = None
|
|
254
261
|
class_dict = None
|