atlas-ftag-tools 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,28 +1,28 @@
1
- ftag/__init__.py,sha256=i5XMWwkyGmwIIIaO3veRyOnTsJ0qKyW1V6HsuYc9Dm4,629
1
+ ftag/__init__.py,sha256=Mx2Emsw4TM1YL0wTQMHK36EaQE_ImeV6ukiz1X5BZAU,629
2
2
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
3
3
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
4
4
  ftag/flavour.py,sha256=EMZZLyl6lSdvkfrYxHhMcSn3aqP_FU7OpCFkvZpTksU,3761
5
5
  ftag/flavours.yaml,sha256=lFnVwjh_DwLhOc3mr5n6bSIWyHgxQvAXas4lEmEDncU,7520
6
6
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
7
7
  ftag/labeller.py,sha256=uDygOhVGSNn96DWw8aErHpTtFsFX0RnxYYpy4g1FRog,2457
8
- ftag/mock.py,sha256=FboI1Kq6aKZv43SpubiLkJvn4BSMDv2Fl2UniuhxspU,4502
8
+ ftag/mock.py,sha256=_oy-r3eLllFy33NAoZaKfAx-Rp2vrCdrGj3UsTMks94,4740
9
9
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
10
10
  ftag/sample.py,sha256=TFXMhDkbPmjkms9-b-bINJ32T3bO86JcU70C0nY7wa8,2500
11
11
  ftag/test_cli_utils.py,sha256=xa08vf6SEOow58SSFagYdAselb-dkNOVvWsWheMnW-g,1001
12
- ftag/track_selector.py,sha256=WQ6lzY8n6pumIVfLMLgjlTkWUaIQu-8Dq3FKsmCsG_0,1736
12
+ ftag/track_selector.py,sha256=piSYAN_IkOsrXxKXjXbJpMSseUig5P2BJW5mCwsMUDM,2535
13
13
  ftag/transform.py,sha256=uEGGJSnqoKOzLYQv650XdK_kDNw4Aw-5dc60z9Dp_y0,3963
14
14
  ftag/vds.py,sha256=nRViQZQIORB95nC7NZsW3KsSoGkLzEdOsuCViH5h8-U,3296
15
15
  ftag/hdf5/__init__.py,sha256=LFDNxVOCp58SvLHwQhdT68Q-KBMS_i6jBrbXoRpHzbM,354
16
16
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
17
- ftag/hdf5/h5reader.py,sha256=H_5Aw0lOyEzK_phMRhD-jR_OSCsXnCA3qJZnRvPqaRU,13569
17
+ ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
18
18
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
19
19
  ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
20
20
  ftag/hdf5/h5writer.py,sha256=j3Fy8snkiVVfimiUz3rrZOhSV8OF27978Y9pk0QcTGM,5277
21
21
  ftag/wps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  ftag/wps/discriminant.py,sha256=kJFekUTPNIvCabJCon6OqOAQEzz5hj3XrWFFRLOgGOs,3836
23
23
  ftag/wps/working_points.py,sha256=VTU6OD40ULAJQD0MlD1EZd33q8ociUvFX1YrhgJFvXc,9722
24
- atlas_ftag_tools-0.2.3.dist-info/METADATA,sha256=XP7QCxOKz-QUKKx5ds8ldz13n_qfTVcP-SZciExdIr0,5169
25
- atlas_ftag_tools-0.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
26
- atlas_ftag_tools-0.2.3.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
- atlas_ftag_tools-0.2.3.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
- atlas_ftag_tools-0.2.3.dist-info/RECORD,,
24
+ atlas_ftag_tools-0.2.4.dist-info/METADATA,sha256=f4aCu6JmItUBp5EmTzbrqhC5-Wsy7uiOiBO9yufyacQ,5169
25
+ atlas_ftag_tools-0.2.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
26
+ atlas_ftag_tools-0.2.4.dist-info/entry_points.txt,sha256=LfVLsZHQolqbPnwPgtmc5IQTh527BKkN2v-IpXWTNHw,137
27
+ atlas_ftag_tools-0.2.4.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
28
+ atlas_ftag_tools-0.2.4.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.3"
5
+ __version__ = "v0.2.4"
6
6
 
7
7
 
8
8
  from ftag import hdf5
ftag/hdf5/h5reader.py CHANGED
@@ -150,8 +150,13 @@ class H5Reader:
150
150
  transform : Transform | None, optional
151
151
  Transform to apply to data, by default None
152
152
  equal_jets : bool, optional
