atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_ftag_tools-0.2.11.dist-info/METADATA +53 -0
- atlas_ftag_tools-0.2.11.dist-info/RECORD +32 -0
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/entry_points.txt +2 -1
- atlas_ftag_tools-0.2.11.dist-info/licenses/LICENSE +201 -0
- ftag/__init__.py +13 -12
- ftag/flavours.yaml +33 -12
- ftag/fraction_optimization.py +184 -0
- ftag/hdf5/__init__.py +5 -3
- ftag/hdf5/h5add_col.py +391 -0
- ftag/hdf5/h5writer.py +12 -1
- ftag/labels.py +10 -2
- ftag/utils/__init__.py +24 -0
- ftag/utils/logging.py +123 -0
- ftag/utils/metrics.py +431 -0
- ftag/vds.py +39 -4
- ftag/{wps/working_points.py → working_points.py} +1 -1
- atlas_ftag_tools-0.2.9.dist-info/METADATA +0 -150
- atlas_ftag_tools-0.2.9.dist-info/RECORD +0 -28
- ftag/wps/__init__.py +0 -0
- ftag/wps/discriminant.py +0 -84
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/top_level.txt +0 -0
ftag/wps/discriminant.py
DELETED
@@ -1,84 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
|
7
|
-
if TYPE_CHECKING: # pragma: no cover
|
8
|
-
from ftag.labels import Label, LabelContainer
|
9
|
-
|
10
|
-
|
11
|
-
def get_discriminant(
|
12
|
-
jets: np.ndarray,
|
13
|
-
tagger: str,
|
14
|
-
signal: Label,
|
15
|
-
flavours: LabelContainer,
|
16
|
-
fraction_values: dict[str, float],
|
17
|
-
epsilon: float = 1e-10,
|
18
|
-
) -> np.ndarray:
|
19
|
-
"""Calculate the tagging discriminant for a given tagger.
|
20
|
-
|
21
|
-
Calculated as the logarithm of the ratio of a specified signal probability
|
22
|
-
to a weighted sum ofbackground probabilities.
|
23
|
-
|
24
|
-
Parameters
|
25
|
-
----------
|
26
|
-
jets : np.ndarray
|
27
|
-
Structured array of jets containing tagger outputs
|
28
|
-
tagger : str
|
29
|
-
Name of the tagger
|
30
|
-
signal : Label
|
31
|
-
Signal flavour (bjets/cjets or hbb/hcc)
|
32
|
-
fraction_values : dict
|
33
|
-
Dict with the fraction values for the background classes for the given tagger
|
34
|
-
epsilon : float, optional
|
35
|
-
Small number to avoid division by zero, by default 1e-10
|
36
|
-
|
37
|
-
Returns
|
38
|
-
-------
|
39
|
-
np.ndarray
|
40
|
-
Array of discriminant values.
|
41
|
-
|
42
|
-
Raises
|
43
|
-
------
|
44
|
-
ValueError
|
45
|
-
If the signal flavour is not recognised.
|
46
|
-
"""
|
47
|
-
# Init the denominator
|
48
|
-
denominator = 0.0
|
49
|
-
|
50
|
-
# Loop over background flavours
|
51
|
-
for flav in flavours:
|
52
|
-
# Skip signal flavour for denominator
|
53
|
-
if flav == signal:
|
54
|
-
continue
|
55
|
-
|
56
|
-
# Get the probability name of the tagger/flavour combo + fraction value
|
57
|
-
prob_name = f"{tagger}_{flav.px}"
|
58
|
-
fraction_value = fraction_values[flav.frac_str]
|
59
|
-
|
60
|
-
# If fraction_value for the given flavour is zero, skip it
|
61
|
-
if fraction_value == 0:
|
62
|
-
continue
|
63
|
-
|
64
|
-
# Check that the probability value for the flavour is available
|
65
|
-
if fraction_value > 0 and prob_name not in jets.dtype.names:
|
66
|
-
raise ValueError(
|
67
|
-
f"Nonzero fraction value for {flav.name}, but '{prob_name}' "
|
68
|
-
"not found in input array."
|
69
|
-
)
|
70
|
-
|
71
|
-
# Update denominator
|
72
|
-
denominator += jets[prob_name] * fraction_value if prob_name in jets.dtype.names else 0
|
73
|
-
|
74
|
-
# Calculate numerator
|
75
|
-
signal_field = f"{tagger}_{signal.px}"
|
76
|
-
|
77
|
-
# Check that the probability of the signal is available
|
78
|
-
if signal_field not in jets.dtype.names:
|
79
|
-
raise ValueError(
|
80
|
-
f"No signal probability value(s) found for tagger {tagger}. "
|
81
|
-
f"Missing variable: {signal_field}"
|
82
|
-
)
|
83
|
-
|
84
|
-
return np.log((jets[signal_field] + epsilon) / (denominator + epsilon))
|
File without changes
|