atlas-ftag-tools 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ftag/hdf5/h5add_col.py ADDED
@@ -0,0 +1,391 @@
1
+ # Utils to take an input h5 file, and append one or more columns to it
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import importlib.util
6
+ from pathlib import Path
7
+ from typing import Callable
8
+
9
+ import h5py
10
+ import numpy as np
11
+
12
+ from ftag.hdf5.h5reader import H5Reader
13
+ from ftag.hdf5.h5writer import H5Writer
14
+
15
+
16
+ def merge_dicts(dicts: list[dict[str, dict[str, np.ndarray]]]) -> dict[str, dict[str, np.ndarray]]:
17
+ """Merges a list of dictionaries.
18
+
19
+ Each dict is of the form:
20
+ {
21
+ group1: {
22
+ variable_1: np.array
23
+ variable_2: np.array
24
+ },
25
+ group2: {
26
+ variable_1: np.array
27
+ variable_2: np.array
28
+ }
29
+ }
30
+
31
+ E.g.
32
+
33
+ dict1 = {
34
+ "jets": {
35
+ "pt": np.array([1, 2, 3]),
36
+ "eta": np.array([4, 5, 6])
37
+ },
38
+ }
39
+ dict2 = {
40
+ "jets": {
41
+ "phi": np.array([7, 8, 9]),
42
+ "energy": np.array([10, 11, 12])
43
+ },
44
+ }
45
+
46
+ merged = {
47
+ "jets": {
48
+ "pt": np.array([1, 2, 3]),
49
+ "eta": np.array([4, 5, 6]),
50
+ "phi": np.array([7, 8, 9]),
51
+ "energy": np.array([10, 11, 12])
52
+ }
53
+ }
54
+
55
+ Parameters
56
+ ----------
57
+ dicts : list[dict[str, dict[str, np.ndarray]]]
58
+ List of dictionaries to merge. Each dictionary should be of the form:
59
+
60
+ Returns
61
+ -------
62
+ dict[str, dict[str, np.ndarray]]
63
+ Merged dictionary of the form:
64
+ {
65
+ group1: {
66
+ variable_1: np.array
67
+ variable_2: np.array
68
+ },
69
+ group2: {
70
+ variable_1: np.array
71
+ variable_2: np.array
72
+ }
73
+ }
74
+
75
+ Raises
76
+ ------
77
+ ValueError
78
+ If a variable already exists in the merged dictionary.
79
+ """
80
+ merged: dict[str, dict[str, np.ndarray]] = {}
81
+ for d in dicts:
82
+ for group, variables in d.items():
83
+ if group not in merged:
84
+ merged[group] = {}
85
+ for variable, data in variables.items():
86
+ if variable not in merged[group]:
87
+ merged[group][variable] = data
88
+ else:
89
+ raise ValueError(f"Variable {variable} already exists in group {group}.")
90
+ return merged
91
+
92
+
93
+ def get_shape(num_jets: int, batch: dict[str, np.ndarray]) -> dict[str, tuple[int, ...]]:
94
+ """Returns a dictionary with the correct output shapes for the H5Writer.
95
+
96
+ Parameters
97
+ ----------
98
+ num_jets : int
99
+ Number of jets to write in total
100
+ batch : dict[str, np.ndarray]
101
+ Dictionary representing the batch
102
+
103
+ Returns
104
+ -------
105
+ dict[str, tuple[int, ...]]
106
+ Dictionary with the shapes of the output arrays
107
+ """
108
+ shape: dict[str, tuple[int, ...]] = {}
109
+
110
+ for key, values in batch.items():
111
+ if values.ndim == 1:
112
+ shape[key] = (num_jets,)
113
+ else:
114
+ shape[key] = (num_jets,) + values.shape[1:]
115
+ return shape
116
+
117
+
118
+ def get_all_groups(file: Path | str) -> dict[str, None]:
119
+ """Returns a dictionary with all the groups in the h5 file.
120
+
121
+ Parameters
122
+ ----------
123
+ file : Path | str
124
+ Path to the h5 file
125
+
126
+ Returns
127
+ -------
128
+ dict[str, None]
129
+ A dictionary with all the groups in the h5 file as keys and None as values,
130
+ such that h5read.stream(all_groups) will return all the groups in the file.
131
+ """
132
+ with h5py.File(file, "r") as f:
133
+ groups = list(f.keys())
134
+ return dict.fromkeys(groups)
135
+
136
+
137
+ def h5_add_column(
138
+ input_file: str | Path,
139
+ output_file: str | Path,
140
+ append_function: Callable | list[Callable],
141
+ num_jets: int = -1,
142
+ input_groups: list[str] | None = None,
143
+ output_groups: list[str] | None = None,
144
+ reader_kwargs: dict | None = None,
145
+ writer_kwargs: dict | None = None,
146
+ overwrite: bool = False,
147
+ ) -> None:
148
+ """Appends one or more columns to one or more groups in an h5 file.
149
+
150
+ Parameters
151
+ ----------
152
+ input_file : str | Path
153
+ Input h5 file to read from.
154
+ output_file : str | Path
155
+ Output h5 file to write to.
156
+ append_function : callable | list[callable]
157
+ A function, or list of functions, which take a batch from H5Reader and returns a dictionary
158
+ of the form:
159
+ {
160
+ group1 : {
161
+ new_column1 : data,
162
+ new_column2 : data,
163
+ },
164
+ group2 : {
165
+ new_column3 : data,
166
+ new_column4 : data,
167
+ },
168
+ ...
169
+ }
170
+ num_jets : int, optional
171
+ Number of jets to read from the input file. If -1, reads all jets. By default -1.
172
+ input_groups : list[str] | None, optional
173
+ List of groups to read from the input file. If None, reads all groups. By default None.
174
+ output_groups : list[str] | None, optional
175
+ List of groups to write to the output file. If None, writes all groups. By default None.
176
+ Note that this is a subset of the input groups, and must include all groups that the
177
+ append functions wish to write to.
178
+ reader_kwargs : dict, optional
179
+ Additional arguments to pass to the H5Reader. By default None.
180
+ writer_kwargs : dict, optional
181
+ Additional arguments to pass to the H5Writer. By default None.
182
+ overwrite : bool, optional
183
+ If True, will overwrite the output file if it exists. By default False.
184
+ If False, will raise a FileExistsError if the output file exists.
185
+ If None, will check if the output file exists and raise an error if it does unless
186
+ overwrite is True.
187
+
188
+ Raises
189
+ ------
190
+ FileNotFoundError
191
+ If the input file does not exist.
192
+ FileExistsError
193
+ If the output file exists and overwrite is False.
194
+ ValueError
195
+ If the new variable already exists, shape is incorrect, or the output group is not in
196
+ the input groups.
197
+
198
+ """
199
+ input_file = Path(input_file)
200
+ output_file = Path(output_file) if output_file is not None else None
201
+
202
+ if not input_file.exists():
203
+ raise FileNotFoundError(f"Input file {input_file} does not exist.")
204
+ if output_file is not None and output_file.exists() and not overwrite:
205
+ raise FileExistsError(
206
+ f"Output file {output_file} already exists. Please choose a different name."
207
+ )
208
+ if not reader_kwargs:
209
+ reader_kwargs = {}
210
+ if not writer_kwargs:
211
+ writer_kwargs = {}
212
+ if output_file is None:
213
+ output_file = input_file.with_name(input_file.name.replace(".h5", "_additional.h5"))
214
+
215
+ if not isinstance(append_function, list):
216
+ append_function = [append_function]
217
+
218
+ reader = H5Reader(input_file, shuffle=False, **reader_kwargs)
219
+ if "precision" not in writer_kwargs:
220
+ writer_kwargs["precision"] = "full"
221
+
222
+ njets = reader.num_jets if num_jets == -1 else num_jets
223
+ writer = None
224
+
225
+ input_variables = (
226
+ get_all_groups(input_file) if input_groups is None else dict.fromkeys(input_groups)
227
+ )
228
+ if output_groups is None:
229
+ output_groups = list(input_variables.keys())
230
+
231
+ assert all(
232
+ o in input_variables for o in output_groups
233
+ ), f"Output groups {output_groups} not in input groups {input_variables.keys()}"
234
+
235
+ num_batches = njets // reader.batch_size + 1
236
+ for i, batch in enumerate(reader.stream(input_variables, num_jets=njets)):
237
+ if (i + 1) % 10 == 0:
238
+ print(f"Processing batch {i + 1}/{num_batches} ({(i + 1) / num_batches * 100:.2f}%)")
239
+
240
+ to_append = merge_dicts([af(batch) for af in append_function])
241
+ for k, newvars in to_append.items():
242
+ if k not in output_groups:
243
+ raise ValueError(f"Trying to output to {k} but only {output_groups} are allowed")
244
+ for newkey, newval in newvars.items():
245
+ if newkey in batch[k].dtype.names:
246
+ raise ValueError(
247
+ f"Trying to append {newkey} to {k} but it already exists in batch"
248
+ )
249
+ if newval.shape != batch[k].shape:
250
+ raise ValueError(
251
+ f"Trying to append {newkey} to {k} but the shape is not correct"
252
+ )
253
+
254
+ to_write = {}
255
+
256
+ for key, str_array in batch.items():
257
+ if key not in output_groups:
258
+ continue
259
+ if key in to_append:
260
+ combined = np.lib.recfunctions.append_fields(
261
+ str_array,
262
+ list(to_append[key].keys()),
263
+ list(to_append[key].values()),
264
+ usemask=False,
265
+ )
266
+ to_write[key] = combined
267
+ else:
268
+ to_write[key] = str_array
269
+ if writer is None:
270
+ writer = H5Writer(
271
+ output_file,
272
+ dtypes={key: str_array.dtype for key, str_array in to_write.items()},
273
+ shapes=get_shape(njets, to_write),
274
+ shuffle=False,
275
+ **writer_kwargs,
276
+ )
277
+
278
+ writer.write(to_write)
279
+
280
+
281
+ def parse_append_function(func_path: str) -> Callable:
282
+ """Attempts to load the function specified by func_path.
283
+ The function should be specified as 'path/to/file.py:function_name'.
284
+
285
+ Parameters
286
+ ----------
287
+ func_path : str
288
+ Path to the function to load. Should be of the form 'path/to/file.py:function_name'.
289
+
290
+ Returns
291
+ -------
292
+ Callable
293
+ The function specified by func_path.
294
+
295
+ Raises
296
+ ------
297
+ ValueError
298
+ If the function path is not of the form 'path/to/file.py:function_name'.
299
+ FileNotFoundError
300
+ If the file does not exist.
301
+ ImportError
302
+ If the file cannot be imported.
303
+ AttributeError
304
+ If the function does not exist in the file.
305
+ """
306
+ if isinstance(func_path, Path):
307
+ func_path = str(func_path)
308
+ if ":" not in func_path:
309
+ print(func_path)
310
+ raise ValueError("Function should be specified as 'path/to/file.py:function_name'")
311
+
312
+ file_str, func_name = func_path.split(":")
313
+ file_path = Path(file_str).resolve()
314
+
315
+ if not file_path.is_file():
316
+ raise FileNotFoundError(f"No such file: {file_path}")
317
+
318
+ module_name = file_path.stem # Just the filename without extension
319
+
320
+ spec = importlib.util.spec_from_file_location(module_name, str(file_path))
321
+ if spec is None or spec.loader is None:
322
+ raise ImportError(f"Cannot load spec for {file_path}")
323
+
324
+ module = importlib.util.module_from_spec(spec)
325
+ spec.loader.exec_module(module)
326
+
327
+ if not hasattr(module, func_name):
328
+ raise AttributeError(f"Module {module_name} has no attribute {func_name}")
329
+
330
+ return getattr(module, func_name)
331
+
332
+
333
+ def get_args(args):
334
+ parser = argparse.ArgumentParser(description="Append columns to an h5 file.")
335
+ parser.add_argument("--input", "-i", type=str, required=True, help="Input h5 file")
336
+ parser.add_argument(
337
+ "--append_function",
338
+ type=str,
339
+ nargs="+",
340
+ help="Function to append to the h5 file. Can be a list of functions.",
341
+ required=True,
342
+ )
343
+ parser.add_argument("--output", type=str, help="Output h5 file")
344
+ parser.add_argument(
345
+ "--num_jets", type=int, default=-1, help="Number of jets to read from the input file"
346
+ )
347
+ parser.add_argument(
348
+ "--input_groups",
349
+ type=str,
350
+ nargs="+",
351
+ default=None,
352
+ help="List of groups to read from the input file",
353
+ )
354
+ parser.add_argument(
355
+ "--output_groups",
356
+ type=str,
357
+ nargs="+",
358
+ default=None,
359
+ help="List of groups to write to the output file",
360
+ )
361
+ parser.add_argument(
362
+ "--reader_kwargs", type=dict, default=None, help="Additional arguments for H5Reader"
363
+ )
364
+ parser.add_argument(
365
+ "--writer_kwargs", type=dict, default=None, help="Additional arguments for H5Writer"
366
+ )
367
+ parser.add_argument(
368
+ "--overwrite", action="store_true", help="Overwrite the output file if it exists"
369
+ )
370
+
371
+ return parser.parse_args(args)
372
+
373
+
374
+ def main(args=None):
375
+ args = get_args(args)
376
+ append_function = [
377
+ parse_append_function(func_path) if isinstance(func_path, str) else func_path
378
+ for func_path in args.append_function
379
+ ]
380
+
381
+ h5_add_column(
382
+ args.input,
383
+ args.output,
384
+ append_function,
385
+ num_jets=args.num_jets,
386
+ input_groups=args.input_groups,
387
+ output_groups=args.output_groups,
388
+ reader_kwargs=args.reader_kwargs,
389
+ writer_kwargs=args.writer_kwargs,
390
+ overwrite=args.overwrite,
391
+ )
ftag/hdf5/h5reader.py CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
74
74
  num_jets: int | None = None,
