atlas-ftag-tools 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -8,8 +8,9 @@ Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
8
8
  Requires-Python: <3.12,>=3.8
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: h5py>=3.0
11
- Requires-Dist: numpy
11
+ Requires-Dist: numpy>=2.2.3
12
12
  Requires-Dist: PyYAML>=5.1
13
+ Requires-Dist: scipy>=1.15.2
13
14
  Provides-Extra: dev
14
15
  Requires-Dist: ruff==0.6.2; extra == "dev"
15
16
  Requires-Dist: mypy==1.11.2; extra == "dev"
@@ -1,28 +1,30 @@
1
- ftag/__init__.py,sha256=QQHtJR1oF0VAd1zVUKsPimr4TLlVf4ymtbKBcaUPra0,737
1
+ ftag/__init__.py,sha256=v9emuK48Hhd-_TCiirfCNMsZSzk52frz1zEOgk9PViQ,787
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
3
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
4
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
5
- ftag/flavours.yaml,sha256=r_5c6SOwYWBCN4qOaz7ZLLEBmbod-ErTDsGGBXJiyZA,8289
5
+ ftag/flavours.yaml,sha256=5Lo9KWe-2KzmGMbc7o_X9gzwUyTl0Q5uVHYExduZ6T4,9502
6
+ ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
6
7
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
8
  ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
8
- ftag/labels.py,sha256=C7IylPTnc32dFXq8C2Ks2wuljYK3WaY2EsPLGrhtXy8,3932
9
- ftag/mock.py,sha256=Eyj3tkkaSSnqvS3G6NS7fq8sB__Nx8YE9-OM2_lpdoQ,4992
9
+ ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
10
+ ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
10
11
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
11
12
  ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
12
13
  ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
13
14
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
14
15
  ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
16
+ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
15
17
  ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
16
18
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
17
19
  ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
18
20
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
19
21
  ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
20
22
  ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
21
- ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- ftag/wps/discriminant.py,sha256=VJdZlJJUwaTeyxmIDEk23rQSAuvWs6wDA3XRjDI6-_c,4277
23
- ftag/wps/working_points.py,sha256=cvStSpP8Cbb_FWM8v59tFsscUvdeqi831tLn5BiHUEg,9741
24
- atlas_ftag_tools-0.2.8.dist-info/METADATA,sha256=PhA8ikzMnWQUOSQxW9PzmxzteUaxCwutISL0JR4UWZY,5153
25
- atlas_ftag_tools-0.2.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
26
- atlas_ftag_tools-0.2.8.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
- atlas_ftag_tools-0.2.8.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
- atlas_ftag_tools-0.2.8.dist-info/RECORD,,
23
+ ftag/utils/__init__.py,sha256=C0PgaA6Nk5WVpFqKhBhrHgj2mwsKJbSxoO6Cl67RsaI,544
24
+ ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
25
+ ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
26
+ atlas_ftag_tools-0.2.10.dist-info/METADATA,sha256=VUhrtQML6_bUKlmZNFlUXxTTt5YBzNYupTrdlaF5IAw,5190
27
+ atlas_ftag_tools-0.2.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
28
+ atlas_ftag_tools-0.2.10.dist-info/entry_points.txt,sha256=b46bVP_O8Mg6aSdPmyjGgVkaXSdyXZMeKAsofh2IDeA,133
29
+ atlas_ftag_tools-0.2.10.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
30
+ atlas_ftag_tools-0.2.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2,4 +2,4 @@
2
2
  h5move = ftag.hdf5.h5move:main
3
3
  h5split = ftag.hdf5.h5split:main
4
4
  vds = ftag.vds:main
5
- wps = ftag.wps.working_points:main
5
+ wps = ftag.working_points:main
ftag/__init__.py CHANGED
@@ -2,18 +2,18 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.8"
5
+ __version__ = "v0.2.10"
6
6
 
7
- from ftag import hdf5
7
+ from ftag import hdf5, utils
8
8
  from ftag.cuts import Cuts
9
9
  from ftag.flavours import Flavours
10
+ from ftag.fraction_optimization import calculate_best_fraction_values
10
11
  from ftag.labeller import Labeller
