atlas-ftag-tools 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_ftag_tools-0.2.12.dist-info/METADATA +53 -0
- atlas_ftag_tools-0.2.12.dist-info/RECORD +32 -0
- {atlas_ftag_tools-0.2.10.dist-info → atlas_ftag_tools-0.2.12.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.10.dist-info → atlas_ftag_tools-0.2.12.dist-info}/entry_points.txt +1 -0
- atlas_ftag_tools-0.2.12.dist-info/licenses/LICENSE +201 -0
- ftag/__init__.py +11 -11
- ftag/flavours.yaml +18 -13
- ftag/hdf5/__init__.py +5 -3
- ftag/hdf5/h5add_col.py +391 -0
- ftag/hdf5/h5reader.py +17 -4
- ftag/hdf5/h5utils.py +10 -1
- ftag/hdf5/h5writer.py +86 -29
- ftag/labeller.py +1 -1
- ftag/mock.py +2 -2
- ftag/utils/__init__.py +2 -2
- ftag/vds.py +39 -4
- atlas_ftag_tools-0.2.10.dist-info/METADATA +0 -151
- atlas_ftag_tools-0.2.10.dist-info/RECORD +0 -30
- {atlas_ftag_tools-0.2.10.dist-info → atlas_ftag_tools-0.2.12.dist-info}/top_level.txt +0 -0
ftag/hdf5/h5add_col.py
ADDED
@@ -0,0 +1,391 @@
|
|
1
|
+
# Utils to take an input h5 file, and append one or more columns to it
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import argparse
|
5
|
+
import importlib.util
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Callable
|
8
|
+
|
9
|
+
import h5py
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from ftag.hdf5.h5reader import H5Reader
|
13
|
+
from ftag.hdf5.h5writer import H5Writer
|
14
|
+
|
15
|
+
|
16
|
+
def merge_dicts(dicts: list[dict[str, dict[str, np.ndarray]]]) -> dict[str, dict[str, np.ndarray]]:
|
17
|
+
"""Merges a list of dictionaries.
|
18
|
+
|
19
|
+
Each dict is of the form:
|
20
|
+
{
|
21
|
+
group1: {
|
22
|
+
variable_1: np.array
|
23
|
+
variable_2: np.array
|
24
|
+
},
|
25
|
+
group2: {
|
26
|
+
variable_1: np.array
|
27
|
+
variable_2: np.array
|
28
|
+
}
|
29
|
+
}
|
30
|
+
|
31
|
+
E.g.
|
32
|
+
|
33
|
+
dict1 = {
|
34
|
+
"jets": {
|
35
|
+
"pt": np.array([1, 2, 3]),
|
36
|
+
"eta": np.array([4, 5, 6])
|
37
|
+
},
|
38
|
+
}
|
39
|
+
dict2 = {
|
40
|
+
"jets": {
|
41
|
+
"phi": np.array([7, 8, 9]),
|
42
|
+
"energy": np.array([10, 11, 12])
|
43
|
+
},
|
44
|
+
}
|
45
|
+
|
46
|
+
merged = {
|
47
|
+
"jets": {
|
48
|
+
"pt": np.array([1, 2, 3]),
|
49
|
+
"eta": np.array([4, 5, 6]),
|
50
|
+
"phi": np.array([7, 8, 9]),
|
51
|
+
"energy": np.array([10, 11, 12])
|
52
|
+
}
|
53
|
+
}
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
dicts : list[dict[str, dict[str, np.ndarray]]]
|
58
|
+
List of dictionaries to merge. Each dictionary should be of the form:
|
59
|
+
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
dict[str, dict[str, np.ndarray]]
|
63
|
+
Merged dictionary of the form:
|
64
|
+
{
|
65
|
+
group1: {
|
66
|
+
variable_1: np.array
|
67
|
+
variable_2: np.array
|
68
|
+
},
|
69
|
+
group2: {
|
70
|
+
variable_1: np.array
|
71
|
+
variable_2: np.array
|
72
|
+
}
|
73
|
+
}
|
74
|
+
|
75
|
+
Raises
|
76
|
+
------
|
77
|
+
ValueError
|
78
|
+
If a variable already exists in the merged dictionary.
|
79
|
+
"""
|
80
|
+
merged: dict[str, dict[str, np.ndarray]] = {}
|
81
|
+
for d in dicts:
|
82
|
+
for group, variables in d.items():
|
83
|
+
if group not in merged:
|
84
|
+
merged[group] = {}
|
85
|
+
for variable, data in variables.items():
|
86
|
+
if variable not in merged[group]:
|
87
|
+
merged[group][variable] = data
|
88
|
+
else:
|
89
|
+
raise ValueError(f"Variable {variable} already exists in group {group}.")
|
90
|
+
return merged
|
91
|
+
|
92
|
+
|
93
|
+
def get_shape(num_jets: int, batch: dict[str, np.ndarray]) -> dict[str, tuple[int, ...]]:
|
94
|
+
"""Returns a dictionary with the correct output shapes for the H5Writer.
|
95
|
+
|
96
|
+
Parameters
|
97
|
+
----------
|
98
|
+
num_jets : int
|
99
|
+
Number of jets to write in total
|
100
|
+
batch : dict[str, np.ndarray]
|
101
|
+
Dictionary representing the batch
|
102
|
+
|
103
|
+
Returns
|
104
|
+
-------
|
105
|
+
dict[str, tuple[int, ...]]
|
106
|
+
Dictionary with the shapes of the output arrays
|
107
|
+
"""
|
108
|
+
shape: dict[str, tuple[int, ...]] = {}
|
109
|
+
|
110
|
+
for key, values in batch.items():
|
111
|
+
if values.ndim == 1:
|
112
|
+
shape[key] = (num_jets,)
|
113
|
+
else:
|
114
|
+
shape[key] = (num_jets,) + values.shape[1:]
|
115
|
+
return shape
|
116
|
+
|
117
|
+
|
118
|
+
def get_all_groups(file: Path | str) -> dict[str, None]:
|
119
|
+
"""Returns a dictionary with all the groups in the h5 file.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
file : Path | str
|
124
|
+
Path to the h5 file
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
dict[str, None]
|
129
|
+
A dictionary with all the groups in the h5 file as keys and None as values,
|
130
|
+
such that h5read.stream(all_groups) will return all the groups in the file.
|
131
|
+
"""
|
132
|
+
with h5py.File(file, "r") as f:
|
133
|
+
groups = list(f.keys())
|
134
|
+
return dict.fromkeys(groups)
|
135
|
+
|
136
|
+
|
137
|
+
def h5_add_column(
|
138
|
+
input_file: str | Path,
|
139
|
+
output_file: str | Path,
|
140
|
+
append_function: Callable | list[Callable],
|
141
|
+
num_jets: int = -1,
|
142
|
+
input_groups: list[str] | None = None,
|
143
|
+
output_groups: list[str] | None = None,
|
144
|
+
reader_kwargs: dict | None = None,
|
145
|
+
writer_kwargs: dict | None = None,
|
146
|
+
overwrite: bool = False,
|
147
|
+
) -> None:
|
148
|
+
"""Appends one or more columns to one or more groups in an h5 file.
|
149
|
+
|
150
|
+
Parameters
|
151
|
+
----------
|
152
|
+
input_file : str | Path
|
153
|
+
Input h5 file to read from.
|
154
|
+
output_file : str | Path
|
155
|
+
Output h5 file to write to.
|
156
|
+
append_function : callable | list[callable]
|
157
|
+
A function, or list of functions, which take a batch from H5Reader and returns a dictionary
|
158
|
+
of the form:
|
159
|
+
{
|
160
|
+
group1 : {
|
161
|
+
new_column1 : data,
|
162
|
+
new_column2 : data,
|
163
|
+
},
|
164
|
+
group2 : {
|
165
|
+
new_column3 : data,
|
166
|
+
new_column4 : data,
|
167
|
+
},
|
168
|
+
...
|
169
|
+
}
|
170
|
+
num_jets : int, optional
|
171
|
+
Number of jets to read from the input file. If -1, reads all jets. By default -1.
|
172
|
+
input_groups : list[str] | None, optional
|
173
|
+
List of groups to read from the input file. If None, reads all groups. By default None.
|
174
|
+
output_groups : list[str] | None, optional
|
175
|
+
List of groups to write to the output file. If None, writes all groups. By default None.
|
176
|
+
Note that this is a subset of the input groups, and must include all groups that the
|
177
|
+
append functions wish to write to.
|
178
|
+
reader_kwargs : dict, optional
|
179
|
+
Additional arguments to pass to the H5Reader. By default None.
|
180
|
+
writer_kwargs : dict, optional
|
181
|
+
Additional arguments to pass to the H5Writer. By default None.
|
182
|
+
overwrite : bool, optional
|
183
|
+
If True, will overwrite the output file if it exists. By default False.
|
184
|
+
If False, will raise a FileExistsError if the output file exists.
|
185
|
+
If None, will check if the output file exists and raise an error if it does unless
|
186
|
+
overwrite is True.
|
187
|
+
|
188
|
+
Raises
|
189
|
+
------
|
190
|
+
FileNotFoundError
|
191
|
+
If the input file does not exist.
|
192
|
+
FileExistsError
|
193
|
+
If the output file exists and overwrite is False.
|
194
|
+
ValueError
|
195
|
+
If the new variable already exists, shape is incorrect, or the output group is not in
|
196
|
+
the input groups.
|
197
|
+
|
198
|
+
"""
|
199
|
+
input_file = Path(input_file)
|
200
|
+
output_file = Path(output_file) if output_file is not None else None
|
201
|
+
|
202
|
+
if not input_file.exists():
|
203
|
+
raise FileNotFoundError(f"Input file {input_file} does not exist.")
|
204
|
+
if output_file is not None and output_file.exists() and not overwrite:
|
205
|
+
raise FileExistsError(
|
206
|
+
f"Output file {output_file} already exists. Please choose a different name."
|
207
|
+
)
|
208
|
+
if not reader_kwargs:
|
209
|
+
reader_kwargs = {}
|
210
|
+
if not writer_kwargs:
|
211
|
+
writer_kwargs = {}
|
212
|
+
if output_file is None:
|
213
|
+
output_file = input_file.with_name(input_file.name.replace(".h5", "_additional.h5"))
|
214
|
+
|
215
|
+
if not isinstance(append_function, list):
|
216
|
+
append_function = [append_function]
|
217
|
+
|
218
|
+
reader = H5Reader(input_file, shuffle=False, **reader_kwargs)
|
219
|
+
if "precision" not in writer_kwargs:
|
220
|
+
writer_kwargs["precision"] = "full"
|
221
|
+
|
222
|
+
njets = reader.num_jets if num_jets == -1 else num_jets
|
223
|
+
writer = None
|
224
|
+
|
225
|
+
input_variables = (
|
226
|
+
get_all_groups(input_file) if input_groups is None else dict.fromkeys(input_groups)
|
227
|
+
)
|
228
|
+
if output_groups is None:
|
229
|
+
output_groups = list(input_variables.keys())
|
230
|
+
|
231
|
+
assert all(
|
232
|
+
o in input_variables for o in output_groups
|
233
|
+
), f"Output groups {output_groups} not in input groups {input_variables.keys()}"
|
234
|
+
|
235
|
+
num_batches = njets // reader.batch_size + 1
|
236
|
+
for i, batch in enumerate(reader.stream(input_variables, num_jets=njets)):
|
237
|
+
if (i + 1) % 10 == 0:
|
238
|
+
print(f"Processing batch {i + 1}/{num_batches} ({(i + 1) / num_batches * 100:.2f}%)")
|
239
|
+
|
240
|
+
to_append = merge_dicts([af(batch) for af in append_function])
|
241
|
+
for k, newvars in to_append.items():
|
242
|
+
if k not in output_groups:
|
243
|
+
raise ValueError(f"Trying to output to {k} but only {output_groups} are allowed")
|
244
|
+
for newkey, newval in newvars.items():
|
245
|
+
if newkey in batch[k].dtype.names:
|
246
|
+
raise ValueError(
|
247
|
+
f"Trying to append {newkey} to {k} but it already exists in batch"
|
248
|
+
)
|
249
|
+
if newval.shape != batch[k].shape:
|
250
|
+
raise ValueError(
|
251
|
+
f"Trying to append {newkey} to {k} but the shape is not correct"
|
252
|
+
)
|
253
|
+
|
254
|
+
to_write = {}
|
255
|
+
|
256
|
+
for key, str_array in batch.items():
|
257
|
+
if key not in output_groups:
|
258
|
+
continue
|
259
|
+
if key in to_append:
|
260
|
+
combined = np.lib.recfunctions.append_fields(
|
261
|
+
str_array,
|
262
|
+
list(to_append[key].keys()),
|
263
|
+
list(to_append[key].values()),
|
264
|
+
usemask=False,
|
265
|
+
)
|
266
|
+
to_write[key] = combined
|
267
|
+
else:
|
268
|
+
to_write[key] = str_array
|
269
|
+
if writer is None:
|
270
|
+
writer = H5Writer(
|
271
|
+
output_file,
|
272
|
+
dtypes={key: str_array.dtype for key, str_array in to_write.items()},
|
273
|
+
shapes=get_shape(njets, to_write),
|
274
|
+
shuffle=False,
|
275
|
+
**writer_kwargs,
|
276
|
+
)
|
277
|
+
|
278
|
+
writer.write(to_write)
|
279
|
+
|
280
|
+
|
281
|
+
def parse_append_function(func_path: str) -> Callable:
|
282
|
+
"""Attempts to load the function specified by func_path.
|
283
|
+
The function should be specified as 'path/to/file.py:function_name'.
|
284
|
+
|
285
|
+
Parameters
|
286
|
+
----------
|
287
|
+
func_path : str
|
288
|
+
Path to the function to load. Should be of the form 'path/to/file.py:function_name'.
|
289
|
+
|
290
|
+
Returns
|
291
|
+
-------
|
292
|
+
Callable
|
293
|
+
The function specified by func_path.
|
294
|
+
|
295
|
+
Raises
|
296
|
+
------
|
297
|
+
ValueError
|
298
|
+
If the function path is not of the form 'path/to/file.py:function_name'.
|
299
|
+
FileNotFoundError
|
300
|
+
If the file does not exist.
|
301
|
+
ImportError
|
302
|
+
If the file cannot be imported.
|
303
|
+
AttributeError
|
304
|
+
If the function does not exist in the file.
|
305
|
+
"""
|
306
|
+
if isinstance(func_path, Path):
|
307
|
+
func_path = str(func_path)
|
308
|
+
if ":" not in func_path:
|
309
|
+
print(func_path)
|
310
|
+
raise ValueError("Function should be specified as 'path/to/file.py:function_name'")
|
311
|
+
|
312
|
+
file_str, func_name = func_path.split(":")
|
313
|
+
file_path = Path(file_str).resolve()
|
314
|
+
|
315
|
+
if not file_path.is_file():
|
316
|
+
raise FileNotFoundError(f"No such file: {file_path}")
|
317
|
+
|
318
|
+
module_name = file_path.stem # Just the filename without extension
|
319
|
+
|
320
|
+
spec = importlib.util.spec_from_file_location(module_name, str(file_path))
|
321
|
+
if spec is None or spec.loader is None:
|
322
|
+
raise ImportError(f"Cannot load spec for {file_path}")
|
323
|
+
|
324
|
+
module = importlib.util.module_from_spec(spec)
|
325
|
+
spec.loader.exec_module(module)
|
326
|
+
|
327
|
+
if not hasattr(module, func_name):
|
328
|
+
raise AttributeError(f"Module {module_name} has no attribute {func_name}")
|
329
|
+
|
330
|
+
return getattr(module, func_name)
|
331
|
+
|
332
|
+
|
333
|
+
def get_args(args):
|
334
|
+
parser = argparse.ArgumentParser(description="Append columns to an h5 file.")
|
335
|
+
parser.add_argument("--input", "-i", type=str, required=True, help="Input h5 file")
|
336
|
+
parser.add_argument(
|
337
|
+
"--append_function",
|
338
|
+
type=str,
|
339
|
+
nargs="+",
|
340
|
+
help="Function to append to the h5 file. Can be a list of functions.",
|
341
|
+
required=True,
|
342
|
+
)
|
343
|
+
parser.add_argument("--output", type=str, help="Output h5 file")
|
344
|
+
parser.add_argument(
|
345
|
+
"--num_jets", type=int, default=-1, help="Number of jets to read from the input file"
|
346
|
+
)
|
347
|
+
parser.add_argument(
|
348
|
+
"--input_groups",
|
349
|
+
type=str,
|
350
|
+
nargs="+",
|
351
|
+
default=None,
|
352
|
+
help="List of groups to read from the input file",
|
353
|
+
)
|
354
|
+
parser.add_argument(
|
355
|
+
"--output_groups",
|
356
|
+
type=str,
|
357
|
+
nargs="+",
|
358
|
+
default=None,
|
359
|
+
help="List of groups to write to the output file",
|
360
|
+
)
|
361
|
+
parser.add_argument(
|
362
|
+
"--reader_kwargs", type=dict, default=None, help="Additional arguments for H5Reader"
|
363
|
+
)
|
364
|
+
parser.add_argument(
|
365
|
+
"--writer_kwargs", type=dict, default=None, help="Additional arguments for H5Writer"
|
366
|
+
)
|
367
|
+
parser.add_argument(
|
368
|
+
"--overwrite", action="store_true", help="Overwrite the output file if it exists"
|
369
|
+
)
|
370
|
+
|
371
|
+
return parser.parse_args(args)
|
372
|
+
|
373
|
+
|
374
|
+
def main(args=None):
|
375
|
+
args = get_args(args)
|
376
|
+
append_function = [
|
377
|
+
parse_append_function(func_path) if isinstance(func_path, str) else func_path
|
378
|
+
for func_path in args.append_function
|
379
|
+
]
|
380
|
+
|
381
|
+
h5_add_column(
|
382
|
+
args.input,
|
383
|
+
args.output,
|
384
|
+
append_function,
|
385
|
+
num_jets=args.num_jets,
|
386
|
+
input_groups=args.input_groups,
|
387
|
+
output_groups=args.output_groups,
|
388
|
+
reader_kwargs=args.reader_kwargs,
|
389
|
+
writer_kwargs=args.writer_kwargs,
|
390
|
+
overwrite=args.overwrite,
|
391
|
+
)
|
ftag/hdf5/h5reader.py
CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
|
|
74
74
|
num_jets: int | None = None,
|
75
75
|
cuts: Cuts | None = None,
|
76
76
|
start: int = 0,
|
77
|
+
skip_batches: int = 0,
|
77
78
|
) -> Generator:
|
78
79
|
if num_jets is None:
|
79
80
|
num_jets = self.num_jets
|
80
|
-
|
81
|
+
if skip_batches > 0:
|
82
|
+
assert not self.shuffle, "Cannot skip batches if shuffle is True"
|
81
83
|
if num_jets > self.num_jets:
|
82
84
|
log.warning(
|
83
85
|
f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
|
@@ -97,7 +99,8 @@ class H5SingleReader:
|
|
97
99
|
indices = list(range(start, self.num_jets + start, self.batch_size))
|
98
100
|
if self.shuffle:
|
99
101
|
self.rng.shuffle(indices)
|
100
|
-
|
102
|
+
if skip_batches > 0:
|
103
|
+
indices = indices[skip_batches:]
|
101
104
|
# loop over batches and read file
|
102
105
|
for low in indices:
|
103
106
|
for name in variables:
|
@@ -176,7 +179,12 @@ class H5Reader:
|
|
176
179
|
|
177
180
|
# calculate batch sizes
|
178
181
|
if self.weights is None:
|
179
|
-
|
182
|
+
rows_per_file = [
|
183
|
+
H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
|
184
|
+
]
|
185
|
+
num_total = sum(rows_per_file)
|
186
|
+
self.weights = [num / num_total for num in rows_per_file]
|
187
|
+
|
180
188
|
self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
|
181
189
|
|
182
190
|
# create readers
|
@@ -233,6 +241,7 @@ class H5Reader:
|
|
233
241
|
num_jets: int | None = None,
|
234
242
|
cuts: Cuts | None = None,
|
235
243
|
start: int = 0,
|
244
|
+
skip_batches: int = 0,
|
236
245
|
) -> Generator:
|
237
246
|
"""Generate batches of selected jets.
|
238
247
|
|
@@ -246,6 +255,8 @@ class H5Reader:
|
|
246
255
|
Selection cuts to apply, by default None
|
247
256
|
start : int, optional
|
248
257
|
Starting index of the first jet to read, by default 0
|
258
|
+
skip_batches : int, optional
|
259
|
+
Number of batches to skip, by default 0
|
249
260
|
|
250
261
|
Yields
|
251
262
|
------
|
@@ -266,7 +277,9 @@ class H5Reader:
|
|
266
277
|
|
267
278
|
# get streams for selected jets from each reader
|
268
279
|
streams = [
|
269
|
-
r.stream(
|
280
|
+
r.stream(
|
281
|
+
variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
|
282
|
+
)
|
270
283
|
for r in self.readers
|
271
284
|
]
|
272
285
|
|
ftag/hdf5/h5utils.py
CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
|
|
13
13
|
variables: list[str] | None = None,
|
14
14
|
precision: str | None = None,
|
15
15
|
transform: Transform | None = None,
|
16
|
+
full_precision_vars: list[str] | None = None,
|
16
17
|
) -> np.dtype:
|
17
18
|
"""Return a dtype based on an existing dataset and requested variables.
|
18
19
|
|
@@ -26,6 +27,8 @@ def get_dtype(
|
|
26
27
|
Precision to cast floats to, "half" or "full", by default None
|
27
28
|
transform : Transform | None, optional
|
28
29
|
Transform to apply to variables names, by default None
|
30
|
+
full_precision_vars : list[str] | None, optional
|
31
|
+
List of variables to keep in full precision, by default None
|
29
32
|
|
30
33
|
Returns
|
31
34
|
-------
|
@@ -39,6 +42,8 @@ def get_dtype(
|
|
39
42
|
"""
|
40
43
|
if variables is None:
|
41
44
|
variables = ds.dtype.names
|
45
|
+
if full_precision_vars is None:
|
46
|
+
full_precision_vars = []
|
42
47
|
|
43
48
|
if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
|
44
49
|
variables = transform.map_variable_names(ds.name, variables, inverse=True)
|
@@ -50,7 +55,10 @@ def get_dtype(
|
|
50
55
|
|
51
56
|
dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
|
52
57
|
if precision:
|
53
|
-
dtype = [
|
58
|
+
dtype = [
|
59
|
+
(n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
|
60
|
+
for n, x in dtype
|
61
|
+
]
|
54
62
|
|
55
63
|
return np.dtype(dtype)
|
56
64
|
|
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
|
|
78
86
|
t = np.dtype(typestr)
|
79
87
|
if t.kind != "f":
|
80
88
|
return t
|
89
|
+
|
81
90
|
if precision == "half":
|
82
91
|
return np.dtype("f2")
|
83
92
|
if precision == "full":
|
ftag/hdf5/h5writer.py
CHANGED
@@ -31,8 +31,11 @@ class H5Writer:
|
|
31
31
|
Compression algorithm to use. Default is "lzf".
|
32
32
|
precision : str | None, optional
|
33
33
|
Precision to use. Default is None.
|
34
|
+
full_precision_vars : list[str] | None, optional
|
35
|
+
List of variables to store in full precision. Default is None.
|
34
36
|
shuffle : bool, optional
|
35
37
|
Whether to shuffle the jets before writing. Default is True.
|
38
|
+
|
36
39
|
"""
|
37
40
|
|
38
41
|
dst: Path | str
|
@@ -42,19 +45,30 @@ class H5Writer:
|
|
42
45
|
add_flavour_label: bool = False
|
43
46
|
compression: str = "lzf"
|
44
47
|
precision: str = "full"
|
48
|
+
full_precision_vars: list[str] | None = None
|
45
49
|
shuffle: bool = True
|
50
|
+
num_jets: int | None = None # Allow dynamic mode by defaulting to None
|
46
51
|
|
47
52
|
def __post_init__(self):
|
48
53
|
self.num_written = 0
|
49
54
|
self.rng = np.random.default_rng(42)
|
50
|
-
|
51
|
-
|
52
|
-
|
55
|
+
|
56
|
+
# Infer number of jets from shapes if not explicitly passed
|
57
|
+
inferred_num_jets = [shape[0] for shape in self.shapes.values()]
|
58
|
+
if self.num_jets is None:
|
59
|
+
assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
|
60
|
+
self.fixed_mode = False
|
61
|
+
else:
|
62
|
+
self.fixed_mode = True
|
63
|
+
for name in self.shapes:
|
64
|
+
self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
|
53
65
|
|
54
66
|
if self.precision == "full":
|
55
67
|
self.fp_dtype = np.float32
|
56
68
|
elif self.precision == "half":
|
57
69
|
self.fp_dtype = np.float16
|
70
|
+
elif self.precision is None:
|
71
|
+
self.fp_dtype = None
|
58
72
|
else:
|
59
73
|
raise ValueError(f"Invalid precision: {self.precision}")
|
60
74
|
|
@@ -67,16 +81,34 @@ class H5Writer:
|
|
67
81
|
self.create_ds(name, dtype)
|
68
82
|
|
69
83
|
@classmethod
|
70
|
-
def from_file(
|
84
|
+
def from_file(
|
85
|
+
cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
|
86
|
+
) -> H5Writer:
|
71
87
|
with h5py.File(source, "r") as f:
|
72
88
|
dtypes = {name: ds.dtype for name, ds in f.items()}
|
73
89
|
shapes = {name: ds.shape for name, ds in f.items()}
|
74
|
-
|
90
|
+
|
91
|
+
if variables:
|
92
|
+
new_dtye = {}
|
93
|
+
new_shape = {}
|
94
|
+
for name, ds in f.items():
|
95
|
+
if name not in variables:
|
96
|
+
continue
|
97
|
+
new_dtye[name] = ftag.hdf5.get_dtype(
|
98
|
+
ds,
|
99
|
+
variables=variables[name],
|
100
|
+
precision=kwargs.get("precision"),
|
101
|
+
full_precision_vars=kwargs.get("full_precision_vars"),
|
102
|
+
)
|
103
|
+
new_shape[name] = ds.shape
|
104
|
+
dtypes = new_dtye
|
105
|
+
shapes = new_shape
|
106
|
+
if num_jets != 0:
|
75
107
|
shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
|
76
108
|
compression = [ds.compression for ds in f.values()]
|
77
109
|
assert len(set(compression)) == 1, "Must have same compression for all groups"
|
78
110
|
compression = compression[0]
|
79
|
-
if compression not in kwargs:
|
111
|
+
if "compression" not in kwargs:
|
80
112
|
kwargs["compression"] = compression
|
81
113
|
return cls(dtypes=dtypes, shapes=shapes, **kwargs)
|
82
114
|
|
@@ -84,29 +116,47 @@ class H5Writer:
|
|
84
116
|
if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
|
85
117
|
dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
|
86
118
|
|
87
|
-
|
119
|
+
fp_vars = self.full_precision_vars or []
|
120
|
+
# If no precision is defined, or the field is in full_precision_vars, or its non-float,
|
121
|
+
# keep it at the original dtype
|
88
122
|
dtype = np.dtype([
|
89
|
-
(
|
123
|
+
(
|
124
|
+
field,
|
125
|
+
(
|
126
|
+
self.fp_dtype
|
127
|
+
if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
|
128
|
+
else dt
|
129
|
+
),
|
130
|
+
)
|
90
131
|
for field, dt in dtype.descr
|
91
132
|
])
|
92
133
|
|
93
|
-
# optimal chunking is around 100 jets, only aply for track groups
|
94
134
|
shape = self.shapes[name]
|
95
135
|
chunks = (100,) + shape[1:] if shape[1:] else None
|
96
136
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
137
|
+
if self.fixed_mode:
|
138
|
+
self.file.create_dataset(
|
139
|
+
name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
maxshape = (None,) + shape[1:]
|
143
|
+
self.file.create_dataset(
|
144
|
+
name,
|
145
|
+
dtype=dtype,
|
146
|
+
shape=(0,) + shape[1:],
|
147
|
+
maxshape=maxshape,
|
148
|
+
compression=self.compression,
|
149
|
+
chunks=chunks,
|
150
|
+
)
|
101
151
|
|
102
152
|
def close(self) -> None:
|
103
|
-
|
104
|
-
written = len(
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
153
|
+
if self.fixed_mode:
|
154
|
+
written = len(self.file[self.jets_name])
|
155
|
+
if self.num_written != written:
|
156
|
+
raise ValueError(
|
157
|
+
f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
|
158
|
+
f" {written:,} jets have been written"
|
159
|
+
)
|
110
160
|
self.file.close()
|
111
161
|
|
112
162
|
def get_attr(self, name, group=None):
|
@@ -126,18 +176,25 @@ class H5Writer:
|
|
126
176
|
for attr_name, value in ds.attrs.items():
|
127
177
|
self.add_attr(attr_name, value, group=name)
|
128
178
|
|
129
|
-
def write(self, data: dict[str, np.
|
130
|
-
|
131
|
-
|
132
|
-
f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
|
133
|
-
)
|
134
|
-
idx = np.arange(len(data[self.jets_name]))
|
179
|
+
def write(self, data: dict[str, np.ndarray]) -> None:
|
180
|
+
batch_size = len(data[self.jets_name])
|
181
|
+
idx = np.arange(batch_size)
|
135
182
|
if self.shuffle:
|
136
183
|
self.rng.shuffle(idx)
|
137
184
|
data = {name: array[idx] for name, array in data.items()}
|
138
185
|
|
139
186
|
low = self.num_written
|
140
|
-
high = low +
|
187
|
+
high = low + batch_size
|
188
|
+
|
189
|
+
if self.fixed_mode and high > self.num_jets:
|
190
|
+
raise ValueError(
|
191
|
+
f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
|
192
|
+
)
|
193
|
+
|
141
194
|
for group in self.dtypes:
|
142
|
-
self.file[group]
|
143
|
-
|
195
|
+
ds = self.file[group]
|
196
|
+
if not self.fixed_mode:
|
197
|
+
ds.resize((high,) + ds.shape[1:])
|
198
|
+
ds[low:high] = data[group]
|
199
|
+
|
200
|
+
self.num_written += batch_size
|
ftag/labeller.py
CHANGED
@@ -30,7 +30,7 @@ class Labeller:
|
|
30
30
|
def __post_init__(self) -> None:
|
31
31
|
if isinstance(self.labels, LabelContainer):
|
32
32
|
self.labels = list(self.labels)
|
33
|
-
self.labels =
|
33
|
+
self.labels = [Flavours[label] for label in self.labels]
|
34
34
|
|
35
35
|
@property
|
36
36
|
def variables(self) -> list[str]:
|