atlas-ftag-tools 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -11,8 +11,8 @@ Requires-Dist: h5py >=3.0
11
11
  Requires-Dist: numpy
12
12
  Requires-Dist: PyYAML >=5.1
13
13
  Provides-Extra: dev
14
- Requires-Dist: ruff ==0.2.2 ; extra == 'dev'
15
- Requires-Dist: mypy ==1.5.1 ; extra == 'dev'
14
+ Requires-Dist: ruff ==0.6.2 ; extra == 'dev'
15
+ Requires-Dist: mypy ==1.11.2 ; extra == 'dev'
16
16
  Requires-Dist: pre-commit ==3.1.1 ; extra == 'dev'
17
17
  Requires-Dist: pytest ==7.2.2 ; extra == 'dev'
18
18
  Requires-Dist: pytest-cov ==4.0.0 ; extra == 'dev'
@@ -0,0 +1,28 @@
1
+ ftag/__init__.py,sha256=i5XMWwkyGmwIIIaO3veRyOnTsJ0qKyW1V6HsuYc9Dm4,629
2
+ ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
+ ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
+ ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
5
+ ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
6
+ ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
+ ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
8
+ ftag/mock.py,sha256=FboI1Kq6aKZv43SpubiLkJvn4BSMDv2Fl2UniuhxspU,4502
9
+ ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
10
+ ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
11
+ ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
12
+ ftag/track_selector.py,sha256=WQ6lzY8n6pumIVfLMLgjlTkWUaIQu-8Dq3FKsmCsG_0,1736
13
+ ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
14
+ ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
15
+ ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
16
+ ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
17
+ ftag/hdf5/h5reader.py,sha256=H_5Aw0lOyEzK_phMRhD-jR_OSCsXnCA3qJZnRvPqaRU,13569
18
+ ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
19
+ ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
20
+ ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
21
+ ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
23
+ ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
24
+ atlas_ftag_tools-0.2.3.dist-info/METADATA,sha256=XP7QCxOKz-QUKKx5ds8ldz13n_qfTVcP-SZciExdIr0,5169
25
+ atlas_ftag_tools-0.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
26
+ atlas_ftag_tools-0.2.3.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
+ atlas_ftag_tools-0.2.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
+ atlas_ftag_tools-0.2.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.3.0)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.1"
5
+ __version__ = "v0.2.3"
6
6
 
7
7
 
8
8
  from ftag import hdf5
ftag/cuts.py CHANGED
@@ -79,7 +79,9 @@ class Cuts:
79
79
  def ignore(self, variables: list[str]):
80
80
  return Cuts(tuple(c for c in self if c.variable not in variables))
81
81
 
82
- def __call__(self, array) -> CutsResult:
82
+ def __call__(self, array: np.ndarray) -> CutsResult:
83
+ if array.ndim == 2:
84
+ raise ValueError("This interface only supports jet selections")
83
85
  keep = np.arange(len(array))
84
86
  for cut in self.cuts:
85
87
  idx = cut(array)
ftag/flavour.py CHANGED
@@ -42,6 +42,9 @@ class Flavour:
42
42
  def __str__(self) -> str:
43
43
  return self.name
44
44
 
45
+ def __lt__(self, other) -> bool:
46
+ return self.name < other.name
47
+
45
48
 
46
49
  @dataclass
47
50
  class FlavourContainer:
@@ -81,7 +84,10 @@ class FlavourContainer:
81
84
  return list(dict.fromkeys(f.category for f in self))
82
85
 
83
86
  def by_category(self, category: str) -> FlavourContainer:
84
- return FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
87
+ f = FlavourContainer({k: v for k, v in self.flavours.items() if v.category == category})
88
+ if not f.flavours:
89
+ raise KeyError(f"No flavours with category '{category}' found")
90
+ return f
85
91
 
86
92
  def from_cuts(self, cuts: list | Cuts) -> Flavour:
87
93
  if isinstance(cuts, list):
ftag/flavours.yaml CHANGED
@@ -47,6 +47,28 @@
47
47
  colour: gold
48
48
  category: single-btag-extended
49
49
 
