atlas-ftag-tools 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.1.16
3
+ Version: 0.1.17
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -87,11 +87,12 @@ The script is `working_points.py` and can be run after installing this package w
87
87
  ```
88
88
  wps \
89
89
  --ttbar "path/to/ttbar/*.h5" \
90
- --tagger GN120220509 \
91
- --fx 0.1
90
+ --tagger GN2v01 \
91
+ --fc 0.1
92
92
  ```
93
93
 
94
- Both the `--tagger` and `--fx` options accept a list if you want to get the WPs for multiple taggers.
94
+ Both the `--tagger` and `--fc` options accept a list if you want to get the WPs for multiple taggers.
95
+ If you are doing c-tagging or xbb-tagging, dedicated fx arguments are available ()you can find them all with `-h`.
95
96
 
96
97
  If you want to use the `ttbar` WPs get the efficiencies and rejections for the `zprime` sample, you can add `--zprime "path/to/zprime/*.h5"` to the command.
97
98
  Note that a default selection of $p_T > 250 ~GeV$ to jets in the `zprime` sample.
@@ -110,7 +111,7 @@ The script `working_points.py` can be run after intalling this package as follow
110
111
  ```
111
112
  wps \
112
113
  --ttbar "path/to/ttbar/*.h5" \
113
- --tagger GN120220509 \
114
+ --tagger GN2v01 \
114
115
  --fx 0.1
115
116
  --disc_cuts 1.0 1.5
116
117
  ```
@@ -1,23 +1,24 @@
1
- ftag/__init__.py,sha256=PN9ZbZHVfxreetH3HnhEbZf1G-J4Up-5si-dQ4rrezg,629
1
+ ftag/__init__.py,sha256=BZi1ffhTco63EC7njY1GtrGgoWXLbBQ-V5ptdfwvjRE,629
2
2
  ftag/cuts.py,sha256=RYAfK3MkEhYhlKQFWQTKu72ZrUwlExFeT8IWLSIgeTU,2798
3
- ftag/flavour.py,sha256=BLifDbJCoszPzgrU5X3Txff9fTMuVGEjUeU0OtOfllc,2701
3
+ ftag/flavour.py,sha256=LbvTQzPKHaGFO5-6XHE-VOV1OgTBy93gYsVcRDMkM64,3570
4
4
  ftag/flavours.yaml,sha256=9ifKyz1_VoHlOaWuf3JEqMLSYyLFedYJf9x1D6dCTnM,5335
5
5
  ftag/git_check.py,sha256=TvF502eqDrYzhI-SgruVolx1BPJi-J0mswc4pmgaYY8,1621
6
- ftag/mock.py,sha256=T2YGeuCkgDEzKPVYwU7Vnq5TNIDrx9eIDVrAQt8s0Mw,4865
6
+ ftag/mock.py,sha256=DZ72l8vi7h5SzgwjZklXdrZK1TI7Np7VApNQMNOVyxU,5321
7
7
  ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
8
8
  ftag/sample.py,sha256=cd-rNHsEY2aWSZdy3V4bOKi3aDMtHTCpjXS8Hl9zwUY,2597
9
9
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
10
- ftag/vds.py,sha256=PG-NCJdpmK5X2l8i8YWBHx6CY8vosShfqp36facKaYM,3383
10
+ ftag/vds.py,sha256=VuN62n-8JC2t-79vlcwKYJsLGRb0If0CVk_PXK3yLyA,3288
11
11
  ftag/hdf5/__init__.py,sha256=pZva2TI8nvpBwoawcm_ucVZbGsQJW_u8GGoJgt5mKEw,354
12
12
  ftag/hdf5/h5move.py,sha256=1XxiJZ96DYSp8JF0ry3lbRSQaFX72DjUV5vBA6hYw-0,873
13
13
  ftag/hdf5/h5reader.py,sha256=et-_LXt942xegqc14bPapUgIO7MUfC2m04uJslLkXxI,13579
14
14
  ftag/hdf5/h5split.py,sha256=BlhpsUlqBSDCjVRWuyEq1OImyzwp7VyVkDrCz7pvQKc,2508
15
15
  ftag/hdf5/h5utils.py,sha256=wjbAmFY5GoFkWW_AvEKTPbwYMFroHKKFuIcehd91dhM,3222
16
16
  ftag/hdf5/h5writer.py,sha256=5jm3vSk4m77lFSUyWm-i_y_USzQRVoKpLL8F_cii65Q,4826