75
75
  cuts: Cuts | None = None,
76
76
  start: int = 0,
77
+ skip_batches: int = 0,
77
78
  ) -> Generator:
78
79
  if num_jets is None:
79
80
  num_jets = self.num_jets
80
-
81
+ if skip_batches > 0:
82
+ assert not self.shuffle, "Cannot skip batches if shuffle is True"
81
83
  if num_jets > self.num_jets:
82
84
  log.warning(
83
85
  f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
@@ -97,7 +99,8 @@ class H5SingleReader:
97
99
  indices = list(range(start, self.num_jets + start, self.batch_size))
98
100
  if self.shuffle:
99
101
  self.rng.shuffle(indices)
100
-
102
+ if skip_batches > 0:
103
+ indices = indices[skip_batches:]
101
104
  # loop over batches and read file
102
105
  for low in indices:
103
106
  for name in variables:
@@ -176,7 +179,12 @@ class H5Reader:
176
179
 
177
180
  # calculate batch sizes
178
181
  if self.weights is None:
179
- self.weights = [1 / len(self.fname)] * len(self.fname)
182
+ rows_per_file = [
183
+ H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
184
+ ]
185
+ num_total = sum(rows_per_file)
186
+ self.weights = [num / num_total for num in rows_per_file]
187
+
180
188
  self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
181
189
 
182
190
  # create readers
@@ -233,6 +241,7 @@ class H5Reader:
233
241
  num_jets: int | None = None,
234
242
  cuts: Cuts | None = None,
235
243
  start: int = 0,
244
+ skip_batches: int = 0,
236
245
  ) -> Generator:
237
246
  """Generate batches of selected jets.
238
247
 
@@ -246,6 +255,8 @@ class H5Reader:
246
255
  Selection cuts to apply, by default None
247
256
  start : int, optional
248
257
  Starting index of the first jet to read, by default 0
258
+ skip_batches : int, optional
259
+ Number of batches to skip, by default 0
249
260
 
250
261
  Yields
251
262
  ------
@@ -266,7 +277,9 @@ class H5Reader:
266
277
 
267
278
  # get streams for selected jets from each reader
268
279
  streams = [
269
- r.stream(variables, int(r.num_jets / self.num_jets * num_jets), cuts, start)
280
+ r.stream(
281
+ variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
282
+ )
270
283
  for r in self.readers
271
284
  ]
272
285
 
ftag/hdf5/h5utils.py CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
13
13
  variables: list[str] | None = None,
14
14
  precision: str | None = None,
15
15
  transform: Transform | None = None,
16
+ full_precision_vars: list[str] | None = None,
16
17
  ) -> np.dtype:
17
18
  """Return a dtype based on an existing dataset and requested variables.
18
19
 
@@ -26,6 +27,8 @@ def get_dtype(
26
27
  Precision to cast floats to, "half" or "full", by default None
27
28
  transform : Transform | None, optional
28
29
  Transform to apply to variables names, by default None
30
+ full_precision_vars : list[str] | None, optional
31
+ List of variables to keep in full precision, by default None
29
32
 
30
33
  Returns
31
34
  -------
@@ -39,6 +42,8 @@ def get_dtype(
39
42
  """
40
43
  if variables is None:
41
44
  variables = ds.dtype.names
45
+ if full_precision_vars is None:
46
+ full_precision_vars = []
42
47
 
43
48
  if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
44
49
  variables = transform.map_variable_names(ds.name, variables, inverse=True)
@@ -50,7 +55,10 @@ def get_dtype(
50
55
 
51
56
  dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
52
57
  if precision:
53
- dtype = [(n, cast_dtype(x, precision)) for n, x in dtype]
58
+ dtype = [
59
+ (n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
60
+ for n, x in dtype
61
+ ]
54
62
 
55
63
  return np.dtype(dtype)
56
64
 
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
78
86
  t = np.dtype(typestr)
79
87
  if t.kind != "f":
80
88
  return t
89
+
81
90
  if precision == "half":
82
91
  return np.dtype("f2")
83
92
  if precision == "full":
ftag/hdf5/h5writer.py CHANGED
@@ -31,8 +31,11 @@ class H5Writer:
31
31
  Compression algorithm to use. Default is "lzf".
32
32
  precision : str | None, optional
33
33
  Precision to use. Default is None.
34
+ full_precision_vars : list[str] | None, optional
35
+ List of variables to store in full precision. Default is None.
34
36
  shuffle : bool, optional
35
37
  Whether to shuffle the jets before writing. Default is True.
38
+
36
39
  """
37
40
 
38
41
  dst: Path | str
@@ -42,19 +45,30 @@ class H5Writer:
42
45
  add_flavour_label: bool = False
43
46
  compression: str = "lzf"
44
47
  precision: str = "full"
48
+ full_precision_vars: list[str] | None = None
45
49
  shuffle: bool = True
50
+ num_jets: int | None = None # Allow dynamic mode by defaulting to None
46
51
 
47
52
  def __post_init__(self):
48
53
  self.num_written = 0
49
54
  self.rng = np.random.default_rng(42)
50
- self.num_jets = [shape[0] for shape in self.shapes.values()]
51
- assert len(set(self.num_jets)) == 1, "Must have same number of jets per group"
52
- self.num_jets = self.num_jets[0]
55
+
56
+ # Infer number of jets from shapes if not explicitly passed
57
+ inferred_num_jets = [shape[0] for shape in self.shapes.values()]
58
+ if self.num_jets is None:
59
+ assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
60
+ self.fixed_mode = False
61
+ else:
62
+ self.fixed_mode = True
63
+ for name in self.shapes:
64
+ self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
53
65
 
54
66
  if self.precision == "full":
55
67
  self.fp_dtype = np.float32
56
68
  elif self.precision == "half":
57
69
  self.fp_dtype = np.float16
70
+ elif self.precision is None:
71
+ self.fp_dtype = None
58
72
  else:
59
73
  raise ValueError(f"Invalid precision: {self.precision}")
60
74
 
@@ -67,16 +81,34 @@ class H5Writer:
67
81
  self.create_ds(name, dtype)
68
82
 
69
83
  @classmethod
70
- def from_file(cls, source: Path, num_jets: int | None = None, **kwargs) -> H5Writer:
84
+ def from_file(
85
+ cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
86
+ ) -> H5Writer:
71
87
  with h5py.File(source, "r") as f:
72
88
  dtypes = {name: ds.dtype for name, ds in f.items()}
73
89
  shapes = {name: ds.shape for name, ds in f.items()}
74
- if num_jets is not None:
90
+
91
+ if variables:
92
+ new_dtye = {}
93
+ new_shape = {}
94
+ for name, ds in f.items():
95
+ if name not in variables:
96
+ continue
97
+ new_dtye[name] = ftag.hdf5.get_dtype(
98
+ ds,
99
+ variables=variables[name],
100
+ precision=kwargs.get("precision"),
101
+ full_precision_vars=kwargs.get("full_precision_vars"),
102
+ )
103
+ new_shape[name] = ds.shape
104
+ dtypes = new_dtye
105
+ shapes = new_shape
106
+ if num_jets != 0:
75
107
  shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
76
108
  compression = [ds.compression for ds in f.values()]
77
109
  assert len(set(compression)) == 1, "Must have same compression for all groups"
78
110
  compression = compression[0]
79
- if compression not in kwargs:
111
+ if "compression" not in kwargs:
80
112
  kwargs["compression"] = compression
81
113
  return cls(dtypes=dtypes, shapes=shapes, **kwargs)
82
114
 
@@ -84,29 +116,47 @@ class H5Writer:
84
116
  if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
85
117
  dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
86
118
 
87
- # adjust dtype based on specified precision
119
+ fp_vars = self.full_precision_vars or []
120
+ # If no precision is defined, or the field is in full_precision_vars, or its non-float,
121
+ # keep it at the original dtype
88
122
  dtype = np.dtype([
89
- (field, self.fp_dtype if np.issubdtype(dt, np.floating) else dt)
123
+ (
124
+ field,
125
+ (
126
+ self.fp_dtype
127
+ if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
128
+ else dt
129
+ ),
130
+ )
90
131
  for field, dt in dtype.descr
91
132
  ])
92
133
 
93
- # optimal chunking is around 100 jets, only aply for track groups
94
134
  shape = self.shapes[name]
95
135
  chunks = (100,) + shape[1:] if shape[1:] else None
96
136
 
97
- # note: enabling the hd5 shuffle filter doesn't improve write performance
98
- self.file.create_dataset(
99
- name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
100
- )
137
+ if self.fixed_mode:
138
+ self.file.create_dataset(
139
+ name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
140
+ )
141
+ else:
142
+ maxshape = (None,) + shape[1:]
143
+ self.file.create_dataset(
144
+ name,
145
+ dtype=dtype,
146
+ shape=(0,) + shape[1:],
147
+ maxshape=maxshape,
148
+ compression=self.compression,
149
+ chunks=chunks,
150
+ )
101
151
 
102
152
  def close(self) -> None:
103
- with h5py.File(self.dst) as f:
104
- written = len(f[self.jets_name])
105
- if self.num_written != written:
106
- raise ValueError(
107
- f"Attemped to close file {self.dst} when only {self.num_written:,} out of"
108
- f" {written:,} jets have been written"
109
- )
153
+ if self.fixed_mode:
154
+ written = len(self.file[self.jets_name])
155
+ if self.num_written != written:
156
+ raise ValueError(
157
+ f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
158
+ f" {written:,} jets have been written"
159
+ )
110
160
  self.file.close()
111
161
 
112
162
  def get_attr(self, name, group=None):
@@ -126,18 +176,25 @@ class H5Writer:
126
176
  for attr_name, value in ds.attrs.items():
127
177
  self.add_attr(attr_name, value, group=name)
128
178
 
129
- def write(self, data: dict[str, np.array]) -> None:
130
- if (total := self.num_written + len(data[self.jets_name])) > self.num_jets:
131
- raise ValueError(
132
- f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
133
- )
134
- idx = np.arange(len(data[self.jets_name]))
179
+ def write(self, data: dict[str, np.ndarray]) -> None:
180
+ batch_size = len(data[self.jets_name])
181
+ idx = np.arange(batch_size)
135
182
  if self.shuffle:
136
183
  self.rng.shuffle(idx)
137
184
  data = {name: array[idx] for name, array in data.items()}
138
185
 
139
186
  low = self.num_written
140
- high = low + len(idx)
187
+ high = low + batch_size
188
+
189
+ if self.fixed_mode and high > self.num_jets:
190
+ raise ValueError(
191
+ f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
192
+ )
193
+
141
194
  for group in self.dtypes:
142
- self.file[group][low:high] = data[group]
143
- self.num_written += len(idx)
195
+ ds = self.file[group]
196
+ if not self.fixed_mode:
197
+ ds.resize((high,) + ds.shape[1:])
198
+ ds[low:high] = data[group]
199
+
200
+ self.num_written += batch_size
ftag/labeller.py CHANGED
@@ -30,7 +30,7 @@ class Labeller:
30
30
  def __post_init__(self) -> None:
31
31
  if isinstance(self.labels, LabelContainer):
32
32
  self.labels = list(self.labels)
33
- self.labels = sorted([Flavours[label] for label in self.labels])
33
+ self.labels = [Flavours[label] for label in self.labels]
34
34
 
35
35
  @property
36
36
  def variables(self) -> list[str]: