atlas-ftag-tools 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
8
- Requires-Python: <3.12,>=3.8
8
+ Requires-Python: <3.12,>=3.10
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
11
  Requires-Dist: h5py>=3.0
@@ -1,32 +1,32 @@
1
- atlas_ftag_tools-0.2.11.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
- ftag/__init__.py,sha256=BGQ1MtuhqCHFXRAh9S9f_ZnOCLWB5RA0ZtL9lW2tofs,748
1
+ atlas_ftag_tools-0.2.13.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
+ ftag/__init__.py,sha256=UdYmO_mROM7jvqpPUMbnaQxdCrlR8O0KLlhatyMnapw,748
3
3
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
4
4
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
5
5
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
6
- ftag/flavours.yaml,sha256=CrVTJKndHeL15LT2nkjPodi6Ck9mk_oUtdRby6X_Rcc,9921
6
+ ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
7
7
  ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
8
8
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
9
- ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
9
+ ftag/labeller.py,sha256=6tKLG0SrBijMIZdzWjGdQU9qN_dkRV4eELLX_8YQvTQ,2772
10
10
  ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
11
- ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
11
+ ftag/mock.py,sha256=syysvzLsBHU8aw7Uy5g3G4HB6LnSmHlGa2BfeXv5mQ4,5970
12
12
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
13
13
  ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
14
14
  ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
15
15
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
16
- ftag/vds.py,sha256=wqj1cA6mIJ4enk8inkearo7ccTw5KCbvuNo2oon51fc,4565
16
+ ftag/vds.py,sha256=l6b54naOK7z0gZjvvtIAQv2Ky4X1w1yLrisZZZYqvbY,11259
17
17
  ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
18
18
  ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
19
19
  ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
20
20
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
21
- ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
21
+ ftag/hdf5/h5reader.py,sha256=NbHohY3RSicM3qnX_0Y1TfGAaDg3wgfEjYlGaWDJmug,14268
22
22
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
23
- ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
24
- ftag/hdf5/h5writer.py,sha256=2gBztierWdwZIqcFItoYz8oua_7hphOI8mbDg7xBdPs,5784
23
+ ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
24
+ ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
25
25
  ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
26
26
  ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
27
27
  ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
28
- atlas_ftag_tools-0.2.11.dist-info/METADATA,sha256=DVmllPN7YQNNmyDcTs3hEGo8mX8ogSReXq9gs6MwUR0,2152
29
- atlas_ftag_tools-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- atlas_ftag_tools-0.2.11.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
- atlas_ftag_tools-0.2.11.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
- atlas_ftag_tools-0.2.11.dist-info/RECORD,,
28
+ atlas_ftag_tools-0.2.13.dist-info/METADATA,sha256=ZpQ5GggkLyizsv9uHEOvIlzRqPmC-4tNaoaMgV6unF4,2153
29
+ atlas_ftag_tools-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ atlas_ftag_tools-0.2.13.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
+ atlas_ftag_tools-0.2.13.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
+ atlas_ftag_tools-0.2.13.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.11"
5
+ __version__ = "v0.2.13"
6
6
 
7
7
  from . import hdf5, utils
8
8
  from .cuts import Cuts
ftag/flavours.yaml CHANGED
@@ -113,7 +113,7 @@
113
113
  category: xbb
114
114
  - name: qcdnonbb
115
115
  label: $\mathrm{QCD} \rightarrow \mathrm{non-} b \bar{b}$
116
- cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount != 2"]
116
+ cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount < 2"]
117
117
  colour: "silver"
118
118
  category: xbb
119
119
  - name: qcdbx
ftag/hdf5/h5reader.py CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
74
74
  num_jets: int | None = None,
75
75
  cuts: Cuts | None = None,
76
76
  start: int = 0,
77
+ skip_batches: int = 0,
77
78
  ) -> Generator:
78
79
  if num_jets is None:
79
80
  num_jets = self.num_jets
80
-
81
+ if skip_batches > 0:
82
+ assert not self.shuffle, "Cannot skip batches if shuffle is True"
81
83
  if num_jets > self.num_jets:
