atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.10.dist-info}/METADATA +4 -3
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.10.dist-info}/RECORD +13 -11
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.10.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.10.dist-info}/entry_points.txt +1 -1
- ftag/__init__.py +6 -5
- ftag/flavours.yaml +16 -0
- ftag/fraction_optimization.py +184 -0
- ftag/labels.py +10 -2
- ftag/utils/__init__.py +24 -0
- ftag/utils/logging.py +123 -0
- ftag/utils/metrics.py +431 -0
- ftag/{wps/working_points.py → working_points.py} +1 -1
- ftag/wps/__init__.py +0 -0
- ftag/wps/discriminant.py +0 -84
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.10.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: atlas-ftag-tools
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.10
|
4
4
|
Summary: ATLAS Flavour Tagging Tools
|
5
5
|
Author: Sam Van Stroud, Philipp Gadow
|
6
6
|
License: MIT
|
@@ -8,8 +8,9 @@ Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
|
|
8
8
|
Requires-Python: <3.12,>=3.8
|
9
9
|
Description-Content-Type: text/markdown
|
10
10
|
Requires-Dist: h5py>=3.0
|
11
|
-
Requires-Dist: numpy
|
11
|
+
Requires-Dist: numpy>=2.2.3
|
12
12
|
Requires-Dist: PyYAML>=5.1
|
13
|
+
Requires-Dist: scipy>=1.15.2
|
13
14
|
Provides-Extra: dev
|
14
15
|
Requires-Dist: ruff==0.6.2; extra == "dev"
|
15
16
|
Requires-Dist: mypy==1.11.2; extra == "dev"
|
@@ -1,28 +1,30 @@
|
|
1
|
-
ftag/__init__.py,sha256=
|
1
|
+
ftag/__init__.py,sha256=v9emuK48Hhd-_TCiirfCNMsZSzk52frz1zEOgk9PViQ,787
|
2
2
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
3
3
|
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
4
4
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
5
|
-
ftag/flavours.yaml,sha256=
|
5
|
+
ftag/flavours.yaml,sha256=5Lo9KWe-2KzmGMbc7o_X9gzwUyTl0Q5uVHYExduZ6T4,9502
|
6
|
+
ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
|
6
7
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
7
8
|
ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
|
8
|
-
ftag/labels.py,sha256=
|
9
|
+
ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
|
9
10
|
ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
|
10
11
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
11
12
|
ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
|
12
13
|
ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
|
13
14
|
ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
|
14
15
|
ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
|
16
|
+
ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
|
15
17
|
ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
|
16
18
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
17
19
|
ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
|
18
20
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
19
21
|
ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
|
20
22
|
ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
|
21
|
-
ftag/
|
22
|
-
ftag/
|
23
|
-
ftag/
|
24
|
-
atlas_ftag_tools-0.2.
|
25
|
-
atlas_ftag_tools-0.2.
|
26
|
-
atlas_ftag_tools-0.2.
|
27
|
-
atlas_ftag_tools-0.2.
|
28
|
-
atlas_ftag_tools-0.2.
|
23
|
+
ftag/utils/__init__.py,sha256=C0PgaA6Nk5WVpFqKhBhrHgj2mwsKJbSxoO6Cl67RsaI,544
|
24
|
+
ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
|
25
|
+
ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
|
26
|
+
atlas_ftag_tools-0.2.10.dist-info/METADATA,sha256=VUhrtQML6_bUKlmZNFlUXxTTt5YBzNYupTrdlaF5IAw,5190
|
27
|
+
atlas_ftag_tools-0.2.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
28
|
+
atlas_ftag_tools-0.2.10.dist-info/entry_points.txt,sha256=b46bVP_O8Mg6aSdPmyjGgVkaXSdyXZMeKAsofh2IDeA,133
|
29
|
+
atlas_ftag_tools-0.2.10.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
30
|
+
atlas_ftag_tools-0.2.10.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
@@ -2,18 +2,18 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
-
__version__ = "v0.2.
|
5
|
+
__version__ = "v0.2.10"
|
6
6
|
|
7
|
-
from ftag import hdf5
|
7
|
+
from ftag import hdf5, utils
|
8
8
|
from ftag.cuts import Cuts
|
9
9
|
from ftag.flavours import Flavours
|
10
|
+
from ftag.fraction_optimization import calculate_best_fraction_values
|
10
11
|
from ftag.labeller import Labeller
|
11
12
|
from ftag.labels import Label, LabelContainer
|
12
13
|
from ftag.mock import get_mock_file
|
13
14
|
from ftag.sample import Sample
|
14
15
|
from ftag.transform import Transform
|
15
|
-
from ftag.
|
16
|
-
from ftag.wps.working_points import get_working_points
|
16
|
+
from ftag.working_points import get_working_points
|
17
17
|
|
18
18
|
__all__ = [
|
19
19
|
"Cuts",
|
@@ -24,8 +24,9 @@ __all__ = [
|
|
24
24
|
"Sample",
|
25
25
|
"Transform",
|
26
26
|
"__version__",
|
27
|
-
"
|
27
|
+
"calculate_best_fraction_values",
|
28
28
|
"get_mock_file",
|
29
29
|
"get_working_points",
|
30
30
|
"hdf5",
|
31
|
+
"utils",
|
31
32
|
]
|
ftag/flavours.yaml
CHANGED
@@ -332,3 +332,19 @@
|
|
332
332
|
cuts: ["iffClass == 0"]
|
333
333
|
colour: tab:gray
|
334
334
|
category: isolation
|
335
|
+
# Trigger-Xbb tagging
|
336
|
+
- name: dRMatchedHbb
|
337
|
+
label: $H \rightarrow b\bar{b}$
|
338
|
+
cuts: ["HadronConeExclExtendedTruthLabelID == 55", "n_truth_higgs > 0", "n_truth_top == 0"]
|
339
|
+
colour: tab:blue
|
340
|
+
category: trigger-xbb
|
341
|
+
- name: dRMatchedTop
|
342
|
+
label: Inclusive Top
|
343
|
+
cuts: ["n_truth_higgs == 0", "n_truth_top > 0"]
|
344
|
+
colour: "#A300A3"
|
345
|
+
category: trigger-xbb
|
346
|
+
- name: dRMatchedQCD
|
347
|
+
label: QCD
|
348
|
+
cuts: ["n_truth_higgs == 0", "n_truth_top == 0"]
|
349
|
+
colour: "#38761D"
|
350
|
+
category: trigger-xbb
|
@@ -0,0 +1,184 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from scipy.optimize import minimize
|
7
|
+
|
8
|
+
from ftag import Flavours
|
9
|
+
from ftag.utils import calculate_rejection, get_discriminant, logger
|
10
|
+
|
11
|
+
if TYPE_CHECKING: # pragma: no cover
|
12
|
+
from ftag.labels import Label, LabelContainer
|
13
|
+
|
14
|
+
|
15
|
+
def convert_dict(
|
16
|
+
fraction_values: dict | np.ndarray,
|
17
|
+
backgrounds: LabelContainer,
|
18
|
+
) -> np.ndarray | dict:
|
19
|
+
if isinstance(fraction_values, dict):
|
20
|
+
return np.array([fraction_values[iter_bkg.frac_str] for iter_bkg in backgrounds])
|
21
|
+
|
22
|
+
if isinstance(fraction_values, np.ndarray):
|
23
|
+
fraction_values = [
|
24
|
+
float(frac_value / np.sum(fraction_values)) for frac_value in fraction_values
|
25
|
+
]
|
26
|
+
|
27
|
+
return dict(zip([iter_bkg.frac_str for iter_bkg in backgrounds], fraction_values))
|
28
|
+
|
29
|
+
raise ValueError(
|
30
|
+
f"Only input of type `dict` or `np.ndarray` are accepted! You gave {type(fraction_values)}"
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def get_bkg_norm_dict(
|
35
|
+
jets: np.ndarray,
|
36
|
+
tagger: str,
|
37
|
+
signal: Label,
|
38
|
+
flavours: LabelContainer,
|
39
|
+
working_point: float,
|
40
|
+
) -> dict:
|
41
|
+
# Init a dict for the bkg rejection norm values
|
42
|
+
bkg_rej_norm = {}
|
43
|
+
|
44
|
+
# Get the background classes
|
45
|
+
backgrounds = flavours.backgrounds(signal)
|
46
|
+
|
47
|
+
# Define a bool array if the jet is signal
|
48
|
+
is_signal = signal.cuts(jets).idx
|
49
|
+
|
50
|
+
# Loop over backgrounds
|
51
|
+
for bkg in backgrounds:
|
52
|
+
# Get the fraction value dict to maximize rejection for given class
|
53
|
+
frac_dict_bkg = {
|
54
|
+
iter_bkg.frac_str: 1 - (0.01 * len(backgrounds)) if iter_bkg == bkg else 0.01
|
55
|
+
for iter_bkg in backgrounds
|
56
|
+
}
|
57
|
+
|
58
|
+
# Calculate the disc value using the new fraction dict
|
59
|
+
disc = get_discriminant(
|
60
|
+
jets=jets,
|
61
|
+
tagger=tagger,
|
62
|
+
signal=signal,
|
63
|
+
flavours=flavours,
|
64
|
+
fraction_values=frac_dict_bkg,
|
65
|
+
)
|
66
|
+
|
67
|
+
# Calculate the discriminant
|
68
|
+
bkg_rej_norm[bkg.name] = calculate_rejection(
|
69
|
+
sig_disc=disc[is_signal],
|
70
|
+
bkg_disc=disc[bkg.cuts(jets).idx],
|
71
|
+
target_eff=working_point,
|
72
|
+
)
|
73
|
+
|
74
|
+
return bkg_rej_norm
|
75
|
+
|
76
|
+
|
77
|
+
def calculate_rejection_sum(
|
78
|
+
fraction_dict: dict | np.ndarray,
|
79
|
+
jets: np.ndarray,
|
80
|
+
tagger: str,
|
81
|
+
signal: Label,
|
82
|
+
flavours: LabelContainer,
|
83
|
+
working_point: float,
|
84
|
+
bkg_norm_dict: dict,
|
85
|
+
rejection_weights: dict,
|
86
|
+
) -> float:
|
87
|
+
# Get the background classes
|
88
|
+
backgrounds = flavours.backgrounds(signal)
|
89
|
+
|
90
|
+
# Define a bool array if the jet is signal
|
91
|
+
is_signal = signal.cuts(jets).idx
|
92
|
+
|
93
|
+
# Check that the fraction dict is a dict
|
94
|
+
if isinstance(fraction_dict, np.ndarray):
|
95
|
+
fraction_dict = convert_dict(
|
96
|
+
fraction_values=fraction_dict,
|
97
|
+
backgrounds=backgrounds,
|
98
|
+
)
|
99
|
+
|
100
|
+
# Calculate discriminant
|
101
|
+
disc = get_discriminant(
|
102
|
+
jets=jets,
|
103
|
+
tagger=tagger,
|
104
|
+
signal=signal,
|
105
|
+
flavours=flavours,
|
106
|
+
fraction_values=fraction_dict,
|
107
|
+
)
|
108
|
+
|
109
|
+
# Init a dict to which the bkg rejs are added
|
110
|
+
sum_bkg_rej = 0
|
111
|
+
|
112
|
+
# Loop over the backgrounds and calculate the rejections
|
113
|
+
for iter_bkg in backgrounds:
|
114
|
+
sum_bkg_rej += (
|
115
|
+
calculate_rejection(
|
116
|
+
sig_disc=disc[is_signal],
|
117
|
+
bkg_disc=disc[iter_bkg.cuts(jets).idx],
|
118
|
+
target_eff=working_point,
|
119
|
+
)
|
120
|
+
/ bkg_norm_dict[iter_bkg.name]
|
121
|
+
) * rejection_weights[iter_bkg.name]
|
122
|
+
|
123
|
+
# Return the negative sum to enable minimizer
|
124
|
+
return -1 * sum_bkg_rej
|
125
|
+
|
126
|
+
|
127
|
+
def calculate_best_fraction_values(
|
128
|
+
jets: np.ndarray,
|
129
|
+
tagger: str,
|
130
|
+
signal: Label,
|
131
|
+
flavours: LabelContainer,
|
132
|
+
working_point: float,
|
133
|
+
rejection_weights: dict | None = None,
|
134
|
+
optimizer_method: str = "Powell",
|
135
|
+
) -> dict:
|
136
|
+
logger.debug("Calculating best fraction values.")
|
137
|
+
logger.debug(f"tagger: {tagger}")
|
138
|
+
logger.debug(f"signal: {signal}")
|
139
|
+
logger.debug(f"flavours: {flavours}")
|
140
|
+
logger.debug(f"working_point: {working_point}")
|
141
|
+
logger.debug(f"rejection_weights: {rejection_weights}")
|
142
|
+
logger.debug(f"optimizer_method: {optimizer_method}")
|
143
|
+
|
144
|
+
# Ensure Label instance
|
145
|
+
if isinstance(signal, str):
|
146
|
+
signal = Flavours[signal]
|
147
|
+
|
148
|
+
# Get the background classes
|
149
|
+
backgrounds = flavours.backgrounds(signal)
|
150
|
+
|
151
|
+
# Define a default fraction dict
|
152
|
+
def_frac_dict = {iter_bkg.frac_str: 1 / len(backgrounds) for iter_bkg in backgrounds}
|
153
|
+
|
154
|
+
# Define rejection weights if not set
|
155
|
+
if rejection_weights is None:
|
156
|
+
rejection_weights = {iter_bkg.name: 1 for iter_bkg in backgrounds}
|
157
|
+
|
158
|
+
# Get the normalisation for all bkg rejections
|
159
|
+
bkg_norm_dict = get_bkg_norm_dict(
|
160
|
+
jets=jets,
|
161
|
+
tagger=tagger,
|
162
|
+
signal=signal,
|
163
|
+
flavours=flavours,
|
164
|
+
working_point=working_point,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Get the best fraction values combination
|
168
|
+
result = minimize(
|
169
|
+
fun=calculate_rejection_sum,
|
170
|
+
x0=convert_dict(fraction_values=def_frac_dict, backgrounds=backgrounds),
|
171
|
+
method=optimizer_method,
|
172
|
+
bounds=[(0, 1)] * len(backgrounds),
|
173
|
+
args=(jets, tagger, signal, flavours, working_point, bkg_norm_dict, rejection_weights),
|
174
|
+
)
|
175
|
+
|
176
|
+
# Get the final fraction dict
|
177
|
+
final_frac_dict = convert_dict(fraction_values=result.x, backgrounds=backgrounds)
|
178
|
+
|
179
|
+
logger.info(f"Minimization Success: {result.success}")
|
180
|
+
logger.info("The following best fraction values were found:")
|
181
|
+
for frac_str, frac_value in final_frac_dict.items():
|
182
|
+
logger.info(f"{frac_str}: {round(frac_value, ndigits=3)}")
|
183
|
+
|
184
|
+
return final_frac_dict
|
ftag/labels.py
CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
|
|
62
62
|
except KeyError as e:
|
63
63
|
raise KeyError(f"Label '{key}' not found") from e
|
64
64
|
|
65
|
+
def __len__(self) -> int:
|
66
|
+
return len(self.labels.keys())
|
67
|
+
|
65
68
|
def __getattr__(self, name) -> Label:
|
66
69
|
return self[name]
|
67
70
|
|
@@ -120,8 +123,13 @@ class LabelContainer:
|
|
120
123
|
def from_list(cls, labels: list[Label]) -> LabelContainer:
|
121
124
|
return cls({f.name: f for f in labels})
|
122
125
|
|
123
|
-
def backgrounds(self,
|
124
|
-
bkg = [f for f in self if f.category ==
|
126
|
+
def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
|
127
|
+
bkg = [f for f in self if f.category == signal.category and f != signal]
|
125
128
|
if not only_signals:
|
126
129
|
bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
|
130
|
+
if len(bkg) == 0:
|
131
|
+
raise TypeError(
|
132
|
+
"No background flavour could be found in the flavours for signal "
|
133
|
+
f"flavour {signal.name}"
|
134
|
+
)
|
127
135
|
return LabelContainer.from_list(bkg)
|
ftag/utils/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from ftag.utils.logging import logger, set_log_level
|
4
|
+
from ftag.utils.metrics import (
|
5
|
+
calculate_efficiency,
|
6
|
+
calculate_efficiency_error,
|
7
|
+
calculate_rejection,
|
8
|
+
calculate_rejection_error,
|
9
|
+
get_discriminant,
|
10
|
+
save_divide,
|
11
|
+
weighted_percentile,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"calculate_efficiency",
|
16
|
+
"calculate_efficiency_error",
|
17
|
+
"calculate_rejection",
|
18
|
+
"calculate_rejection_error",
|
19
|
+
"get_discriminant",
|
20
|
+
"logger",
|
21
|
+
"save_divide",
|
22
|
+
"set_log_level",
|
23
|
+
"weighted_percentile",
|
24
|
+
]
|
ftag/utils/logging.py
ADDED
@@ -0,0 +1,123 @@
|
|
1
|
+
"""Configuration for logger of atlas-ftag-tools."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import ClassVar
|
8
|
+
|
9
|
+
|
10
|
+
class CustomFormatter(logging.Formatter):
|
11
|
+
"""
|
12
|
+
Logging Formatter to add colours and count warning / errors using implementation
|
13
|
+
from
|
14
|
+
https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
|
15
|
+
"""
|
16
|
+
|
17
|
+
grey = "\x1b[38;21m"
|
18
|
+
yellow = "\x1b[33;21m"
|
19
|
+
green = "\x1b[32;21m"
|
20
|
+
red = "\x1b[31;21m"
|
21
|
+
bold_red = "\x1b[31;1m"
|
22
|
+
reset = "\x1b[0m"
|
23
|
+
debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
|
24
|
+
date_format = "%(levelname)s:%(name)s: %(message)s"
|
25
|
+
|
26
|
+
formats: ClassVar = {
|
27
|
+
logging.DEBUG: grey + debugformat + reset,
|
28
|
+
logging.INFO: green + date_format + reset,
|
29
|
+
logging.WARNING: yellow + date_format + reset,
|
30
|
+
logging.ERROR: red + debugformat + reset,
|
31
|
+
logging.CRITICAL: bold_red + debugformat + reset,
|
32
|
+
}
|
33
|
+
|
34
|
+
def format(self, record):
|
35
|
+
log_fmt = self.formats.get(record.levelno)
|
36
|
+
formatter = logging.Formatter(log_fmt)
|
37
|
+
return formatter.format(record)
|
38
|
+
|
39
|
+
|
40
|
+
def get_log_level(
|
41
|
+
level: str,
|
42
|
+
):
|
43
|
+
"""Get logging levels with string key.
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
level : str
|
48
|
+
Log level as string.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
logging level
|
53
|
+
logging object with log level info
|
54
|
+
|
55
|
+
Raises
|
56
|
+
------
|
57
|
+
ValueError
|
58
|
+
If non-valid option is given
|
59
|
+
"""
|
60
|
+
log_levels = {
|
61
|
+
"CRITICAL": logging.CRITICAL,
|
62
|
+
"ERROR": logging.ERROR,
|
63
|
+
"WARNING": logging.WARNING,
|
64
|
+
"INFO": logging.INFO,
|
65
|
+
"DEBUG": logging.DEBUG,
|
66
|
+
"NOTSET": logging.NOTSET,
|
67
|
+
}
|
68
|
+
if level not in log_levels:
|
69
|
+
raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
|
70
|
+
return log_levels[level]
|
71
|
+
|
72
|
+
|
73
|
+
def initialise_logger(
|
74
|
+
log_level: str | None = None,
|
75
|
+
):
|
76
|
+
"""Initialise.
|
77
|
+
|
78
|
+
Parameters
|
79
|
+
----------
|
80
|
+
log_level : str, optional
|
81
|
+
Logging level defining the verbose level. Accepted values are:
|
82
|
+
CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
|
83
|
+
If the log_level is not set, it will be set to info
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
logger
|
88
|
+
logger object with new level set
|
89
|
+
"""
|
90
|
+
retrieved_log_level = get_log_level(
|
91
|
+
os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
|
92
|
+
)
|
93
|
+
|
94
|
+
tools_logger = logging.getLogger("atlas-ftag-tools")
|
95
|
+
tools_logger.setLevel(retrieved_log_level)
|
96
|
+
ch_handler = logging.StreamHandler()
|
97
|
+
ch_handler.setLevel(retrieved_log_level)
|
98
|
+
ch_handler.setFormatter(CustomFormatter())
|
99
|
+
|
100
|
+
tools_logger.addHandler(ch_handler)
|
101
|
+
tools_logger.propagate = False
|
102
|
+
return tools_logger
|
103
|
+
|
104
|
+
|
105
|
+
def set_log_level(
|
106
|
+
tools_logger,
|
107
|
+
log_level: str,
|
108
|
+
):
|
109
|
+
"""Setting log level.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
tools_logger : logger
|
114
|
+
logger object
|
115
|
+
log_level : str
|
116
|
+
Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
|
117
|
+
"""
|
118
|
+
tools_logger.setLevel(get_log_level(log_level))
|
119
|
+
for handler in tools_logger.handlers:
|
120
|
+
handler.setLevel(get_log_level(log_level))
|
121
|
+
|
122
|
+
|
123
|
+
logger = initialise_logger()
|
ftag/utils/metrics.py
ADDED
@@ -0,0 +1,431 @@
|
|
1
|
+
"""Tools for metrics module."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from scipy.ndimage import gaussian_filter1d
|
9
|
+
|
10
|
+
from ftag.utils import logger
|
11
|
+
|
12
|
+
if TYPE_CHECKING: # pragma: no cover
|
13
|
+
from ftag.labels import Label, LabelContainer
|
14
|
+
|
15
|
+
|
16
|
+
def save_divide(
|
17
|
+
numerator: np.ndarray | float,
|
18
|
+
denominator: np.ndarray | float,
|
19
|
+
default: float = 1.0,
|
20
|
+
):
|
21
|
+
"""Save divide for denominator equal to 0.
|
22
|
+
|
23
|
+
Division using numpy divide function returning default value in cases where
|
24
|
+
denominator is 0.
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
numerator: np.ndarray | float,
|
29
|
+
Numerator in the ratio calculation.
|
30
|
+
denominator: np.ndarray | float,
|
31
|
+
Denominator in the ratio calculation.
|
32
|
+
default: float
|
33
|
+
Default value which is returned if denominator is 0.
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
ratio: np.ndarray | float
|
38
|
+
Result of the division
|
39
|
+
"""
|
40
|
+
logger.debug("Calculating save division.")
|
41
|
+
logger.debug("numerator: %s", numerator)
|
42
|
+
logger.debug("denominator: %s", denominator)
|
43
|
+
logger.debug("default: %s", default)
|
44
|
+
|
45
|
+
if isinstance(numerator, (int, float, np.number)) and isinstance(
|
46
|
+
denominator, (int, float, np.number)
|
47
|
+
):
|
48
|
+
output_shape = 1
|
49
|
+
else:
|
50
|
+
try:
|
51
|
+
output_shape = denominator.shape
|
52
|
+
except AttributeError:
|
53
|
+
output_shape = numerator.shape
|
54
|
+
|
55
|
+
ratio = np.divide(
|
56
|
+
numerator,
|
57
|
+
denominator,
|
58
|
+
out=np.ones(
|
59
|
+
output_shape,
|
60
|
+
dtype=float,
|
61
|
+
)
|
62
|
+
* default,
|
63
|
+
where=(denominator != 0),
|
64
|
+
)
|
65
|
+
if output_shape == 1:
|
66
|
+
return float(ratio)
|
67
|
+
return ratio
|
68
|
+
|
69
|
+
|
70
|
+
def weighted_percentile(
|
71
|
+
arr: np.ndarray,
|
72
|
+
percentile: np.ndarray,
|
73
|
+
weights: np.ndarray = None,
|
74
|
+
):
|
75
|
+
"""Calculate weighted percentile.
|
76
|
+
|
77
|
+
Implementation according to https://stackoverflow.com/a/29677616/11509698
|
78
|
+
(https://en.wikipedia.org/wiki/Percentile#The_weighted_percentile_method)
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
arr : np.ndarray
|
83
|
+
Data array
|
84
|
+
percentile : np.ndarray
|
85
|
+
Percentile array
|
86
|
+
weights : np.ndarray
|
87
|
+
Weights array, by default None
|
88
|
+
|
89
|
+
Returns
|
90
|
+
-------
|
91
|
+
np.ndarray
|
92
|
+
Weighted percentile array
|
93
|
+
"""
|
94
|
+
logger.debug("Calculating weighted percentile.")
|
95
|
+
logger.debug("arr: %s", arr)
|
96
|
+
logger.debug("percentile: %s", percentile)
|
97
|
+
logger.debug("weights: %s", weights)
|
98
|
+
|
99
|
+
# Set weights to one if no weights are given
|
100
|
+
if weights is None:
|
101
|
+
weights = np.ones_like(arr)
|
102
|
+
|
103
|
+
# Set dtype to float64 if the weights are too large
|
104
|
+
dtype = np.float64 if np.sum(weights) > 1000000 else np.float32
|
105
|
+
|
106
|
+
# Get an array sorting and sort the array and the weights
|
107
|
+
ix = np.argsort(arr)
|
108
|
+
arr = arr[ix]
|
109
|
+
weights = weights[ix]
|
110
|
+
|
111
|
+
# Return the cumulative sum
|
112
|
+
cdf = np.cumsum(weights, dtype=dtype) - 0.5 * weights
|
113
|
+
cdf -= cdf[0]
|
114
|
+
cdf /= cdf[-1]
|
115
|
+
|
116
|
+
# Return the linear interpolation
|
117
|
+
return np.interp(percentile, cdf, arr)
|
118
|
+
|
119
|
+
|
120
|
+
def calculate_efficiency(
|
121
|
+
sig_disc: np.ndarray,
|
122
|
+
bkg_disc: np.ndarray,
|
123
|
+
target_eff: float | list | np.ndarray,
|
124
|
+
return_cuts: bool = False,
|
125
|
+
sig_weights: np.ndarray = None,
|
126
|
+
bkg_weights: np.ndarray = None,
|
127
|
+
):
|
128
|
+
"""Calculate efficiency.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
sig_disc : np.ndarray
|
133
|
+
Signal discriminant
|
134
|
+
bkg_disc : np.ndarray
|
135
|
+
Background discriminant
|
136
|
+
target_eff : float or list or np.ndarray
|
137
|
+
Working point which is used for discriminant calculation
|
138
|
+
return_cuts : bool
|
139
|
+
Specifies if the cut values corresponding to the provided WPs are returned.
|
140
|
+
If target_eff is a float, only one cut value will be returned. If target_eff
|
141
|
+
is an array, target_eff is an array as well.
|
142
|
+
sig_weights : np.ndarray
|
143
|
+
Weights for signal events
|
144
|
+
bkg_weights : np.ndarray
|
145
|
+
Weights for background events
|
146
|
+
|
147
|
+
Returns
|
148
|
+
-------
|
149
|
+
eff : float or np.ndarray
|
150
|
+
Efficiency.
|
151
|
+
Return float if target_eff is a float, else np.ndarray
|
152
|
+
cutvalue : float or np.ndarray
|
153
|
+
Cutvalue if return_cuts is True.
|
154
|
+
Return float if target_eff is a float, else np.ndarray
|
155
|
+
"""
|
156
|
+
logger.debug("Calculating efficiency.")
|
157
|
+
logger.debug("sig_disc: %s", sig_disc)
|
158
|
+
logger.debug("bkg_disc: %s", bkg_disc)
|
159
|
+
logger.debug("target_eff: %s", target_eff)
|
160
|
+
logger.debug("return_cuts: %s", return_cuts)
|
161
|
+
logger.debug("sig_weights: %s", sig_weights)
|
162
|
+
logger.debug("bkg_weights: %s", bkg_weights)
|
163
|
+
|
164
|
+
# float | np.ndarray for both target_eff and the returned values
|
165
|
+
return_float = False
|
166
|
+
if isinstance(target_eff, float):
|
167
|
+
return_float = True
|
168
|
+
|
169
|
+
# Flatten the target efficiencies
|
170
|
+
target_eff = np.asarray([target_eff]).flatten()
|
171
|
+
|
172
|
+
# Get the cutvalue for the given target efficiency
|
173
|
+
cutvalue = weighted_percentile(arr=sig_disc, percentile=1.0 - target_eff, weights=sig_weights)
|
174
|
+
|
175
|
+
# Sort the cutvalues to get the correct order
|
176
|
+
sorted_args = np.argsort(1 - target_eff)
|
177
|
+
|
178
|
+
# Get the histogram for the backgrounds
|
179
|
+
hist, _ = np.histogram(bkg_disc, (-np.inf, *cutvalue[sorted_args], np.inf), weights=bkg_weights)
|
180
|
+
|
181
|
+
# Calculate the efficiencies for the calculated cut values
|
182
|
+
eff = hist[::-1].cumsum()[-2::-1] / hist.sum()
|
183
|
+
eff = eff[sorted_args]
|
184
|
+
|
185
|
+
# Ensure that a float is returned if float was given
|
186
|
+
if return_float:
|
187
|
+
eff = eff[0]
|
188
|
+
cutvalue = cutvalue[0]
|
189
|
+
|
190
|
+
# Also return the cuts if wanted
|
191
|
+
if return_cuts:
|
192
|
+
return eff, cutvalue
|
193
|
+
|
194
|
+
return eff
|
195
|
+
|
196
|
+
|
197
|
+
def calculate_rejection(
|
198
|
+
sig_disc: np.ndarray,
|
199
|
+
bkg_disc: np.ndarray,
|
200
|
+
target_eff,
|
201
|
+
return_cuts: bool = False,
|
202
|
+
sig_weights: np.ndarray = None,
|
203
|
+
bkg_weights: np.ndarray = None,
|
204
|
+
smooth: bool = False,
|
205
|
+
):
|
206
|
+
"""Calculate rejection.
|
207
|
+
|
208
|
+
Parameters
|
209
|
+
----------
|
210
|
+
sig_disc : np.ndarray
|
211
|
+
Signal discriminant
|
212
|
+
bkg_disc : np.ndarray
|
213
|
+
Background discriminant
|
214
|
+
target_eff : float or list
|
215
|
+
Working point which is used for discriminant calculation
|
216
|
+
return_cuts : bool
|
217
|
+
Specifies if the cut values corresponding to the provided WPs are returned.
|
218
|
+
If target_eff is a float, only one cut value will be returned. If target_eff
|
219
|
+
is an array, target_eff is an array as well.
|
220
|
+
sig_weights : np.ndarray
|
221
|
+
Weights for signal events, by default None
|
222
|
+
bkg_weights : np.ndarray
|
223
|
+
Weights for background events, by default None
|
224
|
+
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
rej : float or np.ndarray
|
228
|
+
Rejection.
|
229
|
+
If target_eff is a float, a float is returned if it's a list a np.ndarray
|
230
|
+
cut_value : float or np.ndarray
|
231
|
+
Cutvalue if return_cuts is True.
|
232
|
+
If target_eff is a float, a float is returned if it's a list a np.ndarray
|
233
|
+
"""
|
234
|
+
logger.debug("Calculating rejection.")
|
235
|
+
logger.debug("sig_disc: %s", sig_disc)
|
236
|
+
logger.debug("bkg_disc: %s", bkg_disc)
|
237
|
+
logger.debug("target_eff: %s", target_eff)
|
238
|
+
logger.debug("return_cuts: %s", return_cuts)
|
239
|
+
logger.debug("sig_weights: %s", sig_weights)
|
240
|
+
logger.debug("bkg_weights: %s", bkg_weights)
|
241
|
+
logger.debug("smooth: %s", smooth)
|
242
|
+
|
243
|
+
# Calculate the efficiency
|
244
|
+
eff = calculate_efficiency(
|
245
|
+
sig_disc=sig_disc,
|
246
|
+
bkg_disc=bkg_disc,
|
247
|
+
target_eff=target_eff,
|
248
|
+
return_cuts=return_cuts,
|
249
|
+
sig_weights=sig_weights,
|
250
|
+
bkg_weights=bkg_weights,
|
251
|
+
)
|
252
|
+
|
253
|
+
# Invert the efficiency to get a rejection
|
254
|
+
rej = save_divide(1, eff[0] if return_cuts else eff, np.inf)
|
255
|
+
|
256
|
+
# Smooth out the rejection if wanted
|
257
|
+
if smooth:
|
258
|
+
rej = gaussian_filter1d(rej, sigma=1, radius=2, mode="nearest")
|
259
|
+
|
260
|
+
# Return also the cut values if wanted
|
261
|
+
if return_cuts:
|
262
|
+
return rej, eff[1]
|
263
|
+
|
264
|
+
return rej
|
265
|
+
|
266
|
+
|
267
|
+
def calculate_efficiency_error(
|
268
|
+
arr: np.ndarray,
|
269
|
+
n_counts: int,
|
270
|
+
suppress_zero_divison_error: bool = False,
|
271
|
+
norm: bool = False,
|
272
|
+
) -> np.ndarray:
|
273
|
+
"""Calculate statistical efficiency uncertainty.
|
274
|
+
|
275
|
+
Parameters
|
276
|
+
----------
|
277
|
+
arr : numpy.array
|
278
|
+
Efficiency values
|
279
|
+
n_counts : int
|
280
|
+
Number of used statistics to calculate efficiency
|
281
|
+
suppress_zero_divison_error : bool
|
282
|
+
Not raising Error for zero division
|
283
|
+
norm : bool, optional
|
284
|
+
If True, normed (relative) error is being calculated, by default False
|
285
|
+
|
286
|
+
Returns
|
287
|
+
-------
|
288
|
+
numpy.array
|
289
|
+
Efficiency uncertainties
|
290
|
+
|
291
|
+
Raises
|
292
|
+
------
|
293
|
+
ValueError
|
294
|
+
If n_counts <=0
|
295
|
+
|
296
|
+
Notes
|
297
|
+
-----
|
298
|
+
This method uses binomial errors as described in section 2.2 of
|
299
|
+
https://inspirehep.net/files/57287ac8e45a976ab423f3dd456af694
|
300
|
+
"""
|
301
|
+
logger.debug("Calculating efficiency error.")
|
302
|
+
logger.debug("arr: %s", arr)
|
303
|
+
logger.debug("n_counts: %i", n_counts)
|
304
|
+
logger.debug("suppress_zero_divison_error: %s", suppress_zero_divison_error)
|
305
|
+
logger.debug("norm: %s", norm)
|
306
|
+
if np.any(n_counts <= 0) and not suppress_zero_divison_error:
|
307
|
+
raise ValueError(f"You passed as argument `N` {n_counts} but it has to be larger 0.")
|
308
|
+
if norm:
|
309
|
+
return np.sqrt(arr * (1 - arr) / n_counts) / arr
|
310
|
+
return np.sqrt(arr * (1 - arr) / n_counts)
|
311
|
+
|
312
|
+
|
313
|
+
def calculate_rejection_error(
|
314
|
+
arr: np.ndarray,
|
315
|
+
n_counts: int,
|
316
|
+
norm: bool = False,
|
317
|
+
) -> np.ndarray:
|
318
|
+
"""Calculate the rejection uncertainties.
|
319
|
+
|
320
|
+
Parameters
|
321
|
+
----------
|
322
|
+
arr : numpy.array
|
323
|
+
Rejection values
|
324
|
+
n_counts : int
|
325
|
+
Number of used statistics to calculate rejection
|
326
|
+
norm : bool, optional
|
327
|
+
If True, normed (relative) error is being calculated, by default False
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
numpy.array
|
332
|
+
Rejection uncertainties
|
333
|
+
|
334
|
+
Raises
|
335
|
+
------
|
336
|
+
ValueError
|
337
|
+
If n_counts <=0
|
338
|
+
ValueError
|
339
|
+
If any rejection value is 0
|
340
|
+
|
341
|
+
Notes
|
342
|
+
-----
|
343
|
+
Special case of `eff_err()`
|
344
|
+
"""
|
345
|
+
logger.debug("Calculating rejection error.")
|
346
|
+
logger.debug("arr: %s", arr)
|
347
|
+
logger.debug("n_counts: %i", n_counts)
|
348
|
+
logger.debug("norm: %s", norm)
|
349
|
+
if np.any(n_counts <= 0):
|
350
|
+
raise ValueError(f"You passed as argument `n_counts` {n_counts} but it has to be larger 0.")
|
351
|
+
if np.any(arr == 0):
|
352
|
+
raise ValueError("One rejection value is 0, cannot calculate error.")
|
353
|
+
if norm:
|
354
|
+
return np.power(arr, 2) * calculate_efficiency_error(1 / arr, n_counts) / arr
|
355
|
+
return np.power(arr, 2) * calculate_efficiency_error(1 / arr, n_counts)
|
356
|
+
|
357
|
+
|
358
|
+
def get_discriminant(
|
359
|
+
jets: np.ndarray,
|
360
|
+
tagger: str,
|
361
|
+
signal: Label,
|
362
|
+
flavours: LabelContainer,
|
363
|
+
fraction_values: dict[str, float],
|
364
|
+
epsilon: float = 1e-10,
|
365
|
+
) -> np.ndarray:
|
366
|
+
"""Calculate the tagging discriminant for a given tagger.
|
367
|
+
|
368
|
+
Calculated as the logarithm of the ratio of a specified signal probability
|
369
|
+
to a weighted sum ofbackground probabilities.
|
370
|
+
|
371
|
+
Parameters
|
372
|
+
----------
|
373
|
+
jets : np.ndarray
|
374
|
+
Structured array of jets containing tagger outputs
|
375
|
+
tagger : str
|
376
|
+
Name of the tagger
|
377
|
+
signal : Label
|
378
|
+
Signal flavour (bjets/cjets or hbb/hcc)
|
379
|
+
fraction_values : dict
|
380
|
+
Dict with the fraction values for the background classes for the given tagger
|
381
|
+
epsilon : float, optional
|
382
|
+
Small number to avoid division by zero, by default 1e-10
|
383
|
+
|
384
|
+
Returns
|
385
|
+
-------
|
386
|
+
np.ndarray
|
387
|
+
Array of discriminant values.
|
388
|
+
|
389
|
+
Raises
|
390
|
+
------
|
391
|
+
ValueError
|
392
|
+
If the signal flavour is not recognised.
|
393
|
+
"""
|
394
|
+
# Init the denominator
|
395
|
+
denominator = 0.0
|
396
|
+
|
397
|
+
# Loop over background flavours
|
398
|
+
for flav in flavours:
|
399
|
+
# Skip signal flavour for denominator
|
400
|
+
if flav == signal:
|
401
|
+
continue
|
402
|
+
|
403
|
+
# Get the probability name of the tagger/flavour combo + fraction value
|
404
|
+
prob_name = f"{tagger}_{flav.px}"
|
405
|
+
fraction_value = fraction_values[flav.frac_str]
|
406
|
+
|
407
|
+
# If fraction_value for the given flavour is zero, skip it
|
408
|
+
if fraction_value == 0:
|
409
|
+
continue
|
410
|
+
|
411
|
+
# Check that the probability value for the flavour is available
|
412
|
+
if fraction_value > 0 and prob_name not in jets.dtype.names:
|
413
|
+
raise ValueError(
|
414
|
+
f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
|
415
|
+
"not found in input array."
|
416
|
+
)
|
417
|
+
|
418
|
+
# Update denominator
|
419
|
+
denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
|
420
|
+
|
421
|
+
# Calculate numerator
|
422
|
+
signal_field = f"{tagger}_{signal.px}"
|
423
|
+
|
424
|
+
# Check that the probability of the signal is available
|
425
|
+
if signal_field not in jets.dtype.names:
|
426
|
+
raise ValueError(
|
427
|
+
f"No signal probability value(s) found for tagger {tagger}. "
|
428
|
+
f"Missing variable: {signal_field}"
|
429
|
+
)
|
430
|
+
|
431
|
+
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
@@ -14,7 +14,7 @@ from ftag import Flavours
|
|
14
14
|
from ftag.cli_utils import HelpFormatter
|
15
15
|
from ftag.cuts import Cuts
|
16
16
|
from ftag.hdf5 import H5Reader
|
17
|
-
from ftag.
|
17
|
+
from ftag.utils import get_discriminant
|
18
18
|
|
19
19
|
if TYPE_CHECKING: # pragma: no cover
|
20
20
|
from collections.abc import Sequence
|
ftag/wps/__init__.py
DELETED
File without changes
|
ftag/wps/discriminant.py
DELETED
@@ -1,84 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
if TYPE_CHECKING: # pragma: no cover
|
8
|
-
from ftag.labels import Label, LabelContainer
|
9
|
-
|
10
|
-
|
11
|
-
def get_discriminant(
|
12
|
-
jets: np.ndarray,
|
13
|
-
tagger: str,
|
14
|
-
signal: Label,
|
15
|
-
flavours: LabelContainer,
|
16
|
-
fraction_values: dict[str, float],
|
17
|
-
epsilon: float = 1e-10,
|
18
|
-
) -> np.ndarray:
|
19
|
-
"""Calculate the tagging discriminant for a given tagger.
|
20
|
-
|
21
|
-
Calculated as the logarithm of the ratio of a specified signal probability
|
22
|
-
to a weighted sum ofbackground probabilities.
|
23
|
-
|
24
|
-
Parameters
|
25
|
-
----------
|
26
|
-
jets : np.ndarray
|
27
|
-
Structured array of jets containing tagger outputs
|
28
|
-
tagger : str
|
29
|
-
Name of the tagger
|
30
|
-
signal : Label
|
31
|
-
Signal flavour (bjets/cjets or hbb/hcc)
|
32
|
-
fraction_values : dict
|
33
|
-
Dict with the fraction values for the background classes for the given tagger
|
34
|
-
epsilon : float, optional
|
35
|
-
Small number to avoid division by zero, by default 1e-10
|
36
|
-
|
37
|
-
Returns
|
38
|
-
-------
|
39
|
-
np.ndarray
|
40
|
-
Array of discriminant values.
|
41
|
-
|
42
|
-
Raises
|
43
|
-
------
|
44
|
-
ValueError
|
45
|
-
If the signal flavour is not recognised.
|
46
|
-
"""
|
47
|
-
# Init the denominator
|
48
|
-
denominator = 0.0
|
49
|
-
|
50
|
-
# Loop over background flavours
|
51
|
-
for flav in flavours:
|
52
|
-
# Skip signal flavour for denominator
|
53
|
-
if flav == signal:
|
54
|
-
continue
|
55
|
-
|
56
|
-
# Get the probability name of the tagger/flavour combo + fraction value
|
57
|
-
prob_name = f"{tagger}_{flav.px}"
|
58
|
-
fraction_value = fraction_values[flav.frac_str]
|
59
|
-
|
60
|
-
# If fraction_value for the given flavour is zero, skip it
|
61
|
-
if fraction_value == 0:
|
62
|
-
continue
|
63
|
-
|
64
|
-
# Check that the probability value for the flavour is available
|
65
|
-
if fraction_value > 0 and prob_name not in jets.dtype.names:
|
66
|
-
raise ValueError(
|
67
|
-
f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
|
68
|
-
"not found in input array."
|
69
|
-
)
|
70
|
-
|
71
|
-
# Update denominator
|
72
|
-
denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
|
73
|
-
|
74
|
-
# Calculate numerator
|
75
|
-
signal_field = f"{tagger}_{signal.px}"
|
76
|
-
|
77
|
-
# Check that the probability of the signal is available
|
78
|
-
if signal_field not in jets.dtype.names:
|
79
|
-
raise ValueError(
|
80
|
-
f"No signal probability value(s) found for tagger {tagger}. "
|
81
|
-
f"Missing variable: {signal_field}"
|
82
|
-
)
|
83
|
-
|
84
|
-
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
File without changes
|