umami-preprocessing 0.0.6__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/PKG-INFO +9 -10
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/pyproject.toml +11 -19
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/PKG-INFO +9 -10
- umami_preprocessing-0.2.0/umami_preprocessing.egg-info/requires.txt +15 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/__init__.py +1 -1
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/components.py +28 -16
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/preprocessing_config.py +52 -14
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/variable_config.py +4 -1
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/main.py +18 -20
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/hist.py +9 -5
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/interpolation.py +4 -4
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/merging.py +12 -2
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/normalisation.py +5 -3
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/plot.py +32 -14
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/resampling.py +13 -12
- umami-preprocessing-0.0.6/umami_preprocessing.egg-info/requires.txt +0 -16
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/README.md +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/setup.cfg +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/SOURCES.txt +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/dependency_links.txt +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/entry_points.txt +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/top_level.txt +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/__init__.py +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/region.py +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/classes/resampling_config.py +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/logger.py +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/stages/__init__.py +0 -0
- {umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/upp/utils.py +0 -0
|
@@ -1,26 +1,25 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Preprocessing for jet tagging
|
|
5
5
|
License: MIT
|
|
6
6
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
|
-
Requires-Python:
|
|
7
|
+
Requires-Python: <3.12,>=3.8
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
9
|
Requires-Dist: pyyaml-include==1.3
|
|
10
10
|
Requires-Dist: PyYAML==6.0.1
|
|
11
11
|
Requires-Dist: rich==12.6.0
|
|
12
12
|
Requires-Dist: scipy==1.10.1
|
|
13
|
-
Requires-Dist: puma-hep==0.
|
|
14
|
-
Requires-Dist: atlas-ftag-tools==0.
|
|
13
|
+
Requires-Dist: puma-hep==0.4.1
|
|
14
|
+
Requires-Dist: atlas-ftag-tools==0.2.7
|
|
15
15
|
Requires-Dist: dotmap==1.3.30
|
|
16
16
|
Provides-Extra: dev
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: ruff==0.0.289; extra == "dev"
|
|
17
|
+
Requires-Dist: ruff==0.1.6; extra == "dev"
|
|
19
18
|
Requires-Dist: mypy==1.5.1; extra == "dev"
|
|
20
|
-
Requires-Dist: pre-commit==3.
|
|
21
|
-
Requires-Dist: pytest
|
|
19
|
+
Requires-Dist: pre-commit==3.5.0; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest>=7.0.1; extra == "dev"
|
|
22
21
|
Requires-Dist: pytest-mock==3.11.1; extra == "dev"
|
|
23
|
-
Requires-Dist: pytest-cov
|
|
22
|
+
Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
|
|
24
23
|
|
|
25
24
|
[](https://github.com/psf/black)
|
|
26
25
|
[](https://codecov.io/gh/umami-hep/umami-preprocessing)
|
|
@@ -4,27 +4,26 @@ description = "Preprocessing for jet tagging"
|
|
|
4
4
|
dynamic = ["version"]
|
|
5
5
|
license = {text = "MIT"}
|
|
6
6
|
readme = "README.md"
|
|
7
|
-
requires-python = "
|
|
7
|
+
requires-python = "<3.12,>=3.8"
|
|
8
8
|
|
|
9
9
|
dependencies = [
|
|
10
10
|
"pyyaml-include==1.3",
|
|
11
11
|
"PyYAML==6.0.1",
|
|
12
12
|
"rich==12.6.0",
|
|
13
13
|
"scipy==1.10.1",
|
|
14
|
-
"puma-hep==0.
|
|
15
|
-
"atlas-ftag-tools==0.
|
|
14
|
+
"puma-hep==0.4.1",
|
|
15
|
+
"atlas-ftag-tools==0.2.7",
|
|
16
16
|
"dotmap==1.3.30"
|
|
17
17
|
]
|
|
18
18
|
|
|
19
19
|
[project.optional-dependencies]
|
|
20
20
|
dev = [
|
|
21
|
-
"
|
|
22
|
-
"ruff==0.0.289",
|
|
21
|
+
"ruff==0.1.6",
|
|
23
22
|
"mypy==1.5.1",
|
|
24
|
-
"pre-commit==3.
|
|
25
|
-
"pytest
|
|
23
|
+
"pre-commit==3.5.0",
|
|
24
|
+
"pytest>=7.0.1",
|
|
26
25
|
"pytest-mock==3.11.1",
|
|
27
|
-
"pytest-cov
|
|
26
|
+
"pytest-cov>=3.0.0",
|
|
28
27
|
]
|
|
29
28
|
|
|
30
29
|
[project.urls]
|
|
@@ -44,24 +43,17 @@ version = {attr = "upp.__version__"}
|
|
|
44
43
|
requires = ["setuptools>=62"]
|
|
45
44
|
build-backend = "setuptools.build_meta"
|
|
46
45
|
|
|
47
|
-
[tool.black]
|
|
48
|
-
line-length = 100
|
|
49
|
-
preview = "True"
|
|
50
|
-
|
|
51
46
|
[tool.ruff]
|
|
52
|
-
select = ["I", "E", "W", "F", "B", "UP", "ARG", "SIM", "TID", "RUF", "D2", "D3", "D4"]
|
|
53
|
-
ignore = ["D211", "D213", "RUF005"]
|
|
47
|
+
lint.select = ["I", "E", "W", "F", "B", "UP", "ARG", "SIM", "TID", "RUF", "D2", "D3", "D4"]
|
|
48
|
+
lint.ignore = ["D211", "D213", "RUF005"]
|
|
54
49
|
line-length = 100
|
|
55
50
|
|
|
56
|
-
[tool.ruff.isort]
|
|
51
|
+
[tool.ruff.lint.isort]
|
|
57
52
|
required-imports = ["from __future__ import annotations"]
|
|
58
53
|
|
|
59
|
-
[tool.ruff.pydocstyle]
|
|
54
|
+
[tool.ruff.lint.pydocstyle]
|
|
60
55
|
convention = "numpy" # Accepts: "google", "numpy", or "pep257".
|
|
61
56
|
|
|
62
|
-
[mypy]
|
|
63
|
-
ignore_missing_imports = "True"
|
|
64
|
-
|
|
65
57
|
[tool.pytest.ini_options]
|
|
66
58
|
log_cli_level = "debug"
|
|
67
59
|
filterwarnings = ["ignore::DeprecationWarning"]
|
{umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/PKG-INFO
RENAMED
|
@@ -1,26 +1,25 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: umami-preprocessing
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Preprocessing for jet tagging
|
|
5
5
|
License: MIT
|
|
6
6
|
Project-URL: Homepage, https://github.com/umami-hep/umami-preprocessing
|
|
7
|
-
Requires-Python:
|
|
7
|
+
Requires-Python: <3.12,>=3.8
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
9
|
Requires-Dist: pyyaml-include==1.3
|
|
10
10
|
Requires-Dist: PyYAML==6.0.1
|
|
11
11
|
Requires-Dist: rich==12.6.0
|
|
12
12
|
Requires-Dist: scipy==1.10.1
|
|
13
|
-
Requires-Dist: puma-hep==0.
|
|
14
|
-
Requires-Dist: atlas-ftag-tools==0.
|
|
13
|
+
Requires-Dist: puma-hep==0.4.1
|
|
14
|
+
Requires-Dist: atlas-ftag-tools==0.2.7
|
|
15
15
|
Requires-Dist: dotmap==1.3.30
|
|
16
16
|
Provides-Extra: dev
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: ruff==0.0.289; extra == "dev"
|
|
17
|
+
Requires-Dist: ruff==0.1.6; extra == "dev"
|
|
19
18
|
Requires-Dist: mypy==1.5.1; extra == "dev"
|
|
20
|
-
Requires-Dist: pre-commit==3.
|
|
21
|
-
Requires-Dist: pytest
|
|
19
|
+
Requires-Dist: pre-commit==3.5.0; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest>=7.0.1; extra == "dev"
|
|
22
21
|
Requires-Dist: pytest-mock==3.11.1; extra == "dev"
|
|
23
|
-
Requires-Dist: pytest-cov
|
|
22
|
+
Requires-Dist: pytest-cov>=3.0.0; extra == "dev"
|
|
24
23
|
|
|
25
24
|
[](https://github.com/psf/black)
|
|
26
25
|
[](https://codecov.io/gh/umami-hep/umami-preprocessing)
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
from ftag import Cuts,
|
|
8
|
+
from ftag import Cuts, Label, Sample
|
|
9
9
|
from ftag.hdf5 import H5Reader, H5Writer
|
|
10
10
|
|
|
11
11
|
from upp.classes.region import Region
|
|
@@ -16,26 +16,28 @@ from upp.stages.hist import Hist
|
|
|
16
16
|
class Component:
|
|
17
17
|
region: Region
|
|
18
18
|
sample: Sample
|
|
19
|
-
flavour:
|
|
19
|
+
flavour: Label
|
|
20
20
|
global_cuts: Cuts
|
|
21
21
|
dirname: Path
|
|
22
22
|
num_jets: int
|
|
23
|
-
|
|
24
|
-
equal_jets: bool
|
|
23
|
+
num_jets_estimate_available: int
|
|
24
|
+
equal_jets: bool
|
|
25
25
|
|
|
26
26
|
def __post_init__(self):
|
|
27
27
|
self.hist = Hist(self.dirname.parent.parent / "hists" / f"hist_{self.name}.h5")
|
|
28
28
|
|
|
29
|
-
def setup_reader(self, batch_size, fname=None, **kwargs):
|
|
29
|
+
def setup_reader(self, batch_size, jets_name="jets", fname=None, **kwargs):
|
|
30
30
|
if fname is None:
|
|
31
31
|
fname = self.sample.path
|
|
32
|
-
self.reader = H5Reader(
|
|
32
|
+
self.reader = H5Reader(
|
|
33
|
+
fname, batch_size, jets_name=jets_name, equal_jets=self.equal_jets, **kwargs
|
|
34
|
+
)
|
|
33
35
|
log.debug(f"Setup component reader at: {fname}")
|
|
34
36
|
|
|
35
|
-
def setup_writer(self, variables):
|
|
37
|
+
def setup_writer(self, variables, jets_name="jets"):
|
|
36
38
|
dtypes = self.reader.dtypes(variables.combined())
|
|
37
39
|
shapes = self.reader.shapes(self.num_jets, variables.keys())
|
|
38
|
-
self.writer = H5Writer(self.out_path, dtypes, shapes)
|
|
40
|
+
self.writer = H5Writer(self.out_path, dtypes, shapes, jets_name=jets_name)
|
|
39
41
|
log.debug(f"Setup component writer at: {self.out_path}")
|
|
40
42
|
|
|
41
43
|
@property
|
|
@@ -61,7 +63,10 @@ class Component:
|
|
|
61
63
|
self, num_req, sampling_frac=None, cuts=None, silent=False, raise_error=True
|
|
62
64
|
):
|
|
63
65
|
# Check if num_jets jets are aviailable after the cuts and sampling fraction
|
|
64
|
-
|
|
66
|
+
num_est = (
|
|
67
|
+
None if self.num_jets_estimate_available <= 0 else self.num_jets_estimate_available
|
|
68
|
+
)
|
|
69
|
+
total = self.reader.estimate_available_jets(cuts, num_est)
|
|
65
70
|
available = total
|
|
66
71
|
if sampling_frac:
|
|
67
72
|
available = int(total * sampling_frac)
|
|
@@ -77,11 +82,17 @@ class Component:
|
|
|
77
82
|
|
|
78
83
|
if not silent:
|
|
79
84
|
log.debug(f"Sampling fraction {sampling_frac}")
|
|
80
|
-
log.info(
|
|
85
|
+
log.info(
|
|
86
|
+
f"Estimated {available:,} {self} jets available - {num_req:,} requested"
|
|
87
|
+
f"({self.reader.num_jets:,} in {self.sample})"
|
|
88
|
+
)
|
|
81
89
|
|
|
82
90
|
def get_auto_sampling_frac(self, num_jets, cuts=None, silent=False):
|
|
83
|
-
|
|
84
|
-
|
|
91
|
+
num_est = (
|
|
92
|
+
None if self.num_jets_estimate_available <= 0 else self.num_jets_estimate_available
|
|
93
|
+
)
|
|
94
|
+
total = self.reader.estimate_available_jets(cuts, num_est)
|
|
95
|
+
auto_sampling_frac = round(1.1 * num_jets / total, 3) # 1.1 is a tolerance factor
|
|
85
96
|
if not silent:
|
|
86
97
|
log.debug(f"optimal sampling fraction {auto_sampling_frac:.3f}")
|
|
87
98
|
return auto_sampling_frac
|
|
@@ -102,6 +113,7 @@ class Components:
|
|
|
102
113
|
def from_config(cls, pp_cfg):
|
|
103
114
|
components = []
|
|
104
115
|
for c in pp_cfg.config["components"]:
|
|
116
|
+
assert "equal_jets" not in c, "equal_jets flag should be set in the sample config"
|
|
105
117
|
region_cuts = Cuts.empty() if pp_cfg.is_test else Cuts.from_list(c["region"]["cuts"])
|
|
106
118
|
region = Region(c["region"]["name"], region_cuts + pp_cfg.global_cuts)
|
|
107
119
|
pattern = c["sample"]["pattern"]
|
|
@@ -119,11 +131,11 @@ class Components:
|
|
|
119
131
|
Component(
|
|
120
132
|
region,
|
|
121
133
|
sample,
|
|
122
|
-
|
|
134
|
+
pp_cfg.flavour_cont[name],
|
|
123
135
|
pp_cfg.global_cuts,
|
|
124
136
|
pp_cfg.components_dir,
|
|
125
137
|
num_jets,
|
|
126
|
-
pp_cfg.
|
|
138
|
+
pp_cfg.num_jets_estimate_available,
|
|
127
139
|
equal_jets,
|
|
128
140
|
)
|
|
129
141
|
)
|
|
@@ -193,7 +205,7 @@ class Components:
|
|
|
193
205
|
|
|
194
206
|
@property
|
|
195
207
|
def dsids(self):
|
|
196
|
-
return list(set(sum([c.sample.dsid for c in self], [])))
|
|
208
|
+
return list(set(sum([c.sample.dsid for c in self], []))) # noqa: RUF017
|
|
197
209
|
|
|
198
210
|
def groupby_region(self):
|
|
199
211
|
return [(r, Components([c for c in self if c.region == r])) for r in self.regions]
|
|
@@ -207,7 +219,7 @@ class Components:
|
|
|
207
219
|
def __getitem__(self, index):
|
|
208
220
|
if isinstance(index, int):
|
|
209
221
|
return self.components[index]
|
|
210
|
-
if isinstance(index, (str,
|
|
222
|
+
if isinstance(index, (str, Label)):
|
|
211
223
|
return self.components[self.flavours.index(index)]
|
|
212
224
|
|
|
213
225
|
def __len__(self):
|
|
@@ -6,12 +6,14 @@ import logging as log
|
|
|
6
6
|
from copy import copy
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from subprocess import CalledProcessError, check_output
|
|
10
9
|
from typing import Literal
|
|
11
10
|
|
|
12
11
|
import yaml
|
|
13
12
|
from dotmap import DotMap
|
|
14
13
|
from ftag import Cuts
|
|
14
|
+
from ftag.git_check import get_git_hash
|
|
15
|
+
from ftag.labels import LabelContainer
|
|
16
|
+
from ftag.track_selector import TrackSelector
|
|
15
17
|
from ftag.transform import Transform
|
|
16
18
|
from yamlinclude import YamlIncludeConstructor
|
|
17
19
|
|
|
@@ -42,6 +44,7 @@ class PreprocessingConfig:
|
|
|
42
44
|
For example:
|
|
43
45
|
```yaml
|
|
44
46
|
global:
|
|
47
|
+
jets_name: jets
|
|
45
48
|
batch_size: 1_000_000
|
|
46
49
|
num_jets_estimate: 5_000_000
|
|
47
50
|
base_dir: /my/stuff/
|
|
@@ -69,8 +72,21 @@ class PreprocessingConfig:
|
|
|
69
72
|
especially to the `countup` method to achive best agreement of target and resampled
|
|
70
73
|
distributions.
|
|
71
74
|
num_jets_estimate : int
|
|
75
|
+
Any of the further three arguments that are not specified will default to this value
|
|
76
|
+
Is equal to 1_000_000 by default.
|
|
77
|
+
num_jets_estimate_available : int | None
|
|
78
|
+
A sabsample taken from the whole sample to estimate the number of jets after the cuts.
|
|
79
|
+
Please keep this number high in order to not get poisson error of more then 5%.
|
|
80
|
+
If time allows you can use -1 to get a precise number of jets and not just an estimate
|
|
81
|
+
although it will be slow for large datasets. Is equal to num_jets_estimate by default.
|
|
82
|
+
num_jets_estimate_hist : int
|
|
72
83
|
Number of jets of each flavour that are used to construct histograms for probability
|
|
73
84
|
density function estimation. Larger numbers give a better quality estmate of the pdfs.
|
|
85
|
+
Is equal to num_jets_estimate by default.
|
|
86
|
+
num_jets_estimate_norm : int
|
|
87
|
+
Number of jets of each flavour that are used to estimate shifting and scaling during
|
|
88
|
+
normalisation step. Larger numbers give a better quality estmates.
|
|
89
|
+
Is equal to num_jets_estimate by default.
|
|
74
90
|
jets_name : str
|
|
75
91
|
Name of the jets dataset in the input file.
|
|
76
92
|
"""
|
|
@@ -85,24 +101,50 @@ class PreprocessingConfig:
|
|
|
85
101
|
out_fname: Path = Path("pp_output.h5")
|
|
86
102
|
batch_size: int = 100_000
|
|
87
103
|
num_jets_estimate: int = 1_000_000
|
|
104
|
+
num_jets_estimate_available: int | None = None
|
|
105
|
+
num_jets_estimate_hist: int | None = None
|
|
106
|
+
num_jets_estimate_norm: int | None = None
|
|
88
107
|
merge_test_samples: bool = False
|
|
89
108
|
jets_name: str = "jets"
|
|
109
|
+
flavour_config: Path | None = None
|
|
90
110
|
|
|
91
111
|
def __post_init__(self):
|
|
92
112
|
# postprocess paths
|
|
113
|
+
if self.num_jets_estimate:
|
|
114
|
+
if self.num_jets_estimate_available is None:
|
|
115
|
+
self.num_jets_estimate_available = max(self.num_jets_estimate, int(1e6))
|
|
116
|
+
if self.num_jets_estimate_hist is None:
|
|
117
|
+
self.num_jets_estimate_hist = self.num_jets_estimate
|
|
118
|
+
if self.num_jets_estimate_norm is None:
|
|
119
|
+
self.num_jets_estimate_norm = self.num_jets_estimate
|
|
120
|
+
|
|
93
121
|
for field in dataclasses.fields(self):
|
|
94
|
-
if field.type == "Path" and field.name != "out_fname":
|
|
122
|
+
if field.type == "Path" and field.name != "out_fname" and field.name != "base_dir":
|
|
95
123
|
setattr(self, field.name, self.get_path(Path(getattr(self, field.name))))
|
|
96
124
|
if not self.ntuple_dir.exists():
|
|
97
125
|
raise FileNotFoundError(f"Path {self.ntuple_dir} does not exist")
|
|
98
126
|
self.components_dir = self.components_dir / self.split
|
|
99
127
|
self.out_fname = self.out_dir / path_append(self.out_fname, self.split)
|
|
128
|
+
self.flavour_cont = LabelContainer.from_yaml(self.flavour_config)
|
|
100
129
|
|
|
101
130
|
# configure classes
|
|
102
131
|
sampl_cfg = copy(self.config["resampling"])
|
|
103
|
-
self.
|
|
132
|
+
if self.is_test:
|
|
133
|
+
sampl_cfg["method"] = None
|
|
134
|
+
self.sampl_cfg = ResamplingConfig(**sampl_cfg)
|
|
104
135
|
self.components = Components.from_config(self)
|
|
105
|
-
|
|
136
|
+
|
|
137
|
+
# get track selectors
|
|
138
|
+
vc = self.config["variables"]
|
|
139
|
+
selectors = {}
|
|
140
|
+
for name, groups in vc.items():
|
|
141
|
+
if selection := groups.get("selection", None):
|
|
142
|
+
selectors[name] = TrackSelector(Cuts.from_list(selection))
|
|
143
|
+
|
|
144
|
+
# configure variables
|
|
145
|
+
self.variables = VariableConfig(
|
|
146
|
+
self.config["variables"], self.jets_name, self.is_test, selectors
|
|
147
|
+
)
|
|
106
148
|
self.variables = self.variables.add_jet_vars(
|
|
107
149
|
list(self.config["resampling"]["variables"].keys()), "labels"
|
|
108
150
|
)
|
|
@@ -110,17 +152,13 @@ class PreprocessingConfig:
|
|
|
110
152
|
Transform(**self.config["transform"]) if "transform" in self.config else None
|
|
111
153
|
)
|
|
112
154
|
|
|
113
|
-
#
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
["git", "rev-parse", "--short", "HEAD"], cwd=Path(__file__).parent
|
|
117
|
-
)
|
|
118
|
-
self.git_hash = git_hash.decode("ascii").strip()
|
|
119
|
-
self.config["pp_git_hash"] = self.git_hash
|
|
120
|
-
except CalledProcessError:
|
|
121
|
-
log.warning("Could not get git hash")
|
|
155
|
+
# reproducibility
|
|
156
|
+
self.git_hash = get_git_hash(Path(__file__).parent)
|
|
157
|
+
if self.git_hash is None:
|
|
122
158
|
self.git_hash = __version__
|
|
123
|
-
|
|
159
|
+
self.config["upp_hash"] = self.git_hash
|
|
160
|
+
|
|
161
|
+
# copy config
|
|
124
162
|
self.copy_config()
|
|
125
163
|
|
|
126
164
|
@classmethod
|
|
@@ -3,12 +3,15 @@ from __future__ import annotations
|
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
|
+
from ftag.track_selector import TrackSelector
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
@dataclass(frozen=True)
|
|
8
10
|
class VariableConfig:
|
|
9
11
|
variables: dict[str, dict[str, list[str]]]
|
|
10
12
|
jets_name: str = "jets"
|
|
11
13
|
keep_all: bool = False
|
|
14
|
+
selectors: dict[str, TrackSelector] | None = None
|
|
12
15
|
|
|
13
16
|
def __post_init__(self):
|
|
14
17
|
for track_vars in self.tracks.values():
|
|
@@ -33,7 +36,7 @@ class VariableConfig:
|
|
|
33
36
|
|
|
34
37
|
def add_jet_vars(self, variables: list[str], kind: str = "inputs") -> VariableConfig:
|
|
35
38
|
"""Return a new VariableConfig instance."""
|
|
36
|
-
vc = VariableConfig(deepcopy(self.variables), self.jets_name, self.keep_all)
|
|
39
|
+
vc = VariableConfig(deepcopy(self.variables), self.jets_name, self.keep_all, self.selectors)
|
|
37
40
|
vc.jets[kind] = list(dict.fromkeys(vc.jets[kind] + variables))
|
|
38
41
|
return vc
|
|
39
42
|
|
|
@@ -8,11 +8,13 @@ To run without certain stages, include the corresponding negative flag.
|
|
|
8
8
|
Note that all stages are required to run the pipeline. If you want to disable resampling,
|
|
9
9
|
you need to set method: none in your config file.
|
|
10
10
|
"""
|
|
11
|
+
|
|
11
12
|
from __future__ import annotations
|
|
12
13
|
|
|
13
14
|
import argparse
|
|
14
15
|
from datetime import datetime
|
|
15
|
-
|
|
16
|
+
|
|
17
|
+
from ftag.cli_utils import HelpFormatter, valid_path
|
|
16
18
|
|
|
17
19
|
from upp.classes.preprocessing_config import PreprocessingConfig
|
|
18
20
|
from upp.logger import setup_logger
|
|
@@ -23,30 +25,24 @@ from upp.stages.plot import plot_initial_resampling_dists, plot_resampled_dists
|
|
|
23
25
|
from upp.stages.resampling import Resampling
|
|
24
26
|
|
|
25
27
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
parser = argparse.ArgumentParser(
|
|
32
|
-
description=__doc__,
|
|
33
|
-
formatter_class=HelpFormatter,
|
|
34
|
-
)
|
|
35
|
-
parser.add_argument("--config", required=True, type=Path, help="Path to config file")
|
|
36
|
-
parser.add_argument("--prep", action=abool, default=None, help="Estimate and write PDFs")
|
|
28
|
+
def parse_args(args):
|
|
29
|
+
_st = "store_true"
|
|
30
|
+
parser = argparse.ArgumentParser(description=__doc__, formatter_class=HelpFormatter)
|
|
31
|
+
parser.add_argument("--config", required=True, type=valid_path, help="Path to config file")
|
|
32
|
+
parser.add_argument("--prep", action=_st, default=None, help="Estimate and write PDFs")
|
|
37
33
|
parser.add_argument("--no-prep", dest="prep", action="store_false")
|
|
38
|
-
parser.add_argument("--resample", action=
|
|
34
|
+
parser.add_argument("--resample", action=_st, default=None, help="Run resampling")
|
|
39
35
|
parser.add_argument("--no-resample", dest="resample", action="store_false")
|
|
40
|
-
parser.add_argument("--merge", action=
|
|
36
|
+
parser.add_argument("--merge", action=_st, default=None, help="Run merging")
|
|
41
37
|
parser.add_argument("--no-merge", dest="merge", action="store_false")
|
|
42
|
-
parser.add_argument("--norm", action=
|
|
38
|
+
parser.add_argument("--norm", action=_st, default=None, help="Compute normalisations")
|
|
43
39
|
parser.add_argument("--no-norm", dest="norm", action="store_false")
|
|
44
|
-
parser.add_argument("--plot", action=
|
|
40
|
+
parser.add_argument("--plot", action=_st, default=None, help="Plot output distributions")
|
|
45
41
|
parser.add_argument("--no-plot", dest="plot", action="store_false")
|
|
46
42
|
splits = ["train", "val", "test", "all"]
|
|
47
43
|
parser.add_argument("--split", default="train", choices=splits, help="Which file to produce")
|
|
48
44
|
|
|
49
|
-
args = parser.parse_args()
|
|
45
|
+
args = parser.parse_args(args)
|
|
50
46
|
d = vars(args)
|
|
51
47
|
ignore = ["config", "split"]
|
|
52
48
|
if not any(v for a, v in d.items() if a not in ignore):
|
|
@@ -65,7 +61,7 @@ def run_pp(args) -> None:
|
|
|
65
61
|
log.info(f"Start time: {start.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
66
62
|
|
|
67
63
|
# load config
|
|
68
|
-
config = PreprocessingConfig.from_file(
|
|
64
|
+
config = PreprocessingConfig.from_file(args.config, args.split)
|
|
69
65
|
|
|
70
66
|
# create virtual datasets and pdf files
|
|
71
67
|
if args.prep and args.split == "train":
|
|
@@ -88,6 +84,8 @@ def run_pp(args) -> None:
|
|
|
88
84
|
|
|
89
85
|
# make plots
|
|
90
86
|
if args.plot:
|
|
87
|
+
title = " Plotting "
|
|
88
|
+
log.info(f"[bold green]{title:-^100}")
|
|
91
89
|
plot_initial_resampling_dists(config=config)
|
|
92
90
|
plot_resampled_dists(config=config, stage=args.split)
|
|
93
91
|
|
|
@@ -99,8 +97,8 @@ def run_pp(args) -> None:
|
|
|
99
97
|
log.info(f"Elapsed time: {str(end - start).split('.')[0]}")
|
|
100
98
|
|
|
101
99
|
|
|
102
|
-
def main() -> None:
|
|
103
|
-
args = parse_args()
|
|
100
|
+
def main(args=None) -> None:
|
|
101
|
+
args = parse_args(args)
|
|
104
102
|
log = setup_logger()
|
|
105
103
|
|
|
106
104
|
if args.split == "all":
|
|
@@ -130,19 +130,23 @@ def create_histograms(config) -> None:
|
|
|
130
130
|
title = " Writing PDFs "
|
|
131
131
|
log.info(f"[bold green]{title:-^100}")
|
|
132
132
|
|
|
133
|
-
log.info(f"[bold green]Estimating PDFs using {config.
|
|
133
|
+
log.info(f"[bold green]Estimating PDFs using {config.num_jets_estimate_hist:,} jets...")
|
|
134
134
|
sampl_vars = config.sampl_cfg.vars
|
|
135
135
|
for c in config.components:
|
|
136
|
-
log.info(f"Estimating PDF
|
|
137
|
-
c.setup_reader(config.batch_size)
|
|
136
|
+
log.info(f"Estimating {c} PDF using {config.num_jets_estimate_hist:,} samples...")
|
|
137
|
+
c.setup_reader(config.batch_size, config.jets_name)
|
|
138
138
|
cuts_no_split = c.cuts.ignore(["eventNumber"])
|
|
139
|
+
|
|
140
|
+
###
|
|
141
|
+
# TODO: return the number of jets here and pass to the next function to get started
|
|
142
|
+
###
|
|
139
143
|
c.check_num_jets(
|
|
140
|
-
config.
|
|
144
|
+
config.num_jets_estimate_hist,
|
|
141
145
|
cuts=cuts_no_split,
|
|
142
146
|
silent=False,
|
|
143
147
|
raise_error=False,
|
|
144
148
|
)
|
|
145
|
-
jets = c.get_jets(sampl_vars, config.
|
|
149
|
+
jets = c.get_jets(sampl_vars, config.num_jets_estimate_hist, cuts_no_split)
|
|
146
150
|
c.hist.write_hist(jets, sampl_vars, config.sampl_cfg.flat_bins)
|
|
147
151
|
|
|
148
152
|
log.info(f"[bold green]Saved to {config.components[0].hist.path.parent}/")
|
|
@@ -70,7 +70,7 @@ def upscale_array(
|
|
|
70
70
|
def upscale_array_regionally(
|
|
71
71
|
array: np.array,
|
|
72
72
|
upscl: int,
|
|
73
|
-
|
|
73
|
+
num_bins: list,
|
|
74
74
|
order: int = 3,
|
|
75
75
|
mode: str = "nearest",
|
|
76
76
|
positive: bool = True,
|
|
@@ -83,7 +83,7 @@ def upscale_array_regionally(
|
|
|
83
83
|
array to be upscaled
|
|
84
84
|
upscl : int
|
|
85
85
|
upscaling factor
|
|
86
|
-
|
|
86
|
+
num_bins : list
|
|
87
87
|
list of lists of region lengths in each dimension,
|
|
88
88
|
region lengths should sum to the length of the array in that dimension
|
|
89
89
|
order : int, optional
|
|
@@ -99,10 +99,10 @@ def upscale_array_regionally(
|
|
|
99
99
|
Array that is upscaled by a factor of upscl
|
|
100
100
|
"""
|
|
101
101
|
up_array = np.empty(shape=[ds * upscl for ds in array.shape])
|
|
102
|
-
starts = [np.cumsum([0] + regionlengths)[:-1] for regionlengths in
|
|
102
|
+
starts = [np.cumsum([0] + regionlengths)[:-1] for regionlengths in num_bins]
|
|
103
103
|
starts_grid = np.meshgrid(*starts)
|
|
104
104
|
starts_grid = [starts_grid[i].flatten() for i in range(len(starts_grid))]
|
|
105
|
-
finishes = [np.cumsum(regionlengths) for regionlengths in
|
|
105
|
+
finishes = [np.cumsum(regionlengths) for regionlengths in num_bins]
|
|
106
106
|
finishes_grid = np.meshgrid(*finishes)
|
|
107
107
|
finishes_grid = [finishes_grid[i].flatten() for i in range(len(finishes_grid))]
|
|
108
108
|
d = len(array.shape)
|
|
@@ -17,11 +17,13 @@ class Merging:
|
|
|
17
17
|
self.components = config.components
|
|
18
18
|
self.variables = config.variables
|
|
19
19
|
self.batch_size = config.batch_size
|
|
20
|
-
self.jets_name =
|
|
20
|
+
self.jets_name = config.jets_name
|
|
21
21
|
self.rng = np.random.default_rng(42)
|
|
22
22
|
self.flavours = self.components.flavours
|
|
23
23
|
|
|
24
24
|
def add_jet_flavour_label(self, jets, component):
|
|
25
|
+
if "flavour_label" in jets.dtype.names:
|
|
26
|
+
return jets
|
|
25
27
|
int_label = self.flavours.index(component.flavour)
|
|
26
28
|
label_array = np.full(len(jets), int_label, dtype=[("flavour_label", "i4")])
|
|
27
29
|
return join_structured_arrays([jets, label_array])
|
|
@@ -49,6 +51,13 @@ class Merging:
|
|
|
49
51
|
if all(c.complete for c in components):
|
|
50
52
|
return False
|
|
51
53
|
|
|
54
|
+
# apply track selections
|
|
55
|
+
for name in self.variables.variables:
|
|
56
|
+
if name == self.jets_name:
|
|
57
|
+
continue
|
|
58
|
+
if selector := self.variables.selectors.get(name):
|
|
59
|
+
merged[name] = selector(merged[name])
|
|
60
|
+
|
|
52
61
|
# write
|
|
53
62
|
self.writer.write(merged)
|
|
54
63
|
return len(merged[self.jets_name])
|
|
@@ -57,7 +66,7 @@ class Merging:
|
|
|
57
66
|
# setup inputs
|
|
58
67
|
for c in components:
|
|
59
68
|
batch_size = self.batch_size * c.num_jets // components.num_jets + 1
|
|
60
|
-
c.setup_reader(batch_size, fname=c.out_path)
|
|
69
|
+
c.setup_reader(batch_size, fname=c.out_path, jets_name=self.jets_name)
|
|
61
70
|
c.stream = c.reader.stream(self.variables.combined(), c.reader.num_jets)
|
|
62
71
|
c.complete = False
|
|
63
72
|
|
|
@@ -70,6 +79,7 @@ class Merging:
|
|
|
70
79
|
components[0].reader.dtypes(self.variables.combined()),
|
|
71
80
|
components[0].reader.shapes(components.num_jets, self.variables.keys()),
|
|
72
81
|
add_flavour_label=self.jets_name,
|
|
82
|
+
jets_name=self.jets_name,
|
|
73
83
|
)
|
|
74
84
|
self.writer.add_attr("flavour_label", [f.name for f in self.flavours], self.jets_name)
|
|
75
85
|
self.writer.add_attr("unique_jets", components.unique_jets)
|
|
@@ -16,7 +16,7 @@ class Normalisation:
|
|
|
16
16
|
self.components = config.components
|
|
17
17
|
self.variables = config.variables
|
|
18
18
|
self.jets_name = self.ppc.jets_name
|
|
19
|
-
self.num_jets = config.
|
|
19
|
+
self.num_jets = config.num_jets_estimate_norm
|
|
20
20
|
self.norm_fname = config.out_dir / config.config.get("norm_fname", "norm_dict.yaml")
|
|
21
21
|
self.class_fname = config.out_dir / config.config.get("class_fname", "class_dict.yaml")
|
|
22
22
|
|
|
@@ -62,7 +62,7 @@ class Normalisation:
|
|
|
62
62
|
return combined
|
|
63
63
|
|
|
64
64
|
def get_class_dict(self, batch):
|
|
65
|
-
ignore = ["VertexIndex", "ftagTruthParentBarcode", "barcode"]
|
|
65
|
+
ignore = ["VertexIndex", "ftagTruthParentBarcode", "barcode", "eventNumber", "jetFoldHash"]
|
|
66
66
|
class_dict = {k: {} for k in self.variables}
|
|
67
67
|
for name, array in batch.items():
|
|
68
68
|
if name != self.variables.jets_name:
|
|
@@ -118,7 +118,9 @@ class Normalisation:
|
|
|
118
118
|
log.info(f"[bold green]{title:-^100}")
|
|
119
119
|
|
|
120
120
|
# setup reader
|
|
121
|
-
reader = H5Reader(
|
|
121
|
+
reader = H5Reader(
|
|
122
|
+
self.ppc.out_fname, self.ppc.batch_size, precision="full", jets_name=self.jets_name
|
|
123
|
+
)
|
|
122
124
|
log.debug(f"Setup reader at: {self.ppc.out_fname}")
|
|
123
125
|
|
|
124
126
|
norm_dict = None
|
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
|
|
6
6
|
from ftag import Flavours
|
|
7
7
|
from ftag.hdf5 import H5Reader
|
|
8
|
+
from ftag.labels import LabelContainer
|
|
8
9
|
from puma import Histogram, HistogramPlot
|
|
9
10
|
|
|
10
11
|
from upp.utils import path_append
|
|
@@ -14,6 +15,7 @@ def load_jets(
|
|
|
14
15
|
paths: str | list,
|
|
15
16
|
variable: str,
|
|
16
17
|
flavour_label="flavour_label",
|
|
18
|
+
jets_name="jets",
|
|
17
19
|
) -> dict:
|
|
18
20
|
"""
|
|
19
21
|
Load the variables and labels from the jets in a given file(s).
|
|
@@ -28,15 +30,18 @@ def load_jets(
|
|
|
28
30
|
flavour_label : str, optional
|
|
29
31
|
Name of the flavour label variable which is used for the labels,
|
|
30
32
|
by default "flavour_label"
|
|
33
|
+
jets_name: str, optional
|
|
34
|
+
Name of the jet dataset / the global objects
|
|
35
|
+
by default "jets"
|
|
31
36
|
|
|
32
37
|
Returns
|
|
33
38
|
-------
|
|
34
39
|
dict
|
|
35
40
|
Dict with the loaded variable and labels.
|
|
36
41
|
"""
|
|
37
|
-
variables = {
|
|
38
|
-
reader = H5Reader(paths, batch_size=1000)
|
|
39
|
-
df = reader.load(variables, num_jets=10000)[
|
|
42
|
+
variables = {jets_name: [flavour_label, variable]}
|
|
43
|
+
reader = H5Reader(paths, batch_size=1000, jets_name=jets_name)
|
|
44
|
+
df = reader.load(variables, num_jets=10000)[jets_name]
|
|
40
45
|
return df
|
|
41
46
|
|
|
42
47
|
|
|
@@ -45,8 +50,10 @@ def make_hist(
|
|
|
45
50
|
flavours: list,
|
|
46
51
|
variable: str,
|
|
47
52
|
in_paths: str | list,
|
|
53
|
+
jets_name: str = "jets",
|
|
48
54
|
bins_range: tuple | None = None,
|
|
49
55
|
suffix: str = "",
|
|
56
|
+
flavour_cont: LabelContainer = Flavours,
|
|
50
57
|
) -> None:
|
|
51
58
|
"""
|
|
52
59
|
Create and plot the histogram and save it to disk.
|
|
@@ -64,6 +71,9 @@ def make_hist(
|
|
|
64
71
|
Variable that is to be histogrammed and plotted.
|
|
65
72
|
in_paths : str
|
|
66
73
|
Path to the files from which the jets are loaded.
|
|
74
|
+
jets_name: str, optional
|
|
75
|
+
Name of the jet dataset / the global objects
|
|
76
|
+
by default "jets"
|
|
67
77
|
bins_range : tuple, optional
|
|
68
78
|
bins_range argument from from puma.HistogramPlot,
|
|
69
79
|
by default None
|
|
@@ -72,11 +82,11 @@ def make_hist(
|
|
|
72
82
|
output name, by default "".
|
|
73
83
|
"""
|
|
74
84
|
# Load the variable from the jets
|
|
75
|
-
df = load_jets(in_paths, variable)
|
|
85
|
+
df = load_jets(in_paths, variable, jets_name=jets_name)
|
|
76
86
|
|
|
77
87
|
# Setup the histogram
|
|
78
88
|
plot = HistogramPlot(
|
|
79
|
-
ylabel="Normalised Number of
|
|
89
|
+
ylabel=f"Normalised Number of {jets_name}",
|
|
80
90
|
atlas_second_tag="$\\sqrt{s}=13$ TeV",
|
|
81
91
|
xlabel=variable,
|
|
82
92
|
bins=50,
|
|
@@ -94,8 +104,8 @@ def make_hist(
|
|
|
94
104
|
plot.add(
|
|
95
105
|
Histogram(
|
|
96
106
|
df[df["flavour_label"] == label_value][variable],
|
|
97
|
-
label=
|
|
98
|
-
colour=
|
|
107
|
+
label=flavour_cont[label_string].label,
|
|
108
|
+
colour=flavour_cont[label_string].colour,
|
|
99
109
|
)
|
|
100
110
|
)
|
|
101
111
|
|
|
@@ -118,6 +128,7 @@ def make_hist_initial(
|
|
|
118
128
|
flavours: list,
|
|
119
129
|
variable: str,
|
|
120
130
|
in_paths_list: str | list,
|
|
131
|
+
jets_name: str = "jets",
|
|
121
132
|
bins_range: tuple | None = None,
|
|
122
133
|
suffix: str = "",
|
|
123
134
|
jets_to_plot: int = -1,
|
|
@@ -125,7 +136,7 @@ def make_hist_initial(
|
|
|
125
136
|
suffixes: list | None = None,
|
|
126
137
|
out_format: str = "png",
|
|
127
138
|
) -> None:
|
|
128
|
-
"""Make
|
|
139
|
+
"""Make initial distribution plots.
|
|
129
140
|
|
|
130
141
|
Plot the initial distribution of the given variable
|
|
131
142
|
for multiple different samples (like ttbar, zpext, etc.)
|
|
@@ -145,6 +156,9 @@ def make_hist_initial(
|
|
|
145
156
|
in_paths_list : str | list
|
|
146
157
|
String or list of strings with the paths to the files
|
|
147
158
|
from which the jets are loaded.
|
|
159
|
+
jets_name: str, optional
|
|
160
|
+
Name of the jet dataset / the global objects
|
|
161
|
+
by default "jets"
|
|
148
162
|
bins_range : tuple, optional
|
|
149
163
|
bins_range argument from from puma.HistogramPlot,
|
|
150
164
|
by default None
|
|
@@ -163,7 +177,7 @@ def make_hist_initial(
|
|
|
163
177
|
"""
|
|
164
178
|
# Setup the histogram
|
|
165
179
|
plot = HistogramPlot(
|
|
166
|
-
ylabel="Normalised Number of
|
|
180
|
+
ylabel=f"Normalised Number of {jets_name}",
|
|
167
181
|
atlas_second_tag="$\\sqrt{s}=13$ TeV",
|
|
168
182
|
xlabel=variable,
|
|
169
183
|
bins=100,
|
|
@@ -187,7 +201,7 @@ def make_hist_initial(
|
|
|
187
201
|
# Loop over the different samples
|
|
188
202
|
for i, in_paths in enumerate(in_paths_list):
|
|
189
203
|
# Load jets from the file
|
|
190
|
-
reader = H5Reader(in_paths, batch_size=10000)
|
|
204
|
+
reader = H5Reader(in_paths, batch_size=10000, jets_name=jets_name)
|
|
191
205
|
|
|
192
206
|
# Loop over the flavours
|
|
193
207
|
for flavour in flavours:
|
|
@@ -197,12 +211,10 @@ def make_hist_initial(
|
|
|
197
211
|
plot.add(
|
|
198
212
|
Histogram(
|
|
199
213
|
reader.load(
|
|
200
|
-
{
|
|
214
|
+
{jets_name: [variable]},
|
|
201
215
|
num_jets=jets_to_plot,
|
|
202
216
|
cuts=flavour.cuts,
|
|
203
|
-
)[
|
|
204
|
-
"jets"
|
|
205
|
-
][variable],
|
|
217
|
+
)[jets_name][variable],
|
|
206
218
|
label=flavour.label + " " + suffixes[i],
|
|
207
219
|
colour=flavour.colour,
|
|
208
220
|
linestyle=linestiles[i],
|
|
@@ -250,6 +262,7 @@ def plot_initial_resampling_dists(config) -> None:
|
|
|
250
262
|
flavours=config.components.flavours,
|
|
251
263
|
variable=var,
|
|
252
264
|
in_paths_list=paths,
|
|
265
|
+
jets_name=config.jets_name,
|
|
253
266
|
jets_to_plot=100000,
|
|
254
267
|
out_dir=config.out_dir / "plots",
|
|
255
268
|
suffixes=suffixes,
|
|
@@ -260,6 +273,7 @@ def plot_initial_resampling_dists(config) -> None:
|
|
|
260
273
|
flavours=config.components.flavours,
|
|
261
274
|
variable=var,
|
|
262
275
|
in_paths_list=paths,
|
|
276
|
+
jets_name=config.jets_name,
|
|
263
277
|
bins_range=(0, 500e3),
|
|
264
278
|
suffix="low",
|
|
265
279
|
jets_to_plot=100000,
|
|
@@ -293,15 +307,19 @@ def plot_resampled_dists(config, stage: str) -> None:
|
|
|
293
307
|
make_hist(
|
|
294
308
|
stage=stage,
|
|
295
309
|
flavours=config.components.flavours,
|
|
310
|
+
flavour_cont=config.flavour_cont,
|
|
296
311
|
variable=var,
|
|
297
312
|
in_paths=paths,
|
|
313
|
+
jets_name=config.jets_name,
|
|
298
314
|
)
|
|
299
315
|
if "pt" in var:
|
|
300
316
|
make_hist(
|
|
301
317
|
stage=stage,
|
|
302
318
|
flavours=config.components.flavours,
|
|
319
|
+
flavour_cont=config.flavour_cont,
|
|
303
320
|
variable=var,
|
|
304
321
|
in_paths=paths,
|
|
322
|
+
jets_name=config.jets_name,
|
|
305
323
|
bins_range=(0, 500e3),
|
|
306
324
|
suffix="low",
|
|
307
325
|
)
|
|
@@ -38,14 +38,14 @@ class Resampling:
|
|
|
38
38
|
self.components = config.components
|
|
39
39
|
self.variables = config.variables
|
|
40
40
|
self.batch_size = config.batch_size
|
|
41
|
-
self.
|
|
42
|
-
self.num_jets_estimate = config.num_jets_estimate
|
|
41
|
+
self.jets_name = config.jets_name
|
|
43
42
|
self.upscale_pdf = config.sampl_cfg.upscale_pdf or 1
|
|
44
|
-
self.
|
|
43
|
+
self.num_bins = self.get_num_bins_from_config()
|
|
45
44
|
self.methods_map = {
|
|
46
45
|
"pdf": self.pdf_select_func,
|
|
47
46
|
"countup": self.countup_select_func,
|
|
48
47
|
"none": None,
|
|
48
|
+
None: None,
|
|
49
49
|
}
|
|
50
50
|
if self.config.method not in self.methods_map:
|
|
51
51
|
raise ValueError(
|
|
@@ -97,7 +97,7 @@ class Resampling:
|
|
|
97
97
|
num_samples = int(len(jets) * component.sampling_fraction)
|
|
98
98
|
ratios = safe_divide(self.target.hist.pbin, component.hist.pbin)
|
|
99
99
|
if self.upscale_pdf > 1:
|
|
100
|
-
ratios = upscale_array_regionally(ratios, self.upscale_pdf, self.
|
|
100
|
+
ratios = upscale_array_regionally(ratios, self.upscale_pdf, self.num_bins)
|
|
101
101
|
probs = ratios[binnumbers]
|
|
102
102
|
idx = random.choices(np.arange(len(jets)), weights=probs, k=num_samples)
|
|
103
103
|
return idx
|
|
@@ -123,7 +123,7 @@ class Resampling:
|
|
|
123
123
|
|
|
124
124
|
# apply sampling
|
|
125
125
|
idx = np.arange(len(batch_out[self.variables.jets_name]))
|
|
126
|
-
if c != self.target and
|
|
126
|
+
if c != self.target and self.select_func:
|
|
127
127
|
idx = self.select_func(batch_out[self.variables.jets_name], c)
|
|
128
128
|
if len(idx) == 0:
|
|
129
129
|
continue
|
|
@@ -177,6 +177,7 @@ class Resampling:
|
|
|
177
177
|
reader = H5Reader(
|
|
178
178
|
sample.path,
|
|
179
179
|
self.batch_size,
|
|
180
|
+
jets_name=self.jets_name,
|
|
180
181
|
equal_jets=equal_jets_flag,
|
|
181
182
|
transform=self.transform,
|
|
182
183
|
)
|
|
@@ -242,8 +243,8 @@ class Resampling:
|
|
|
242
243
|
# setup i/o
|
|
243
244
|
for c in self.components:
|
|
244
245
|
# just used for the writer configuration
|
|
245
|
-
c.setup_reader(self.batch_size, transform=self.transform)
|
|
246
|
-
c.setup_writer(self.variables)
|
|
246
|
+
c.setup_reader(self.batch_size, jets_name=self.jets_name, transform=self.transform)
|
|
247
|
+
c.setup_writer(self.variables, jets_name=self.jets_name)
|
|
247
248
|
|
|
248
249
|
# set samplig fraction if needed
|
|
249
250
|
self.set_component_sampling_fractions()
|
|
@@ -254,7 +255,7 @@ class Resampling:
|
|
|
254
255
|
f" {self.config.sampling_fraction}..."
|
|
255
256
|
)
|
|
256
257
|
for c in self.components:
|
|
257
|
-
frac = c.sampling_fraction if
|
|
258
|
+
frac = c.sampling_fraction if self.select_func else 1
|
|
258
259
|
c.check_num_jets(c.num_jets, sampling_frac=frac, cuts=c.cuts)
|
|
259
260
|
|
|
260
261
|
# run resampling
|
|
@@ -268,7 +269,7 @@ class Resampling:
|
|
|
268
269
|
log.info(f"[bold green]Estimated unqiue jets: {unique:,.0f}")
|
|
269
270
|
log.info(f"[bold green]Saved to {self.components.out_dir}/")
|
|
270
271
|
|
|
271
|
-
def
|
|
272
|
+
def get_num_bins_from_config(self) -> list[list[int]]:
|
|
272
273
|
"""Get the lengths of the binning regions in each variable from the config.
|
|
273
274
|
|
|
274
275
|
Returns
|
|
@@ -276,7 +277,7 @@ class Resampling:
|
|
|
276
277
|
typing.List[typing.List[int]]
|
|
277
278
|
lengths of the binning regions in each variable from the config
|
|
278
279
|
"""
|
|
279
|
-
|
|
280
|
+
num_bins = []
|
|
280
281
|
for row in self.config.bins.values():
|
|
281
|
-
|
|
282
|
-
return
|
|
282
|
+
num_bins.append([sub[-1] for sub in row])
|
|
283
|
+
return num_bins
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
pyyaml-include==1.3
|
|
2
|
-
PyYAML==6.0.1
|
|
3
|
-
rich==12.6.0
|
|
4
|
-
scipy==1.10.1
|
|
5
|
-
puma-hep==0.3.0
|
|
6
|
-
atlas-ftag-tools==0.1.10
|
|
7
|
-
dotmap==1.3.30
|
|
8
|
-
|
|
9
|
-
[dev]
|
|
10
|
-
black==23.9.1
|
|
11
|
-
ruff==0.0.289
|
|
12
|
-
mypy==1.5.1
|
|
13
|
-
pre-commit==3.1.1
|
|
14
|
-
pytest==7.2.2
|
|
15
|
-
pytest-mock==3.11.1
|
|
16
|
-
pytest-cov==4.0.0
|
|
File without changes
|
|
File without changes
|
{umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{umami-preprocessing-0.0.6 → umami_preprocessing-0.2.0}/umami_preprocessing.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|