atlas-ftag-tools 0.2.8__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/PKG-INFO +1 -1
  2. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/PKG-INFO +1 -1
  3. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/__init__.py +1 -1
  4. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/flavours.yaml +31 -4
  5. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/mock.py +58 -17
  6. atlas_ftag_tools-0.2.9/ftag/wps/discriminant.py +84 -0
  7. atlas_ftag_tools-0.2.9/ftag/wps/working_points.py +547 -0
  8. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/pyproject.toml +2 -1
  9. atlas_ftag_tools-0.2.8/ftag/wps/discriminant.py +0 -131
  10. atlas_ftag_tools-0.2.8/ftag/wps/working_points.py +0 -316
  11. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/MANIFEST.in +0 -0
  12. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/README.md +0 -0
  13. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/SOURCES.txt +0 -0
  14. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/dependency_links.txt +0 -0
  15. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/entry_points.txt +0 -0
  16. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/requires.txt +0 -0
  17. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/atlas_ftag_tools.egg-info/top_level.txt +0 -0
  18. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/cli_utils.py +0 -0
  19. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/cuts.py +0 -0
  20. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/flavours.py +0 -0
  21. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/git_check.py +0 -0
  22. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/__init__.py +0 -0
  23. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/h5move.py +0 -0
  24. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/h5reader.py +0 -0
  25. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/h5split.py +0 -0
  26. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/h5utils.py +0 -0
  27. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/hdf5/h5writer.py +0 -0
  28. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/labeller.py +0 -0
  29. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/labels.py +0 -0
  30. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/region.py +0 -0
  31. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/sample.py +0 -0
  32. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/track_selector.py +0 -0
  33. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/transform.py +0 -0
  34. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/vds.py +0 -0
  35. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/ftag/wps/__init__.py +0 -0
  36. {atlas_ftag_tools-0.2.8 → atlas_ftag_tools-0.2.9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.8"
5
+ __version__ = "v0.2.9"
6
6
 
7
7
  from ftag import hdf5
8
8
  from ftag.cuts import Cuts
@@ -60,12 +60,24 @@
60
60
  colour: tab:orange
61
61
  category: single-btag-ghost
62
62
  _px: pc
63
- - name: ghostujets
64
- label: Light-jets
65
- cuts: ["HadronGhostTruthLabelID == 0"]
63
+ - name: ghostsjets
64
+ label: $s$-jets
65
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 3"]
66
+ colour: tab:red
67
+ category: single-btag-ghost
68
+ _px: ps
69
+ - name: ghostudjets
70
+ label: Light-quark-jets
71
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID <= 2"]
66
72
  colour: tab:green
67
73
  category: single-btag-ghost
68
- _px: pu
74
+ _px: pud
75
+ - name: ghostgjets
76
+ label: Gluon-jets
77
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 21"]
78
+ colour: tab:gray
79
+ category: single-btag-ghost
80
+ _px: pg
69
81
  - name: ghosttaujets
70
82
  label: $\tau$-jets
71
83
  cuts: ["HadronGhostTruthLabelID == 15"]
@@ -119,6 +131,21 @@
119
131
  cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount == 0", "GhostCHadronsFinalCount == 0"]
120
132
  colour: "green"
121
133
  category: xbb
134
+ - name: htauel
135
+ label: $H \rightarrow \tau e$
136
+ cuts: ["R10TruthLabel_R22v1 == 14"]
137
+ colour: "#b40612"
138
+ category: xbb
139
+ - name: htaumu
140
+ label: $H \rightarrow \tau\mu$
141
+ cuts: ["R10TruthLabel_R22v1 == 15"]
142
+ colour: "#b40657"
143
+ category: xbb
144
+ - name: htauhad
145
+ label: $H \rightarrow \tau\tau$
146
+ cuts: ["R10TruthLabel_R22v1 == 16"]
147
+ colour: "#b406a0"
148
+ category: xbb
122
149
 
123
150
  # extended Xbb tagging
124
151
  - name: tqqb
@@ -54,33 +54,74 @@ TRACK_VARS = [
54
54
  ]
55
55
 
56
56
 
57
- def softmax(x, axis=None):
57
+ def softmax(x: np.ndarray, axis: int | None = None) -> np.ndarray:
58
+ """Softmax function for numpy arrays.
59
+
60
+ Parameters
61
+ ----------
62
+ x : np.ndarray
63
+ Input array for the softmax
64
+ axis : int | None, optional
65
+ Axis along which the softmax is calculated, by default None
66
+
67
+ Returns
68
+ -------
69
+ np.ndarray
70
+ Output array with the softmax output
71
+ """
58
72
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
59
73
  return e_x / e_x.sum(axis=axis, keepdims=True)
60
74
 
61
75
 
62
- def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
63
- means = [
64
- [2, 0, 0, 0],
65
- [0, 1, 0, 0],
66
- [0, 0, 3.5, 0],
67
- [0, 0, 0, 1],
68
- ]
76
+ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
69
77
  if not is_xbb:
70
78
  label_dict = {"u": 0, "c": 4, "b": 5, "tau": 15}
71
- label_mapping = dict(zip(label_dict.values(), means))
72
- else:
73
- label_dict = {"hbb": 11, "hcc": 12, "top": 1, "qcd": 10}
74
- label_mapping = dict(zip(label_dict.values(), means))
75
79
 
80
+ else:
81
+ label_dict = {
82
+ "hbb": 11,
83
+ "hcc": 12,
84
+ "top": 1,
85
+ "qcd": 10,
86
+ "htauel": 14,
87
+ "htaumu": 15,
88
+ "htauhad": 16,
89
+ }
90
+
91
+ # Set random seed
76
92
  rng = np.random.default_rng(42)
77
- nclass = len(label_dict)
78
- scores = np.zeros((len(labels), nclass))
79
- scales = [1, 2.5, 5, 1]
93
+
94
+ # Set a list of possible means/scales
95
+ mean_scale_list = [1, 2, 2.5, 3.5]
96
+
97
+ # Get the number of classes
98
+ n_classes = len(label_dict)
99
+
100
+ # Init a scores array
101
+ scores = np.zeros((len(labels), n_classes))
102
+
103
+ # Generate means/scales
104
+ means = []
105
+ scales = []
106
+ for i in range(n_classes):
107
+ tmp_means = []
108
+ tmp_means = [
109
+ 0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
110
+ for j in range(n_classes)
111
+ ]
112
+ means.append(tmp_means)
113
+ scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
114
+
115
+ # Map the labels to the means
116
+ label_mapping = dict(zip(label_dict.values(), means))
117
+
118
+ # Generate random mock scores
80
119
  for i, (label, count) in enumerate(zip(*np.unique(labels, return_counts=True))):
81
120
  scores[labels == label] = rng.normal(
82
- loc=label_mapping[label], scale=scales[i], size=(count, nclass)
121
+ loc=label_mapping[label], scale=scales[i], size=(count, n_classes)
83
122
  )
123
+
124
+ # Pipe scores through softmax
84
125
  scores = softmax(scores, axis=1)
85
126
  name = "MockXbbTagger" if is_xbb else "MockTagger"
86
127
  cols = [f"{name}_p{x}" for x in label_dict]
@@ -103,7 +144,7 @@ def mock_jets(num_jets=1000) -> np.ndarray:
103
144
  jets["HadronConeExclTruthLabelID"] = rng.choice([0, 4, 5, 15], size=num_jets)
104
145
  jets["GhostBHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
105
146
  jets["GhostCHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
106
- jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
147
+ jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12, 14, 15, 16], size=num_jets)
107
148
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
108
149
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
109
150
  return join_structured_arrays([jets, scores, xbb_scores])
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+
7
+ if TYPE_CHECKING: # pragma: no cover
8
+ from ftag.labels import Label, LabelContainer
9
+
10
+
11
+ def get_discriminant(
12
+ jets: np.ndarray,
13
+ tagger: str,
14
+ signal: Label,
15
+ flavours: LabelContainer,
16
+ fraction_values: dict[str, float],
17
+ epsilon: float = 1e-10,
18
+ ) -> np.ndarray:
19
+ """Calculate the tagging discriminant for a given tagger.
20
+
21
+ Calculated as the logarithm of the ratio of a specified signal probability
22
+ to a weighted sum ofbackground probabilities.
23
+
24
+ Parameters
25
+ ----------
26
+ jets : np.ndarray
27
+ Structured array of jets containing tagger outputs
28
+ tagger : str
29
+ Name of the tagger
30
+ signal : Label
31
+ Signal flavour (bjets/cjets or hbb/hcc)
32
+ fraction_values : dict
33
+ Dict with the fraction values for the background classes for the given tagger
34
+ epsilon : float, optional
35
+ Small number to avoid division by zero, by default 1e-10
36
+
37
+ Returns
38
+ -------
39
+ np.ndarray
40
+ Array of discriminant values.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If the signal flavour is not recognised.
46
+ """
47
+ # Init the denominator
48
+ denominator = 0.0
49
+
50
+ # Loop over background flavours
51
+ for flav in flavours:
52
+ # Skip signal flavour for denominator
53
+ if flav == signal:
54
+ continue
55
+
56
+ # Get the probability name of the tagger/flavour combo + fraction value
57
+ prob_name = f"{tagger}_{flav.px}"
58
+ fraction_value = fraction_values[flav.frac_str]
59
+
60
+ # If fraction_value for the given flavour is zero, skip it
61
+ if fraction_value == 0:
62
+ continue
63
+
64
+ # Check that the probability value for the flavour is available
65
+ if fraction_value > 0 and prob_name not in jets.dtype.names:
66
+ raise ValueError(
67
+ f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
68
+ "not found in input array."
69
+ )
70
+
71
+ # Update denominator
72
+ denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
73
+
74
+ # Calculate numerator
75
+ signal_field = f"{tagger}_{signal.px}"
76
+
77
+ # Check that the probability of the signal is available
78
+ if signal_field not in jets.dtype.names:
79
+ raise ValueError(
80
+ f"No signal probability value(s) found for tagger {tagger}. "
81
+ f"Missing variable: {signal_field}"
82
+ )
83
+
84
+ return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
@@ -0,0 +1,547 @@
1
+ """Calculate tagger working points."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ import numpy as np
11
+ import yaml
12
+
13
+ from ftag import Flavours
14
+ from ftag.cli_utils import HelpFormatter
15
+ from ftag.cuts import Cuts
16
+ from ftag.hdf5 import H5Reader
17
+ from ftag.wps.discriminant import get_discriminant
18
+
19
+ if TYPE_CHECKING: # pragma: no cover
20
+ from collections.abc import Sequence
21
+
22
+ from ftag.labels import Label, LabelContainer
23
+
24
+
25
+ def parse_args(args: Sequence[str]) -> argparse.Namespace:
26
+ """Parse the input arguments into a Namespace.
27
+
28
+ Parameters
29
+ ----------
30
+ args : Sequence[str] | None
31
+ Sequence of string inputs to the script
32
+
33
+ Returns
34
+ -------
35
+ argparse.Namespace
36
+ Namespace with the parsed arguments
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ When both --effs and --disc_cuts are provided
42
+ ValueError
43
+ When neither --effs nor --disc_cuts are provided
44
+ ValueError
45
+ When the number of fraction values is not conistent
46
+ ValueError
47
+ When the sum of fraction values for a tagger is not equal to one
48
+ """
49
+ # Define the pre-parser which checks the --category
50
+ pre_parser = argparse.ArgumentParser(add_help=False)
51
+ pre_parser.add_argument(
52
+ "-c",
53
+ "--category",
54
+ default="single-btag",
55
+ type=str,
56
+ help="Label category to use for the working point calculation",
57
+ )
58
+
59
+ pre_parser.add_argument(
60
+ "-s",
61
+ "--signal",
62
+ default="bjets",
63
+ type=str,
64
+ help="Signal flavour which is to be used",
65
+ )
66
+
67
+ # Parse only --category/--signal and ignore for now all other args
68
+ pre_args, remaining_argv = pre_parser.parse_known_args(args=args)
69
+
70
+ # Create the "real" parser
71
+ parser = argparse.ArgumentParser(
72
+ description=__doc__,
73
+ formatter_class=HelpFormatter,
74
+ )
75
+
76
+ # Add --category/--signal so the help is correctly shown
77
+ parser.add_argument(
78
+ "-c",
79
+ "--category",
80
+ default="single-btag",
81
+ type=str,
82
+ help="Label category to use for the working point calculation",
83
+ )
84
+ parser.add_argument(
85
+ "-s",
86
+ "--signal",
87
+ default="bjets",
88
+ type=str,
89
+ help="Signal flavour which is to be used",
90
+ )
91
+
92
+ # Check which label category was chosen and load the corresponding flavours
93
+ flavours = Flavours.by_category(pre_args.category)
94
+
95
+ # Build the fraction value arguments for all classes (besides signal)
96
+ for flav in flavours:
97
+ # Skip signal
98
+ if flav.name == pre_args.signal:
99
+ continue
100
+
101
+ # Built fraction values for all background classes
102
+ parser.add_argument(
103
+ f"--{flav.frac_str}",
104
+ nargs="+",
105
+ required=True,
106
+ type=float,
107
+ help=f"{flav.frac_str} value(s) for each tagger",
108
+ )
109
+
110
+ # # Adding the other arguments
111
+ parser.add_argument(
112
+ "--ttbar",
113
+ required=True,
114
+ type=Path,
115
+ help="Path to ttbar sample (supports globbing)",
116
+ )
117
+ parser.add_argument(
118
+ "--zprime",
119
+ required=False,
120
+ type=Path,
121
+ help="Path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
122
+ )
123
+ parser.add_argument(
124
+ "-t",
125
+ "--tagger",
126
+ nargs="+",
127
+ required=True,
128
+ type=str,
129
+ help="tagger name(s)",
130
+ )
131
+ parser.add_argument(
132
+ "-e",
133
+ "--effs",
134
+ nargs="+",
135
+ type=float,
136
+ help="Efficiency working point(s). If -r is specified, values should be 1/efficiency",
137
+ )
138
+ parser.add_argument(
139
+ "-r",
140
+ "--rejection",
141
+ default=None,
142
+ help="Use rejection of specified background class to determine working points",
143
+ )
144
+ parser.add_argument(
145
+ "-d",
146
+ "--disc_cuts",
147
+ nargs="+",
148
+ type=float,
149
+ help="D_x value(s) to calculate efficiency at",
150
+ )
151
+ parser.add_argument(
152
+ "-n",
153
+ "--num_jets",
154
+ default=1_000_000,
155
+ type=int,
156
+ help="Use this many jets (post selection)",
157
+ )
158
+ parser.add_argument(
159
+ "--ttbar_cuts",
160
+ nargs="+",
161
+ default=["pt > 20e3"],
162
+ type=list,
163
+ help="Selection to apply to ttbar (|eta| < 2.5 is always applied)",
164
+ )
165
+ parser.add_argument(
166
+ "--zprime_cuts",
167
+ nargs="+",
168
+ default=["pt > 250e3"],
169
+ type=list,
170
+ help="Selection to apply to zprime (|eta| < 2.5 is always applied)",
171
+ )
172
+ parser.add_argument(
173
+ "-o",
174
+ "--outfile",
175
+ type=Path,
176
+ help="Save results to yaml instead of printing",
177
+ )
178
+
179
+ # Final parse of all arguments
180
+ parsed_args = parser.parse_args(remaining_argv)
181
+
182
+ # Define the signal as an instance of Flavours
183
+ parsed_args.signal = Flavours[parsed_args.signal]
184
+
185
+ # Check that only --effs or --disc_cuts is given
186
+ if parsed_args.effs and parsed_args.disc_cuts:
187
+ raise ValueError("Cannot specify both --effs and --disc_cuts")
188
+ if not parsed_args.effs and not parsed_args.disc_cuts:
189
+ raise ValueError("Must specify either --effs or --disc_cuts")
190
+
191
+ # Check that all fraction values have the same length
192
+ for flav in flavours:
193
+ if flav.name != parsed_args.signal.name and len(getattr(parsed_args, flav.frac_str)) != len(
194
+ parsed_args.tagger
195
+ ):
196
+ raise ValueError(f"Number of {flav.frac_str} values must match number of taggers")
197
+
198
+ # Check that all fraction value combinations add up to one
199
+ for tagger_idx in range(len(parsed_args.tagger)):
200
+ fraction_value_sum = 0
201
+ for flav in flavours:
202
+ if flav.name != parsed_args.signal.name:
203
+ fraction_value_sum += getattr(parsed_args, flav.frac_str)[tagger_idx]
204
+
205
+ # Round the value to take machine precision into account
206
+ fraction_value_sum = np.round(fraction_value_sum, 8)
207
+
208
+ # Check it's equal to one
209
+ if fraction_value_sum != 1:
210
+ raise ValueError(
211
+ "Sum of the fraction values must be one! You gave "
212
+ f"{fraction_value_sum} for tagger {parsed_args.tagger[tagger_idx]}"
213
+ )
214
+ return parsed_args
215
+
216
+
217
+ def get_fxs_from_args(args: argparse.Namespace, flavours: LabelContainer) -> list:
218
+ """Get the fraction values for each tagger from the argparsed inputs.
219
+
220
+ Parameters
221
+ ----------
222
+ args : argparse.Namespace
223
+ Input arguments parsed by the argparser
224
+ flavours : LabelContainer
225
+ LabelContainer instance of the labels that are used
226
+
227
+ Returns
228
+ -------
229
+ list
230
+ List of dicts with the fraction values. Each dict is for one tagger.
231
+ """
232
+ # Init the fraction_dict dict
233
+ fraction_dict = {}
234
+
235
+ # Add the fraction values to the dict
236
+ for flav in flavours:
237
+ if flav.name != args.signal.name:
238
+ fraction_dict[flav.frac_str] = vars(args)[flav.frac_str]
239
+
240
+ return [{k: v[i] for k, v in fraction_dict.items()} for i in range(len(args.tagger))]
241
+
242
+
243
+ def get_eff_rej(
244
+ jets: np.ndarray,
245
+ disc: np.ndarray,
246
+ wp: float,
247
+ flavours: LabelContainer,
248
+ ) -> dict:
249
+ """Calculate the efficiency/rejection for each flavour.
250
+
251
+ Parameters
252
+ ----------
253
+ jets : np.ndarray
254
+ Loaded jets
255
+ disc : np.ndarray
256
+ Discriminant values of the jets
257
+ wp : float
258
+ Working point that is used
259
+ flavours : LabelContainer
260
+ LabelContainer instance of the flavours used
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ Dict with the efficiency/rejection values for each flavour
266
+ """
267
+ # Init an out dict
268
+ out: dict[str, dict] = {"eff": {}, "rej": {}}
269
+
270
+ # Loop over the flavours
271
+ for flav in flavours:
272
+ # Calculate discriminant values and efficiencies/rejections
273
+ flav_disc = disc[flav.cuts(jets).idx]
274
+ eff = sum(flav_disc > wp) / len(flav_disc)
275
+ out["eff"][flav.name] = float(f"{eff:.3g}")
276
+ out["rej"][flav.name] = float(f"{1 / eff:.3g}")
277
+
278
+ return out
279
+
280
+
281
+ def get_rej_eff_at_disc(
282
+ jets: np.ndarray,
283
+ tagger: str,
284
+ signal: Label,
285
+ disc_cuts: list,
286
+ flavours: LabelContainer,
287
+ fraction_values: dict,
288
+ ) -> dict:
289
+ """Calculate the efficiency/rejection at a certain discriminant values.
290
+
291
+ Parameters
292
+ ----------
293
+ jets : np.ndarray
294
+ Loaded jets used
295
+ tagger : str
296
+ Name of the tagger
297
+ signal : Label
298
+ Label instance of the signal flavour
299
+ disc_cuts : list
300
+ List of discriminant cut values for which the efficiency/rejection is calculated
301
+ flavours : LabelContainer
302
+ LabelContainer instance of the flavours that are used
303
+
304
+ Returns
305
+ -------
306
+ dict
307
+ Dict with the discriminant cut values and their respective efficiencies/rejections
308
+ """
309
+ # Calculate discriminants
310
+ disc = get_discriminant(
311
+ jets=jets,
312
+ tagger=tagger,
313
+ signal=signal,
314
+ flavours=flavours,
315
+ fraction_values=fraction_values,
316
+ )
317
+
318
+ # Init out dict
319
+ ref_eff_dict: dict[str, dict] = {}
320
+
321
+ # Loop over the disc cut values
322
+ for dcut in disc_cuts:
323
+ ref_eff_dict[str(dcut)] = {"eff": {}, "rej": {}}
324
+
325
+ # Loop over the flavours
326
+ for flav in flavours:
327
+ e_discs = disc[flav.cuts(jets).idx]
328
+ eff = sum(e_discs > dcut) / len(e_discs)
329
+ ref_eff_dict[str(dcut)]["eff"][str(flav)] = float(f"{eff:.3g}")
330
+ ref_eff_dict[str(dcut)]["rej"][str(flav)] = 1 / float(f"{eff:.3g}")
331
+
332
+ return ref_eff_dict
333
+
334
+
335
+ def setup_common_parts(
336
+ args: argparse.Namespace,
337
+ ) -> tuple[np.ndarray, np.ndarray | None, LabelContainer]:
338
+ """Load the jets from the files and setup the taggers.
339
+
340
+ Parameters
341
+ ----------
342
+ args : argparse.Namespace
343
+ Input arguments from the argparser
344
+
345
+ Returns
346
+ -------
347
+ tuple[dict, dict | None, list]
348
+ Outputs the ttbar jets, the zprime jets (if wanted, else None), and the flavours used.
349
+ """
350
+ # Get the used flavours
351
+ flavours = Flavours.by_category(args.category)
352
+
353
+ # Get the cuts for the samples
354
+ default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
355
+ ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
356
+ zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
357
+
358
+ # Prepare the loading of the jets
359
+ all_vars = list(set(sum((flav.cuts.variables for flav in flavours), [])))
360
+ reader = H5Reader(args.ttbar)
361
+ jet_vars = reader.dtypes()["jets"].names
362
+
363
+ # Create for all taggers the fraction values
364
+ for tagger in args.tagger:
365
+ all_vars += [
366
+ f"{tagger}_{flav.px}" for flav in flavours if (f"{tagger}_{flav.px}" in jet_vars)
367
+ ]
368
+
369
+ # Load ttbar jets
370
+ ttbar_jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
371
+ zprime_jets = None
372
+
373
+ # Load zprime jets if needed
374
+ if args.zprime:
375
+ zprime_reader = H5Reader(args.zprime)
376
+ zprime_jets = zprime_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)[
377
+ "jets"
378
+ ]
379
+
380
+ else:
381
+ zprime_jets = None
382
+
383
+ return ttbar_jets, zprime_jets, flavours
384
+
385
+
386
+ def get_working_points(args: argparse.Namespace) -> dict | None:
387
+ """Calculate the working points.
388
+
389
+ Parameters
390
+ ----------
391
+ args : argparse.Namespace
392
+ Input arguments from the argparser
393
+
394
+ Returns
395
+ -------
396
+ dict | None
397
+ Dict with the working points. If args.outfile is given, the function returns None and
398
+ stored the resulting dict in a yaml file in args.outfile.
399
+ """
400
+ # Load the jets and flavours and get the fraction values
401
+ ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
402
+ fraction_values = get_fxs_from_args(args=args, flavours=flavours)
403
+
404
+ # Init an out dict
405
+ out = {}
406
+
407
+ # Loop over taggers
408
+ for i, tagger in enumerate(args.tagger):
409
+ # Calculate discriminant
410
+ out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
411
+ disc = get_discriminant(
412
+ jets=ttbar_jets,
413
+ tagger=tagger,
414
+ signal=args.signal,
415
+ flavours=flavours,
416
+ fraction_values=fraction_values[i],
417
+ )
418
+
419
+ # Loop over efficiency working points
420
+ for eff in args.effs:
421
+ d = out[tagger][f"{eff:.0f}"] = {}
422
+
423
+ # Set the working point
424
+ wp_flavour = args.signal
425
+ if args.rejection:
426
+ eff = 100 / eff # noqa: PLW2901
427
+ wp_flavour = args.rejection
428
+
429
+ # Calculate the discriminant value of the working point
430
+ wp_disc = disc[flavours[wp_flavour].cuts(ttbar_jets).idx]
431
+ wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
432
+
433
+ # Calculate efficiency and rejection for each flavour
434
+ d["ttbar"] = get_eff_rej(
435
+ jets=ttbar_jets,
436
+ disc=disc,
437
+ wp=wp,
438
+ flavours=flavours,
439
+ )
440
+
441
+ # calculate for zprime
442
+ if args.zprime:
443
+ zprime_disc = get_discriminant(
444
+ jets=zprime_jets,
445
+ tagger=tagger,
446
+ signal=args.signal,
447
+ flavours=flavours,
448
+ fraction_values=fraction_values[i],
449
+ )
450
+ d["zprime"] = get_eff_rej(
451
+ jets=zprime_jets,
452
+ disc=zprime_disc,
453
+ wp=wp,
454
+ flavours=flavours,
455
+ )
456
+
457
+ if args.outfile:
458
+ with open(args.outfile, "w") as f:
459
+ yaml.dump(out, f, sort_keys=False)
460
+ return None
461
+
462
+ else:
463
+ return out
464
+
465
+
466
+ def get_efficiencies(args: argparse.Namespace) -> dict | None:
467
+ """Calculate the efficiencies for the given jets.
468
+
469
+ Parameters
470
+ ----------
471
+ args : argparse.Namespace
472
+ Input arguments from the argparser
473
+
474
+ Returns
475
+ -------
476
+ dict | None
477
+ Dict with the efficiencies. If args.outfile is given, the function returns None and
478
+ stored the resulting dict in a yaml file in args.outfile.
479
+ """
480
+ # Load the jets and flavours and get the fraction values
481
+ ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
482
+ fraction_values = get_fxs_from_args(args=args, flavours=flavours)
483
+
484
+ # Init an out dict
485
+ out = {}
486
+
487
+ # Loop over the taggers
488
+ for i, tagger in enumerate(args.tagger):
489
+ out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
490
+
491
+ out[tagger]["ttbar"] = get_rej_eff_at_disc(
492
+ jets=ttbar_jets,
493
+ tagger=tagger,
494
+ signal=args.signal,
495
+ disc_cuts=args.disc_cuts,
496
+ flavours=flavours,
497
+ fraction_values=fraction_values[i],
498
+ )
499
+ if args.zprime:
500
+ out[tagger]["zprime"] = get_rej_eff_at_disc(
501
+ jets=zprime_jets,
502
+ tagger=tagger,
503
+ signal=args.signal,
504
+ disc_cuts=args.disc_cuts,
505
+ flavours=flavours,
506
+ fraction_values=fraction_values[i],
507
+ )
508
+
509
+ if args.outfile:
510
+ with open(args.outfile, "w") as f:
511
+ yaml.dump(out, f, sort_keys=False)
512
+ return None
513
+ else:
514
+ return out
515
+
516
+
517
+ def main(args: Sequence[str]) -> dict | None:
518
+ """Main function to run working point calculation.
519
+
520
+ Parameters
521
+ ----------
522
+ args : Sequence[str] | None, optional
523
+ Input arguments, by default None
524
+
525
+ Returns
526
+ -------
527
+ dict | None
528
+ The output dict with the calculated values. When --outfile
529
+ was given, the return value is None
530
+ """
531
+ parsed_args = parse_args(args=args)
532
+
533
+ if parsed_args.effs:
534
+ out = get_working_points(args=parsed_args)
535
+
536
+ elif parsed_args.disc_cuts:
537
+ out = get_efficiencies(args=parsed_args)
538
+
539
+ if out:
540
+ print(yaml.dump(out, sort_keys=False))
541
+ return out
542
+
543
+ return None
544
+
545
+
546
+ if __name__ == "__main__": # pragma: no cover
547
+ main(args=sys.argv[1:])
@@ -57,7 +57,8 @@ lint.ignore = [
57
57
  "ANN001", "ANN002", "ANN003", "ANN101", "ANN201", "ANN202", "ANN204",
58
58
  "T201", "PLW1514", "PTH123", "RUF017", "PLR6301", "ISC001", "S307",
59
59
  "PT027", "NPY002", "PT009", "PLW1641", "PLR0904", "N817", "S603", "PD011",
60
- "S113", "TCH", "PT011", "PLR1702", "S108", "PTH207", "S607", "E203", "SIM115"
60
+ "S113", "TCH", "PT011", "PLR1702", "S108", "PTH207", "S607", "E203", "SIM115", "PLR0913",
61
+ "PLR0917"
61
62
  ]
62
63
 
63
64
  [tool.ruff.lint.flake8-pytest-style]
@@ -1,131 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Callable
4
-
5
- import numpy as np
6
-
7
- from ftag import Flavours
8
- from ftag.labels import Label, remove_suffix
9
-
10
-
11
- def discriminant(
12
- jets: np.ndarray,
13
- tagger: str,
14
- signal: Label,
15
- fxs: dict[str, float],
16
- epsilon: float = 1e-10,
17
- ) -> np.ndarray:
18
- """
19
- Get the tagging discriminant.
20
-
21
- Calculated as the logarithm of the ratio of a specified signal probability
22
- to a weighted sum ofbackground probabilities.
23
-
24
- Parameters
25
- ----------
26
- jets : np.ndarray
27
- Structed jet array containing tagger scores.
28
- tagger : str
29
- Name of the tagger, used to construct field names.
30
- signal : str
31
- Type of signal.
32
- fxs : dict[str, float]
33
- Dict of background probability names and their fractions.
34
- If a fraction is None, it is calculated as (1 - sum of provided fractions).
35
- epsilon : float, optional
36
- A small value added to probabilities to prevent division by zero, by default 1e-10.
37
-
38
- Returns
39
- -------
40
- np.ndarray
41
- The tagger discriminant values for the jets.
42
-
43
- Raises
44
- ------
45
- ValueError
46
- If a fraction is specified for a denominator that is not present in the input array.
47
- """
48
- denominator = 0.0
49
- for d, fx in fxs.items():
50
- name = f"{tagger}_{d}"
51
- if fx > 0 and name not in jets.dtype.names:
52
- raise ValueError(f"Nonzero fx for {d}, but '{name}' not found in input array.")
53
- denominator += jets[name] * fx if name in jets.dtype.names else 0
54
- signal_field = f"{tagger}_{signal.px}"
55
- if signal_field not in jets.dtype.names:
56
- signal_field = f"{tagger}_p{remove_suffix(signal.name, 'jets')}"
57
- return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
58
-
59
-
60
- def tautag_dicriminant(jets, tagger, fb, fc, epsilon=1e-10):
61
- fxs = {"pb": fb, "pc": fc, "pu": 1 - fb - fc}
62
- return discriminant(jets, tagger, Flavours.taujets, fxs, epsilon=epsilon)
63
-
64
-
65
- def btag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
66
- fxs = {"pc": fc, "ptau": ftau, "pu": 1 - fc - ftau}
67
- return discriminant(jets, tagger, Flavours.bjets, fxs, epsilon=epsilon)
68
-
69
-
70
- def ghostbtag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
71
- fxs = {"pghostc": fc, "pghosttau": ftau, "pghostu": 1 - fc - ftau}
72
- return discriminant(jets, tagger, Flavours.ghostbjets, fxs, epsilon=epsilon)
73
-
74
-
75
- def ctag_discriminant(jets, tagger, fb, ftau=0, epsilon=1e-10):
76
- fxs = {"pb": fb, "ptau": ftau, "pu": 1 - fb - ftau}
77
- return discriminant(jets, tagger, Flavours.cjets, fxs, epsilon=epsilon)
78
-
79
-
80
- def hbb_discriminant(jets, tagger, ftop=0.25, fhcc=0.02, epsilon=1e-10):
81
- fxs = {"phcc": fhcc, "ptop": ftop, "pqcd": 1 - ftop - fhcc}
82
- return discriminant(jets, tagger, Flavours.hbb, fxs, epsilon=epsilon)
83
-
84
-
85
- def hcc_discriminant(jets, tagger, ftop=0.25, fhbb=0.3, epsilon=1e-10):
86
- fxs = {"phbb": fhbb, "ptop": ftop, "pqcd": 1 - ftop - fhbb}
87
- return discriminant(jets, tagger, Flavours.hcc, fxs, epsilon=epsilon)
88
-
89
-
90
- def get_discriminant(
91
- jets: np.ndarray, tagger: str, signal: Label | str, epsilon: float = 1e-10, **fxs
92
- ):
93
- """Calculate the b-tag or c-tag discriminant for a given tagger.
94
-
95
- Parameters
96
- ----------
97
- jets : np.ndarray
98
- Structured array of jets containing tagger outputs
99
- tagger : str
100
- Name of the tagger
101
- signal : Label
102
- Signal flavour (bjets/cjets or hbb/hcc)
103
- epsilon : float, optional
104
- Small number to avoid division by zero, by default 1e-10
105
- **fxs : dict
106
- Fractions for the different background flavours.
107
-
108
- Returns
109
- -------
110
- np.ndarray
111
- Array of discriminant values.
112
-
113
- Raises
114
- ------
115
- ValueError
116
- If the signal flavour is not recognised.
117
- """
118
- tagger_funcs: dict[str, Callable] = {
119
- "bjets": btag_discriminant,
120
- "cjets": ctag_discriminant,
121
- "taujets": tautag_dicriminant,
122
- "hbb": hbb_discriminant,
123
- "hcc": hcc_discriminant,
124
- "ghostbjets": ghostbtag_discriminant,
125
- }
126
-
127
- if str(signal) not in tagger_funcs:
128
- raise ValueError(f"Signal flavour must be one of {list(tagger_funcs.keys())}, not {signal}")
129
-
130
- func: Callable = tagger_funcs[str(Flavours[signal])]
131
- return func(jets, tagger, **fxs, epsilon=epsilon)
@@ -1,316 +0,0 @@
1
- """Calculate tagger working points."""
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- from pathlib import Path
7
-
8
- import numpy as np
9
- import yaml
10
-
11
- from ftag import Flavours
12
- from ftag.cli_utils import HelpFormatter
13
- from ftag.cuts import Cuts
14
- from ftag.hdf5 import H5Reader
15
- from ftag.wps.discriminant import get_discriminant
16
-
17
-
18
- def parse_args(args):
19
- parser = argparse.ArgumentParser(
20
- description=__doc__,
21
- formatter_class=HelpFormatter,
22
- )
23
- parser.add_argument(
24
- "--ttbar",
25
- required=True,
26
- type=Path,
27
- help="path to ttbar sample (supports globbing)",
28
- )
29
- parser.add_argument(
30
- "--zprime",
31
- required=False,
32
- type=Path,
33
- help="path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
34
- )
35
- parser.add_argument(
36
- "-e",
37
- "--effs",
38
- nargs="+",
39
- type=float,
40
- help="efficiency working point(s). If -r is specified, values should be 1/efficiency",
41
- )
42
- parser.add_argument(
43
- "-t",
44
- "--tagger",
45
- nargs="+",
46
- required=True,
47
- type=str,
48
- help="tagger name(s)",
49
- )
50
- parser.add_argument(
51
- "-s",
52
- "--signal",
53
- default="bjets",
54
- choices=["bjets", "cjets", "hbb", "hcc"],
55
- type=str,
56
- help='signal flavour ("bjets" or "cjets" for b-tagging, "hbb" or "hcc" for Xbb)',
57
- )
58
- parser.add_argument(
59
- "-r",
60
- "--rejection",
61
- default=None,
62
- choices=["ujets", "cjets", "bjets", "hbb", "hcc", "top", "qcd"],
63
- help="use rejection of specified background class to determine working points",
64
- )
65
- parser.add_argument(
66
- "-d",
67
- "--disc_cuts",
68
- nargs="+",
69
- type=float,
70
- help="D_x value(s) to calculate efficiency at",
71
- )
72
- parser.add_argument(
73
- "-n",
74
- "--num_jets",
75
- default=1_000_000,
76
- type=int,
77
- help="use this many jets (post selection)",
78
- )
79
- parser.add_argument(
80
- "--ttbar_cuts",
81
- nargs="+",
82
- default=["pt > 20e3"],
83
- type=list,
84
- help="selection to apply to ttbar (|eta| < 2.5 is always applied)",
85
- )
86
- parser.add_argument(
87
- "--zprime_cuts",
88
- nargs="+",
89
- default=["pt > 250e3"],
90
- type=list,
91
- help="selection to apply to zprime (|eta| < 2.5 is always applied)",
92
- )
93
- parser.add_argument(
94
- "-o",
95
- "--outfile",
96
- type=Path,
97
- help="save results to yaml instead of printing",
98
- )
99
- parser.add_argument(
100
- "--xbb",
101
- action="store_true",
102
- help="Enable Xbb tagging which expects two fx values ftop and fhcc/fhbb for each tagger",
103
- )
104
- parser.add_argument(
105
- "--fb",
106
- nargs="+",
107
- type=float,
108
- help="fb value(s) for each tagger",
109
- )
110
- parser.add_argument(
111
- "--fc",
112
- nargs="+",
113
- type=float,
114
- help="fc value(s) for each tagger",
115
- )
116
- parser.add_argument(
117
- "--ftau",
118
- nargs="+",
119
- type=float,
120
- help="ftau value(s) for each tagger",
121
- )
122
- parser.add_argument(
123
- "--ftop",
124
- nargs="+",
125
- type=float,
126
- help="ftop value(s) for each tagger",
127
- )
128
- parser.add_argument(
129
- "--fhbb",
130
- nargs="+",
131
- type=float,
132
- help="fhbb value(s) for each tagger",
133
- )
134
- parser.add_argument(
135
- "--fhcc",
136
- nargs="+",
137
- type=float,
138
- help="fhcc value(s) for each tagger",
139
- )
140
- args = parser.parse_args(args)
141
-
142
- args.signal = Flavours[args.signal]
143
-
144
- if args.effs and args.disc_cuts:
145
- raise ValueError("Cannot specify both --effs and --disc_cuts")
146
- if not args.effs and not args.disc_cuts:
147
- raise ValueError("Must specify either --effs or --disc_cuts")
148
-
149
- if args.xbb:
150
- if args.signal not in {Flavours.hbb, Flavours.hcc}:
151
- raise ValueError("Xbb tagging only supports hbb or hcc signal flavours")
152
- if args.fb or args.fc or args.ftau:
153
- raise ValueError("For Xbb tagging, fb, fc and ftau should not be specified")
154
- if not args.ftop:
155
- raise ValueError("For Xbb tagging, ftop should be specified")
156
- if args.signal == "hbb" and not args.fhcc:
157
- raise ValueError("For hbb tagging, fhcc should be specified")
158
- if args.signal == "hcc" and not args.fhbb:
159
- raise ValueError("For hcc tagging, fhbb should be specified")
160
- else:
161
- if args.ftop or args.fhbb or args.fhcc:
162
- raise ValueError("For single-b tagging, ftop, fhbb and fhcc should not be specified")
163
- if args.signal == "bjets" and not args.fc:
164
- raise ValueError("For bjets tagging, fc should be specified")
165
- if args.signal == "cjets" and not args.fb:
166
- raise ValueError("For cjets tagging, fb should be specified")
167
- if args.ftau is None:
168
- args.ftau = [0.0] * len(args.tagger)
169
-
170
- for fx in ["fb", "fc", "ftau", "ftop", "fhbb", "fhcc"]:
171
- if getattr(args, fx) and len(getattr(args, fx)) != len(args.tagger):
172
- raise ValueError(f"Number of {fx} values must match number of taggers")
173
-
174
- return args
175
-
176
-
177
- def get_fxs_from_args(args):
178
- if args.signal == Flavours.bjets:
179
- fxs = {"fc": args.fc, "ftau": args.ftau}
180
- elif args.signal == Flavours.cjets:
181
- fxs = {"fb": args.fb, "ftau": args.ftau}
182
- elif args.signal == Flavours.hbb:
183
- fxs = {"ftop": args.ftop, "fhcc": args.fhcc}
184
- elif args.signal == Flavours.hcc:
185
- fxs = {"ftop": args.ftop, "fhbb": args.fhbb}
186
- assert fxs is not None
187
- return [{k: v[i] for k, v in fxs.items()} for i in range(len(args.tagger))]
188
-
189
-
190
- def get_eff_rej(jets, disc, wp, flavs):
191
- out = {"eff": {}, "rej": {}}
192
- for bkg in list(flavs):
193
- bkg_disc = disc[bkg.cuts(jets).idx]
194
- eff = sum(bkg_disc > wp) / len(bkg_disc)
195
- out["eff"][str(bkg)] = float(f"{eff:.3g}")
196
- out["rej"][str(bkg)] = float(f"{1 / eff:.3g}")
197
- return out
198
-
199
-
200
- def get_rej_eff_at_disc(jets, tagger, signal, disc_cuts, **fxs):
201
- disc = get_discriminant(jets, tagger, signal, **fxs)
202
- d = {}
203
- flavs = Flavours.by_category("single-btag")
204
- for dcut in disc_cuts:
205
- d[str(dcut)] = {"eff": {}, "rej": {}}
206
- for f in flavs:
207
- e_discs = disc[f.cuts(jets).idx]
208
- eff = sum(e_discs > dcut) / len(e_discs)
209
- d[str(dcut)]["eff"][str(f)] = float(f"{eff:.3g}")
210
- d[str(dcut)]["rej"][str(f)] = 1 / float(f"{eff:.3g}")
211
- return d
212
-
213
-
214
- def setup_common_parts(args):
215
- flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
216
- default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
217
- ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
218
- zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
219
-
220
- # prepare to load jets
221
- all_vars = list(set(sum((flav.cuts.variables for flav in flavs), [])))
222
- reader = H5Reader(args.ttbar)
223
- jet_vars = reader.dtypes()["jets"].names
224
- for tagger in args.tagger:
225
- all_vars += [f"{tagger}_{f.px}" for f in flavs if (f"{tagger}_{f.px}" in jet_vars)]
226
-
227
- # load jets
228
- jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
229
- zp_jets = None
230
- if args.zprime:
231
- zp_reader = H5Reader(args.zprime)
232
- zp_jets = zp_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)["jets"]
233
-
234
- return jets, zp_jets, flavs
235
-
236
-
237
- def get_working_points(args=None):
238
- jets, zp_jets, flavs = setup_common_parts(args)
239
- fxs = get_fxs_from_args(args)
240
-
241
- # loop over taggers
242
- out = {}
243
- for i, tagger in enumerate(args.tagger):
244
- # calculate discriminant
245
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
246
- disc = get_discriminant(jets, tagger, args.signal, **fxs[i])
247
-
248
- # loop over efficiency working points
249
- for eff in args.effs:
250
- d = out[tagger][f"{eff:.0f}"] = {}
251
-
252
- wp_flavour = args.signal
253
- if args.rejection:
254
- eff = 100 / eff # noqa: PLW2901
255
- wp_flavour = args.rejection
256
-
257
- wp_disc = disc[flavs[wp_flavour].cuts(jets).idx]
258
- wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
259
-
260
- # calculate eff and rej for each flavour
261
- d["ttbar"] = get_eff_rej(jets, disc, wp, flavs)
262
-
263
- # calculate for zprime
264
- if args.zprime:
265
- zp_disc = get_discriminant(zp_jets, tagger, Flavours[args.signal], **fxs[i])
266
- d["zprime"] = get_eff_rej(zp_jets, zp_disc, wp, flavs)
267
-
268
- if args.outfile:
269
- with open(args.outfile, "w") as f:
270
- yaml.dump(out, f, sort_keys=False)
271
- return None
272
- else:
273
- return out
274
-
275
-
276
- def get_efficiencies(args=None):
277
- jets, zp_jets, _ = setup_common_parts(args)
278
- fxs = get_fxs_from_args(args)
279
-
280
- out = {}
281
- for i, tagger in enumerate(args.tagger):
282
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
283
-
284
- out[tagger]["ttbar"] = get_rej_eff_at_disc(
285
- jets, tagger, args.signal, args.disc_cuts, **fxs[i]
286
- )
287
- if args.zprime:
288
- out[tagger]["zprime"] = get_rej_eff_at_disc(
289
- zp_jets, tagger, args.signal, args.disc_cuts, **fxs[i]
290
- )
291
-
292
- if args.outfile:
293
- with open(args.outfile, "w") as f:
294
- yaml.dump(out, f, sort_keys=False)
295
- return None
296
- else:
297
- return out
298
-
299
-
300
- def main(args=None):
301
- args = parse_args(args)
302
-
303
- if args.effs:
304
- out = get_working_points(args)
305
- elif args.disc_cuts:
306
- out = get_efficiencies(args)
307
-
308
- if out:
309
- print(yaml.dump(out, sort_keys=False))
310
- return out
311
-
312
- return None
313
-
314
-
315
- if __name__ == "__main__":
316
- main()