atlas-ftag-tools 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atlas-ftag-tools
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: ATLAS Flavour Tagging Tools
5
5
  Author: Sam Van Stroud, Philipp Gadow
6
6
  License: MIT
@@ -1,14 +1,14 @@
1
- atlas_ftag_tools-0.2.11.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
- ftag/__init__.py,sha256=BGQ1MtuhqCHFXRAh9S9f_ZnOCLWB5RA0ZtL9lW2tofs,748
1
+ atlas_ftag_tools-0.2.12.dist-info/licenses/LICENSE,sha256=R4o6bZfajQ1KxwcIeavTC00qYTdL33YGNe1hzfV53gM,11349
2
+ ftag/__init__.py,sha256=CU1RjEu6pHq11LQ2kAy9YDittMHXB51fNWvuy1NFr7o,748
3
3
  ftag/cli_utils.py,sha256=w3TtQmUHSyAKChS3ewvOtcSDAUJAZGIIomaNi8f446U,298
4
4
  ftag/cuts.py,sha256=9_ooLZHaO3SnIQBNxwbaPZn-qptGdKnB27FdKQGTiTY,2933
5
5
  ftag/flavours.py,sha256=ShH4M2UjQZpZ_NlCctTm2q1tJbzYxjmGteioQ2GcqEU,114
6
- ftag/flavours.yaml,sha256=CrVTJKndHeL15LT2nkjPodi6Ck9mk_oUtdRby6X_Rcc,9921
6
+ ftag/flavours.yaml,sha256=b86gXX_FMIewLK7_pr0bNgz7RJ84fDf9nAvmjB4J-Ks,9920
7
7
  ftag/fraction_optimization.py,sha256=IlMEJe5fD0soX40f-LO4dYAYld2gMqgZRuBLctoPn9A,5566
8
8
  ftag/git_check.py,sha256=Y-XqM80CVXZ5ZKrDdZcYOJt3X64uU6W3OP6Z0D7AZU0,1663
9
- ftag/labeller.py,sha256=IXUgU9UBir39PxVWRKs5r5fqI66Tv0x7nJD3-RYpbrg,2780
9
+ ftag/labeller.py,sha256=6tKLG0SrBijMIZdzWjGdQU9qN_dkRV4eELLX_8YQvTQ,2772
10
10
  ftag/labels.py,sha256=2nmcmrZD8mWQPxJsGiOgcLDhSVgWfS_cEzqsBV-Qy8o,4198
11
- ftag/mock.py,sha256=P2D7nNKAz2jRBbmfpHTDj9sBVU9r7HGd0rpWZOJYZ90,5980
11
+ ftag/mock.py,sha256=syysvzLsBHU8aw7Uy5g3G4HB6LnSmHlGa2BfeXv5mQ4,5970
12
12
  ftag/region.py,sha256=ANv0dGI2W6NJqD9fp7EfqAUReH4FOjc1gwl_Qn8llcM,360
13
13
  ftag/sample.py,sha256=3N0FrRcu9l1sX8ohuGOHuMYGD0See6gMO4--7NzR2tE,2538
14
14
  ftag/track_selector.py,sha256=fJNk_kIBQriBqV4CPT_3ReJbOUnavDDzO-u3EQlRuyk,2654
@@ -18,15 +18,15 @@ ftag/working_points.py,sha256=RJws2jPMEDQDspCbXUZBifS1CCBmlMJ5ax0eMyDzCRA,15949
18
18
  ftag/hdf5/__init__.py,sha256=8yzVQITge-HKkBQQ60eJwWmWDycYZjgVs-qVg4ShVr0,385
19
19
  ftag/hdf5/h5add_col.py,sha256=htS5wn4Tm4S3U6mrJ8s24VUnbI7o28Z6Ll-J_V68xTA,12558
20
20
  ftag/hdf5/h5move.py,sha256=oYpRu0IDCIJIQ2ML52HBAdoyDxmKkHTeM9JdbPEgKfI,947
21
- ftag/hdf5/h5reader.py,sha256=i31pDAqmOSaxdeRhc4iSBlld8xJ0pmp4rNd7CugNzw0,13706
21
+ ftag/hdf5/h5reader.py,sha256=NbHohY3RSicM3qnX_0Y1TfGAaDg3wgfEjYlGaWDJmug,14268
22
22
  ftag/hdf5/h5split.py,sha256=4Wy6Xc3J58MdD9aBaSZHf5ZcVFnJSkWsm42R5Pgo-R4,2448
