atlas-ftag-tools 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.3.dist-info}/METADATA +1 -1
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.3.dist-info}/RECORD +11 -9
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.3.dist-info}/WHEEL +1 -1
- ftag/__init__.py +1 -1
- ftag/cuts.py +3 -1
- ftag/flavour.py +7 -1
- ftag/labeller.py +88 -0
- ftag/mock.py +22 -12
- ftag/track_selector.py +53 -0
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.3.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.2.dist-info → atlas_ftag_tools-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
|
|
1
|
-
ftag/__init__.py,sha256=
|
1
|
+
ftag/__init__.py,sha256=i5XMWwkyGmwIIIaO3veRyOnTsJ0qKyW1V6HsuYc9Dm4,629
|
2
2
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
3
|
-
ftag/cuts.py,sha256=
|
4
|
-
ftag/flavour.py,sha256=
|
3
|
+
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
4
|
+
ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
|
5
5
|
ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
|
6
6
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
7
|
-
ftag/
|
7
|
+
ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
|
8
|
+
ftag/mock.py,sha256=FboI1Kq6aKZv43SpubiLkJvn4BSMDv2Fl2UniuhxspU,4502
|
8
9
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
9
10
|
ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
|
10
11
|
ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
|
12
|
+
ftag/track_selector.py,sha256=WQ6lzY8n6pumIVfLMLgjlTkWUaIQu-8Dq3FKsmCsG_0,1736
|
11
13
|
ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
|
12
14
|
ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
|
13
15
|
ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
|
@@ -19,8 +21,8 @@ ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
|
|
19
21
|
ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
22
|
ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
|
21
23
|
ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
|
22
|
-
atlas_ftag_tools-0.2.
|
23
|
-
atlas_ftag_tools-0.2.
|
24
|
-
atlas_ftag_tools-0.2.
|
25
|
-
atlas_ftag_tools-0.2.
|
26
|
-
atlas_ftag_tools-0.2.
|
24
|
+
atlas_ftag_tools-0.2.3.dist-info/METADATA,sha256=XP7QCxOKz-QUKKx5ds8ldz13n_qfTVcP-SZciExdIr0,5169
|
25
|
+
atlas_ftag_tools-0.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
26
|
+
atlas_ftag_tools-0.2.3.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
|
27
|
+
atlas_ftag_tools-0.2.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
28
|
+
atlas_ftag_tools-0.2.3.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/cuts.py
CHANGED
@@ -79,7 +79,9 @@ class Cuts:
|
|
79
79
|
def ignore(self, variables: list[str]):
|
80
80
|
return Cuts(tuple(c for c in self if c.variable not in variables))
|
81
81
|
|
82
|
-
def __call__(self, array) -> CutsResult:
|
82
|
+
def __call__(self, array: np.ndarray) -> CutsResult:
|
83
|
+
if array.ndim == 2:
|
84
|
+
raise ValueError("This interface only supports jet selections")
|
83
85
|
keep = np.arange(len(array))
|
84
86
|
for cut in self.cuts:
|
85
87
|
idx = cut(array)
|
ftag/flavour.py
CHANGED
@@ -42,6 +42,9 @@ class Flavour:
|
|
42
42
|
def __str__(self) -> str:
|
43
43
|
return self.name
|
44
44
|
|
45
|
+
def __lt__(self, other) -> bool:
|
46
|
+
return self.name < other.name
|
47
|
+
|
45
48
|
|
46
49
|
@dataclass
|
47
50
|
class FlavourContainer:
|
@@ -81,7 +84,10 @@ class FlavourContainer:
|
|
81
84
|
return list(dict.fromkeys(f.category for f in self))
|
82
85
|
|
83
86
|
def by_category(self, category: str) -> FlavourContainer:
|
84
|
-
|
87
|
+
f = FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
|
88
|
+
if not f.flavours:
|
89
|
+
raise KeyError(f"No flavours with category '{category}' found")
|
90
|
+
return f
|
85
91
|
|
86
92
|
def from_cuts(self, cuts: list | Cuts) -> Flavour:
|
87
93
|
if isinstance(cuts, list):
|
ftag/labeller.py
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ftag import Flavours
|
8
|
+
from ftag.flavour import Flavour, FlavourContainer
|
9
|
+
from ftag.hdf5 import join_structured_arrays, structured_from_dict
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class Labeller:
|
14
|
+
"""
|
15
|
+
Defines a labelling scheme.
|
16
|
+
|
17
|
+
Labels are [0, ..., n] and are assigned using pre-defined selections.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
labels : FlavourContainer | list[str | Flavour]
|
22
|
+
The labels to be use.
|
23
|
+
require_labels : bool
|
24
|
+
Whether to require that all objects are labelled.
|
25
|
+
"""
|
26
|
+
|
27
|
+
labels: FlavourContainer | list[str | Flavour]
|
28
|
+
require_labels: bool = True
|
29
|
+
|
30
|
+
def __post_init__(self) -> None:
|
31
|
+
if isinstance(self.labels, FlavourContainer):
|
32
|
+
self.labels = list(self.labels)
|
33
|
+
self.labels = sorted([Flavours[label] for label in self.labels])
|
34
|
+
|
35
|
+
def get_labels(self, array: np.ndarray) -> np.ndarray:
|
36
|
+
"""
|
37
|
+
Returns the labels for the given array.
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
array : np.ndarray
|
42
|
+
The array to label.
|
43
|
+
|
44
|
+
Returns
|
45
|
+
-------
|
46
|
+
np.ndarray
|
47
|
+
The labels for the given array.
|
48
|
+
|
49
|
+
Raises
|
50
|
+
------
|
51
|
+
ValueError
|
52
|
+
If the `require_labels` attribute is set to `True` and some objects were not labelled.
|
53
|
+
"""
|
54
|
+
labels = -1 * np.ones_like(array, dtype=int)
|
55
|
+
for i, label in enumerate(self.labels):
|
56
|
+
labels[label.cuts(array).idx] = i
|
57
|
+
|
58
|
+
if self.require_labels and -1 in labels:
|
59
|
+
raise ValueError("Some objects were not labelled")
|
60
|
+
|
61
|
+
return labels[labels != -1]
|
62
|
+
|
63
|
+
def add_labels(self, array: np.ndarray, label_name: str = "labels") -> np.ndarray:
|
64
|
+
"""
|
65
|
+
Adds the labels to the given array.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
array : np.ndarray
|
70
|
+
The array to label.
|
71
|
+
label_name : str
|
72
|
+
The name of the label column.
|
73
|
+
|
74
|
+
Returns
|
75
|
+
-------
|
76
|
+
np.ndarray
|
77
|
+
The array with the labels added.
|
78
|
+
|
79
|
+
Raises
|
80
|
+
------
|
81
|
+
ValueError
|
82
|
+
If the `require_labels` attribute is set to `False`.
|
83
|
+
"""
|
84
|
+
if not self.require_labels:
|
85
|
+
raise ValueError("Cannot add labels if require_labels is set to False")
|
86
|
+
labels = self.get_labels(array)
|
87
|
+
labels = structured_from_dict({label_name: labels})
|
88
|
+
return join_structured_arrays([array, labels])
|
ftag/mock.py
CHANGED
@@ -84,12 +84,7 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
|
|
84
84
|
return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
|
85
85
|
|
86
86
|
|
87
|
-
def
|
88
|
-
num_jets=1000,
|
89
|
-
fname: str | None = None,
|
90
|
-
tracks_name: str = "tracks",
|
91
|
-
num_tracks: int = 40,
|
92
|
-
) -> tuple[str, h5py.File]:
|
87
|
+
def mock_jets(num_jets=1000) -> np.ndarray:
|
93
88
|
# setup jets
|
94
89
|
rng = np.random.default_rng(42)
|
95
90
|
jets_dtype = np.dtype(JET_VARS)
|
@@ -106,7 +101,26 @@ def get_mock_file(
|
|
106
101
|
jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
|
107
102
|
scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
|
108
103
|
xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
|
109
|
-
|
104
|
+
return join_structured_arrays([jets, scores, xbb_scores])
|
105
|
+
|
106
|
+
|
107
|
+
def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
|
108
|
+
rng = np.random.default_rng(42)
|
109
|
+
tracks_dtype = np.dtype(TRACK_VARS)
|
110
|
+
tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
|
111
|
+
tracks["d0"] *= 5
|
112
|
+
valid = rng.choice([True, False], size=(num_jets, num_tracks))
|
113
|
+
valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
|
114
|
+
return join_structured_arrays([tracks, valid])
|
115
|
+
|
116
|
+
|
117
|
+
def get_mock_file(
|
118
|
+
num_jets=1000,
|
119
|
+
fname: str | None = None,
|
120
|
+
tracks_name: str = "tracks",
|
121
|
+
num_tracks: int = 40,
|
122
|
+
) -> tuple[str, h5py.File]:
|
123
|
+
jets = mock_jets(num_jets)
|
110
124
|
|
111
125
|
# create a tempfile in a new folder
|
112
126
|
if fname is None:
|
@@ -120,11 +134,7 @@ def get_mock_file(
|
|
120
134
|
|
121
135
|
# setup tracks
|
122
136
|
if tracks_name:
|
123
|
-
|
124
|
-
tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
|
125
|
-
valid = rng.choice([True, False], size=(num_jets, num_tracks))
|
126
|
-
valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
|
127
|
-
tracks = join_structured_arrays([tracks, valid])
|
137
|
+
tracks = mock_tracks(num_jets, num_tracks)
|
128
138
|
f.create_dataset(tracks_name, data=tracks)
|
129
139
|
|
130
140
|
return fname, f
|
ftag/track_selector.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ftag import Cuts
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class TrackSelector:
|
12
|
+
"""
|
13
|
+
Apply track selections to a set of tracks stored in a structured numpy array.
|
14
|
+
|
15
|
+
The array is assumed to have shape (n_jets, n_tracks, n_features).
|
16
|
+
Applying cuts will NaN out the tracks that do not pass the cuts,
|
17
|
+
but leave the shape of the array unchanged.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
cuts : Cuts
|
22
|
+
The cuts to apply to the tracks
|
23
|
+
valid_str : str
|
24
|
+
The name of the field in the tracks that indicates whether the track is
|
25
|
+
"""
|
26
|
+
|
27
|
+
cuts: Cuts
|
28
|
+
valid_str: str = "valid"
|
29
|
+
|
30
|
+
def __call__(self, tracks: np.ndarray) -> np.ndarray:
|
31
|
+
# get a bool array for all tracks passing before any cuts
|
32
|
+
rm_idx = np.zeros_like(tracks[self.valid_str], dtype=bool)
|
33
|
+
|
34
|
+
# apply the cuts
|
35
|
+
for cut in self.cuts.cuts:
|
36
|
+
# remove valid track indices that do not pass the selection
|
37
|
+
rm_idx[tracks[self.valid_str] & ~cut(tracks)] = True
|
38
|
+
|
39
|
+
# set the values of the tracks that do not pass the cuts to
|
40
|
+
for var in tracks.dtype.names:
|
41
|
+
if issubclass(tracks[var].dtype.type, np.floating):
|
42
|
+
tracks[var][rm_idx] = np.nan
|
43
|
+
elif issubclass(tracks[var].dtype.type, np.integer):
|
44
|
+
tracks[var][rm_idx] = -1
|
45
|
+
elif issubclass(tracks[var].dtype.type, np.bool_):
|
46
|
+
tracks[var][rm_idx] = False
|
47
|
+
else:
|
48
|
+
raise TypeError(f"Unknown dtype {tracks[var].dtype}")
|
49
|
+
|
50
|
+
# specifically set the valid flag to false (even though it's already false by now)
|
51
|
+
tracks[rm_idx][self.valid_str] = False
|
52
|
+
|
53
|
+
return tracks
|
File without changes
|
File without changes
|