atlas-ftag-tools 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.7.dist-info → atlas_ftag_tools-0.2.9.dist-info}/METADATA +11 -11
- {atlas_ftag_tools-0.2.7.dist-info → atlas_ftag_tools-0.2.9.dist-info}/RECORD +10 -10
- {atlas_ftag_tools-0.2.7.dist-info → atlas_ftag_tools-0.2.9.dist-info}/WHEEL +1 -1
- ftag/__init__.py +1 -1
- ftag/flavours.yaml +32 -5
- ftag/mock.py +58 -17
- ftag/wps/discriminant.py +45 -92
- ftag/wps/working_points.py +394 -163
- {atlas_ftag_tools-0.2.7.dist-info → atlas_ftag_tools-0.2.9.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.7.dist-info → atlas_ftag_tools-0.2.9.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,23 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: atlas-ftag-tools
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.9
|
4
4
|
Summary: ATLAS Flavour Tagging Tools
|
5
5
|
Author: Sam Van Stroud, Philipp Gadow
|
6
6
|
License: MIT
|
7
7
|
Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
|
8
8
|
Requires-Python: <3.12,>=3.8
|
9
9
|
Description-Content-Type: text/markdown
|
10
|
-
Requires-Dist: h5py
|
10
|
+
Requires-Dist: h5py>=3.0
|
11
11
|
Requires-Dist: numpy
|
12
|
-
Requires-Dist: PyYAML
|
12
|
+
Requires-Dist: PyYAML>=5.1
|
13
13
|
Provides-Extra: dev
|
14
|
-
Requires-Dist: ruff
|
15
|
-
Requires-Dist: mypy
|
16
|
-
Requires-Dist: pre-commit
|
17
|
-
Requires-Dist: pytest
|
18
|
-
Requires-Dist: pytest-cov
|
19
|
-
Requires-Dist:
|
20
|
-
Requires-Dist: ipykernel
|
14
|
+
Requires-Dist: ruff==0.6.2; extra == "dev"
|
15
|
+
Requires-Dist: mypy==1.11.2; extra == "dev"
|
16
|
+
Requires-Dist: pre-commit==3.1.1; extra == "dev"
|
17
|
+
Requires-Dist: pytest==7.2.2; extra == "dev"
|
18
|
+
Requires-Dist: pytest-cov==4.0.0; extra == "dev"
|
19
|
+
Requires-Dist: pytest_notebook==0.10.0; extra == "dev"
|
20
|
+
Requires-Dist: ipykernel==6.21.3; extra == "dev"
|
21
21
|
|
22
22
|
[](https://github.com/psf/black)
|
23
23
|
[](https://badge.fury.io/py/atlas-ftag-tools)
|
@@ -1,12 +1,12 @@
|
|
1
|
-
ftag/__init__.py,sha256=
|
1
|
+
ftag/__init__.py,sha256=YRug5UslRbNoQACbEhdenDS6wXmsmeLjlz4JaKP6eHs,737
|
2
2
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
3
3
|
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
4
4
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
5
|
-
ftag/flavours.yaml,sha256=
|
5
|
+
ftag/flavours.yaml,sha256=87xBvLkMDkicuRMaXtxcao8gjEAgvlTbgjAzpvx4YFM,9021
|
6
6
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
7
7
|
ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
|
8
8
|
ftag/labels.py,sha256=C7IylPTnc32dFXq8C2Ks2wuljYK3WaY2EsPLGrhtXy8,3932
|
9
|
-
ftag/mock.py,sha256=
|
9
|
+
ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
|
10
10
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
11
11
|
ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
|
12
12
|
ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
|
@@ -19,10 +19,10 @@ ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
|
19
19
|
ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
|
20
20
|
ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
|
21
21
|
ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
-
ftag/wps/discriminant.py,sha256=
|
23
|
-
ftag/wps/working_points.py,sha256=
|
24
|
-
atlas_ftag_tools-0.2.
|
25
|
-
atlas_ftag_tools-0.2.
|
26
|
-
atlas_ftag_tools-0.2.
|
27
|
-
atlas_ftag_tools-0.2.
|
28
|
-
atlas_ftag_tools-0.2.
|
22
|
+
ftag/wps/discriminant.py,sha256=GKa0zZlLREdm0mCYSbcWXITYe3VEn3PXOBQiPg5WvgM,2521
|
23
|
+
ftag/wps/working_points.py,sha256=jXyikB-bf73EaYFkngjE977-Ytvb9nDTqIdHxWW6WQQ,15960
|
24
|
+
atlas_ftag_tools-0.2.9.dist-info/METADATA,sha256=lXC-e0iHMDtvJH8h3i7PcCEKh4_CFz5vlqdGXKSEoV4,5153
|
25
|
+
atlas_ftag_tools-0.2.9.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
26
|
+
atlas_ftag_tools-0.2.9.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
|
27
|
+
atlas_ftag_tools-0.2.9.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
28
|
+
atlas_ftag_tools-0.2.9.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/flavours.yaml
CHANGED
@@ -60,12 +60,24 @@
|
|
60
60
|
colour: tab:orange
|
61
61
|
category: single-btag-ghost
|
62
62
|
_px: pc
|
63
|
-
- name:
|
64
|
-
label:
|
65
|
-
cuts: ["HadronGhostTruthLabelID == 0"]
|
63
|
+
- name: ghostsjets
|
64
|
+
label: $s$-jets
|
65
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 3"]
|
66
|
+
colour: tab:red
|
67
|
+
category: single-btag-ghost
|
68
|
+
_px: ps
|
69
|
+
- name: ghostudjets
|
70
|
+
label: Light-quark-jets
|
71
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID <= 2"]
|
66
72
|
colour: tab:green
|
67
73
|
category: single-btag-ghost
|
68
|
-
_px:
|
74
|
+
_px: pud
|
75
|
+
- name: ghostgjets
|
76
|
+
label: Gluon-jets
|
77
|
+
cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 21"]
|
78
|
+
colour: tab:gray
|
79
|
+
category: single-btag-ghost
|
80
|
+
_px: pg
|
69
81
|
- name: ghosttaujets
|
70
82
|
label: $\tau$-jets
|
71
83
|
cuts: ["HadronGhostTruthLabelID == 15"]
|
@@ -119,6 +131,21 @@
|
|
119
131
|
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount == 0", "GhostCHadronsFinalCount == 0"]
|
120
132
|
colour: "green"
|
121
133
|
category: xbb
|
134
|
+
- name: htauel
|
135
|
+
label: $H \rightarrow \tau e$
|
136
|
+
cuts: ["R10TruthLabel_R22v1 == 14"]
|
137
|
+
colour: "#b40612"
|
138
|
+
category: xbb
|
139
|
+
- name: htaumu
|
140
|
+
label: $H \rightarrow \tau\mu$
|
141
|
+
cuts: ["R10TruthLabel_R22v1 == 15"]
|
142
|
+
colour: "#b40657"
|
143
|
+
category: xbb
|
144
|
+
- name: htauhad
|
145
|
+
label: $H \rightarrow \tau\tau$
|
146
|
+
cuts: ["R10TruthLabel_R22v1 == 16"]
|
147
|
+
colour: "#b406a0"
|
148
|
+
category: xbb
|
122
149
|
|
123
150
|
# extended Xbb tagging
|
124
151
|
- name: tqqb
|
@@ -272,7 +299,7 @@
|
|
272
299
|
category: isolation
|
273
300
|
- name: npxall
|
274
301
|
label: non-prompt lepton
|
275
|
-
cuts: ["iffClass notin (2,3,4,11)"]
|
302
|
+
cuts: ["iffClass notin (0,1,2,3,4,11)"]
|
276
303
|
colour: "#264653"
|
277
304
|
category: isolation
|
278
305
|
- name: npxtau
|
ftag/mock.py
CHANGED
@@ -54,33 +54,74 @@ TRACK_VARS = [
|
|
54
54
|
]
|
55
55
|
|
56
56
|
|
57
|
-
def softmax(x, axis=None):
|
57
|
+
def softmax(x: np.ndarray, axis: int | None = None) -> np.ndarray:
|
58
|
+
"""Softmax function for numpy arrays.
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
x : np.ndarray
|
63
|
+
Input array for the softmax
|
64
|
+
axis : int | None, optional
|
65
|
+
Axis along which the softmax is calculated, by default None
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
np.ndarray
|
70
|
+
Output array with the softmax output
|
71
|
+
"""
|
58
72
|
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
59
73
|
return e_x / e_x.sum(axis=axis, keepdims=True)
|
60
74
|
|
61
75
|
|
62
|
-
def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
|
63
|
-
means = [
|
64
|
-
[2, 0, 0, 0],
|
65
|
-
[0, 1, 0, 0],
|
66
|
-
[0, 0, 3.5, 0],
|
67
|
-
[0, 0, 0, 1],
|
68
|
-
]
|
76
|
+
def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
|
69
77
|
if not is_xbb:
|
70
78
|
label_dict = {"u": 0, "c": 4, "b": 5, "tau": 15}
|
71
|
-
label_mapping = dict(zip(label_dict.values(), means))
|
72
|
-
else:
|
73
|
-
label_dict = {"hbb": 11, "hcc": 12, "top": 1, "qcd": 10}
|
74
|
-
label_mapping = dict(zip(label_dict.values(), means))
|
75
79
|
|
80
|
+
else:
|
81
|
+
label_dict = {
|
82
|
+
"hbb": 11,
|
83
|
+
"hcc": 12,
|
84
|
+
"top": 1,
|
85
|
+
"qcd": 10,
|
86
|
+
"htauel": 14,
|
87
|
+
"htaumu": 15,
|
88
|
+
"htauhad": 16,
|
89
|
+
}
|
90
|
+
|
91
|
+
# Set random seed
|
76
92
|
rng = np.random.default_rng(42)
|
77
|
-
|
78
|
-
|
79
|
-
|
93
|
+
|
94
|
+
# Set a list of possible means/scales
|
95
|
+
mean_scale_list = [1, 2, 2.5, 3.5]
|
96
|
+
|
97
|
+
# Get the number of classes
|
98
|
+
n_classes = len(label_dict)
|
99
|
+
|
100
|
+
# Init a scores array
|
101
|
+
scores = np.zeros((len(labels), n_classes))
|
102
|
+
|
103
|
+
# Generate means/scales
|
104
|
+
means = []
|
105
|
+
scales = []
|
106
|
+
for i in range(n_classes):
|
107
|
+
tmp_means = []
|
108
|
+
tmp_means = [
|
109
|
+
0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
|
110
|
+
for j in range(n_classes)
|
111
|
+
]
|
112
|
+
means.append(tmp_means)
|
113
|
+
scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
|
114
|
+
|
115
|
+
# Map the labels to the means
|
116
|
+
label_mapping = dict(zip(label_dict.values(), means))
|
117
|
+
|
118
|
+
# Generate random mock scores
|
80
119
|
for i, (label, count) in enumerate(zip(*np.unique(labels, return_counts=True))):
|
81
120
|
scores[labels == label] = rng.normal(
|
82
|
-
loc=label_mapping[label], scale=scales[i], size=(count,
|
121
|
+
loc=label_mapping[label], scale=scales[i], size=(count, n_classes)
|
83
122
|
)
|
123
|
+
|
124
|
+
# Pipe scores through softmax
|
84
125
|
scores = softmax(scores, axis=1)
|
85
126
|
name = "MockXbbTagger" if is_xbb else "MockTagger"
|
86
127
|
cols = [f"{name}_p{x}" for x in label_dict]
|
@@ -103,7 +144,7 @@ def mock_jets(num_jets=1000) -> np.ndarray:
|
|
103
144
|
jets["HadronConeExclTruthLabelID"] = rng.choice([0, 4, 5, 15], size=num_jets)
|
104
145
|
jets["GhostBHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
|
105
146
|
jets["GhostCHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
|
106
|
-
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
|
147
|
+
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12, 14, 15, 16], size=num_jets)
|
107
148
|
scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
|
108
149
|
xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
|
109
150
|
return join_structured_arrays([jets, scores, xbb_scores])
|
ftag/wps/discriminant.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
|
7
|
-
|
8
|
-
from ftag.labels import Label,
|
7
|
+
if TYPE_CHECKING: # pragma: no cover
|
8
|
+
from ftag.labels import Label, LabelContainer
|
9
9
|
|
10
10
|
|
11
|
-
def
|
11
|
+
def get_discriminant(
|
12
12
|
jets: np.ndarray,
|
13
13
|
tagger: str,
|
14
14
|
signal: Label,
|
15
|
-
|
15
|
+
flavours: LabelContainer,
|
16
|
+
fraction_values: dict[str, float],
|
16
17
|
epsilon: float = 1e-10,
|
17
18
|
) -> np.ndarray:
|
18
|
-
"""
|
19
|
-
Get the tagging discriminant.
|
19
|
+
"""Calculate the tagging discriminant for a given tagger.
|
20
20
|
|
21
21
|
Calculated as the logarithm of the ratio of a specified signal probability
|
22
22
|
to a weighted sum ofbackground probabilities.
|
@@ -24,108 +24,61 @@ def discriminant(
|
|
24
24
|
Parameters
|
25
25
|
----------
|
26
26
|
jets : np.ndarray
|
27
|
-
|
27
|
+
Structured array of jets containing tagger outputs
|
28
28
|
tagger : str
|
29
|
-
Name of the tagger
|
30
|
-
signal :
|
31
|
-
|
32
|
-
|
33
|
-
Dict
|
34
|
-
If a fraction is None, it is calculated as (1 - sum of provided fractions).
|
29
|
+
Name of the tagger
|
30
|
+
signal : Label
|
31
|
+
Signal flavour (bjets/cjets or hbb/hcc)
|
32
|
+
fraction_values : dict
|
33
|
+
Dict with the fraction values for the background classes for the given tagger
|
35
34
|
epsilon : float, optional
|
36
|
-
|
35
|
+
Small number to avoid division by zero, by default 1e-10
|
37
36
|
|
38
37
|
Returns
|
39
38
|
-------
|
40
39
|
np.ndarray
|
41
|
-
|
40
|
+
Array of discriminant values.
|
42
41
|
|
43
42
|
Raises
|
44
43
|
------
|
45
44
|
ValueError
|
46
|
-
If
|
45
|
+
If the signal flavour is not recognised.
|
47
46
|
"""
|
47
|
+
# Init the denominator
|
48
48
|
denominator = 0.0
|
49
|
-
for d, fx in fxs.items():
|
50
|
-
name = f"{tagger}_{d}"
|
51
|
-
if fx > 0 and name not in jets.dtype.names:
|
52
|
-
raise ValueError(f"Nonzero fx for {d}, but '{name}' not found in input array.")
|
53
|
-
denominator += jets[name] * fx if name in jets.dtype.names else 0
|
54
|
-
signal_field = f"{tagger}_{signal.px}"
|
55
|
-
if signal_field not in jets.dtype.names:
|
56
|
-
signal_field = f"{tagger}_p{remove_suffix(signal.name, 'jets')}"
|
57
|
-
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
58
|
-
|
59
|
-
|
60
|
-
def tautag_dicriminant(jets, tagger, fb, fc, epsilon=1e-10):
|
61
|
-
fxs = {"pb": fb, "pc": fc, "pu": 1 - fb - fc}
|
62
|
-
return discriminant(jets, tagger, Flavours.taujets, fxs, epsilon=epsilon)
|
63
|
-
|
64
|
-
|
65
|
-
def btag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
|
66
|
-
fxs = {"pc": fc, "ptau": ftau, "pu": 1 - fc - ftau}
|
67
|
-
return discriminant(jets, tagger, Flavours.bjets, fxs, epsilon=epsilon)
|
68
|
-
|
69
49
|
|
70
|
-
|
71
|
-
|
72
|
-
|
50
|
+
# Loop over background flavours
|
51
|
+
for flav in flavours:
|
52
|
+
# Skip signal flavour for denominator
|
53
|
+
if flav == signal:
|
54
|
+
continue
|
73
55
|
|
56
|
+
# Get the probability name of the tagger/flavour combo + fraction value
|
57
|
+
prob_name = f"{tagger}_{flav.px}"
|
58
|
+
fraction_value = fraction_values[flav.frac_str]
|
74
59
|
|
75
|
-
|
76
|
-
|
77
|
-
|
60
|
+
# If fraction_value for the given flavour is zero, skip it
|
61
|
+
if fraction_value == 0:
|
62
|
+
continue
|
78
63
|
|
64
|
+
# Check that the probability value for the flavour is available
|
65
|
+
if fraction_value > 0 and prob_name not in jets.dtype.names:
|
66
|
+
raise ValueError(
|
67
|
+
f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
|
68
|
+
"not found in input array."
|
69
|
+
)
|
79
70
|
|
80
|
-
|
81
|
-
|
82
|
-
return discriminant(jets, tagger, Flavours.hbb, fxs, epsilon=epsilon)
|
83
|
-
|
84
|
-
|
85
|
-
def hcc_discriminant(jets, tagger, ftop=0.25, fhbb=0.3, epsilon=1e-10):
|
86
|
-
fxs = {"phbb": fhbb, "ptop": ftop, "pqcd": 1 - ftop - fhbb}
|
87
|
-
return discriminant(jets, tagger, Flavours.hcc, fxs, epsilon=epsilon)
|
88
|
-
|
89
|
-
|
90
|
-
def get_discriminant(
|
91
|
-
jets: np.ndarray, tagger: str, signal: Label | str, epsilon: float = 1e-10, **fxs
|
92
|
-
):
|
93
|
-
"""Calculate the b-tag or c-tag discriminant for a given tagger.
|
71
|
+
# Update denominator
|
72
|
+
denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
|
94
73
|
|
95
|
-
|
96
|
-
|
97
|
-
jets : np.ndarray
|
98
|
-
Structured array of jets containing tagger outputs
|
99
|
-
tagger : str
|
100
|
-
Name of the tagger
|
101
|
-
signal : Label
|
102
|
-
Signal flavour (bjets/cjets or hbb/hcc)
|
103
|
-
epsilon : float, optional
|
104
|
-
Small number to avoid division by zero, by default 1e-10
|
105
|
-
**fxs : dict
|
106
|
-
Fractions for the different background flavours.
|
74
|
+
# Calculate numerator
|
75
|
+
signal_field = f"{tagger}_{signal.px}"
|
107
76
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
77
|
+
# Check that the probability of the signal is available
|
78
|
+
if signal_field not in jets.dtype.names:
|
79
|
+
raise ValueError(
|
80
|
+
f"No signal probability value(s) found for tagger {tagger}. "
|
81
|
+
f"Missing variable: {signal_field}"
|
82
|
+
)
|
112
83
|
|
113
|
-
|
114
|
-
------
|
115
|
-
ValueError
|
116
|
-
If the signal flavour is not recognised.
|
117
|
-
"""
|
118
|
-
tagger_funcs: dict[str, Callable] = {
|
119
|
-
"bjets": btag_discriminant,
|
120
|
-
"cjets": ctag_discriminant,
|
121
|
-
"taujets": tautag_dicriminant,
|
122
|
-
"hbb": hbb_discriminant,
|
123
|
-
"hcc": hcc_discriminant,
|
124
|
-
"ghostbjets": ghostbtag_discriminant,
|
125
|
-
}
|
126
|
-
|
127
|
-
if str(signal) not in tagger_funcs:
|
128
|
-
raise ValueError(f"Signal flavour must be one of {list(tagger_funcs.keys())}, not {signal}")
|
129
|
-
|
130
|
-
func: Callable = tagger_funcs[str(Flavours[signal])]
|
131
|
-
return func(jets, tagger, **fxs, epsilon=epsilon)
|
84
|
+
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
ftag/wps/working_points.py
CHANGED
@@ -3,7 +3,9 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import argparse
|
6
|
+
import sys
|
6
7
|
from pathlib import Path
|
8
|
+
from typing import TYPE_CHECKING
|
7
9
|
|
8
10
|
import numpy as np
|
9
11
|
import yaml
|
@@ -14,30 +16,109 @@ from ftag.cuts import Cuts
|
|
14
16
|
from ftag.hdf5 import H5Reader
|
15
17
|
from ftag.wps.discriminant import get_discriminant
|
16
18
|
|
19
|
+
if TYPE_CHECKING: # pragma: no cover
|
20
|
+
from collections.abc import Sequence
|
21
|
+
|
22
|
+
from ftag.labels import Label, LabelContainer
|
23
|
+
|
24
|
+
|
25
|
+
def parse_args(args: Sequence[str]) -> argparse.Namespace:
|
26
|
+
"""Parse the input arguments into a Namespace.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
args : Sequence[str] | None
|
31
|
+
Sequence of string inputs to the script
|
32
|
+
|
33
|
+
Returns
|
34
|
+
-------
|
35
|
+
argparse.Namespace
|
36
|
+
Namespace with the parsed arguments
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
When both --effs and --disc_cuts are provided
|
42
|
+
ValueError
|
43
|
+
When neither --effs nor --disc_cuts are provided
|
44
|
+
ValueError
|
45
|
+
When the number of fraction values is not conistent
|
46
|
+
ValueError
|
47
|
+
When the sum of fraction values for a tagger is not equal to one
|
48
|
+
"""
|
49
|
+
# Define the pre-parser which checks the --category
|
50
|
+
pre_parser = argparse.ArgumentParser(add_help=False)
|
51
|
+
pre_parser.add_argument(
|
52
|
+
"-c",
|
53
|
+
"--category",
|
54
|
+
default="single-btag",
|
55
|
+
type=str,
|
56
|
+
help="Label category to use for the working point calculation",
|
57
|
+
)
|
58
|
+
|
59
|
+
pre_parser.add_argument(
|
60
|
+
"-s",
|
61
|
+
"--signal",
|
62
|
+
default="bjets",
|
63
|
+
type=str,
|
64
|
+
help="Signal flavour which is to be used",
|
65
|
+
)
|
66
|
+
|
67
|
+
# Parse only --category/--signal and ignore for now all other args
|
68
|
+
pre_args, remaining_argv = pre_parser.parse_known_args(args=args)
|
17
69
|
|
18
|
-
|
70
|
+
# Create the "real" parser
|
19
71
|
parser = argparse.ArgumentParser(
|
20
72
|
description=__doc__,
|
21
73
|
formatter_class=HelpFormatter,
|
22
74
|
)
|
75
|
+
|
76
|
+
# Add --category/--signal so the help is correctly shown
|
77
|
+
parser.add_argument(
|
78
|
+
"-c",
|
79
|
+
"--category",
|
80
|
+
default="single-btag",
|
81
|
+
type=str,
|
82
|
+
help="Label category to use for the working point calculation",
|
83
|
+
)
|
84
|
+
parser.add_argument(
|
85
|
+
"-s",
|
86
|
+
"--signal",
|
87
|
+
default="bjets",
|
88
|
+
type=str,
|
89
|
+
help="Signal flavour which is to be used",
|
90
|
+
)
|
91
|
+
|
92
|
+
# Check which label category was chosen and load the corresponding flavours
|
93
|
+
flavours = Flavours.by_category(pre_args.category)
|
94
|
+
|
95
|
+
# Build the fraction value arguments for all classes (besides signal)
|
96
|
+
for flav in flavours:
|
97
|
+
# Skip signal
|
98
|
+
if flav.name == pre_args.signal:
|
99
|
+
continue
|
100
|
+
|
101
|
+
# Built fraction values for all background classes
|
102
|
+
parser.add_argument(
|
103
|
+
f"--{flav.frac_str}",
|
104
|
+
nargs="+",
|
105
|
+
required=True,
|
106
|
+
type=float,
|
107
|
+
help=f"{flav.frac_str} value(s) for each tagger",
|
108
|
+
)
|
109
|
+
|
110
|
+
# # Adding the other arguments
|
23
111
|
parser.add_argument(
|
24
112
|
"--ttbar",
|
25
113
|
required=True,
|
26
114
|
type=Path,
|
27
|
-
help="
|
115
|
+
help="Path to ttbar sample (supports globbing)",
|
28
116
|
)
|
29
117
|
parser.add_argument(
|
30
118
|
"--zprime",
|
31
119
|
required=False,
|
32
120
|
type=Path,
|
33
|
-
help="
|
34
|
-
)
|
35
|
-
parser.add_argument(
|
36
|
-
"-e",
|
37
|
-
"--effs",
|
38
|
-
nargs="+",
|
39
|
-
type=float,
|
40
|
-
help="efficiency working point(s). If -r is specified, values should be 1/efficiency",
|
121
|
+
help="Path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
|
41
122
|
)
|
42
123
|
parser.add_argument(
|
43
124
|
"-t",
|
@@ -48,19 +129,17 @@ def parse_args(args):
|
|
48
129
|
help="tagger name(s)",
|
49
130
|
)
|
50
131
|
parser.add_argument(
|
51
|
-
"-
|
52
|
-
"--
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
help='signal flavour ("bjets" or "cjets" for b-tagging, "hbb" or "hcc" for Xbb)',
|
132
|
+
"-e",
|
133
|
+
"--effs",
|
134
|
+
nargs="+",
|
135
|
+
type=float,
|
136
|
+
help="Efficiency working point(s). If -r is specified, values should be 1/efficiency",
|
57
137
|
)
|
58
138
|
parser.add_argument(
|
59
139
|
"-r",
|
60
140
|
"--rejection",
|
61
141
|
default=None,
|
62
|
-
|
63
|
-
help="use rejection of specified background class to determine working points",
|
142
|
+
help="Use rejection of specified background class to determine working points",
|
64
143
|
)
|
65
144
|
parser.add_argument(
|
66
145
|
"-d",
|
@@ -74,219 +153,357 @@ def parse_args(args):
|
|
74
153
|
"--num_jets",
|
75
154
|
default=1_000_000,
|
76
155
|
type=int,
|
77
|
-
help="
|
156
|
+
help="Use this many jets (post selection)",
|
78
157
|
)
|
79
158
|
parser.add_argument(
|
80
159
|
"--ttbar_cuts",
|
81
160
|
nargs="+",
|
82
161
|
default=["pt > 20e3"],
|
83
162
|
type=list,
|
84
|
-
help="
|
163
|
+
help="Selection to apply to ttbar (|eta| < 2.5 is always applied)",
|
85
164
|
)
|
86
165
|
parser.add_argument(
|
87
166
|
"--zprime_cuts",
|
88
167
|
nargs="+",
|
89
168
|
default=["pt > 250e3"],
|
90
169
|
type=list,
|
91
|
-
help="
|
170
|
+
help="Selection to apply to zprime (|eta| < 2.5 is always applied)",
|
92
171
|
)
|
93
172
|
parser.add_argument(
|
94
173
|
"-o",
|
95
174
|
"--outfile",
|
96
175
|
type=Path,
|
97
|
-
help="
|
98
|
-
)
|
99
|
-
parser.add_argument(
|
100
|
-
"--xbb",
|
101
|
-
action="store_true",
|
102
|
-
help="Enable Xbb tagging which expects two fx values ftop and fhcc/fhbb for each tagger",
|
103
|
-
)
|
104
|
-
parser.add_argument(
|
105
|
-
"--fb",
|
106
|
-
nargs="+",
|
107
|
-
type=float,
|
108
|
-
help="fb value(s) for each tagger",
|
109
|
-
)
|
110
|
-
parser.add_argument(
|
111
|
-
"--fc",
|
112
|
-
nargs="+",
|
113
|
-
type=float,
|
114
|
-
help="fc value(s) for each tagger",
|
115
|
-
)
|
116
|
-
parser.add_argument(
|
117
|
-
"--ftau",
|
118
|
-
nargs="+",
|
119
|
-
type=float,
|
120
|
-
help="ftau value(s) for each tagger",
|
121
|
-
)
|
122
|
-
parser.add_argument(
|
123
|
-
"--ftop",
|
124
|
-
nargs="+",
|
125
|
-
type=float,
|
126
|
-
help="ftop value(s) for each tagger",
|
176
|
+
help="Save results to yaml instead of printing",
|
127
177
|
)
|
128
|
-
parser.add_argument(
|
129
|
-
"--fhbb",
|
130
|
-
nargs="+",
|
131
|
-
type=float,
|
132
|
-
help="fhbb value(s) for each tagger",
|
133
|
-
)
|
134
|
-
parser.add_argument(
|
135
|
-
"--fhcc",
|
136
|
-
nargs="+",
|
137
|
-
type=float,
|
138
|
-
help="fhcc value(s) for each tagger",
|
139
|
-
)
|
140
|
-
args = parser.parse_args(args)
|
141
178
|
|
142
|
-
|
179
|
+
# Final parse of all arguments
|
180
|
+
parsed_args = parser.parse_args(remaining_argv)
|
181
|
+
|
182
|
+
# Define the signal as an instance of Flavours
|
183
|
+
parsed_args.signal = Flavours[parsed_args.signal]
|
143
184
|
|
144
|
-
|
185
|
+
# Check that only --effs or --disc_cuts is given
|
186
|
+
if parsed_args.effs and parsed_args.disc_cuts:
|
145
187
|
raise ValueError("Cannot specify both --effs and --disc_cuts")
|
146
|
-
if not
|
188
|
+
if not parsed_args.effs and not parsed_args.disc_cuts:
|
147
189
|
raise ValueError("Must specify either --effs or --disc_cuts")
|
148
190
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
191
|
+
# Check that all fraction values have the same length
|
192
|
+
for flav in flavours:
|
193
|
+
if flav.name != parsed_args.signal.name and len(getattr(parsed_args, flav.frac_str)) != len(
|
194
|
+
parsed_args.tagger
|
195
|
+
):
|
196
|
+
raise ValueError(f"Number of {flav.frac_str} values must match number of taggers")
|
197
|
+
|
198
|
+
# Check that all fraction value combinations add up to one
|
199
|
+
for tagger_idx in range(len(parsed_args.tagger)):
|
200
|
+
fraction_value_sum = 0
|
201
|
+
for flav in flavours:
|
202
|
+
if flav.name != parsed_args.signal.name:
|
203
|
+
fraction_value_sum += getattr(parsed_args, flav.frac_str)[tagger_idx]
|
204
|
+
|
205
|
+
# Round the value to take machine precision into account
|
206
|
+
fraction_value_sum = np.round(fraction_value_sum, 8)
|
207
|
+
|
208
|
+
# Check it's equal to one
|
209
|
+
if fraction_value_sum != 1:
|
210
|
+
raise ValueError(
|
211
|
+
"Sum of the fraction values must be one! You gave "
|
212
|
+
f"{fraction_value_sum} for tagger {parsed_args.tagger[tagger_idx]}"
|
213
|
+
)
|
214
|
+
return parsed_args
|
215
|
+
|
216
|
+
|
217
|
+
def get_fxs_from_args(args: argparse.Namespace, flavours: LabelContainer) -> list:
|
218
|
+
"""Get the fraction values for each tagger from the argparsed inputs.
|
219
|
+
|
220
|
+
Parameters
|
221
|
+
----------
|
222
|
+
args : argparse.Namespace
|
223
|
+
Input arguments parsed by the argparser
|
224
|
+
flavours : LabelContainer
|
225
|
+
LabelContainer instance of the labels that are used
|
226
|
+
|
227
|
+
Returns
|
228
|
+
-------
|
229
|
+
list
|
230
|
+
List of dicts with the fraction values. Each dict is for one tagger.
|
231
|
+
"""
|
232
|
+
# Init the fraction_dict dict
|
233
|
+
fraction_dict = {}
|
234
|
+
|
235
|
+
# Add the fraction values to the dict
|
236
|
+
for flav in flavours:
|
237
|
+
if flav.name != args.signal.name:
|
238
|
+
fraction_dict[flav.frac_str] = vars(args)[flav.frac_str]
|
239
|
+
|
240
|
+
return [{k: v[i] for k, v in fraction_dict.items()} for i in range(len(args.tagger))]
|
241
|
+
|
242
|
+
|
243
|
+
def get_eff_rej(
|
244
|
+
jets: np.ndarray,
|
245
|
+
disc: np.ndarray,
|
246
|
+
wp: float,
|
247
|
+
flavours: LabelContainer,
|
248
|
+
) -> dict:
|
249
|
+
"""Calculate the efficiency/rejection for each flavour.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
jets : np.ndarray
|
254
|
+
Loaded jets
|
255
|
+
disc : np.ndarray
|
256
|
+
Discriminant values of the jets
|
257
|
+
wp : float
|
258
|
+
Working point that is used
|
259
|
+
flavours : LabelContainer
|
260
|
+
LabelContainer instance of the flavours used
|
261
|
+
|
262
|
+
Returns
|
263
|
+
-------
|
264
|
+
dict
|
265
|
+
Dict with the efficiency/rejection values for each flavour
|
266
|
+
"""
|
267
|
+
# Init an out dict
|
268
|
+
out: dict[str, dict] = {"eff": {}, "rej": {}}
|
269
|
+
|
270
|
+
# Loop over the flavours
|
271
|
+
for flav in flavours:
|
272
|
+
# Calculate discriminant values and efficiencies/rejections
|
273
|
+
flav_disc = disc[flav.cuts(jets).idx]
|
274
|
+
eff = sum(flav_disc > wp) / len(flav_disc)
|
275
|
+
out["eff"][flav.name] = float(f"{eff:.3g}")
|
276
|
+
out["rej"][flav.name] = float(f"{1 / eff:.3g}")
|
277
|
+
|
197
278
|
return out
|
198
279
|
|
199
280
|
|
200
|
-
def get_rej_eff_at_disc(
|
201
|
-
|
202
|
-
|
203
|
-
|
281
|
+
def get_rej_eff_at_disc(
|
282
|
+
jets: np.ndarray,
|
283
|
+
tagger: str,
|
284
|
+
signal: Label,
|
285
|
+
disc_cuts: list,
|
286
|
+
flavours: LabelContainer,
|
287
|
+
fraction_values: dict,
|
288
|
+
) -> dict:
|
289
|
+
"""Calculate the efficiency/rejection at a certain discriminant values.
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
jets : np.ndarray
|
294
|
+
Loaded jets used
|
295
|
+
tagger : str
|
296
|
+
Name of the tagger
|
297
|
+
signal : Label
|
298
|
+
Label instance of the signal flavour
|
299
|
+
disc_cuts : list
|
300
|
+
List of discriminant cut values for which the efficiency/rejection is calculated
|
301
|
+
flavours : LabelContainer
|
302
|
+
LabelContainer instance of the flavours that are used
|
303
|
+
|
304
|
+
Returns
|
305
|
+
-------
|
306
|
+
dict
|
307
|
+
Dict with the discriminant cut values and their respective efficiencies/rejections
|
308
|
+
"""
|
309
|
+
# Calculate discriminants
|
310
|
+
disc = get_discriminant(
|
311
|
+
jets=jets,
|
312
|
+
tagger=tagger,
|
313
|
+
signal=signal,
|
314
|
+
flavours=flavours,
|
315
|
+
fraction_values=fraction_values,
|
316
|
+
)
|
317
|
+
|
318
|
+
# Init out dict
|
319
|
+
ref_eff_dict: dict[str, dict] = {}
|
320
|
+
|
321
|
+
# Loop over the disc cut values
|
204
322
|
for dcut in disc_cuts:
|
205
|
-
|
206
|
-
|
207
|
-
|
323
|
+
ref_eff_dict[str(dcut)] = {"eff": {}, "rej": {}}
|
324
|
+
|
325
|
+
# Loop over the flavours
|
326
|
+
for flav in flavours:
|
327
|
+
e_discs = disc[flav.cuts(jets).idx]
|
208
328
|
eff = sum(e_discs > dcut) / len(e_discs)
|
209
|
-
|
210
|
-
|
211
|
-
|
329
|
+
ref_eff_dict[str(dcut)]["eff"][str(flav)] = float(f"{eff:.3g}")
|
330
|
+
ref_eff_dict[str(dcut)]["rej"][str(flav)] = 1 / float(f"{eff:.3g}")
|
331
|
+
|
332
|
+
return ref_eff_dict
|
333
|
+
|
334
|
+
|
335
|
+
def setup_common_parts(
|
336
|
+
args: argparse.Namespace,
|
337
|
+
) -> tuple[np.ndarray, np.ndarray | None, LabelContainer]:
|
338
|
+
"""Load the jets from the files and setup the taggers.
|
339
|
+
|
340
|
+
Parameters
|
341
|
+
----------
|
342
|
+
args : argparse.Namespace
|
343
|
+
Input arguments from the argparser
|
212
344
|
|
345
|
+
Returns
|
346
|
+
-------
|
347
|
+
tuple[dict, dict | None, list]
|
348
|
+
Outputs the ttbar jets, the zprime jets (if wanted, else None), and the flavours used.
|
349
|
+
"""
|
350
|
+
# Get the used flavours
|
351
|
+
flavours = Flavours.by_category(args.category)
|
213
352
|
|
214
|
-
|
215
|
-
flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
|
353
|
+
# Get the cuts for the samples
|
216
354
|
default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
|
217
355
|
ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
|
218
356
|
zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
|
219
357
|
|
220
|
-
#
|
221
|
-
all_vars = list(set(sum((flav.cuts.variables for flav in
|
358
|
+
# Prepare the loading of the jets
|
359
|
+
all_vars = list(set(sum((flav.cuts.variables for flav in flavours), [])))
|
222
360
|
reader = H5Reader(args.ttbar)
|
223
361
|
jet_vars = reader.dtypes()["jets"].names
|
362
|
+
|
363
|
+
# Create for all taggers the fraction values
|
224
364
|
for tagger in args.tagger:
|
225
|
-
all_vars += [
|
365
|
+
all_vars += [
|
366
|
+
f"{tagger}_{flav.px}" for flav in flavours if (f"{tagger}_{flav.px}" in jet_vars)
|
367
|
+
]
|
368
|
+
|
369
|
+
# Load ttbar jets
|
370
|
+
ttbar_jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
|
371
|
+
zprime_jets = None
|
226
372
|
|
227
|
-
#
|
228
|
-
jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
|
229
|
-
zp_jets = None
|
373
|
+
# Load zprime jets if needed
|
230
374
|
if args.zprime:
|
231
|
-
|
232
|
-
|
375
|
+
zprime_reader = H5Reader(args.zprime)
|
376
|
+
zprime_jets = zprime_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)[
|
377
|
+
"jets"
|
378
|
+
]
|
379
|
+
|
380
|
+
else:
|
381
|
+
zprime_jets = None
|
382
|
+
|
383
|
+
return ttbar_jets, zprime_jets, flavours
|
233
384
|
|
234
|
-
return jets, zp_jets, flavs
|
235
385
|
|
386
|
+
def get_working_points(args: argparse.Namespace) -> dict | None:
|
387
|
+
"""Calculate the working points.
|
236
388
|
|
237
|
-
|
238
|
-
|
239
|
-
|
389
|
+
Parameters
|
390
|
+
----------
|
391
|
+
args : argparse.Namespace
|
392
|
+
Input arguments from the argparser
|
240
393
|
|
241
|
-
|
394
|
+
Returns
|
395
|
+
-------
|
396
|
+
dict | None
|
397
|
+
Dict with the working points. If args.outfile is given, the function returns None and
|
398
|
+
stored the resulting dict in a yaml file in args.outfile.
|
399
|
+
"""
|
400
|
+
# Load the jets and flavours and get the fraction values
|
401
|
+
ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
|
402
|
+
fraction_values = get_fxs_from_args(args=args, flavours=flavours)
|
403
|
+
|
404
|
+
# Init an out dict
|
242
405
|
out = {}
|
406
|
+
|
407
|
+
# Loop over taggers
|
243
408
|
for i, tagger in enumerate(args.tagger):
|
244
|
-
#
|
245
|
-
out[tagger] = {"signal": str(args.signal), **
|
246
|
-
disc = get_discriminant(
|
409
|
+
# Calculate discriminant
|
410
|
+
out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
|
411
|
+
disc = get_discriminant(
|
412
|
+
jets=ttbar_jets,
|
413
|
+
tagger=tagger,
|
414
|
+
signal=args.signal,
|
415
|
+
flavours=flavours,
|
416
|
+
fraction_values=fraction_values[i],
|
417
|
+
)
|
247
418
|
|
248
|
-
#
|
419
|
+
# Loop over efficiency working points
|
249
420
|
for eff in args.effs:
|
250
421
|
d = out[tagger][f"{eff:.0f}"] = {}
|
251
422
|
|
423
|
+
# Set the working point
|
252
424
|
wp_flavour = args.signal
|
253
425
|
if args.rejection:
|
254
426
|
eff = 100 / eff # noqa: PLW2901
|
255
427
|
wp_flavour = args.rejection
|
256
428
|
|
257
|
-
|
429
|
+
# Calculate the discriminant value of the working point
|
430
|
+
wp_disc = disc[flavours[wp_flavour].cuts(ttbar_jets).idx]
|
258
431
|
wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
|
259
432
|
|
260
|
-
#
|
261
|
-
d["ttbar"] = get_eff_rej(
|
433
|
+
# Calculate efficiency and rejection for each flavour
|
434
|
+
d["ttbar"] = get_eff_rej(
|
435
|
+
jets=ttbar_jets,
|
436
|
+
disc=disc,
|
437
|
+
wp=wp,
|
438
|
+
flavours=flavours,
|
439
|
+
)
|
262
440
|
|
263
441
|
# calculate for zprime
|
264
442
|
if args.zprime:
|
265
|
-
|
266
|
-
|
443
|
+
zprime_disc = get_discriminant(
|
444
|
+
jets=zprime_jets,
|
445
|
+
tagger=tagger,
|
446
|
+
signal=args.signal,
|
447
|
+
flavours=flavours,
|
448
|
+
fraction_values=fraction_values[i],
|
449
|
+
)
|
450
|
+
d["zprime"] = get_eff_rej(
|
451
|
+
jets=zprime_jets,
|
452
|
+
disc=zprime_disc,
|
453
|
+
wp=wp,
|
454
|
+
flavours=flavours,
|
455
|
+
)
|
267
456
|
|
268
457
|
if args.outfile:
|
269
458
|
with open(args.outfile, "w") as f:
|
270
459
|
yaml.dump(out, f, sort_keys=False)
|
271
460
|
return None
|
461
|
+
|
272
462
|
else:
|
273
463
|
return out
|
274
464
|
|
275
465
|
|
276
|
-
def get_efficiencies(args
|
277
|
-
|
278
|
-
|
466
|
+
def get_efficiencies(args: argparse.Namespace) -> dict | None:
|
467
|
+
"""Calculate the efficiencies for the given jets.
|
468
|
+
|
469
|
+
Parameters
|
470
|
+
----------
|
471
|
+
args : argparse.Namespace
|
472
|
+
Input arguments from the argparser
|
279
473
|
|
474
|
+
Returns
|
475
|
+
-------
|
476
|
+
dict | None
|
477
|
+
Dict with the efficiencies. If args.outfile is given, the function returns None and
|
478
|
+
stored the resulting dict in a yaml file in args.outfile.
|
479
|
+
"""
|
480
|
+
# Load the jets and flavours and get the fraction values
|
481
|
+
ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
|
482
|
+
fraction_values = get_fxs_from_args(args=args, flavours=flavours)
|
483
|
+
|
484
|
+
# Init an out dict
|
280
485
|
out = {}
|
486
|
+
|
487
|
+
# Loop over the taggers
|
281
488
|
for i, tagger in enumerate(args.tagger):
|
282
|
-
out[tagger] = {"signal": str(args.signal), **
|
489
|
+
out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
|
283
490
|
|
284
491
|
out[tagger]["ttbar"] = get_rej_eff_at_disc(
|
285
|
-
jets,
|
492
|
+
jets=ttbar_jets,
|
493
|
+
tagger=tagger,
|
494
|
+
signal=args.signal,
|
495
|
+
disc_cuts=args.disc_cuts,
|
496
|
+
flavours=flavours,
|
497
|
+
fraction_values=fraction_values[i],
|
286
498
|
)
|
287
499
|
if args.zprime:
|
288
500
|
out[tagger]["zprime"] = get_rej_eff_at_disc(
|
289
|
-
|
501
|
+
jets=zprime_jets,
|
502
|
+
tagger=tagger,
|
503
|
+
signal=args.signal,
|
504
|
+
disc_cuts=args.disc_cuts,
|
505
|
+
flavours=flavours,
|
506
|
+
fraction_values=fraction_values[i],
|
290
507
|
)
|
291
508
|
|
292
509
|
if args.outfile:
|
@@ -297,13 +514,27 @@ def get_efficiencies(args=None):
|
|
297
514
|
return out
|
298
515
|
|
299
516
|
|
300
|
-
def main(args
|
301
|
-
|
517
|
+
def main(args: Sequence[str]) -> dict | None:
|
518
|
+
"""Main function to run working point calculation.
|
519
|
+
|
520
|
+
Parameters
|
521
|
+
----------
|
522
|
+
args : Sequence[str] | None, optional
|
523
|
+
Input arguments, by default None
|
524
|
+
|
525
|
+
Returns
|
526
|
+
-------
|
527
|
+
dict | None
|
528
|
+
The output dict with the calculated values. When --outfile
|
529
|
+
was given, the return value is None
|
530
|
+
"""
|
531
|
+
parsed_args = parse_args(args=args)
|
532
|
+
|
533
|
+
if parsed_args.effs:
|
534
|
+
out = get_working_points(args=parsed_args)
|
302
535
|
|
303
|
-
|
304
|
-
out =
|
305
|
-
elif args.disc_cuts:
|
306
|
-
out = get_efficiencies(args)
|
536
|
+
elif parsed_args.disc_cuts:
|
537
|
+
out = get_efficiencies(args=parsed_args)
|
307
538
|
|
308
539
|
if out:
|
309
540
|
print(yaml.dump(out, sort_keys=False))
|
@@ -312,5 +543,5 @@ def main(args=None):
|
|
312
543
|
return None
|
313
544
|
|
314
545
|
|
315
|
-
if __name__ == "__main__":
|
316
|
-
main()
|
546
|
+
if __name__ == "__main__": # pragma: no cover
|
547
|
+
main(args=sys.argv[1:])
|
File without changes
|
File without changes
|