11
12
  from ftag.labels import Label, LabelContainer
12
13
  from ftag.mock import get_mock_file
13
14
  from ftag.sample import Sample
14
15
  from ftag.transform import Transform
15
- from ftag.wps.discriminant import get_discriminant
16
- from ftag.wps.working_points import get_working_points
16
+ from ftag.working_points import get_working_points
17
17
 
18
18
  __all__ = [
19
19
  "Cuts",
@@ -24,8 +24,9 @@ __all__ = [
24
24
  "Sample",
25
25
  "Transform",
26
26
  "__version__",
27
- "get_discriminant",
27
+ "calculate_best_fraction_values",
28
28
  "get_mock_file",
29
29
  "get_working_points",
30
30
  "hdf5",
31
+ "utils",
31
32
  ]
ftag/flavours.yaml CHANGED
@@ -60,12 +60,24 @@
60
60
  colour: tab:orange
61
61
  category: single-btag-ghost
62
62
  _px: pc
63
- - name: ghostujets
64
- label: Light-jets
65
- cuts: ["HadronGhostTruthLabelID == 0"]
63
+ - name: ghostsjets
64
+ label: $s$-jets
65
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 3"]
66
+ colour: tab:red
67
+ category: single-btag-ghost
68
+ _px: ps
69
+ - name: ghostudjets
70
+ label: Light-quark-jets
71
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID <= 2"]
66
72
  colour: tab:green
67
73
  category: single-btag-ghost
68
- _px: pu
74
+ _px: pud
75
+ - name: ghostgjets
76
+ label: Gluon-jets
77
+ cuts: ["HadronGhostTruthLabelID == 0", "PartonTruthLabelID == 21"]
78
+ colour: tab:gray
79
+ category: single-btag-ghost
80
+ _px: pg
69
81
  - name: ghosttaujets
70
82
  label: $\tau$-jets
71
83
  cuts: ["HadronGhostTruthLabelID == 15"]
@@ -119,6 +131,21 @@
119
131
  cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount == 0", "GhostCHadronsFinalCount == 0"]
120
132
  colour: "green"
121
133
  category: xbb
134
+ - name: htauel
135
+ label: $H \rightarrow \tau e$
136
+ cuts: ["R10TruthLabel_R22v1 == 14"]
137
+ colour: "#b40612"
138
+ category: xbb
139
+ - name: htaumu
140
+ label: $H \rightarrow \tau\mu$
141
+ cuts: ["R10TruthLabel_R22v1 == 15"]
142
+ colour: "#b40657"
143
+ category: xbb
144
+ - name: htauhad
145
+ label: $H \rightarrow \tau\tau$
146
+ cuts: ["R10TruthLabel_R22v1 == 16"]
147
+ colour: "#b406a0"
148
+ category: xbb
122
149
 
123
150
  # extended Xbb tagging
124
151
  - name: tqqb
@@ -305,3 +332,19 @@
305
332
  cuts: ["iffClass == 0"]
306
333
  colour: tab:gray
307
334
  category: isolation
335
+ # Trigger-Xbb tagging
336
+ - name: dRMatchedHbb
337
+ label: $H \rightarrow b\bar{b}$
338
+ cuts: ["HadronConeExclExtendedTruthLabelID == 55", "n_truth_higgs > 0", "n_truth_top == 0"]
339
+ colour: tab:blue
340
+ category: trigger-xbb
341
+ - name: dRMatchedTop
342
+ label: Inclusive Top
343
+ cuts: ["n_truth_higgs == 0", "n_truth_top > 0"]
344
+ colour: "#A300A3"
345
+ category: trigger-xbb
346
+ - name: dRMatchedQCD
347
+ label: QCD
348
+ cuts: ["n_truth_higgs == 0", "n_truth_top == 0"]
349
+ colour: "#38761D"
350
+ category: trigger-xbb
@@ -0,0 +1,184 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from scipy.optimize import minimize
7
+
8
+ from ftag import Flavours
9
+ from ftag.utils import calculate_rejection, get_discriminant, logger
10
+
11
+ if TYPE_CHECKING: # pragma: no cover
12
+ from ftag.labels import Label, LabelContainer
13
+
14
+
15
+ def convert_dict(
16
+ fraction_values: dict | np.ndarray,
17
+ backgrounds: LabelContainer,
18
+ ) -> np.ndarray | dict:
19
+ if isinstance(fraction_values, dict):
20
+ return np.array([fraction_values[iter_bkg.frac_str] for iter_bkg in backgrounds])
21
+
22
+ if isinstance(fraction_values, np.ndarray):
23
+ fraction_values = [
24
+ float(frac_value / np.sum(fraction_values)) for frac_value in fraction_values
25
+ ]
26
+
27
+ return dict(zip([iter_bkg.frac_str for iter_bkg in backgrounds], fraction_values))
28
+
29
+ raise ValueError(
30
+ f"Only input of type `dict` or `np.ndarray` are accepted! You gave {type(fraction_values)}"
31
+ )
32
+
33
+
34
+ def get_bkg_norm_dict(
35
+ jets: np.ndarray,
36
+ tagger: str,
37
+ signal: Label,
38
+ flavours: LabelContainer,
39
+ working_point: float,
40
+ ) -> dict:
41
+ # Init a dict for the bkg rejection norm values
42
+ bkg_rej_norm = {}
43
+
44
+ # Get the background classes
45
+ backgrounds = flavours.backgrounds(signal)
46
+
47
+ # Define a bool array if the jet is signal
48
+ is_signal = signal.cuts(jets).idx
49
+
50
+ # Loop over backgrounds
51
+ for bkg in backgrounds:
52
+ # Get the fraction value dict to maximize rejection for given class
53
+ frac_dict_bkg = {
54
+ iter_bkg.frac_str: 1 - (0.01 * len(backgrounds)) if iter_bkg == bkg else 0.01
55
+ for iter_bkg in backgrounds
56
+ }
57
+
58
+ # Calculate the disc value using the new fraction dict
59
+ disc = get_discriminant(
60
+ jets=jets,
61
+ tagger=tagger,
62
+ signal=signal,
63
+ flavours=flavours,
64
+ fraction_values=frac_dict_bkg,
65
+ )
66
+
67
+ # Calculate the discriminant
68
+ bkg_rej_norm[bkg.name] = calculate_rejection(
69
+ sig_disc=disc[is_signal],
70
+ bkg_disc=disc[bkg.cuts(jets).idx],
71
+ target_eff=working_point,
72
+ )
73
+
74
+ return bkg_rej_norm
75
+
76
+
77
+ def calculate_rejection_sum(
78
+ fraction_dict: dict | np.ndarray,
79
+ jets: np.ndarray,
80
+ tagger: str,
81
+ signal: Label,
82
+ flavours: LabelContainer,
83
+ working_point: float,
84
+ bkg_norm_dict: dict,
85
+ rejection_weights: dict,
86
+ ) -> float:
87
+ # Get the background classes
88
+ backgrounds = flavours.backgrounds(signal)
89
+
90
+ # Define a bool array if the jet is signal
91
+ is_signal = signal.cuts(jets).idx
92
+
93
+ # Check that the fraction dict is a dict
94
+ if isinstance(fraction_dict, np.ndarray):
95
+ fraction_dict = convert_dict(
96
+ fraction_values=fraction_dict,
97
+ backgrounds=backgrounds,
98
+ )
99
+
100
+ # Calculate discriminant
101
+ disc = get_discriminant(
102
+ jets=jets,
103
+ tagger=tagger,
104
+ signal=signal,
105
+ flavours=flavours,
106
+ fraction_values=fraction_dict,
107
+ )
108
+
109
+ # Init a dict to which the bkg rejs are added
110
+ sum_bkg_rej = 0
111
+
112
+ # Loop over the backgrounds and calculate the rejections
113
+ for iter_bkg in backgrounds:
114
+ sum_bkg_rej += (
115
+ calculate_rejection(
116
+ sig_disc=disc[is_signal],
117
+ bkg_disc=disc[iter_bkg.cuts(jets).idx],
118
+ target_eff=working_point,
119
+ )
120
+ / bkg_norm_dict[iter_bkg.name]
121
+ ) * rejection_weights[iter_bkg.name]
122
+
123
+ # Return the negative sum to enable minimizer
124
+ return -1 * sum_bkg_rej
125
+
126
+
127
+ def calculate_best_fraction_values(
128
+ jets: np.ndarray,
129
+ tagger: str,
130
+ signal: Label,
131
+ flavours: LabelContainer,
132
+ working_point: float,
133
+ rejection_weights: dict | None = None,
134
+ optimizer_method: str = "Powell",
135
+ ) -> dict:
136
+ logger.debug("Calculating best fraction values.")
137
+ logger.debug(f"tagger: {tagger}")
138
+ logger.debug(f"signal: {signal}")
139
+ logger.debug(f"flavours: {flavours}")
140
+ logger.debug(f"working_point: {working_point}")
141
+ logger.debug(f"rejection_weights: {rejection_weights}")
142
+ logger.debug(f"optimizer_method: {optimizer_method}")
143
+
144
+ # Ensure Label instance
145
+ if isinstance(signal, str):
146
+ signal = Flavours[signal]
147
+
148
+ # Get the background classes
149
+ backgrounds = flavours.backgrounds(signal)
150
+
151
+ # Define a default fraction dict
152
+ def_frac_dict = {iter_bkg.frac_str: 1 / len(backgrounds) for iter_bkg in backgrounds}
153
+
154
+ # Define rejection weights if not set
155
+ if rejection_weights is None:
156
+ rejection_weights = {iter_bkg.name: 1 for iter_bkg in backgrounds}
157
+
158
+ # Get the normalisation for all bkg rejections
159
+ bkg_norm_dict = get_bkg_norm_dict(
160
+ jets=jets,
161
+ tagger=tagger,
162
+ signal=signal,
163
+ flavours=flavours,
164
+ working_point=working_point,
165
+ )
166
+
167
+ # Get the best fraction values combination
168
+ result = minimize(
169
+ fun=calculate_rejection_sum,
170
+ x0=convert_dict(fraction_values=def_frac_dict, backgrounds=backgrounds),
171
+ method=optimizer_method,
172
+ bounds=[(0, 1)] * len(backgrounds),
173
+ args=(jets, tagger, signal, flavours, working_point, bkg_norm_dict, rejection_weights),
174
+ )
175
+
176
+ # Get the final fraction dict
177
+ final_frac_dict = convert_dict(fraction_values=result.x, backgrounds=backgrounds)
178
+
179
+ logger.info(f"Minimization Success: {result.success}")
180
+ logger.info("The following best fraction values were found:")
181
+ for frac_str, frac_value in final_frac_dict.items():
182
+ logger.info(f"{frac_str}: {round(frac_value, ndigits=3)}")
183
+
184
+ return final_frac_dict
ftag/labels.py CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
62
62
  except KeyError as e:
63
63
  raise KeyError(f"Label '{key}' not found") from e
64
64
 
65
+ def __len__(self) -> int:
66
+ return len(self.labels.keys())
67
+
65
68
  def __getattr__(self, name) -> Label:
66
69
  return self[name]
67
70
 
@@ -120,8 +123,13 @@ class LabelContainer:
120
123
  def from_list(cls, labels: list[Label]) -> LabelContainer:
121
124
  return cls({f.name: f for f in labels})
122
125
 
123
- def backgrounds(self, label: Label, only_signals: bool = True) -> LabelContainer:
124
- bkg = [f for f in self if f.category == label.category and f != label]
126
+ def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
127
+ bkg = [f for f in self if f.category == signal.category and f != signal]
125
128
  if not only_signals:
126
129
  bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
130
+ if len(bkg) == 0:
131
+ raise TypeError(
132
+ "No background flavour could be found in the flavours for signal "
133
+ f"flavour {signal.name}"
134
+ )
127
135
  return LabelContainer.from_list(bkg)
ftag/mock.py CHANGED
@@ -54,33 +54,74 @@ TRACK_VARS = [
54
54
  ]
55
55
 
56
56
 
57
- def softmax(x, axis=None):
57
+ def softmax(x: np.ndarray, axis: int | None = None) -> np.ndarray:
58
+ """Softmax function for numpy arrays.
59
+
60
+ Parameters
61
+ ----------
62
+ x : np.ndarray
63
+ Input array for the softmax
64
+ axis : int | None, optional
65
+ Axis along which the softmax is calculated, by default None
66
+
67
+ Returns
68
+ -------
69
+ np.ndarray
70
+ Output array with the softmax output
71
+ """
58
72
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
59
73
  return e_x / e_x.sum(axis=axis, keepdims=True)
60
74
 
61
75
 
62
- def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
63
- means = [
64
- [2, 0, 0, 0],
65
- [0, 1, 0, 0],
66
- [0, 0, 3.5, 0],
67
- [0, 0, 0, 1],
68
- ]
76
+ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
69
77
  if not is_xbb:
70
78
  label_dict = {"u": 0, "c": 4, "b": 5, "tau": 15}
71
- label_mapping = dict(zip(label_dict.values(), means))
72
- else:
73
- label_dict = {"hbb": 11, "hcc": 12, "top": 1, "qcd": 10}
74
- label_mapping = dict(zip(label_dict.values(), means))
75
79
 
80
+ else:
81
+ label_dict = {
82
+ "hbb": 11,
83
+ "hcc": 12,
84
+ "top": 1,
85
+ "qcd": 10,
86
+ "htauel": 14,
87
+ "htaumu": 15,
88
+ "htauhad": 16,
89
+ }
90
+
91
+ # Set random seed
76
92
  rng = np.random.default_rng(42)
77
- nclass = len(label_dict)
78
- scores = np.zeros((len(labels), nclass))
79
- scales = [1, 2.5, 5, 1]
93
+
94
+ # Set a list of possible means/scales
95
+ mean_scale_list = [1, 2, 2.5, 3.5]
96
+
97
+ # Get the number of classes
98
+ n_classes = len(label_dict)
99
+
100
+ # Init a scores array
101
+ scores = np.zeros((len(labels), n_classes))
102
+
103
+ # Generate means/scales
104
+ means = []
105
+ scales = []
106
+ for i in range(n_classes):
107
+ tmp_means = []
108
+ tmp_means = [
109
+ 0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
110
+ for j in range(n_classes)
111
+ ]
112
+ means.append(tmp_means)
113
+ scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
114
+
115
+ # Map the labels to the means
116
+ label_mapping = dict(zip(label_dict.values(), means))
117
+
118
+ # Generate random mock scores
80
119
  for i, (label, count) in enumerate(zip(*np.unique(labels, return_counts=True))):
81
120
  scores[labels == label] = rng.normal(
82
- loc=label_mapping[label], scale=scales[i], size=(count, nclass)
121
+ loc=label_mapping[label], scale=scales[i], size=(count, n_classes)
83
122
  )
123
+
124
+ # Pipe scores through softmax
84
125
  scores = softmax(scores, axis=1)
85
126
  name = "MockXbbTagger" if is_xbb else "MockTagger"
86
127
  cols = [f"{name}_p{x}" for x in label_dict]
@@ -103,7 +144,7 @@ def mock_jets(num_jets=1000) -> np.ndarray:
103
144
  jets["HadronConeExclTruthLabelID"] = rng.choice([0, 4, 5, 15], size=num_jets)
104
145
  jets["GhostBHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
105
146
  jets["GhostCHadronsFinalCount"] = rng.choice([0, 1, 2], size=num_jets)
106
- jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
147
+ jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12, 14, 15, 16], size=num_jets)
107
148
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
108
149
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
109
150
  return join_structured_arrays([jets, scores, xbb_scores])
