atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.9
3
+ Version: 0.2.10
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -8,8 +8,9 @@ Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
8
8
  Requires-Python: <3.12,>=3.8
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: h5py>=3.0
11
- Requires-Dist: numpy
11
+ Requires-Dist: numpy>=2.2.3
12
12
  Requires-Dist: PyYAML>=5.1
13
+ Requires-Dist: scipy>=1.15.2
13
14
  Provides-Extra: dev
14
15
  Requires-Dist: ruff==0.6.2; extra == "dev"
15
16
  Requires-Dist: mypy==1.11.2; extra == "dev"
@@ -1,28 +1,30 @@
1
- ftag/__init__.py,sha256=YRug5UslRbNoQACbEhdenDS6wXmsmeLjlz4JaKP6eHs,737
1
+ ftag/__init__.py,sha256=v9emuK48Hhd-_TCiirfCNMsZSzk52frz1zEOgk9PViQ,787
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
3
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
4
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
5
- ftag/flavours.yaml,sha256=87xBvLkMDkicuRMaXtxcao8gjEAgvlTbgjAzpvx4YFM,9021
5
+ ftag/flavours.yaml,sha256=5Lo9KWe-2KzmGMbc7o_X9gzwUyTl0Q5uVHYExduZ6T4,9502
6
+ ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
6
7
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
8
  ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
8
- ftag/labels.py,sha256=C7IylPTnc32dFXq8C2Ks2wuljYK3WaY2EsPLGrhtXy8,3932
9
+ ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
9
10
  ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
10
11
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
11
12
  ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
12
13
  ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
13
14
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
14
15
  ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
16
+ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
15
17
  ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
16
18
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
17
19
  ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
18
20
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
19
21
  ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
20
22
  ftag/hdf5/h5writer.py,sha256=9FkClV__UbBqmFsq_h2jwiZnbWVm8QFRL_4mDZZBbTs,5316
21
- ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- ftag/wps/discriminant.py,sha256=GKa0zZlLREdm0mCYSbcWXITYe3VEn3PXOBQiPg5WvgM,2521
23
- ftag/wps/working_points.py,sha256=jXyikB-bf73EaYFkngjE977-Ytvb9nDTqIdHxWW6WQQ,15960
24
- atlas_ftag_tools-0.2.9.dist-info/METADATA,sha256=lXC-e0iHMDtvJH8h3i7PcCEKh4_CFz5vlqdGXKSEoV4,5153
25
- atlas_ftag_tools-0.2.9.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
26
- atlas_ftag_tools-0.2.9.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
- atlas_ftag_tools-0.2.9.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
- atlas_ftag_tools-0.2.9.dist-info/RECORD,,
23
+ ftag/utils/__init__.py,sha256=C0PgaA6Nk5WVpFqKhBhrHgj2mwsKJbSxoO6Cl67RsaI,544
24
+ ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
25
+ ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
26
+ atlas_ftag_tools-0.2.10.dist-info/METADATA,sha256=VUhrtQML6_bUKlmZNFlUXxTTt5YBzNYupTrdlaF5IAw,5190
27
+ atlas_ftag_tools-0.2.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
28
+ atlas_ftag_tools-0.2.10.dist-info/entry_points.txt,sha256=b46bVP_O8Mg6aSdPmyjGgVkaXSdyXZMeKAsofh2IDeA,133
29
+ atlas_ftag_tools-0.2.10.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
30
+ atlas_ftag_tools-0.2.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.1.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2,4 +2,4 @@
2
2
  h5move = ftag.hdf5.h5move:main
3
3
  h5split = ftag.hdf5.h5split:main
4
4
  vds = ftag.vds:main
5
- wps = ftag.wps.working_points:main
5
+ wps = ftag.working_points:main
ftag/__init__.py CHANGED
@@ -2,18 +2,18 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.9"
5
+ __version__ = "v0.2.10"
6
6
 
7
- from ftag import hdf5
7
+ from ftag import hdf5, utils
8
8
  from ftag.cuts import Cuts
9
9
  from ftag.flavours import Flavours
10
+ from ftag.fraction_optimization import calculate_best_fraction_values
10
11
  from ftag.labeller import Labeller
11
12
  from ftag.labels import Label, LabelContainer
12
13
  from ftag.mock import get_mock_file
13
14
  from ftag.sample import Sample
14
15
  from ftag.transform import Transform
15
- from ftag.wps.discriminant import get_discriminant
16
- from ftag.wps.working_points import get_working_points
16
+ from ftag.working_points import get_working_points
17
17
 
