atlas-ftag-tools 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/METADATA +4 -3
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/RECORD +14 -12
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/entry_points.txt +1 -1
- ftag/__init__.py +6 -5
- ftag/flavours.yaml +47 -4
- ftag/fraction_optimization.py +184 -0
- ftag/labels.py +10 -2
- ftag/mock.py +58 -17
- ftag/utils/__init__.py +24 -0
- ftag/utils/logging.py +123 -0
- ftag/utils/metrics.py +431 -0
- ftag/working_points.py +547 -0
- ftag/wps/__init__.py +0 -0
- ftag/wps/discriminant.py +0 -131
- ftag/wps/working_points.py +0 -316
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: atlas-ftag-tools
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.10
|
4
4
|
Summary: ATLAS Flavour Tagging Tools
|
5
5
|
Author: Sam Van Stroud, Philipp Gadow
|
6
6
|
License: MIT
|
@@ -8,8 +8,9 @@ Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
|
|
8
8
|
Requires-Python: <3.12,>=3.8
|
9
9
|
Description-Content-Type: text/markdown
|
10
10
|
Requires-Dist: h5py>=3.0
|
11
|
-
Requires-Dist: numpy
|
11
|
+
Requires-Dist: numpy>=2.2.3
|
12
12
|
Requires-Dist: PyYAML>=5.1
|
13
|
+
Requires-Dist: scipy>=1.15.2
|
13
14
|
Provides-Extra: dev
|
14
15
|
Requires-Dist: ruff==0.6.2; extra == "dev"
|
15
16
|
Requires-Dist: mypy==1.11.2; extra == "dev"
|
@@ -1,28 +1,30 @@
|
|
1
|
-
ftag/__init__.py,sha256=
|
1
|
+
ftag/__init__.py,sha256=v9emuK48Hhd-_TCiirfCNMsZSzk52frz1zEOgk9PViQ,787
|
2
2
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
3
3
|
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
4
4
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
5
|
-
ftag/flavours.yaml,sha256=
|
5
|
+
ftag/flavours.yaml,sha256=5Lo9KWe-2KzmGMbc7o_X9gzwUyTl0Q5uVHYExduZ6T4,9502
|
6
|
+
ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
|
6
7
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
7
8
|
ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
|
8
|
-
ftag/labels.py,sha256=
|
9
|
-
ftag/mock.py,sha256=
|
9
|
+
ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
|
10
|
+
ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
|
10
11
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
11
12
|
ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
|
12
13
|
ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
|
13
14
|
ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
|
14
15
|
ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
|
16
|
+
ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
|
15
17
|
ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
|
16
18
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
17
19
|
ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
|
18
20
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
19
21
|
ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
|
20
22
|
ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
|
21
|
-
ftag/
|
22
|
-
ftag/
|
23
|
-
ftag/
|
24
|
-
atlas_ftag_tools-0.2.
|
25
|
-
atlas_ftag_tools-0.2.
|
26
|
-
atlas_ftag_tools-0.2.
|
27
|
-
atlas_ftag_tools-0.2.
|
28
|
-
atlas_ftag_tools-0.2.
|
23
|
+
ftag/utils/__init__.py,sha256=C0PgaA6Nk5WVpFqKhBhrHgj2mwsKJbSxoO6Cl67RsaI,544
|
24
|
+
ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
|
25
|
+
ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
|
26
|
+
atlas_ftag_tools-0.2.10.dist-info/METADATA,sha256=VUhrtQML6_bUKlmZNFlUXxTTt5YBzNYupTrdlaF5IAw,5190
|
27
|
+
atlas_ftag_tools-0.2.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
28
|
+
atlas_ftag_tools-0.2.10.dist-info/entry_points.txt,sha256=b46bVP_O8Mg6aSdPmyjGgVkaXSdyXZMeKAsofh2IDeA,133
|
29
|
+
atlas_ftag_tools-0.2.10.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
30
|
+
atlas_ftag_tools-0.2.10.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
@@ -2,18 +2,18 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
-
__version__ = "v0.2.
|
5
|
+
__version__ = "v0.2.10"
|
6
6
|
|
7
|
-
from ftag import hdf5
|
7
|
+
from ftag import hdf5, utils
|
8
8
|
from ftag.cuts import Cuts
|
9
9
|
from ftag.flavours import Flavours
|
10
|
+
from ftag.fraction_optimization import calculate_best_fraction_values
|
10
11
|
from ftag.labeller import Labeller
|
11
12
|
from ftag.labels import Label, LabelContainer
|
12
13
|
from ftag.mock import get_mock_file
|
13
14
|
from ftag.sample import Sample
|
14
15
|
from ftag.transform import Transform
|
15
|
-
from ftag.
|
16
|
-
from ftag.wps.working_points import get_working_points
|
16
|
+
from ftag.working_points import get_working_points
|
17
17
|
|
18
18
|
__all__ = [
|
19
19
|
"Cuts",
|
@@ -24,8 +24,9 @@ __all__ = [
|
|
24
24
|
"Sample",
|
25
25
|
"Transform",
|
26
26
|
"__version__",
|
27
|
-
"
|
27
|
+
"calculate_best_fraction_values",
|
28
28
|
"get_mock_file",
|
29
29
|
"get_working_points",
|
30
30
|
"hdf5",
|
31
|
+
"utils",
|
31
32
|
]
|
ftag/flavours.yaml
CHANGED
@@ -60,12 +60,24 @@
|
|
60
60
|
colour: tab:orange
|
61
61
|
category: single-btag-ghost
|
62
62
|
_px: pc
|
63
|
-
- name:
|
64
|
-
label:
|
65
|
-
cuts: ["HadronGhostTruthLabelID == 0"]
|
63
|
+
- name: ghostsjets
|
64
|
+
label: $s$-jets
|
65
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 3"]
|
66
|
+
colour: tab:red
|
67
|
+
category: single-btag-ghost
|
68
|
+
_px: ps
|
69
|
+
- name: ghostudjets
|
70
|
+
label: Light-quark-jets
|
71
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID <= 2"]
|
66
72
|
colour: tab:green
|
67
73
|
category: single-btag-ghost
|
68
|
-
_px:
|
74
|
+
_px: pud
|
75
|
+
- name: ghostgjets
|
76
|
+
label: Gluon-jets
|
77
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 21"]
|
78
|
+
colour: tab:gray
|
79
|
+
category: single-btag-ghost
|
80
|
+
_px: pg
|
69
81
|
- name: ghosttaujets
|
70
82
|
label: $\tau$-jets
|
71
83
|
cuts: ["HadronGhostTruthLabelID == 15"]
|
@@ -119,6 +131,21 @@
|
|
119
131
|
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount == 0", "GhostCHadronsFinalCount == 0"]
|
120
132
|
colour: "green"
|
121
133
|
category: xbb
|
134
|
+
- name: htauel
|
135
|
+
label: $H \rightarrow \tau e$
|
136
|
+
cuts: ["R10TruthLabel_R22v1 == 14"]
|
137
|
+
colour: "#b40612"
|
138
|
+
category: xbb
|
139
|
+
- name: htaumu
|
140
|
+
label: $H \rightarrow \tau\mu$
|
141
|
+
cuts: ["R10TruthLabel_R22v1 == 15"]
|
142
|
+
colour: "#b40657"
|
143
|
+
category: xbb
|
144
|
+
- name: htauhad
|
145
|
+
label: $H \rightarrow \tau\tau$
|
146
|
+
cuts: ["R10TruthLabel_R22v1 == 16"]
|
147
|
+
colour: "#b406a0"
|
148
|
+
category: xbb
|
122
149
|
|
123
150
|
# extended Xbb tagging
|
124
151
|
- name: tqqb
|
@@ -305,3 +332,19 @@
|
|
305
332
|
cuts: ["iffClass == 0"]
|
306
333
|
colour: tab:gray
|
307
334
|
category: isolation
|
335
|
+
# Trigger-Xbb tagging
|
336
|
+
- name: dRMatchedHbb
|
337
|
+
label: $H \rightarrow b\bar{b}$
|
338
|
+
cuts: ["HadronConeExclExtendedTruthLabelID == 55", "n_truth_higgs > 0", "n_truth_top == 0"]
|
339
|
+
colour: tab:blue
|
340
|
+
category: trigger-xbb
|
341
|
+
- name: dRMatchedTop
|
342
|
+
label: Inclusive Top
|
343
|
+
cuts: ["n_truth_higgs == 0", "n_truth_top > 0"]
|
344
|
+
colour: "#A300A3"
|
345
|
+
category: trigger-xbb
|
346
|
+
- name: dRMatchedQCD
|
347
|
+
label: QCD
|
348
|
+
cuts: ["n_truth_higgs == 0", "n_truth_top == 0"]
|
349
|
+
colour: "#38761D"
|
350
|
+
category: trigger-xbb
|
@@ -0,0 +1,184 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from scipy.optimize import minimize
|
7
|
+
|
8
|
+
from ftag import Flavours
|
9
|
+
from ftag.utils import calculate_rejection, get_discriminant, logger
|
10
|
+
|
11
|
+
if TYPE_CHECKING: # pragma: no cover
|
12
|
+
from ftag.labels import Label, LabelContainer
|
13
|
+
|
14
|
+
|
15
|
+
def convert_dict(
|
16
|
+
fraction_values: dict | np.ndarray,
|
17
|
+
backgrounds: LabelContainer,
|
18
|
+
) -> np.ndarray | dict:
|
19
|
+
if isinstance(fraction_values, dict):
|
20
|
+
return np.array([fraction_values[iter_bkg.frac_str] for iter_bkg in backgrounds])
|
21
|
+
|
22
|
+
if isinstance(fraction_values, np.ndarray):
|
23
|
+
fraction_values = [
|
24
|
+
float(frac_value / np.sum(fraction_values)) for frac_value in fraction_values
|
25
|
+
]
|
26
|
+
|
27
|
+
return dict(zip([iter_bkg.frac_str for iter_bkg in backgrounds], fraction_values))
|
28
|
+
|
29
|
+
raise ValueError(
|
30
|
+
f"Only input of type `dict` or `np.ndarray` are accepted! You gave {type(fraction_values)}"
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def get_bkg_norm_dict(
|
35
|
+
jets: np.ndarray,
|
36
|
+
tagger: str,
|
37
|
+
signal: Label,
|
38
|
+
flavours: LabelContainer,
|
39
|
+
working_point: float,
|
40
|
+
) -> dict:
|
41
|
+
# Init a dict for the bkg rejection norm values
|
42
|
+
bkg_rej_norm = {}
|
43
|
+
|
44
|
+
# Get the background classes
|
45
|
+
backgrounds = flavours.backgrounds(signal)
|
46
|
+
|
47
|
+
# Define a bool array if the jet is signal
|
48
|
+
is_signal = signal.cuts(jets).idx
|
49
|
+
|
50
|
+
# Loop over backgrounds
|
51
|
+
for bkg in backgrounds:
|
52
|
+
# Get the fraction value dict to maximize rejection for given class
|
53
|
+
frac_dict_bkg = {
|
54
|
+
iter_bkg.frac_str: 1 - (0.01 * len(backgrounds)) if iter_bkg == bkg else 0.01
|
55
|
+
for iter_bkg in backgrounds
|
56
|
+
}
|
57
|
+
|
58
|
+
# Calculate the disc value using the new fraction dict
|
59
|
+
disc = get_discriminant(
|
60
|
+
jets=jets,
|
61
|
+
tagger=tagger,
|
62
|
+
signal=signal,
|
63
|
+
flavours=flavours,
|
64
|
+
fraction_values=frac_dict_bkg,
|
65
|
+
)
|
66
|
+
|
67
|
+
# Calculate the discriminant
|
68
|
+
bkg_rej_norm[bkg.name] = calculate_rejection(
|
69
|
+
sig_disc=disc[is_signal],
|
70
|
+
bkg_disc=disc[bkg.cuts(jets).idx],
|
71
|
+
target_eff=working_point,
|
72
|
+
)
|
73
|
+
|
74
|
+
return bkg_rej_norm
|
75
|
+
|
76
|
+
|
77
|
+
def calculate_rejection_sum(
|
78
|
+
fraction_dict: dict | np.ndarray,
|
79
|
+
jets: np.ndarray,
|
80
|
+
tagger: str,
|
81
|
+
signal: Label,
|
82
|
+
flavours: LabelContainer,
|
83
|
+
working_point: float,
|
84
|
+
bkg_norm_dict: dict,
|
85
|
+
rejection_weights: dict,
|
86
|
+
) -> float:
|
87
|
+
# Get the background classes
|
88
|
+
backgrounds = flavours.backgrounds(signal)
|
89
|
+
|
90
|
+
# Define a bool array if the jet is signal
|
91
|
+
is_signal = signal.cuts(jets).idx
|
92
|
+
|
93
|
+
# Check that the fraction dict is a dict
|
94
|
+
if isinstance(fraction_dict, np.ndarray):
|
95
|
+
fraction_dict = convert_dict(
|
96
|
+
fraction_values=fraction_dict,
|
97
|
+
backgrounds=backgrounds,
|
98
|
+
)
|
99
|
+
|
100
|
+
# Calculate discriminant
|
101
|
+
disc = get_discriminant(
|
102
|
+
jets=jets,
|
103
|
+
tagger=tagger,
|
104
|
+
signal=signal,
|
105
|
+
flavours=flavours,
|
106
|
+
fraction_values=fraction_dict,
|
107
|
+
)
|
108
|
+
|
109
|
+
# Init a dict to which the bkg rejs are added
|
110
|
+
sum_bkg_rej = 0
|
111
|
+
|
112
|
+
# Loop over the backgrounds and calculate the rejections
|
113
|
+
for iter_bkg in backgrounds:
|
114
|
+
sum_bkg_rej += (
|
115
|
+
calculate_rejection(
|
116
|
+
sig_disc=disc[is_signal],
|
117
|
+
bkg_disc=disc[iter_bkg.cuts(jets).idx],
|
118
|
+
target_eff=working_point,
|
119
|
+
)
|
120
|
+
/ bkg_norm_dict[iter_bkg.name]
|
121
|
+
) * rejection_weights[iter_bkg.name]
|
122
|
+
|
123
|
+
# Return the negative sum to enable minimizer
|
124
|
+
return -1 * sum_bkg_rej
|
125
|
+
|
126
|
+
|
127
|
+
def calculate_best_fraction_values(
|
128
|
+
jets: np.ndarray,
|
129
|
+
tagger: str,
|
130
|
+
signal: Label,
|
131
|
+
flavours: LabelContainer,
|
132
|
+
working_point: float,
|
133
|
+
rejection_weights: dict | None = None,
|
134
|
+
optimizer_method: str = "Powell",
|
135
|
+
) -> dict:
|
136
|
+
logger.debug("Calculating best fraction values.")
|
137
|
+
logger.debug(f"tagger: {tagger}")
|
138
|
+
logger.debug(f"signal: {signal}")
|
139
|
+
logger.debug(f"flavours: {flavours}")
|
140
|
+
logger.debug(f"working_point: {working_point}")
|
141
|
+
logger.debug(f"rejection_weights: {rejection_weights}")
|
142
|
+
logger.debug(f"optimizer_method: {optimizer_method}")
|
143
|
+
|
144
|
+
# Ensure Label instance
|
145
|
+
if isinstance(signal, str):
|
146
|
+
signal = Flavours[signal]
|
147
|
+
|
148
|
+
# Get the background classes
|
149
|
+
backgrounds = flavours.backgrounds(signal)
|
150
|
+
|
151
|
+
# Define a default fraction dict
|
152
|
+
def_frac_dict = {iter_bkg.frac_str: 1 / len(backgrounds) for iter_bkg in backgrounds}
|
153
|
+
|
154
|
+
# Define rejection weights if not set
|
155
|
+
if rejection_weights is None:
|
156
|
+
rejection_weights = {iter_bkg.name: 1 for iter_bkg in backgrounds}
|
157
|
+
|
158
|
+
# Get the normalisation for all bkg rejections
|
159
|
+
bkg_norm_dict = get_bkg_norm_dict(
|
160
|
+
jets=jets,
|
161
|
+
tagger=tagger,
|
162
|
+
signal=signal,
|
163
|
+
flavours=flavours,
|
164
|
+
working_point=working_point,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Get the best fraction values combination
|
168
|
+
result = minimize(
|
169
|
+
fun=calculate_rejection_sum,
|
170
|
+
x0=convert_dict(fraction_values=def_frac_dict, backgrounds=backgrounds),
|
171
|
+
method=optimizer_method,
|
172
|
+
bounds=[(0, 1)] * len(backgrounds),
|
173
|
+
args=(jets, tagger, signal, flavours, working_point, bkg_norm_dict, rejection_weights),
|
174
|
+
)
|
175
|
+
|
176
|
+
# Get the final fraction dict
|
177
|
+
final_frac_dict = convert_dict(fraction_values=result.x, backgrounds=backgrounds)
|
178
|
+
|
179
|
+
logger.info(f"Minimization Success: {result.success}")
|
180
|
+
logger.info("The following best fraction values were found:")
|
181
|
+
for frac_str, frac_value in final_frac_dict.items():
|
182
|
+
logger.info(f"{frac_str}: {round(frac_value, ndigits=3)}")
|
183
|
+
|
184
|
+
return final_frac_dict
|
ftag/labels.py
CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
|
|
62
62
|
except KeyError as e:
|
63
63
|
raise KeyError(f"Label '{key}' not found") from e
|
64
64
|
|
65
|
+
def __len__(self) -> int:
|
66
|
+
return len(self.labels.keys())
|
67
|
+
|
65
68
|
def __getattr__(self, name) -> Label:
|
66
69
|
return self[name]
|
67
70
|
|
@@ -120,8 +123,13 @@ class LabelContainer:
|
|
120
123
|
def from_list(cls, labels: list[Label]) -> LabelContainer:
|
121
124
|
return cls({f.name: f for f in labels})
|
122
125
|
|
123
|
-
def backgrounds(self,
|
124
|
-
bkg = [f for f in self if f.category ==
|
126
|
+
def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
|
127
|
+
bkg = [f for f in self if f.category == signal.category and f != signal]
|
125
128
|
if not only_signals:
|
126
129
|
bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
|
130
|
+
if len(bkg) == 0:
|
131
|
+
raise TypeError(
|
132
|
+
"No background flavour could be found in the flavours for signal "
|
133
|
+
f"flavour {signal.name}"
|
134
|
+
)
|
127
135
|
return LabelContainer.from_list(bkg)
|
ftag/mock.py
CHANGED
@@ -54,33 +54,74 @@ TRACK_VARS = [
|
|
54
54
|
]
|
55
55
|
|
56
56
|
|
57
|
-
def softmax(x, axis=None):
|
57
|
+
def softmax(x: np.ndarray, axis: int | None = None) -> np.ndarray:
|
58
|
+
"""Softmax function for numpy arrays.
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
x : np.ndarray
|
63
|
+
Input array for the softmax
|
64
|
+
axis : int | None, optional
|
65
|
+
Axis along which the softmax is calculated, by default None
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
np.ndarray
|
70
|
+
Output array with the softmax output
|
71
|
+
"""
|
58
72
|
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
59
73
|
return e_x / e_x.sum(axis=axis, keepdims=True)
|
60
74
|
|
61
75
|
|
62
|
-
def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
|
63
|
-
means = [
|
64
|
-
[2, 0, 0, 0],
|
65
|
-
[0, 1, 0, 0],
|
66
|
-
[0, 0, 3.5, 0],
|
67
|
-
[0, 0, 0, 1],
|
68
|
-
]
|
76
|
+
def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
|
69
77
|
if not is_xbb:
|
70
78
|
label_dict = {"u": 0, "c": 4, "b": 5, "tau": 15}
|
71
|
-
label_mapping = dict(zip(label_dict.values(), means))
|
72
|
-
else:
|
73
|
-
label_dict = {"hbb": 11, "hcc": 12, "top": 1, "qcd": 10}
|
74
|
-
label_mapping = dict(zip(label_dict.values(), means))
|
75
79
|
|
80
|
+
else:
|
81
|
+
label_dict = {
|
82
|
+
"hbb": 11,
|
83
|
+
"hcc": 12,
|
84
|
+
"top": 1,
|
85
|
+
"qcd": 10,
|
86
|
+
"htauel": 14,
|
87
|
+
"htaumu": 15,
|
88
|
+
"htauhad": 16,
|
89
|
+
}
|
90
|
+
|
91
|
+
# Set random seed
|
76
92
|
rng = np.random.default_rng(42)
|
77
|
-
|
78
|
-
|
79
|
-
|
93
|
+
|
94
|
+
# Set a list of possible means/scales
|
95
|
+
mean_scale_list = [1, 2, 2.5, 3.5]
|
96
|
+
|
97
|
+
# Get the number of classes
|
98
|
+
n_classes = len(label_dict)
|
99
|
+
|
100
|
+
# Init a scores array
|
101
|
+
scores = np.zeros((len(labels), n_classes))
|
102
|
+
|
103
|
+
# Generate means/scales
|
104
|
+
means = []
|
105
|
+
scales = []
|
106
|
+
for i in range(n_classes):
|
107
|
+
tmp_means = []
|
108
|
+
tmp_means = [
|
109
|
+
0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
|
110
|
+
for j in range(n_classes)
|
111
|
+
]
|
112
|
+
means.append(tmp_means)
|
113
|
+
scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
|
114
|
+
|
115
|
+
# Map the labels to the means
|
116
|
+
label_mapping = dict(zip(label_dict.values(), means))
|
117
|
+
|
118
|
+
# Generate random mock scores
|
80
119
|
for i, (label, count) in enumerate(zip(*np.unique(labels, return_counts=True))):
|
81
120
|
scores[labels == label] = rng.normal(
|
82
|
-
loc=label_mapping[label], scale=scales[i], size=(count,
|
121
|
+
loc=label_mapping[label], scale=scales[i], size=(count, n_classes)
|
83
122
|
)
|
123
|
+
|
124
|
+
# Pipe scores through softmax
|
84
125
|
scores = softmax(scores, axis=1)
|
85
126
|
name = "MockXbbTagger" if is_xbb else "MockTagger"
|
86
127
|
cols = [f"{name}_p{x}" for x in label_dict]
|
@@ -103,7 +144,7 @@ def mock_jets(num_jets=1000) -> np.ndarray:
|
|
103
144
|
jets["HadronConeExclTruthLabelID"] = rng.choice([0, 4, 5, 15], size=num_jets)
|
104
145
|
jets["GhostBHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
|
105
146
|
jets["GhostCHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
|
106
|
-
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
|
147
|
+
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12, 14, 15, 16], size=num_jets)
|
107
148
|
scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
|
108
149
|
xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
|
109
150
|
return join_structured_arrays([jets, scores, xbb_scores])
|
ftag/utils/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from ftag.utils.logging import logger, set_log_level
|
4
|
+
from ftag.utils.metrics import (
|
5
|
+
calculate_efficiency,
|
6
|
+
calculate_efficiency_error,
|
7
|
+
calculate_rejection,
|
8
|
+
calculate_rejection_error,
|
9
|
+
get_discriminant,
|
10
|
+
save_divide,
|
11
|
+
weighted_percentile,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"calculate_efficiency",
|
16
|
+
"calculate_efficiency_error",
|
17
|
+
"calculate_rejection",
|
18
|
+
"calculate_rejection_error",
|
19
|
+
"get_discriminant",
|
20
|
+
"logger",
|
21
|
+
"save_divide",
|
22
|
+
"set_log_level",
|
23
|
+
"weighted_percentile",
|
24
|
+
]
|
ftag/utils/logging.py
ADDED
@@ -0,0 +1,123 @@
|
|
1
|
+
"""Configuration for logger of atlas-ftag-tools."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import ClassVar
|
8
|
+
|
9
|
+
|
10
|
+
class CustomFormatter(logging.Formatter):
|
11
|
+
"""
|
12
|
+
Logging Formatter to add colours and count warning / errors using implementation
|
13
|
+
from
|
14
|
+
https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
|
15
|
+
"""
|
16
|
+
|
17
|
+
grey = "\x1b[38;21m"
|
18
|
+
yellow = "\x1b[33;21m"
|
19
|
+
green = "\x1b[32;21m"
|
20
|
+
red = "\x1b[31;21m"
|
21
|
+
bold_red = "\x1b[31;1m"
|
22
|
+
reset = "\x1b[0m"
|
23
|
+
debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
|
24
|
+
date_format = "%(levelname)s:%(name)s: %(message)s"
|
25
|
+
|
26
|
+
formats: ClassVar = {
|
27
|
+
logging.DEBUG: grey + debugformat + reset,
|
28
|
+
logging.INFO: green + date_format + reset,
|
29
|
+
logging.WARNING: yellow + date_format + reset,
|
30
|
+
logging.ERROR: red + debugformat + reset,
|
31
|
+
logging.CRITICAL: bold_red + debugformat + reset,
|
32
|
+
}
|
33
|
+
|
34
|
+
def format(self, record):
|
35
|
+
log_fmt = self.formats.get(record.levelno)
|
36
|
+
formatter = logging.Formatter(log_fmt)
|
37
|
+
return formatter.format(record)
|
38
|
+
|
39
|
+
|
40
|
+
def get_log_level(
|
41
|
+
level: str,
|
42
|
+
):
|
43
|
+
"""Get logging levels with string key.
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
level : str
|
48
|
+
Log level as string.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
logging level
|
53
|
+
logging object with log level info
|
54
|
+
|
55
|
+
Raises
|
56
|
+
------
|
57
|
+
ValueError
|
58
|
+
If non-valid option is given
|
59
|
+
"""
|
60
|
+
log_levels = {
|
61
|
+
"CRITICAL": logging.CRITICAL,
|
62
|
+
"ERROR": logging.ERROR,
|
63
|
+
"WARNING": logging.WARNING,
|
64
|
+
"INFO": logging.INFO,
|
65
|
+
"DEBUG": logging.DEBUG,
|
66
|
+
"NOTSET": logging.NOTSET,
|
67
|
+
}
|
68
|
+
if level not in log_levels:
|
69
|
+
raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
|
70
|
+
return log_levels[level]
|
71
|
+
|
72
|
+
|
73
|
+
def initialise_logger(
|
74
|
+
log_level: str | None = None,
|
75
|
+
):
|
76
|
+
"""Initialise.
|
77
|
+
|
78
|
+
Parameters
|
79
|
+
----------
|
80
|
+
log_level : str, optional
|
81
|
+
Logging level defining the verbose level. Accepted values are:
|
82
|
+
CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
|
83
|
+
If the log_level is not set, it will be set to info
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
logger
|
88
|
+
logger object with new level set
|
89
|
+
"""
|
90
|
+
retrieved_log_level = get_log_level(
|
91
|
+
os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
|
92
|
+
)
|
93
|
+
|
94
|
+
tools_logger = logging.getLogger("atlas-ftag-tools")
|
95
|
+
tools_logger.setLevel(retrieved_log_level)
|
96
|
+
ch_handler = logging.StreamHandler()
|
97
|
+
ch_handler.setLevel(retrieved_log_level)
|
98
|
+
ch_handler.setFormatter(CustomFormatter())
|
99
|
+
|
100
|
+
tools_logger.addHandler(ch_handler)
|
101
|
+
tools_logger.propagate = False
|
102
|
+
return tools_logger
|
103
|
+
|
104
|
+
|
105
|
+
def set_log_level(
|
106
|
+
tools_logger,
|
107
|
+
log_level: str,
|
108
|
+
):
|
109
|
+
"""Setting log level.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
tools_logger : logger
|
114
|
+
logger object
|
115
|
+
log_level : str
|
116
|
+
Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
|
117
|
+
"""
|
118
|
+
tools_logger.setLevel(get_log_level(log_level))
|
119
|
+
for handler in tools_logger.handlers:
|
120
|
+
handler.setLevel(get_log_level(log_level))
|
121
|
+
|
122
|
+
|
123
|
+
logger = initialise_logger()
|