ftag/utils/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from ftag.utils.logging import logger, set_log_level
4
+ from ftag.utils.metrics import (
5
+ calculate_efficiency,
6
+ calculate_efficiency_error,
7
+ calculate_rejection,
8
+ calculate_rejection_error,
9
+ get_discriminant,
10
+ save_divide,
11
+ weighted_percentile,
12
+ )
13
+
14
+ __all__ = [
15
+ "calculate_efficiency",
16
+ "calculate_efficiency_error",
17
+ "calculate_rejection",
18
+ "calculate_rejection_error",
19
+ "get_discriminant",
20
+ "logger",
21
+ "save_divide",
22
+ "set_log_level",
23
+ "weighted_percentile",
24
+ ]
ftag/utils/logging.py ADDED
@@ -0,0 +1,123 @@
1
+ """Configuration for logger of atlas-ftag-tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import ClassVar
8
+
9
+
10
+ class CustomFormatter(logging.Formatter):
11
+ """
12
+ Logging Formatter to add colours and count warning / errors using implementation
13
+ from
14
+ https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
15
+ """
16
+
17
+ grey = "\x1b[38;21m"
18
+ yellow = "\x1b[33;21m"
19
+ green = "\x1b[32;21m"
20
+ red = "\x1b[31;21m"
21
+ bold_red = "\x1b[31;1m"
22
+ reset = "\x1b[0m"
23
+ debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
24
+ date_format = "%(levelname)s:%(name)s: %(message)s"
25
+
26
+ formats: ClassVar = {
27
+ logging.DEBUG: grey + debugformat + reset,
28
+ logging.INFO: green + date_format + reset,
29
+ logging.WARNING: yellow + date_format + reset,
30
+ logging.ERROR: red + debugformat + reset,
31
+ logging.CRITICAL: bold_red + debugformat + reset,
32
+ }
33
+
34
+ def format(self, record):
35
+ log_fmt = self.formats.get(record.levelno)
36
+ formatter = logging.Formatter(log_fmt)
37
+ return formatter.format(record)
38
+
39
+
40
+ def get_log_level(
41
+ level: str,
42
+ ):
43
+ """Get logging levels with string key.
44
+
45
+ Parameters
46
+ ----------
47
+ level : str
48
+ Log level as string.
49
+
50
+ Returns
51
+ -------
52
+ logging level
53
+ logging object with log level info
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If non-valid option is given
59
+ """
60
+ log_levels = {
61
+ "CRITICAL": logging.CRITICAL,
62
+ "ERROR": logging.ERROR,
63
+ "WARNING": logging.WARNING,
64
+ "INFO": logging.INFO,
65
+ "DEBUG": logging.DEBUG,
66
+ "NOTSET": logging.NOTSET,
67
+ }
68
+ if level not in log_levels:
69
+ raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
70
+ return log_levels[level]
71
+
72
+
73
+ def initialise_logger(
74
+ log_level: str | None = None,
75
+ ):
76
+ """Initialise.
77
+
78
+ Parameters
79
+ ----------
80
+ log_level : str, optional
81
+ Logging level defining the verbose level. Accepted values are:
82
+ CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
83
+ If the log_level is not set, it will be set to info
84
+
85
+ Returns
86
+ -------
87
+ logger
88
+ logger object with new level set
89
+ """
90
+ retrieved_log_level = get_log_level(
91
+ os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
92
+ )
93
+
94
+ tools_logger = logging.getLogger("atlas-ftag-tools")
95
+ tools_logger.setLevel(retrieved_log_level)
96
+ ch_handler = logging.StreamHandler()
97
+ ch_handler.setLevel(retrieved_log_level)
98
+ ch_handler.setFormatter(CustomFormatter())
99
+
100
+ tools_logger.addHandler(ch_handler)
101
+ tools_logger.propagate = False
102
+ return tools_logger
103
+
104
+
105
+ def set_log_level(
106
+ tools_logger,
107
+ log_level: str,
108
+ ):
109
+ """Setting log level.
110
+
111
+ Parameters
112
+ ----------
113
+ tools_logger : logger
114
+ logger object
115
+ log_level : str
116
+ Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
117
+ """
118
+ tools_logger.setLevel(get_log_level(log_level))
119
+ for handler in tools_logger.handlers:
120
+ handler.setLevel(get_log_level(log_level))
121
+
122
+
123
+ logger = initialise_logger()