atlas-ftag-tools 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/METADATA +1 -1
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/RECORD +13 -13
- ftag/__init__.py +1 -1
- ftag/flavours.yaml +1 -1
- ftag/hdf5/h5reader.py +17 -4
- ftag/hdf5/h5utils.py +10 -1
- ftag/hdf5/h5writer.py +79 -33
- ftag/labeller.py +1 -1
- ftag/mock.py +2 -2
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/WHEEL +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/licenses/LICENSE +0 -0
- {atlas_ftag_tools-0.2.11.dist-info → atlas_ftag_tools-0.2.12.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
|
|
1
|
-
atlas_ftag_tools-0.2.
|
2
|
-
ftag/__init__.py,sha256=
|
1
|
+
atlas_ftag_tools-0.2.12.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
|
2
|
+
ftag/__init__.py,sha256=CU1RjEu6pHq11LQ2kAy9YDittMHXB51fNWvuy1NFr7o,748
|
3
3
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
4
4
|
ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
|
5
5
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
6
|
-
ftag/flavours.yaml,sha256=
|
6
|
+
ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
|
7
7
|
ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
|
8
8
|
ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
|
9
|
-
ftag/labeller.py,sha256=
|
9
|
+
ftag/labeller.py,sha256=6tKLG0SrBijMIZdzWjGdQU9qN_dkRV4eELLX_8YQvTQ,2772
|
10
10
|
ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
|
11
|
-
ftag/mock.py,sha256=
|
11
|
+
ftag/mock.py,sha256=syysvzLsBHU8aw7Uy5g3G4HB6LnSmHlGa2BfeXv5mQ4,5970
|
12
12
|
ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
|
13
13
|
ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
|
14
14
|
ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
|
@@ -18,15 +18,15 @@ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
|
|
18
18
|
ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
|
19
19
|
ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
|
20
20
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
21
|
-
ftag/hdf5/h5reader.py,sha256=
|
21
|
+
ftag/hdf5/h5reader.py,sha256=NbHohY3RSicM3qnX_0Y1TfGAaDg3wgfEjYlGaWDJmug,14268
|
22
22
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
23
|
-
ftag/hdf5/h5utils.py,sha256
|
24
|
-
ftag/hdf5/h5writer.py,sha256=
|
23
|
+
ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
|
24
|
+
ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
|
25
25
|
ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
|
26
26
|
ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
|
27
27
|
ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
|
28
|
-
atlas_ftag_tools-0.2.
|
29
|
-
atlas_ftag_tools-0.2.
|
30
|
-
atlas_ftag_tools-0.2.
|
31
|
-
atlas_ftag_tools-0.2.
|
32
|
-
atlas_ftag_tools-0.2.
|
28
|
+
atlas_ftag_tools-0.2.12.dist-info/METADATA,sha256=bGfabVRARSL6PZTsDqen30IkQVqVGw8Tg9lMCnzY-5w,2152
|
29
|
+
atlas_ftag_tools-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
30
|
+
atlas_ftag_tools-0.2.12.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
|
31
|
+
atlas_ftag_tools-0.2.12.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
32
|
+
atlas_ftag_tools-0.2.12.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/flavours.yaml
CHANGED
@@ -113,7 +113,7 @@
|
|
113
113
|
category: xbb
|
114
114
|
- name: qcdnonbb
|
115
115
|
label: $\mathrm{QCD} \rightarrow \mathrm{non-} b \bar{b}$
|
116
|
-
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount
|
116
|
+
cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount < 2"]
|
117
117
|
colour: "silver"
|
118
118
|
category: xbb
|
119
119
|
- name: qcdbx
|
ftag/hdf5/h5reader.py
CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
|
|
74
74
|
num_jets: int | None = None,
|
75
75
|
cuts: Cuts | None = None,
|
76
76
|
start: int = 0,
|
77
|
+
skip_batches: int = 0,
|
77
78
|
) -> Generator:
|
78
79
|
if num_jets is None:
|
79
80
|
num_jets = self.num_jets
|
80
|
-
|
81
|
+
if skip_batches > 0:
|
82
|
+
assert not self.shuffle, "Cannot skip batches if shuffle is True"
|
81
83
|
if num_jets > self.num_jets:
|
82
84
|
log.warning(
|
83
85
|
f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
|
@@ -97,7 +99,8 @@ class H5SingleReader:
|
|
97
99
|
indices = list(range(start, self.num_jets + start, self.batch_size))
|
98
100
|
if self.shuffle:
|
99
101
|
self.rng.shuffle(indices)
|
100
|
-
|
102
|
+
if skip_batches > 0:
|
103
|
+
indices = indices[skip_batches:]
|
101
104
|
# loop over batches and read file
|
102
105
|
for low in indices:
|
103
106
|
for name in variables:
|
@@ -176,7 +179,12 @@ class H5Reader:
|
|
176
179
|
|
177
180
|
# calculate batch sizes
|
178
181
|
if self.weights is None:
|
179
|
-
|
182
|
+
rows_per_file = [
|
183
|
+
H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
|
184
|
+
]
|
185
|
+
num_total = sum(rows_per_file)
|
186
|
+
self.weights = [num / num_total for num in rows_per_file]
|
187
|
+
|
180
188
|
self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
|
181
189
|
|
182
190
|
# create readers
|
@@ -233,6 +241,7 @@ class H5Reader:
|
|
233
241
|
num_jets: int | None = None,
|
234
242
|
cuts: Cuts | None = None,
|
235
243
|
start: int = 0,
|
244
|
+
skip_batches: int = 0,
|
236
245
|
) -> Generator:
|
237
246
|
"""Generate batches of selected jets.
|
238
247
|
|
@@ -246,6 +255,8 @@ class H5Reader:
|
|
246
255
|
Selection cuts to apply, by default None
|
247
256
|
start : int, optional
|
248
257
|
Starting index of the first jet to read, by default 0
|
258
|
+
skip_batches : int, optional
|
259
|
+
Number of batches to skip, by default 0
|
249
260
|
|
250
261
|
Yields
|
251
262
|
------
|
@@ -266,7 +277,9 @@ class H5Reader:
|
|
266
277
|
|
267
278
|
# get streams for selected jets from each reader
|
268
279
|
streams = [
|
269
|
-
r.stream(
|
280
|
+
r.stream(
|
281
|
+
variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
|
282
|
+
)
|
270
283
|
for r in self.readers
|
271
284
|
]
|
272
285
|
|
ftag/hdf5/h5utils.py
CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
|
|
13
13
|
variables: list[str] | None = None,
|
14
14
|
precision: str | None = None,
|
15
15
|
transform: Transform | None = None,
|
16
|
+
full_precision_vars: list[str] | None = None,
|
16
17
|
) -> np.dtype:
|
17
18
|
"""Return a dtype based on an existing dataset and requested variables.
|
18
19
|
|
@@ -26,6 +27,8 @@ def get_dtype(
|
|
26
27
|
Precision to cast floats to, "half" or "full", by default None
|
27
28
|
transform : Transform | None, optional
|
28
29
|
Transform to apply to variables names, by default None
|
30
|
+
full_precision_vars : list[str] | None, optional
|
31
|
+
List of variables to keep in full precision, by default None
|
29
32
|
|
30
33
|
Returns
|
31
34
|
-------
|
@@ -39,6 +42,8 @@ def get_dtype(
|
|
39
42
|
"""
|
40
43
|
if variables is None:
|
41
44
|
variables = ds.dtype.names
|
45
|
+
if full_precision_vars is None:
|
46
|
+
full_precision_vars = []
|
42
47
|
|
43
48
|
if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
|
44
49
|
variables = transform.map_variable_names(ds.name, variables, inverse=True)
|
@@ -50,7 +55,10 @@ def get_dtype(
|
|
50
55
|
|
51
56
|
dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
|
52
57
|
if precision:
|
53
|
-
dtype = [
|
58
|
+
dtype = [
|
59
|
+
(n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
|
60
|
+
for n, x in dtype
|
61
|
+
]
|
54
62
|
|
55
63
|
return np.dtype(dtype)
|
56
64
|
|
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
|
|
78
86
|
t = np.dtype(typestr)
|
79
87
|
if t.kind != "f":
|
80
88
|
return t
|
89
|
+
|
81
90
|
if precision == "half":
|
82
91
|
return np.dtype("f2")
|
83
92
|
if precision == "full":
|
ftag/hdf5/h5writer.py
CHANGED
@@ -47,18 +47,28 @@ class H5Writer:
|
|
47
47
|
precision: str = "full"
|
48
48
|
full_precision_vars: list[str] | None = None
|
49
49
|
shuffle: bool = True
|
50
|
+
num_jets: int | None = None # Allow dynamic mode by defaulting to None
|
50
51
|
|
51
52
|
def __post_init__(self):
|
52
53
|
self.num_written = 0
|
53
54
|
self.rng = np.random.default_rng(42)
|
54
|
-
|
55
|
-
|
56
|
-
|
55
|
+
|
56
|
+
# Infer number of jets from shapes if not explicitly passed
|
57
|
+
inferred_num_jets = [shape[0] for shape in self.shapes.values()]
|
58
|
+
if self.num_jets is None:
|
59
|
+
assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
|
60
|
+
self.fixed_mode = False
|
61
|
+
else:
|
62
|
+
self.fixed_mode = True
|
63
|
+
for name in self.shapes:
|
64
|
+
self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
|
57
65
|
|
58
66
|
if self.precision == "full":
|
59
67
|
self.fp_dtype = np.float32
|
60
68
|
elif self.precision == "half":
|
61
69
|
self.fp_dtype = np.float16
|
70
|
+
elif self.precision is None:
|
71
|
+
self.fp_dtype = None
|
62
72
|
else:
|
63
73
|
raise ValueError(f"Invalid precision: {self.precision}")
|
64
74
|
|
@@ -71,16 +81,34 @@ class H5Writer:
|
|
71
81
|
self.create_ds(name, dtype)
|
72
82
|
|
73
83
|
@classmethod
|
74
|
-
def from_file(
|
84
|
+
def from_file(
|
85
|
+
cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
|
86
|
+
) -> H5Writer:
|
75
87
|
with h5py.File(source, "r") as f:
|
76
88
|
dtypes = {name: ds.dtype for name, ds in f.items()}
|
77
89
|
shapes = {name: ds.shape for name, ds in f.items()}
|
78
|
-
|
90
|
+
|
91
|
+
if variables:
|
92
|
+
new_dtye = {}
|
93
|
+
new_shape = {}
|
94
|
+
for name, ds in f.items():
|
95
|
+
if name not in variables:
|
96
|
+
continue
|
97
|
+
new_dtye[name] = ftag.hdf5.get_dtype(
|
98
|
+
ds,
|
99
|
+
variables=variables[name],
|
100
|
+
precision=kwargs.get("precision"),
|
101
|
+
full_precision_vars=kwargs.get("full_precision_vars"),
|
102
|
+
)
|
103
|
+
new_shape[name] = ds.shape
|
104
|
+
dtypes = new_dtye
|
105
|
+
shapes = new_shape
|
106
|
+
if num_jets != 0:
|
79
107
|
shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
|
80
108
|
compression = [ds.compression for ds in f.values()]
|
81
109
|
assert len(set(compression)) == 1, "Must have same compression for all groups"
|
82
110
|
compression = compression[0]
|
83
|
-
if compression not in kwargs:
|
111
|
+
if "compression" not in kwargs:
|
84
112
|
kwargs["compression"] = compression
|
85
113
|
return cls(dtypes=dtypes, shapes=shapes, **kwargs)
|
86
114
|
|
@@ -88,36 +116,47 @@ class H5Writer:
|
|
88
116
|
if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
|
89
117
|
dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
|
90
118
|
|
91
|
-
|
92
|
-
|
93
|
-
#
|
119
|
+
fp_vars = self.full_precision_vars or []
|
120
|
+
# If no precision is defined, or the field is in full_precision_vars, or its non-float,
|
121
|
+
# keep it at the original dtype
|
94
122
|
dtype = np.dtype([
|
95
123
|
(
|
96
124
|
field,
|
97
|
-
|
98
|
-
|
99
|
-
|
125
|
+
(
|
126
|
+
self.fp_dtype
|
127
|
+
if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
|
128
|
+
else dt
|
129
|
+
),
|
100
130
|
)
|
101
131
|
for field, dt in dtype.descr
|
102
132
|
])
|
103
133
|
|
104
|
-
# optimal chunking is around 100 jets, only aply for track groups
|
105
134
|
shape = self.shapes[name]
|
106
135
|
chunks = (100,) + shape[1:] if shape[1:] else None
|
107
136
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
137
|
+
if self.fixed_mode:
|
138
|
+
self.file.create_dataset(
|
139
|
+
name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
maxshape = (None,) + shape[1:]
|
143
|
+
self.file.create_dataset(
|
144
|
+
name,
|
145
|
+
dtype=dtype,
|
146
|
+
shape=(0,) + shape[1:],
|
147
|
+
maxshape=maxshape,
|
148
|
+
compression=self.compression,
|
149
|
+
chunks=chunks,
|
150
|
+
)
|
112
151
|
|
113
152
|
def close(self) -> None:
|
114
|
-
|
115
|
-
written = len(
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
153
|
+
if self.fixed_mode:
|
154
|
+
written = len(self.file[self.jets_name])
|
155
|
+
if self.num_written != written:
|
156
|
+
raise ValueError(
|
157
|
+
f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
|
158
|
+
f" {written:,} jets have been written"
|
159
|
+
)
|
121
160
|
self.file.close()
|
122
161
|
|
123
162
|
def get_attr(self, name, group=None):
|
@@ -137,18 +176,25 @@ class H5Writer:
|
|
137
176
|
for attr_name, value in ds.attrs.items():
|
138
177
|
self.add_attr(attr_name, value, group=name)
|
139
178
|
|
140
|
-
def write(self, data: dict[str, np.
|
141
|
-
|
142
|
-
|
143
|
-
f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
|
144
|
-
)
|
145
|
-
idx = np.arange(len(data[self.jets_name]))
|
179
|
+
def write(self, data: dict[str, np.ndarray]) -> None:
|
180
|
+
batch_size = len(data[self.jets_name])
|
181
|
+
idx = np.arange(batch_size)
|
146
182
|
if self.shuffle:
|
147
183
|
self.rng.shuffle(idx)
|
148
184
|
data = {name: array[idx] for name, array in data.items()}
|
149
185
|
|
150
186
|
low = self.num_written
|
151
|
-
high = low +
|
187
|
+
high = low + batch_size
|
188
|
+
|
189
|
+
if self.fixed_mode and high > self.num_jets:
|
190
|
+
raise ValueError(
|
191
|
+
f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
|
192
|
+
)
|
193
|
+
|
152
194
|
for group in self.dtypes:
|
153
|
-
self.file[group]
|
154
|
-
|
195
|
+
ds = self.file[group]
|
196
|
+
if not self.fixed_mode:
|
197
|
+
ds.resize((high,) + ds.shape[1:])
|
198
|
+
ds[low:high] = data[group]
|
199
|
+
|
200
|
+
self.num_written += batch_size
|
ftag/labeller.py
CHANGED
@@ -30,7 +30,7 @@ class Labeller:
|
|
30
30
|
def __post_init__(self) -> None:
|
31
31
|
if isinstance(self.labels, LabelContainer):
|
32
32
|
self.labels = list(self.labels)
|
33
|
-
self.labels =
|
33
|
+
self.labels = [Flavours[label] for label in self.labels]
|
34
34
|
|
35
35
|
@property
|
36
36
|
def variables(self) -> list[str]:
|
ftag/mock.py
CHANGED
@@ -106,11 +106,11 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
|
|
106
106
|
for i in range(n_classes):
|
107
107
|
tmp_means = []
|
108
108
|
tmp_means = [
|
109
|
-
0 if j != i else mean_scale_list[
|
109
|
+
0 if j != i else mean_scale_list[rng.integers(0, len(mean_scale_list))]
|
110
110
|
for j in range(n_classes)
|
111
111
|
]
|
112
112
|
means.append(tmp_means)
|
113
|
-
scales.append(mean_scale_list[
|
113
|
+
scales.append(mean_scale_list[rng.integers(0, len(mean_scale_list))])
|
114
114
|
|
115
115
|
# Map the labels to the means
|
116
116
|
label_mapping = dict(zip(label_dict.values(), means))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|