atlas-ftag-tools 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.1.3.dist-info → atlas_ftag_tools-0.1.5.dist-info}/METADATA +7 -3
- atlas_ftag_tools-0.1.5.dist-info/RECORD +19 -0
- ftag/__init__.py +1 -1
- ftag/cuts.py +1 -1
- ftag/flavours.yaml +7 -2
- ftag/hdf5/h5reader.py +101 -17
- ftag/mock.py +1 -0
- ftag/vds.py +23 -10
- atlas_ftag_tools-0.1.3.dist-info/RECORD +0 -19
- {atlas_ftag_tools-0.1.3.dist-info → atlas_ftag_tools-0.1.5.dist-info}/WHEEL +0 -0
- {atlas_ftag_tools-0.1.3.dist-info → atlas_ftag_tools-0.1.5.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.1.3.dist-info → atlas_ftag_tools-0.1.5.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: atlas-ftag-tools
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: ATLAS Flavour Tagging Tools
|
5
5
|
Author: Sam Van Stroud, Philipp Gadow
|
6
6
|
License: MIT
|
@@ -39,7 +39,7 @@ If you want to use this package without modification, you can install from [pypi
|
|
39
39
|
pip install atlas-ftag-tools
|
40
40
|
```
|
41
41
|
|
42
|
-
To additionally install the development dependencies (for formatting and linting)
|
42
|
+
To additionally install the development dependencies (for formatting and linting) use
|
43
43
|
```bash
|
44
44
|
pip install atlas-ftag-tools[dev]
|
45
45
|
```
|
@@ -58,10 +58,11 @@ Include development dependencies with
|
|
58
58
|
python -m pip install -e ".[dev]"
|
59
59
|
```
|
60
60
|
|
61
|
-
You can set up pre-commit hooks with
|
61
|
+
You can set up and run pre-commit hooks with
|
62
62
|
|
63
63
|
```bash
|
64
64
|
pre-commit install
|
65
|
+
pre-commmit run --all-files
|
65
66
|
```
|
66
67
|
|
67
68
|
To run the tests you can use the `pytest` or `coverage` command, for example
|
@@ -75,6 +76,9 @@ Running `coverage report` will display the test coverage.
|
|
75
76
|
|
76
77
|
# Usage
|
77
78
|
|
79
|
+
Please see the [example notebook](ftag/example.ipynb) for full usage.
|
80
|
+
Additional functionality is also documented below.
|
81
|
+
|
78
82
|
## Create virtual file
|
79
83
|
|
80
84
|
This package contains a script to easily merge a set of H5 files.
|
@@ -0,0 +1,19 @@
|
|
1
|
+
ftag/__init__.py,sha256=XBQEZpFSnGyihB9F3eGOvB_5YknggY_L6fzwYszXLuQ,543
|
2
|
+
ftag/cuts.py,sha256=lCnyHd4kbrt3CMXGE1ASCgaa07o1qOBn6GQek6lClVQ,2734
|
3
|
+
ftag/flavour.py,sha256=sEelvHNLWmHsecQQrmRc8ktwykMMHnGX8ePDRrqQkuo,2460
|
4
|
+
ftag/flavours.yaml,sha256=VrOGD5FUhMVPIW31whY-nSqNv98AcnLsPmPGmAcCg3w,3287
|
5
|
+
ftag/mock.py,sha256=HUyYOPsRtkmzjLRNF2zs0kpVUrTRIHTsnIyDlXIZArU,3627
|
6
|
+
ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
|
7
|
+
ftag/sample.py,sha256=uVNyxFYMMtkP-o2tjQatpo8mIH4ZNNe3mSFEPebYh_E,2622
|
8
|
+
ftag/vds.py,sha256=8b5-zqDELUmxdO5Txdowe3v7XGS1pKgO20bhzUQqCxU,2945
|
9
|
+
ftag/hdf5/__init__.py,sha256=A_a_4IUlZ2mSiDcfrZKBdja_3iTrUHvADM2lWx6g66g,325
|
10
|
+
ftag/hdf5/h5reader.py,sha256=1_iyYfWI1ht1-p9vBBpGhw47ZKola_KhWxbrywoB-Jg,11751
|
11
|
+
ftag/hdf5/h5utils.py,sha256=GKduv9b6JRSBirRdmNgGcmsINCMTj54kH4RQqxrM1t8,2363
|
12
|
+
ftag/hdf5/h5writer.py,sha256=_N-DJSX283r-XsGczvLFA4_qaK4BkFkdKZAusHEvRjU,2919
|
13
|
+
ftag/wps/discriminant.py,sha256=86ISONTuIjqTJO1A27oqkoCgDjAQinofiYNdcjfdkIk,1380
|
14
|
+
ftag/wps/working_points.py,sha256=487NsQGGY2Qt4q8mXxKABMFa-YLsbrhkPLcYVdebeVk,4950
|
15
|
+
atlas_ftag_tools-0.1.5.dist-info/METADATA,sha256=Uc4Z2zAMD7jsSKoV6o2LJwfm2X0KEWYRingGT_msE4I,4182
|
16
|
+
atlas_ftag_tools-0.1.5.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
17
|
+
atlas_ftag_tools-0.1.5.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
|
18
|
+
atlas_ftag_tools-0.1.5.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
19
|
+
atlas_ftag_tools-0.1.5.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/cuts.py
CHANGED
@@ -20,7 +20,7 @@ OPERATORS = {
|
|
20
20
|
"notin": lambda x, y: ~np.isin(x, y),
|
21
21
|
}
|
22
22
|
|
23
|
-
for i in range(2,
|
23
|
+
for i in range(2, 101):
|
24
24
|
OPERATORS[f"%{i}=="] = functools.partial(lambda x, y, i: (x % i) == y, i=i)
|
25
25
|
OPERATORS[f"%{i}<="] = functools.partial(lambda x, y, i: (x % i) <= y, i=i)
|
26
26
|
OPERATORS[f"%{i}>="] = functools.partial(lambda x, y, i: (x % i) >= y, i=i)
|
ftag/flavours.yaml
CHANGED
@@ -49,12 +49,12 @@
|
|
49
49
|
|
50
50
|
# Xbb tagging
|
51
51
|
- name: hbb
|
52
|
-
label:
|
52
|
+
label: $H \rightarrow b\bar{b}$
|
53
53
|
cuts: ["R10TruthLabel_R22v1 == 11"]
|
54
54
|
colour: tab:blue
|
55
55
|
category: xbb
|
56
56
|
- name: hcc
|
57
|
-
label:
|
57
|
+
label: $H \rightarrow c\bar{c}$
|
58
58
|
cuts: ["R10TruthLabel_R22v1 == 12"]
|
59
59
|
colour: "#B45F06"
|
60
60
|
category: xbb
|
@@ -63,6 +63,11 @@
|
|
63
63
|
cuts: ["R10TruthLabel_R22v1 == 1"]
|
64
64
|
colour: "#A300A3"
|
65
65
|
category: xbb
|
66
|
+
- name: inclusive_top
|
67
|
+
label: Inclusive Top
|
68
|
+
cuts: ["R10TruthLabel_R22v1 in (1,6,7)"]
|
69
|
+
colour: "#A300A3"
|
70
|
+
category: xbb
|
66
71
|
- name: qcd
|
67
72
|
label: QCD
|
68
73
|
cuts: ["R10TruthLabel_R22v1 == 10"]
|
ftag/hdf5/h5reader.py
CHANGED
@@ -26,9 +26,10 @@ class H5SingleReader:
|
|
26
26
|
|
27
27
|
def __post_init__(self) -> None:
|
28
28
|
self.sample = Sample(self.fname)
|
29
|
-
|
29
|
+
fname = self.sample.virtual_file()
|
30
|
+
if len(fname) != 1:
|
30
31
|
raise ValueError("H5SingleReader should only read a single file")
|
31
|
-
self.fname =
|
32
|
+
self.fname = fname[0]
|
32
33
|
|
33
34
|
@cached_property
|
34
35
|
def num_jets(self) -> int:
|
@@ -57,21 +58,27 @@ class H5SingleReader:
|
|
57
58
|
isinf = np.isinf(array[var])
|
58
59
|
keep_idx = keep_idx & ~isinf.any(axis=-1)
|
59
60
|
if num_inf := isinf.sum():
|
60
|
-
log.
|
61
|
+
log.warning(
|
61
62
|
f"{num_inf} inf values detected for variable {var} in"
|
62
63
|
f" {name} array. Removing the affected jets."
|
63
64
|
)
|
64
65
|
return {name: array[keep_idx] for name, array in data.items()}
|
65
66
|
|
66
67
|
def stream(
|
67
|
-
self,
|
68
|
+
self,
|
69
|
+
variables: dict | None = None,
|
70
|
+
num_jets: int | None = None,
|
71
|
+
cuts: Cuts | None = None,
|
68
72
|
) -> Generator:
|
69
73
|
if num_jets is None:
|
70
74
|
num_jets = self.num_jets
|
75
|
+
|
71
76
|
if num_jets > self.num_jets:
|
72
|
-
|
73
|
-
f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}"
|
77
|
+
log.warning(
|
78
|
+
f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
|
79
|
+
" Set to maximum available number!"
|
74
80
|
)
|
81
|
+
num_jets = self.num_jets
|
75
82
|
|
76
83
|
if variables is None:
|
77
84
|
variables = {self.jets_name: None}
|
@@ -131,6 +138,9 @@ class H5Reader:
|
|
131
138
|
Weights for different input datasets, by default None
|
132
139
|
do_remove_inf : bool, optional
|
133
140
|
Remove jets with inf values, by default False
|
141
|
+
equal_jets : bool, optional
|
142
|
+
Take the same number of jets (weighted) from each sample, by default True
|
143
|
+
If False, use all jets in each sample.
|
134
144
|
"""
|
135
145
|
|
136
146
|
fname: Path | str | list[Path | str]
|
@@ -140,8 +150,16 @@ class H5Reader:
|
|
140
150
|
shuffle: bool = True
|
141
151
|
weights: list[float] | None = None
|
142
152
|
do_remove_inf: bool = False
|
153
|
+
equal_jets: bool = True
|
143
154
|
|
144
155
|
def __post_init__(self) -> None:
|
156
|
+
if not self.equal_jets:
|
157
|
+
log.warning(
|
158
|
+
"equal_jets is set to False, which will result in different number of jets taken"
|
159
|
+
" from each sample. Be aware that this can affect the resampling, so make sure you"
|
160
|
+
" know what you are doing."
|
161
|
+
)
|
162
|
+
|
145
163
|
if isinstance(self.fname, (str, Path)):
|
146
164
|
self.fname = [self.fname]
|
147
165
|
|
@@ -191,10 +209,14 @@ class H5Reader:
|
|
191
209
|
Generator
|
192
210
|
Generator of batches of selected jets.
|
193
211
|
"""
|
212
|
+
# Check if number of jets is given, if not, set to maximum available
|
194
213
|
if num_jets is None:
|
195
214
|
num_jets = self.num_jets
|
215
|
+
|
216
|
+
# Check if variables if given, if not, set to all
|
196
217
|
if variables is None:
|
197
218
|
variables = {self.jets_name: None}
|
219
|
+
|
198
220
|
if self.jets_name not in variables or variables[self.jets_name] is not None:
|
199
221
|
jet_vars = variables.get(self.jets_name, [])
|
200
222
|
variables[self.jets_name] = list(jet_vars) + (cuts.variables if cuts else [])
|
@@ -207,12 +229,25 @@ class H5Reader:
|
|
207
229
|
|
208
230
|
rng = np.random.default_rng(42)
|
209
231
|
while True:
|
210
|
-
# yeild from each stream
|
211
232
|
samples = []
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
233
|
+
# Track which streams have been exhausted
|
234
|
+
streams_done = [False] * len(streams)
|
235
|
+
|
236
|
+
# for each unexhausted stream, get the next sample
|
237
|
+
for i, stream in enumerate(streams):
|
238
|
+
if not streams_done[i]:
|
239
|
+
try:
|
240
|
+
samples.append(next(stream))
|
241
|
+
|
242
|
+
# if equal_jets is True, we can stop when any stream is done
|
243
|
+
# otherwise if sample is exhausted, mark it as done
|
244
|
+
except StopIteration:
|
245
|
+
if self.equal_jets:
|
246
|
+
return
|
247
|
+
streams_done[i] = True
|
248
|
+
|
249
|
+
# if equal_jets is False, we need to keep going until all streams are done
|
250
|
+
if all(streams_done):
|
216
251
|
return
|
217
252
|
|
218
253
|
# combine samples and shuffle
|
@@ -222,23 +257,72 @@ class H5Reader:
|
|
222
257
|
rng.shuffle(idx)
|
223
258
|
data = {name: array[idx] for name, array in data.items()}
|
224
259
|
|
225
|
-
#
|
260
|
+
# yield batch
|
226
261
|
yield data
|
227
262
|
|
228
263
|
def load(
|
229
264
|
self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
|
230
265
|
) -> dict:
|
266
|
+
"""Load multiple batches of selected jets into memory.
|
267
|
+
|
268
|
+
Parameters
|
269
|
+
----------
|
270
|
+
variables : dict | None, optional
|
271
|
+
Dictionary of variables to for each group, by default use all jet variables.
|
272
|
+
num_jets : int | None, optional
|
273
|
+
Total number of selected jets to load, by default all.
|
274
|
+
cuts : Cuts | None, optional
|
275
|
+
Selection cuts to apply, by default None
|
276
|
+
|
277
|
+
Returns
|
278
|
+
-------
|
279
|
+
dict
|
280
|
+
Dictionary of arrays for each group.
|
281
|
+
"""
|
282
|
+
# handle default arguments
|
283
|
+
if num_jets == -1:
|
284
|
+
num_jets = self.num_jets
|
231
285
|
if variables is None:
|
232
286
|
variables = {self.jets_name: None}
|
287
|
+
|
288
|
+
# get data from each sample
|
233
289
|
data: dict[str, list] = {name: [] for name in variables}
|
234
|
-
for
|
235
|
-
for name, array in
|
290
|
+
for batch in self.stream(variables, num_jets, cuts):
|
291
|
+
for name, array in batch.items():
|
236
292
|
if name in data:
|
237
293
|
data[name].append(array)
|
294
|
+
|
295
|
+
# concatenate batches
|
238
296
|
return {name: np.concatenate(array) for name, array in data.items()}
|
239
297
|
|
240
298
|
def estimate_available_jets(self, cuts: Cuts, num: int = 1_000_000) -> int:
|
241
|
-
"""Estimate the number of jets available after selection cuts
|
242
|
-
|
243
|
-
|
299
|
+
"""Estimate the number of jets available after selection cuts (round down).
|
300
|
+
|
301
|
+
Parameters
|
302
|
+
----------
|
303
|
+
cuts : Cuts
|
304
|
+
Selection cuts to apply.
|
305
|
+
num : int, optional
|
306
|
+
Number of jets to use for the estimation, by default 1_000_000.
|
307
|
+
|
308
|
+
Returns
|
309
|
+
-------
|
310
|
+
int
|
311
|
+
Estimated number of jets available after selection cuts,
|
312
|
+
rounded down to nearest thousand.
|
313
|
+
"""
|
314
|
+
# if equal jets is True, available jets is based on the smallest sample
|
315
|
+
if self.equal_jets:
|
316
|
+
num_jets = []
|
317
|
+
for r in self.readers:
|
318
|
+
stream = r.stream({self.jets_name: cuts.variables}, num)
|
319
|
+
all_jets = np.concatenate([batch[self.jets_name] for batch in stream])
|
320
|
+
frac_selected = len(cuts(all_jets).values) / len(all_jets)
|
321
|
+
num_jets.append(frac_selected * r.num_jets)
|
322
|
+
estimated_num_jets = min(num_jets) * len(self.readers)
|
323
|
+
# otherwise, available jets is based on all samples
|
324
|
+
else:
|
325
|
+
all_jets = self.load({self.jets_name: cuts.variables}, num)[self.jets_name]
|
326
|
+
frac_selected = len(cuts(all_jets).values) / len(all_jets)
|
327
|
+
estimated_num_jets = frac_selected * self.num_jets
|
244
328
|
return math.floor(estimated_num_jets / 1_000) * 1_000
|
ftag/mock.py
CHANGED
@@ -92,6 +92,7 @@ def get_mock_file(num_jets=1000, tracks_name: str = "tracks", num_tracks: int =
|
|
92
92
|
fname = NamedTemporaryFile(suffix=".h5", dir=mkdtemp()).name
|
93
93
|
f = h5py.File(fname, "w")
|
94
94
|
f.create_dataset("jets", data=jets)
|
95
|
+
f.attrs["test"] = "test"
|
95
96
|
|
96
97
|
# setup tracks
|
97
98
|
if tracks_name:
|
ftag/vds.py
CHANGED
@@ -1,11 +1,22 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import argparse
|
3
4
|
import glob
|
4
5
|
from pathlib import Path
|
5
6
|
|
6
7
|
import h5py
|
7
8
|
|
8
9
|
|
10
|
+
def parse_args(args):
|
11
|
+
parser = argparse.ArgumentParser(
|
12
|
+
description="Create a lightweight wrapper around a set of h5 files"
|
13
|
+
)
|
14
|
+
parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
|
15
|
+
parser.add_argument("output", type=Path, help="path to output virtual file")
|
16
|
+
args = parser.parse_args(args)
|
17
|
+
return args
|
18
|
+
|
19
|
+
|
9
20
|
def get_virtual_layout(fnames: list[str], group: str):
|
10
21
|
# get sources
|
11
22
|
sources = []
|
@@ -58,20 +69,22 @@ def create_virtual_file(
|
|
58
69
|
for group in h5py.File(fnames[0]):
|
59
70
|
layout = get_virtual_layout(fnames, group)
|
60
71
|
f.create_virtual_dataset(group, layout)
|
72
|
+
attrs_dict: dict = {}
|
73
|
+
for fname in fnames:
|
74
|
+
with h5py.File(fname) as g:
|
75
|
+
for name, value in g[group].attrs.items():
|
76
|
+
if name not in attrs_dict:
|
77
|
+
attrs_dict[name] = []
|
78
|
+
attrs_dict[name].append(value)
|
79
|
+
for name, value in attrs_dict.items():
|
80
|
+
if len(value) > 0:
|
81
|
+
f[group].attrs[name] = value[0]
|
61
82
|
|
62
83
|
return out_fname
|
63
84
|
|
64
85
|
|
65
|
-
def main():
|
66
|
-
|
67
|
-
|
68
|
-
parser = argparse.ArgumentParser(
|
69
|
-
description="Create a lightweight wrapper around a set of h5 files"
|
70
|
-
)
|
71
|
-
parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
|
72
|
-
parser.add_argument("output", type=Path, help="path to output virtual file")
|
73
|
-
args = parser.parse_args()
|
74
|
-
|
86
|
+
def main(args=None):
|
87
|
+
args = parse_args(args)
|
75
88
|
print(f"Globbing {args.pattern}...")
|
76
89
|
create_virtual_file(args.pattern, args.output, overwrite=True)
|
77
90
|
with h5py.File(args.output) as f:
|
@@ -1,19 +0,0 @@
|
|
1
|
-
ftag/__init__.py,sha256=3VQyLgnMa0A0325TNda80-4qGbPPnQmkrQZq1-klRcA,543
|
2
|
-
ftag/cuts.py,sha256=Ge4WXLPg3WNgGxg-g7oIgCbbNFcKZonvkyskU0fDuDg,2733
|
3
|
-
ftag/flavour.py,sha256=sEelvHNLWmHsecQQrmRc8ktwykMMHnGX8ePDRrqQkuo,2460
|
4
|
-
ftag/flavours.yaml,sha256=S4WoB_n2uqvjo8_mlvNA1wKUwz9aFLhpyXtWsR8uR80,3121
|
5
|
-
ftag/mock.py,sha256=Y__r5zToQLqrBg7T1a5RF_ten_gwBHIqgQOtj2DhIhU,3598
|
6
|
-
ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
|
7
|
-
ftag/sample.py,sha256=uVNyxFYMMtkP-o2tjQatpo8mIH4ZNNe3mSFEPebYh_E,2622
|
8
|
-
ftag/vds.py,sha256=FmpP31YiSKBvh6TRIMWr-_aJHAkQs0Trhmqh2KLfT64,2402
|
9
|
-
ftag/hdf5/__init__.py,sha256=A_a_4IUlZ2mSiDcfrZKBdja_3iTrUHvADM2lWx6g66g,325
|
10
|
-
ftag/hdf5/h5reader.py,sha256=ayKX3xUiyV42avsCZQhcTYuNLPgJ3NQCS1qUjSggcKQ,8659
|
11
|
-
ftag/hdf5/h5utils.py,sha256=GKduv9b6JRSBirRdmNgGcmsINCMTj54kH4RQqxrM1t8,2363
|
12
|
-
ftag/hdf5/h5writer.py,sha256=_N-DJSX283r-XsGczvLFA4_qaK4BkFkdKZAusHEvRjU,2919
|
13
|
-
ftag/wps/discriminant.py,sha256=86ISONTuIjqTJO1A27oqkoCgDjAQinofiYNdcjfdkIk,1380
|
14
|
-
ftag/wps/working_points.py,sha256=487NsQGGY2Qt4q8mXxKABMFa-YLsbrhkPLcYVdebeVk,4950
|
15
|
-
atlas_ftag_tools-0.1.3.dist-info/METADATA,sha256=stGxR0B4fZIJyJFpOO3vtJR9ytUrMeDHIP6qafWPDzI,4023
|
16
|
-
atlas_ftag_tools-0.1.3.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
17
|
-
atlas_ftag_tools-0.1.3.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
|
18
|
-
atlas_ftag_tools-0.1.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
19
|
-
atlas_ftag_tools-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|