atlas-ftag-tools 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,26 +1,28 @@
1
- ftag/__init__.py,sha256=7IKOa65yKaQWsx6-s7VVQs4t1NQ9hyVAOlj-U5m-VBk,629
1
+ ftag/__init__.py,sha256=Mx2Emsw4TM1YL0wTQMHK36EaQE_ImeV6ukiz1X5BZAU,629
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
- ftag/cuts.py,sha256=a0BJj4cVRunc-hFLPloGvNoSFvRmZg2kVLv7sA0iAaI,2817
4
- ftag/flavour.py,sha256=qvgp4DarOdcQgjae_NWnd81k_YqdmFY74lOKky2lpb8,3568
3
+ ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
+ ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
5
5
  ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
6
6
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
- ftag/mock.py,sha256=QAm0ti6FWDCRtIyay4yozbGNNATDQbq5b1uc8uVhi2s,4275
7
+ ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
8
+ ftag/mock.py,sha256=_oy-r3eLllFy33NAoZaKfAx-Rp2vrCdrGj3UsTMks94,4740
8
9
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
9
10
  ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
10
11
  ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
12
+ ftag/track_selector.py,sha256=piSYAN_IkOsrXxKXjXbJpMSseUig5P2BJW5mCwsMUDM,2535
11
13
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
12
14
  ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
13
15
  ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
14
16
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
15
- ftag/hdf5/h5reader.py,sha256=H_5Aw0lOyEzK_phMRhD-jR_OSCsXnCA3qJZnRvPqaRU,13569
17
+ ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
16
18
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
17
19
  ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
18
20
  ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
19
21
  ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
22
  ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
21
23
  ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
22
- atlas_ftag_tools-0.2.2.dist-info/METADATA,sha256=y2fq23cqtkaoUQxEiCrdoTVuBcG154yjo4k4cwf8P-A,5169
23
- atlas_ftag_tools-0.2.2.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
24
- atlas_ftag_tools-0.2.2.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
25
- atlas_ftag_tools-0.2.2.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
26
- atlas_ftag_tools-0.2.2.dist-info/RECORD,,
24
+ atlas_ftag_tools-0.2.4.dist-info/METADATA,sha256=f4aCu6JmItUBp5EmTzbrqhC5-Wsy7uiOiBO9yufyacQ,5169
25
+ atlas_ftag_tools-0.2.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
26
+ atlas_ftag_tools-0.2.4.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
+ atlas_ftag_tools-0.2.4.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
+ atlas_ftag_tools-0.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.2"
5
+ __version__ = "v0.2.4"
6
6
 
7
7
 
8
8
  from ftag import hdf5
ftag/cuts.py CHANGED
@@ -79,7 +79,9 @@ class Cuts:
79
79
  def ignore(self, variables: list[str]):
80
80
  return Cuts(tuple(c for c in self if c.variable not in variables))
81
81
 
82
- def __call__(self, array) -> CutsResult:
82
+ def __call__(self, array: np.ndarray) -> CutsResult:
83
+ if array.ndim == 2:
84
+ raise ValueError("This interface only supports jet selections")
83
85
  keep = np.arange(len(array))
84
86
  for cut in self.cuts:
85
87
  idx = cut(array)
ftag/flavour.py CHANGED
@@ -42,6 +42,9 @@ class Flavour:
42
42
  def __str__(self) -> str:
43
43
  return self.name
44
44
 
45
+ def __lt__(self, other) -> bool:
46
+ return self.name < other.name
47
+
45
48
 
46
49
  @dataclass
47
50
  class FlavourContainer:
@@ -81,7 +84,10 @@ class FlavourContainer:
81
84
  return list(dict.fromkeys(f.category for f in self))
82
85
 
83
86
  def by_category(self, category: str) -> FlavourContainer:
84
- return FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
87
+ f = FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
88
+ if not f.flavours:
89
+ raise KeyError(f"No flavours with category '{category}' found")
90
+ return f
85
91
 
86
92
  def from_cuts(self, cuts: list | Cuts) -> Flavour:
87
93
  if isinstance(cuts, list):
ftag/hdf5/h5reader.py CHANGED
@@ -150,8 +150,13 @@ class H5Reader:
150
150
  transform : Transform | None, optional
151
151
  Transform to apply to data, by default None
152
152
  equal_jets : bool, optional