17
- ftag/wps/discriminant.py,sha256=0YmI3-ieSWReO_uY4-3Sc_85hLVpoCHQ7LfuU1SC_Sg,2318
18
- ftag/wps/working_points.py,sha256=pLoe8RaVmbtiGM9TxMtWocMohBWrY6JcMCLE_e3XtVY,8033
19
- atlas_ftag_tools-0.1.16.dist-info/METADATA,sha256=H547FpfWZ6plg5PXm7E9SJhppKGWM8iKM_EpFJaUNJs,5064
20
- atlas_ftag_tools-0.1.16.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
21
- atlas_ftag_tools-0.1.16.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
22
- atlas_ftag_tools-0.1.16.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
23
- atlas_ftag_tools-0.1.16.dist-info/RECORD,,
17
+ ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ ftag/wps/discriminant.py,sha256=J2Cst9slZCLoHZYTeltQSmN1Uoa8GUX529rwNlHmyfI,3519
19
+ ftag/wps/working_points.py,sha256=H602ikwrLVoe2Vq13DPDYSxBpeJzN-2sIKZUeAYL3Pc,9694
20
+ atlas_ftag_tools-0.1.17.dist-info/METADATA,sha256=EUSiZ7BKKcjPdD1M0qT1rZkGdv9UnF3iXySdYjOkpTs,5169
21
+ atlas_ftag_tools-0.1.17.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
22
+ atlas_ftag_tools-0.1.17.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
23
+ atlas_ftag_tools-0.1.17.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
24
+ atlas_ftag_tools-0.1.17.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """atlas-ftag-tools - Common tools for ATLAS flavour tagging software."""
2
2
  from __future__ import annotations
3
3
 
4
- __version__ = "v0.1.16"
4
+ __version__ = "v0.1.17"
5
5
 
6
6
 
7
7
  from ftag import hdf5
ftag/flavour.py CHANGED
@@ -35,6 +35,10 @@ class Flavour:
35
35
  def rej_str(self) -> str:
36
36
  return self.label.replace("jets", "jet") + " rejection"
37
37
 
38
+ @property
39
+ def frac_str(self) -> str:
40
+ return "f" + self.name.replace("jets", "jet")
41
+
38
42
  def __str__(self) -> str:
39
43
  return self.name
40
44
 
@@ -62,6 +66,13 @@ class FlavourContainer:
62
66
  flavour = flavour.name
63
67
  return flavour in self.flavours
64
68
 
69
+ def __eq__(self, other) -> bool:
70
+ if isinstance(other, FlavourContainer):
71
+ return self.flavours == other.flavours
72
+ if isinstance(other, list) and all(isinstance(f, str) for f in other):
73
+ return {f.name for f in self} == set(other)
74
+ return False
75
+
65
76
  def __repr__(self) -> str:
66
77
  return f"{self.__class__.__name__}({', '.join([f.name for f in self])})"
67
78
 
@@ -95,5 +106,15 @@ class FlavourContainer:
95
106
 
96
107
  return cls(flavours_dict)
97
108
 
109
+ @classmethod
110
+ def from_list(cls, flavours: list[Flavour]) -> FlavourContainer:
111
+ return cls({f.name: f for f in flavours})
112
+
113
+ def backgrounds(self, flavour: Flavour, keep_possible_signals: bool = True) -> FlavourContainer:
114
+ bkg = [f for f in self if f.category == flavour.category and f != flavour]
115
+ if not keep_possible_signals:
116
+ bkg = [f for f in bkg if f.name not in ["ujets", "qcd"]]
117
+ return FlavourContainer.from_list(bkg)
118
+
98
119
 
99
120
  Flavours = FlavourContainer.from_yaml()
ftag/mock.py CHANGED
@@ -57,18 +57,31 @@ def softmax(x, axis=None):
57
57
  return e_x / e_x.sum(axis=axis, keepdims=True)
58
58
 
59
59
 
60
- def get_mock_scores(labels: np.ndarray):
60
+ def get_mock_scores(labels: np.ndarray, inc_tau: bool = False):
61
61
  rng = np.random.default_rng(42)
62
- scores = np.zeros((len(labels), 3))
62
+ nclass = 3 + inc_tau
63
+ scores = np.zeros((len(labels), nclass))
64
+
63
65
  for label, count in zip(*np.unique(labels, return_counts=True)):
