atlas-ftag-tools 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,19 +1,19 @@
1
- ftag/__init__.py,sha256=yf92K1TFG1_KK30N_6FgdGMh-arYNnl4YKQEXPmrJOk,543
1
+ ftag/__init__.py,sha256=XBQEZpFSnGyihB9F3eGOvB_5YknggY_L6fzwYszXLuQ,543
2
2
  ftag/cuts.py,sha256=lCnyHd4kbrt3CMXGE1ASCgaa07o1qOBn6GQek6lClVQ,2734
3
3
  ftag/flavour.py,sha256=sEelvHNLWmHsecQQrmRc8ktwykMMHnGX8ePDRrqQkuo,2460
4
- ftag/flavours.yaml,sha256=woPpF8hDycjv_McKbHVqQQE072_P50f9KVNNckEbFKA,3245
4
+ ftag/flavours.yaml,sha256=VrOGD5FUhMVPIW31whY-nSqNv98AcnLsPmPGmAcCg3w,3287
5
5
  ftag/mock.py,sha256=HUyYOPsRtkmzjLRNF2zs0kpVUrTRIHTsnIyDlXIZArU,3627
6
6
  ftag/region.py,sha256=-WxdC0Gy9zz3zEJ2pN779RcxXPG-QEROuMwMoP-Qs0g,353
7
7
  ftag/sample.py,sha256=uVNyxFYMMtkP-o2tjQatpo8mIH4ZNNe3mSFEPebYh_E,2622
8
8
  ftag/vds.py,sha256=8b5-zqDELUmxdO5Txdowe3v7XGS1pKgO20bhzUQqCxU,2945
9
9
  ftag/hdf5/__init__.py,sha256=A_a_4IUlZ2mSiDcfrZKBdja_3iTrUHvADM2lWx6g66g,325
10
- ftag/hdf5/h5reader.py,sha256=PlLv3VkGGywAbo8dpbLdwnXW2NTTHTlepFfG8nE00J8,8723
10
+ ftag/hdf5/h5reader.py,sha256=1_iyYfWI1ht1-p9vBBpGhw47ZKola_KhWxbrywoB-Jg,11751
11
11
  ftag/hdf5/h5utils.py,sha256=GKduv9b6JRSBirRdmNgGcmsINCMTj54kH4RQqxrM1t8,2363
12
12
  ftag/hdf5/h5writer.py,sha256=_N-DJSX283r-XsGczvLFA4_qaK4BkFkdKZAusHEvRjU,2919
13
13
  ftag/wps/discriminant.py,sha256=86ISONTuIjqTJO1A27oqkoCgDjAQinofiYNdcjfdkIk,1380
14
14
  ftag/wps/working_points.py,sha256=487NsQGGY2Qt4q8mXxKABMFa-YLsbrhkPLcYVdebeVk,4950
15
- atlas_ftag_tools-0.1.4.dist-info/METADATA,sha256=DiUJuY2MIGmxugt723jC3_FmWXW61wVSS_jRyic3K1w,4182
16
- atlas_ftag_tools-0.1.4.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
17
- atlas_ftag_tools-0.1.4.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
18
- atlas_ftag_tools-0.1.4.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
19
- atlas_ftag_tools-0.1.4.dist-info/RECORD,,
15
+ atlas_ftag_tools-0.1.5.dist-info/METADATA,sha256=Uc4Z2zAMD7jsSKoV6o2LJwfm2X0KEWYRingGT_msE4I,4182
16
+ atlas_ftag_tools-0.1.5.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
17
+ atlas_ftag_tools-0.1.5.dist-info/entry_points.txt,sha256=UKbRbwA9DxfsTPRBIVVDz3u15WdzhzgRKwXXSAXuQqc,73
18
+ atlas_ftag_tools-0.1.5.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
19
+ atlas_ftag_tools-0.1.5.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """atlas-ftag-tools - Common tools for ATLAS flavour tagging software."""
2
2
 
3
3
 
4
- __version__ = "v0.1.4"
4
+ __version__ = "v0.1.5"
5
5
 
6
6
 
7
7
  import ftag.hdf5 as hdf5
ftag/flavours.yaml CHANGED
@@ -49,12 +49,12 @@
49
49
 
50
50
  # Xbb tagging
