atlas-ftag-tools 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -39,7 +39,7 @@ If you want to use this package without modification, you can install from [pypi
39
39
  pip install atlas-ftag-tools
40
40
  ```
41
41
 
42
- To additionally install the development dependencies (for formatting and linting) rn
42
+ To additionally install the development dependencies (for formatting and linting) use
43
43
  ```bash
44
44
  pip install atlas-ftag-tools[dev]
45
45
  ```
@@ -58,10 +58,11 @@ Include development dependencies with
58
58
  python -m pip install -e ".[dev]"
59
59
  ```
60
60
 
61
- You can set up pre-commit hooks with
61
+ You can set up and run pre-commit hooks with
62
62
 
63
63
  ```bash
64
64
  pre-commit install
65
+ pre-commmit run --all-files
65
66
  ```
66
67
 
67
68
  To run the tests you can use the `pytest` or `coverage` command, for example
@@ -75,6 +76,9 @@ Running `coverage report` will display the test coverage.
75
76
 
76
77
  # Usage
77
78
 
79
+ Please see the [example notebook](ftag/example.ipynb) for full usage.
80
+ Additional functionality is also documented below.
81
+
78
82
  ## Create virtual file
79
83
 
80
84
  This package contains a script to easily merge a set of H5 files.
@@ -0,0 +1,19 @@
1
+ ftag/__init__.py,sha256=XBQEZpFSnGyihB9F3eGOvB_5YknggY_L6fzwYszXLuQ,543
2
+ ftag/cuts.py,sha256=lCnyHd4kbrt3CMXGE1ASCgaa07o1qOBn6GQek6lClVQ,2734
3
+ ftag/flavour.py,sha256=sEelvHNLWmHsecQQrmRc8ktwykMMHnGX8ePDRrqQkuo,2460
4
+ ftag/flavours.yaml,sha256=VrOGD5FUhMVPIW31whY-nSqNv98AcnLsPmPGmAcCg3w,3287
5
+ ftag/mock.py,sha256=HUyYOPsRtkmzjLRNF2zs0kpVUrTRIHTsnIyDlXIZArU,3627
6
+ ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
7
+ ftag/sample.py,sha256=uVNyxFYMMtkP-o2tjQatpo8mIH4ZNNe3mSFEPebYh_E,2622
8
+ ftag/vds.py,sha256=8b5-zqDELUmxdO5Txdowe3v7XGS1pKgO20bhzUQqCxU,2945
9
+ ftag/hdf5/__init__.py,sha256=A_a_4IUlZ2mSiDcfrZKBdja_3iTrUHvADM2lWx6g66g,325
10
+ ftag/hdf5/h5reader.py,sha256=1_iyYfWI1ht1-p9vBBpGhw47ZKola_KhWxbrywoB-Jg,11751
11
+ ftag/hdf5/h5utils.py,sha256=GKduv9b6JRSBirRdmNgGcmsINCMTj54kH4RQqxrM1t8,2363
12
+ ftag/hdf5/h5writer.py,sha256=_N-DJSX283r-XsGczvLFA4_qaK4BkFkdKZAusHEvRjU,2919
13
+ ftag/wps/discriminant.py,sha256=86ISONTuIjqTJO1A27oqkoCgDjAQinofiYNdcjfdkIk,1380
14
+ ftag/wps/working_points.py,sha256=487NsQGGY2Qt4q8mXxKABMFa-YLsbrhkPLcYVdebeVk,4950
15
+ atlas_ftag_tools-0.1.5.dist-info/METADATA,sha256=Uc4Z2zAMD7jsSKoV6o2LJwfm2X0KEWYRingGT_msE4I,4182
16
+ atlas_ftag_tools-0.1.5.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
17
+ atlas_ftag_tools-0.1.5.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
18
+ atlas_ftag_tools-0.1.5.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
19
+ atlas_ftag_tools-0.1.5.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """atlas-ftag-tools - Common tools for ATLAS flavour tagging software."""
2
2
 
3
3
 
4
- __version__ = "v0.1.3"
4
+ __version__ = "v0.1.5"
5
5
 
6
6
 
7
7
  import ftag.hdf5 as hdf5
ftag/cuts.py CHANGED
@@ -20,7 +20,7 @@ OPERATORS = {
20
20
  "notin": lambda x, y: ~np.isin(x, y),
21
21
  }
22
22
 
23
- for i in range(2, 20):
23
+ for i in range(2, 101):
24
24
  OPERATORS[f"%{i}=="] = functools.partial(lambda x, y, i: (x % i) == y, i=i)
25
25
  OPERATORS[f"%{i}<="] = functools.partial(lambda x, y, i: (x % i) <= y, i=i)
26
26
  OPERATORS[f"%{i}>="] = functools.partial(lambda x, y, i: (x % i) >= y, i=i)
ftag/flavours.yaml CHANGED
@@ -49,12 +49,12 @@
49
49
 
50
50
  # Xbb tagging
51
51
  - name: hbb
52
- label: Hbb
52
+ label: $H \rightarrow b\bar{b}$
53
53
  cuts: ["R10TruthLabel_R22v1 == 11"]
54
54
  colour: tab:blue
55
55
  category: xbb
56
56
  - name: hcc
57
- label: Hcc
57
+ label: $H \rightarrow c\bar{c}$
58
58
  cuts: ["R10TruthLabel_R22v1 == 12"]
59
59
  colour: "#B45F06"
60
60
  category: xbb
@@ -63,6 +63,11 @@
63
63
  cuts: ["R10TruthLabel_R22v1 == 1"]
64
64
  colour: "#A300A3"
65
65
  category: xbb
66
+ - name: inclusive_top
67
+ label: Inclusive Top
68
+ cuts: ["R10TruthLabel_R22v1 in (1,6,7)"]
69
+ colour: "#A300A3"
70
+ category: xbb
66
71
  - name: qcd
67
72
  label: QCD
68
73
  cuts: ["R10TruthLabel_R22v1 == 10"]
ftag/hdf5/h5reader.py CHANGED
@@ -26,9 +26,10 @@ class H5SingleReader:
26
26
 
27
27
  def __post_init__(self) -> None:
28
28
  self.sample = Sample(self.fname)
29
- if len(self.sample.virtual_file()) != 1:
29
+ fname = self.sample.virtual_file()
30
+ if len(fname) != 1:
30
31
  raise ValueError("H5SingleReader should only read a single file")
31
- self.fname = self.sample.virtual_file()[0]
32
+ self.fname = fname[0]
32
33
 
33
34
  @cached_property
34
35
  def num_jets(self) -> int:
@@ -57,21 +58,27 @@ class H5SingleReader:
57
58
  isinf = np.isinf(array[var])
58
59
  keep_idx = keep_idx & ~isinf.any(axis=-1)
59
60
  if num_inf := isinf.sum():
60
- log.warn(
61
+ log.warning(
61
62
  f"{num_inf} inf values detected for variable {var} in"
62
63
  f" {name} array. Removing the affected jets."
63
64
  )
64
65
  return {name: array[keep_idx] for name, array in data.items()}
65
66
 
66
67
  def stream(
67
- self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
68
+ self,
69
+ variables: dict | None = None,
70
+ num_jets: int | None = None,
71
+ cuts: Cuts | None = None,
68
72
  ) -> Generator:
69
73
  if num_jets is None:
70
74
  num_jets = self.num_jets
75
+
71
76
  if num_jets > self.num_jets:
72
- raise ValueError(
73
- f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}"
77
+ log.warning(
78
+ f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
79
+ " Set to maximum available number!"
74
80
  )
81
+ num_jets = self.num_jets
75
82
 
76
83
  if variables is None:
77
84
  variables = {self.jets_name: None}
@@ -131,6 +138,9 @@ class H5Reader:
131
138
  Weights for different input datasets, by default None
132
139
  do_remove_inf : bool, optional
133
140
  Remove jets with inf values, by default False
141
+ equal_jets : bool, optional
142
+ Take the same number of jets (weighted) from each sample, by default True
143
+ If False, use all jets in each sample.
134
144
  """
135
145
 
136
146
  fname: Path | str | list[Path | str]
@@ -140,8 +150,16 @@ class H5Reader:
140
150
  shuffle: bool = True
141
151
  weights: list[float] | None = None
142
152
  do_remove_inf: bool = False
153
+ equal_jets: bool = True
143
154
 
144
155
  def __post_init__(self) -> None:
156
+ if not self.equal_jets:
157
+ log.warning(
158
+ "equal_jets is set to False, which will result in different number of jets taken"
159
+ " from each sample. Be aware that this can affect the resampling, so make sure you"
160
+ " know what you are doing."
161
+ )
162
+
145
163
  if isinstance(self.fname, (str, Path)):
146
164
  self.fname = [self.fname]
147
165
 
@@ -191,10 +209,14 @@ class H5Reader:
191
209
  Generator
192
210
  Generator of batches of selected jets.
193
211
  """
212
+ # Check if number of jets is given, if not, set to maximum available
194
213
  if num_jets is None:
195
214
  num_jets = self.num_jets
215
+
216
+ # Check if variables if given, if not, set to all
196
217
  if variables is None:
197
218
  variables = {self.jets_name: None}
219
+
198
220
  if self.jets_name not in variables or variables[self.jets_name] is not None:
199
221
  jet_vars = variables.get(self.jets_name, [])
200
222
  variables[self.jets_name] = list(jet_vars) + (cuts.variables if cuts else [])
@@ -207,12 +229,25 @@ class H5Reader:
207
229
 
208
230
  rng = np.random.default_rng(42)
209
231
  while True:
210
- # yeild from each stream
211
232
  samples = []
212
- for stream in streams:
213
- try:
214
- samples.append(next(stream))
215
- except StopIteration:
233
+ # Track which streams have been exhausted
234
+ streams_done = [False] * len(streams)
235
+
236
+ # for each unexhausted stream, get the next sample
237
+ for i, stream in enumerate(streams):
238
+ if not streams_done[i]:
239
+ try:
240
+ samples.append(next(stream))
241
+
242
+ # if equal_jets is True, we can stop when any stream is done
243
+ # otherwise if sample is exhausted, mark it as done
244
+ except StopIteration:
245
+ if self.equal_jets:
246
+ return
247
+ streams_done[i] = True
248
+
249
+ # if equal_jets is False, we need to keep going until all streams are done
250
+ if all(streams_done):
216
251
  return
217
252
 
218
253
  # combine samples and shuffle
@@ -222,23 +257,72 @@ class H5Reader:
222
257
  rng.shuffle(idx)
223
258
  data = {name: array[idx] for name, array in data.items()}
224
259
 
225
- # select
260
+ # yield batch
226
261
  yield data
227
262
 
228
263
  def load(
229
264
  self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
230
265
  ) -> dict:
266
+ """Load multiple batches of selected jets into memory.
267
+
268
+ Parameters
269
+ ----------
270
+ variables : dict | None, optional
271
+ Dictionary of variables to for each group, by default use all jet variables.
272
+ num_jets : int | None, optional
273
+ Total number of selected jets to load, by default all.
274
+ cuts : Cuts | None, optional
275
+ Selection cuts to apply, by default None
276
+
277
+ Returns
278
+ -------
279
+ dict
280
+ Dictionary of arrays for each group.
281
+ """
282
+ # handle default arguments
283
+ if num_jets == -1:
284
+ num_jets = self.num_jets
231
285
  if variables is None:
232
286
  variables = {self.jets_name: None}
287
+
288
+ # get data from each sample
233
289
  data: dict[str, list] = {name: [] for name in variables}
234
- for sample in self.stream(variables, num_jets, cuts):
235
- for name, array in sample.items():
290
+ for batch in self.stream(variables, num_jets, cuts):
291
+ for name, array in batch.items():
236
292
  if name in data:
237
293
  data[name].append(array)
294
+
295
+ # concatenate batches
238
296
  return {name: np.concatenate(array) for name, array in data.items()}
239
297
 
240
298
  def estimate_available_jets(self, cuts: Cuts, num: int = 1_000_000) -> int:
241
- """Estimate the number of jets available after selection cuts, rounded down."""
242
- all_jets = self.load({self.jets_name: cuts.variables}, num)[self.jets_name]
243
- estimated_num_jets = len(cuts(all_jets).values) / len(all_jets) * self.num_jets
299
+ """Estimate the number of jets available after selection cuts (round down).
300
+
301
+ Parameters
302
+ ----------
303
+ cuts : Cuts
304
+ Selection cuts to apply.
305
+ num : int, optional
306
+ Number of jets to use for the estimation, by default 1_000_000.
307
+
308
+ Returns
309
+ -------
310
+ int
311
+ Estimated number of jets available after selection cuts,
312
+ rounded down to nearest thousand.
313
+ """
314
+ # if equal jets is True, available jets is based on the smallest sample
315
+ if self.equal_jets:
316
+ num_jets = []
317
+ for r in self.readers:
318
+ stream = r.stream({self.jets_name: cuts.variables}, num)
319
+ all_jets = np.concatenate([batch[self.jets_name] for batch in stream])
320
+ frac_selected = len(cuts(all_jets).values) / len(all_jets)
321
+ num_jets.append(frac_selected * r.num_jets)
322
+ estimated_num_jets = min(num_jets) * len(self.readers)
323
+ # otherwise, available jets is based on all samples
324
+ else:
325
+ all_jets = self.load({self.jets_name: cuts.variables}, num)[self.jets_name]
326
+ frac_selected = len(cuts(all_jets).values) / len(all_jets)
327
+ estimated_num_jets = frac_selected * self.num_jets
244
328
  return math.floor(estimated_num_jets / 1_000) * 1_000
ftag/mock.py CHANGED
@@ -92,6 +92,7 @@ def get_mock_file(num_jets=1000, tracks_name: str = "tracks", num_tracks: int =
92
92
  fname = NamedTemporaryFile(suffix=".h5", dir=mkdtemp()).name
93
93
  f = h5py.File(fname, "w")
94
94
  f.create_dataset("jets", data=jets)
95
+ f.attrs["test"] = "test"
95
96
 
96
97
  # setup tracks
97
98
  if tracks_name:
ftag/vds.py CHANGED
@@ -1,11 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import argparse
3
4
  import glob
4
5
  from pathlib import Path
5
6
 
6
7
  import h5py
7
8
 
8
9
 
10
+ def parse_args(args):
11
+ parser = argparse.ArgumentParser(
12
+ description="Create a lightweight wrapper around a set of h5 files"
13
+ )
14
+ parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
15
+ parser.add_argument("output", type=Path, help="path to output virtual file")
16
+ args = parser.parse_args(args)
17
+ return args
18
+
19
+
9
20
  def get_virtual_layout(fnames: list[str], group: str):
10
21
  # get sources
11
22
  sources = []
@@ -58,20 +69,22 @@ def create_virtual_file(
58
69
  for group in h5py.File(fnames[0]):
59
70
  layout = get_virtual_layout(fnames, group)
60
71
  f.create_virtual_dataset(group, layout)
72
+ attrs_dict: dict = {}
73
+ for fname in fnames:
74
+ with h5py.File(fname) as g:
75
+ for name, value in g[group].attrs.items():
76
+ if name not in attrs_dict:
77
+ attrs_dict[name] = []
78
+ attrs_dict[name].append(value)
79
+ for name, value in attrs_dict.items():
80
+ if len(value) > 0:
81
+ f[group].attrs[name] = value[0]
61
82
 
62
83
  return out_fname
63
84
 
64
85
 
65
- def main():
66
- import argparse
67
-
68
- parser = argparse.ArgumentParser(
69
- description="Create a lightweight wrapper around a set of h5 files"
70
- )
71
- parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
72
- parser.add_argument("output", type=Path, help="path to output virtual file")
73
- args = parser.parse_args()
74
-
86
+ def main(args=None):
87
+ args = parse_args(args)
75
88
  print(f"Globbing {args.pattern}...")
76
89
  create_virtual_file(args.pattern, args.output, overwrite=True)
77
90
  with h5py.File(args.output) as f:
@@ -1,19 +0,0 @@
1
- ftag/__init__.py,sha256=3VQyLgnMa0A0325TNda80-4qGbPPnQmkrQZq1-klRcA,543
2
- ftag/cuts.py,sha256=Ge4WXLPg3WNgGxg-g7oIgCbbNFcKZonvkyskU0fDuDg,2733
3
- ftag/flavour.py,sha256=sEelvHNLWmHsecQQrmRc8ktwykMMHnGX8ePDRrqQkuo,2460
4
- ftag/flavours.yaml,sha256=S4WoB_n2uqvjo8_mlvNA1wKUwz9aFLhpyXtWsR8uR80,3121
5
- ftag/mock.py,sha256=Y__r5zToQLqrBg7T1a5RF_ten_gwBHIqgQOtj2DhIhU,3598
6
- ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
7
- ftag/sample.py,sha256=uVNyxFYMMtkP-o2tjQatpo8mIH4ZNNe3mSFEPebYh_E,2622
8
- ftag/vds.py,sha256=FmpP31YiSKBvh6TRIMWr-_aJHAkQs0Trhmqh2KLfT64,2402
9
- ftag/hdf5/__init__.py,sha256=A_a_4IUlZ2mSiDcfrZKBdja_3iTrUHvADM2lWx6g66g,325
10
- ftag/hdf5/h5reader.py,sha256=ayKX3xUiyV42avsCZQhcTYuNLPgJ3NQCS1qUjSggcKQ,8659
11
- ftag/hdf5/h5utils.py,sha256=GKduv9b6JRSBirRdmNgGcmsINCMTj54kH4RQqxrM1t8,2363
12
- ftag/hdf5/h5writer.py,sha256=_N-DJSX283r-XsGczvLFA4_qaK4BkFkdKZAusHEvRjU,2919
13
- ftag/wps/discriminant.py,sha256=86ISONTuIjqTJO1A27oqkoCgDjAQinofiYNdcjfdkIk,1380
14
- ftag/wps/working_points.py,sha256=487NsQGGY2Qt4q8mXxKABMFa-YLsbrhkPLcYVdebeVk,4950
15
- atlas_ftag_tools-0.1.3.dist-info/METADATA,sha256=stGxR0B4fZIJyJFpOO3vtJR9ytUrMeDHIP6qafWPDzI,4023
16
- atlas_ftag_tools-0.1.3.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
17
- atlas_ftag_tools-0.1.3.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
18
- atlas_ftag_tools-0.1.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
19
- atlas_ftag_tools-0.1.3.dist-info/RECORD,,