atlas-ftag-tools 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,13 +1,15 @@
1
- ftag/__init__.py,sha256=7IKOa65yKaQWsx6-s7VVQs4t1NQ9hyVAOlj-U5m-VBk,629
1
+ ftag/__init__.py,sha256=i5XMWwkyGmwIIIaO3veRyOnTsJ0qKyW1V6HsuYc9Dm4,629
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
- ftag/cuts.py,sha256=a0BJj4cVRunc-hFLPloGvNoSFvRmZg2kVLv7sA0iAaI,2817
4
- ftag/flavour.py,sha256=qvgp4DarOdcQgjae_NWnd81k_YqdmFY74lOKky2lpb8,3568
3
+ ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
+ ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
5
5
  ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
6
6
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
- ftag/mock.py,sha256=QAm0ti6FWDCRtIyay4yozbGNNATDQbq5b1uc8uVhi2s,4275
7
+ ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
8
+ ftag/mock.py,sha256=FboI1Kq6aKZv43SpubiLkJvn4BSMDv2Fl2UniuhxspU,4502
8
9
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
9
10
  ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
10
11
  ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
12
+ ftag/track_selector.py,sha256=WQ6lzY8n6pumIVfLMLgjlTkWUaIQu-8Dq3FKsmCsG_0,1736
11
13
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
12
14
  ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
13
15
  ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
@@ -19,8 +21,8 @@ ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
19
21
  ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
22
  ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
21
23
  ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
22
- atlas_ftag_tools-0.2.2.dist-info/METADATA,sha256=y2fq23cqtkaoUQxEiCrdoTVuBcG154yjo4k4cwf8P-A,5169
23
- atlas_ftag_tools-0.2.2.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
24
- atlas_ftag_tools-0.2.2.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
25
- atlas_ftag_tools-0.2.2.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
26
- atlas_ftag_tools-0.2.2.dist-info/RECORD,,
24
+ atlas_ftag_tools-0.2.3.dist-info/METADATA,sha256=XP7QCxOKz-QUKKx5ds8ldz13n_qfTVcP-SZciExdIr0,5169
25
+ atlas_ftag_tools-0.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
26
+ atlas_ftag_tools-0.2.3.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
+ atlas_ftag_tools-0.2.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
+ atlas_ftag_tools-0.2.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.2"
5
+ __version__ = "v0.2.3"
6
6
 
7
7
 
8
8
  from ftag import hdf5
ftag/cuts.py CHANGED
@@ -79,7 +79,9 @@ class Cuts:
79
79
  def ignore(self, variables: list[str]):
80
80
  return Cuts(tuple(c for c in self if c.variable not in variables))
81
81
 
82
- def __call__(self, array) -> CutsResult:
82
+ def __call__(self, array: np.ndarray) -> CutsResult:
83
+ if array.ndim == 2:
84
+ raise ValueError("This interface only supports jet selections")
83
85
  keep = np.arange(len(array))
84
86
  for cut in self.cuts:
85
87
  idx = cut(array)
ftag/flavour.py CHANGED
@@ -42,6 +42,9 @@ class Flavour:
42
42
  def __str__(self) -> str:
43
43
  return self.name
44
44
 
45
+ def __lt__(self, other) -> bool:
46
+ return self.name < other.name
47
+
45
48
 
46
49
  @dataclass
47
50
  class FlavourContainer:
@@ -81,7 +84,10 @@ class FlavourContainer:
81
84
  return list(dict.fromkeys(f.category for f in self))
82
85
 
83
86
  def by_category(self, category: str) -> FlavourContainer:
84
- return FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
87
+ f = FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
88
+ if not f.flavours:
89
+ raise KeyError(f"No flavours with category '{category}' found")
90
+ return f
85
91
 
86
92
  def from_cuts(self, cuts: list | Cuts) -> Flavour:
87
93
  if isinstance(cuts, list):
ftag/labeller.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag import Flavours
8
+ from ftag.flavour import Flavour, FlavourContainer
9
+ from ftag.hdf5 import join_structured_arrays, structured_from_dict
10
+
11
+
12
+ @dataclass
13
+ class Labeller:
14
+ """
15
+ Defines a labelling scheme.
16
+
17
+ Labels are [0, ..., n] and are assigned using pre-defined selections.
18
+
19
+ Parameters
20
+ ----------
21
+ labels : FlavourContainer | list[str | Flavour]
22
+ The labels to be use.
23
+ require_labels : bool
24
+ Whether to require that all objects are labelled.
25
+ """
26
+
27
+ labels: FlavourContainer | list[str | Flavour]
28
+ require_labels: bool = True
29
+
30
+ def __post_init__(self) -> None:
31
+ if isinstance(self.labels, FlavourContainer):
32
+ self.labels = list(self.labels)
33
+ self.labels = sorted([Flavours[label] for label in self.labels])
34
+
35
+ def get_labels(self, array: np.ndarray) -> np.ndarray:
36
+ """
37
+ Returns the labels for the given array.
38
+
39
+ Parameters
40
+ ----------
41
+ array : np.ndarray
42
+ The array to label.
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray
47
+ The labels for the given array.
48
+
49
+ Raises
50
+ ------
51
+ ValueError
52
+ If the `require_labels` attribute is set to `True` and some objects were not labelled.
53
+ """
54
+ labels = -1 * np.ones_like(array, dtype=int)
55
+ for i, label in enumerate(self.labels):
56
+ labels[label.cuts(array).idx] = i
57
+
58
+ if self.require_labels and -1 in labels:
59
+ raise ValueError("Some objects were not labelled")
60
+
61
+ return labels[labels != -1]
62
+
63
+ def add_labels(self, array: np.ndarray, label_name: str = "labels") -> np.ndarray:
64
+ """
65
+ Adds the labels to the given array.
66
+
67
+ Parameters
68
+ ----------
69
+ array : np.ndarray
70
+ The array to label.
71
+ label_name : str
72
+ The name of the label column.
73
+
74
+ Returns
75
+ -------
76
+ np.ndarray
77
+ The array with the labels added.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If the `require_labels` attribute is set to `False`.
83
+ """
84
+ if not self.require_labels:
85
+ raise ValueError("Cannot add labels if require_labels is set to False")
86
+ labels = self.get_labels(array)
87
+ labels = structured_from_dict({label_name: labels})
88
+ return join_structured_arrays([array, labels])
ftag/mock.py CHANGED
@@ -84,12 +84,7 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
84
84
  return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