23
- ftag/hdf5/h5utils.py,sha256=-4zKTMtNCrDZr_9Ww7uzfsB7M7muBKpmm_1IkKJnHOI,3222
24
- ftag/hdf5/h5writer.py,sha256=2gBztierWdwZIqcFItoYz8oua_7hphOI8mbDg7xBdPs,5784
23
+ ftag/hdf5/h5utils.py,sha256=EbCLOF_j1EBFwD95Z3QvlNpBpNkaZoqySZybhVja67U,3542
24
+ ftag/hdf5/h5writer.py,sha256=SMurvZ8FPvqieZUaYRX2SBu-jIyZ6Fx8IasUrEOxIvM,7185
25
25
  ftag/utils/__init__.py,sha256=U3YyLY77-FzxRUbudxciieDoy_mnLlY3OfBquA3PnTE,524
26
26
  ftag/utils/logging.py,sha256=54NaQiC9Bh4vSznSqzoPfR-7tj1PXfmoH7yKgv_ZHZk,3192
27
27
  ftag/utils/metrics.py,sha256=zQI4nPeRDSyzqKpdOPmu0GU560xSWoW1wgL13rrja-I,12664
28
- atlas_ftag_tools-0.2.11.dist-info/METADATA,sha256=DVmllPN7YQNNmyDcTs3hEGo8mX8ogSReXq9gs6MwUR0,2152
29
- atlas_ftag_tools-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- atlas_ftag_tools-0.2.11.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
- atlas_ftag_tools-0.2.11.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
- atlas_ftag_tools-0.2.11.dist-info/RECORD,,
28
+ atlas_ftag_tools-0.2.12.dist-info/METADATA,sha256=bGfabVRARSL6PZTsDqen30IkQVqVGw8Tg9lMCnzY-5w,2152
29
+ atlas_ftag_tools-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ atlas_ftag_tools-0.2.12.dist-info/entry_points.txt,sha256=acr7WwxMIJ3x2I7AheNxNnpWE7sS8XE9MA1eUJGcU5A,169
31
+ atlas_ftag_tools-0.2.12.dist-info/top_level.txt,sha256=qiYQuKcAvMim-31FwkT3MTQu7WQm0s58tPAia5KKWqs,5
32
+ atlas_ftag_tools-0.2.12.dist-info/RECORD,,
ftag/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "v0.2.11"
5
+ __version__ = "v0.2.12"
6
6
 
7
7
  from . import hdf5, utils
8
8
  from .cuts import Cuts
ftag/flavours.yaml CHANGED
@@ -113,7 +113,7 @@
113
113
  category: xbb
114
114
  - name: qcdnonbb
115
115
  label: $\mathrm{QCD} \rightarrow \mathrm{non-} b \bar{b}$
116
- cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount != 2"]
116
+ cuts: ["R10TruthLabel_R22v1 == 10", "GhostBHadronsFinalCount < 2"]
117
117
  colour: "silver"
118
118
  category: xbb
119
119
  - name: qcdbx
ftag/hdf5/h5reader.py CHANGED
@@ -74,10 +74,12 @@ class H5SingleReader:
74
74
  num_jets: int | None = None,
75
75
  cuts: Cuts | None = None,
76
76
  start: int = 0,
77
+ skip_batches: int = 0,
77
78
  ) -> Generator:
78
79
  if num_jets is None:
79
80
  num_jets = self.num_jets
80
-
81
+ if skip_batches > 0:
82
+ assert not self.shuffle, "Cannot skip batches if shuffle is True"
81
83
  if num_jets > self.num_jets:
82
84
  log.warning(
83
85
  f"{num_jets:,} jets requested but only {self.num_jets:,} available in {self.fname}."
@@ -97,7 +99,8 @@ class H5SingleReader:
97
99
  indices = list(range(start, self.num_jets + start, self.batch_size))
98
100
  if self.shuffle:
99
101
  self.rng.shuffle(indices)
100
-
102
+ if skip_batches > 0:
103
+ indices = indices[skip_batches:]
101
104
  # loop over batches and read file
102
105
  for low in indices:
103
106
  for name in variables:
@@ -176,7 +179,12 @@ class H5Reader:
176
179
 
177
180
  # calculate batch sizes
178
181
  if self.weights is None:
179
- self.weights = [1 / len(self.fname)] * len(self.fname)
182
+ rows_per_file = [
183
+ H5SingleReader(f, jets_name=self.jets_name).num_jets for f in self.fname
184
+ ]
185
+ num_total = sum(rows_per_file)
186
+ self.weights = [num / num_total for num in rows_per_file]
187
+
180
188
  self.batch_sizes = [int(w * self.batch_size) for w in self.weights]
181
189
 
182
190
  # create readers
@@ -233,6 +241,7 @@ class H5Reader:
233
241
  num_jets: int | None = None,
234
242
  cuts: Cuts | None = None,
235
243
  start: int = 0,
244
+ skip_batches: int = 0,
236
245
  ) -> Generator:
237
246
  """Generate batches of selected jets.
238
247
 
@@ -246,6 +255,8 @@ class H5Reader:
246
255
  Selection cuts to apply, by default None
247
256
  start : int, optional
248
257
  Starting index of the first jet to read, by default 0
258
+ skip_batches : int, optional
259
+ Number of batches to skip, by default 0
249
260
 
250
261
  Yields
251
262
  ------
@@ -266,7 +277,9 @@ class H5Reader:
266
277
 
267
278
  # get streams for selected jets from each reader
268
279
  streams = [
269
- r.stream(variables, int(r.num_jets / self.num_jets * num_jets), cuts, start)
280
+ r.stream(
281
+ variables, int(r.num_jets / self.num_jets * num_jets), cuts, start, skip_batches
282
+ )
270
283
  for r in self.readers
271
284
  ]
272
285
 
ftag/hdf5/h5utils.py CHANGED
@@ -13,6 +13,7 @@ def get_dtype(
13
13
  variables: list[str] | None = None,
14
14
  precision: str | None = None,
15
15
  transform: Transform | None = None,
16
+ full_precision_vars: list[str] | None = None,
16
17
  ) -> np.dtype:
17
18
  """Return a dtype based on an existing dataset and requested variables.
18
19
 