51
51
  - name: hbb
52
- label: Hbb
52
+ label: $H \rightarrow b\bar{b}$
53
53
  cuts: ["R10TruthLabel_R22v1 == 11"]
54
54
  colour: tab:blue
55
55
  category: xbb
56
56
  - name: hcc
57
- label: Hcc
57
+ label: $H \rightarrow c\bar{c}$
58
58
  cuts: ["R10TruthLabel_R22v1 == 12"]
59
59
  colour: "#B45F06"
60
60
  category: xbb
ftag/hdf5/h5reader.py CHANGED
@@ -26,9 +26,10 @@ class H5SingleReader:
26
26
 
27
27
  def __post_init__(self) -> None:
28
28
  self.sample = Sample(self.fname)
29
- if len(self.sample.virtual_file()) != 1:
29
+ fname = self.sample.virtual_file()
30
+ if len(fname) != 1:
30
31
  raise ValueError("H5SingleReader should only read a single file")
31
- self.fname = self.sample.virtual_file()[0]
32
+ self.fname = fname[0]
32
33
 
33
34
  @cached_property
34
35
  def num_jets(self) -> int:
@@ -57,21 +58,27 @@ class H5SingleReader:
57
58
  isinf = np.isinf(array[var])
58
59
  keep_idx = keep_idx & ~isinf.any(axis=-1)
59
60
  if num_inf := isinf.sum():