85
85
 
86
86
 
87
- def get_mock_file(
88
- num_jets=1000,
89
- fname: str | None = None,
90
- tracks_name: str = "tracks",
91
- num_tracks: int = 40,
92
- ) -> tuple[str, h5py.File]:
87
+ def mock_jets(num_jets=1000) -> np.ndarray:
93
88
  # setup jets
94
89
  rng = np.random.default_rng(42)
95
90
  jets_dtype = np.dtype(JET_VARS)
@@ -106,7 +101,26 @@ def get_mock_file(
106
101
  jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
107
102
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
108
103
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
109
- jets = join_structured_arrays([jets, scores, xbb_scores])
104
+ return join_structured_arrays([jets, scores, xbb_scores])
105
+
106
+
107
+ def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
108
+ rng = np.random.default_rng(42)
109
+ tracks_dtype = np.dtype(TRACK_VARS)
110
+ tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
111
+ tracks["d0"] *= 5
112
+ valid = rng.choice([True, False], size=(num_jets, num_tracks))
113
+ valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
114
+ return join_structured_arrays([tracks, valid])
115
+
116
+
117
+ def get_mock_file(
118
+ num_jets=1000,
119
+ fname: str | None = None,
120
+ tracks_name: str = "tracks",
121
+ num_tracks: int = 40,
122
+ ) -> tuple[str, h5py.File]:
123
+ jets = mock_jets(num_jets)
110
124
 
111
125
  # create a tempfile in a new folder
112
126
  if fname is None:
@@ -120,11 +134,7 @@ def get_mock_file(
120
134
 
121
135
  # setup tracks
122
136
  if tracks_name:
123
- tracks_dtype = np.dtype(TRACK_VARS)
124
- tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
125
- valid = rng.choice([True, False], size=(num_jets, num_tracks))
126
- valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
127
- tracks = join_structured_arrays([tracks, valid])
137
+ tracks = mock_tracks(num_jets, num_tracks)
128
138
  f.create_dataset(tracks_name, data=tracks)
129
139
 
130
140
  return fname, f
ftag/track_selector.py ADDED
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag import Cuts
8
+
9
+
10
+ @dataclass
11
+ class TrackSelector:
12
+ """
13
+ Apply track selections to a set of tracks stored in a structured numpy array.
14
+
15
+ The array is assumed to have shape (n_jets, n_tracks, n_features).
16
+ Applying cuts will NaN out the tracks that do not pass the cuts,
17
+ but leave the shape of the array unchanged.
18
+
19
+ Parameters
20
+ ----------
21
+ cuts : Cuts
22
+ The cuts to apply to the tracks
23
+ valid_str : str
24
+ The name of the field in the tracks that indicates whether the track is
25
+ """
26
+
27
+ cuts: Cuts
28
+ valid_str: str = "valid"
29
+
30
+ def __call__(self, tracks: np.ndarray) -> np.ndarray:
31
+ # get a bool array for all tracks passing before any cuts
32
+ rm_idx = np.zeros_like(tracks[self.valid_str], dtype=bool)
33
+
34
+ # apply the cuts
35
+ for cut in self.cuts.cuts:
36
+ # remove valid track indices that do not pass the selection
37
+ rm_idx[tracks[self.valid_str] & ~cut(tracks)] = True
38
+
39
+ # set the values of the tracks that do not pass the cuts to
40
+ for var in tracks.dtype.names:
41
+ if issubclass(tracks[var].dtype.type, np.floating):
42
+ tracks[var][rm_idx] = np.nan
43
+ elif issubclass(tracks[var].dtype.type, np.integer):
44
+ tracks[var][rm_idx] = -1
45
+ elif issubclass(tracks[var].dtype.type, np.bool_):
46
+ tracks[var][rm_idx] = False
47
+ else:
48
+ raise TypeError(f"Unknown dtype {tracks[var].dtype}")
49
+
50
+ # specifically set the valid flag to false (even though it's already false by now)
51
+ tracks[rm_idx][self.valid_str] = False
52
+
53
+ return tracks