82
84
  log.warning(
83
85
  f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
@@ -97,7 +99,8 @@ class H5SingleReader:
97
99
  indices = list(range(start, self.num_jets + start, self.batch_size))
98
100
  if self.shuffle:
99
101
  self.rng.shuffle(indices)
100
-
102
+ if skip_batches > 0:
103
+ indices = indices[skip_batches:]
101
104
  # loop over batches and read file
102
105
  for low in indices:
103
106
  for name in variables:
@@ -176,7 +179,12 @@ class H5Reader:
176
179
 
177
180
  # calculate batch sizes
178
181
  if self.weights is None:
179
- self.weights = [1 / len(self.fname)] * len(self.fname)
182
+ rows_per_file = [
183
+ H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
184
+ ]
185
+ num_total = sum(rows_per_file)
186
+ self.weights = [num / num_total for num in rows_per_file]
187
+
180
188
  self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
181
189
 
182
190
  # create readers
@@ -233,6 +241,7 @@ class H5Reader:
233
241
  num_jets: int | None = None,
234
242
  cuts: Cuts | None = None,
235
243
  start: int = 0,
244
+ skip_batches: int = 0,
236
245
  ) -> Generator:
237
246
  """Generate batches of selected jets.
238
247
 
@@ -246,6 +255,8 @@ class H5Reader:
246
255
  Selection cuts to apply, by default None
247
256
  start : int, optional
248
257
  Starting index of the first jet to read, by default 0
258
+ skip_batches : int, optional
259
+ Number of batches to skip, by default 0
249
260
 
250
261
  Yields
251
262
  ------
@@ -266,7 +277,9 @@ class H5Reader:
266
277
 
267
278
  # get streams for selected jets from each reader
268
279
  streams = [
269
- r.stream(variables, int(r.num_jets / self.num_jets * num_jets), cuts, start)
280
+ r.stream(
281
+ variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
282
+ )
270
283
  for r in self.readers
271
284
  ]
272
285
 
ftag/hdf5/h5utils.py CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
13
13
  variables: list[str] | None = None,
14
14
  precision: str | None = None,
15
15
  transform: Transform | None = None,
16
+ full_precision_vars: list[str] | None = None,
16
17
  ) -> np.dtype:
17
18
  """Return a dtype based on an existing dataset and requested variables.
18
19
 
