umami-preprocessing 0.2.6__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {umami_preprocessing-0.2.6/umami_preprocessing.egg-info → umami_preprocessing-0.2.7}/PKG-INFO +22 -15
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/pyproject.toml +21 -14
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7/umami_preprocessing.egg-info}/PKG-INFO +22 -15
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/SOURCES.txt +4 -0
- umami_preprocessing-0.2.7/umami_preprocessing.egg-info/requires.txt +20 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/__init__.py +1 -1
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/components.py +3 -2
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/preprocessing_config.py +64 -12
- umami_preprocessing-0.2.7/upp/classes/reweight_config.py +78 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/main.py +81 -5
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/normalisation.py +4 -2
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/resampling.py +2 -1
- umami_preprocessing-0.2.7/upp/stages/reweight.py +465 -0
- umami_preprocessing-0.2.7/upp/stages/rw_merge.py +314 -0
- umami_preprocessing-0.2.7/upp/stages/split_containers.py +386 -0
- umami_preprocessing-0.2.6/umami_preprocessing.egg-info/requires.txt +0 -15
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/LICENSE +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/MANIFEST.in +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/README.md +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/setup.cfg +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/entry_points.txt +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/top_level.txt +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/__init__.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/region.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/resampling_config.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/classes/variable_config.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/__init__.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/hist.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/interpolation.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/merging.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/stages/plot.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/__init__.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/check_input_samples.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/logger.py +0 -0
- {umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/upp/utils/tools.py +0 -0
{umami_preprocessing-0.2.6/umami_preprocessing.egg-info → umami_preprocessing-0.2.7}/PKG-INFO
RENAMED
|
@@ -1,26 +1,33 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.2.
|
|
4
|
-
Summary: Preprocessing
|
|
3
|
+
Version: 0.2.7
|
|
4
|
+
Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
|
|
5
|
+
Author: Alexander Froch
|
|
5
6
|
License: MIT
|
|
6
7
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
|
-
|
|
8
|
+
Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/issues
|
|
9
|
+
Requires-Python: <3.12,>=3.10
|
|
8
10
|
Description-Content-Type: text/markdown
|
|
9
11
|
License-File: LICENSE
|
|
10
|
-
Requires-Dist: atlas-ftag-tools==0.2.
|
|
11
|
-
Requires-Dist: dotmap
|
|
12
|
-
Requires-Dist:
|
|
12
|
+
Requires-Dist: atlas-ftag-tools==0.2.17
|
|
13
|
+
Requires-Dist: dotmap>=1.3.30
|
|
14
|
+
Requires-Dist: numpy>=2.2.6
|
|
15
|
+
Requires-Dist: puma-hep==0.4.11
|
|
13
16
|
Requires-Dist: pyyaml-include==1.3
|
|
14
|
-
Requires-Dist: PyYAML>=6.0.
|
|
15
|
-
Requires-Dist: rich
|
|
16
|
-
Requires-Dist: scipy>=1.15.
|
|
17
|
+
Requires-Dist: PyYAML>=6.0.2
|
|
18
|
+
Requires-Dist: rich>=14.1.0
|
|
19
|
+
Requires-Dist: scipy>=1.15.3
|
|
17
20
|
Provides-Extra: dev
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist:
|
|
23
|
-
Requires-Dist:
|
|
21
|
+
Requires-Dist: coverage>=7.10.6; extra == "dev"
|
|
22
|
+
Requires-Dist: ipykernel>=6.30.1; extra == "dev"
|
|
23
|
+
Requires-Dist: mypy>=1.18.1; extra == "dev"
|
|
24
|
+
Requires-Dist: pre-commit>=4.3.0; extra == "dev"
|
|
25
|
+
Requires-Dist: pydoclint>=0.7.3; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest_notebook>=0.10.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest-randomly>=4.0.1; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest>=8.4.2; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff>=0.13.0; extra == "dev"
|
|
24
31
|
Dynamic: license-file
|
|
25
32
|
|
|
26
33
|
[](https://github.com/psf/black)
|
|
@@ -1,33 +1,40 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "umami-preprocessing"
|
|
3
|
-
description = "Preprocessing
|
|
3
|
+
description = "ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)"
|
|
4
|
+
authors = [{name="Alexander Froch"}]
|
|
4
5
|
dynamic = ["version"]
|
|
5
6
|
license = {text = "MIT"}
|
|
6
7
|
readme = "README.md"
|
|
7
|
-
requires-python = "
|
|
8
|
+
requires-python = ">=3.10,<3.12"
|
|
8
9
|
|
|
9
10
|
dependencies = [
|
|
10
|
-
"atlas-ftag-tools==0.2.
|
|
11
|
-
"dotmap
|
|
12
|
-
"
|
|
11
|
+
"atlas-ftag-tools==0.2.17",
|
|
12
|
+
"dotmap>=1.3.30",
|
|
13
|
+
"numpy>=2.2.6",
|
|
14
|
+
"puma-hep==0.4.11",
|
|
13
15
|
"pyyaml-include==1.3",
|
|
14
|
-
"PyYAML>=6.0.
|
|
15
|
-
"rich
|
|
16
|
-
"scipy>=1.15.
|
|
16
|
+
"PyYAML>=6.0.2",
|
|
17
|
+
"rich>=14.1.0",
|
|
18
|
+
"scipy>=1.15.3",
|
|
17
19
|
]
|
|
18
20
|
|
|
19
21
|
[project.optional-dependencies]
|
|
20
22
|
dev = [
|
|
21
|
-
"
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"
|
|
23
|
+
"coverage>=7.10.6",
|
|
24
|
+
"ipykernel>=6.30.1",
|
|
25
|
+
"mypy>=1.18.1",
|
|
26
|
+
"pre-commit>=4.3.0",
|
|
27
|
+
"pydoclint>=0.7.3",
|
|
28
|
+
"pytest_notebook>=0.10.0",
|
|
29
|
+
"pytest-cov>=7.0.0",
|
|
30
|
+
"pytest-randomly>=4.0.1",
|
|
31
|
+
"pytest>=8.4.2",
|
|
32
|
+
"ruff>=0.13.0",
|
|
27
33
|
]
|
|
28
34
|
|
|
29
35
|
[project.urls]
|
|
30
36
|
"Homepage" = "https://github.com/umami-hep/umami-preprocessing"
|
|
37
|
+
"Issue Tracker" = "https://github.com/umami-hep/umami-preprocessing/issues"
|
|
31
38
|
|
|
32
39
|
[project.scripts]
|
|
33
40
|
preprocess = "upp.main:main"
|
{umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7/umami_preprocessing.egg-info}/PKG-INFO
RENAMED
|
@@ -1,26 +1,33 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.2.
|
|
4
|
-
Summary: Preprocessing
|
|
3
|
+
Version: 0.2.7
|
|
4
|
+
Summary: ATLAS Flavour Tagging Preprocessing - Umami PreProcessing (UPP)
|
|
5
|
+
Author: Alexander Froch
|
|
5
6
|
License: MIT
|
|
6
7
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
|
-
|
|
8
|
+
Project-URL: Issue Tracker, https://github.com/umami-hep/umami-preprocessing/issues
|
|
9
|
+
Requires-Python: <3.12,>=3.10
|
|
8
10
|
Description-Content-Type: text/markdown
|
|
9
11
|
License-File: LICENSE
|
|
10
|
-
Requires-Dist: atlas-ftag-tools==0.2.
|
|
11
|
-
Requires-Dist: dotmap
|
|
12
|
-
Requires-Dist:
|
|
12
|
+
Requires-Dist: atlas-ftag-tools==0.2.17
|
|
13
|
+
Requires-Dist: dotmap>=1.3.30
|
|
14
|
+
Requires-Dist: numpy>=2.2.6
|
|
15
|
+
Requires-Dist: puma-hep==0.4.11
|
|
13
16
|
Requires-Dist: pyyaml-include==1.3
|
|
14
|
-
Requires-Dist: PyYAML>=6.0.
|
|
15
|
-
Requires-Dist: rich
|
|
16
|
-
Requires-Dist: scipy>=1.15.
|
|
17
|
+
Requires-Dist: PyYAML>=6.0.2
|
|
18
|
+
Requires-Dist: rich>=14.1.0
|
|
19
|
+
Requires-Dist: scipy>=1.15.3
|
|
17
20
|
Provides-Extra: dev
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist:
|
|
23
|
-
Requires-Dist:
|
|
21
|
+
Requires-Dist: coverage>=7.10.6; extra == "dev"
|
|
22
|
+
Requires-Dist: ipykernel>=6.30.1; extra == "dev"
|
|
23
|
+
Requires-Dist: mypy>=1.18.1; extra == "dev"
|
|
24
|
+
Requires-Dist: pre-commit>=4.3.0; extra == "dev"
|
|
25
|
+
Requires-Dist: pydoclint>=0.7.3; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest_notebook>=0.10.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest-randomly>=4.0.1; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest>=8.4.2; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff>=0.13.0; extra == "dev"
|
|
24
31
|
Dynamic: license-file
|
|
25
32
|
|
|
26
33
|
[](https://github.com/psf/black)
|
{umami_preprocessing-0.2.6 → umami_preprocessing-0.2.7}/umami_preprocessing.egg-info/SOURCES.txt
RENAMED
|
@@ -15,6 +15,7 @@ upp/classes/components.py
|
|
|
15
15
|
upp/classes/preprocessing_config.py
|
|
16
16
|
upp/classes/region.py
|
|
17
17
|
upp/classes/resampling_config.py
|
|
18
|
+
upp/classes/reweight_config.py
|
|
18
19
|
upp/classes/variable_config.py
|
|
19
20
|
upp/stages/__init__.py
|
|
20
21
|
upp/stages/hist.py
|
|
@@ -23,6 +24,9 @@ upp/stages/merging.py
|
|
|
23
24
|
upp/stages/normalisation.py
|
|
24
25
|
upp/stages/plot.py
|
|
25
26
|
upp/stages/resampling.py
|
|
27
|
+
upp/stages/reweight.py
|
|
28
|
+
upp/stages/rw_merge.py
|
|
29
|
+
upp/stages/split_containers.py
|
|
26
30
|
upp/utils/__init__.py
|
|
27
31
|
upp/utils/check_input_samples.py
|
|
28
32
|
upp/utils/logger.py
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
atlas-ftag-tools==0.2.17
|
|
2
|
+
dotmap>=1.3.30
|
|
3
|
+
numpy>=2.2.6
|
|
4
|
+
puma-hep==0.4.11
|
|
5
|
+
pyyaml-include==1.3
|
|
6
|
+
PyYAML>=6.0.2
|
|
7
|
+
rich>=14.1.0
|
|
8
|
+
scipy>=1.15.3
|
|
9
|
+
|
|
10
|
+
[dev]
|
|
11
|
+
coverage>=7.10.6
|
|
12
|
+
ipykernel>=6.30.1
|
|
13
|
+
mypy>=1.18.1
|
|
14
|
+
pre-commit>=4.3.0
|
|
15
|
+
pydoclint>=0.7.3
|
|
16
|
+
pytest_notebook>=0.10.0
|
|
17
|
+
pytest-cov>=7.0.0
|
|
18
|
+
pytest-randomly>=4.0.1
|
|
19
|
+
pytest>=8.4.2
|
|
20
|
+
ruff>=0.13.0
|
|
@@ -336,6 +336,7 @@ class Components:
|
|
|
336
336
|
pattern=pattern,
|
|
337
337
|
ntuple_dir=config.ntuple_dir,
|
|
338
338
|
name=component["sample"]["name"],
|
|
339
|
+
skip_checks=config.skip_checks,
|
|
339
340
|
)
|
|
340
341
|
|
|
341
342
|
# Create the Component instances for the different flavours
|
|
@@ -360,7 +361,7 @@ class Components:
|
|
|
360
361
|
components = cls(component_list)
|
|
361
362
|
|
|
362
363
|
# Check the flavour ratios
|
|
363
|
-
if config.sampl_cfg.method is not None:
|
|
364
|
+
if config.sampl_cfg and config.sampl_cfg.method is not None:
|
|
364
365
|
components.check_flavour_ratios()
|
|
365
366
|
|
|
366
367
|
return components
|
|
@@ -510,7 +511,7 @@ class Components:
|
|
|
510
511
|
def __getitem__(self, index):
|
|
511
512
|
if isinstance(index, int):
|
|
512
513
|
return self.components[index]
|
|
513
|
-
if isinstance(index,
|
|
514
|
+
if isinstance(index, str | Label):
|
|
514
515
|
return self.components[self.flavours.index(index)]
|
|
515
516
|
|
|
516
517
|
def __len__(self):
|
|
@@ -20,6 +20,7 @@ from yamlinclude import YamlIncludeConstructor
|
|
|
20
20
|
from upp import __version__
|
|
21
21
|
from upp.classes.components import Components
|
|
22
22
|
from upp.classes.resampling_config import ResamplingConfig
|
|
23
|
+
from upp.classes.reweight_config import ReweightConfig
|
|
23
24
|
from upp.classes.variable_config import VariableConfig
|
|
24
25
|
from upp.utils.tools import path_append
|
|
25
26
|
|
|
@@ -114,6 +115,8 @@ class PreprocessingConfig:
|
|
|
114
115
|
than this number, the final h5 output files are splitted in multiple smaller
|
|
115
116
|
files with this number of jets per file. By default None which produces one
|
|
116
117
|
huge output file.
|
|
118
|
+
skip_checks : bool, optional
|
|
119
|
+
Skip checks for the input files. This is used for grid submission
|
|
117
120
|
skip_config_copy : bool, optional
|
|
118
121
|
Decide, if the config copying is skipped or not. By default False
|
|
119
122
|
"""
|
|
@@ -137,6 +140,7 @@ class PreprocessingConfig:
|
|
|
137
140
|
flavour_config: Path | None = None
|
|
138
141
|
flavour_category: str = "standard"
|
|
139
142
|
num_jets_per_output_file: int | None = None
|
|
143
|
+
skip_checks: bool = False
|
|
140
144
|
skip_config_copy: bool = False
|
|
141
145
|
|
|
142
146
|
def __post_init__(self):
|
|
@@ -154,11 +158,10 @@ class PreprocessingConfig:
|
|
|
154
158
|
for field in dataclasses.fields(self):
|
|
155
159
|
if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
|
|
156
160
|
setattr(self, field.name, self.get_path(Path(getattr(self, field.name))))
|
|
157
|
-
if not self.ntuple_dir.exists():
|
|
161
|
+
if not self.ntuple_dir.exists() and not self.skip_checks:
|
|
158
162
|
raise FileNotFoundError(f"Path {self.ntuple_dir} does not exist")
|
|
159
163
|
self.components_dir = self.components_dir / self.split
|
|
160
164
|
self.out_fname = self.out_dir / path_append(self.out_fname, self.split)
|
|
161
|
-
|
|
162
165
|
# Define the content of the flavour label container
|
|
163
166
|
if self.flavour_config:
|
|
164
167
|
self.flavour_cont = LabelContainer.from_yaml(
|
|
@@ -177,12 +180,15 @@ class PreprocessingConfig:
|
|
|
177
180
|
"flavours! If you want to use your own flavour config yaml file, please "
|
|
178
181
|
"provide flavour_config!"
|
|
179
182
|
)
|
|
180
|
-
|
|
181
183
|
# configure classes
|
|
182
|
-
sampl_cfg
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
184
|
+
if sampl_cfg := self.config.get("resampling", None):
|
|
185
|
+
sampl_cfg = copy(sampl_cfg)
|
|
186
|
+
if self.is_test:
|
|
187
|
+
sampl_cfg["method"] = None
|
|
188
|
+
self.sampl_cfg = ResamplingConfig(**sampl_cfg)
|
|
189
|
+
else:
|
|
190
|
+
self.sampl_cfg = None
|
|
191
|
+
|
|
186
192
|
self.components = Components.from_config(self)
|
|
187
193
|
|
|
188
194
|
# get track selectors
|
|
@@ -196,13 +202,21 @@ class PreprocessingConfig:
|
|
|
196
202
|
self.variables = VariableConfig(
|
|
197
203
|
self.config["variables"], self.jets_name, self.is_test, selectors
|
|
198
204
|
)
|
|
199
|
-
self.
|
|
200
|
-
|
|
201
|
-
|
|
205
|
+
if self.sampl_cfg is not None:
|
|
206
|
+
self.variables = self.variables.add_jet_vars(
|
|
207
|
+
list(self.config["resampling"]["variables"].keys()), "labels"
|
|
208
|
+
)
|
|
202
209
|
self.transform = (
|
|
203
210
|
Transform(**self.config["transform"]) if "transform" in self.config else None
|
|
204
211
|
)
|
|
205
212
|
|
|
213
|
+
self.rw_config = (
|
|
214
|
+
ReweightConfig(
|
|
215
|
+
**self.config["reweighting"],
|
|
216
|
+
)
|
|
217
|
+
if "reweighting" in self.config
|
|
218
|
+
else None
|
|
219
|
+
)
|
|
206
220
|
# reproducibility
|
|
207
221
|
self.git_hash = get_git_hash(Path(__file__).parent)
|
|
208
222
|
if self.git_hash is None:
|
|
@@ -214,7 +228,13 @@ class PreprocessingConfig:
|
|
|
214
228
|
self.copy_config()
|
|
215
229
|
|
|
216
230
|
@classmethod
|
|
217
|
-
def from_file(
|
|
231
|
+
def from_file(
|
|
232
|
+
cls,
|
|
233
|
+
config_path: Path,
|
|
234
|
+
split: Split,
|
|
235
|
+
skip_checks: bool = False,
|
|
236
|
+
skip_config_copy: bool = False,
|
|
237
|
+
):
|
|
218
238
|
if not config_path.exists():
|
|
219
239
|
raise FileNotFoundError(f"{config_path} does not exist - check your --config arg")
|
|
220
240
|
with open(config_path) as file:
|
|
@@ -225,6 +245,7 @@ class PreprocessingConfig:
|
|
|
225
245
|
config=config,
|
|
226
246
|
skip_config_copy=skip_config_copy,
|
|
227
247
|
**config["global"],
|
|
248
|
+
skip_checks=skip_checks,
|
|
228
249
|
)
|
|
229
250
|
|
|
230
251
|
def get_path(self, path: Path):
|
|
@@ -238,7 +259,7 @@ class PreprocessingConfig:
|
|
|
238
259
|
def global_cuts(self):
|
|
239
260
|
cuts_list = self.config["global_cuts"].get("common", [])
|
|
240
261
|
cuts_list += self.config["global_cuts"][self.split]
|
|
241
|
-
if not self.is_test:
|
|
262
|
+
if not self.is_test and self.config.get("resampling", None) is not None:
|
|
242
263
|
for resampling_var, cfg in self.config["resampling"]["variables"].items():
|
|
243
264
|
cuts_list.append([resampling_var, ">", cfg["bins"][0][0]])
|
|
244
265
|
cuts_list.append([resampling_var, "<", cfg["bins"][-1][1]])
|
|
@@ -357,3 +378,34 @@ class PreprocessingConfig:
|
|
|
357
378
|
f"Option value {option} is not supported! "
|
|
358
379
|
"Only resampled and resampled_scaled_shuffled are."
|
|
359
380
|
)
|
|
381
|
+
|
|
382
|
+
# Static because otherwise config paths end up getting messed up
|
|
383
|
+
@staticmethod
|
|
384
|
+
def get_input_files_with_split_components(config_path):
|
|
385
|
+
"""Return a nested dictionary of the form.
|
|
386
|
+
|
|
387
|
+
{
|
|
388
|
+
"container_1" : {
|
|
389
|
+
"train_bjets" : list[str] : -> which are cuts
|
|
390
|
+
"test_bjets" : list[str] : -> which are cuts
|
|
391
|
+
...
|
|
392
|
+
"test_cjets" : list[str] : -> which are cuts
|
|
393
|
+
|
|
394
|
+
},
|
|
395
|
+
"container_2" : ...
|
|
396
|
+
}
|
|
397
|
+
which represents all splits and components to be run for each input container
|
|
398
|
+
"""
|
|
399
|
+
containers_with_splits = {}
|
|
400
|
+
|
|
401
|
+
for split in ["train", "val", "test"]:
|
|
402
|
+
split_config = PreprocessingConfig.from_file(config_path, split, skip_checks=True)
|
|
403
|
+
|
|
404
|
+
for component in split_config.components.components:
|
|
405
|
+
for container in component.sample.pattern:
|
|
406
|
+
if container not in containers_with_splits:
|
|
407
|
+
containers_with_splits[container] = {}
|
|
408
|
+
component_cuts = component.cuts
|
|
409
|
+
containers_with_splits[container][f"{split}_{component.name}"] = component_cuts
|
|
410
|
+
|
|
411
|
+
return containers_with_splits
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ReweightConfig:
|
|
11
|
+
# Number of jets to estimate, if None, use the global num jets estimate
|
|
12
|
+
num_jets_estimate: None | int = None
|
|
13
|
+
merge_num_proc: int = 1 # Number of processes to use for merging
|
|
14
|
+
reweights: list[SingleReweightConfig] = field(default_factory=list)
|
|
15
|
+
|
|
16
|
+
def __post_init__(self):
|
|
17
|
+
if self.num_jets_estimate is not None and self.num_jets_estimate <= 0:
|
|
18
|
+
raise ValueError("num_jets_estimate must be a positive integer or None")
|
|
19
|
+
|
|
20
|
+
parsed_reweights = []
|
|
21
|
+
for rw in self.reweights:
|
|
22
|
+
parsed_reweights.append(SingleReweightConfig(**rw))
|
|
23
|
+
self.reweights = parsed_reweights
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class SingleReweightConfig:
|
|
28
|
+
group: str # The group our variables in the h5 file are in
|
|
29
|
+
reweight_vars: list[str] # The variables we want to reweight
|
|
30
|
+
bins: list[np.ndarray] # The bins we want to use for the reweighting
|
|
31
|
+
class_var: str # The variable which contains the label we resample over, e.g. flavour
|
|
32
|
+
class_target: int | tuple | str | None = None
|
|
33
|
+
add_overflow: bool = True # Whether to add overflow bins
|
|
34
|
+
|
|
35
|
+
target_hist_func: Callable | None = None
|
|
36
|
+
target_hist_func_name: str | None = None
|
|
37
|
+
|
|
38
|
+
# TODO - this is the same as in resampling, maybe can cleanup
|
|
39
|
+
def get_bins_x(self, bins_x, upscale=1):
|
|
40
|
+
flat_bins = []
|
|
41
|
+
for i, sub_bins_x in enumerate(bins_x):
|
|
42
|
+
start, stop, nbins = sub_bins_x
|
|
43
|
+
b = np.linspace(start, stop, nbins * upscale + 1)
|
|
44
|
+
if i > 0:
|
|
45
|
+
b = b[1:]
|
|
46
|
+
flat_bins.append(b)
|
|
47
|
+
if self.add_overflow:
|
|
48
|
+
flat_bins = [np.array([-np.inf])] + flat_bins + [np.array([np.inf])]
|
|
49
|
+
return np.concatenate(flat_bins)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def flat_bins(self):
|
|
53
|
+
return [self.get_bins_x(self.bins[k]) for k in self.reweight_vars]
|
|
54
|
+
|
|
55
|
+
def __post_init__(self):
|
|
56
|
+
if isinstance(self.class_target, str) and self.class_target not in [
|
|
57
|
+
"mean",
|
|
58
|
+
"min",
|
|
59
|
+
"max",
|
|
60
|
+
"uniform",
|
|
61
|
+
]:
|
|
62
|
+
raise ValueError("class_target must be either 'mean', 'min', 'max' or an integer")
|
|
63
|
+
|
|
64
|
+
if self.target_hist_func is not None and self.target_hist_func_name is None:
|
|
65
|
+
self.target_hist_func_name = self.target_hist_func.__name__
|
|
66
|
+
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
target_str = "target_"
|
|
69
|
+
if self.target_hist_func_name is not None:
|
|
70
|
+
target_str += f"{self.target_hist_func_name}_"
|
|
71
|
+
if self.class_target is not None:
|
|
72
|
+
if isinstance(self.class_target, list | tuple):
|
|
73
|
+
target_str += "_".join(map(str, self.class_target))
|
|
74
|
+
else:
|
|
75
|
+
target_str += f"{self.class_target}_{self.class_var}"
|
|
76
|
+
else:
|
|
77
|
+
target_str += "none"
|
|
78
|
+
return f"weight_{self.group}_{'_'.join(self.reweight_vars)}_{target_str}"
|
|
@@ -23,6 +23,9 @@ from upp.stages.merging import Merging
|
|
|
23
23
|
from upp.stages.normalisation import Normalisation
|
|
24
24
|
from upp.stages.plot import plot_resampling_dists
|
|
25
25
|
from upp.stages.resampling import Resampling
|
|
26
|
+
from upp.stages.reweight import Reweight
|
|
27
|
+
from upp.stages.rw_merge import RWMerge
|
|
28
|
+
from upp.stages.split_containers import SplitContainers
|
|
26
29
|
from upp.utils.check_input_samples import run_input_sample_check
|
|
27
30
|
from upp.utils.logger import setup_logger
|
|
28
31
|
|
|
@@ -116,6 +119,28 @@ def parse_args(args: Any) -> argparse.Namespace:
|
|
|
116
119
|
default=None,
|
|
117
120
|
help="Component which is processed during --prep",
|
|
118
121
|
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--split-components",
|
|
124
|
+
action="store_true",
|
|
125
|
+
default=False,
|
|
126
|
+
help="Split containers into components",
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--reweight", "--rw", action="store_true", default=False, help="Run the reweighting stage"
|
|
130
|
+
)
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--rw-merge", "--rwm", action="store_true", default=False, help="Run the reweighting stage"
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--rw-merge-idx",
|
|
136
|
+
"--rwm-idx",
|
|
137
|
+
type=str,
|
|
138
|
+
default=None,
|
|
139
|
+
help=(
|
|
140
|
+
"Commar seperated pair of indices representing the range of output "
|
|
141
|
+
"files to create, e.g '0,10' will create files 0 to 9"
|
|
142
|
+
),
|
|
143
|
+
)
|
|
119
144
|
parser.add_argument(
|
|
120
145
|
"--region",
|
|
121
146
|
default=None,
|
|
@@ -126,10 +151,37 @@ def parse_args(args: Any) -> argparse.Namespace:
|
|
|
126
151
|
action="store_true",
|
|
127
152
|
help="Skip the inital input sample check",
|
|
128
153
|
)
|
|
154
|
+
parser.add_argument(
|
|
155
|
+
"--grid", action="store_true", help="Use when running the split stage on the grid. "
|
|
156
|
+
)
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"--container",
|
|
159
|
+
default=None,
|
|
160
|
+
type=str,
|
|
161
|
+
help="Container to use during the 'split-containers' stage. "
|
|
162
|
+
"If not specified, all containers in the config will be used.",
|
|
163
|
+
)
|
|
164
|
+
parser.add_argument(
|
|
165
|
+
"--files",
|
|
166
|
+
default=None,
|
|
167
|
+
help="comma-separated list of files to use during the 'split-containers' stage ",
|
|
168
|
+
)
|
|
129
169
|
|
|
130
170
|
args = parser.parse_args(args)
|
|
131
171
|
d = vars(args)
|
|
132
|
-
ignore = [
|
|
172
|
+
ignore = [
|
|
173
|
+
"config",
|
|
174
|
+
"split",
|
|
175
|
+
"component",
|
|
176
|
+
"region",
|
|
177
|
+
"container",
|
|
178
|
+
"files",
|
|
179
|
+
"grid",
|
|
180
|
+
"split_components",
|
|
181
|
+
"reweight",
|
|
182
|
+
"rw_merge",
|
|
183
|
+
"rw_merge_idx",
|
|
184
|
+
]
|
|
133
185
|
if not any(v for a, v in d.items() if a not in ignore):
|
|
134
186
|
for v in d:
|
|
135
187
|
if v not in ignore and d[v] is None:
|
|
@@ -151,10 +203,24 @@ def run_pp(args: argparse.Namespace) -> None:
|
|
|
151
203
|
log.info("[bold green]Starting preprocessing...")
|
|
152
204
|
start = datetime.now()
|
|
153
205
|
log.info(f"Start time: {start.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
154
|
-
|
|
155
206
|
# load config
|
|
156
|
-
config = PreprocessingConfig.from_file(args.config, args.split)
|
|
207
|
+
config = PreprocessingConfig.from_file(args.config, args.split, skip_checks=args.grid)
|
|
157
208
|
|
|
209
|
+
if args.split_components:
|
|
210
|
+
log.info("Splitting containers...")
|
|
211
|
+
split = SplitContainers(args.config)
|
|
212
|
+
split.run(
|
|
213
|
+
args.container,
|
|
214
|
+
args.files,
|
|
215
|
+
"." if args.grid else None, # Use current directory if not on grid
|
|
216
|
+
)
|
|
217
|
+
# If we aren't running on the grid, we create the metadata after splitting
|
|
218
|
+
if not args.grid:
|
|
219
|
+
split.create_meta_data()
|
|
220
|
+
if args.reweight:
|
|
221
|
+
log.info("Running reweighting...")
|
|
222
|
+
reweight = Reweight(config)
|
|
223
|
+
reweight.run()
|
|
158
224
|
# create virtual datasets and pdf files
|
|
159
225
|
if args.prep:
|
|
160
226
|
# Check the input samples sizes
|
|
@@ -180,6 +246,16 @@ def run_pp(args: argparse.Namespace) -> None:
|
|
|
180
246
|
if args.merge:
|
|
181
247
|
merging = Merging(config)
|
|
182
248
|
merging.run()
|
|
249
|
+
if args.rw_merge:
|
|
250
|
+
if args.rw_merge_idx:
|
|
251
|
+
rw_merge_idx = args.rw_merge_idx
|
|
252
|
+
assert "," in rw_merge_idx, "rw-merge-idx must be a comma-separated pair of indices"
|
|
253
|
+
rw_merge_idx = tuple(map(int, rw_merge_idx.split(",")))
|
|
254
|
+
assert len(rw_merge_idx) == 2, "rw-merge-idx must be a pair of indices"
|
|
255
|
+
else:
|
|
256
|
+
rw_merge_idx = None
|
|
257
|
+
rw_merge = RWMerge(config, rw_merge_idx)
|
|
258
|
+
rw_merge.run()
|
|
183
259
|
|
|
184
260
|
# run the normalisation
|
|
185
261
|
if args.norm and args.split == "train":
|
|
@@ -209,10 +285,10 @@ def main(args: Any | None = None) -> None:
|
|
|
209
285
|
d = vars(args)
|
|
210
286
|
for split in ["train", "val", "test"]:
|
|
211
287
|
d["split"] = split
|
|
212
|
-
log.info(f"[bold blue]{'-'*100}")
|
|
288
|
+
log.info(f"[bold blue]{'-' * 100}")
|
|
213
289
|
title = f" {args.split} "
|
|
214
290
|
log.info(f"[bold blue]{title:-^100}")
|
|
215
|
-
log.info(f"[bold blue]{'-'*100}")
|
|
291
|
+
log.info(f"[bold blue]{'-' * 100}")
|
|
216
292
|
run_pp(args)
|
|
217
293
|
else:
|
|
218
294
|
run_pp(args)
|
|
@@ -191,7 +191,7 @@ class Normalisation:
|
|
|
191
191
|
"Class dict A has arrays of different lengths for the same"
|
|
192
192
|
" variable. This should not happen."
|
|
193
193
|
)
|
|
194
|
-
counts_A = dict(zip(*class_dict_A[name][v]))
|
|
194
|
+
counts_A = dict(zip(*class_dict_A[name][v], strict=False))
|
|
195
195
|
counts[i] += counts_A.get(label, 0)
|
|
196
196
|
var[v] = (labels, counts)
|
|
197
197
|
return class_dict_B
|
|
@@ -233,9 +233,11 @@ class Normalisation:
|
|
|
233
233
|
"""Run the normalisation calculation."""
|
|
234
234
|
title = " Computing Normalisations "
|
|
235
235
|
log.info(f"[bold green]{title:-^100}")
|
|
236
|
+
if self.config.rw_config is not None:
|
|
237
|
+
fname = str(self.config.out_fname).replace(".h5", "_vds.h5")
|
|
236
238
|
|
|
237
239
|
# Get the correct output names if multiple output files were written
|
|
238
|
-
|
|
240
|
+
elif self.config.num_jets_per_output_file:
|
|
239
241
|
fname = self.config.out_fname.parent / f"{self.config.out_fname.stem}*.h5"
|
|
240
242
|
|
|
241
243
|
else:
|
|
@@ -15,7 +15,8 @@ from upp.stages.interpolation import subdivide_bins, upscale_array_regionally
|
|
|
15
15
|
from upp.utils.logger import ProgressBar
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING: # pragma: no cover
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Generator
|
|
19
|
+
from typing import Any
|
|
19
20
|
|
|
20
21
|
from rich.progress import Progress
|
|
21
22
|
|