atlas-ftag-tools 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.13
3
+ Version: 0.2.14
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
- atlas_ftag_tools-0.2.13.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
- ftag/__init__.py,sha256=UdYmO_mROM7jvqpPUMbnaQxdCrlR8O0KLlhatyMnapw,748
1
+ atlas_ftag_tools-0.2.14.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
+ ftag/__init__.py,sha256=N224CzevGVd0J9m-ak4UuIuCUXc5Cf6Op9O993TqX80,748
3
3
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
4
- ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
+ ftag/cuts.py,sha256=r1g0vHJtafwEsCPlss685v9a5YDCSSMf-AXsqHrlxV0,3511
5
5
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
6
6
  ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
7
7
  ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
@@ -18,15 +18,15 @@ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
18
18
  ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
19
19
  ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
20
20
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
21
- ftag/hdf5/h5reader.py,sha256=NbHohY3RSicM3qnX_0Y1TfGAaDg3wgfEjYlGaWDJmug,14268
21
+ ftag/hdf5/h5reader.py,sha256=kDZykPPGS2ABixp3rBPhG-6TTRTN3VKKBKzFV_sv-Eg,18542
22
22
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
23
23
  ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
24
24
  ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
25
25
  ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
26
26
  ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
27
27
  ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
28
- atlas_ftag_tools-0.2.13.dist-info/METADATA,sha256=ZpQ5GggkLyizsv9uHEOvIlzRqPmC-4tNaoaMgV6unF4,2153
29
- atlas_ftag_tools-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- atlas_ftag_tools-0.2.13.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
- atlas_ftag_tools-0.2.13.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
- atlas_ftag_tools-0.2.13.dist-info/RECORD,,
28
+ atlas_ftag_tools-0.2.14.dist-info/METADATA,sha256=ZUQJ5F0_oCFY-MtQ3-jChGevsFcIui11nQYJxF0tWjw,2153
29
+ atlas_ftag_tools-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ atlas_ftag_tools-0.2.14.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
+ atlas_ftag_tools-0.2.14.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
+ atlas_ftag_tools-0.2.14.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.13"
5
+ __version__ = "v0.2.14"
6
6
 
7
7
  from . import hdf5, utils
8
8
  from .cuts import Cuts
ftag/cuts.py CHANGED
@@ -42,11 +42,28 @@ class Cut:
42
42
  @property
43
43
  def value(self) -> int | float:
44
44
  if isinstance(self._value, str):
45
+ txt = self._value.strip().lower()
46
+ if txt == "nan":
47
+ return float("nan")
48
+ if txt in {"+inf", "inf"}:
49
+ return float("inf")
50
+ if txt == "-inf":
51
+ return float("-inf")
45
52
  return literal_eval(self._value)
46
53
  return self._value
47
54
 
48
55
  def __call__(self, array):
49
- return OPERATORS[self.operator](array[self.variable], self.value)
56
+ col = array[self.variable]
57
+ val = self.value
58
+
59
+ if isinstance(val, float) and np.isnan(val):
60
+ if self.operator == "!=":
61
+ return ~np.isnan(col)
62
+ if self.operator == "==":
63
+ return np.isnan(col)
64
+ raise ValueError("'nan' only makes sense with '==' or '!=' operators")
65
+
66
+ return OPERATORS[self.operator](col, val)
50
67
 
51
68
  def __str__(self) -> str:
52
69
  return f"{self.variable} {self.operator} {self.value}"
ftag/hdf5/h5reader.py CHANGED
@@ -68,6 +68,37 @@ class H5SingleReader:
68
68
  )
69
69
  return {name: array[keep_idx] for name, array in data.items()}
70
70
 