@@ -26,6 +27,8 @@ def get_dtype(
26
27
  Precision to cast floats to, "half" or "full", by default None
27
28
  transform : Transform | None, optional
28
29
  Transform to apply to variables names, by default None
30
+ full_precision_vars : list[str] | None, optional
31
+ List of variables to keep in full precision, by default None
29
32
 
30
33
  Returns
31
34
  -------
@@ -39,6 +42,8 @@ def get_dtype(
39
42
  """
40
43
  if variables is None:
41
44
  variables = ds.dtype.names
45
+ if full_precision_vars is None:
46
+ full_precision_vars = []
42
47
 
43
48
  if (missing := set(variables) - set(ds.dtype.names)) and transform is not None:
44
49
  variables = transform.map_variable_names(ds.name, variables, inverse=True)
@@ -50,7 +55,10 @@ def get_dtype(
50
55
 
51
56
  dtype = [(n, x) for n, x in ds.dtype.descr if n in variables]
52
57
  if precision:
53
- dtype = [(n, cast_dtype(x, precision)) for n, x in dtype]
58
+ dtype = [
59
+ (n, cast_dtype(x, precision)) if n not in full_precision_vars else (n, x)
60
+ for n, x in dtype
61
+ ]
54
62
 
55
63
  return np.dtype(dtype)
56
64
 
@@ -78,6 +86,7 @@ def cast_dtype(typestr: str, precision: str) -> np.dtype:
78
86
  t = np.dtype(typestr)
79
87
  if t.kind != "f":
80
88
  return t
89
+
81
90
  if precision == "half":
82
91
  return np.dtype("f2")
83
92
  if precision == "full":
ftag/hdf5/h5writer.py CHANGED
@@ -47,18 +47,28 @@ class H5Writer:
47
47
  precision: str = "full"
48
48
  full_precision_vars: list[str] | None = None
49
49
  shuffle: bool = True
50
+ num_jets: int | None = None # Allow dynamic mode by defaulting to None
50
51
 
51
52
  def __post_init__(self):
52
53
  self.num_written = 0
53
54
  self.rng = np.random.default_rng(42)
54
- self.num_jets = [shape[0] for shape in self.shapes.values()]
55
- assert len(set(self.num_jets)) == 1, "Must have same number of jets per group"
56
- self.num_jets = self.num_jets[0]
55
+
56
+ # Infer number of jets from shapes if not explicitly passed
57
+ inferred_num_jets = [shape[0] for shape in self.shapes.values()]
58
+ if self.num_jets is None:
59
+ assert len(set(inferred_num_jets)) == 1, "Shapes must agree in first dimension"
60
+ self.fixed_mode = False
61
+ else:
62
+ self.fixed_mode = True
63
+ for name in self.shapes:
64
+ self.shapes[name] = (self.num_jets,) + self.shapes[name][1:]
57
65
 
58
66
  if self.precision == "full":
59
67
  self.fp_dtype = np.float32
60
68
  elif self.precision == "half":
61
69
  self.fp_dtype = np.float16
70
+ elif self.precision is None:
71
+ self.fp_dtype = None
62
72
  else:
63
73
  raise ValueError(f"Invalid precision: {self.precision}")
64
74
 
@@ -71,16 +81,34 @@ class H5Writer:
71
81
  self.create_ds(name, dtype)
72
82
 
73
83
  @classmethod
74
- def from_file(cls, source: Path, num_jets: int | None = None, **kwargs) -> H5Writer:
84
+ def from_file(
85
+ cls, source: Path, num_jets: int | None = 0, variables=None, **kwargs
86
+ ) -> H5Writer:
75
87
  with h5py.File(source, "r") as f:
76
88
  dtypes = {name: ds.dtype for name, ds in f.items()}
77
89
  shapes = {name: ds.shape for name, ds in f.items()}
78
- if num_jets is not None:
90
+
91
+ if variables:
92
+ new_dtye = {}
93
+ new_shape = {}
94
+ for name, ds in f.items():
95
+ if name not in variables:
96
+ continue
97
+ new_dtye[name] = ftag.hdf5.get_dtype(
98
+ ds,
99
+ variables=variables[name],
100
+ precision=kwargs.get("precision"),
101
+ full_precision_vars=kwargs.get("full_precision_vars"),
102
+ )
103
+ new_shape[name] = ds.shape
104
+ dtypes = new_dtye
105
+ shapes = new_shape
106
+ if num_jets != 0:
79
107
  shapes = {name: (num_jets,) + shape[1:] for name, shape in shapes.items()}
80
108
  compression = [ds.compression for ds in f.values()]
81
109
  assert len(set(compression)) == 1, "Must have same compression for all groups"
82
110
  compression = compression[0]
83
- if compression not in kwargs:
111
+ if "compression" not in kwargs:
84
112
  kwargs["compression"] = compression
85
113
  return cls(dtypes=dtypes, shapes=shapes, **kwargs)
86
114
 
@@ -88,36 +116,47 @@ class H5Writer:
88
116
  if name == self.jets_name and self.add_flavour_label and "flavour_label" not in dtype.names:
89
117
  dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
90
118
 
91
- # adjust dtype based on specified precision
92
- full_precision_vars = [] if self.full_precision_vars is None else self.full_precision_vars
93
- # If the field is in full_precision_vars, use the full precision dtype
119
+ fp_vars = self.full_precision_vars or []
120
+ # If no precision is defined, or the field is in full_precision_vars, or its non-float,
121
+ # keep it at the original dtype
94
122
  dtype = np.dtype([
95
123
  (
96
124
  field,
97
- self.fp_dtype
98
- if field not in full_precision_vars and np.issubdtype(dt, np.floating)
99
- else dt,
125
+ (
126
+ self.fp_dtype
127
+ if (self.fp_dtype and field not in fp_vars and np.issubdtype(dt, np.floating))
128
+ else dt
129
+ ),
100
130
  )
101
131
  for field, dt in dtype.descr
102
132
  ])
103
133
 
104
- # optimal chunking is around 100 jets, only aply for track groups
105
134
  shape = self.shapes[name]
106
135
  chunks = (100,) + shape[1:] if shape[1:] else None
107
136
 
108
- # note: enabling the hd5 shuffle filter doesn't improve write performance
109
- self.file.create_dataset(
110
- name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
111
- )
137
+ if self.fixed_mode:
138
+ self.file.create_dataset(
139
+ name, dtype=dtype, shape=shape, compression=self.compression, chunks=chunks
140
+ )
141
+ else:
142
+ maxshape = (None,) + shape[1:]
143
+ self.file.create_dataset(
144
+ name,
145
+ dtype=dtype,
146
+ shape=(0,) + shape[1:],
147
+ maxshape=maxshape,
148
+ compression=self.compression,
149
+ chunks=chunks,
150
+ )
112
151
 
113
152
  def close(self) -> None:
114
- with h5py.File(self.dst) as f:
115
- written = len(f[self.jets_name])
116
- if self.num_written != written:
117
- raise ValueError(
118
- f"Attemped to close file {self.dst} when only {self.num_written:,} out of"
119
- f" {written:,} jets have been written"
120
- )
153
+ if self.fixed_mode:
154
+ written = len(self.file[self.jets_name])
155
+ if self.num_written != written:
156
+ raise ValueError(
157
+ f"Attempted to close file {self.dst} when only {self.num_written:,} out of"
158
+ f" {written:,} jets have been written"
159
+ )
121
160
  self.file.close()
122
161
 
123
162
  def get_attr(self, name, group=None):
@@ -137,18 +176,25 @@ class H5Writer:
137
176
  for attr_name, value in ds.attrs.items():
138
177
  self.add_attr(attr_name, value, group=name)
139
178
 
140
- def write(self, data: dict[str, np.array]) -> None:
141
- if (total := self.num_written + len(data[self.jets_name])) > self.num_jets:
142
- raise ValueError(
143
- f"Attempted to write more jets than expected: {total:,} > {self.num_jets:,}"
144
- )
145
- idx = np.arange(len(data[self.jets_name]))
179
+ def write(self, data: dict[str, np.ndarray]) -> None:
180
+ batch_size = len(data[self.jets_name])
181
+ idx = np.arange(batch_size)
146
182
  if self.shuffle:
147
183
  self.rng.shuffle(idx)
148
184
  data = {name: array[idx] for name, array in data.items()}
149
185
 
150
186
  low = self.num_written
151
- high = low + len(idx)
187
+ high = low + batch_size
188
+
189
+ if self.fixed_mode and high > self.num_jets:
190
+ raise ValueError(
191
+ f"Attempted to write more jets than expected: {high:,} > {self.num_jets:,}"
192
+ )
193
+
152
194
  for group in self.dtypes:
153
- self.file[group][low:high] = data[group]
154
- self.num_written += len(idx)
195
+ ds = self.file[group]
196
+ if not self.fixed_mode:
197
+ ds.resize((high,) + ds.shape[1:])
198
+ ds[low:high] = data[group]
199
+
200
+ self.num_written += batch_size
ftag/labeller.py CHANGED
@@ -30,7 +30,7 @@ class Labeller:
30
30
  def __post_init__(self) -> None:
31
31
  if isinstance(self.labels, LabelContainer):
32
32
  self.labels = list(self.labels)
33
- self.labels = sorted([Flavours[label] for label in self.labels])
33
+ self.labels = [Flavours[label] for label in self.labels]
34
34
 
35
35
  @property
36
36
  def variables(self) -> list[str]:
ftag/mock.py CHANGED
@@ -106,11 +106,11 @@ def get_mock_scores(labels: np.ndarray, is_xbb: bool = False) -> np.ndarray:
106
106
  for i in range(n_classes):
107
107
  tmp_means = []
108
108
  tmp_means = [
109
- 0 if j != i else mean_scale_list[np.random.randint(0, len(mean_scale_list))]
109
+ 0 if j != i else mean_scale_list[rng.integers(0, len(mean_scale_list))]
110
110
  for j in range(n_classes)
111
111
  ]
112
112
  means.append(tmp_means)
113
- scales.append(mean_scale_list[np.random.randint(0, len(mean_scale_list))])
113
+ scales.append(mean_scale_list[rng.integers(0, len(mean_scale_list))])
114
114
 
115
115
  # Map the labels to the means
116
116
  label_mapping = dict(zip(label_dict.values(), means))