64
- if label in (0, 15):
65
- scores[labels == label] = rng.normal(loc=[2, 0, 0], scale=1, size=(count, 3))
66
+ if label == 0:
67
+ scores[labels == label] = rng.normal(
68
+ loc=[2, 0, 0] + [0] * inc_tau, scale=1, size=(count, nclass)
69
+ )
66
70
  elif label == 4:
67
- scores[labels == label] = rng.normal(loc=[0, 1, 0], scale=2.5, size=(count, 3))
71
+ scores[labels == label] = rng.normal(
72
+ loc=[0, 1, 0] + [0] * inc_tau, scale=2.5, size=(count, nclass)
73
+ )
68
74
  elif label == 5:
69
- scores[labels == label] = rng.normal(loc=[0, 0, 3.5], scale=5, size=(count, 3))
75
+ scores[labels == label] = rng.normal(
76
+ loc=[0, 0, 3.5] + [0] * inc_tau, scale=5, size=(count, nclass)
77
+ )
78
+ elif label == 15:
79
+ scores[labels == label] = rng.normal(
80
+ loc=[0, 0, 0] + [1] * inc_tau, scale=1, size=(count, nclass)
81
+ )
70
82
  scores = softmax(scores, axis=1)
71
- cols = [f"MockTagger_p{x}" for x in ["u", "c", "b"]]
83
+ cols = [f"MockTagger_p{x}" for x in ["u", "c", "b"]] + (["MockTagger_ptau"] if inc_tau else [])
84
+
72
85
  return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
73
86
 
74
87
 
@@ -94,6 +107,7 @@ def get_mock_file(
94
107
  fname: str | None = None,
95
108
  tracks_name: str = "tracks",
96
109
  num_tracks: int = 40,
110
+ inc_tau: bool = False,
97
111
  ) -> tuple[str, h5py.File]:
98
112
  # setup jets
99
113
  rng = np.random.default_rng(42)
