atlas-ftag-tools 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/METADATA +2 -2
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/RECORD +14 -14
- ftag/__init__.py +1 -1
- ftag/flavours.yaml +1 -1
- ftag/hdf5/h5reader.py +17 -4
- ftag/hdf5/h5utils.py +10 -1
- ftag/hdf5/h5writer.py +79 -33
- ftag/labeller.py +1 -1
- ftag/mock.py +2 -2
- ftag/vds.py +286 -65
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/WHEEL +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.13.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: atlas-ftag-tools
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.13
|
4
4
|
Summary: ATLAS Flavour Tagging Tools
|
5
5
|
Author: Sam Van Stroud, Philipp Gadow
|
6
6
|
License: MIT
|
7
7
|
Project-URL: Homepage, https://github.com/umami-hep/atlas-ftag-tools/
|
8
|
-
Requires-Python: <3.12,>=3.
|
8
|
+
Requires-Python: <3.12,>=3.10
|
9
9
|
Description-Content-Type: text/markdown
|
10
10
|
License-File: LICENSE
|
11
11
|
Requires-Dist: h5py>=3.0
|
@@ -1,32 +1,32 @@
|
|
1
|
-
atlas_ftag_tools-0.2.
|
2
|
-
ftag/__init__.py,sha256=
|
1
|
+
atlas_ftag_tools-0.2.13.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
|
2
|
+
ftag/__init__.py,sha256=UdYmO_mROM7jvqpPUMbnaQxdCrlR8O0KLlhatyMnapw,748
|
3
3
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
4
4
|
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
5
5
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
6
|
-
ftag/flavours.yaml,sha256=
|
6
|
+
ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
|
7
7
|
ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
|
8
8
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
9
|
-
ftag/labeller.py,sha256=
|
9
|
+
ftag/labeller.py,sha256=6tKLG0SrBijMIZdzWjGdQU9qN_dkRV4eELLX_8YQvTQ,2772
|
10
10
|
ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
|
11
|
-
ftag/mock.py,sha256=
|
11
|
+
ftag/mock.py,sha256=syysvzLsBHU8aw7Uy5g3G4HB6LnSmHlGa2BfeXv5mQ4,5970
|
12
12
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
13
13
|
ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
|
14
14
|
ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
|
15
15
|
ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
|
16
|
-
ftag/vds.py,sha256=
|
16
|
+
ftag/vds.py,sha256=l6b54naOK7z0gZjvvtIAQv2Ky4X1w1yLrisZZZYqvbY,11259
|
17
17
|
ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
|
18
18
|
ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
|
19
19
|
ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
|
20
20
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
21
|
-
ftag/hdf5/h5reader.py,sha256=
|
21
|
+
ftag/hdf5/h5reader.py,sha256=NbHohY3RSicM3qnX_0Y1TfGAaDg3wgfEjYlGaWDJmug,14268
|
22
22
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
23
|
-
ftag/hdf5/h5utils.py,sha256
|
24
|
-
ftag/hdf5/h5writer.py,sha256=
|
23
|
+
ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
|
24
|
+
ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
|
25
25
|
ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
|
26
26
|
ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
|
27
27
|
ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
|
28
|
-
atlas_ftag_tools-0.2.
|
29
|
-
atlas_ftag_tools-0.2.
|
30
|
-
atlas_ftag_tools-0.2.
|
31
|
-
atlas_ftag_tools-0.2.
|
32
|
-
atlas_ftag_tools-0.2.
|
28
|
+
atlas_ftag_tools-0.2.13.dist-info/METADATA,sha256=ZpQ5GggkLyizsv9uHEOvIlzRqPmC-4tNaoaMgV6unF4,2153
|
29
|
+
atlas_ftag_tools-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
30
|
+
atlas_ftag_tools-0.2.13.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
|
31
|
+
atlas_ftag_tools-0.2.13.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
32
|
+
atlas_ftag_tools-0.2.13.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/flavours.yaml
CHANGED
@@ -113,7 +113,7 @@
|
|
113
113
|
category: xbb
|
114
114
|
- name: qcdnonbb
|
115
115
|
label: $\mathrm{QCD} \rightarrow \mathrm{non-} b \bar{b}$
|
116
|
-
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount
|
116
|
+
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount < 2"]
|
117
117
|
colour: "silver"
|
118
118
|
category: xbb
|
119
119
|
- name: qcdbx
|
ftag/hdf5/h5reader.py
CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
|
|
74
74
|
num_jets: int | None = None,
|
75
75
|
cuts: Cuts | None = None,
|
76
76
|
start: int = 0,
|
77
|
+
skip_batches: int = 0,
|
77
78
|
) -> Generator:
|
78
79
|
if num_jets is None:
|
79
80
|
num_jets = self.num_jets
|
80
|
-
|
81
|
+
if skip_batches > 0:
|
82
|
+
assert not self.shuffle, "Cannot skip batches if shuffle is True"
|
81
83
|
if num_jets > self.num_jets:
|
82
84
|
log.warning(
|
83
85
|
f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
|
@@ -97,7 +99,8 @@ class H5SingleReader:
|
|
97
99
|
indices = list(range(start, self.num_jets + start, self.batch_size))
|
98
100
|
if self.shuffle:
|
99
101
|
self.rng.shuffle(indices)
|
100
|
-
|
102
|
+
if skip_batches > 0:
|
103
|
+
indices = indices[skip_batches:]
|
101
104
|
# loop over batches and read file
|
102
105
|
for low in indices:
|
103
106
|
for name in variables:
|
@@ -176,7 +179,12 @@ class H5Reader:
|
|
176
179
|
|
177
180
|
# calculate batch sizes
|
178
181
|
if self.weights is None:
|
179
|
-
|
182
|
+
rows_per_file = [
|
183
|
+
H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
|
184
|
+
]
|
185
|
+
num_total = sum(rows_per_file)
|
186
|
+
self.weights = [num / num_total for num in rows_per_file]
|
187
|
+
|
180
188
|
self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
|
181
189
|
|
182
190
|
# create readers
|
@@ -233,6 +241,7 @@ class H5Reader:
|
|
233
241
|
num_jets: int | None = None,
|
234
242
|
cuts: Cuts | None = None,
|
235
243
|
start: int = 0,
|
244
|
+
skip_batches: int = 0,
|
236
245
|
) -> Generator:
|
237
246
|
"""Generate batches of selected jets.
|
238
247
|
|
@@ -246,6 +255,8 @@ class H5Reader:
|
|
246
255
|
Selection cuts to apply, by default None
|
247
256
|
start : int, optional
|
248
257
|
Starting index of the first jet to read, by default 0
|
258
|
+
skip_batches : int, optional
|
259
|
+
Number of batches to skip, by default 0
|
249
260
|
|
250
261
|
Yields
|
251
262
|
------
|
@@ -266,7 +277,9 @@ class H5Reader:
|
|
266
277
|
|
267
278
|
# get streams for selected jets from each reader
|
268
279
|
streams = [
|
269
|
-
r.stream(
|
280
|
+
r.stream(
|
281
|
+
variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
|
282
|
+
)
|
270
283
|
for r in self.readers
|
271
284
|
]
|
272
285
|
|
ftag/hdf5/h5utils.py
CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
|
|
13
13
|
variables: list[str] | None = None,
|
14
14
|
precision: str | None = None,
|
15
15
|
transform: Transform | None = None,
|
16
|
+
full_precision_vars: list[str] | None = None,
|
16
17
|
) -> np.dtype:
|
17
18
|
"""Return a dtype based on an existing dataset and requested variables.
|
18
19
|
|
@@ -26,6 +27,8 @@ def get_dtype(
|
|
26
27
|
Precision to cast floats to, "half" or "full", by default None
|
27
28
|
transform : Transform | None, optional
|
28
29
|
Transform to apply to variables names, by default None
|
30
|
+
full_precision_vars : list[str] | None, optional
|
31
|
+
List of variables to keep in full precision, by default None
|
29
32
|
|
30
33
|
Returns
|
31
34
|
-------
|
@@ -39,6 +42,8 @@ def get_dtype(
|
|
39
42
|
"""
|
40
43
|
if variables is None:
|
41
44
|
variables = ds.dtype.names
|
45
|
+
if full_precision_vars is None:
|
46
|
+
full_precision_vars = []
|
42
47
|
|
43
48
|
if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
|
44
49
|
variables = transform.map_variable_names(ds.name, variables, inverse=True)
|
@@ -50,7 +55,10 @@ def get_dtype(
|
|
50
55
|
|
51
56
|
dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
|
52
57
|
if precision:
|
53
|
-
dtype = [
|
58
|
+
dtype = [
|
59
|
+
(n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
|
60
|
+
for n, x in dtype
|
61
|
+
]
|
54
62
|
|
55
63
|
return np.dtype(dtype)
|
56
64
|
|
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
|
|
78
86
|
t = np.dtype(typestr)
|
79
87
|
if t.kind != "f":
|
80
88
|
return t
|
89
|
+
|
81
90
|
if precision == "half":
|
82
91
|
return np.dtype("f2")
|
83
92
|
if precision == "full":
|
ftag/hdf5/h5writer.py
CHANGED
@@ -47,18 +47,28 @@ class H5Writer:
|
|
47
47
|
precision: str = "full"
|
48
48
|
full_precision_vars: list[str] | None = None
|
49
49
|
shuffle: bool = True
|
50
|
+
num_jets: int | None = None # Allow dynamic mode by defaulting to None
|
50
51
|
|
51
52
|
def __post_init__(self):
|
52
53
|
self.num_written = 0
|
53
54
|
self.rng = np.random.default_rng(42)
|
54
|
-
|
55
|
-
|
56
|
-
|
55
|
+
|
56
|
+
# Infer number of jets from shapes if not explicitly passed
|
57
|
+
inferred_num_jets = [shape[0] for shape in self.shapes.values()]
|
58
|
+
if self.num_jets is None:
|
59
|
+
assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
|
60
|
+
self.fixed_mode = False
|
61
|
+
else:
|
62
|
+
self.fixed_mode = True
|
63
|
+
for name in self.shapes:
|
64
|
+
self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
|
57
65
|
|
58
66
|
if self.precision == "full":
|
59
67
|
self.fp_dtype = np.float32
|
60
68
|
elif self.precision == "half":
|
61
69
|
self.fp_dtype = np.float16
|
70
|
+
elif self.precision is None:
|
71
|
+
self.fp_dtype = None
|
62
72
|
else:
|
63
73
|
raise ValueError(f"Invalid precision: {self.precision}")
|
64
74
|
|
@@ -71,16 +81,34 @@ class H5Writer:
|
|
71
81
|
self.create_ds(name, dtype)
|
72
82
|
|
73
83
|
@classmethod
|
74
|
-
def from_file(
|
84
|
+
def from_file(
|
85
|
+
cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
|
86
|
+
) -> H5Writer:
|
75
87
|
with h5py.File(source, "r") as f:
|
76
88
|
dtypes = {name: ds.dtype for name, ds in f.items()}
|
77
89
|
shapes = {name: ds.shape for name, ds in f.items()}
|
78
|
-
|
90
|
+
|
91
|
+
if variables:
|
92
|
+
new_dtye = {}
|
93
|
+
new_shape = {}
|
94
|
+
for name, ds in f.items():
|
95
|
+
if name not in variables:
|
96
|
+
continue
|
97
|
+
new_dtye[name] = ftag.hdf5.get_dtype(
|
98
|
+
ds,
|
99
|
+
variables=variables[name],
|
100
|
+
precision=kwargs.get("precision"),
|
101
|
+
full_precision_vars=kwargs.get("full_precision_vars"),
|
102
|
+
)
|
103
|
+
new_shape[name] = ds.shape
|
104
|
+
dtypes = new_dtye
|
105
|
+
shapes = new_shape
|
106
|
+
if num_jets != 0:
|
79
107
|
shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
|
80
108
|
compression = [ds.compression for ds in f.values()]
|
81
109
|
assert len(set(compression)) == 1, "Must have same compression for all groups"
|
82
110
|
compression = compression[0]
|
83
|
-
if compression not in kwargs:
|
111
|
+
if "compression" not in kwargs:
|
84
112
|
kwargs["compression"] = compression
|
85
113
|
return cls(dtypes=dtypes, shapes=shapes, **kwargs)
|
86
114
|
|
@@ -88,36 +116,47 @@ class H5Writer:
|
|
88
116
|
if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
|
89
117
|
dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
|
90
118
|
|
91
|
-
|
92
|
-
|
93
|
-
#
|
119
|
+
fp_vars = self.full_precision_vars or []
|
120
|
+
# If no precision is defined, or the field is in full_precision_vars, or its non-float,
|
121
|
+
# keep it at the original dtype
|
94
122
|
dtype = np.dtype([
|
95
123
|
(
|
96
124
|
field,
|
97
|
-
|
98
|
-
|
99
|
-
|
125
|
+
(
|
126
|
+
self.fp_dtype
|
127
|
+
if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
|
128
|
+
else dt
|
129
|
+
),
|
100
130
|
)
|
101
131
|
for field, dt in dtype.descr
|
102
132
|
])
|
103
133
|
|
104
|
-
# optimal chunking is around 100 jets, only aply for track groups
|
105
134
|
shape = self.shapes[name]
|
106
135
|
chunks = (100,) + shape[1:] if shape[1:] else None
|
107
136
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
137
|
+
if self.fixed_mode:
|
138
|
+
self.file.create_dataset(
|
139
|
+
name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
maxshape = (None,) + shape[1:]
|
143
|
+
self.file.create_dataset(
|
144
|
+
name,
|
145
|
+
dtype=dtype,
|
146
|
+
shape=(0,) + shape[1:],
|
147
|
+
maxshape=maxshape,
|
148
|
+
compression=self.compression,
|
149
|
+
chunks=chunks,
|
150
|
+
)
|
112
151
|
|
113
152
|
def close(self) -> None:
|
114
|
-
|
115
|
-
written = len(
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
153
|
+
if self.fixed_mode:
|
154
|
+
written = len(self.file[self.jets_name])
|
155
|
+
if self.num_written != written:
|
156
|
+
raise ValueError(
|
157
|
+
f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
|
158
|
+
f" {written:,} jets have been written"
|
159
|
+
)
|
121
160
|
self.file.close()
|
122
161
|
|
123
162
|
def get_attr(self, name, group=None):
|
@@ -137,18 +176,25 @@ class H5Writer:
|
|
137
176
|
for attr_name, value in ds.attrs.items():
|
138
177
|
self.add_attr(attr_name, value, group=name)
|
139
178
|
|
140
|
-
def write(self, data: dict[str, np.
|
141
|
-
|
142
|
-
|
143
|
-
f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
|
144
|
-
)
|
145
|
-
idx = np.arange(len(data[self.jets_name]))
|
179
|
+
def write(self, data: dict[str, np.ndarray]) -> None:
|
180
|
+
batch_size = len(data[self.jets_name])
|
181
|
+
idx = np.arange(batch_size)
|
146
182
|
if self.shuffle:
|
147
183
|
self.rng.shuffle(idx)
|
148
184
|
data = {name: array[idx] for name, array in data.items()}
|
149
185
|
|
150
186
|
low = self.num_written
|
151
|
-
high = low +
|
187
|
+
high = low + batch_size
|
188
|
+
|
189
|
+
if self.fixed_mode and high > self.num_jets:
|
190
|
+
raise ValueError(
|
191
|
+
f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
|
192
|
+
)
|
193
|
+
|
152
194
|
for group in self.dtypes:
|
153
|
-
self.file[group]
|
154
|
-
|
195
|
+
ds = self.file[group]
|
196
|
+
if not self.fixed_mode:
|
197
|
+
ds.resize((high,) + ds.shape[1:])
|
198
|
+
ds[low:high] = data[group]
|
199
|
+
|
200
|
+
self.num_written += batch_size
|
ftag/labeller.py
CHANGED
@@ -30,7 +30,7 @@ class Labeller:
|
|
30
30
|
def __post_init__(self) -> None:
|
31
31
|
if isinstance(self.labels, LabelContainer):
|
32
32
|
self.labels = list(self.labels)
|
33
|
-
self.labels =
|
33
|
+
self.labels = [Flavours[label] for label in self.labels]
|
34
34
|
|
35
35
|
@property
|
36
36
|
def variables(self) -> list[str]:
|
ftag/mock.py
CHANGED
@@ -106,11 +106,11 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
|
|
106
106
|
for i in range(n_classes):
|
107
107
|
tmp_means = []
|
108
108
|
tmp_means = [
|
109
|
-
0 if j != i else mean_scale_list[
|
109
|
+
0 if j != i else mean_scale_list[rng.integers(0, len(mean_scale_list))]
|
110
110
|
for j in range(n_classes)
|
111
111
|
]
|
112
112
|
means.append(tmp_means)
|
113
|
-
scales.append(mean_scale_list[
|
113
|
+
scales.append(mean_scale_list[rng.integers(0, len(mean_scale_list))])
|
114
114
|
|
115
115
|
# Map the labels to the means
|
116
116
|
label_mapping = dict(zip(label_dict.values(), means))
|
ftag/vds.py
CHANGED
@@ -8,114 +8,334 @@ import sys
|
|
8
8
|
from pathlib import Path
|
9
9
|
|
10
10
|
import h5py
|
11
|
+
import numpy as np
|
11
12
|
|
12
13
|
|
13
|
-
def parse_args(args):
|
14
|
+
def parse_args(args=None):
|
14
15
|
parser = argparse.ArgumentParser(
|
15
|
-
description="Create a lightweight wrapper
|
16
|
+
description="Create a lightweight HDF5 wrapper (virtual datasets + "
|
17
|
+
"summed cutBookkeeper counts) around a set of .h5 files"
|
18
|
+
)
|
19
|
+
parser.add_argument(
|
20
|
+
"pattern",
|
21
|
+
type=Path,
|
22
|
+
help="quotes-enclosed glob pattern of files to merge, "
|
23
|
+
"or a regex if --use_regex is given",
|
16
24
|
)
|
17
|
-
parser.add_argument("pattern", type=Path, help="quotes-enclosed glob pattern of files to merge")
|
18
25
|
parser.add_argument("output", type=Path, help="path to output virtual file")
|
19
|
-
parser.add_argument(
|
20
|
-
|
26
|
+
parser.add_argument(
|
27
|
+
"--use_regex",
|
28
|
+
action="store_true",
|
29
|
+
help="treat PATTERN as a regular expression instead of a glob",
|
30
|
+
)
|
31
|
+
parser.add_argument(
|
32
|
+
"--regex_path",
|
33
|
+
type=str,
|
34
|
+
required="--use_regex" in (args or sys.argv),
|
35
|
+
default=None,
|
36
|
+
help="directory whose entries the regex is applied to "
|
37
|
+
"(defaults to the current working directory)",
|
38
|
+
)
|
21
39
|
return parser.parse_args(args)
|
22
40
|
|
23
41
|
|
24
|
-
def get_virtual_layout(fnames: list[str], group: str):
|
25
|
-
|
42
|
+
def get_virtual_layout(fnames: list[str], group: str) -> h5py.VirtualLayout:
|
43
|
+
"""Concatenate group from multiple files into a single VirtualDataset.
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
fnames : list[str]
|
48
|
+
List with the file names
|
49
|
+
group : str
|
50
|
+
Name of the group that is concatenated
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
h5py.VirtualLayout
|
55
|
+
Virtual layout of the new virtual dataset
|
56
|
+
"""
|
26
57
|
sources = []
|
27
58
|
total = 0
|
59
|
+
|
60
|
+
# Loop over the input files
|
28
61
|
for fname in fnames:
|
29
|
-
with h5py.File(fname) as f:
|
30
|
-
|
31
|
-
|
32
|
-
|
62
|
+
with h5py.File(fname, "r") as f:
|
63
|
+
# Get the file and append its length
|
64
|
+
vsrc = h5py.VirtualSource(f[group])
|
65
|
+
total += vsrc.shape[0]
|
66
|
+
sources.append(vsrc)
|
33
67
|
|
34
|
-
#
|
35
|
-
with h5py.File(fnames[0]) as f:
|
68
|
+
# Define the layout of the output vds
|
69
|
+
with h5py.File(fnames[0], "r") as f:
|
36
70
|
dtype = f[group].dtype
|
37
71
|
shape = f[group].shape
|
72
|
+
|
73
|
+
# Update the shape finalize the output layout
|
38
74
|
shape = (total,) + shape[1:]
|
39
75
|
layout = h5py.VirtualLayout(shape=shape, dtype=dtype)
|
40
76
|
|
41
|
-
#
|
77
|
+
# Fill the vds
|
42
78
|
idx = 0
|
43
|
-
for
|
44
|
-
length =
|
45
|
-
layout[idx : idx + length] =
|
79
|
+
for vsrc in sources:
|
80
|
+
length = vsrc.shape[0]
|
81
|
+
layout[idx : idx + length] = vsrc
|
46
82
|
idx += length
|
47
83
|
|
48
84
|
return layout
|
49
85
|
|
50
86
|
|
51
|
-
def glob_re(pattern, regex_path):
|
87
|
+
def glob_re(pattern: str | None, regex_path: str | None) -> list[str] | None:
|
88
|
+
"""Return list of filenames that match REGEX pattern inside regex_path.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
pattern : str
|
93
|
+
Pattern for the input files
|
94
|
+
regex_path : str
|
95
|
+
Regex path for the input files
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
list[str]
|
100
|
+
List of the file basenames that matched the regex pattern
|
101
|
+
"""
|
102
|
+
if pattern is None or regex_path is None:
|
103
|
+
return None
|
104
|
+
|
52
105
|
return list(filter(re.compile(pattern).match, os.listdir(regex_path)))
|
53
106
|
|
54
107
|
|
55
|
-
def regex_files_from_dir(
|
108
|
+
def regex_files_from_dir(
|
109
|
+
reg_matched_fnames: list[str] | None,
|
110
|
+
regex_path: str | None,
|
111
|
+
) -> list[str] | None:
|
112
|
+
"""Turn a list of basenames into full paths; dive into sub-dirs if needed.
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
reg_matched_fnames : list[str]
|
117
|
+
List of the regex matched file names
|
118
|
+
regex_path : str
|
119
|
+
Regex path for the input files
|
120
|
+
|
121
|
+
Returns
|
122
|
+
-------
|
123
|
+
list[str]
|
124
|
+
List of file paths (as strings) that matched the regex and any subsequent
|
125
|
+
globbing inside matched directories.
|
126
|
+
"""
|
127
|
+
if reg_matched_fnames is None or regex_path is None:
|
128
|
+
return None
|
129
|
+
|
56
130
|
parent_dir = regex_path or str(Path.cwd())
|
57
|
-
full_paths = [parent_dir
|
58
|
-
paths_to_glob = [
|
59
|
-
nested_fnames = [glob.glob(
|
131
|
+
full_paths = [Path(parent_dir) / fname for fname in reg_matched_fnames]
|
132
|
+
paths_to_glob = [str(fp / "*.h5") if fp.is_dir() else str(fp) for fp in full_paths]
|
133
|
+
nested_fnames = [glob.glob(p) for p in paths_to_glob]
|
60
134
|
return sum(nested_fnames, [])
|
61
135
|
|
62
136
|
|
137
|
+
def sum_counts_once(counts: np.ndarray) -> np.ndarray:
|
138
|
+
"""Reduce the arrays in the counts dataset for one file to a scalar via summation.
|
139
|
+
|
140
|
+
Parameters
|
141
|
+
----------
|
142
|
+
counts : np.ndarray
|
143
|
+
Array from the h5py dataset (counts) from the cutBookkeeper groups
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
np.ndarray
|
148
|
+
Array with the summed variables for the file
|
149
|
+
"""
|
150
|
+
dtype = counts.dtype
|
151
|
+
summed = np.zeros((), dtype=dtype)
|
152
|
+
for field in dtype.names:
|
153
|
+
summed[field] = counts[field].sum()
|
154
|
+
return summed
|
155
|
+
|
156
|
+
|
157
|
+
def check_subgroups(fnames: list[str], group_name: str = "cutBookkeeper") -> list[str]:
|
158
|
+
"""Check which subgroups are available for the bookkeeper.
|
159
|
+
|
160
|
+
Find the intersection of sub-group names that have a 'counts' dataset
|
161
|
+
in every input file. (Using the intersection makes the script robust
|
162
|
+
even if a few files are missing a variation.)
|
163
|
+
|
164
|
+
Parameters
|
165
|
+
----------
|
166
|
+
fnames : list[str]
|
167
|
+
List of the input files
|
168
|
+
group_name : str, optional
|
169
|
+
Group name in the h5 files of the bookkeeper, by default "cutBookkeeper"
|
170
|
+
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
set[str]
|
174
|
+
Returns the files with common sub-groups
|
175
|
+
|
176
|
+
Raises
|
177
|
+
------
|
178
|
+
KeyError
|
179
|
+
When a file does not have a bookkeeper
|
180
|
+
ValueError
|
181
|
+
When no common bookkeeper sub-groups were found
|
182
|
+
"""
|
183
|
+
common: set[str] | None = None
|
184
|
+
for fname in fnames:
|
185
|
+
with h5py.File(fname, "r") as f:
|
186
|
+
if group_name not in f:
|
187
|
+
raise KeyError(f"{fname} has no '{group_name}' group")
|
188
|
+
these = {
|
189
|
+
name
|
190
|
+
for name, item in f[group_name].items()
|
191
|
+
if isinstance(item, h5py.Group) and "counts" in item
|
192
|
+
}
|
193
|
+
common = these if common is None else common & these
|
194
|
+
if not common:
|
195
|
+
raise ValueError("No common cutBookkeeper sub-groups with 'counts' found")
|
196
|
+
return sorted(common)
|
197
|
+
|
198
|
+
|
199
|
+
def aggregate_cutbookkeeper(
|
200
|
+
fnames: list[str],
|
201
|
+
group_name: str = "cutBookkeeper",
|
202
|
+
) -> dict[str, np.ndarray] | None:
|
203
|
+
"""Aggregate the cutBookkeeper in the input files.
|
204
|
+
|
205
|
+
For every input file:
|
206
|
+
For every sub-group (nominal, sysUp, sysDown, …):
|
207
|
+
1. Sum the 4-entry record array inside each file into 1 record
|
208
|
+
1. Add those records from all files together into grand total
|
209
|
+
Returns a dict {subgroup_name: scalar-record-array}
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
fnames : list[str]
|
214
|
+
List of the input files
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
dict[str, np.ndarray] | None
|
219
|
+
Dict with the accumulated cutBookkeeper groups. If the cut bookkeeper
|
220
|
+
is not in the files, return None.
|
221
|
+
"""
|
222
|
+
if any(group_name not in h5py.File(f, "r") for f in fnames):
|
223
|
+
return None
|
224
|
+
|
225
|
+
subgroups = check_subgroups(fnames, group_name=group_name)
|
226
|
+
|
227
|
+
# initialise an accumulator per subgroup (dtype taken from 1st file)
|
228
|
+
accum: dict[str, np.ndarray] = {}
|
229
|
+
with h5py.File(fnames[0], "r") as f0:
|
230
|
+
for sg in subgroups:
|
231
|
+
dtype = f0[f"{group_name}/{sg}/counts"].dtype
|
232
|
+
accum[sg] = np.zeros((), dtype=dtype)
|
233
|
+
|
234
|
+
# add each files contribution field-wise
|
235
|
+
for fname in fnames:
|
236
|
+
with h5py.File(fname, "r") as f:
|
237
|
+
for sg in subgroups:
|
238
|
+
per_file = sum_counts_once(f[f"{group_name}/{sg}/counts"][()])
|
239
|
+
for fld in accum[sg].dtype.names:
|
240
|
+
accum[sg][fld] += per_file[fld]
|
241
|
+
|
242
|
+
return accum
|
243
|
+
|
244
|
+
|
63
245
|
def create_virtual_file(
|
64
246
|
pattern: Path | str,
|
65
|
-
out_fname: Path | None = None,
|
247
|
+
out_fname: Path | str | None = None,
|
66
248
|
use_regex: bool = False,
|
67
249
|
regex_path: str | None = None,
|
68
250
|
overwrite: bool = False,
|
69
|
-
|
70
|
-
|
251
|
+
bookkeeper_name: str = "cutBookkeeper",
|
252
|
+
) -> Path:
|
253
|
+
"""Create the virtual dataset file for the given inputs.
|
254
|
+
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
pattern : Path | str
|
258
|
+
Pattern of the input files used. Wildcard is supported
|
259
|
+
out_fname : Path | str | None, optional
|
260
|
+
Output path to which the virtual dataset file is written. By default None
|
261
|
+
use_regex : bool, optional
|
262
|
+
If you want to use regex instead of glob, by default False
|
263
|
+
regex_path : str | None, optional
|
264
|
+
Regex logic used to define the input files, by default None
|
265
|
+
overwrite : bool, optional
|
266
|
+
Decide, if an existing output file is overwritten, by default False
|
267
|
+
bookkeeper_name : str, optional
|
268
|
+
Name of the cut bookkeeper in the h5 files.
|
269
|
+
|
270
|
+
Returns
|
271
|
+
-------
|
272
|
+
Path
|
273
|
+
Path object of the path to which the output file is written
|
274
|
+
|
275
|
+
Raises
|
276
|
+
------
|
277
|
+
FileNotFoundError
|
278
|
+
If not input files were found for the given pattern
|
279
|
+
ValueError
|
280
|
+
If no output file is given and the input comes from multiple directories
|
281
|
+
"""
|
282
|
+
# Get list of filenames
|
71
283
|
pattern_str = str(pattern)
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
284
|
+
|
285
|
+
# Use regex to find input files else use glob
|
286
|
+
if use_regex is True:
|
287
|
+
matched = glob_re(pattern_str, regex_path)
|
288
|
+
fnames = regex_files_from_dir(matched, regex_path)
|
76
289
|
else:
|
77
290
|
fnames = glob.glob(pattern_str)
|
291
|
+
|
292
|
+
# Throw error if no input files were found
|
78
293
|
if not fnames:
|
79
|
-
raise FileNotFoundError(f"No files matched pattern {pattern}")
|
80
|
-
print("Files to merge to vds: ", fnames)
|
294
|
+
raise FileNotFoundError(f"No files matched pattern {pattern!r}")
|
81
295
|
|
82
|
-
#
|
296
|
+
# Infer output path if not given
|
83
297
|
if out_fname is None:
|
84
|
-
|
298
|
+
if len({Path(f).parent for f in fnames}) != 1:
|
299
|
+
raise ValueError("Give --output when files reside in multiple dirs")
|
85
300
|
out_fname = Path(fnames[0]).parent / "vds" / "vds.h5"
|
86
301
|
else:
|
87
302
|
out_fname = Path(out_fname)
|
88
303
|
|
89
|
-
#
|
304
|
+
# If overwrite is not active and a file exists, stop here
|
90
305
|
if not overwrite and out_fname.is_file():
|
91
306
|
return out_fname
|
92
307
|
|
93
|
-
#
|
308
|
+
# Identify common groups across all files
|
94
309
|
common_groups: set[str] = set()
|
95
310
|
for fname in fnames:
|
96
|
-
with h5py.File(fname) as f:
|
311
|
+
with h5py.File(fname, "r") as f:
|
97
312
|
groups = set(f.keys())
|
98
|
-
common_groups = groups if not common_groups else common_groups
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
#
|
104
|
-
out_fname.parent.mkdir(exist_ok=True)
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
313
|
+
common_groups = groups if not common_groups else common_groups & groups
|
314
|
+
|
315
|
+
# Ditch the bookkeeper. We will process it separately
|
316
|
+
common_groups.discard("cutBookkeeper")
|
317
|
+
|
318
|
+
# Check that the directory of the output file exists
|
319
|
+
out_fname.parent.mkdir(parents=True, exist_ok=True)
|
320
|
+
|
321
|
+
# Build the output file
|
322
|
+
with h5py.File(out_fname, "w") as fout:
|
323
|
+
# Build "standard" groups
|
324
|
+
for gname in sorted(common_groups):
|
325
|
+
layout = get_virtual_layout(fnames, gname)
|
326
|
+
fout.create_virtual_dataset(gname, layout)
|
327
|
+
|
328
|
+
# Copy first-file attributes to VDS root object
|
329
|
+
with h5py.File(fnames[0], "r") as f0:
|
330
|
+
for k, v in f0[gname].attrs.items():
|
331
|
+
fout[gname].attrs[k] = v
|
332
|
+
|
333
|
+
# Build the cutBookkeeper
|
334
|
+
counts_total = aggregate_cutbookkeeper(fnames=fnames, group_name=bookkeeper_name)
|
335
|
+
if counts_total is not None:
|
336
|
+
for sg, record in counts_total.items():
|
337
|
+
grp = fout.require_group(f"{bookkeeper_name}/{sg}")
|
338
|
+
grp.create_dataset("counts", data=record, shape=(), dtype=record.dtype)
|
119
339
|
|
120
340
|
return out_fname
|
121
341
|
|
@@ -123,19 +343,20 @@ def create_virtual_file(
|
|
123
343
|
def main(args=None) -> None:
|
124
344
|
args = parse_args(args)
|
125
345
|
matching_mode = "Applying regex to" if args.use_regex else "Globbing"
|
126
|
-
print(f"{matching_mode} {args.pattern}...")
|
127
|
-
create_virtual_file(
|
128
|
-
args.pattern,
|
129
|
-
args.output,
|
346
|
+
print(f"{matching_mode} {args.pattern} ...")
|
347
|
+
out_path = create_virtual_file(
|
348
|
+
pattern=args.pattern,
|
349
|
+
out_fname=args.output,
|
130
350
|
use_regex=args.use_regex,
|
131
351
|
regex_path=args.regex_path,
|
132
352
|
overwrite=True,
|
133
353
|
)
|
134
|
-
|
354
|
+
|
355
|
+
with h5py.File(out_path, "r") as f:
|
135
356
|
key = next(iter(f.keys()))
|
136
|
-
|
137
|
-
|
138
|
-
print(f"Saved virtual file to {
|
357
|
+
print(f"Virtual dataset '{key}' has {len(f[key]):,} entries")
|
358
|
+
|
359
|
+
print(f"Saved virtual file to {out_path.resolve()}")
|
139
360
|
|
140
361
|
|
141
362
|
if __name__ == "__main__":
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|