@@ -26,6 +27,8 @@ def get_dtype(
26
27
  Precision to cast floats to, "half" or "full", by default None
27
28
  transform : Transform | None, optional
28
29
  Transform to apply to variables names, by default None
30
+ full_precision_vars : list[str] | None, optional
31
+ List of variables to keep in full precision, by default None
29
32
 
30
33
  Returns
31
34
  -------
@@ -39,6 +42,8 @@ def get_dtype(
39
42
  """
40
43
  if variables is None:
41
44
  variables = ds.dtype.names
45
+ if full_precision_vars is None:
46
+ full_precision_vars = []
42
47
 
43
48
  if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
44
49
  variables = transform.map_variable_names(ds.name, variables, inverse=True)
@@ -50,7 +55,10 @@ def get_dtype(
50
55
 
51
56
  dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
52
57
  if precision:
53
- dtype = [(n, cast_dtype(x, precision)) for n, x in dtype]
58
+ dtype = [
59
+ (n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
60
+ for n, x in dtype
61
+ ]
54
62
 
55
63
  return np.dtype(dtype)
56
64
 
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
78
86
  t = np.dtype(typestr)
79
87
  if t.kind != "f":
80
88
  return t
89
+
81
90
  if precision == "half":
82
91
  return np.dtype("f2")
83
92
  if precision == "full":
ftag/hdf5/h5writer.py CHANGED
@@ -47,18 +47,28 @@ class H5Writer:
47
47
  precision: str = "full"
48
48
  full_precision_vars: list[str] | None = None
49
49
  shuffle: bool = True
50
+ num_jets: int | None = None # Allow dynamic mode by defaulting to None
50
51
 
51
52
  def __post_init__(self):
52
53
  self.num_written = 0
53
54
  self.rng = np.random.default_rng(42)
54
- self.num_jets = [shape[0] for shape in self.shapes.values()]
55
- assert len(set(self.num_jets)) == 1, "Must have same number of jets per group"
56
- self.num_jets = self.num_jets[0]
55
+
56
+ # Infer number of jets from shapes if not explicitly passed
57
+ inferred_num_jets = [shape[0] for shape in self.shapes.values()]
58
+ if self.num_jets is None:
59
+ assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
60
+ self.fixed_mode = False
61
+ else:
62
+ self.fixed_mode = True
63
+ for name in self.shapes:
64
+ self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
57
65
 
58
66
  if self.precision == "full":
59
67
  self.fp_dtype = np.float32
60
68
  elif self.precision == "half":
61
69
  self.fp_dtype = np.float16
70
+ elif self.precision is None:
71
+ self.fp_dtype = None
62
72
  else:
63
73
  raise ValueError(f"Invalid precision: {self.precision}")
64
74
 
@@ -71,16 +81,34 @@ class H5Writer:
71
81
  self.create_ds(name, dtype)
72
82
 
73
83
  @classmethod
74
- def from_file(cls, source: Path, num_jets: int | None = None, **kwargs) -> H5Writer:
84
+ def from_file(
85
+ cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
86
+ ) -> H5Writer:
75
87
  with h5py.File(source, "r") as f:
76
88
  dtypes = {name: ds.dtype for name, ds in f.items()}
77
89
  shapes = {name: ds.shape for name, ds in f.items()}
78
- if num_jets is not None:
90
+
91
+ if variables:
92
+ new_dtye = {}
93
+ new_shape = {}
94
+ for name, ds in f.items():
95
+ if name not in variables:
96
+ continue
97
+ new_dtye[name] = ftag.hdf5.get_dtype(
98
+ ds,
99
+ variables=variables[name],
100
+ precision=kwargs.get("precision"),
101
+ full_precision_vars=kwargs.get("full_precision_vars"),
102
+ )
103
+ new_shape[name] = ds.shape
104
+ dtypes = new_dtye
105
+ shapes = new_shape
106
+ if num_jets != 0:
79
107
  shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
80
108
  compression = [ds.compression for ds in f.values()]
81
109
  assert len(set(compression)) == 1, "Must have same compression for all groups"
82
110
  compression = compression[0]
83
- if compression not in kwargs:
111
+ if "compression" not in kwargs:
84
112
  kwargs["compression"] = compression
85
113
  return cls(dtypes=dtypes, shapes=shapes, **kwargs)
86
114
 
@@ -88,36 +116,47 @@ class H5Writer:
88
116
  if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
89
117
  dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
90
118
 
91
- # adjust dtype based on specified precision
92
- full_precision_vars = [] if self.full_precision_vars is None else self.full_precision_vars
93
- # If the field is in full_precision_vars, use the full precision dtype
119
+ fp_vars = self.full_precision_vars or []
120
+ # If no precision is defined, or the field is in full_precision_vars, or its non-float,
121
+ # keep it at the original dtype
94
122
  dtype = np.dtype([
95
123
  (
96
124
  field,
97
- self.fp_dtype
98
- if field not in full_precision_vars and np.issubdtype(dt, np.floating)
99
- else dt,
125
+ (
126
+ self.fp_dtype
127
+ if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
128
+ else dt
129
+ ),
100
130
  )
101
131
  for field, dt in dtype.descr
102
132
  ])
103
133
 
104
- # optimal chunking is around 100 jets, only aply for track groups
105
134
  shape = self.shapes[name]
106
135
  chunks = (100,) + shape[1:] if shape[1:] else None
107
136
 
108
- # note: enabling the hd5 shuffle filter doesn't improve write performance
109
- self.file.create_dataset(
110
- name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
111
- )
137
+ if self.fixed_mode:
138
+ self.file.create_dataset(
139
+ name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
140
+ )
141
+ else:
142
+ maxshape = (None,) + shape[1:]
143
+ self.file.create_dataset(
144
+ name,
145
+ dtype=dtype,
146
+ shape=(0,) + shape[1:],
147
+ maxshape=maxshape,
148
+ compression=self.compression,
149
+ chunks=chunks,
150
+ )
112
151
 
113
152
  def close(self) -> None:
114
- with h5py.File(self.dst) as f:
115
- written = len(f[self.jets_name])
116
- if self.num_written != written:
117
- raise ValueError(
118
- f"Attemped to close file {self.dst} when only {self.num_written:,} out of"
119
- f" {written:,} jets have been written"
120
- )
153
+ if self.fixed_mode:
154
+ written = len(self.file[self.jets_name])
155
+ if self.num_written != written:
156
+ raise ValueError(
157
+ f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
158
+ f" {written:,} jets have been written"
159
+ )
121
160
  self.file.close()
122
161
 
123
162
  def get_attr(self, name, group=None):
@@ -137,18 +176,25 @@ class H5Writer:
137
176
  for attr_name, value in ds.attrs.items():
138
177
  self.add_attr(attr_name, value, group=name)
139
178
 
140
- def write(self, data: dict[str, np.array]) -> None:
141
- if (total := self.num_written + len(data[self.jets_name])) > self.num_jets:
142
- raise ValueError(
143
- f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
144
- )
145
- idx = np.arange(len(data[self.jets_name]))
179
+ def write(self, data: dict[str, np.ndarray]) -> None:
180
+ batch_size = len(data[self.jets_name])
181
+ idx = np.arange(batch_size)
146
182
  if self.shuffle:
147
183
  self.rng.shuffle(idx)
148
184
  data = {name: array[idx] for name, array in data.items()}
149
185
 
150
186
  low = self.num_written
151
- high = low + len(idx)
187
+ high = low + batch_size
188
+
189
+ if self.fixed_mode and high > self.num_jets:
190
+ raise ValueError(
191
+ f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
192
+ )
193
+
152
194
  for group in self.dtypes:
153
- self.file[group][low:high] = data[group]
154
- self.num_written += len(idx)
195
+ ds = self.file[group]
196
+ if not self.fixed_mode:
197
+ ds.resize((high,) + ds.shape[1:])
198
+ ds[low:high] = data[group]
199
+
200
+ self.num_written += batch_size
ftag/labeller.py CHANGED
@@ -30,7 +30,7 @@ class Labeller:
30
30
  def __post_init__(self) -> None:
31
31
  if isinstance(self.labels, LabelContainer):
32
32
  self.labels = list(self.labels)
33
- self.labels = sorted([Flavours[label] for label in self.labels])
33
+ self.labels = [Flavours[label] for label in self.labels]
34
34
 
35
35
  @property
36
36
  def variables(self) -> list[str]:
ftag/mock.py CHANGED
@@ -106,11 +106,11 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
106
106
  for i in range(n_classes):
107
107
  tmp_means = []
108
108
  tmp_means = [
109
- 0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
109
+ 0 if j != i else mean_scale_list[rng.integers(0, len(mean_scale_list))]
110
110
  for j in range(n_classes)
111
111
  ]
112
112
  means.append(tmp_means)
113
- scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
113
+ scales.append(mean_scale_list[rng.integers(0, len(mean_scale_list))])
114
114
 
115
115
  # Map the labels to the means
116
116
  label_mapping = dict(zip(label_dict.values(), means))
ftag/vds.py CHANGED
@@ -8,114 +8,334 @@ import sys
8
8
  from pathlib import Path
9
9
 
10
10
  import h5py
11
+ import numpy as np
11
12
 
12
13
 
13
- def parse_args(args):
14
+ def parse_args(args=None):
14
15
  parser = argparse.ArgumentParser(
15
- description="Create a lightweight wrapper around a set of h5 files"
16
+ description="Create a lightweight HDF5 wrapper (virtual datasets + "
17
+ "summed cutBookkeeper counts) around a set of .h5 files"
18
+ )
19
+ parser.add_argument(
20
+ "pattern",
21
+ type=Path,
22
+ help="quotes-enclosed glob pattern of files to merge, "
23
+ "or a regex if --use_regex is given",
16
24
  )
17
- parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
18
25
  parser.add_argument("output", type=Path, help="path to output virtual file")
19
- parser.add_argument("--use_regex", help="if provided pattern is a regex", action="store_true")
20
- parser.add_argument("--regex_path", type=str, required="--regex" in sys.argv, default=None)
26
+ parser.add_argument(
27
+ "--use_regex",
28
+ action="store_true",
29
+ help="treat PATTERN as a regular expression instead of a glob",
30
+ )
31
+ parser.add_argument(
32
+ "--regex_path",
33
+ type=str,
34
+ required="--use_regex" in (args or sys.argv),
35
+ default=None,
36
+ help="directory whose entries the regex is applied to "
37
+ "(defaults to the current working directory)",
38
+ )
21
39
  return parser.parse_args(args)
22
40
 
23
41
 
24
- def get_virtual_layout(fnames: list[str], group: str):
25
- # get sources
42
+ def get_virtual_layout(fnames: list[str], group: str) -> h5py.VirtualLayout:
43
+ """Concatenate group from multiple files into a single VirtualDataset.
44
+
45
+ Parameters
46
+ ----------
47
+ fnames : list[str]
48
+ List with the file names
49
+ group : str
50
+ Name of the group that is concatenated
51
+
52
+ Returns
53
+ -------
54
+ h5py.VirtualLayout
55
+ Virtual layout of the new virtual dataset
56
+ """
26
57
  sources = []
27
58
  total = 0
59
+
60
+ # Loop over the input files
28
61
  for fname in fnames:
29
- with h5py.File(fname) as f:
30
- vsource = h5py.VirtualSource(f[group])
31
- total += vsource.shape[0]
32
- sources.append(vsource)
62
+ with h5py.File(fname, "r") as f:
63
+ # Get the file and append its length
64
+ vsrc = h5py.VirtualSource(f[group])
65
+ total += vsrc.shape[0]
66
+ sources.append(vsrc)
33
67
 
34
- # define layout of the vds
35
- with h5py.File(fnames[0]) as f:
68
+ # Define the layout of the output vds
69
+ with h5py.File(fnames[0], "r") as f:
36
70
  dtype = f[group].dtype
37
71
  shape = f[group].shape
72
+
73
+ # Update the shape finalize the output layout
38
74
  shape = (total,) + shape[1:]
39
75
  layout = h5py.VirtualLayout(shape=shape, dtype=dtype)
40
76
 
41
- # fill the vds
77
+ # Fill the vds
42
78
  idx = 0
43
- for source in sources:
44
- length = source.shape[0]
45
- layout[idx : idx + length] = source
79
+ for vsrc in sources:
80
+ length = vsrc.shape[0]
81
+ layout[idx : idx + length] = vsrc
46
82
  idx += length
47
83
 
48
84
  return layout
49
85
 
50
86
 
51
- def glob_re(pattern, regex_path):
87
+ def glob_re(pattern: str | None, regex_path: str | None) -> list[str] | None:
88
+ """Return list of filenames that match REGEX pattern inside regex_path.
89
+
90
+ Parameters
91
+ ----------
92
+ pattern : str
93
+ Pattern for the input files
94
+ regex_path : str
95
+ Regex path for the input files
96
+
97
+ Returns
98
+ -------
99
+ list[str]
100
+ List of the file basenames that matched the regex pattern
101
+ """
102
+ if pattern is None or regex_path is None:
103
+ return None
104
+
52
105
  return list(filter(re.compile(pattern).match, os.listdir(regex_path)))
53
106
 
54
107
 
55
- def regex_files_from_dir(reg_matched_fnames, regex_path):
108
+ def regex_files_from_dir(
109
+ reg_matched_fnames: list[str] | None,
110
+ regex_path: str | None,
111
+ ) -> list[str] | None:
112
+ """Turn a list of basenames into full paths; dive into sub-dirs if needed.
113
+
114
+ Parameters
115
+ ----------
116
+ reg_matched_fnames : list[str]
117
+ List of the regex matched file names
118
+ regex_path : str
119
+ Regex path for the input files
120
+
121
+ Returns
122
+ -------
123
+ list[str]
124
+ List of file paths (as strings) that matched the regex and any subsequent
125
+ globbing inside matched directories.
126
+ """
127
+ if reg_matched_fnames is None or regex_path is None:
128
+ return None
129
+
56
130
  parent_dir = regex_path or str(Path.cwd())
57
- full_paths = [parent_dir + "/" + fname for fname in reg_matched_fnames]
58
- paths_to_glob = [fname + "/*.h5" if Path(fname).is_dir() else fname for fname in full_paths]
59
- nested_fnames = [glob.glob(fname) for fname in paths_to_glob]
131
+ full_paths = [Path(parent_dir) / fname for fname in reg_matched_fnames]
132
+ paths_to_glob = [str(fp / "*.h5") if fp.is_dir() else str(fp) for fp in full_paths]
133
+ nested_fnames = [glob.glob(p) for p in paths_to_glob]
60
134
  return sum(nested_fnames, [])
61
135
 
62
136
 
137
+ def sum_counts_once(counts: np.ndarray) -> np.ndarray:
138
+ """Reduce the arrays in the counts dataset for one file to a scalar via summation.
139
+
140
+ Parameters
141
+ ----------
142
+ counts : np.ndarray
143
+ Array from the h5py dataset (counts) from the cutBookkeeper groups
144
+
145
+ Returns
146
+ -------
147
+ np.ndarray
148
+ Array with the summed variables for the file
149
+ """
150
+ dtype = counts.dtype
151
+ summed = np.zeros((), dtype=dtype)
152
+ for field in dtype.names:
153
+ summed[field] = counts[field].sum()
154
+ return summed
155
+
156
+
157
+ def check_subgroups(fnames: list[str], group_name: str = "cutBookkeeper") -> list[str]:
158
+ """Check which subgroups are available for the bookkeeper.
159
+
160
+ Find the intersection of sub-group names that have a 'counts' dataset
161
+ in every input file. (Using the intersection makes the script robust
162
+ even if a few files are missing a variation.)
163
+
164
+ Parameters
165
+ ----------
166
+ fnames : list[str]
167
+ List of the input files
168
+ group_name : str, optional
169
+ Group name in the h5 files of the bookkeeper, by default "cutBookkeeper"
170
+
171
+ Returns
172
+ -------
173
+ set[str]
174
+ Returns the files with common sub-groups
175
+
176
+ Raises
177
+ ------
178
+ KeyError
179
+ When a file does not have a bookkeeper
180
+ ValueError
181
+ When no common bookkeeper sub-groups were found
182
+ """
183
+ common: set[str] | None = None
184
+ for fname in fnames:
185
+ with h5py.File(fname, "r") as f:
186
+ if group_name not in f:
187
+ raise KeyError(f"{fname} has no '{group_name}' group")
188
+ these = {
189
+ name
190
+ for name, item in f[group_name].items()
191
+ if isinstance(item, h5py.Group) and "counts" in item
192
+ }
193
+ common = these if common is None else common & these
194
+ if not common:
195
+ raise ValueError("No common cutBookkeeper sub-groups with 'counts' found")
196
+ return sorted(common)
197
+
198
+
199
+ def aggregate_cutbookkeeper(
200
+ fnames: list[str],
201
+ group_name: str = "cutBookkeeper",
202
+ ) -> dict[str, np.ndarray] | None:
203
+ """Aggregate the cutBookkeeper in the input files.
204
+
205
+ For every input file:
206
+ For every sub-group (nominal, sysUp, sysDown, …):
207
+ 1. Sum the 4-entry record array inside each file into 1 record
208
+ 1. Add those records from all files together into grand total
209
+ Returns a dict {subgroup_name: scalar-record-array}
210
+
211
+ Parameters
212
+ ----------
213
+ fnames : list[str]
214
+ List of the input files
215
+
216
+ Returns
217
+ -------
218
+ dict[str, np.ndarray] | None
219
+ Dict with the accumulated cutBookkeeper groups. If the cut bookkeeper
220
+ is not in the files, return None.
221
+ """
222
+ if any(group_name not in h5py.File(f, "r") for f in fnames):
223
+ return None
224
+
225
+ subgroups = check_subgroups(fnames, group_name=group_name)
226
+
227
+ # initialise an accumulator per subgroup (dtype taken from 1st file)
228
+ accum: dict[str, np.ndarray] = {}
229
+ with h5py.File(fnames[0], "r") as f0:
230
+ for sg in subgroups:
231
+ dtype = f0[f"{group_name}/{sg}/counts"].dtype
232
+ accum[sg] = np.zeros((), dtype=dtype)
233
+
234
+ # add each files contribution field-wise
235
+ for fname in fnames:
236
+ with h5py.File(fname, "r") as f:
237
+ for sg in subgroups:
238
+ per_file = sum_counts_once(f[f"{group_name}/{sg}/counts"][()])
239
+ for fld in accum[sg].dtype.names:
240
+ accum[sg][fld] += per_file[fld]
241
+
242
+ return accum
243
+
244
+
63
245
  def create_virtual_file(
64
246
  pattern: Path | str,
65
- out_fname: Path | None = None,
247
+ out_fname: Path | str | None = None,
66
248
  use_regex: bool = False,
67
249
  regex_path: str | None = None,
68
250
  overwrite: bool = False,
69
- ):
70
- # get list of filenames
251
+ bookkeeper_name: str = "cutBookkeeper",
252
+ ) -> Path:
253
+ """Create the virtual dataset file for the given inputs.
254
+
255
+ Parameters
256
+ ----------
257
+ pattern : Path | str
258
+ Pattern of the input files used. Wildcard is supported
259
+ out_fname : Path | str | None, optional
260
+ Output path to which the virtual dataset file is written. By default None
261
+ use_regex : bool, optional
262
+ If you want to use regex instead of glob, by default False
263
+ regex_path : str | None, optional
264
+ Regex logic used to define the input files, by default None
265
+ overwrite : bool, optional
266
+ Decide, if an existing output file is overwritten, by default False
267
+ bookkeeper_name : str, optional
268
+ Name of the cut bookkeeper in the h5 files.
269
+
270
+ Returns
271
+ -------
272
+ Path
273
+ Path object of the path to which the output file is written
274
+
275
+ Raises
276
+ ------
277
+ FileNotFoundError
278
+ If not input files were found for the given pattern
279
+ ValueError
280
+ If no output file is given and the input comes from multiple directories
281
+ """
282
+ # Get list of filenames
71
283
  pattern_str = str(pattern)
72
- if use_regex:
73
- reg_matched_fnames = glob_re(pattern_str, regex_path)
74
- print("reg matched fnames: ", reg_matched_fnames)
75
- fnames = regex_files_from_dir(reg_matched_fnames, regex_path)
284
+
285
+ # Use regex to find input files else use glob
286
+ if use_regex is True:
287
+ matched = glob_re(pattern_str, regex_path)
288
+ fnames = regex_files_from_dir(matched, regex_path)
76
289
  else:
77
290
  fnames = glob.glob(pattern_str)
291
+
292
+ # Throw error if no input files were found
78
293
  if not fnames:
79
- raise FileNotFoundError(f"No files matched pattern {pattern}")
80
- print("Files to merge to vds: ", fnames)
294
+ raise FileNotFoundError(f"No files matched pattern {pattern!r}")
81
295
 
82
- # infer output path if not given
296
+ # Infer output path if not given
83
297
  if out_fname is None:
84
- assert len({Path(fname).parent for fname in fnames}) == 1
298
+ if len({Path(f).parent for f in fnames}) != 1:
299
+ raise ValueError("Give --output when files reside in multiple dirs")
85
300
  out_fname = Path(fnames[0]).parent / "vds" / "vds.h5"
86
301
  else:
87
302
  out_fname = Path(out_fname)
88
303
 
89
- # check if file already exists
304
+ # If overwrite is not active and a file exists, stop here
90
305
  if not overwrite and out_fname.is_file():
91
306
  return out_fname
92
307
 
93
- # identify common groups across all files
308
+ # Identify common groups across all files
94
309
  common_groups: set[str] = set()
95
310
  for fname in fnames:
96
- with h5py.File(fname) as f:
311
+ with h5py.File(fname, "r") as f:
97
312
  groups = set(f.keys())
98
- common_groups = groups if not common_groups else common_groups.intersection(groups)
99
-
100
- if not common_groups:
101
- raise ValueError("No common groups found across files")
102
-
103
- # create virtual file
104
- out_fname.parent.mkdir(exist_ok=True)
105
- with h5py.File(out_fname, "w") as f:
106
- for group in common_groups:
107
- layout = get_virtual_layout(fnames, group)
108
- f.create_virtual_dataset(group, layout)
109
- attrs_dict: dict = {}
110
- for fname in fnames:
111
- with h5py.File(fname) as g:
112
- for name, value in g[group].attrs.items():
113
- if name not in attrs_dict:
114
- attrs_dict[name] = []
115
- attrs_dict[name].append(value)
116
- for name, value in attrs_dict.items():
117
- if len(value) > 0:
118
- f[group].attrs[name] = value[0]
313
+ common_groups = groups if not common_groups else common_groups & groups
314
+
315
+ # Ditch the bookkeeper. We will process it separately
316
+ common_groups.discard("cutBookkeeper")
317
+
318
+ # Check that the directory of the output file exists
319
+ out_fname.parent.mkdir(parents=True, exist_ok=True)
320
+
321
+ # Build the output file
322
+ with h5py.File(out_fname, "w") as fout:
323
+ # Build "standard" groups
324
+ for gname in sorted(common_groups):
325
+ layout = get_virtual_layout(fnames, gname)
326
+ fout.create_virtual_dataset(gname, layout)
327
+
328
+ # Copy first-file attributes to VDS root object
329
+ with h5py.File(fnames[0], "r") as f0:
330
+ for k, v in f0[gname].attrs.items():
331
+ fout[gname].attrs[k] = v
332
+
333
+ # Build the cutBookkeeper
334
+ counts_total = aggregate_cutbookkeeper(fnames=fnames, group_name=bookkeeper_name)
335
+ if counts_total is not None:
336
+ for sg, record in counts_total.items():
337
+ grp = fout.require_group(f"{bookkeeper_name}/{sg}")
338
+ grp.create_dataset("counts", data=record, shape=(), dtype=record.dtype)
119
339
 
120
340
  return out_fname
121
341
 
@@ -123,19 +343,20 @@ def create_virtual_file(
123
343
  def main(args=None) -> None:
124
344
  args = parse_args(args)
125
345
  matching_mode = "Applying regex to" if args.use_regex else "Globbing"
126
- print(f"{matching_mode} {args.pattern}...")
127
- create_virtual_file(
128
- args.pattern,
129
- args.output,
346
+ print(f"{matching_mode} {args.pattern} ...")
347
+ out_path = create_virtual_file(
348
+ pattern=args.pattern,
349
+ out_fname=args.output,
130
350
  use_regex=args.use_regex,
131
351
  regex_path=args.regex_path,
132
352
  overwrite=True,
133
353
  )
134
- with h5py.File(args.output) as f:
354
+
355
+ with h5py.File(out_path, "r") as f:
135
356
  key = next(iter(f.keys()))
136
- num = len(f[key])
137
- print(f"Virtual dataset '{key}' has {num:,} entries")
138
- print(f"Saved virtual file to {args.output.resolve()}")
357
+ print(f"Virtual dataset '{key}' has {len(f[key]):,} entries")
358
+
359
+ print(f"Saved virtual file to {out_path.resolve()}")
139
360
 
140
361
 
141
362
  if __name__ == "__main__":