atlas-ftag-tools 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,23 +1,23 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
8
8
  Requires-Python: <3.12,>=3.8
9
9
  Description-Content-Type: text/markdown
10
- Requires-Dist: h5py >=3.0
10
+ Requires-Dist: h5py>=3.0
11
11
  Requires-Dist: numpy
12
- Requires-Dist: PyYAML >=5.1
12
+ Requires-Dist: PyYAML>=5.1
13
13
  Provides-Extra: dev
14
- Requires-Dist: ruff ==0.6.2 ; extra == 'dev'
15
- Requires-Dist: mypy ==1.11.2 ; extra == 'dev'
16
- Requires-Dist: pre-commit ==3.1.1 ; extra == 'dev'
17
- Requires-Dist: pytest ==7.2.2 ; extra == 'dev'
18
- Requires-Dist: pytest-cov ==4.0.0 ; extra == 'dev'
19
- Requires-Dist: pytest-notebook ==0.10.0 ; extra == 'dev'
20
- Requires-Dist: ipykernel ==6.21.3 ; extra == 'dev'
14
+ Requires-Dist: ruff==0.6.2; extra == "dev"
15
+ Requires-Dist: mypy==1.11.2; extra == "dev"
16
+ Requires-Dist: pre-commit==3.1.1; extra == "dev"
17
+ Requires-Dist: pytest==7.2.2; extra == "dev"
18
+ Requires-Dist: pytest-cov==4.0.0; extra == "dev"
19
+ Requires-Dist: pytest_notebook==0.10.0; extra == "dev"
20
+ Requires-Dist: ipykernel==6.21.3; extra == "dev"
21
21
 