50
+ # single b-tagging (ghost association)
51
+ - name: ghostbjets
52
+ label: $b$-jets
53
+ cuts: ["HadronGhostTruthLabelID == 5"]
54
+ colour: tab:blue
55
+ category: single-btag-ghost
56
+ - name: ghostcjets
57
+ label: $c$-jets
58
+ cuts: ["HadronGhostTruthLabelID == 4"]
59
+ colour: tab:orange
60
+ category: single-btag-ghost
61
+ - name: ghostujets
62
+ label: Light-jets
63
+ cuts: ["HadronGhostTruthLabelID == 0"]
64
+ colour: tab:green
65
+ category: single-btag-ghost
66
+ - name: ghosttaujets
67
+ label: $\tau$-jets
68
+ cuts: ["HadronGhostTruthLabelID == 15"]
69
+ colour: tab:purple
70
+ category: single-btag-ghost
71
+
50
72
  # Xbb tagging
51
73
  - name: hbb
52
74
  label: $H \rightarrow b\bar{b}$
ftag/hdf5/h5reader.py CHANGED
@@ -60,7 +60,7 @@ class H5SingleReader:
60
60
  for var in array.dtype.names:
61
61
  isinf = np.isinf(array[var])
62
62
  isinf = isinf if name == self.jets_name else isinf.any(axis=-1)
63
- keep_idx = keep_idx & ~isinf
63
+ keep_idx &= ~isinf
64
64
  if num_inf := isinf.sum():
65
65
  log.warning(
66
66
  f"{num_inf} inf values detected for variable {var} in"
ftag/hdf5/h5writer.py CHANGED
@@ -41,7 +41,7 @@ class H5Writer:
41
41
  jets_name: str = "jets"
42
42
  add_flavour_label: bool = False
43
43
  compression: str = "lzf"
44
- precision: str | None = None
44
+ precision: str = "full"
45
45
  shuffle: bool = True
46
46
 
47
47
  def __post_init__(self):
@@ -51,6 +51,13 @@ class H5Writer:
51
51
  assert len(set(self.num_jets)) == 1, "Must have same number of jets per group"
52
52
  self.num_jets = self.num_jets[0]
53
53
 
54
+ if self.precision == "full":
55
+ self.fp_dtype = np.float32
56
+ elif self.precision == "half":
57
+ self.fp_dtype = np.float16
58
+ else:
59
+ raise ValueError(f"Invalid precision: {self.precision}")
60
+
54
61
  self.dst = Path(self.dst)
55
62
  self.dst.parent.mkdir(parents=True, exist_ok=True)
56
63
  self.file = h5py.File(self.dst, "w")
@@ -77,11 +84,17 @@ class H5Writer:
77
84
  if name == self.jets_name and self.add_flavour_label:
78
85
  dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
79
86
 
87
+ # adjust dtype based on specified precision
88
+ dtype = np.dtype([
89
+ (field, self.fp_dtype if np.issubdtype(dt, np.floating) else dt)
90
+ for field, dt in dtype.descr
91
+ ])
92
+
80
93
  # optimal chunking is around 100 jets, only aply for track groups
81
94
  shape = self.shapes[name]
82
95
  chunks = (100,) + shape[1:] if shape[1:] else None
83
96
 
84
- # note: enabling the hd5 shuffle filter doesn't improve anything
97
+ # note: enabling the hd5 shuffle filter doesn't improve write performance
85
98
  self.file.create_dataset(
86
99
  name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
87
100
  )
ftag/labeller.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag import Flavours
8
+ from ftag.flavour import Flavour, FlavourContainer
9
+ from ftag.hdf5 import join_structured_arrays, structured_from_dict
10
+
11
+
12
+ @dataclass
13
+ class Labeller:
14
+ """
15
+ Defines a labelling scheme.
16
+
17
+ Labels are [0, ..., n] and are assigned using pre-defined selections.
18
+
19
+ Parameters
20
+ ----------
21
+ labels : FlavourContainer | list[str | Flavour]
22
+ The labels to be use.
23
+ require_labels : bool
24
+ Whether to require that all objects are labelled.
25
+ """
26
+
27
+ labels: FlavourContainer | list[str | Flavour]
28
+ require_labels: bool = True
29
+
30
+ def __post_init__(self) -> None:
31
+ if isinstance(self.labels, FlavourContainer):
32
+ self.labels = list(self.labels)
33
+ self.labels = sorted([Flavours[label] for label in self.labels])
34
+
35
+ def get_labels(self, array: np.ndarray) -> np.ndarray:
36
+ """
37
+ Returns the labels for the given array.
38
+
39
+ Parameters
40
+ ----------
41
+ array : np.ndarray
42
+ The array to label.
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray
47
+ The labels for the given array.
48
+
49
+ Raises
50
+ ------
51
+ ValueError
52
+ If the `require_labels` attribute is set to `True` and some objects were not labelled.
53
+ """
54
+ labels = -1 * np.ones_like(array, dtype=int)
55
+ for i, label in enumerate(self.labels):
56
+ labels[label.cuts(array).idx] = i
57
+
58
+ if self.require_labels and -1 in labels:
59
+ raise ValueError("Some objects were not labelled")
60
+
61
+ return labels[labels != -1]
62
+
63
+ def add_labels(self, array: np.ndarray, label_name: str = "labels") -> np.ndarray:
64
+ """
65
+ Adds the labels to the given array.
66
+
67
+ Parameters
68
+ ----------
69
+ array : np.ndarray
70
+ The array to label.
71
+ label_name : str
72
+ The name of the label column.
73
+
74
+ Returns
75
+ -------
76
+ np.ndarray
77
+ The array with the labels added.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If the `require_labels` attribute is set to `False`.
83
+ """
84
+ if not self.require_labels:
85
+ raise ValueError("Cannot add labels if require_labels is set to False")
86
+ labels = self.get_labels(array)
87
+ labels = structured_from_dict({label_name: labels})
88
+ return join_structured_arrays([array, labels])
ftag/mock.py CHANGED
@@ -52,7 +52,6 @@ TRACK_VARS = [
52
52
 
53
53
 
54
54
  def softmax(x, axis=None):
55
- """Compute softmax values for each sets of scores in x."""
56
55
  e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
57
56
  return e_x / e_x.sum(axis=axis, keepdims=True)
58
57
 
@@ -85,12 +84,7 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False):
85
84
  return u2s(scores, dtype=np.dtype([(name, "f4") for name in cols]))
86
85
 
87
86
 
88
- def get_mock_file(
89
- num_jets=1000,
90
- fname: str | None = None,
91
- tracks_name: str = "tracks",
92
- num_tracks: int = 40,
93
- ) -> tuple[str, h5py.File]:
87
+ def mock_jets(num_jets=1000) -> np.ndarray:
94
88
  # setup jets
95
89
  rng = np.random.default_rng(42)
96
90
  jets_dtype = np.dtype(JET_VARS)
@@ -107,7 +101,26 @@ def get_mock_file(
107
101
  jets["R10TruthLabel_R22v1"] = rng.choice([1, 10, 11, 12], size=num_jets)
108
102
  scores = get_mock_scores(jets["HadronConeExclTruthLabelID"])
109
103
  xbb_scores = get_mock_scores(jets["R10TruthLabel_R22v1"], is_xbb=True)
110
- jets = join_structured_arrays([jets, scores, xbb_scores])
104
+ return join_structured_arrays([jets, scores, xbb_scores])
105
+
106
+
107
+ def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
108
+ rng = np.random.default_rng(42)
109
+ tracks_dtype = np.dtype(TRACK_VARS)
110
+ tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
111
+ tracks["d0"] *= 5
112
+ valid = rng.choice([True, False], size=(num_jets, num_tracks))
113
+ valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
114
+ return join_structured_arrays([tracks, valid])
115
+
116
+
117
+ def get_mock_file(
118
+ num_jets=1000,
119
+ fname: str | None = None,
120
+ tracks_name: str = "tracks",
121
+ num_tracks: int = 40,
122
+ ) -> tuple[str, h5py.File]:
123
+ jets = mock_jets(num_jets)
111
124
 
112
125
  # create a tempfile in a new folder
113
126
  if fname is None:
@@ -121,11 +134,7 @@ def get_mock_file(
121
134
 
122
135
  # setup tracks
123
136
  if tracks_name:
124
- tracks_dtype = np.dtype(TRACK_VARS)
125
- tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
126
- valid = rng.choice([True, False], size=(num_jets, num_tracks))
127
- valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
128
- tracks = join_structured_arrays([tracks, valid])
137
+ tracks = mock_tracks(num_jets, num_tracks)
129
138
  f.create_dataset(tracks_name, data=tracks)
130
139
 
131
140
  return fname, f
ftag/track_selector.py ADDED
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from ftag import Cuts
8
+
9
+
10
+ @dataclass
11
+ class TrackSelector:
12
+ """
13
+ Apply track selections to a set of tracks stored in a structured numpy array.
14
+
15
+ The array is assumed to have shape (n_jets, n_tracks, n_features).
16
+ Applying cuts will NaN out the tracks that do not pass the cuts,
17
+ but leave the shape of the array unchanged.
18
+
19
+ Parameters
20
+ ----------
21
+ cuts : Cuts
22
+ The cuts to apply to the tracks
23
+ valid_str : str
24
+ The name of the field in the tracks that indicates whether the track is
25
+ """
26
+
27
+ cuts: Cuts
28
+ valid_str: str = "valid"
29
+
30
+ def __call__(self, tracks: np.ndarray) -> np.ndarray:
31
+ # get a bool array for all tracks passing before any cuts
32
+ rm_idx = np.zeros_like(tracks[self.valid_str], dtype=bool)
33
+
34
+ # apply the cuts
35
+ for cut in self.cuts.cuts:
36
+ # remove valid track indices that do not pass the selection
37
+ rm_idx[tracks[self.valid_str] & ~cut(tracks)] = True
38
+
39
+ # set the values of the tracks that do not pass the cuts to
40
+ for var in tracks.dtype.names:
41
+ if issubclass(tracks[var].dtype.type, np.floating):
42
+ tracks[var][rm_idx] = np.nan
43
+ elif issubclass(tracks[var].dtype.type, np.integer):
44
+ tracks[var][rm_idx] = -1
45
+ elif issubclass(tracks[var].dtype.type, np.bool_):
46
+ tracks[var][rm_idx] = False
47
+ else:
48
+ raise TypeError(f"Unknown dtype {tracks[var].dtype}")
49
+
50
+ # specifically set the valid flag to false (even though it's already false by now)
51
+ tracks[rm_idx][self.valid_str] = False
52
+
53
+ return tracks
ftag/wps/discriminant.py CHANGED
@@ -100,6 +100,11 @@ def get_discriminant(
100
100
  -------
101
101
  np.ndarray
102
102
  Array of discriminant values.
103
+
104
+ Raises
105
+ ------
106
+ ValueError
107
+ If the signal flavour is not recognised.
103
108
  """
104
109
  tagger_funcs: dict[str, Callable] = {
105
110
  "bjets": btag_discriminant,
@@ -1,26 +0,0 @@
1
- ftag/__init__.py,sha256=aCC6idmHdlETdgAFN4PeESN67WpeVqo85Nledw9VkH4,629
2
- ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
- ftag/cuts.py,sha256=a0BJj4cVRunc-hFLPloGvNoSFvRmZg2kVLv7sA0iAaI,2817
4
- ftag/flavour.py,sha256=qvgp4DarOdcQgjae_NWnd81k_YqdmFY74lOKky2lpb8,3568
5
- ftag/flavours.yaml,sha256=h0A2cw-je1oCe-rh5qhuL1BhDNhu2aLkf7W9y-Cpy3g,6959
6
- ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
- ftag/mock.py,sha256=9V6sAT4_t-rhR67q9KHaj1NKAeqU7lQjWxiOxEzk8Sw,4338
8
- ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
9
- ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
10
- ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
11
- ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
12
- ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
13
- ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
14
- ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
15
- ftag/hdf5/h5reader.py,sha256=et-_LXt942xegqc14bPapUgIO7MUfC2m04uJslLkXxI,13579
16
- ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
17
- ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
18
- ftag/hdf5/h5writer.py,sha256=wVyurIgfSBtvZTX-v0v3R5-8JOwWK_yF1rUX-RewXzY,4826
19
- ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- ftag/wps/discriminant.py,sha256=c2-bal124yY1rZjK87iQHb9RRo4pus6SVue6SfMKrRY,3749
21
- ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
22
- atlas_ftag_tools-0.2.1.dist-info/METADATA,sha256=KpKapru-RWOoq01J5sm04jE2brwpoPxDGVPTZoppl8c,5168
23
- atlas_ftag_tools-0.2.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
24
- atlas_ftag_tools-0.2.1.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
25
- atlas_ftag_tools-0.2.1.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
26
- atlas_ftag_tools-0.2.1.dist-info/RECORD,,