153
- Take the same number of jets (weighted) from each sample, by default True
154
- If False, use all jets in each sample.
153
+ Take the same number of jets (weighted) from each sample, by default True.
154
+ This is useful when you specify a list of DSIDs for the sample and they are
155
+ qualitatively different, and you want to ensure that you always return batches
156
+ with jets from all DSIDs. This is used for example in the QCD resampling for Xbb.
157
+ If False, use all jets in each sample, allowing for the full available statistics
158
+ to be used. Useful for example if you have multiple ttbar samples and you want to
159
+ use all available jets from each sample.
155
160
  """
156
161
 
157
162
  fname: Path | str | list[Path | str]
@@ -162,17 +167,10 @@ class H5Reader:
162
167
  weights: list[float] | None = None
163
168
  do_remove_inf: bool = False
164
169
  transform: Transform | None = None
165
- equal_jets: bool = True
170
+ equal_jets: bool = False
166
171
 
167
172
  def __post_init__(self) -> None:
168
173
  self.rng = np.random.default_rng(42)
169
- if not self.equal_jets:
170
- log.warning(
171
- "equal_jets is set to False, which will result in different number of jets taken"
172
- " from each sample. Be aware that this can affect the resampling, so make sure you"
173
- " know what you are doing."
174
- )
175
-
176
174
  if isinstance(self.fname, (str, Path)):
177
175
  self.fname = [self.fname]
178
176
 
@@ -283,8 +281,8 @@ class H5Reader:
283
281
  try:
284
282
  samples.append(next(stream))
285
283
 
286
- # if equal_jets is True, we can stop when any stream is done
287
- # otherwise if sample is exhausted, mark it as done
284
+ # if equal_jets is True, stop when any sample is done
285
+ # otherwise if stream is exhausted, mark it as such and continue
288
286
  except StopIteration:
289
287
  if self.equal_jets:
290
288
  return
ftag/labeller.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag import Flavours
8
+ from ftag.flavour import Flavour, FlavourContainer
9
+ from ftag.hdf5 import join_structured_arrays, structured_from_dict
10
+
11
+
12
+ @dataclass
13
+ class Labeller:
14
+ """
15
+ Defines a labelling scheme.
16
+
17
+ Labels are [0, ..., n] and are assigned using pre-defined selections.
18
+
19
+ Parameters
20
+ ----------
21
+ labels : FlavourContainer | list[str | Flavour]
22
+ The labels to be use.
23
+ require_labels : bool
24
+ Whether to require that all objects are labelled.
25
+ """
26
+
27
+ labels: FlavourContainer | list[str | Flavour]
28
+ require_labels: bool = True
29
+
30
+ def __post_init__(self) -> None:
31
+ if isinstance(self.labels, FlavourContainer):
32
+ self.labels = list(self.labels)
33
+ self.labels = sorted([Flavours[label] for label in self.labels])
34
+
35
+ def get_labels(self, array: np.ndarray) -> np.ndarray:
36
+ """
37
+ Returns the labels for the given array.
38
+
39
+ Parameters
40
+ ----------
41
+ array : np.ndarray
42
+ The array to label.
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray
47
+ The labels for the given array.
48
+
49
+ Raises
50
+ ------
51
+ ValueError
52
+ If the `require_labels` attribute is set to `True` and some objects were not labelled.
53
+ """
54
+ labels = -1 * np.ones_like(array, dtype=int)
55
+ for i, label in enumerate(self.labels):
56
+ labels[label.cuts(array).idx] = i
57
+
58
+ if self.require_labels and -1 in labels:
59
+ raise ValueError("Some objects were not labelled")
60
+
61
+ return labels[labels != -1]
62
+
63
+ def add_labels(self, array: np.ndarray, label_name: str = "labels") -> np.ndarray:
64
+ """
65
+ Adds the labels to the given array.
66
+
67
+ Parameters
68
+ ----------
69
+ array : np.ndarray
70
+ The array to label.
71
+ label_name : str
72
+ The name of the label column.
73
+
74
+ Returns
75
+ -------
76
+ np.ndarray
77
+ The array with the labels added.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If the `require_labels` attribute is set to `False`.
83
+ """
84
+ if not self.require_labels:
85
+ raise ValueError("Cannot add labels if require_labels is set to False")
86
+ labels = self.get_labels(array)
87
+ labels = structured_from_dict({label_name: labels})
88
+ return join_structured_arrays([array, labels])
ftag/mock.py CHANGED
@@ -84,12 +84,7 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
84
84
  return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
85
85
 
86
86
 
87
- def get_mock_file(
88
- num_jets=1000,
89
- fname: str | None = None,
90
- tracks_name: str = "tracks",
91
- num_tracks: int = 40,
92
- ) -> tuple[str, h5py.File]:
87
+ def mock_jets(num_jets=1000) -> np.ndarray:
93
88
  # setup jets
94
89
  rng = np.random.default_rng(42)
95
90
  jets_dtype = np.dtype(JET_VARS)
@@ -106,7 +101,31 @@ def get_mock_file(
106
101
  jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
107
102
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
108
103
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
109
- jets = join_structured_arrays([jets, scores, xbb_scores])
104
+ return join_structured_arrays([jets, scores, xbb_scores])
105
+
106
+
107
+ def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
108
+ rng = np.random.default_rng(42)
109
+ tracks_dtype = np.dtype(TRACK_VARS)
110
+ tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
111
+ tracks["d0"] *= 5
112
+
113
+ # for the shared hits, add some reasonable integer values
114
+ tracks["numberOfPixelSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
115
+ tracks["numberOfSCTSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
116
+
117
+ valid = rng.choice([True, False], size=(num_jets, num_tracks))
118
+ valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
119
+ return join_structured_arrays([tracks, valid])
120
+
121
+
122
+ def get_mock_file(
123
+ num_jets=1000,
124
+ fname: str | None = None,
125
+ tracks_name: str = "tracks",
126
+ num_tracks: int = 40,
127
+ ) -> tuple[str, h5py.File]:
128
+ jets = mock_jets(num_jets)
110
129
 
111
130
  # create a tempfile in a new folder
112
131
  if fname is None:
@@ -120,11 +139,7 @@ def get_mock_file(
120
139
 
121
140
  # setup tracks
122
141
  if tracks_name:
123
- tracks_dtype = np.dtype(TRACK_VARS)
124
- tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
125
- valid = rng.choice([True, False], size=(num_jets, num_tracks))
126
- valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
127
- tracks = join_structured_arrays([tracks, valid])
142
+ tracks = mock_tracks(num_jets, num_tracks)
128
143
  f.create_dataset(tracks_name, data=tracks)
129
144
 
130
145
  return fname, f
ftag/track_selector.py ADDED
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag.cuts import Cut, Cuts
8
+
9
+
10
+ @dataclass
11
+ class TrackSelector:
12
+ """
13
+ Apply track selections to a set of tracks stored in a structured numpy array.
14
+
15
+ The array is assumed to have shape (n_jets, n_tracks, n_features).
16
+ Applying cuts will NaN out the tracks that do not pass the cuts,
17
+ but leave the shape of the array unchanged.
18
+
19
+ Parameters
20
+ ----------
21
+ cuts : Cuts
22
+ The cuts to apply to the tracks
23
+ valid_str : str
24
+ The name of the field in the tracks that indicates whether the track is
25
+ """
26
+
27
+ cuts: Cuts
28
+ valid_str: str = "valid"
29
+
30
+ def __call__(self, tracks: np.ndarray) -> np.ndarray:
31
+ # get a bool array for all tracks passing before any cuts
32
+ rm_idx = np.zeros_like(tracks[self.valid_str], dtype=bool)
33
+
34
+ # apply the cuts
35
+ for cut in self.cuts.cuts:
36
+ # remove valid track indices that do not pass the selection
37
+ keep_idx = self._nshared_cut(cut, tracks) if cut.variable == "NSHARED" else cut(tracks)
38
+ rm_idx[tracks[self.valid_str] & ~keep_idx] = True
39
+
40
+ # set the values of the tracks that do not pass the cuts to
41
+ for var in tracks.dtype.names:
42
+ if issubclass(tracks[var].dtype.type, np.floating):
43
+ tracks[var][rm_idx] = np.nan
44
+ elif issubclass(tracks[var].dtype.type, np.integer):
45
+ tracks[var][rm_idx] = -1
46
+ elif issubclass(tracks[var].dtype.type, np.bool_):
47
+ tracks[var][rm_idx] = False
48
+ else:
49
+ raise TypeError(f"Unknown dtype {tracks[var].dtype}")
50
+
51
+ # specifically set the valid flag to false (even though it's already false by now)
52
+ tracks[rm_idx][self.valid_str] = False
53
+
54
+ return tracks
55
+
56
+ def _nshared_cut(self, cut: Cut, tracks: np.ndarray) -> np.ndarray:
57
+ # hack to apply the FTAG shared hit cut, which requires an intermediate step
58
+ if cut.variable == "NSHARED" and "NSHARED" in tracks.dtype.names:
59
+ raise ValueError("NSHARED is a reserved variable name")
60
+
61
+ # compute
62
+ n_pix_shared = tracks["numberOfPixelSharedHits"]
63
+ n_sct_shared = tracks["numberOfSCTSharedHits"]
64
+ n_module_shared = n_pix_shared + n_sct_shared / 2
65
+
66
+ # convert n_module_shared to structured array
67
+ n_module_shared = n_module_shared.view(dtype=[(cut.variable, n_module_shared.dtype)])
68
+
69
+ # select
70
+ return cut(n_module_shared)