153
- Take the same number of jets (weighted) from each sample, by default True
154
- If False, use all jets in each sample.
153
+ Take the same number of jets (weighted) from each sample, by default True.
154
+ This is useful when you specify a list of DSIDs for the sample and they are
155
+ qualitatively different, and you want to ensure that you always return batches
156
+ with jets from all DSIDs. This is used for example in the QCD resampling for Xbb.
157
+ If False, use all jets in each sample, allowing for the full available statistics
158
+ to be used. Useful for example if you have multiple ttbar samples and you want to
159
+ use all available jets from each sample.
155
160
  """
156
161
 
157
162
  fname: Path | str | list[Path | str]
@@ -162,17 +167,10 @@ class H5Reader:
162
167
  weights: list[float] | None = None
163
168
  do_remove_inf: bool = False
164
169
  transform: Transform | None = None
165
- equal_jets: bool = True
170
+ equal_jets: bool = False
166
171
 
167
172
  def __post_init__(self) -> None:
168
173
  self.rng = np.random.default_rng(42)
169
- if not self.equal_jets:
170
- log.warning(
171
- "equal_jets is set to False, which will result in different number of jets taken"
172
- " from each sample. Be aware that this can affect the resampling, so make sure you"
173
- " know what you are doing."
174
- )
175
-
176
174
  if isinstance(self.fname, (str, Path)):
177
175
  self.fname = [self.fname]
178
176
 
@@ -283,8 +281,8 @@ class H5Reader:
283
281
  try:
284
282
  samples.append(next(stream))
285
283
 
286
- # if equal_jets is True, we can stop when any stream is done
287
- # otherwise if sample is exhausted, mark it as done
284
+ # if equal_jets is True, stop when any sample is done
285
+ # otherwise if stream is exhausted, mark it as such and continue
288
286
  except StopIteration:
289
287
  if self.equal_jets:
290
288
  return
ftag/mock.py CHANGED
@@ -109,6 +109,11 @@ def mock_tracks(num_jets=1000, num_tracks=40) -> np.ndarray:
109
109
  tracks_dtype = np.dtype(TRACK_VARS)
110
110
  tracks = u2s(rng.random((num_jets, num_tracks, len(TRACK_VARS))), tracks_dtype)
111
111
  tracks["d0"] *= 5
112
+
113
+ # for the shared hits, add some reasonable integer values
114
+ tracks["numberOfPixelSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
115
+ tracks["numberOfSCTSharedHits"] = rng.integers(0, 3, size=(num_jets, num_tracks))
116
+
112
117
  valid = rng.choice([True, False], size=(num_jets, num_tracks))
113
118
  valid = valid.astype(bool).view(dtype=np.dtype([("valid", bool)]))
114
119
  return join_structured_arrays([tracks, valid])
ftag/track_selector.py CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  import numpy as np
6
6
 
7
- from ftag import Cuts
7
+ from ftag.cuts import Cut, Cuts
8
8
 
9
9
 
10
10
  @dataclass
@@ -34,7 +34,8 @@ class TrackSelector:
34
34
  # apply the cuts
35
35
  for cut in self.cuts.cuts:
36
36
  # remove valid track indices that do not pass the selection
37
- rm_idx[tracks[self.valid_str] & ~cut(tracks)] = True
37
+ keep_idx = self._nshared_cut(cut, tracks) if cut.variable == "NSHARED" else cut(tracks)
38
+ rm_idx[tracks[self.valid_str] & ~keep_idx] = True
38
39
 
39
40
  # set the values of the tracks that do not pass the cuts to
40
41
  for var in tracks.dtype.names:
@@ -51,3 +52,19 @@ class TrackSelector:
51
52
  tracks[rm_idx][self.valid_str] = False
52
53
 
53
54
  return tracks
55
+
56
+ def _nshared_cut(self, cut: Cut, tracks: np.ndarray) -> np.ndarray:
57
+ # hack to apply the FTAG shared hit cut, which requires an intermediate step
58
+ if cut.variable == "NSHARED" and "NSHARED" in tracks.dtype.names:
59
+ raise ValueError("NSHARED is a reserved variable name")
60
+
61
+ # compute
62
+ n_pix_shared = tracks["numberOfPixelSharedHits"]
63
+ n_sct_shared = tracks["numberOfSCTSharedHits"]
64
+ n_module_shared = n_pix_shared + n_sct_shared / 2
65
+
66
+ # convert n_module_shared to structured array
67
+ n_module_shared = n_module_shared.view(dtype=[(cut.variable, n_module_shared.dtype)])
68
+
69
+ # select
70
+ return cut(n_module_shared)