71
+ def _process_batch(self, data: dict, cuts: Cuts | None = None) -> dict:
72
+ """Apply cuts and transformations to the batch.
73
+
74
+ Parameters
75
+ ----------
76
+ data : dict
77
+ Dictionary of arrays for each group.
78
+ cuts : Cuts | None, optional
79
+ Selection cuts to apply, by default None
80
+
81
+ Returns
82
+ -------
83
+ dict
84
+ Processed data dictionary with arrays for each group. After applying cuts,
85
+ (optional) removal of infs, and (optional) transformation.
86
+ """
87
+ # apply selections
88
+ if cuts:
89
+ idx = cuts(data[self.jets_name]).idx
90
+ data = {name: array[idx] for name, array in data.items()}
91
+
92
+ # check for inf and remove
93
+ if self.do_remove_inf:
94
+ data = self.remove_inf(data)
95
+
96
+ # apply transform
97
+ if self.transform:
98
+ data = self.transform(data)
99
+
100
+ return data
101
+
71
102
  def stream(
72
103
  self,
73
104
  variables: dict | None = None,
@@ -106,18 +137,8 @@ class H5SingleReader:
106
137
  for name in variables:
107
138
  data[name] = self.read_chunk(f[name], arrays[name], low)
108
139
 
109
- # apply selections
110
- if cuts:
111
- idx = cuts(data[self.jets_name]).idx
112
- data = {name: array[idx] for name, array in data.items()}
113
-
114
- # check for inf and remove
115
- if self.do_remove_inf:
116
- data = self.remove_inf(data)
117
-
118
- # apply transform
119
- if self.transform:
120
- data = self.transform(data)
140
+ # Apply cuts and transformations
141
+ data = self._process_batch(data, cuts)
121
142
 
122
143
  # check for completion
123
144
  total += len(data[self.jets_name])
@@ -129,6 +150,58 @@ class H5SingleReader:
129
150
 
130
151
  yield data
131
152
 
153
+ def get_batch_reader(
154
+ self,
155
+ variables: dict | None = None,
156
+ cuts: Cuts | None = None,
157
+ ):
158
+ """Get a function to read batches of selected jets.
159
+
160
+ Parameters
161
+ ----------
162
+ variables : dict | None, optional
163
+ Dictionary of variables to for each group, by default use all jet variables.
164
+ cuts : Cuts | None, optional
165
+ Selection cuts to apply, by default None
166
+
167
+ Returns
168
+ -------
169
+ function
170
+ Function that takes an index and returns a batch of selected jets.
171
+ """
172
+ if variables is None:
173
+ variables = {self.jets_name: None}
174
+ h5 = h5py.File(self.fname, "r")
175
+ arrays = {name: self.empty(h5[name], var) for name, var in variables.items()}
176
+ # nonlocal data
177
+ data = {name: self.empty(h5[name], var) for name, var in variables.items()}
178
+
179
+ def get_batch(idx: int) -> dict | None:
180
+ """Get a batch of data from the HDF5 file.
181
+
182
+ Parameters
183
+ ----------
184
+ idx : int
185
+ Index of the batch to read.
186
+
187
+ Returns
188
+ -------
189
+ dict | None
190
+ Dictionary of arrays for each group, or None if no more batches are available.
191
+ """
192
+ low = idx * self.batch_size
193
+ if low >= self.num_jets:
194
+ return None
195
+
196
+ for name in variables:
197
+ data[name] = self.read_chunk(h5[name], arrays[name], low)
198
+
199
+ data_out = {name: array.copy() for name, array in data.items()}
200
+
201
+ return self._process_batch(data_out, cuts)
202
+
203
+ return get_batch
204
+
132
205
 
133
206
  @dataclass
134
207
  class H5Reader:
@@ -315,6 +388,62 @@ class H5Reader:
315
388
  # yield batch
316
389
  yield data
317
390
 
391
+ def get_batch_reader(
392
+ self, variables: dict | None = None, cuts: Cuts | None = None, shuffle: bool = True
393
+ ):
394
+ """Get a function to read batches of selected jets.
395
+
396
+ Parameters
397
+ ----------
398
+ variables : dict | None, optional
399
+ Dictionary of variables to for each group, by default use all jet variables.
400
+ cuts : Cuts | None, optional
401
+ Selection cuts to apply, by default None
402
+ shuffle : bool, optional
403
+ Read batches in a shuffled order, by default True
404
+
405
+ Returns
406
+ -------
407
+ function
408
+ Function that takes an index and returns a batch of selected jets.
409
+ """
410
+ if variables is None:
411
+ variables = {self.jets_name: None}
412
+
413
+ # create batch readers for each sample
414
+ batch_readers = [r.get_batch_reader(variables, cuts) for r in self.readers]
415
+
416
+ def get_batch(idx: int) -> dict | None:
417
+ """Get a batch of data from the HDF5 files.
418
+
419
+ Parameters
420
+ ----------
421
+ idx : int
422
+ Index of the batch to read.
423
+
424
+ Returns
425
+ -------
426
+ dict | None
427
+ Dictionary of arrays for each group, or None if no more batches are available.
428
+ """
429
+ assert idx >= 0, "Index must be non-negative"
430
+ if idx * self.batch_size >= self.num_jets:
431
+ return None
432
+ # get a batch from each sample
433
+ samples = [br(idx) for br in batch_readers]
434
+ samples = [s for s in samples if s is not None]
435
+ if len(samples) == 0:
436
+ return None
437
+ # combine samples and shuffle
438
+ data = {name: np.concatenate([s[name] for s in samples]) for name in variables}
439
+ if shuffle:
440
+ idx = np.arange(len(data[self.jets_name]))
441
+ self.rng.shuffle(idx)
442
+ data = {name: array[idx] for name, array in data.items()}
443
+ return data
444
+
445
+ return get_batch
446
+
318
447
  def load(
319
448
  self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
320
449
  ) -> dict: