atlas-ftag-tools 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,316 +0,0 @@
1
- """Calculate tagger working points."""
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- from pathlib import Path
7
-
8
- import numpy as np
9
- import yaml
10
-
11
- from ftag import Flavours
12
- from ftag.cli_utils import HelpFormatter
13
- from ftag.cuts import Cuts
14
- from ftag.hdf5 import H5Reader
15
- from ftag.wps.discriminant import get_discriminant
16
-
17
-
18
- def parse_args(args):
19
- parser = argparse.ArgumentParser(
20
- description=__doc__,
21
- formatter_class=HelpFormatter,
22
- )
23
- parser.add_argument(
24
- "--ttbar",
25
- required=True,
26
- type=Path,
27
- help="path to ttbar sample (supports globbing)",
28
- )
29
- parser.add_argument(
30
- "--zprime",
31
- required=False,
32
- type=Path,
33
- help="path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
34
- )
35
- parser.add_argument(
36
- "-e",
37
- "--effs",
38
- nargs="+",
39
- type=float,
40
- help="efficiency working point(s). If -r is specified, values should be 1/efficiency",
41
- )
42
- parser.add_argument(
43
- "-t",
44
- "--tagger",
45
- nargs="+",
46
- required=True,
47
- type=str,
48
- help="tagger name(s)",
49
- )
50
- parser.add_argument(
51
- "-s",
52
- "--signal",
53
- default="bjets",
54
- choices=["bjets", "cjets", "hbb", "hcc"],
55
- type=str,
56
- help='signal flavour ("bjets" or "cjets" for b-tagging, "hbb" or "hcc" for Xbb)',
57
- )
58
- parser.add_argument(
59
- "-r",
60
- "--rejection",
61
- default=None,
62
- choices=["ujets", "cjets", "bjets", "hbb", "hcc", "top", "qcd"],
63
- help="use rejection of specified background class to determine working points",
64
- )
65
- parser.add_argument(
66
- "-d",
67
- "--disc_cuts",
68
- nargs="+",
69
- type=float,
70
- help="D_x value(s) to calculate efficiency at",
71
- )
72
- parser.add_argument(
73
- "-n",
74
- "--num_jets",
75
- default=1_000_000,
76
- type=int,
77
- help="use this many jets (post selection)",
78
- )
79
- parser.add_argument(
80
- "--ttbar_cuts",
81
- nargs="+",
82
- default=["pt > 20e3"],
83
- type=list,
84
- help="selection to apply to ttbar (|eta| < 2.5 is always applied)",
85
- )
86
- parser.add_argument(
87
- "--zprime_cuts",
88
- nargs="+",
89
- default=["pt > 250e3"],
90
- type=list,
91
- help="selection to apply to zprime (|eta| < 2.5 is always applied)",
92
- )
93
- parser.add_argument(
94
- "-o",
95
- "--outfile",
96
- type=Path,
97
- help="save results to yaml instead of printing",
98
- )
99
- parser.add_argument(
100
- "--xbb",
101
- action="store_true",
102
- help="Enable Xbb tagging which expects two fx values ftop and fhcc/fhbb for each tagger",
103
- )
104
- parser.add_argument(
105
- "--fb",
106
- nargs="+",
107
- type=float,
108
- help="fb value(s) for each tagger",
109
- )
110
- parser.add_argument(
111
- "--fc",
112
- nargs="+",
113
- type=float,
114
- help="fc value(s) for each tagger",
115
- )
116
- parser.add_argument(
117
- "--ftau",
118
- nargs="+",
119
- type=float,
120
- help="ftau value(s) for each tagger",
121
- )
122
- parser.add_argument(
123
- "--ftop",
124
- nargs="+",
125
- type=float,
126
- help="ftop value(s) for each tagger",
127
- )
128
- parser.add_argument(
129
- "--fhbb",
130
- nargs="+",
131
- type=float,
132
- help="fhbb value(s) for each tagger",
133
- )
134
- parser.add_argument(
135
- "--fhcc",
136
- nargs="+",
137
- type=float,
138
- help="fhcc value(s) for each tagger",
139
- )
140
- args = parser.parse_args(args)
141
-
142
- args.signal = Flavours[args.signal]
143
-
144
- if args.effs and args.disc_cuts:
145
- raise ValueError("Cannot specify both --effs and --disc_cuts")
146
- if not args.effs and not args.disc_cuts:
147
- raise ValueError("Must specify either --effs or --disc_cuts")
148
-
149
- if args.xbb:
150
- if args.signal not in {Flavours.hbb, Flavours.hcc}:
151
- raise ValueError("Xbb tagging only supports hbb or hcc signal flavours")
152
- if args.fb or args.fc or args.ftau:
153
- raise ValueError("For Xbb tagging, fb, fc and ftau should not be specified")
154
- if not args.ftop:
155
- raise ValueError("For Xbb tagging, ftop should be specified")
156
- if args.signal == "hbb" and not args.fhcc:
157
- raise ValueError("For hbb tagging, fhcc should be specified")
158
- if args.signal == "hcc" and not args.fhbb:
159
- raise ValueError("For hcc tagging, fhbb should be specified")
160
- else:
161
- if args.ftop or args.fhbb or args.fhcc:
162
- raise ValueError("For single-b tagging, ftop, fhbb and fhcc should not be specified")
163
- if args.signal == "bjets" and not args.fc:
164
- raise ValueError("For bjets tagging, fc should be specified")
165
- if args.signal == "cjets" and not args.fb:
166
- raise ValueError("For cjets tagging, fb should be specified")
167
- if args.ftau is None:
168
- args.ftau = [0.0] * len(args.tagger)
169
-
170
- for fx in ["fb", "fc", "ftau", "ftop", "fhbb", "fhcc"]:
171
- if getattr(args, fx) and len(getattr(args, fx)) != len(args.tagger):
172
- raise ValueError(f"Number of {fx} values must match number of taggers")
173
-
174
- return args
175
-
176
-
177
- def get_fxs_from_args(args):
178
- if args.signal == Flavours.bjets:
179
- fxs = {"fc": args.fc, "ftau": args.ftau}
180
- elif args.signal == Flavours.cjets:
181
- fxs = {"fb": args.fb, "ftau": args.ftau}
182
- elif args.signal == Flavours.hbb:
183
- fxs = {"ftop": args.ftop, "fhcc": args.fhcc}
184
- elif args.signal == Flavours.hcc:
185
- fxs = {"ftop": args.ftop, "fhbb": args.fhbb}
186
- assert fxs is not None
187
- return [{k: v[i] for k, v in fxs.items()} for i in range(len(args.tagger))]
188
-
189
-
190
- def get_eff_rej(jets, disc, wp, flavs):
191
- out = {"eff": {}, "rej": {}}
192
- for bkg in list(flavs):
193
- bkg_disc = disc[bkg.cuts(jets).idx]
194
- eff = sum(bkg_disc > wp) / len(bkg_disc)
195
- out["eff"][str(bkg)] = float(f"{eff:.3g}")
196
- out["rej"][str(bkg)] = float(f"{1 / eff:.3g}")
197
- return out
198
-
199
-
200
- def get_rej_eff_at_disc(jets, tagger, signal, disc_cuts, **fxs):
201
- disc = get_discriminant(jets, tagger, signal, **fxs)
202
- d = {}
203
- flavs = Flavours.by_category("single-btag")
204
- for dcut in disc_cuts:
205
- d[str(dcut)] = {"eff": {}, "rej": {}}
206
- for f in flavs:
207
- e_discs = disc[f.cuts(jets).idx]
208
- eff = sum(e_discs > dcut) / len(e_discs)
209
- d[str(dcut)]["eff"][str(f)] = float(f"{eff:.3g}")
210
- d[str(dcut)]["rej"][str(f)] = 1 / float(f"{eff:.3g}")
211
- return d
212
-
213
-
214
- def setup_common_parts(args):
215
- flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
216
- default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
217
- ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
218
- zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
219
-
220
- # prepare to load jets
221
- all_vars = list(set(sum((flav.cuts.variables for flav in flavs), [])))
222
- reader = H5Reader(args.ttbar)
223
- jet_vars = reader.dtypes()["jets"].names
224
- for tagger in args.tagger:
225
- all_vars += [f"{tagger}_{f.px}" for f in flavs if (f"{tagger}_{f.px}" in jet_vars)]
226
-
227
- # load jets
228
- jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
229
- zp_jets = None
230
- if args.zprime:
231
- zp_reader = H5Reader(args.zprime)
232
- zp_jets = zp_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)["jets"]
233
-
234
- return jets, zp_jets, flavs
235
-
236
-
237
- def get_working_points(args=None):
238
- jets, zp_jets, flavs = setup_common_parts(args)
239
- fxs = get_fxs_from_args(args)
240
-
241
- # loop over taggers
242
- out = {}
243
- for i, tagger in enumerate(args.tagger):
244
- # calculate discriminant
245
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
246
- disc = get_discriminant(jets, tagger, args.signal, **fxs[i])
247
-
248
- # loop over efficiency working points
249
- for eff in args.effs:
250
- d = out[tagger][f"{eff:.0f}"] = {}
251
-
252
- wp_flavour = args.signal
253
- if args.rejection:
254
- eff = 100 / eff # noqa: PLW2901
255
- wp_flavour = args.rejection
256
-
257
- wp_disc = disc[flavs[wp_flavour].cuts(jets).idx]
258
- wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
259
-
260
- # calculate eff and rej for each flavour
261
- d["ttbar"] = get_eff_rej(jets, disc, wp, flavs)
262
-
263
- # calculate for zprime
264
- if args.zprime:
265
- zp_disc = get_discriminant(zp_jets, tagger, Flavours[args.signal], **fxs[i])
266
- d["zprime"] = get_eff_rej(zp_jets, zp_disc, wp, flavs)
267
-
268
- if args.outfile:
269
- with open(args.outfile, "w") as f:
270
- yaml.dump(out, f, sort_keys=False)
271
- return None
272
- else:
273
- return out
274
-
275
-
276
- def get_efficiencies(args=None):
277
- jets, zp_jets, _ = setup_common_parts(args)
278
- fxs = get_fxs_from_args(args)
279
-
280
- out = {}
281
- for i, tagger in enumerate(args.tagger):
282
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
283
-
284
- out[tagger]["ttbar"] = get_rej_eff_at_disc(
285
- jets, tagger, args.signal, args.disc_cuts, **fxs[i]
286
- )
287
- if args.zprime:
288
- out[tagger]["zprime"] = get_rej_eff_at_disc(
289
- zp_jets, tagger, args.signal, args.disc_cuts, **fxs[i]
290
- )
291
-
292
- if args.outfile:
293
- with open(args.outfile, "w") as f:
294
- yaml.dump(out, f, sort_keys=False)
295
- return None
296
- else:
297
- return out
298
-
299
-
300
- def main(args=None):
301
- args = parse_args(args)
302
-
303
- if args.effs:
304
- out = get_working_points(args)
305
- elif args.disc_cuts:
306
- out = get_efficiencies(args)
307
-
308
- if out:
309
- print(yaml.dump(out, sort_keys=False))
310
- return out
311
-
312
- return None
313
-
314
-
315
- if __name__ == "__main__":
316
- main()