atlas-ftag-tools 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/METADATA +1 -1
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/RECORD +9 -9
- ftag/__init__.py +1 -1
- ftag/cuts.py +18 -1
- ftag/hdf5/h5reader.py +141 -12
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/WHEEL +0 -0
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/entry_points.txt +0 -0
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {atlas_ftag_tools-0.2.13.dist-info → atlas_ftag_tools-0.2.14.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
-
atlas_ftag_tools-0.2.
|
2
|
-
ftag/__init__.py,sha256=
|
1
|
+
atlas_ftag_tools-0.2.14.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
|
2
|
+
ftag/__init__.py,sha256=N224CzevGVd0J9m-ak4UuIuCUXc5Cf6Op9O993TqX80,748
|
3
3
|
ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
|
4
|
-
ftag/cuts.py,sha256=
|
4
|
+
ftag/cuts.py,sha256=r1g0vHJtafwEsCPlss685v9a5YDCSSMf-AXsqHrlxV0,3511
|
5
5
|
ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
|
6
6
|
ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
|
7
7
|
ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
|
@@ -18,15 +18,15 @@ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
|
|
18
18
|
ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
|
19
19
|
ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
|
20
20
|
ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
|
21
|
-
ftag/hdf5/h5reader.py,sha256=
|
21
|
+
ftag/hdf5/h5reader.py,sha256=kDZykPPGS2ABixp3rBPhG-6TTRTN3VKKBKzFV_sv-Eg,18542
|
22
22
|
ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
|
23
23
|
ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
|
24
24
|
ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
|
25
25
|
ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
|
26
26
|
ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
|
27
27
|
ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
|
28
|
-
atlas_ftag_tools-0.2.
|
29
|
-
atlas_ftag_tools-0.2.
|
30
|
-
atlas_ftag_tools-0.2.
|
31
|
-
atlas_ftag_tools-0.2.
|
32
|
-
atlas_ftag_tools-0.2.
|
28
|
+
atlas_ftag_tools-0.2.14.dist-info/METADATA,sha256=ZUQJ5F0_oCFY-MtQ3-jChGevsFcIui11nQYJxF0tWjw,2153
|
29
|
+
atlas_ftag_tools-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
30
|
+
atlas_ftag_tools-0.2.14.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
|
31
|
+
atlas_ftag_tools-0.2.14.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
|
32
|
+
atlas_ftag_tools-0.2.14.dist-info/RECORD,,
|
ftag/__init__.py
CHANGED
ftag/cuts.py
CHANGED
@@ -42,11 +42,28 @@ class Cut:
|
|
42
42
|
@property
|
43
43
|
def value(self) -> int | float:
|
44
44
|
if isinstance(self._value, str):
|
45
|
+
txt = self._value.strip().lower()
|
46
|
+
if txt == "nan":
|
47
|
+
return float("nan")
|
48
|
+
if txt in {"+inf", "inf"}:
|
49
|
+
return float("inf")
|
50
|
+
if txt == "-inf":
|
51
|
+
return float("-inf")
|
45
52
|
return literal_eval(self._value)
|
46
53
|
return self._value
|
47
54
|
|
48
55
|
def __call__(self, array):
|
49
|
-
|
56
|
+
col = array[self.variable]
|
57
|
+
val = self.value
|
58
|
+
|
59
|
+
if isinstance(val, float) and np.isnan(val):
|
60
|
+
if self.operator == "!=":
|
61
|
+
return ~np.isnan(col)
|
62
|
+
if self.operator == "==":
|
63
|
+
return np.isnan(col)
|
64
|
+
raise ValueError("'nan' only makes sense with '==' or '!=' operators")
|
65
|
+
|
66
|
+
return OPERATORS[self.operator](col, val)
|
50
67
|
|
51
68
|
def __str__(self) -> str:
|
52
69
|
return f"{self.variable} {self.operator} {self.value}"
|
ftag/hdf5/h5reader.py
CHANGED
@@ -68,6 +68,37 @@ class H5SingleReader:
|
|
68
68
|
)
|
69
69
|
return {name: array[keep_idx] for name, array in data.items()}
|
70
70
|
|
71
|
+
def _process_batch(self, data: dict, cuts: Cuts | None = None) -> dict:
|
72
|
+
"""Apply cuts and transformations to the batch.
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
data : dict
|
77
|
+
Dictionary of arrays for each group.
|
78
|
+
cuts : Cuts | None, optional
|
79
|
+
Selection cuts to apply, by default None
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
dict
|
84
|
+
Processed data dictionary with arrays for each group. After applying cuts,
|
85
|
+
(optional) removal of infs, and (optional) transformation.
|
86
|
+
"""
|
87
|
+
# apply selections
|
88
|
+
if cuts:
|
89
|
+
idx = cuts(data[self.jets_name]).idx
|
90
|
+
data = {name: array[idx] for name, array in data.items()}
|
91
|
+
|
92
|
+
# check for inf and remove
|
93
|
+
if self.do_remove_inf:
|
94
|
+
data = self.remove_inf(data)
|
95
|
+
|
96
|
+
# apply transform
|
97
|
+
if self.transform:
|
98
|
+
data = self.transform(data)
|
99
|
+
|
100
|
+
return data
|
101
|
+
|
71
102
|
def stream(
|
72
103
|
self,
|
73
104
|
variables: dict | None = None,
|
@@ -106,18 +137,8 @@ class H5SingleReader:
|
|
106
137
|
for name in variables:
|
107
138
|
data[name] = self.read_chunk(f[name], arrays[name], low)
|
108
139
|
|
109
|
-
#
|
110
|
-
|
111
|
-
idx = cuts(data[self.jets_name]).idx
|
112
|
-
data = {name: array[idx] for name, array in data.items()}
|
113
|
-
|
114
|
-
# check for inf and remove
|
115
|
-
if self.do_remove_inf:
|
116
|
-
data = self.remove_inf(data)
|
117
|
-
|
118
|
-
# apply transform
|
119
|
-
if self.transform:
|
120
|
-
data = self.transform(data)
|
140
|
+
# Apply cuts and transformations
|
141
|
+
data = self._process_batch(data, cuts)
|
121
142
|
|
122
143
|
# check for completion
|
123
144
|
total += len(data[self.jets_name])
|
@@ -129,6 +150,58 @@ class H5SingleReader:
|
|
129
150
|
|
130
151
|
yield data
|
131
152
|
|
153
|
+
def get_batch_reader(
|
154
|
+
self,
|
155
|
+
variables: dict | None = None,
|
156
|
+
cuts: Cuts | None = None,
|
157
|
+
):
|
158
|
+
"""Get a function to read batches of selected jets.
|
159
|
+
|
160
|
+
Parameters
|
161
|
+
----------
|
162
|
+
variables : dict | None, optional
|
163
|
+
Dictionary of variables to for each group, by default use all jet variables.
|
164
|
+
cuts : Cuts | None, optional
|
165
|
+
Selection cuts to apply, by default None
|
166
|
+
|
167
|
+
Returns
|
168
|
+
-------
|
169
|
+
function
|
170
|
+
Function that takes an index and returns a batch of selected jets.
|
171
|
+
"""
|
172
|
+
if variables is None:
|
173
|
+
variables = {self.jets_name: None}
|
174
|
+
h5 = h5py.File(self.fname, "r")
|
175
|
+
arrays = {name: self.empty(h5[name], var) for name, var in variables.items()}
|
176
|
+
# nonlocal data
|
177
|
+
data = {name: self.empty(h5[name], var) for name, var in variables.items()}
|
178
|
+
|
179
|
+
def get_batch(idx: int) -> dict | None:
|
180
|
+
"""Get a batch of data from the HDF5 file.
|
181
|
+
|
182
|
+
Parameters
|
183
|
+
----------
|
184
|
+
idx : int
|
185
|
+
Index of the batch to read.
|
186
|
+
|
187
|
+
Returns
|
188
|
+
-------
|
189
|
+
dict | None
|
190
|
+
Dictionary of arrays for each group, or None if no more batches are available.
|
191
|
+
"""
|
192
|
+
low = idx * self.batch_size
|
193
|
+
if low >= self.num_jets:
|
194
|
+
return None
|
195
|
+
|
196
|
+
for name in variables:
|
197
|
+
data[name] = self.read_chunk(h5[name], arrays[name], low)
|
198
|
+
|
199
|
+
data_out = {name: array.copy() for name, array in data.items()}
|
200
|
+
|
201
|
+
return self._process_batch(data_out, cuts)
|
202
|
+
|
203
|
+
return get_batch
|
204
|
+
|
132
205
|
|
133
206
|
@dataclass
|
134
207
|
class H5Reader:
|
@@ -315,6 +388,62 @@ class H5Reader:
|
|
315
388
|
# yield batch
|
316
389
|
yield data
|
317
390
|
|
391
|
+
def get_batch_reader(
|
392
|
+
self, variables: dict | None = None, cuts: Cuts | None = None, shuffle: bool = True
|
393
|
+
):
|
394
|
+
"""Get a function to read batches of selected jets.
|
395
|
+
|
396
|
+
Parameters
|
397
|
+
----------
|
398
|
+
variables : dict | None, optional
|
399
|
+
Dictionary of variables to for each group, by default use all jet variables.
|
400
|
+
cuts : Cuts | None, optional
|
401
|
+
Selection cuts to apply, by default None
|
402
|
+
shuffle : bool, optional
|
403
|
+
Read batches in a shuffled order, by default True
|
404
|
+
|
405
|
+
Returns
|
406
|
+
-------
|
407
|
+
function
|
408
|
+
Function that takes an index and returns a batch of selected jets.
|
409
|
+
"""
|
410
|
+
if variables is None:
|
411
|
+
variables = {self.jets_name: None}
|
412
|
+
|
413
|
+
# create batch readers for each sample
|
414
|
+
batch_readers = [r.get_batch_reader(variables, cuts) for r in self.readers]
|
415
|
+
|
416
|
+
def get_batch(idx: int) -> dict | None:
|
417
|
+
"""Get a batch of data from the HDF5 files.
|
418
|
+
|
419
|
+
Parameters
|
420
|
+
----------
|
421
|
+
idx : int
|
422
|
+
Index of the batch to read.
|
423
|
+
|
424
|
+
Returns
|
425
|
+
-------
|
426
|
+
dict | None
|
427
|
+
Dictionary of arrays for each group, or None if no more batches are available.
|
428
|
+
"""
|
429
|
+
assert idx >= 0, "Index must be non-negative"
|
430
|
+
if idx * self.batch_size >= self.num_jets:
|
431
|
+
return None
|
432
|
+
# get a batch from each sample
|
433
|
+
samples = [br(idx) for br in batch_readers]
|
434
|
+
samples = [s for s in samples if s is not None]
|
435
|
+
if len(samples) == 0:
|
436
|
+
return None
|
437
|
+
# combine samples and shuffle
|
438
|
+
data = {name: np.concatenate([s[name] for s in samples]) for name in variables}
|
439
|
+
if shuffle:
|
440
|
+
idx = np.arange(len(data[self.jets_name]))
|
441
|
+
self.rng.shuffle(idx)
|
442
|
+
data = {name: array[idx] for name, array in data.items()}
|
443
|
+
return data
|
444
|
+
|
445
|
+
return get_batch
|
446
|
+
|
318
447
|
def load(
|
319
448
|
self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
|
320
449
|
) -> dict:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|