22
22
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
23
23
  [![PyPI version](https://badge.fury.io/py/atlas-ftag-tools.svg)](https://badge.fury.io/py/atlas-ftag-tools)
@@ -1,12 +1,12 @@
1
- ftag/__init__.py,sha256=k5qBmtC7Ieh0trgm2Ba9Qj_6A2wQSpmAmXo2iIOAaI0,737
1
+ ftag/__init__.py,sha256=YRug5UslRbNoQACbEhdenDS6wXmsmeLjlz4JaKP6eHs,737
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
3
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
4
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
5
- ftag/flavours.yaml,sha256=E_vpn38qJ3-Tygg2aHlH4wkn_rR1On_lMeaG8OemHCQ,8285
5
+ ftag/flavours.yaml,sha256=87xBvLkMDkicuRMaXtxcao8gjEAgvlTbgjAzpvx4YFM,9021
6
6
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
7
  ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
8
8
  ftag/labels.py,sha256=C7IylPTnc32dFXq8C2Ks2wuljYK3WaY2EsPLGrhtXy8,3932
9
- ftag/mock.py,sha256=Eyj3tkkaSSnqvS3G6NS7fq8sB__Nx8YE9-OM2_lpdoQ,4992
9
+ ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
10
10
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
11
11
  ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
12
12
  ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
@@ -19,10 +19,10 @@ ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
19
19
  ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
20
20
  ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
21
21
  ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- ftag/wps/discriminant.py,sha256=VJdZlJJUwaTeyxmIDEk23rQSAuvWs6wDA3XRjDI6-_c,4277
23
- ftag/wps/working_points.py,sha256=cvStSpP8Cbb_FWM8v59tFsscUvdeqi831tLn5BiHUEg,9741
24
- atlas_ftag_tools-0.2.7.dist-info/METADATA,sha256=oo5m85dK467AWuK-L8xaIbLDVmRUO3r7vA_J1vgR5b8,5169
25
- atlas_ftag_tools-0.2.7.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
26
- atlas_ftag_tools-0.2.7.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
- atlas_ftag_tools-0.2.7.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
- atlas_ftag_tools-0.2.7.dist-info/RECORD,,
22
+ ftag/wps/discriminant.py,sha256=GKa0zZlLREdm0mCYSbcWXITYe3VEn3PXOBQiPg5WvgM,2521
23
+ ftag/wps/working_points.py,sha256=jXyikB-bf73EaYFkngjE977-Ytvb9nDTqIdHxWW6WQQ,15960
24
+ atlas_ftag_tools-0.2.9.dist-info/METADATA,sha256=lXC-e0iHMDtvJH8h3i7PcCEKh4_CFz5vlqdGXKSEoV4,5153
25
+ atlas_ftag_tools-0.2.9.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
26
+ atlas_ftag_tools-0.2.9.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
+ atlas_ftag_tools-0.2.9.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
+ atlas_ftag_tools-0.2.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (76.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.7"
5
+ __version__ = "v0.2.9"
6
6
 
7
7
  from ftag import hdf5
8
8
  from ftag.cuts import Cuts
ftag/flavours.yaml CHANGED
@@ -60,12 +60,24 @@
60
60
  colour: tab:orange
61
61
  category: single-btag-ghost
62
62
  _px: pc
63
- - name: ghostujets
64
- label: Light-jets
65
- cuts: ["HadronGhostTruthLabelID == 0"]
63
+ - name: ghostsjets
64
+ label: $s$-jets
65
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 3"]
66
+ colour: tab:red
67
+ category: single-btag-ghost
68
+ _px: ps
69
+ - name: ghostudjets
70
+ label: Light-quark-jets
71
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID <= 2"]
66
72
  colour: tab:green
67
73
  category: single-btag-ghost
68
- _px: pu
74
+ _px: pud
75
+ - name: ghostgjets
76
+ label: Gluon-jets
77
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 21"]
78
+ colour: tab:gray
79
+ category: single-btag-ghost
80
+ _px: pg
69
81
  - name: ghosttaujets
70
82
  label: $\tau$-jets
71
83
  cuts: ["HadronGhostTruthLabelID == 15"]
@@ -119,6 +131,21 @@
119
131
  cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount == 0", "GhostCHadronsFinalCount == 0"]
120
132
  colour: "green"
121
133
  category: xbb
134
+ - name: htauel
135
+ label: $H \rightarrow \tau e$
136
+ cuts: ["R10TruthLabel_R22v1 == 14"]
137
+ colour: "#b40612"
138
+ category: xbb
139
+ - name: htaumu
140
+ label: $H \rightarrow \tau\mu$
141
+ cuts: ["R10TruthLabel_R22v1 == 15"]
142
+ colour: "#b40657"
143
+ category: xbb
144
+ - name: htauhad
145
+ label: $H \rightarrow \tau\tau$
146
+ cuts: ["R10TruthLabel_R22v1 == 16"]
147
+ colour: "#b406a0"
148
+ category: xbb
122
149
 
123
150
  # extended Xbb tagging
124
151
  - name: tqqb
@@ -272,7 +299,7 @@
272
299
  category: isolation
273
300
  - name: npxall
274
301
  label: non-prompt lepton
275
- cuts: ["iffClass notin (2,3,4,11)"]
302
+ cuts: ["iffClass notin (0,1,2,3,4,11)"]
276
303
  colour: "#264653"
277
304
  category: isolation
278
305
  - name: npxtau
ftag/mock.py CHANGED
@@ -54,33 +54,74 @@ TRACK_VARS = [
54
54
  ]
55
55
 
56
56
 
57
- def softmax(x, axis=None):
57
+ def softmax(x: np.ndarray, axis: int | None = None) -> np.ndarray:
58
+ """Softmax function for numpy arrays.
59
+
60
+ Parameters
61
+ ----------
62
+ x : np.ndarray
63
+ Input array for the softmax
64
+ axis : int | None, optional
65
+ Axis along which the softmax is calculated, by default None
66
+
67
+ Returns
68
+ -------
69
+ np.ndarray
70
+ Output array with the softmax output
71
+ """
58
72
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
59
73
  return e_x / e_x.sum(axis=axis, keepdims=True)
60
74
 
61
75
 
62
- def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
63
- means = [
64
- [2, 0, 0, 0],
65
- [0, 1, 0, 0],
66
- [0, 0, 3.5, 0],
67
- [0, 0, 0, 1],
68
- ]
76
+ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
69
77
  if not is_xbb:
70
78
  label_dict = {"u": 0, "c": 4, "b": 5, "tau": 15}
71
- label_mapping = dict(zip(label_dict.values(), means))
72
- else:
73
- label_dict = {"hbb": 11, "hcc": 12, "top": 1, "qcd": 10}
74
- label_mapping = dict(zip(label_dict.values(), means))
75
79
 
80
+ else:
81
+ label_dict = {
82
+ "hbb": 11,
83
+ "hcc": 12,
84
+ "top": 1,
85
+ "qcd": 10,
86
+ "htauel": 14,
87
+ "htaumu": 15,
88
+ "htauhad": 16,
89
+ }
90
+
91
+ # Set random seed
76
92
  rng = np.random.default_rng(42)
77
- nclass = len(label_dict)
78
- scores = np.zeros((len(labels), nclass))
79
- scales = [1, 2.5, 5, 1]
93
+
94
+ # Set a list of possible means/scales
95
+ mean_scale_list = [1, 2, 2.5, 3.5]
96
+
97
+ # Get the number of classes
98
+ n_classes = len(label_dict)
99
+
100
+ # Init a scores array
101
+ scores = np.zeros((len(labels), n_classes))
102
+
103
+ # Generate means/scales
104
+ means = []
105
+ scales = []
106
+ for i in range(n_classes):
107
+ tmp_means = []
108
+ tmp_means = [
109
+ 0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
110
+ for j in range(n_classes)
111
+ ]
112
+ means.append(tmp_means)
113
+ scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
114
+
115
+ # Map the labels to the means
116
+ label_mapping = dict(zip(label_dict.values(), means))
117
+
118
+ # Generate random mock scores
80
119
  for i, (label, count) in enumerate(zip(*np.unique(labels, return_counts=True))):
81
120
  scores[labels == label] = rng.normal(
82
- loc=label_mapping[label], scale=scales[i], size=(count, nclass)
121
+ loc=label_mapping[label], scale=scales[i], size=(count, n_classes)
83
122
  )
123
+
124
+ # Pipe scores through softmax
84
125
  scores = softmax(scores, axis=1)
85
126
  name = "MockXbbTagger" if is_xbb else "MockTagger"
86
127
  cols = [f"{name}_p{x}" for x in label_dict]
@@ -103,7 +144,7 @@ def mock_jets(num_jets=1000) -> np.ndarray:
103
144
  jets["HadronConeExclTruthLabelID"] = rng.choice([0, 4, 5, 15], size=num_jets)
104
145
  jets["GhostBHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
105
146
  jets["GhostCHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
106
- jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
147
+ jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12, 14, 15, 16], size=num_jets)
107
148
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
108
149
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
109
150
  return join_structured_arrays([jets, scores, xbb_scores])
ftag/wps/discriminant.py CHANGED
@@ -1,22 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  import numpy as np
6
6
 
7
- from ftag import Flavours
8
- from ftag.labels import Label, remove_suffix
7
+ if TYPE_CHECKING: # pragma: no cover
8
+ from ftag.labels import Label, LabelContainer
9
9
 
10
10
 
11
- def discriminant(
11
+ def get_discriminant(
12
12
  jets: np.ndarray,
13
13
  tagger: str,
14
14
  signal: Label,
15
- fxs: dict[str, float],
15
+ flavours: LabelContainer,
16
+ fraction_values: dict[str, float],
16
17
  epsilon: float = 1e-10,
17
18
  ) -> np.ndarray:
18
- """
19
- Get the tagging discriminant.
19
+ """Calculate the tagging discriminant for a given tagger.
20
20
 
21
21
  Calculated as the logarithm of the ratio of a specified signal probability
22
22
  to a weighted sum ofbackground probabilities.
@@ -24,108 +24,61 @@ def discriminant(
24
24
  Parameters
25
25
  ----------
26
26
  jets : np.ndarray
27
- Structed jet array containing tagger scores.
27
+ Structured array of jets containing tagger outputs
28
28
  tagger : str
29
- Name of the tagger, used to construct field names.
30
- signal : str
31
- Type of signal.
32
- fxs : dict[str, float]
33
- Dict of background probability names and their fractions.
34
- If a fraction is None, it is calculated as (1 - sum of provided fractions).
29
+ Name of the tagger
30
+ signal : Label
31
+ Signal flavour (bjets/cjets or hbb/hcc)
32
+ fraction_values : dict
33
+ Dict with the fraction values for the background classes for the given tagger
35
34
  epsilon : float, optional
36
- A small value added to probabilities to prevent division by zero, by default 1e-10.
35
+ Small number to avoid division by zero, by default 1e-10
37
36
 
38
37
  Returns
39
38
  -------
40
39
  np.ndarray
41
- The tagger discriminant values for the jets.
40
+ Array of discriminant values.
42
41
 
43
42
  Raises
44
43
  ------
45
44
  ValueError
46
- If a fraction is specified for a denominator that is not present in the input array.
45
+ If the signal flavour is not recognised.
47
46
  """
47
+ # Init the denominator
48
48
  denominator = 0.0
49
- for d, fx in fxs.items():
50
- name = f"{tagger}_{d}"
51
- if fx > 0 and name not in jets.dtype.names:
52
- raise ValueError(f"Nonzero fx for {d}, but '{name}' not found in input array.")
53
- denominator += jets[name] * fx if name in jets.dtype.names else 0
54
- signal_field = f"{tagger}_{signal.px}"
55
- if signal_field not in jets.dtype.names:
56
- signal_field = f"{tagger}_p{remove_suffix(signal.name, 'jets')}"
57
- return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
58
-
59
-
60
- def tautag_dicriminant(jets, tagger, fb, fc, epsilon=1e-10):
61
- fxs = {"pb": fb, "pc": fc, "pu": 1 - fb - fc}
62
- return discriminant(jets, tagger, Flavours.taujets, fxs, epsilon=epsilon)
63
-
64
-
65
- def btag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
66
- fxs = {"pc": fc, "ptau": ftau, "pu": 1 - fc - ftau}
67
- return discriminant(jets, tagger, Flavours.bjets, fxs, epsilon=epsilon)
68
-
69
49
 
70
- def ghostbtag_discriminant(jets, tagger, fc, ftau=0, epsilon=1e-10):
71
- fxs = {"pghostc": fc, "pghosttau": ftau, "pghostu": 1 - fc - ftau}
72
- return discriminant(jets, tagger, Flavours.ghostbjets, fxs, epsilon=epsilon)
50
+ # Loop over background flavours
51
+ for flav in flavours:
52
+ # Skip signal flavour for denominator
53
+ if flav == signal:
54
+ continue
73
55
 
56
+ # Get the probability name of the tagger/flavour combo + fraction value
57
+ prob_name = f"{tagger}_{flav.px}"
58
+ fraction_value = fraction_values[flav.frac_str]
74
59
 
75
- def ctag_discriminant(jets, tagger, fb, ftau=0, epsilon=1e-10):
76
- fxs = {"pb": fb, "ptau": ftau, "pu": 1 - fb - ftau}
77
- return discriminant(jets, tagger, Flavours.cjets, fxs, epsilon=epsilon)
60
+ # If fraction_value for the given flavour is zero, skip it
61
+ if fraction_value == 0:
62
+ continue
78
63
 
64
+ # Check that the probability value for the flavour is available
65
+ if fraction_value > 0 and prob_name not in jets.dtype.names:
66
+ raise ValueError(
67
+ f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
68
+ "not found in input array."
69
+ )
79
70
 
80
- def hbb_discriminant(jets, tagger, ftop=0.25, fhcc=0.02, epsilon=1e-10):
81
- fxs = {"phcc": fhcc, "ptop": ftop, "pqcd": 1 - ftop - fhcc}
82
- return discriminant(jets, tagger, Flavours.hbb, fxs, epsilon=epsilon)
83
-
84
-
85
- def hcc_discriminant(jets, tagger, ftop=0.25, fhbb=0.3, epsilon=1e-10):
86
- fxs = {"phbb": fhbb, "ptop": ftop, "pqcd": 1 - ftop - fhbb}
87
- return discriminant(jets, tagger, Flavours.hcc, fxs, epsilon=epsilon)
88
-
89
-
90
- def get_discriminant(
91
- jets: np.ndarray, tagger: str, signal: Label | str, epsilon: float = 1e-10, **fxs
92
- ):
93
- """Calculate the b-tag or c-tag discriminant for a given tagger.
71
+ # Update denominator
72
+ denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
94
73
 
95
- Parameters
96
- ----------
97
- jets : np.ndarray
98
- Structured array of jets containing tagger outputs
99
- tagger : str
100
- Name of the tagger
101
- signal : Label
102
- Signal flavour (bjets/cjets or hbb/hcc)
103
- epsilon : float, optional
104
- Small number to avoid division by zero, by default 1e-10
105
- **fxs : dict
106
- Fractions for the different background flavours.
74
+ # Calculate numerator
75
+ signal_field = f"{tagger}_{signal.px}"
107
76
 
108
- Returns
109
- -------
110
- np.ndarray
111
- Array of discriminant values.
77
+ # Check that the probability of the signal is available
78
+ if signal_field not in jets.dtype.names:
79
+ raise ValueError(
80
+ f"No signal probability value(s) found for tagger {tagger}. "
81
+ f"Missing variable: {signal_field}"
82
+ )
112
83
 
113
- Raises
114
- ------
115
- ValueError
116
- If the signal flavour is not recognised.
117
- """
118
- tagger_funcs: dict[str, Callable] = {
119
- "bjets": btag_discriminant,
120
- "cjets": ctag_discriminant,
121
- "taujets": tautag_dicriminant,
122
- "hbb": hbb_discriminant,
123
- "hcc": hcc_discriminant,
124
- "ghostbjets": ghostbtag_discriminant,
125
- }
126
-
127
- if str(signal) not in tagger_funcs:
128
- raise ValueError(f"Signal flavour must be one of {list(tagger_funcs.keys())}, not {signal}")
129
-
130
- func: Callable = tagger_funcs[str(Flavours[signal])]
131
- return func(jets, tagger, **fxs, epsilon=epsilon)
84
+ return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
@@ -3,7 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import argparse
6
+ import sys
6
7
  from pathlib import Path
8
+ from typing import TYPE_CHECKING
7
9
 
8
10
  import numpy as np
9
11
  import yaml
@@ -14,30 +16,109 @@ from ftag.cuts import Cuts
14
16
  from ftag.hdf5 import H5Reader
15
17
  from ftag.wps.discriminant import get_discriminant
16
18
 
19
+ if TYPE_CHECKING: # pragma: no cover
20
+ from collections.abc import Sequence
21
+
22
+ from ftag.labels import Label, LabelContainer
23
+
24
+
25
+ def parse_args(args: Sequence[str]) -> argparse.Namespace:
26
+ """Parse the input arguments into a Namespace.
27
+
28
+ Parameters
29
+ ----------
30
+ args : Sequence[str] | None
31
+ Sequence of string inputs to the script
32
+
33
+ Returns
34
+ -------
35
+ argparse.Namespace
36
+ Namespace with the parsed arguments
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ When both --effs and --disc_cuts are provided
42
+ ValueError
43
+ When neither --effs nor --disc_cuts are provided
44
+ ValueError
45
+ When the number of fraction values is not conistent
46
+ ValueError
47
+ When the sum of fraction values for a tagger is not equal to one
48
+ """
49
+ # Define the pre-parser which checks the --category
50
+ pre_parser = argparse.ArgumentParser(add_help=False)
51
+ pre_parser.add_argument(
52
+ "-c",
53
+ "--category",
54
+ default="single-btag",
55
+ type=str,
56
+ help="Label category to use for the working point calculation",
57
+ )
58
+
59
+ pre_parser.add_argument(
60
+ "-s",
61
+ "--signal",
62
+ default="bjets",
63
+ type=str,
64
+ help="Signal flavour which is to be used",
65
+ )
66
+
67
+ # Parse only --category/--signal and ignore for now all other args
68
+ pre_args, remaining_argv = pre_parser.parse_known_args(args=args)
17
69
 
18
- def parse_args(args):
70
+ # Create the "real" parser
19
71
  parser = argparse.ArgumentParser(
20
72
  description=__doc__,
21
73
  formatter_class=HelpFormatter,
22
74
  )
75
+
76
+ # Add --category/--signal so the help is correctly shown
77
+ parser.add_argument(
78
+ "-c",
79
+ "--category",
80
+ default="single-btag",
81
+ type=str,
82
+ help="Label category to use for the working point calculation",
83
+ )
84
+ parser.add_argument(
85
+ "-s",
86
+ "--signal",
87
+ default="bjets",
88
+ type=str,
89
+ help="Signal flavour which is to be used",
90
+ )
91
+
92
+ # Check which label category was chosen and load the corresponding flavours
93
+ flavours = Flavours.by_category(pre_args.category)
94
+
95
+ # Build the fraction value arguments for all classes (besides signal)
96
+ for flav in flavours:
97
+ # Skip signal
98
+ if flav.name == pre_args.signal:
99
+ continue
100
+
101
+ # Built fraction values for all background classes
102
+ parser.add_argument(
103
+ f"--{flav.frac_str}",
104
+ nargs="+",
105
+ required=True,
106
+ type=float,
107
+ help=f"{flav.frac_str} value(s) for each tagger",
108
+ )
109
+
110
+ # # Adding the other arguments
23
111
  parser.add_argument(
24
112
  "--ttbar",
25
113
  required=True,
26
114
  type=Path,
27
- help="path to ttbar sample (supports globbing)",
115
+ help="Path to ttbar sample (supports globbing)",
28
116
  )
29
117
  parser.add_argument(
30
118
  "--zprime",
31
119
  required=False,
32
120
  type=Path,
33
- help="path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
34
- )
35
- parser.add_argument(
36
- "-e",
37
- "--effs",
38
- nargs="+",
39
- type=float,
40
- help="efficiency working point(s). If -r is specified, values should be 1/efficiency",
121
+ help="Path to zprime (supports globbing). WPs from ttbar will be reused for zprime",
41
122
  )
42
123
  parser.add_argument(
43
124
  "-t",
@@ -48,19 +129,17 @@ def parse_args(args):
48
129
  help="tagger name(s)",
49
130
  )
50
131
  parser.add_argument(
51
- "-s",
52
- "--signal",
53
- default="bjets",
54
- choices=["bjets", "cjets", "hbb", "hcc"],
55
- type=str,
56
- help='signal flavour ("bjets" or "cjets" for b-tagging, "hbb" or "hcc" for Xbb)',
132
+ "-e",
133
+ "--effs",
134
+ nargs="+",
135
+ type=float,
136
+ help="Efficiency working point(s). If -r is specified, values should be 1/efficiency",
57
137
  )
58
138
  parser.add_argument(
59
139
  "-r",
60
140
  "--rejection",
61
141
  default=None,
62
- choices=["ujets", "cjets", "bjets", "hbb", "hcc", "top", "qcd"],
63
- help="use rejection of specified background class to determine working points",
142
+ help="Use rejection of specified background class to determine working points",
64
143
  )
65
144
  parser.add_argument(
66
145
  "-d",
@@ -74,219 +153,357 @@ def parse_args(args):
74
153
  "--num_jets",
75
154
  default=1_000_000,
76
155
  type=int,
77
- help="use this many jets (post selection)",
156
+ help="Use this many jets (post selection)",
78
157
  )
79
158
  parser.add_argument(
80
159
  "--ttbar_cuts",
81
160
  nargs="+",
82
161
  default=["pt > 20e3"],
83
162
  type=list,
84
- help="selection to apply to ttbar (|eta| < 2.5 is always applied)",
163
+ help="Selection to apply to ttbar (|eta| < 2.5 is always applied)",
85
164
  )
86
165
  parser.add_argument(
87
166
  "--zprime_cuts",
88
167
  nargs="+",
89
168
  default=["pt > 250e3"],
90
169
  type=list,
91
- help="selection to apply to zprime (|eta| < 2.5 is always applied)",
170
+ help="Selection to apply to zprime (|eta| < 2.5 is always applied)",
92
171
  )
93
172
  parser.add_argument(
94
173
  "-o",
95
174
  "--outfile",
96
175
  type=Path,
97
- help="save results to yaml instead of printing",
98
- )
99
- parser.add_argument(
100
- "--xbb",
101
- action="store_true",
102
- help="Enable Xbb tagging which expects two fx values ftop and fhcc/fhbb for each tagger",
103
- )
104
- parser.add_argument(
105
- "--fb",
106
- nargs="+",
107
- type=float,
108
- help="fb value(s) for each tagger",
109
- )
110
- parser.add_argument(
111
- "--fc",
112
- nargs="+",
113
- type=float,
114
- help="fc value(s) for each tagger",
115
- )
116
- parser.add_argument(
117
- "--ftau",
118
- nargs="+",
119
- type=float,
120
- help="ftau value(s) for each tagger",
121
- )
122
- parser.add_argument(
123
- "--ftop",
124
- nargs="+",
125
- type=float,
126
- help="ftop value(s) for each tagger",
176
+ help="Save results to yaml instead of printing",
127
177
  )
128
- parser.add_argument(
129
- "--fhbb",
130
- nargs="+",
131
- type=float,
132
- help="fhbb value(s) for each tagger",
133
- )
134
- parser.add_argument(
135
- "--fhcc",
136
- nargs="+",
137
- type=float,
138
- help="fhcc value(s) for each tagger",
139
- )
140
- args = parser.parse_args(args)
141
178
 
142
- args.signal = Flavours[args.signal]
179
+ # Final parse of all arguments
180
+ parsed_args = parser.parse_args(remaining_argv)
181
+
182
+ # Define the signal as an instance of Flavours
183
+ parsed_args.signal = Flavours[parsed_args.signal]
143
184
 
144
- if args.effs and args.disc_cuts:
185
+ # Check that only --effs or --disc_cuts is given
186
+ if parsed_args.effs and parsed_args.disc_cuts:
145
187
  raise ValueError("Cannot specify both --effs and --disc_cuts")
146
- if not args.effs and not args.disc_cuts:
188
+ if not parsed_args.effs and not parsed_args.disc_cuts:
147
189
  raise ValueError("Must specify either --effs or --disc_cuts")
148
190
 
149
- if args.xbb:
150
- if args.signal not in {Flavours.hbb, Flavours.hcc}:
151
- raise ValueError("Xbb tagging only supports hbb or hcc signal flavours")
152
- if args.fb or args.fc or args.ftau:
153
- raise ValueError("For Xbb tagging, fb, fc and ftau should not be specified")
154
- if not args.ftop:
155
- raise ValueError("For Xbb tagging, ftop should be specified")
156
- if args.signal == "hbb" and not args.fhcc:
157
- raise ValueError("For hbb tagging, fhcc should be specified")
158
- if args.signal == "hcc" and not args.fhbb:
159
- raise ValueError("For hcc tagging, fhbb should be specified")
160
- else:
161
- if args.ftop or args.fhbb or args.fhcc:
162
- raise ValueError("For single-b tagging, ftop, fhbb and fhcc should not be specified")
163
- if args.signal == "bjets" and not args.fc:
164
- raise ValueError("For bjets tagging, fc should be specified")
165
- if args.signal == "cjets" and not args.fb:
166
- raise ValueError("For cjets tagging, fb should be specified")
167
- if args.ftau is None:
168
- args.ftau = [0.0] * len(args.tagger)
169
-
170
- for fx in ["fb", "fc", "ftau", "ftop", "fhbb", "fhcc"]:
171
- if getattr(args, fx) and len(getattr(args, fx)) != len(args.tagger):
172
- raise ValueError(f"Number of {fx} values must match number of taggers")
173
-
174
- return args
175
-
176
-
177
- def get_fxs_from_args(args):
178
- if args.signal == Flavours.bjets:
179
- fxs = {"fc": args.fc, "ftau": args.ftau}
180
- elif args.signal == Flavours.cjets:
181
- fxs = {"fb": args.fb, "ftau": args.ftau}
182
- elif args.signal == Flavours.hbb:
183
- fxs = {"ftop": args.ftop, "fhcc": args.fhcc}
184
- elif args.signal == Flavours.hcc:
185
- fxs = {"ftop": args.ftop, "fhbb": args.fhbb}
186
- assert fxs is not None
187
- return [{k: v[i] for k, v in fxs.items()} for i in range(len(args.tagger))]
188
-
189
-
190
- def get_eff_rej(jets, disc, wp, flavs):
191
- out = {"eff": {}, "rej": {}}
192
- for bkg in list(flavs):
193
- bkg_disc = disc[bkg.cuts(jets).idx]
194
- eff = sum(bkg_disc > wp) / len(bkg_disc)
195
- out["eff"][str(bkg)] = float(f"{eff:.3g}")
196
- out["rej"][str(bkg)] = float(f"{1 / eff:.3g}")
191
+ # Check that all fraction values have the same length
192
+ for flav in flavours:
193
+ if flav.name != parsed_args.signal.name and len(getattr(parsed_args, flav.frac_str)) != len(
194
+ parsed_args.tagger
195
+ ):
196
+ raise ValueError(f"Number of {flav.frac_str} values must match number of taggers")
197
+
198
+ # Check that all fraction value combinations add up to one
199
+ for tagger_idx in range(len(parsed_args.tagger)):
200
+ fraction_value_sum = 0
201
+ for flav in flavours:
202
+ if flav.name != parsed_args.signal.name:
203
+ fraction_value_sum += getattr(parsed_args, flav.frac_str)[tagger_idx]
204
+
205
+ # Round the value to take machine precision into account
206
+ fraction_value_sum = np.round(fraction_value_sum, 8)
207
+
208
+ # Check it's equal to one
209
+ if fraction_value_sum != 1:
210
+ raise ValueError(
211
+ "Sum of the fraction values must be one! You gave "
212
+ f"{fraction_value_sum} for tagger {parsed_args.tagger[tagger_idx]}"
213
+ )
214
+ return parsed_args
215
+
216
+
217
+ def get_fxs_from_args(args: argparse.Namespace, flavours: LabelContainer) -> list:
218
+ """Get the fraction values for each tagger from the argparsed inputs.
219
+
220
+ Parameters
221
+ ----------
222
+ args : argparse.Namespace
223
+ Input arguments parsed by the argparser
224
+ flavours : LabelContainer
225
+ LabelContainer instance of the labels that are used
226
+
227
+ Returns
228
+ -------
229
+ list
230
+ List of dicts with the fraction values. Each dict is for one tagger.
231
+ """
232
+ # Init the fraction_dict dict
233
+ fraction_dict = {}
234
+
235
+ # Add the fraction values to the dict
236
+ for flav in flavours:
237
+ if flav.name != args.signal.name:
238
+ fraction_dict[flav.frac_str] = vars(args)[flav.frac_str]
239
+
240
+ return [{k: v[i] for k, v in fraction_dict.items()} for i in range(len(args.tagger))]
241
+
242
+
243
+ def get_eff_rej(
244
+ jets: np.ndarray,
245
+ disc: np.ndarray,
246
+ wp: float,
247
+ flavours: LabelContainer,
248
+ ) -> dict:
249
+ """Calculate the efficiency/rejection for each flavour.
250
+
251
+ Parameters
252
+ ----------
253
+ jets : np.ndarray
254
+ Loaded jets
255
+ disc : np.ndarray
256
+ Discriminant values of the jets
257
+ wp : float
258
+ Working point that is used
259
+ flavours : LabelContainer
260
+ LabelContainer instance of the flavours used
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ Dict with the efficiency/rejection values for each flavour
266
+ """
267
+ # Init an out dict
268
+ out: dict[str, dict] = {"eff": {}, "rej": {}}
269
+
270
+ # Loop over the flavours
271
+ for flav in flavours:
272
+ # Calculate discriminant values and efficiencies/rejections
273
+ flav_disc = disc[flav.cuts(jets).idx]
274
+ eff = sum(flav_disc > wp) / len(flav_disc)
275
+ out["eff"][flav.name] = float(f"{eff:.3g}")
276
+ out["rej"][flav.name] = float(f"{1 / eff:.3g}")
277
+
197
278
  return out
198
279
 
199
280
 
200
- def get_rej_eff_at_disc(jets, tagger, signal, disc_cuts, **fxs):
201
- disc = get_discriminant(jets, tagger, signal, **fxs)
202
- d = {}
203
- flavs = Flavours.by_category("single-btag")
281
+ def get_rej_eff_at_disc(
282
+ jets: np.ndarray,
283
+ tagger: str,
284
+ signal: Label,
285
+ disc_cuts: list,
286
+ flavours: LabelContainer,
287
+ fraction_values: dict,
288
+ ) -> dict:
289
+ """Calculate the efficiency/rejection at a certain discriminant values.
290
+
291
+ Parameters
292
+ ----------
293
+ jets : np.ndarray
294
+ Loaded jets used
295
+ tagger : str
296
+ Name of the tagger
297
+ signal : Label
298
+ Label instance of the signal flavour
299
+ disc_cuts : list
300
+ List of discriminant cut values for which the efficiency/rejection is calculated
301
+ flavours : LabelContainer
302
+ LabelContainer instance of the flavours that are used
303
+
304
+ Returns
305
+ -------
306
+ dict
307
+ Dict with the discriminant cut values and their respective efficiencies/rejections
308
+ """
309
+ # Calculate discriminants
310
+ disc = get_discriminant(
311
+ jets=jets,
312
+ tagger=tagger,
313
+ signal=signal,
314
+ flavours=flavours,
315
+ fraction_values=fraction_values,
316
+ )
317
+
318
+ # Init out dict
319
+ ref_eff_dict: dict[str, dict] = {}
320
+
321
+ # Loop over the disc cut values
204
322
  for dcut in disc_cuts:
205
- d[str(dcut)] = {"eff": {}, "rej": {}}
206
- for f in flavs:
207
- e_discs = disc[f.cuts(jets).idx]
323
+ ref_eff_dict[str(dcut)] = {"eff": {}, "rej": {}}
324
+
325
+ # Loop over the flavours
326
+ for flav in flavours:
327
+ e_discs = disc[flav.cuts(jets).idx]
208
328
  eff = sum(e_discs > dcut) / len(e_discs)
209
- d[str(dcut)]["eff"][str(f)] = float(f"{eff:.3g}")
210
- d[str(dcut)]["rej"][str(f)] = 1 / float(f"{eff:.3g}")
211
- return d
329
+ ref_eff_dict[str(dcut)]["eff"][str(flav)] = float(f"{eff:.3g}")
330
+ ref_eff_dict[str(dcut)]["rej"][str(flav)] = 1 / float(f"{eff:.3g}")
331
+
332
+ return ref_eff_dict
333
+
334
+
335
+ def setup_common_parts(
336
+ args: argparse.Namespace,
337
+ ) -> tuple[np.ndarray, np.ndarray | None, LabelContainer]:
338
+ """Load the jets from the files and setup the taggers.
339
+
340
+ Parameters
341
+ ----------
342
+ args : argparse.Namespace
343
+ Input arguments from the argparser
212
344
 
345
+ Returns
346
+ -------
347
+ tuple[dict, dict | None, list]
348
+ Outputs the ttbar jets, the zprime jets (if wanted, else None), and the flavours used.
349
+ """
350
+ # Get the used flavours
351
+ flavours = Flavours.by_category(args.category)
213
352
 
214
- def setup_common_parts(args):
215
- flavs = Flavours.by_category("single-btag") if not args.xbb else Flavours.by_category("xbb")
353
+ # Get the cuts for the samples
216
354
  default_cuts = Cuts.from_list(["eta > -2.5", "eta < 2.5"])
217
355
  ttbar_cuts = Cuts.from_list(args.ttbar_cuts) + default_cuts
218
356
  zprime_cuts = Cuts.from_list(args.zprime_cuts) + default_cuts
219
357
 
220
- # prepare to load jets
221
- all_vars = list(set(sum((flav.cuts.variables for flav in flavs), [])))
358
+ # Prepare the loading of the jets
359
+ all_vars = list(set(sum((flav.cuts.variables for flav in flavours), [])))
222
360
  reader = H5Reader(args.ttbar)
223
361
  jet_vars = reader.dtypes()["jets"].names
362
+
363
+ # Create for all taggers the fraction values
224
364
  for tagger in args.tagger:
225
- all_vars += [f"{tagger}_{f.px}" for f in flavs if (f"{tagger}_{f.px}" in jet_vars)]
365
+ all_vars += [
366
+ f"{tagger}_{flav.px}" for flav in flavours if (f"{tagger}_{flav.px}" in jet_vars)
367
+ ]
368
+
369
+ # Load ttbar jets
370
+ ttbar_jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
371
+ zprime_jets = None
226
372
 
227
- # load jets
228
- jets = reader.load({"jets": all_vars}, args.num_jets, cuts=ttbar_cuts)["jets"]
229
- zp_jets = None
373
+ # Load zprime jets if needed
230
374
  if args.zprime:
231
- zp_reader = H5Reader(args.zprime)
232
- zp_jets = zp_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)["jets"]
375
+ zprime_reader = H5Reader(args.zprime)
376
+ zprime_jets = zprime_reader.load({"jets": all_vars}, args.num_jets, cuts=zprime_cuts)[
377
+ "jets"
378
+ ]
379
+
380
+ else:
381
+ zprime_jets = None
382
+
383
+ return ttbar_jets, zprime_jets, flavours
233
384
 
234
- return jets, zp_jets, flavs
235
385
 
386
+ def get_working_points(args: argparse.Namespace) -> dict | None:
387
+ """Calculate the working points.
236
388
 
237
- def get_working_points(args=None):
238
- jets, zp_jets, flavs = setup_common_parts(args)
239
- fxs = get_fxs_from_args(args)
389
+ Parameters
390
+ ----------
391
+ args : argparse.Namespace
392
+ Input arguments from the argparser
240
393
 
241
- # loop over taggers
394
+ Returns
395
+ -------
396
+ dict | None
397
+ Dict with the working points. If args.outfile is given, the function returns None and
398
+ stored the resulting dict in a yaml file in args.outfile.
399
+ """
400
+ # Load the jets and flavours and get the fraction values
401
+ ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
402
+ fraction_values = get_fxs_from_args(args=args, flavours=flavours)
403
+
404
+ # Init an out dict
242
405
  out = {}
406
+
407
+ # Loop over taggers
243
408
  for i, tagger in enumerate(args.tagger):
244
- # calculate discriminant
245
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
246
- disc = get_discriminant(jets, tagger, args.signal, **fxs[i])
409
+ # Calculate discriminant
410
+ out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
411
+ disc = get_discriminant(
412
+ jets=ttbar_jets,
413
+ tagger=tagger,
414
+ signal=args.signal,
415
+ flavours=flavours,
416
+ fraction_values=fraction_values[i],
417
+ )
247
418
 
248
- # loop over efficiency working points
419
+ # Loop over efficiency working points
249
420
  for eff in args.effs:
250
421
  d = out[tagger][f"{eff:.0f}"] = {}
251
422
 
423
+ # Set the working point
252
424
  wp_flavour = args.signal
253
425
  if args.rejection:
254
426
  eff = 100 / eff # noqa: PLW2901
255
427
  wp_flavour = args.rejection
256
428
 
257
- wp_disc = disc[flavs[wp_flavour].cuts(jets).idx]
429
+ # Calculate the discriminant value of the working point
430
+ wp_disc = disc[flavours[wp_flavour].cuts(ttbar_jets).idx]
258
431
  wp = d["cut_value"] = round(float(np.percentile(wp_disc, 100 - eff)), 3)
259
432
 
260
- # calculate eff and rej for each flavour
261
- d["ttbar"] = get_eff_rej(jets, disc, wp, flavs)
433
+ # Calculate efficiency and rejection for each flavour
434
+ d["ttbar"] = get_eff_rej(
435
+ jets=ttbar_jets,
436
+ disc=disc,
437
+ wp=wp,
438
+ flavours=flavours,
439
+ )
262
440
 
263
441
  # calculate for zprime
264
442
  if args.zprime:
265
- zp_disc = get_discriminant(zp_jets, tagger, Flavours[args.signal], **fxs[i])
266
- d["zprime"] = get_eff_rej(zp_jets, zp_disc, wp, flavs)
443
+ zprime_disc = get_discriminant(
444
+ jets=zprime_jets,
445
+ tagger=tagger,
446
+ signal=args.signal,
447
+ flavours=flavours,
448
+ fraction_values=fraction_values[i],
449
+ )
450
+ d["zprime"] = get_eff_rej(
451
+ jets=zprime_jets,
452
+ disc=zprime_disc,
453
+ wp=wp,
454
+ flavours=flavours,
455
+ )
267
456
 
268
457
  if args.outfile:
269
458
  with open(args.outfile, "w") as f:
270
459
  yaml.dump(out, f, sort_keys=False)
271
460
  return None
461
+
272
462
  else:
273
463
  return out
274
464
 
275
465
 
276
- def get_efficiencies(args=None):
277
- jets, zp_jets, _ = setup_common_parts(args)
278
- fxs = get_fxs_from_args(args)
466
+ def get_efficiencies(args: argparse.Namespace) -> dict | None:
467
+ """Calculate the efficiencies for the given jets.
468
+
469
+ Parameters
470
+ ----------
471
+ args : argparse.Namespace
472
+ Input arguments from the argparser
279
473
 
474
+ Returns
475
+ -------
476
+ dict | None
477
+ Dict with the efficiencies. If args.outfile is given, the function returns None and
478
+ stored the resulting dict in a yaml file in args.outfile.
479
+ """
480
+ # Load the jets and flavours and get the fraction values
481
+ ttbar_jets, zprime_jets, flavours = setup_common_parts(args=args)
482
+ fraction_values = get_fxs_from_args(args=args, flavours=flavours)
483
+
484
+ # Init an out dict
280
485
  out = {}
486
+
487
+ # Loop over the taggers
281
488
  for i, tagger in enumerate(args.tagger):
282
- out[tagger] = {"signal": str(args.signal), **fxs[i]}
489
+ out[tagger] = {"signal": str(args.signal), **fraction_values[i]}
283
490
 
284
491
  out[tagger]["ttbar"] = get_rej_eff_at_disc(
285
- jets, tagger, args.signal, args.disc_cuts, **fxs[i]
492
+ jets=ttbar_jets,
493
+ tagger=tagger,
494
+ signal=args.signal,
495
+ disc_cuts=args.disc_cuts,
496
+ flavours=flavours,
497
+ fraction_values=fraction_values[i],
286
498
  )
287
499
  if args.zprime:
288
500
  out[tagger]["zprime"] = get_rej_eff_at_disc(
289
- zp_jets, tagger, args.signal, args.disc_cuts, **fxs[i]
501
+ jets=zprime_jets,
502
+ tagger=tagger,
503
+ signal=args.signal,
504
+ disc_cuts=args.disc_cuts,
505
+ flavours=flavours,
506
+ fraction_values=fraction_values[i],
290
507
  )
291
508
 
292
509
  if args.outfile:
@@ -297,13 +514,27 @@ def get_efficiencies(args=None):
297
514
  return out
298
515
 
299
516
 
300
- def main(args=None):
301
- args = parse_args(args)
517
+ def main(args: Sequence[str]) -> dict | None:
518
+ """Main function to run working point calculation.
519
+
520
+ Parameters
521
+ ----------
522
+ args : Sequence[str] | None, optional
523
+ Input arguments, by default None
524
+
525
+ Returns
526
+ -------
527
+ dict | None
528
+ The output dict with the calculated values. When --outfile
529
+ was given, the return value is None
530
+ """
531
+ parsed_args = parse_args(args=args)
532
+
533
+ if parsed_args.effs:
534
+ out = get_working_points(args=parsed_args)
302
535
 
303
- if args.effs:
304
- out = get_working_points(args)
305
- elif args.disc_cuts:
306
- out = get_efficiencies(args)
536
+ elif parsed_args.disc_cuts:
537
+ out = get_efficiencies(args=parsed_args)
307
538
 
308
539
  if out:
309
540
  print(yaml.dump(out, sort_keys=False))
@@ -312,5 +543,5 @@ def main(args=None):
312
543
  return None
313
544
 
314
545
 
315
- if __name__ == "__main__":
316
- main()
546
+ if __name__ == "__main__": # pragma: no cover
547
+ main(args=sys.argv[1:])