atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ftag/wps/discriminant.py DELETED
@@ -1,84 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TYPE_CHECKING
4
-
5
- import numpy as np
6
-
7
- if TYPE_CHECKING: # pragma: no cover
8
- from ftag.labels import Label, LabelContainer
9
-
10
-
11
- def get_discriminant(
12
- jets: np.ndarray,
13
- tagger: str,
14
- signal: Label,
15
- flavours: LabelContainer,
16
- fraction_values: dict[str, float],
17
- epsilon: float = 1e-10,
18
- ) -> np.ndarray:
19
- """Calculate the tagging discriminant for a given tagger.
20
-
21
- Calculated as the logarithm of the ratio of a specified signal probability
22
- to a weighted sum ofbackground probabilities.
23
-
24
- Parameters
25
- ----------
26
- jets : np.ndarray
27
- Structured array of jets containing tagger outputs
28
- tagger : str
29
- Name of the tagger
30
- signal : Label
31
- Signal flavour (bjets/cjets or hbb/hcc)
32
- fraction_values : dict
33
- Dict with the fraction values for the background classes for the given tagger
34
- epsilon : float, optional
35
- Small number to avoid division by zero, by default 1e-10
36
-
37
- Returns
38
- -------
39
- np.ndarray
40
- Array of discriminant values.
41
-
42
- Raises
43
- ------
44
- ValueError
45
- If the signal flavour is not recognised.
46
- """
47
- # Init the denominator
48
- denominator = 0.0
49
-
50
- # Loop over background flavours
51
- for flav in flavours:
52
- # Skip signal flavour for denominator
53
- if flav == signal:
54
- continue
55
-
56
- # Get the probability name of the tagger/flavour combo + fraction value
57
- prob_name = f"{tagger}_{flav.px}"
58
- fraction_value = fraction_values[flav.frac_str]
59
-
60
- # If fraction_value for the given flavour is zero, skip it
61
- if fraction_value == 0:
62
- continue
63
-
64
- # Check that the probability value for the flavour is available
65
- if fraction_value > 0 and prob_name not in jets.dtype.names:
66
- raise ValueError(
67
- f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
68
- "not found in input array."
69
- )
70
-
71
- # Update denominator
72
- denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
73
-
74
- # Calculate numerator
75
- signal_field = f"{tagger}_{signal.px}"
76
-
77
- # Check that the probability of the signal is available
78
- if signal_field not in jets.dtype.names:
79
- raise ValueError(
80
- f"No signal probability value(s) found for tagger {tagger}. "
81
- f"Missing variable: {signal_field}"
82
- )
83
-
84
- return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))