atlas-ftag-tools 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/METADATA +4 -3
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/RECORD +14 -12
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/entry_points.txt +1 -1
- ftag/__init__.py +6 -5
- ftag/flavours.yaml +47 -4
- ftag/fraction_optimization.py +184 -0
- ftag/labels.py +10 -2
- ftag/mock.py +58 -17
- ftag/utils/__init__.py +24 -0
- ftag/utils/logging.py +123 -0
- ftag/utils/metrics.py +431 -0
- ftag/working_points.py +547 -0
- ftag/wps/__init__.py +0 -0
- ftag/wps/discriminant.py +0 -131
- ftag/wps/working_points.py +0 -316
- {atlas_ftag_tools-0.2.8.dist-info → atlas_ftag_tools-0.2.10.dist-info}/top_level.txt +0 -0
ftag/working_points.py
ADDED
@@ -0,0 +1,547 @@
|
|
1
|
+
"""Calculate tagger working points."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import argparse
|
6
|
+
import sys
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import TYPE_CHECKING
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import yaml
|
12
|
+
|
13
|
+
from ftag import Flavours
|
14
|
+
from ftag.cli_utils import HelpFormatter
|
15
|
+
from ftag.cuts import Cuts
|
16
|
+
from ftag.hdf5 import H5Reader
|
17
|
+
from ftag.utils import get_discriminant
|
18
|
+
|
19
|
+
if TYPE_CHECKING: # pragma: no cover
|
20
|
+
from collections.abc import Sequence
|
21
|
+
|
22
|
+
from ftag.labels import Label, LabelContainer
|
23
|
+
|
24
|
+
|
25
|
+
def parse_args(args: Sequence[str]) -> argparse.Namespace:
|
26
|
+
"""Parse the input arguments into a Namespace.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
args : Sequence[str] | None
|
31
|
+
Sequence of string inputs to the script
|
32
|
+
|
33
|
+
Returns
|
34
|
+
-------
|
35
|
+
argparse.Namespace
|
36
|
+
Namespace with the parsed arguments
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
When both --effs and --disc_cuts are provided
|
42
|
+
ValueError
|
43
|
+
When neither --effs nor --disc_cuts are provided
|
44
|
+
ValueError
|
45
|
+
When the number of fraction values is not conistent
|
46
|
+
ValueError
|
47
|
+
When the sum of fraction values for a tagger is not equal to one
|
48
|
+
"""
|
49
|
+
# Define the pre-parser which checks the --category
|
50
|
+
pre_parser = argparse.ArgumentParser(add_help=False)
|
51
|
+
pre_parser.add_argument(
|
52
|
+
"-c",
|
53
|
+
"--category",
|
54
|
+
default="single-btag",
|
55
|
+
type=str,
|
56
|
+
help="Label category to use for the working point calculation",
|
57
|
+
)
|
58
|
+
|
59
|
+
pre_parser.add_argument(
|
60
|
+
"-s",
|
61
|
+
"--signal",
|
62
|
+
default="bjets",
|
63
|
+
type=str,
|
64
|
+
help="Signal flavour which is to be used",
|
65
|
+
)
|
66
|
+
|
67
|
+
# Parse only --category/--signal and ignore for now all other args
|
68
|
+
pre_args, remaining_argv = pre_parser.parse_known_args(args=args)
|
69
|
+
|
70
|
+
# Create the "real" parser
|
71
|
+
parser = argparse.ArgumentParser(
|
72
|
+
description=__doc__,
|
73
|
+
formatter_class=HelpFormatter,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Add --category/--signal so the help is correctly shown
|
77
|
+
parser.add_argument(
|
78
|
+
"-c",
|
79
|
+
"--category",
|
80
|
+
default="single-btag",
|
81
|
+
type=str,
|
82
|
+
help="Label category to use for the working point calculation",
|
83
|
+
)
|
84
|
+
parser.add_argument(
|
85
|
+
"-s",
|
86
|
+
"--signal",
|
87
|
+
default="bjets",
|
88
|
+
type=str,
|
89
|
+
help="Signal flavour which is to be used",
|
90
|
+
)
|
91
|
+
|
92
|
+
# Check which label category was chosen and load the corresponding flavours
|
93
|
+
flavours = Flavours.by_category(pre_args.category)
|
94
|
+
|
95
|
+
# Build the fraction value arguments for all classes (besides signal)
|
96
|
+
for flav in flavours:
|
97
|
+
# Skip signal
|
98
|
+
if flav.name == pre_args.signal:
|
99
|
+
continue
|
100
|
+
|
101
|
+
# Built fraction values for all background classes
|
102
|
+
parser.add_argument(
|
103
|
+
f"--{flav.frac_str}",
|
104
|
+
nargs="+",
|
105
|
+
required=True,
|
106
|
+
type=float,
|
107
|
+
help=f"{flav.frac_str} value(s) for each tagger",
|
108
|
+
)
|
109
|
+
|
110
|
+
# # Adding the other arguments
|
111
|
+
parser.add_argument(
|
112
|
+
"--ttbar",
|
113
|
+
required=True,
|
114
|
+
type=Path,
|
115
|
+
help="Path to ttbar sample (supports globbing)",
|
116
|
+
)
|
117
|
+
parser.add_argument(
|
118
|
+
"--zprime",
|
119
|
+
required=False,
|
120
|
+
type=Path,
|
121
|
+
help="Path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"-t",
|
125
|
+
"--tagger",
|
126
|
+
nargs="+",
|
127
|
+
required=True,
|
128
|
+
type=str,
|
129
|
+
help="tagger name(s)",
|
130
|
+
)
|
131
|
+
parser.add_argument(
|
132
|
+
"-e",
|
133
|
+
"--effs",
|
134
|
+
nargs="+",
|
135
|
+
type=float,
|
136
|
+
help="Efficiency working point(s). If -r is specified, values should be 1/efficiency",
|
137
|
+
)
|
138
|
+
parser.add_argument(
|
139
|
+
"-r",
|
140
|
+
"--rejection",
|
141
|
+
default=None,
|
142
|
+
help="Use rejection of specified background class to determine working points",
|
143
|
+
)
|
144
|
+
parser.add_argument(
|
145
|
+
"-d",
|
146
|
+
"--disc_cuts",
|
147
|
+
nargs="+",
|
148
|
+
type=float,
|
149
|
+
help="D_x value(s) to calculate efficiency at",
|
150
|
+
)
|
151
|
+
parser.add_argument(
|
152
|
+
"-n",
|
153
|
+
"--num_jets",
|
154
|
+
default=1_000_000,
|
155
|
+
type=int,
|
156
|
+
help="Use this many jets (post selection)",
|
157
|
+
)
|
158
|
+
parser.add_argument(
|
159
|
+
"--ttbar_cuts",
|
160
|
+
nargs="+",
|
161
|
+
default=["pt > 20e3"],
|
162
|
+
type=list,
|
163
|
+
help="Selection to apply to ttbar (|eta| < 2.5 is always applied)",
|
164
|
+
)
|
165
|
+
parser.add_argument(
|
166
|
+
"--zprime_cuts",
|
167
|
+
nargs="+",
|
168
|
+
default=["pt > 250e3"],
|
169
|
+
type=list,
|
170
|
+
help="Selection to apply to zprime (|eta| < 2.5 is always applied)",
|
171
|
+
)
|
172
|
+
parser.add_argument(
|
173
|
+
"-o",
|
174
|
+
"--outfile",
|
175
|
+
type=Path,
|
176
|
+
help="Save results to yaml instead of printing",
|
177
|
+
)
|
178
|
+
|
179
|
+
# Final parse of all arguments
|
180
|
+
parsed_args = parser.parse_args(remaining_argv)
|
181
|
+
|
182
|
+
# Define the signal as an instance of Flavours
|
183
|
+
parsed_args.signal = Flavours[parsed_args.signal]
|
184
|
+
|
185
|
+
# Check that only --effs or --disc_cuts is given
|
186
|
+
if parsed_args.effs and parsed_args.disc_cuts:
|
187
|
+
raise ValueError("Cannot specify both --effs and --disc_cuts")
|
188
|
+
if not parsed_args.effs and not parsed_args.disc_cuts:
|
189
|
+
raise ValueError("Must specify either --effs or --disc_cuts")
|
190
|
+
|
191
|
+
# Check that all fraction values have the same length
|
192
|
+
for flav in flavours:
|
193
|
+
if flav.name != parsed_args.signal.name and len(getattr(parsed_args, flav.frac_str)) != len(
|
194
|
+
parsed_args.tagger
|
195
|
+
):
|
196
|
+
raise ValueError(f"Number of {flav.frac_str} values must match number of taggers")
|
197
|
+
|
198
|
+
# Check that all fraction value combinations add up to one
|
199
|
+
for tagger_idx in range(len(parsed_args.tagger)):
|
200
|
+
fraction_value_sum = 0
|
201
|
+
for flav in flavours:
|
202
|
+
if flav.name != parsed_args.signal.name:
|
203
|
+
fraction_value_sum += getattr(parsed_args, flav.frac_str)[tagger_idx]
|
204
|
+
|
205
|
+
# Round the value to take machine precision into account
|
206
|
+
fraction_value_sum = np.round(fraction_value_sum, 8)
|
207
|
+
|
208
|
+
# Check it's equal to one
|
209
|
+
if fraction_value_sum != 1:
|
210
|
+
raise ValueError(
|
211
|
+
"Sum of the fraction values must be one! You gave "
|
212
|
+
f"{fraction_value_sum} for tagger {parsed_args.tagger[tagger_idx]}"
|
213
|
+
)
|
214
|
+
return parsed_args
|
215
|
+
|
216
|
+
|
217
|
+
def get_fxs_from_args(args: argparse.Namespace, flavours: LabelContainer) -> list:
|
218
|
+
"""Get the fraction values for each tagger from the argparsed inputs.
|
219
|
+
|
220
|
+
Parameters
|
221
|
+
----------
|
222
|
+
args : argparse.Namespace
|
223
|
+
Input arguments parsed by the argparser
|
224
|
+
flavours : LabelContainer
|
225
|
+
LabelContainer instance of the labels that are used
|
226
|
+
|
227
|
+
Returns
|
228
|
+
-------
|
229
|
+
list
|
230
|
+
List of dicts with the fraction values. Each dict is for one tagger.
|
231
|
+
"""
|
232
|
+
# Init the fraction_dict dict
|
233
|
+
fraction_dict = {}
|
234
|
+
|
235
|
+
# Add the fraction values to the dict
|
236
|
+
for flav in flavours:
|
237
|
+
if flav.name != args.signal.name:
|
238
|
+
fraction_dict[flav.frac_str] = vars(args)[flav.frac_str]
|
239
|
+
|
240
|
+
return [{k: v[i] for k, v in fraction_dict.items()} for i in range(len(args.tagger))]
|
241
|
+
|
242
|
+
|
243
|
+
def get_eff_rej(
|
244
|
+
jets: np.ndarray,
|
245
|
+
disc: np.ndarray,
|
246
|
+
wp: float,
|
247
|
+
flavours: LabelContainer,
|
248
|
+
) -> dict:
|
249
|
+
"""Calculate the efficiency/rejection for each flavour.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
jets : np.ndarray
|
254
|
+
Loaded jets
|
255
|
+
disc : np.ndarray
|
256
|
+
Discriminant values of the jets
|
257
|
+
wp : float
|
258
|
+
Working point that is used
|
259
|
+
flavours : LabelContainer
|
260
|
+
LabelContainer instance of the flavours used
|
261
|
+
|
262
|
+
Returns
|
263
|
+
-------
|
264
|
+
dict
|
265
|
+
Dict with the efficiency/rejection values for each flavour
|
266
|
+
"""
|
267
|
+
# Init an out dict
|
268
|
+
out: dict[str, dict] = {"eff": {}, "rej": {}}
|
269
|
+
|
270
|
+
# Loop over the flavours
|
271
|
+
for flav in flavours:
|
272
|
+
# Calculate discriminant values and efficiencies/rejections
|
273
|
+
flav_disc = disc[flav.cuts(jets).idx]
|
274
|
+
eff = sum(flav_disc > wp) / len(flav_disc)
|
275
|
+
out["eff"][flav.name] = float(f"{eff:.3g}")
|
276
|
+
out["rej"][flav.name] = float(f"{1 / eff:.3g}")
|
277
|
+
|
278
|
+
return out
|
279
|
+
|
280
|
+
|
281
|
+
def get_rej_eff_at_disc(
|
282
|
+
jets: np.ndarray,
|
283
|
+
tagger: str,
|
284
|
+
signal: Label,
|
285
|
+
disc_cuts: list,
|
286
|
+
flavours: LabelContainer,
|
287
|
+
fraction_values: dict,
|
288
|
+
) -> dict:
|
289
|
+
"""Calculate the efficiency/rejection at a certain discriminant values.
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
jets : np.ndarray
|
294
|
+
Loaded jets used
|
295
|
+
tagger : str
|
296
|
+
Name of the tagger
|
297
|
+
signal : Label
|
298
|
+
Label instance of the signal flavour
|
299
|
+
disc_cuts : list
|
300
|
+
List of discriminant cut values for which the efficiency/rejection is calculated
|
301
|
+
flavours : LabelContainer
|
302
|
+
LabelContainer instance of the flavours that are used
|
303
|
+
|
304
|
+
Returns
|
305
|
+
-------
|
306
|
+
dict
|
307
|
+
Dict with the discriminant cut values and their respective efficiencies/rejections
|
308
|
+
"""
|
309
|
+
# Calculate discriminants
|
310
|
+
disc = get_discriminant(
|
311
|
+
jets=jets,
|
312
|
+
tagger=tagger,
|
313
|
+
signal=signal,
|
314
|
+
flavours=flavours,
|
315
|
+
fraction_values=fraction_values,
|
316
|
+
)
|
317
|
+
|
318
|
+
# Init out dict
|
319
|
+
ref_eff_dict: dict[str, dict] = {}
|
320
|
+
|
321
|
+
# Loop over the disc cut values
|
322
|
+
for dcut in disc_cuts:
|
323
|
+
ref_eff_dict[str(dcut)] = {"eff": {}, "rej": {}}
|
324
|
+
|
325
|
+
# Loop over the flavours
|
326
|
+
for flav in flavours:
|
327
|
+
e_discs = disc[flav.cuts(jets).idx]
|
328
|
+
eff = sum(e_discs > dcut) / len(e_discs)
|
329
|
+
ref_eff_dict[str(dcut)]["eff"][str(flav)] = float(f"{eff:.3g}")
|
330
|
+
ref_eff_dict[str(dcut)]["rej"][str(flav)] = 1 / float(f"{eff:.3g}")
|
331
|
+
|
332
|
+
return ref_eff_dict
|
333
|
+
|
334
|
+
|
335
|
+
def setup_common_parts(
|
336
|
+
args: argparse.Namespace,
|
337
|
+
) -> tuple[np.ndarray, np.ndarray | None, LabelContainer]:
|
338
|
+
"""Load the jets from the files and setup the taggers.
|
339
|
+
|
340
|
+
Parameters
|
341
|
+
----------
|
342
|
+
args : argparse.Namespace
|
343
|
+
Input arguments from the argparser
|
344
|
+
|
345
|
+
Returns
|
346
|
+
-------
|
347
|
+
tuple[dict, dict | None, list]
|
348
|
+
Outputs the ttbar jets, the zprime jets (if wanted, else None), and the flavours used.
|
349
|
+
"""
|
350
|
+
# Get the used flavours
|
351
|
+
flavours = Flavours.by_category(args.category)
|
352
|
+
|
353
|
+
# Get the cuts for the samples
|
354
|
+
default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
|
355
|
+
ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
|
356
|
+
zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
|
357
|
+
|
358
|
+
# Prepare the loading of the jets
|
359
|
+
all_vars = list(set(sum((flav.cuts.variables for flav in flavours), [])))
|
360
|
+
reader = H5Reader(args.ttbar)
|
361
|
+
jet_vars = reader.dtypes()["jets"].names
|
362
|
+
|
363
|
+
# Create for all taggers the fraction values
|
364
|
+
for tagger in args.tagger:
|
365
|
+
all_vars += [
|
366
|
+
f"{tagger}_{flav.px}" for flav in flavours if (f"{tagger}_{flav.px}" in jet_vars)
|
367
|
+
]
|
368
|
+
|
369
|
+
# Load ttbar jets
|
370
|
+
ttbar_jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
|
371
|
+
zprime_jets = None
|
372
|
+
|
373
|
+
# Load zprime jets if needed
|
374
|
+
if args.zprime:
|
375
|
+
zprime_reader = H5Reader(args.zprime)
|
376
|
+
zprime_jets = zprime_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)[
|
377
|
+
"jets"
|
378
|
+
]
|
379
|
+
|
380
|
+
else:
|
381
|
+
zprime_jets = None
|
382
|
+
|
383
|
+
return ttbar_jets, zprime_jets, flavours
|
384
|
+
|
385
|
+
|
386
|
+
def get_working_points(args: argparse.Namespace) -> dict | None:
|
387
|
+
"""Calculate the working points.
|
388
|
+
|
389
|
+
Parameters
|
390
|
+
----------
|
391
|
+
args : argparse.Namespace
|
392
|
+
Input arguments from the argparser
|
393
|
+
|
394
|
+
Returns
|
395
|
+
-------
|
396
|
+
dict | None
|
397
|
+
Dict with the working points. If args.outfile is given, the function returns None and
|
398
|
+
stored the resulting dict in a yaml file in args.outfile.
|
399
|
+
"""
|
400
|
+
# Load the jets and flavours and get the fraction values
|
401
|
+
ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
|
402
|
+
fraction_values = get_fxs_from_args(args=args, flavours=flavours)
|
403
|
+
|
404
|
+
# Init an out dict
|
405
|
+
out = {}
|
406
|
+
|
407
|
+
# Loop over taggers
|
408
|
+
for i, tagger in enumerate(args.tagger):
|
409
|
+
# Calculate discriminant
|
410
|
+
out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
|
411
|
+
disc = get_discriminant(
|
412
|
+
jets=ttbar_jets,
|
413
|
+
tagger=tagger,
|
414
|
+
signal=args.signal,
|
415
|
+
flavours=flavours,
|
416
|
+
fraction_values=fraction_values[i],
|
417
|
+
)
|
418
|
+
|
419
|
+
# Loop over efficiency working points
|
420
|
+
for eff in args.effs:
|
421
|
+
d = out[tagger][f"{eff:.0f}"] = {}
|
422
|
+
|
423
|
+
# Set the working point
|
424
|
+
wp_flavour = args.signal
|
425
|
+
if args.rejection:
|
426
|
+
eff = 100 / eff # noqa: PLW2901
|
427
|
+
wp_flavour = args.rejection
|
428
|
+
|
429
|
+
# Calculate the discriminant value of the working point
|
430
|
+
wp_disc = disc[flavours[wp_flavour].cuts(ttbar_jets).idx]
|
431
|
+
wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
|
432
|
+
|
433
|
+
# Calculate efficiency and rejection for each flavour
|
434
|
+
d["ttbar"] = get_eff_rej(
|
435
|
+
jets=ttbar_jets,
|
436
|
+
disc=disc,
|
437
|
+
wp=wp,
|
438
|
+
flavours=flavours,
|
439
|
+
)
|
440
|
+
|
441
|
+
# calculate for zprime
|
442
|
+
if args.zprime:
|
443
|
+
zprime_disc = get_discriminant(
|
444
|
+
jets=zprime_jets,
|
445
|
+
tagger=tagger,
|
446
|
+
signal=args.signal,
|
447
|
+
flavours=flavours,
|
448
|
+
fraction_values=fraction_values[i],
|
449
|
+
)
|
450
|
+
d["zprime"] = get_eff_rej(
|
451
|
+
jets=zprime_jets,
|
452
|
+
disc=zprime_disc,
|
453
|
+
wp=wp,
|
454
|
+
flavours=flavours,
|
455
|
+
)
|
456
|
+
|
457
|
+
if args.outfile:
|
458
|
+
with open(args.outfile, "w") as f:
|
459
|
+
yaml.dump(out, f, sort_keys=False)
|
460
|
+
return None
|
461
|
+
|
462
|
+
else:
|
463
|
+
return out
|
464
|
+
|
465
|
+
|
466
|
+
def get_efficiencies(args: argparse.Namespace) -> dict | None:
|
467
|
+
"""Calculate the efficiencies for the given jets.
|
468
|
+
|
469
|
+
Parameters
|
470
|
+
----------
|
471
|
+
args : argparse.Namespace
|
472
|
+
Input arguments from the argparser
|
473
|
+
|
474
|
+
Returns
|
475
|
+
-------
|
476
|
+
dict | None
|
477
|
+
Dict with the efficiencies. If args.outfile is given, the function returns None and
|
478
|
+
stored the resulting dict in a yaml file in args.outfile.
|
479
|
+
"""
|
480
|
+
# Load the jets and flavours and get the fraction values
|
481
|
+
ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
|
482
|
+
fraction_values = get_fxs_from_args(args=args, flavours=flavours)
|
483
|
+
|
484
|
+
# Init an out dict
|
485
|
+
out = {}
|
486
|
+
|
487
|
+
# Loop over the taggers
|
488
|
+
for i, tagger in enumerate(args.tagger):
|
489
|
+
out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
|
490
|
+
|
491
|
+
out[tagger]["ttbar"] = get_rej_eff_at_disc(
|
492
|
+
jets=ttbar_jets,
|
493
|
+
tagger=tagger,
|
494
|
+
signal=args.signal,
|
495
|
+
disc_cuts=args.disc_cuts,
|
496
|
+
flavours=flavours,
|
497
|
+
fraction_values=fraction_values[i],
|
498
|
+
)
|
499
|
+
if args.zprime:
|
500
|
+
out[tagger]["zprime"] = get_rej_eff_at_disc(
|
501
|
+
jets=zprime_jets,
|
502
|
+
tagger=tagger,
|
503
|
+
signal=args.signal,
|
504
|
+
disc_cuts=args.disc_cuts,
|
505
|
+
flavours=flavours,
|
506
|
+
fraction_values=fraction_values[i],
|
507
|
+
)
|
508
|
+
|
509
|
+
if args.outfile:
|
510
|
+
with open(args.outfile, "w") as f:
|
511
|
+
yaml.dump(out, f, sort_keys=False)
|
512
|
+
return None
|
513
|
+
else:
|
514
|
+
return out
|
515
|
+
|
516
|
+
|
517
|
+
def main(args: Sequence[str]) -> dict | None:
|
518
|
+
"""Main function to run working point calculation.
|
519
|
+
|
520
|
+
Parameters
|
521
|
+
----------
|
522
|
+
args : Sequence[str] | None, optional
|
523
|
+
Input arguments, by default None
|
524
|
+
|
525
|
+
Returns
|
526
|
+
-------
|
527
|
+
dict | None
|
528
|
+
The output dict with the calculated values. When --outfile
|
529
|
+
was given, the return value is None
|
530
|
+
"""
|
531
|
+
parsed_args = parse_args(args=args)
|
532
|
+
|
533
|
+
if parsed_args.effs:
|
534
|
+
out = get_working_points(args=parsed_args)
|
535
|
+
|
536
|
+
elif parsed_args.disc_cuts:
|
537
|
+
out = get_efficiencies(args=parsed_args)
|
538
|
+
|
539
|
+
if out:
|
540
|
+
print(yaml.dump(out, sort_keys=False))
|
541
|
+
return out
|
542
|
+
|
543
|
+
return None
|
544
|
+
|
545
|
+
|
546
|
+
if __name__ == "__main__": # pragma: no cover
|
547
|
+
main(args=sys.argv[1:])
|
ftag/wps/__init__.py
DELETED
File without changes
|
ftag/wps/discriminant.py
DELETED
@@ -1,131 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import Callable
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
from ftag import Flavours
|
8
|
-
from ftag.labels import Label, remove_suffix
|
9
|
-
|
10
|
-
|
11
|
-
def discriminant(
|
12
|
-
jets: np.ndarray,
|
13
|
-
tagger: str,
|
14
|
-
signal: Label,
|
15
|
-
fxs: dict[str, float],
|
16
|
-
epsilon: float = 1e-10,
|
17
|
-
) -> np.ndarray:
|
18
|
-
"""
|
19
|
-
Get the tagging discriminant.
|
20
|
-
|
21
|
-
Calculated as the logarithm of the ratio of a specified signal probability
|
22
|
-
to a weighted sum ofbackground probabilities.
|
23
|
-
|
24
|
-
Parameters
|
25
|
-
----------
|
26
|
-
jets : np.ndarray
|
27
|
-
Structed jet array containing tagger scores.
|
28
|
-
tagger : str
|
29
|
-
Name of the tagger, used to construct field names.
|
30
|
-
signal : str
|
31
|
-
Type of signal.
|
32
|
-
fxs : dict[str, float]
|
33
|
-
Dict of background probability names and their fractions.
|
34
|
-
If a fraction is None, it is calculated as (1 - sum of provided fractions).
|
35
|
-
epsilon : float, optional
|
36
|
-
A small value added to probabilities to prevent division by zero, by default 1e-10.
|
37
|
-
|
38
|
-
Returns
|
39
|
-
-------
|
40
|
-
np.ndarray
|
41
|
-
The tagger discriminant values for the jets.
|
42
|
-
|
43
|
-
Raises
|
44
|
-
------
|
45
|
-
ValueError
|
46
|
-
If a fraction is specified for a denominator that is not present in the input array.
|
47
|
-
"""
|
48
|
-
denominator = 0.0
|
49
|
-
for d, fx in fxs.items():
|
50
|
-
name = f"{tagger}_{d}"
|
51
|
-
if fx > 0 and name not in jets.dtype.names:
|
52
|
-
raise ValueError(f"Nonzero fx for {d}, but '{name}' not found in input array.")
|
53
|
-
denominator += jets[name] * fx if name in jets.dtype.names else 0
|
54
|
-
signal_field = f"{tagger}_{signal.px}"
|
55
|
-
if signal_field not in jets.dtype.names:
|
56
|
-
signal_field = f"{tagger}_p{remove_suffix(signal.name, 'jets')}"
|
57
|
-
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
58
|
-
|
59
|
-
|
60
|
-
def tautag_dicriminant(jets, tagger, fb, fc, epsilon=1e-10):
|
61
|
-
fxs = {"pb": fb, "pc": fc, "pu": 1 - fb - fc}
|
62
|
-
return discriminant(jets, tagger, Flavours.taujets, fxs, epsilon=epsilon)
|
63
|
-
|
64
|
-
|
65
|
-
def btag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
|
66
|
-
fxs = {"pc": fc, "ptau": ftau, "pu": 1 - fc - ftau}
|
67
|
-
return discriminant(jets, tagger, Flavours.bjets, fxs, epsilon=epsilon)
|
68
|
-
|
69
|
-
|
70
|
-
def ghostbtag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
|
71
|
-
fxs = {"pghostc": fc, "pghosttau": ftau, "pghostu": 1 - fc - ftau}
|
72
|
-
return discriminant(jets, tagger, Flavours.ghostbjets, fxs, epsilon=epsilon)
|
73
|
-
|
74
|
-
|
75
|
-
def ctag_discriminant(jets, tagger, fb, ftau=0, epsilon=1e-10):
|
76
|
-
fxs = {"pb": fb, "ptau": ftau, "pu": 1 - fb - ftau}
|
77
|
-
return discriminant(jets, tagger, Flavours.cjets, fxs, epsilon=epsilon)
|
78
|
-
|
79
|
-
|
80
|
-
def hbb_discriminant(jets, tagger, ftop=0.25, fhcc=0.02, epsilon=1e-10):
|
81
|
-
fxs = {"phcc": fhcc, "ptop": ftop, "pqcd": 1 - ftop - fhcc}
|
82
|
-
return discriminant(jets, tagger, Flavours.hbb, fxs, epsilon=epsilon)
|
83
|
-
|
84
|
-
|
85
|
-
def hcc_discriminant(jets, tagger, ftop=0.25, fhbb=0.3, epsilon=1e-10):
|
86
|
-
fxs = {"phbb": fhbb, "ptop": ftop, "pqcd": 1 - ftop - fhbb}
|
87
|
-
return discriminant(jets, tagger, Flavours.hcc, fxs, epsilon=epsilon)
|
88
|
-
|
89
|
-
|
90
|
-
def get_discriminant(
|
91
|
-
jets: np.ndarray, tagger: str, signal: Label | str, epsilon: float = 1e-10, **fxs
|
92
|
-
):
|
93
|
-
"""Calculate the b-tag or c-tag discriminant for a given tagger.
|
94
|
-
|
95
|
-
Parameters
|
96
|
-
----------
|
97
|
-
jets : np.ndarray
|
98
|
-
Structured array of jets containing tagger outputs
|
99
|
-
tagger : str
|
100
|
-
Name of the tagger
|
101
|
-
signal : Label
|
102
|
-
Signal flavour (bjets/cjets or hbb/hcc)
|
103
|
-
epsilon : float, optional
|
104
|
-
Small number to avoid division by zero, by default 1e-10
|
105
|
-
**fxs : dict
|
106
|
-
Fractions for the different background flavours.
|
107
|
-
|
108
|
-
Returns
|
109
|
-
-------
|
110
|
-
np.ndarray
|
111
|
-
Array of discriminant values.
|
112
|
-
|
113
|
-
Raises
|
114
|
-
------
|
115
|
-
ValueError
|
116
|
-
If the signal flavour is not recognised.
|
117
|
-
"""
|
118
|
-
tagger_funcs: dict[str, Callable] = {
|
119
|
-
"bjets": btag_discriminant,
|
120
|
-
"cjets": ctag_discriminant,
|
121
|
-
"taujets": tautag_dicriminant,
|
122
|
-
"hbb": hbb_discriminant,
|
123
|
-
"hcc": hcc_discriminant,
|
124
|
-
"ghostbjets": ghostbtag_discriminant,
|
125
|
-
}
|
126
|
-
|
127
|
-
if str(signal) not in tagger_funcs:
|
128
|
-
raise ValueError(f"Signal flavour must be one of {list(tagger_funcs.keys())}, not {signal}")
|
129
|
-
|
130
|
-
func: Callable = tagger_funcs[str(Flavours[signal])]
|
131
|
-
return func(jets, tagger, **fxs, epsilon=epsilon)
|