@@ -109,7 +123,7 @@ def get_mock_file(
109
123
  jets["n_truth_promptLepton"] = 0
110
124
 
111
125
  # add tagger scores
112
- scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
126
+ scores = get_mock_scores(jets["HadronConeExclTruthLabelID"], inc_tau=inc_tau)
113
127
  xbb_scores = get_mock_xbb_scores(jets["R10TruthLabel_R22v1"])
114
128
  jets = join_structured_arrays([jets, scores, xbb_scores])
115
129
 
ftag/vds.py CHANGED
@@ -72,10 +72,6 @@ def create_virtual_file(
72
72
  if not common_groups:
73
73
  raise ValueError("No common groups found across files")
74
74
 
75
- print("Common groups found:")
76
- for group in common_groups:
77
- print(f" {group}")
78
-
79
75
  # create virtual file
80
76
  out_fname.parent.mkdir(exist_ok=True)
81
77
  with h5py.File(out_fname, "w") as f:
ftag/wps/__init__.py ADDED
File without changes
ftag/wps/discriminant.py CHANGED
@@ -1,42 +1,80 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Callable
4
+
3
5
  import numpy as np
4
6
 
5
7
  from ftag.flavour import Flavour, Flavours
6
8
 
7
9
 
8
- def btag_discriminant(jets, tagger, fc=0.1, epsilon=1e-10):
9
- pb, pc, pu = (jets[f"{tagger}_pb"], jets[f"{tagger}_pc"], jets[f"{tagger}_pu"])
10
- return np.log((pb + epsilon) / ((1.0 - fc) * pu + fc * pc + epsilon))
10
+ def discriminant(
11
+ jets: np.ndarray,
12
+ tagger: str,
13
+ signal: Flavour,
14
+ fxs: dict[str, float],
15
+ epsilon: float = 1e-10,
16
+ ) -> np.ndarray:
17
+ """
18
+ Get the tagging discriminant.
11
19
 
20
+ Calculated as the logarithm of the ratio of a specified signal probability
21
+ to a weighted sum ofbackground probabilities.
12
22
 
13
- def ctag_discriminant(jets, tagger, fb=0.2, epsilon=1e-10):
14
- pb, pc, pu = (jets[f"{tagger}_pb"], jets[f"{tagger}_pc"], jets[f"{tagger}_pu"])
15
- return np.log((pc + epsilon) / ((1.0 - fb) * pu + fb * pb + epsilon))
23
+ Parameters
24
+ ----------
25
+ jets : np.ndarray
26
+ Structed jet array containing tagger scores.
27
+ tagger : str
28
+ Name of the tagger, used to construct field names.
29
+ signal : str
30
+ Type of signal.
31
+ fxs : dict[str, float]
32
+ Dict of background probability names and their fractions.
33
+ If a fraction is None, it is calculated as (1 - sum of provided fractions).
34
+ epsilon : float, optional
35
+ A small value added to probabilities to prevent division by zero, by default 1e-10.
36
+
37
+ Returns
38
+ -------
39
+ np.ndarray
40
+ The tagger discriminant values for the jets.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If a fraction is specified for a denominator that is not present in the input array.
46
+ """
47
+ denominator = 0.0
48
+ for d, fx in fxs.items():
49
+ name = f"{tagger}_{d}"
50
+ if fx > 0 and name not in jets.dtype.names:
51
+ raise ValueError(f"Nonzero fx for {d}, but '{name}' not found in input array.")
52
+ denominator += jets[name] * fx if name in jets.dtype.names else 0
53
+ return np.log((jets[f"{tagger}_{signal.px}"] + epsilon) / (denominator + epsilon))
54
+
55
+
56
+ def btag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
57
+ fxs = {"pc": fc, "ptau": ftau, "pu": 1 - fc - ftau}
58
+ return discriminant(jets, tagger, Flavours.bjets, fxs, epsilon=epsilon)
59
+
60
+
61
+ def ctag_discriminant(jets, tagger, fb, ftau=0, epsilon=1e-10):
62
+ fxs = {"pb": fb, "ptau": ftau, "pu": 1 - fb - ftau}
63
+ return discriminant(jets, tagger, Flavours.cjets, fxs, epsilon=epsilon)
16
64
 
17
65
 
18
66
  def hbb_discriminant(jets, tagger, ftop=0.25, fhcc=0.02, epsilon=1e-10):
19
- phbb = jets[f"{tagger}_phbb"]
20
- phcc = jets[f"{tagger}_phcc"]
21
- ptop = jets[f"{tagger}_ptop"]
22
- pqcd = jets[f"{tagger}_pqcd"]
23
- return np.log(phbb / (ftop * ptop + fhcc * phcc + (1 - ftop - fhcc) * pqcd + epsilon))
67
+ fxs = {"phcc": fhcc, "ptop": ftop, "pqcd": 1 - ftop - fhcc}
68
+ return discriminant(jets, tagger, Flavours.hbb, fxs, epsilon=epsilon)
24
69
 
25
70
 
26
71
  def hcc_discriminant(jets, tagger, ftop=0.25, fhbb=0.3, epsilon=1e-10):
27
- phbb = jets[f"{tagger}_phbb"]
28
- phcc = jets[f"{tagger}_phcc"]
29
- ptop = jets[f"{tagger}_ptop"]
30
- pqcd = jets[f"{tagger}_pqcd"]
31
- return np.log(phcc / (ftop * ptop + fhbb * phbb + (1 - ftop - fhbb) * pqcd + epsilon))
72
+ fxs = {"phbb": fhbb, "ptop": ftop, "pqcd": 1 - ftop - fhbb}
73
+ return discriminant(jets, tagger, Flavours.hcc, fxs, epsilon=epsilon)
32
74
 
33
75
 
34
76
  def get_discriminant(
35
- jets: np.ndarray,
36
- tagger: str,
37
- signal: Flavour | str,
38
- fx: float | tuple[float, ...],
39
- epsilon: float = 1e-10,
77
+ jets: np.ndarray, tagger: str, signal: Flavour | str, epsilon: float = 1e-10, **fxs
40
78
  ):
41
79
  """Calculate the b-tag or c-tag discriminant for a given tagger.
42
80
 
@@ -48,27 +86,25 @@ def get_discriminant(
48
86
  Name of the tagger
49
87
  signal : Flavour
50
88
  Signal flavour (bjets/cjets or hbb/hcc)
51
- fx : float, optional
52
- Value fb or fc (fhbb or fhcc and ftop for Xbb)
53
89
  epsilon : float, optional
54
90
  Small number to avoid division by zero, by default 1e-10
91
+ **fxs : dict
92
+ Fractions for the different background flavours.
55
93
 
56
94
  Returns
57
95
  -------
58
96
  np.ndarray
59
97
  Array of discriminant values.
60
98
  """
61
- if not isinstance(fx, tuple | list):
62
- fx = (fx,)
63
- tagger_funcs = {
99
+ tagger_funcs: dict[str, Callable] = {
64
100
  "bjets": btag_discriminant,
65
101
  "cjets": ctag_discriminant,
66
102
  "hbb": hbb_discriminant,
67
103
  "hcc": hcc_discriminant,
68
104
  }
69
105
 
70
- func = tagger_funcs.get(str(Flavours[signal]), None)
71
- if func is None:
72
- raise ValueError(f"Signal flavour must be among {list(tagger_funcs.keys())}, not {signal}")
106
+ if str(signal) not in tagger_funcs:
107
+ raise ValueError(f"Signal flavour must be one of {list(tagger_funcs.keys())}, not {signal}")
73
108
 
74
- return func(jets, tagger, *fx, epsilon) # type: ignore
109
+ func: Callable = tagger_funcs[str(Flavours[signal])]
110
+ return func(jets, tagger, **fxs, epsilon=epsilon)
@@ -12,12 +12,11 @@ from ftag.hdf5 import H5Reader
12
12
  from ftag.wps.discriminant import get_discriminant
13
13
 
14
14
 
15
- def parse_args(args):
15
+ def parse_args(args): # noqa: PLR0912
16
16
  parser = argparse.ArgumentParser(
17
17
  description="Calculate tagger working points",
18
18
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
19
19
  )
20
-
21
20
  parser.add_argument(
22
21
  "--ttbar",
23
22
  required=True,
@@ -45,14 +44,6 @@ def parse_args(args):
45
44
  type=str,
46
45
  help="tagger name(s)",
47
46
  )
48
- parser.add_argument(
49
- "-f",
50
- "--fx",
51
- nargs="+",
52
- required=True,
53
- type=float,
54
- help="fb or fc value(s) for each tagger",
55
- )
56
47
  parser.add_argument(
57
48
  "-s",
58
49
  "--signal",
@@ -107,8 +98,90 @@ def parse_args(args):
107
98
  action="store_true",
108
99
  help="Enable Xbb tagging which expects two fx values ftop and fhcc/fhbb for each tagger",
109
100
  )
101
+ parser.add_argument(
102
+ "--fb",
103
+ nargs="+",
104
+ type=float,
105
+ help="fb value(s) for each tagger",
106
+ )
107
+ parser.add_argument(
108
+ "--fc",
109
+ nargs="+",
110
+ type=float,
111
+ help="fc value(s) for each tagger",
112
+ )
113
+ parser.add_argument(
114
+ "--ftau",
115
+ nargs="+",
116
+ type=float,
117
+ help="ftau value(s) for each tagger",
118
+ )
119
+ parser.add_argument(
120
+ "--ftop",
121
+ nargs="+",
122
+ type=float,
123
+ help="ftop value(s) for each tagger",
124
+ )
125
+ parser.add_argument(
126
+ "--fhbb",
127
+ nargs="+",
128
+ type=float,
129
+ help="fhbb value(s) for each tagger",
130
+ )
131
+ parser.add_argument(
132
+ "--fhcc",
133
+ nargs="+",
134
+ type=float,
135
+ help="fhcc value(s) for each tagger",
136
+ )
137
+ args = parser.parse_args(args)
110
138
 
111
- return parser.parse_args(args)
139
+ args.signal = Flavours[args.signal]
140
+
141
+ if args.effs and args.disc_cuts:
142
+ raise ValueError("Cannot specify both --effs and --disc_cuts")
143
+ if not args.effs and not args.disc_cuts:
144
+ raise ValueError("Must specify either --effs or --disc_cuts")
145
+
146
+ if args.xbb:
147
+ if args.signal not in [Flavours.hbb, Flavours.hcc]:
148
+ raise ValueError("Xbb tagging only supports hbb or hcc signal flavours")
149
+ if args.fb or args.fc or args.ftau:
150
+ raise ValueError("For Xbb tagging, fb, fc and ftau should not be specified")
151
+ if not args.ftop:
152
+ raise ValueError("For Xbb tagging, ftop should be specified")
153
+ if args.signal == "hbb" and not args.fhcc:
154
+ raise ValueError("For hbb tagging, fhcc should be specified")
155
+ if args.signal == "hcc" and not args.fhbb:
156
+ raise ValueError("For hcc tagging, fhbb should be specified")
157
+ else:
158
+ if args.ftop or args.fhbb or args.fhcc:
159
+ raise ValueError("For single-b tagging, ftop, fhbb and fhcc should not be specified")
160
+ if args.signal == "bjets" and not args.fc:
161
+ raise ValueError("For bjets tagging, fc should be specified")
162
+ if args.signal == "cjets" and not args.fb:
163
+ raise ValueError("For cjets tagging, fb should be specified")
164
+ if args.ftau is None:
165
+ args.ftau = [0.0] * len(args.tagger)
166
+
167
+ for fx in ["fb", "fc", "ftau", "ftop", "fhbb", "fhcc"]:
168
+ if getattr(args, fx) and len(getattr(args, fx)) != len(args.tagger):
169
+ raise ValueError(f"Number of {fx} values must match number of taggers")
170
+
171
+ return args
172
+
173
+
174
+ def get_fxs_from_args(args):
175
+ if args.signal == Flavours.bjets:
176
+ fxs = {"fc": args.fc, "ftau": args.ftau}
177
+ elif args.signal == Flavours.cjets:
178
+ fxs = {"fb": args.fb, "ftau": args.ftau}
179
+ elif args.signal == Flavours.hbb:
180
+ fxs = {"ftop": args.ftop, "fhcc": args.fhcc}
181
+ elif args.signal == Flavours.hcc:
182
+ fxs = {"ftop": args.ftop, "fhbb": args.fhbb}
183
+ assert fxs is not None
184
+ return [{k: v[i] for k, v in fxs.items()} for i in range(len(args.tagger))]
112
185
 
113
186
 
114
187
  def get_eff_rej(jets, disc, wp, flavs):
@@ -121,40 +194,53 @@ def get_eff_rej(jets, disc, wp, flavs):
121
194
  return out
122
195
 
123
196
 
124
- def get_working_points(args=None):
125
- if args.xbb:
126
- if len(args.fx) != 2 * len(args.tagger):
127
- raise ValueError("For Xbb tagging, each tagger must have two fx values")
128
- fx_values = list(zip(args.fx[::2], args.fx[1::2]))
129
- else:
130
- if len(args.fx) != len(args.tagger):
131
- raise ValueError("Number of taggers must match number of fx values")
132
- fx_values = [(fx,) for fx in args.fx]
197
+ def get_rej_eff_at_disc(jets, tagger, signal, disc_cuts, **fxs):
198
+ disc = get_discriminant(jets, tagger, signal, **fxs)
199
+ d = {}
200
+ flavs = Flavours.by_category("single-btag")
201
+ for dcut in disc_cuts:
202
+ d[str(dcut)] = {"eff": {}, "rej": {}}
203
+ for f in flavs:
204
+ e_discs = disc[f.cuts(jets).idx]
205
+ eff = sum(e_discs > dcut) / len(e_discs)
206
+ d[str(dcut)]["eff"][str(f)] = float(f"{eff:.3g}")
207
+ d[str(dcut)]["rej"][str(f)] = 1 / float(f"{eff:.3g}")
208
+ return d
133
209
 
134
- # setup cuts and variables
135
- flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
136
210
 
211
+ def setup_common_parts(args):
212
+ flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
137
213
  default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
138
214
  ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
139
215
  zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
216
+
217
+ # prepare to load jets
140
218
  all_vars = next(iter(flavs)).cuts.variables
219
+ reader = H5Reader(args.ttbar)
220
+ jet_vars = reader.dtypes()["jets"].names
141
221
  for tagger in args.tagger:
142
- all_vars += [f"{tagger}_{f.px}" for f in flavs if "tau" not in f.px]
222
+ all_vars += [f"{tagger}_{f.px}" for f in flavs if (f"{tagger}_{f.px}" in jet_vars)]
143
223
 
144
224
  # load jets
145
- reader = H5Reader(args.ttbar)
146
225
  jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
226
+ zp_jets = None
147
227
  if args.zprime:
148
228
  zp_reader = H5Reader(args.zprime)
149
229
  zp_jets = zp_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)["jets"]
150
230
 
231
+ return jets, zp_jets, flavs
232
+
233
+
234
+ def get_working_points(args=None):
235
+ jets, zp_jets, flavs = setup_common_parts(args)
236
+ fxs = get_fxs_from_args(args)
237
+
151
238
  # loop over taggers
152
239
  out = {}
153
- for tagger, fx in zip(args.tagger, fx_values):
154
- out[tagger] = {"signal": args.signal, "fx": fx}
155
-
240
+ for i, tagger in enumerate(args.tagger):
156
241
  # calculate discriminant
157
- disc = get_discriminant(jets, tagger, args.signal, fx)
242
+ out[tagger] = {"signal": str(args.signal), **fxs[i]}
243
+ disc = get_discriminant(jets, tagger, args.signal, **fxs[i])
158
244
 
159
245
  # loop over efficiency working points
160
246
  for eff in args.effs:
@@ -173,7 +259,7 @@ def get_working_points(args=None):
173
259
 
174
260
  # calculate for zprime
175
261
  if args.zprime:
176
- zp_disc = get_discriminant(zp_jets, tagger, Flavours[args.signal], fx)
262
+ zp_disc = get_discriminant(zp_jets, tagger, Flavours[args.signal], **fxs[i])
177
263
  d["zprime"] = get_eff_rej(zp_jets, zp_disc, wp, flavs)
178
264
 
179
265
  if args.outfile:
@@ -184,50 +270,20 @@ def get_working_points(args=None):
184
270
  return out
185
271
 
186
272
 
187
- def get_rej_eff_at_disc(jets, tagger, signal, fx, disc_cuts):
188
- disc = get_discriminant(jets, tagger, signal, fx)
189
- d = {}
190
- flavs = Flavours.by_category("single-btag")
191
- for dcut in disc_cuts:
192
- d[str(dcut)] = {"eff": {}, "rej": {}}
193
- for f in flavs:
194
- e_discs = disc[f.cuts(jets).idx]
195
- eff = sum(e_discs > dcut) / len(e_discs)
196
- d[str(dcut)]["eff"][str(f)] = float(f"{eff:.3g}")
197
- d[str(dcut)]["rej"][str(f)] = 1 / float(f"{eff:.3g}")
198
- return d
199
-
200
-
201
273
  def get_efficiencies(args=None):
202
- if len(args.tagger) != len(args.fx):
203
- raise ValueError("Must provide fb/fc for each tagger")
204
-
205
- fx_values = [(fx,) for fx in args.fx]
206
- # setup cuts and variables
207
- flavs = Flavours.by_category("single-btag")
208
- default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
209
- ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
210
- zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
211
- all_vars = next(iter(flavs)).cuts.variables
212
- for tagger in args.tagger:
213
- all_vars += [f"{tagger}_{f.px}" for f in flavs if "tau" not in f.px]
274
+ jets, zp_jets, flavs = setup_common_parts(args)
275
+ fxs = get_fxs_from_args(args)
214
276
 
215
- # load jets
216
- reader = H5Reader(args.ttbar)
217
- jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
218
- if args.zprime:
219
- zp_reader = H5Reader(args.zprime)
220
- zp_jets = zp_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)["jets"]
221
-
222
- # loop over taggers
223
277
  out = {}
224
- for tagger, fx in zip(args.tagger, fx_values):
225
- out[tagger] = {"signal": args.signal, "fx": fx}
278
+ for i, tagger in enumerate(args.tagger):
279
+ out[tagger] = {"signal": str(args.signal), **fxs[i]}
226
280
 
227
- out[tagger]["ttbar"] = get_rej_eff_at_disc(jets, tagger, args.signal, fx, args.disc_cuts)
281
+ out[tagger]["ttbar"] = get_rej_eff_at_disc(
282
+ jets, tagger, args.signal, args.disc_cuts, **fxs[i]
283
+ )
228
284
  if args.zprime:
229
285
  out[tagger]["zprime"] = get_rej_eff_at_disc(
230
- zp_jets, tagger, args.signal, fx, args.disc_cuts
286
+ zp_jets, tagger, args.signal, args.disc_cuts, **fxs[i]
231
287
  )
232
288
 
233
289
  if args.outfile:
@@ -241,18 +297,15 @@ def get_efficiencies(args=None):
241
297
  def main(args=None):
242
298
  args = parse_args(args)
243
299
 
244
- if args.effs and args.disc_cuts:
245
- raise ValueError("Cannot specify both --effs and --disc_cuts")
246
-
247
300
  if args.effs:
248
301
  out = get_working_points(args)
249
302
  elif args.disc_cuts:
250
303
  out = get_efficiencies(args)
251
- else:
252
- raise ValueError("Must specify either --effs or --disc_cuts")
304
+
253
305
  if out:
254
306
  print(yaml.dump(out, sort_keys=False))
255
307
  return out
308
+
256
309
  return None
257
310
 
258
311