18
18
  __all__ = [
19
19
  "Cuts",
@@ -24,8 +24,9 @@ __all__ = [
24
24
  "Sample",
25
25
  "Transform",
26
26
  "__version__",
27
- "get_discriminant",
27
+ "calculate_best_fraction_values",
28
28
  "get_mock_file",
29
29
  "get_working_points",
30
30
  "hdf5",
31
+ "utils",
31
32
  ]
ftag/flavours.yaml CHANGED
@@ -332,3 +332,19 @@
332
332
  cuts: ["iffClass == 0"]
333
333
  colour: tab:gray
334
334
  category: isolation
335
+ # Trigger-Xbb tagging
336
+ - name: dRMatchedHbb
337
+ label: $H \rightarrow b\bar{b}$
338
+ cuts: ["HadronConeExclExtendedTruthLabelID == 55", "n_truth_higgs > 0", "n_truth_top == 0"]
339
+ colour: tab:blue
340
+ category: trigger-xbb
341
+ - name: dRMatchedTop
342
+ label: Inclusive Top
343
+ cuts: ["n_truth_higgs == 0", "n_truth_top > 0"]
344
+ colour: "#A300A3"
345
+ category: trigger-xbb
346
+ - name: dRMatchedQCD
347
+ label: QCD
348
+ cuts: ["n_truth_higgs == 0", "n_truth_top == 0"]
349
+ colour: "#38761D"
350
+ category: trigger-xbb
@@ -0,0 +1,184 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from scipy.optimize import minimize
7
+
8
+ from ftag import Flavours
9
+ from ftag.utils import calculate_rejection, get_discriminant, logger
10
+
11
+ if TYPE_CHECKING: # pragma: no cover
12
+ from ftag.labels import Label, LabelContainer
13
+
14
+
15
+ def convert_dict(
16
+ fraction_values: dict | np.ndarray,
17
+ backgrounds: LabelContainer,
18
+ ) -> np.ndarray | dict:
19
+ if isinstance(fraction_values, dict):
20
+ return np.array([fraction_values[iter_bkg.frac_str] for iter_bkg in backgrounds])
21
+
22
+ if isinstance(fraction_values, np.ndarray):
23
+ fraction_values = [
24
+ float(frac_value / np.sum(fraction_values)) for frac_value in fraction_values
25
+ ]
26
+
27
+ return dict(zip([iter_bkg.frac_str for iter_bkg in backgrounds], fraction_values))
28
+
29
+ raise ValueError(
30
+ f"Only input of type `dict` or `np.ndarray` are accepted! You gave {type(fraction_values)}"
31
+ )
32
+
33
+
34
+ def get_bkg_norm_dict(
35
+ jets: np.ndarray,
36
+ tagger: str,
37
+ signal: Label,
38
+ flavours: LabelContainer,
39
+ working_point: float,
40
+ ) -> dict:
41
+ # Init a dict for the bkg rejection norm values
42
+ bkg_rej_norm = {}
43
+
44
+ # Get the background classes
45
+ backgrounds = flavours.backgrounds(signal)
46
+
47
+ # Define a bool array if the jet is signal
48
+ is_signal = signal.cuts(jets).idx
49
+
50
+ # Loop over backgrounds
51
+ for bkg in backgrounds:
52
+ # Get the fraction value dict to maximize rejection for given class
53
+ frac_dict_bkg = {
54
+ iter_bkg.frac_str: 1 - (0.01 * len(backgrounds)) if iter_bkg == bkg else 0.01
55
+ for iter_bkg in backgrounds
56
+ }
57
+
58
+ # Calculate the disc value using the new fraction dict
59
+ disc = get_discriminant(
60
+ jets=jets,
61
+ tagger=tagger,
62
+ signal=signal,
63
+ flavours=flavours,
64
+ fraction_values=frac_dict_bkg,
65
+ )
66
+
67
+ # Calculate the discriminant
68
+ bkg_rej_norm[bkg.name] = calculate_rejection(
69
+ sig_disc=disc[is_signal],
70
+ bkg_disc=disc[bkg.cuts(jets).idx],
71
+ target_eff=working_point,
72
+ )
73
+
74
+ return bkg_rej_norm
75
+
76
+
77
+ def calculate_rejection_sum(
78
+ fraction_dict: dict | np.ndarray,
79
+ jets: np.ndarray,
80
+ tagger: str,
81
+ signal: Label,
82
+ flavours: LabelContainer,
83
+ working_point: float,
84
+ bkg_norm_dict: dict,
85
+ rejection_weights: dict,
86
+ ) -> float:
87
+ # Get the background classes
88
+ backgrounds = flavours.backgrounds(signal)
89
+
90
+ # Define a bool array if the jet is signal
91
+ is_signal = signal.cuts(jets).idx
92
+
93
+ # Check that the fraction dict is a dict
94
+ if isinstance(fraction_dict, np.ndarray):
95
+ fraction_dict = convert_dict(
96
+ fraction_values=fraction_dict,
97
+ backgrounds=backgrounds,
98
+ )
99
+
100
+ # Calculate discriminant
101
+ disc = get_discriminant(
102
+ jets=jets,
103
+ tagger=tagger,
104
+ signal=signal,
105
+ flavours=flavours,
106
+ fraction_values=fraction_dict,
107
+ )
108
+
109
+ # Init a dict to which the bkg rejs are added
110
+ sum_bkg_rej = 0
111
+
112
+ # Loop over the backgrounds and calculate the rejections
113
+ for iter_bkg in backgrounds:
114
+ sum_bkg_rej += (
115
+ calculate_rejection(
116
+ sig_disc=disc[is_signal],
117
+ bkg_disc=disc[iter_bkg.cuts(jets).idx],
118
+ target_eff=working_point,
119
+ )
120
+ / bkg_norm_dict[iter_bkg.name]
121
+ ) * rejection_weights[iter_bkg.name]
122
+
123
+ # Return the negative sum to enable minimizer
124
+ return -1 * sum_bkg_rej
125
+
126
+
127
+ def calculate_best_fraction_values(
128
+ jets: np.ndarray,
129
+ tagger: str,
130
+ signal: Label,
131
+ flavours: LabelContainer,
132
+ working_point: float,
133
+ rejection_weights: dict | None = None,
134
+ optimizer_method: str = "Powell",
135
+ ) -> dict:
136
+ logger.debug("Calculating best fraction values.")
137
+ logger.debug(f"tagger: {tagger}")
138
+ logger.debug(f"signal: {signal}")
139
+ logger.debug(f"flavours: {flavours}")
140
+ logger.debug(f"working_point: {working_point}")
141
+ logger.debug(f"rejection_weights: {rejection_weights}")
142
+ logger.debug(f"optimizer_method: {optimizer_method}")
143
+
144
+ # Ensure Label instance
145
+ if isinstance(signal, str):
146
+ signal = Flavours[signal]
147
+
148
+ # Get the background classes
149
+ backgrounds = flavours.backgrounds(signal)
150
+
151
+ # Define a default fraction dict
152
+ def_frac_dict = {iter_bkg.frac_str: 1 / len(backgrounds) for iter_bkg in backgrounds}
153
+
154
+ # Define rejection weights if not set
155
+ if rejection_weights is None:
156
+ rejection_weights = {iter_bkg.name: 1 for iter_bkg in backgrounds}
157
+
158
+ # Get the normalisation for all bkg rejections
159
+ bkg_norm_dict = get_bkg_norm_dict(
160
+ jets=jets,
161
+ tagger=tagger,
162
+ signal=signal,
163
+ flavours=flavours,
164
+ working_point=working_point,
165
+ )
166
+
167
+ # Get the best fraction values combination
168
+ result = minimize(
169
+ fun=calculate_rejection_sum,
170
+ x0=convert_dict(fraction_values=def_frac_dict, backgrounds=backgrounds),
171
+ method=optimizer_method,
172
+ bounds=[(0, 1)] * len(backgrounds),
173
+ args=(jets, tagger, signal, flavours, working_point, bkg_norm_dict, rejection_weights),
174
+ )
175
+
176
+ # Get the final fraction dict
177
+ final_frac_dict = convert_dict(fraction_values=result.x, backgrounds=backgrounds)
178
+
179
+ logger.info(f"Minimization Success: {result.success}")
180
+ logger.info("The following best fraction values were found:")
181
+ for frac_str, frac_value in final_frac_dict.items():
182
+ logger.info(f"{frac_str}: {round(frac_value, ndigits=3)}")
183
+
184
+ return final_frac_dict
ftag/labels.py CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
62
62
  except KeyError as e:
63
63
  raise KeyError(f"Label '{key}' not found") from e
64
64
 
65
+ def __len__(self) -> int:
66
+ return len(self.labels.keys())
67
+
65
68
  def __getattr__(self, name) -> Label:
66
69
  return self[name]
67
70
 
@@ -120,8 +123,13 @@ class LabelContainer:
120
123
  def from_list(cls, labels: list[Label]) -> LabelContainer:
121
124
  return cls({f.name: f for f in labels})
122
125
 
123
- def backgrounds(self, label: Label, only_signals: bool = True) -> LabelContainer:
124
- bkg = [f for f in self if f.category == label.category and f != label]
126
+ def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
127
+ bkg = [f for f in self if f.category == signal.category and f != signal]
125
128
  if not only_signals:
126
129
  bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
130
+ if len(bkg) == 0:
131
+ raise TypeError(
132
+ "No background flavour could be found in the flavours for signal "
133
+ f"flavour {signal.name}"
134
+ )
127
135
  return LabelContainer.from_list(bkg)
ftag/utils/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from ftag.utils.logging import logger, set_log_level
4
+ from ftag.utils.metrics import (
5
+ calculate_efficiency,
6
+ calculate_efficiency_error,
7
+ calculate_rejection,
8
+ calculate_rejection_error,
9
+ get_discriminant,
10
+ save_divide,
11
+ weighted_percentile,
12
+ )
13
+
14
+ __all__ = [
15
+ "calculate_efficiency",
16
+ "calculate_efficiency_error",
17
+ "calculate_rejection",
18
+ "calculate_rejection_error",
19
+ "get_discriminant",
20
+ "logger",
21
+ "save_divide",
22
+ "set_log_level",
23
+ "weighted_percentile",
24
+ ]
ftag/utils/logging.py ADDED
@@ -0,0 +1,123 @@
1
+ """Configuration for logger of atlas-ftag-tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import ClassVar
8
+
9
+
10
+ class CustomFormatter(logging.Formatter):
11
+ """
12
+ Logging Formatter to add colours and count warning / errors using implementation
13
+ from
14
+ https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
15
+ """
16
+
17
+ grey = "\x1b[38;21m"
18
+ yellow = "\x1b[33;21m"
19
+ green = "\x1b[32;21m"
20
+ red = "\x1b[31;21m"
21
+ bold_red = "\x1b[31;1m"
22
+ reset = "\x1b[0m"
23
+ debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
24
+ date_format = "%(levelname)s:%(name)s: %(message)s"
25
+
26
+ formats: ClassVar = {
27
+ logging.DEBUG: grey + debugformat + reset,
28
+ logging.INFO: green + date_format + reset,
29
+ logging.WARNING: yellow + date_format + reset,
30
+ logging.ERROR: red + debugformat + reset,
31
+ logging.CRITICAL: bold_red + debugformat + reset,
32
+ }
33
+
34
+ def format(self, record):
35
+ log_fmt = self.formats.get(record.levelno)
36
+ formatter = logging.Formatter(log_fmt)
37
+ return formatter.format(record)
38
+
39
+
40
+ def get_log_level(
41
+ level: str,
42
+ ):
43
+ """Get logging levels with string key.
44
+
45
+ Parameters
46
+ ----------
47
+ level : str
48
+ Log level as string.
49
+
50
+ Returns
51
+ -------
52
+ logging level
53
+ logging object with log level info
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If non-valid option is given
59
+ """
60
+ log_levels = {
61
+ "CRITICAL": logging.CRITICAL,
62
+ "ERROR": logging.ERROR,
63
+ "WARNING": logging.WARNING,
64
+ "INFO": logging.INFO,
65
+ "DEBUG": logging.DEBUG,
66
+ "NOTSET": logging.NOTSET,
67
+ }
68
+ if level not in log_levels:
69
+ raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
70
+ return log_levels[level]
71
+
72
+
73
+ def initialise_logger(
74
+ log_level: str | None = None,
75
+ ):
76
+ """Initialise.
77
+
78
+ Parameters
79
+ ----------
80
+ log_level : str, optional
81
+ Logging level defining the verbose level. Accepted values are:
82
+ CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
83
+ If the log_level is not set, it will be set to info
84
+
85
+ Returns
86
+ -------
87
+ logger
88
+ logger object with new level set
89
+ """
90
+ retrieved_log_level = get_log_level(
91
+ os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
92
+ )
93
+
94
+ tools_logger = logging.getLogger("atlas-ftag-tools")
95
+ tools_logger.setLevel(retrieved_log_level)
96
+ ch_handler = logging.StreamHandler()
97
+ ch_handler.setLevel(retrieved_log_level)
98
+ ch_handler.setFormatter(CustomFormatter())
99
+
100
+ tools_logger.addHandler(ch_handler)
101
+ tools_logger.propagate = False
102
+ return tools_logger
103
+
104
+
105
+ def set_log_level(
106
+ tools_logger,
107
+ log_level: str,
108
+ ):
109
+ """Setting log level.
110
+
111
+ Parameters
112
+ ----------
113
+ tools_logger : logger
114
+ logger object
115
+ log_level : str
116
+ Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
117
+ """
118
+ tools_logger.setLevel(get_log_level(log_level))
119
+ for handler in tools_logger.handlers:
120
+ handler.setLevel(get_log_level(log_level))
121
+
122
+
123
+ logger = initialise_logger()
ftag/utils/metrics.py ADDED
@@ -0,0 +1,431 @@
1
+ """Tools for metrics module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ from scipy.ndimage import gaussian_filter1d
9
+
10
+ from ftag.utils import logger
11
+
12
+ if TYPE_CHECKING: # pragma: no cover
13
+ from ftag.labels import Label, LabelContainer
14
+
15
+
16
+ def save_divide(
17
+ numerator: np.ndarray | float,
18
+ denominator: np.ndarray | float,
19
+ default: float = 1.0,
20
+ ):
21
+ """Save divide for denominator equal to 0.
22
+
23
+ Division using numpy divide function returning default value in cases where
24
+ denominator is 0.
25
+
26
+ Parameters
27
+ ----------
28
+ numerator: np.ndarray | float,
29
+ Numerator in the ratio calculation.
30
+ denominator: np.ndarray | float,
31
+ Denominator in the ratio calculation.
32
+ default: float
33
+ Default value which is returned if denominator is 0.
34
+
35
+ Returns
36
+ -------
37
+ ratio: np.ndarray | float
38
+ Result of the division
39
+ """
40
+ logger.debug("Calculating save division.")
41
+ logger.debug("numerator: %s", numerator)
42
+ logger.debug("denominator: %s", denominator)
43
+ logger.debug("default: %s", default)
44
+
45
+ if isinstance(numerator, (int, float, np.number)) and isinstance(
46
+ denominator, (int, float, np.number)
47
+ ):
48
+ output_shape = 1
49
+ else:
50
+ try:
51
+ output_shape = denominator.shape
52
+ except AttributeError:
53
+ output_shape = numerator.shape
54
+
55
+ ratio = np.divide(
56
+ numerator,
57
+ denominator,
58
+ out=np.ones(
59
+ output_shape,
60
+ dtype=float,
61
+ )
62
+ * default,
63
+ where=(denominator != 0),
64
+ )
65
+ if output_shape == 1:
66
+ return float(ratio)
67
+ return ratio
68
+
69
+
70
+ def weighted_percentile(
71
+ arr: np.ndarray,
72
+ percentile: np.ndarray,
73
+ weights: np.ndarray = None,
74
+ ):
75
+ """Calculate weighted percentile.
76
+
77
+ Implementation according to https://stackoverflow.com/a/29677616/11509698
78
+ (https://en.wikipedia.org/wiki/Percentile#The_weighted_percentile_method)
79
+
80
+ Parameters
81
+ ----------
82
+ arr : np.ndarray
83
+ Data array
84
+ percentile : np.ndarray
85
+ Percentile array
86
+ weights : np.ndarray
87
+ Weights array, by default None
88
+
89
+ Returns
90
+ -------
91
+ np.ndarray
92
+ Weighted percentile array
93
+ """
94
+ logger.debug("Calculating weighted percentile.")
95
+ logger.debug("arr: %s", arr)
96
+ logger.debug("percentile: %s", percentile)
97
+ logger.debug("weights: %s", weights)
98
+
99
+ # Set weights to one if no weights are given
100
+ if weights is None:
101
+ weights = np.ones_like(arr)
102
+
103
+ # Set dtype to float64 if the weights are too large
104
+ dtype = np.float64 if np.sum(weights) > 1000000 else np.float32
105
+
106
+ # Get an array sorting and sort the array and the weights
107
+ ix = np.argsort(arr)
108
+ arr = arr[ix]
109
+ weights = weights[ix]
110
+
111
+ # Return the cumulative sum
112
+ cdf = np.cumsum(weights, dtype=dtype) - 0.5 * weights
113
+ cdf -= cdf[0]
114
+ cdf /= cdf[-1]
115
+
116
+ # Return the linear interpolation
117
+ return np.interp(percentile, cdf, arr)
118
+
119
+
120
+ def calculate_efficiency(
121
+ sig_disc: np.ndarray,
122
+ bkg_disc: np.ndarray,
123
+ target_eff: float | list | np.ndarray,
124
+ return_cuts: bool = False,
125
+ sig_weights: np.ndarray = None,
126
+ bkg_weights: np.ndarray = None,
127
+ ):
128
+ """Calculate efficiency.
129
+
130
+ Parameters
131
+ ----------
132
+ sig_disc : np.ndarray
133
+ Signal discriminant
134
+ bkg_disc : np.ndarray
135
+ Background discriminant
136
+ target_eff : float or list or np.ndarray
137
+ Working point which is used for discriminant calculation
138
+ return_cuts : bool
139
+ Specifies if the cut values corresponding to the provided WPs are returned.
140
+ If target_eff is a float, only one cut value will be returned. If target_eff
141
+ is an array, target_eff is an array as well.
142
+ sig_weights : np.ndarray
143
+ Weights for signal events
144
+ bkg_weights : np.ndarray
145
+ Weights for background events
146
+
147
+ Returns
148
+ -------
149
+ eff : float or np.ndarray
150
+ Efficiency.
151
+ Return float if target_eff is a float, else np.ndarray
152
+ cutvalue : float or np.ndarray
153
+ Cutvalue if return_cuts is True.
154
+ Return float if target_eff is a float, else np.ndarray
155
+ """
156
+ logger.debug("Calculating efficiency.")
157
+ logger.debug("sig_disc: %s", sig_disc)
158
+ logger.debug("bkg_disc: %s", bkg_disc)
159
+ logger.debug("target_eff: %s", target_eff)
160
+ logger.debug("return_cuts: %s", return_cuts)
161
+ logger.debug("sig_weights: %s", sig_weights)
162
+ logger.debug("bkg_weights: %s", bkg_weights)
163
+
164
+ # float | np.ndarray for both target_eff and the returned values
165
+ return_float = False
166
+ if isinstance(target_eff, float):
167
+ return_float = True
168
+
169
+ # Flatten the target efficiencies
170
+ target_eff = np.asarray([target_eff]).flatten()
171
+
172
+ # Get the cutvalue for the given target efficiency
173
+ cutvalue = weighted_percentile(arr=sig_disc, percentile=1.0 - target_eff, weights=sig_weights)
174
+
175
+ # Sort the cutvalues to get the correct order
176
+ sorted_args = np.argsort(1 - target_eff)
177
+
178
+ # Get the histogram for the backgrounds
179
+ hist, _ = np.histogram(bkg_disc, (-np.inf, *cutvalue[sorted_args], np.inf), weights=bkg_weights)
180
+
181
+ # Calculate the efficiencies for the calculated cut values
182
+ eff = hist[::-1].cumsum()[-2::-1] / hist.sum()
183
+ eff = eff[sorted_args]
184
+
185
+ # Ensure that a float is returned if float was given
186
+ if return_float:
187
+ eff = eff[0]
188
+ cutvalue = cutvalue[0]
189
+
190
+ # Also return the cuts if wanted
191
+ if return_cuts:
192
+ return eff, cutvalue
193
+
194
+ return eff
195
+
196
+
197
+ def calculate_rejection(
198
+ sig_disc: np.ndarray,
199
+ bkg_disc: np.ndarray,
200
+ target_eff,
201
+ return_cuts: bool = False,
202
+ sig_weights: np.ndarray = None,
203
+ bkg_weights: np.ndarray = None,
204
+ smooth: bool = False,
205
+ ):
206
+ """Calculate rejection.
207
+
208
+ Parameters
209
+ ----------
210
+ sig_disc : np.ndarray
211
+ Signal discriminant
212
+ bkg_disc : np.ndarray
213
+ Background discriminant
214
+ target_eff : float or list
215
+ Working point which is used for discriminant calculation
216
+ return_cuts : bool
217
+ Specifies if the cut values corresponding to the provided WPs are returned.
218
+ If target_eff is a float, only one cut value will be returned. If target_eff
219
+ is an array, target_eff is an array as well.
220
+ sig_weights : np.ndarray
221
+ Weights for signal events, by default None
222
+ bkg_weights : np.ndarray
223
+ Weights for background events, by default None
224
+
225
+ Returns
226
+ -------
227
+ rej : float or np.ndarray
228
+ Rejection.
229
+ If target_eff is a float, a float is returned if it's a list a np.ndarray
230
+ cut_value : float or np.ndarray
231
+ Cutvalue if return_cuts is True.
232
+ If target_eff is a float, a float is returned if it's a list a np.ndarray
233
+ """
234
+ logger.debug("Calculating rejection.")
235
+ logger.debug("sig_disc: %s", sig_disc)
236
+ logger.debug("bkg_disc: %s", bkg_disc)
237
+ logger.debug("target_eff: %s", target_eff)
238
+ logger.debug("return_cuts: %s", return_cuts)
239
+ logger.debug("sig_weights: %s", sig_weights)
240
+ logger.debug("bkg_weights: %s", bkg_weights)
241
+ logger.debug("smooth: %s", smooth)
242
+
243
+ # Calculate the efficiency
244
+ eff = calculate_efficiency(
245
+ sig_disc=sig_disc,
246
+ bkg_disc=bkg_disc,
247
+ target_eff=target_eff,
248
+ return_cuts=return_cuts,
249
+ sig_weights=sig_weights,
250
+ bkg_weights=bkg_weights,
251
+ )
252
+
253
+ # Invert the efficiency to get a rejection
254
+ rej = save_divide(1, eff[0] if return_cuts else eff, np.inf)
255
+
256
+ # Smooth out the rejection if wanted
257
+ if smooth:
258
+ rej = gaussian_filter1d(rej, sigma=1, radius=2, mode="nearest")
259
+
260
+ # Return also the cut values if wanted
261
+ if return_cuts:
262
+ return rej, eff[1]
263
+
264
+ return rej
265
+
266
+
267
+ def calculate_efficiency_error(
268
+ arr: np.ndarray,
269
+ n_counts: int,
270
+ suppress_zero_divison_error: bool = False,
271
+ norm: bool = False,
272
+ ) -> np.ndarray:
273
+ """Calculate statistical efficiency uncertainty.
274
+
275
+ Parameters
276
+ ----------
277
+ arr : numpy.array
278
+ Efficiency values
279
+ n_counts : int
280
+ Number of used statistics to calculate efficiency
281
+ suppress_zero_divison_error : bool
282
+ Not raising Error for zero division
283
+ norm : bool, optional
284
+ If True, normed (relative) error is being calculated, by default False
285
+
286
+ Returns
287
+ -------
288
+ numpy.array
289
+ Efficiency uncertainties
290
+
291
+ Raises
292
+ ------
293
+ ValueError
294
+ If n_counts <=0
295
+
296
+ Notes
297
+ -----
298
+ This method uses binomial errors as described in section 2.2 of
299
+ https://inspirehep.net/files/57287ac8e45a976ab423f3dd456af694
300
+ """
301
+ logger.debug("Calculating efficiency error.")
302
+ logger.debug("arr: %s", arr)
303
+ logger.debug("n_counts: %i", n_counts)
304
+ logger.debug("suppress_zero_divison_error: %s", suppress_zero_divison_error)
305
+ logger.debug("norm: %s", norm)
306
+ if np.any(n_counts <= 0) and not suppress_zero_divison_error:
307
+ raise ValueError(f"You passed as argument `N` {n_counts} but it has to be larger 0.")
308
+ if norm:
309
+ return np.sqrt(arr * (1 - arr) / n_counts) / arr
310
+ return np.sqrt(arr * (1 - arr) / n_counts)
311
+
312
+
313
+ def calculate_rejection_error(
314
+ arr: np.ndarray,
315
+ n_counts: int,
316
+ norm: bool = False,
317
+ ) -> np.ndarray:
318
+ """Calculate the rejection uncertainties.
319
+
320
+ Parameters
321
+ ----------
322
+ arr : numpy.array
323
+ Rejection values
324
+ n_counts : int
325
+ Number of used statistics to calculate rejection
326
+ norm : bool, optional
327
+ If True, normed (relative) error is being calculated, by default False
328
+
329
+ Returns
330
+ -------
331
+ numpy.array
332
+ Rejection uncertainties
333
+
334
+ Raises
335
+ ------
336
+ ValueError
337
+ If n_counts <=0
338
+ ValueError
339
+ If any rejection value is 0
340
+
341
+ Notes
342
+ -----
343
+ Special case of `eff_err()`
344
+ """
345
+ logger.debug("Calculating rejection error.")
346
+ logger.debug("arr: %s", arr)
347
+ logger.debug("n_counts: %i", n_counts)
348
+ logger.debug("norm: %s", norm)
349
+ if np.any(n_counts <= 0):
350
+ raise ValueError(f"You passed as argument `n_counts` {n_counts} but it has to be larger 0.")
351
+ if np.any(arr == 0):
352
+ raise ValueError("One rejection value is 0, cannot calculate error.")
353
+ if norm:
354
+ return np.power(arr, 2) * calculate_efficiency_error(1 / arr, n_counts) / arr
355
+ return np.power(arr, 2) * calculate_efficiency_error(1 / arr, n_counts)
356
+
357
+
358
+ def get_discriminant(
359
+ jets: np.ndarray,
360
+ tagger: str,
361
+ signal: Label,
362
+ flavours: LabelContainer,
363
+ fraction_values: dict[str, float],
364
+ epsilon: float = 1e-10,
365
+ ) -> np.ndarray:
366
+ """Calculate the tagging discriminant for a given tagger.
367
+
368
+ Calculated as the logarithm of the ratio of a specified signal probability
369
+ to a weighted sum ofbackground probabilities.
370
+
371
+ Parameters
372
+ ----------
373
+ jets : np.ndarray
374
+ Structured array of jets containing tagger outputs
375
+ tagger : str
376
+ Name of the tagger
377
+ signal : Label
378
+ Signal flavour (bjets/cjets or hbb/hcc)
379
+ fraction_values : dict
380
+ Dict with the fraction values for the background classes for the given tagger
381
+ epsilon : float, optional
382
+ Small number to avoid division by zero, by default 1e-10
383
+
384
+ Returns
385
+ -------
386
+ np.ndarray
387
+ Array of discriminant values.
388
+
389
+ Raises
390
+ ------
391
+ ValueError
392
+ If the signal flavour is not recognised.
393
+ """
394
+ # Init the denominator
395
+ denominator = 0.0
396
+
397
+ # Loop over background flavours
398
+ for flav in flavours:
399
+ # Skip signal flavour for denominator
400
+ if flav == signal:
401
+ continue
402
+
403
+ # Get the probability name of the tagger/flavour combo + fraction value
404
+ prob_name = f"{tagger}_{flav.px}"
405
+ fraction_value = fraction_values[flav.frac_str]
406
+
407
+ # If fraction_value for the given flavour is zero, skip it
408
+ if fraction_value == 0:
409
+ continue
410
+
411
+ # Check that the probability value for the flavour is available
412
+ if fraction_value > 0 and prob_name not in jets.dtype.names:
413
+ raise ValueError(
414
+ f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
415
+ "not found in input array."
416
+ )
417
+
418
+ # Update denominator
419
+ denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
420
+
421
+ # Calculate numerator
422
+ signal_field = f"{tagger}_{signal.px}"
423
+
424
+ # Check that the probability of the signal is available
425
+ if signal_field not in jets.dtype.names:
426
+ raise ValueError(
427
+ f"No signal probability value(s) found for tagger {tagger}. "
428
+ f"Missing variable: {signal_field}"
429
+ )
430
+
431
+ return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
@@ -14,7 +14,7 @@ from ftag import Flavours
14
14
  from ftag.cli_utils import HelpFormatter
15
15
  from ftag.cuts import Cuts
16
16
  from ftag.hdf5 import H5Reader
17
- from ftag.wps.discriminant import get_discriminant
17
+ from ftag.utils import get_discriminant
18
18
 
19
19
  if TYPE_CHECKING: # pragma: no cover
20
20
  from collections.abc import Sequence
ftag/wps/__init__.py DELETED
File without changes
ftag/wps/discriminant.py DELETED
@@ -1,84 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING
4
-
5
- import numpy as np
6
-
7
- if TYPE_CHECKING: # pragma: no cover
8
- from ftag.labels import Label, LabelContainer
9
-
10
-
11
- def get_discriminant(
12
- jets: np.ndarray,
13
- tagger: str,
14
- signal: Label,
15
- flavours: LabelContainer,
16
- fraction_values: dict[str, float],
17
- epsilon: float = 1e-10,
18
- ) -> np.ndarray:
19
- """Calculate the tagging discriminant for a given tagger.
20
-
21
- Calculated as the logarithm of the ratio of a specified signal probability
22
- to a weighted sum ofbackground probabilities.
23
-
24
- Parameters
25
- ----------
26
- jets : np.ndarray
27
- Structured array of jets containing tagger outputs
28
- tagger : str
29
- Name of the tagger
30
- signal : Label
31
- Signal flavour (bjets/cjets or hbb/hcc)
32
- fraction_values : dict
33
- Dict with the fraction values for the background classes for the given tagger
34
- epsilon : float, optional
35
- Small number to avoid division by zero, by default 1e-10
36
-
37
- Returns
38
- -------
39
- np.ndarray
40
- Array of discriminant values.
41
-
42
- Raises
43
- ------
44
- ValueError
45
- If the signal flavour is not recognised.
46
- """
47
- # Init the denominator
48
- denominator = 0.0
49
-
50
- # Loop over background flavours
51
- for flav in flavours:
52
- # Skip signal flavour for denominator
53
- if flav == signal:
54
- continue
55
-
56
- # Get the probability name of the tagger/flavour combo + fraction value
57
- prob_name = f"{tagger}_{flav.px}"
58
- fraction_value = fraction_values[flav.frac_str]
59
-
60
- # If fraction_value for the given flavour is zero, skip it
61
- if fraction_value == 0:
62
- continue
63
-
64
- # Check that the probability value for the flavour is available
65
- if fraction_value > 0 and prob_name not in jets.dtype.names:
66
- raise ValueError(
67
- f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
68
- "not found in input array."
69
- )
70
-
71
- # Update denominator
72
- denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
73
-
74
- # Calculate numerator
75
- signal_field = f"{tagger}_{signal.px}"
76
-
77
- # Check that the probability of the signal is available
78
- if signal_field not in jets.dtype.names:
79
- raise ValueError(
80
- f"No signal probability value(s) found for tagger {tagger}. "
81
- f"Missing variable: {signal_field}"
82
- )
83
-
84
- return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))