60
- log.warn(
61
+ log.warning(
61
62
  f"{num_inf} inf values detected for variable {var} in"
62
63
  f" {name} array. Removing the affected jets."
63
64
  )
64
65
  return {name: array[keep_idx] for name, array in data.items()}
65
66
 
66
67
  def stream(
67
- self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
68
+ self,
69
+ variables: dict | None = None,
70
+ num_jets: int | None = None,
71
+ cuts: Cuts | None = None,
68
72
  ) -> Generator:
69
73
  if num_jets is None:
70
74
  num_jets = self.num_jets
75
+
71
76
  if num_jets > self.num_jets:
72
- raise ValueError(
73
- f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}"
77
+ log.warning(
78
+ f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
79
+ " Set to maximum available number!"
74
80
  )
81
+ num_jets = self.num_jets
75
82
 
76
83
  if variables is None:
77
84
  variables = {self.jets_name: None}
@@ -131,6 +138,9 @@ class H5Reader:
131
138
  Weights for different input datasets, by default None
132
139
  do_remove_inf : bool, optional
133
140
  Remove jets with inf values, by default False
141
+ equal_jets : bool, optional
142
+ Take the same number of jets (weighted) from each sample, by default True
143
+ If False, use all jets in each sample.
134
144
  """
135
145
 
136
146
  fname: Path | str | list[Path | str]
@@ -140,8 +150,16 @@ class H5Reader:
140
150
  shuffle: bool = True
141
151
  weights: list[float] | None = None
142
152
  do_remove_inf: bool = False
153
+ equal_jets: bool = True
143
154
 
144
155
  def __post_init__(self) -> None:
156
+ if not self.equal_jets:
157
+ log.warning(
158
+ "equal_jets is set to False, which will result in different number of jets taken"
159
+ " from each sample. Be aware that this can affect the resampling, so make sure you"
160
+ " know what you are doing."
161
+ )
162
+
145
163
  if isinstance(self.fname, (str, Path)):
146
164
  self.fname = [self.fname]
147
165
 
@@ -191,10 +209,14 @@ class H5Reader:
191
209
  Generator
192
210
  Generator of batches of selected jets.
193
211
  """
212
+ # Check if number of jets is given, if not, set to maximum available
194
213
  if num_jets is None:
195
214
  num_jets = self.num_jets
215
+
216
+ # Check if variables if given, if not, set to all
196
217
  if variables is None:
197
218
  variables = {self.jets_name: None}
219
+
198
220
  if self.jets_name not in variables or variables[self.jets_name] is not None:
199
221
  jet_vars = variables.get(self.jets_name, [])
200
222
  variables[self.jets_name] = list(jet_vars) + (cuts.variables if cuts else [])
@@ -207,12 +229,25 @@ class H5Reader:
207
229
 
208
230
  rng = np.random.default_rng(42)
209
231
  while True:
210
- # yeild from each stream
211
232
  samples = []
212
- for stream in streams:
213
- try:
214
- samples.append(next(stream))
215
- except StopIteration:
233
+ # Track which streams have been exhausted
234
+ streams_done = [False] * len(streams)
235
+
236
+ # for each unexhausted stream, get the next sample
237
+ for i, stream in enumerate(streams):
238
+ if not streams_done[i]:
239
+ try:
240
+ samples.append(next(stream))
241
+
242
+ # if equal_jets is True, we can stop when any stream is done
243
+ # otherwise if sample is exhausted, mark it as done
244
+ except StopIteration:
245
+ if self.equal_jets:
246
+ return
247
+ streams_done[i] = True
248
+
249
+ # if equal_jets is False, we need to keep going until all streams are done
250
+ if all(streams_done):
216
251
  return
217
252
 
218
253
  # combine samples and shuffle
@@ -222,25 +257,72 @@ class H5Reader:
222
257
  rng.shuffle(idx)
223
258
  data = {name: array[idx] for name, array in data.items()}
224
259
 
225
- # select
260
+ # yield batch
226
261
  yield data
227
262
 
228
263
  def load(
229
264
  self, variables: dict | None = None, num_jets: int | None = None, cuts: Cuts | None = None
230
265
  ) -> dict:
266
+ """Load multiple batches of selected jets into memory.
267
+
268
+ Parameters
269
+ ----------
270
+ variables : dict | None, optional
271
+ Dictionary of variables to for each group, by default use all jet variables.
272
+ num_jets : int | None, optional
273
+ Total number of selected jets to load, by default all.
274
+ cuts : Cuts | None, optional
275
+ Selection cuts to apply, by default None
276
+
277
+ Returns
278
+ -------
279
+ dict
280
+ Dictionary of arrays for each group.
281
+ """
282
+ # handle default arguments
231
283
  if num_jets == -1:
232
284
  num_jets = self.num_jets
233
285
  if variables is None:
234
286
  variables = {self.jets_name: None}
287
+
288
+ # get data from each sample
235
289
  data: dict[str, list] = {name: [] for name in variables}
236
- for sample in self.stream(variables, num_jets, cuts):
237
- for name, array in sample.items():
290
+ for batch in self.stream(variables, num_jets, cuts):
291
+ for name, array in batch.items():
238
292
  if name in data:
239
293
  data[name].append(array)
294
+
295
+ # concatenate batches
240
296
  return {name: np.concatenate(array) for name, array in data.items()}
241
297
 
242
298
  def estimate_available_jets(self, cuts: Cuts, num: int = 1_000_000) -> int:
243
- """Estimate the number of jets available after selection cuts, rounded down."""
244
- all_jets = self.load({self.jets_name: cuts.variables}, num)[self.jets_name]
245
- estimated_num_jets = len(cuts(all_jets).values) / len(all_jets) * self.num_jets
299
+ """Estimate the number of jets available after selection cuts (round down).
300
+
301
+ Parameters
302
+ ----------
303
+ cuts : Cuts
304
+ Selection cuts to apply.
305
+ num : int, optional
306
+ Number of jets to use for the estimation, by default 1_000_000.
307
+
308
+ Returns
309
+ -------
310
+ int
311
+ Estimated number of jets available after selection cuts,
312
+ rounded down to nearest thousand.
313
+ """
314
+ # if equal jets is True, available jets is based on the smallest sample
315
+ if self.equal_jets:
316
+ num_jets = []
317
+ for r in self.readers:
318
+ stream = r.stream({self.jets_name: cuts.variables}, num)
319
+ all_jets = np.concatenate([batch[self.jets_name] for batch in stream])
320
+ frac_selected = len(cuts(all_jets).values) / len(all_jets)
321
+ num_jets.append(frac_selected * r.num_jets)
322
+ estimated_num_jets = min(num_jets) * len(self.readers)
323
+ # otherwise, available jets is based on all samples
324
+ else:
325
+ all_jets = self.load({self.jets_name: cuts.variables}, num)[self.jets_name]
326
+ frac_selected = len(cuts(all_jets).values) / len(all_jets)
327
+ estimated_num_jets = frac_selected * self.num_jets
246
328
  return math.floor(estimated_num_jets / 1_000) * 1_000