atlas-ftag-tools 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.4.dist-info}/METADATA +1 -1
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.4.dist-info}/RECORD +12 -10
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.4.dist-info}/WHEEL +1 -1
- ftag/__init__.py +1 -1
- ftag/cuts.py +3 -1
- ftag/flavour.py +7 -1
- ftag/hdf5/h5reader.py +10 -12
- ftag/labeller.py +88 -0
- ftag/mock.py +27 -12
- ftag/track_selector.py +70 -0
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.4.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.4.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,28 @@
|
|
1
|
-
ftag/__init__.py,sha256=
|
1
|
+
ftag/__init__.py,sha256=Mx2Emsw4TM1YL0wTQMHK36EaQE_ImeV6ukiz1X5BZAU,629
|
2
2
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
3
|
-
ftag/cuts.py,sha256=
|
4
|
-
ftag/flavour.py,sha256=
|
3
|
+
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
4
|
+
ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
|
5
5
|
ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
|
6
6
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
7
|
-
ftag/
|
7
|
+
ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
|
8
|
+
ftag/mock.py,sha256=_oy-r3eLllFy33NAoZaKfAx-Rp2vrCdrGj3UsTMks94,4740
|
8
9
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
9
10
|
ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
|
10
11
|
ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
|
12
|
+
ftag/track_selector.py,sha256=piSYAN_IkOsrXxKXjXbJpMSseUig5P2BJW5mCwsMUDM,2535
|
11
13
|
ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
|
12
14
|
ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
|
13
15
|
ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
|
14
16
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
15
|
-
ftag/hdf5/h5reader.py,sha256=
|
17
|
+
ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
|
16
18
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
17
19
|
ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
|
18
20
|
ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
|
19
21
|
ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
22
|
ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
|
21
23
|
ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
|
22
|
-
atlas_ftag_tools-0.2.
|
23
|
-
atlas_ftag_tools-0.2.
|
24
|
-
atlas_ftag_tools-0.2.
|
25
|
-
atlas_ftag_tools-0.2.
|
26
|
-
atlas_ftag_tools-0.2.
|
24
|
+
atlas_ftag_tools-0.2.4.dist-info/METADATA,sha256=f4aCu6JmItUBp5EmTzbrqhC5-Wsy7uiOiBO9yufyacQ,5169
|
25
|
+
atlas_ftag_tools-0.2.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
26
|
+
atlas_ftag_tools-0.2.4.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
|
27
|
+
atlas_ftag_tools-0.2.4.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
28
|
+
atlas_ftag_tools-0.2.4.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/cuts.py
CHANGED
@@ -79,7 +79,9 @@ class Cuts:
|
|
79
79
|
def ignore(self, variables: list[str]):
|
80
80
|
return Cuts(tuple(c for c in self if c.variable not in variables))
|
81
81
|
|
82
|
-
def __call__(self, array) -> CutsResult:
|
82
|
+
def __call__(self, array: np.ndarray) -> CutsResult:
|
83
|
+
if array.ndim == 2:
|
84
|
+
raise ValueError("This interface only supports jet selections")
|
83
85
|
keep = np.arange(len(array))
|
84
86
|
for cut in self.cuts:
|
85
87
|
idx = cut(array)
|
ftag/flavour.py
CHANGED
@@ -42,6 +42,9 @@ class Flavour:
|
|
42
42
|
def __str__(self) -> str:
|
43
43
|
return self.name
|
44
44
|
|
45
|
+
def __lt__(self, other) -> bool:
|
46
|
+
return self.name < other.name
|
47
|
+
|
45
48
|
|
46
49
|
@dataclass
|
47
50
|
class FlavourContainer:
|
@@ -81,7 +84,10 @@ class FlavourContainer:
|
|
81
84
|
return list(dict.fromkeys(f.category for f in self))
|
82
85
|
|
83
86
|
def by_category(self, category: str) -> FlavourContainer:
|
84
|
-
|
87
|
+
f = FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
|
88
|
+
if not f.flavours:
|
89
|
+
raise KeyError(f"No flavours with category '{category}' found")
|
90
|
+
return f
|
85
91
|
|
86
92
|
def from_cuts(self, cuts: list | Cuts) -> Flavour:
|
87
93
|
if isinstance(cuts, list):
|
ftag/hdf5/h5reader.py
CHANGED
@@ -150,8 +150,13 @@ class H5Reader:
|
|
150
150
|
transform : Transform | None, optional
|
151
151
|
Transform to apply to data, by default None
|
152
152
|
equal_jets : bool, optional
|
153
|
-
Take the same number of jets (weighted) from each sample, by default True
|
154
|
-
|
153
|
+
Take the same number of jets (weighted) from each sample, by default True.
|
154
|
+
This is useful when you specify a list of DSIDs for the sample and they are
|
155
|
+
qualitatively different, and you want to ensure that you always return batches
|
156
|
+
with jets from all DSIDs. This is used for example in the QCD resampling for Xbb.
|
157
|
+
If False, use all jets in each sample, allowing for the full available statistics
|
158
|
+
to be used. Useful for example if you have multiple ttbar samples and you want to
|
159
|
+
use all available jets from each sample.
|
155
160
|
"""
|
156
161
|
|
157
162
|
fname: Path | str | list[Path | str]
|
@@ -162,17 +167,10 @@ class H5Reader:
|
|
162
167
|
weights: list[float] | None = None
|
163
168
|
do_remove_inf: bool = False
|
164
169
|
transform: Transform | None = None
|
165
|
-
equal_jets: bool =
|
170
|
+
equal_jets: bool = False
|
166
171
|
|
167
172
|
def __post_init__(self) -> None:
|
168
173
|
self.rng = np.random.default_rng(42)
|
169
|
-
if not self.equal_jets:
|
170
|
-
log.warning(
|
171
|
-
"equal_jets is set to False, which will result in different number of jets taken"
|
172
|
-
" from each sample. Be aware that this can affect the resampling, so make sure you"
|
173
|
-
" know what you are doing."
|
174
|
-
)
|
175
|
-
|
176
174
|
if isinstance(self.fname, (str, Path)):
|
177
175
|
self.fname = [self.fname]
|
178
176
|
|
@@ -283,8 +281,8 @@ class H5Reader:
|
|
283
281
|
try:
|
284
282
|
samples.append(next(stream))
|
285
283
|
|
286
|
-
# if equal_jets is True,
|
287
|
-
# otherwise if
|
284
|
+
# if equal_jets is True, stop when any sample is done
|
285
|
+
# otherwise if stream is exhausted, mark it as such and continue
|
288
286
|
except StopIteration:
|
289
287
|
if self.equal_jets:
|
290
288
|
return
|
ftag/labeller.py
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ftag import Flavours
|
8
|
+
from ftag.flavour import Flavour, FlavourContainer
|
9
|
+
from ftag.hdf5 import join_structured_arrays, structured_from_dict
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class Labeller:
|
14
|
+
"""
|
15
|
+
Defines a labelling scheme.
|
16
|
+
|
17
|
+
Labels are [0, ..., n] and are assigned using pre-defined selections.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
labels : FlavourContainer | list[str | Flavour]
|
22
|
+
The labels to be use.
|
23
|
+
require_labels : bool
|
24
|
+
Whether to require that all objects are labelled.
|
25
|
+
"""
|
26
|
+
|
27
|
+
labels: FlavourContainer | list[str | Flavour]
|
28
|
+
require_labels: bool = True
|
29
|
+
|
30
|
+
def __post_init__(self) -> None:
|
31
|
+
if isinstance(self.labels, FlavourContainer):
|
32
|
+
self.labels = list(self.labels)
|
33
|
+
self.labels = sorted([Flavours[label] for label in self.labels])
|
34
|
+
|
35
|
+
def get_labels(self, array: np.ndarray) -> np.ndarray:
|
36
|
+
"""
|
37
|
+
Returns the labels for the given array.
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
array : np.ndarray
|
42
|
+
The array to label.
|
43
|
+
|
44
|
+
Returns
|
45
|
+
-------
|
46
|
+
np.ndarray
|
47
|
+
The labels for the given array.
|
48
|
+
|
49
|
+
Raises
|
50
|
+
------
|
51
|
+
ValueError
|
52
|
+
If the `require_labels` attribute is set to `True` and some objects were not labelled.
|
53
|
+
"""
|
54
|
+
labels = -1 * np.ones_like(array, dtype=int)
|
55
|
+
for i, label in enumerate(self.labels):
|
56
|
+
labels[label.cuts(array).idx] = i
|
57
|
+
|
58
|
+
if self.require_labels and -1 in labels:
|
59
|
+
raise ValueError("Some objects were not labelled")
|
60
|
+
|
61
|
+
return labels[labels != -1]
|
62
|
+
|
63
|
+
def add_labels(self, array: np.ndarray, label_name: str = "labels") -> np.ndarray:
|
64
|
+
"""
|
65
|
+
Adds the labels to the given array.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
array : np.ndarray
|
70
|
+
The array to label.
|
71
|
+
label_name : str
|
72
|
+
The name of the label column.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
np.ndarray
|
77
|
+
The array with the labels added.
|
78
|
+
|
79
|
+
Raises
|
80
|
+
------
|
81
|
+
ValueError
|
82
|
+
If the `require_labels` attribute is set to `False`.
|
83
|
+
"""
|
84
|
+
if not self.require_labels:
|
85
|
+
raise ValueError("Cannot add labels if require_labels is set to False")
|
86
|
+
labels = self.get_labels(array)
|
87
|
+
labels = structured_from_dict({label_name: labels})
|
88
|
+
return join_structured_arrays([array, labels])
|
ftag/mock.py
CHANGED
@@ -84,12 +84,7 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
|
|
84
84
|
return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
|
85
85
|
|
86
86
|
|
87
|
-
def
|
88
|
-
num_jets=1000,
|
89
|
-
fname: str | None = None,
|
90
|
-
tracks_name: str = "tracks",
|
91
|
-
num_tracks: int = 40,
|
92
|
-
) -> tuple[str, h5py.File]:
|
87
|
+
def mock_jets(num_jets=1000) -> np.ndarray:
|
93
88
|
# setup jets
|
94
89
|
rng = np.random.default_rng(42)
|
95
90
|
jets_dtype = np.dtype(JET_VARS)
|
@@ -106,7 +101,31 @@ def get_mock_file(
|
|
106
101
|
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
|
107
102
|
scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
|
108
103
|
xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
|
109
|
-
|
104
|
+
return join_structured_arrays([jets, scores, xbb_scores])
|
105
|
+
|
106
|
+
|
107
|
+
def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
|
108
|
+
rng = np.random.default_rng(42)
|
109
|
+
tracks_dtype = np.dtype(TRACK_VARS)
|
110
|
+
tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
|
111
|
+
tracks["d0"] *= 5
|
112
|
+
|
113
|
+
# for the shared hits, add some reasonable integer values
|
114
|
+
tracks["numberOfPixelSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
|
115
|
+
tracks["numberOfSCTSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
|
116
|
+
|
117
|
+
valid = rng.choice([True, False], size=(num_jets, num_tracks))
|
118
|
+
valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
|
119
|
+
return join_structured_arrays([tracks, valid])
|
120
|
+
|
121
|
+
|
122
|
+
def get_mock_file(
|
123
|
+
num_jets=1000,
|
124
|
+
fname: str | None = None,
|
125
|
+
tracks_name: str = "tracks",
|
126
|
+
num_tracks: int = 40,
|
127
|
+
) -> tuple[str, h5py.File]:
|
128
|
+
jets = mock_jets(num_jets)
|
110
129
|
|
111
130
|
# create a tempfile in a new folder
|
112
131
|
if fname is None:
|
@@ -120,11 +139,7 @@ def get_mock_file(
|
|
120
139
|
|
121
140
|
# setup tracks
|
122
141
|
if tracks_name:
|
123
|
-
|
124
|
-
tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
|
125
|
-
valid = rng.choice([True, False], size=(num_jets, num_tracks))
|
126
|
-
valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
|
127
|
-
tracks = join_structured_arrays([tracks, valid])
|
142
|
+
tracks = mock_tracks(num_jets, num_tracks)
|
128
143
|
f.create_dataset(tracks_name, data=tracks)
|
129
144
|
|
130
145
|
return fname, f
|
ftag/track_selector.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ftag.cuts import Cut, Cuts
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class TrackSelector:
|
12
|
+
"""
|
13
|
+
Apply track selections to a set of tracks stored in a structured numpy array.
|
14
|
+
|
15
|
+
The array is assumed to have shape (n_jets, n_tracks, n_features).
|
16
|
+
Applying cuts will NaN out the tracks that do not pass the cuts,
|
17
|
+
but leave the shape of the array unchanged.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
cuts : Cuts
|
22
|
+
The cuts to apply to the tracks
|
23
|
+
valid_str : str
|
24
|
+
The name of the field in the tracks that indicates whether the track is
|
25
|
+
"""
|
26
|
+
|
27
|
+
cuts: Cuts
|
28
|
+
valid_str: str = "valid"
|
29
|
+
|
30
|
+
def __call__(self, tracks: np.ndarray) -> np.ndarray:
|
31
|
+
# get a bool array for all tracks passing before any cuts
|
32
|
+
rm_idx = np.zeros_like(tracks[self.valid_str], dtype=bool)
|
33
|
+
|
34
|
+
# apply the cuts
|
35
|
+
for cut in self.cuts.cuts:
|
36
|
+
# remove valid track indices that do not pass the selection
|
37
|
+
keep_idx = self._nshared_cut(cut, tracks) if cut.variable == "NSHARED" else cut(tracks)
|
38
|
+
rm_idx[tracks[self.valid_str] & ~keep_idx] = True
|
39
|
+
|
40
|
+
# set the values of the tracks that do not pass the cuts to
|
41
|
+
for var in tracks.dtype.names:
|
42
|
+
if issubclass(tracks[var].dtype.type, np.floating):
|
43
|
+
tracks[var][rm_idx] = np.nan
|
44
|
+
elif issubclass(tracks[var].dtype.type, np.integer):
|
45
|
+
tracks[var][rm_idx] = -1
|
46
|
+
elif issubclass(tracks[var].dtype.type, np.bool_):
|
47
|
+
tracks[var][rm_idx] = False
|
48
|
+
else:
|
49
|
+
raise TypeError(f"Unknown dtype {tracks[var].dtype}")
|
50
|
+
|
51
|
+
# specifically set the valid flag to false (even though it's already false by now)
|
52
|
+
tracks[rm_idx][self.valid_str] = False
|
53
|
+
|
54
|
+
return tracks
|
55
|
+
|
56
|
+
def _nshared_cut(self, cut: Cut, tracks: np.ndarray) -> np.ndarray:
|
57
|
+
# hack to apply the FTAG shared hit cut, which requires an intermediate step
|
58
|
+
if cut.variable == "NSHARED" and "NSHARED" in tracks.dtype.names:
|
59
|
+
raise ValueError("NSHARED is a reserved variable name")
|
60
|
+
|
61
|
+
# compute
|
62
|
+
n_pix_shared = tracks["numberOfPixelSharedHits"]
|
63
|
+
n_sct_shared = tracks["numberOfSCTSharedHits"]
|
64
|
+
n_module_shared = n_pix_shared + n_sct_shared / 2
|
65
|
+
|
66
|
+
# convert n_module_shared to structured array
|
67
|
+
n_module_shared = n_module_shared.view(dtype=[(cut.variable, n_module_shared.dtype)])
|
68
|
+
|
69
|
+
# select
|
70
|
+
return cut(n_module_shared)
|
File without changes
|
File without changes
|