atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ftag/hdf5/h5add_col.py ADDED
@@ -0,0 +1,391 @@
1
+ # Utils to take an input h5 file, and append one or more columns to it
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import importlib.util
6
+ from pathlib import Path
7
+ from typing import Callable
8
+
9
+ import h5py
10
+ import numpy as np
11
+
12
+ from ftag.hdf5.h5reader import H5Reader
13
+ from ftag.hdf5.h5writer import H5Writer
14
+
15
+
16
+ def merge_dicts(dicts: list[dict[str, dict[str, np.ndarray]]]) -> dict[str, dict[str, np.ndarray]]:
17
+ """Merges a list of dictionaries.
18
+
19
+ Each dict is of the form:
20
+ {
21
+ group1: {
22
+ variable_1: np.array
23
+ variable_2: np.array
24
+ },
25
+ group2: {
26
+ variable_1: np.array
27
+ variable_2: np.array
28
+ }
29
+ }
30
+
31
+ E.g.
32
+
33
+ dict1 = {
34
+ "jets": {
35
+ "pt": np.array([1, 2, 3]),
36
+ "eta": np.array([4, 5, 6])
37
+ },
38
+ }
39
+ dict2 = {
40
+ "jets": {
41
+ "phi": np.array([7, 8, 9]),
42
+ "energy": np.array([10, 11, 12])
43
+ },
44
+ }
45
+
46
+ merged = {
47
+ "jets": {
48
+ "pt": np.array([1, 2, 3]),
49
+ "eta": np.array([4, 5, 6]),
50
+ "phi": np.array([7, 8, 9]),
51
+ "energy": np.array([10, 11, 12])
52
+ }
53
+ }
54
+
55
+ Parameters
56
+ ----------
57
+ dicts : list[dict[str, dict[str, np.ndarray]]]
58
+ List of dictionaries to merge. Each dictionary should be of the form:
59
+
60
+ Returns
61
+ -------
62
+ dict[str, dict[str, np.ndarray]]
63
+ Merged dictionary of the form:
64
+ {
65
+ group1: {
66
+ variable_1: np.array
67
+ variable_2: np.array
68
+ },
69
+ group2: {
70
+ variable_1: np.array
71
+ variable_2: np.array
72
+ }
73
+ }
74
+
75
+ Raises
76
+ ------
77
+ ValueError
78
+ If a variable already exists in the merged dictionary.
79
+ """
80
+ merged: dict[str, dict[str, np.ndarray]] = {}
81
+ for d in dicts:
82
+ for group, variables in d.items():
83
+ if group not in merged:
84
+ merged[group] = {}
85
+ for variable, data in variables.items():
86
+ if variable not in merged[group]:
87
+ merged[group][variable] = data
88
+ else:
89
+ raise ValueError(f"Variable {variable} already exists in group {group}.")
90
+ return merged
91
+
92
+
93
+ def get_shape(num_jets: int, batch: dict[str, np.ndarray]) -> dict[str, tuple[int, ...]]:
94
+ """Returns a dictionary with the correct output shapes for the H5Writer.
95
+
96
+ Parameters
97
+ ----------
98
+ num_jets : int
99
+ Number of jets to write in total
100
+ batch : dict[str, np.ndarray]
101
+ Dictionary representing the batch
102
+
103
+ Returns
104
+ -------
105
+ dict[str, tuple[int, ...]]
106
+ Dictionary with the shapes of the output arrays
107
+ """
108
+ shape: dict[str, tuple[int, ...]] = {}
109
+
110
+ for key, values in batch.items():
111
+ if values.ndim == 1:
112
+ shape[key] = (num_jets,)
113
+ else:
114
+ shape[key] = (num_jets,) + values.shape[1:]
115
+ return shape
116
+
117
+
118
+ def get_all_groups(file: Path | str) -> dict[str, None]:
119
+ """Returns a dictionary with all the groups in the h5 file.
120
+
121
+ Parameters
122
+ ----------
123
+ file : Path | str
124
+ Path to the h5 file
125
+
126
+ Returns
127
+ -------
128
+ dict[str, None]
129
+ A dictionary with all the groups in the h5 file as keys and None as values,
130
+ such that h5read.stream(all_groups) will return all the groups in the file.
131
+ """
132
+ with h5py.File(file, "r") as f:
133
+ groups = list(f.keys())
134
+ return dict.fromkeys(groups)
135
+
136
+
137
+ def h5_add_column(
138
+ input_file: str | Path,
139
+ output_file: str | Path,
140
+ append_function: Callable | list[Callable],
141
+ num_jets: int = -1,
142
+ input_groups: list[str] | None = None,
143
+ output_groups: list[str] | None = None,
144
+ reader_kwargs: dict | None = None,
145
+ writer_kwargs: dict | None = None,
146
+ overwrite: bool = False,
147
+ ) -> None:
148
+ """Appends one or more columns to one or more groups in an h5 file.
149
+
150
+ Parameters
151
+ ----------
152
+ input_file : str | Path
153
+ Input h5 file to read from.
154
+ output_file : str | Path
155
+ Output h5 file to write to.
156
+ append_function : callable | list[callable]
157
+ A function, or list of functions, which take a batch from H5Reader and returns a dictionary
158
+ of the form:
159
+ {
160
+ group1 : {
161
+ new_column1 : data,
162
+ new_column2 : data,
163
+ },
164
+ group2 : {
165
+ new_column3 : data,
166
+ new_column4 : data,
167
+ },
168
+ ...
169
+ }
170
+ num_jets : int, optional
171
+ Number of jets to read from the input file. If -1, reads all jets. By default -1.
172
+ input_groups : list[str] | None, optional
173
+ List of groups to read from the input file. If None, reads all groups. By default None.
174
+ output_groups : list[str] | None, optional
175
+ List of groups to write to the output file. If None, writes all groups. By default None.
176
+ Note that this is a subset of the input groups, and must include all groups that the
177
+ append functions wish to write to.
178
+ reader_kwargs : dict, optional
179
+ Additional arguments to pass to the H5Reader. By default None.
180
+ writer_kwargs : dict, optional
181
+ Additional arguments to pass to the H5Writer. By default None.
182
+ overwrite : bool, optional
183
+ If True, will overwrite the output file if it exists. By default False.
184
+ If False, will raise a FileExistsError if the output file exists.
185
+ If None, will check if the output file exists and raise an error if it does unless
186
+ overwrite is True.
187
+
188
+ Raises
189
+ ------
190
+ FileNotFoundError
191
+ If the input file does not exist.
192
+ FileExistsError
193
+ If the output file exists and overwrite is False.
194
+ ValueError
195
+ If the new variable already exists, shape is incorrect, or the output group is not in
196
+ the input groups.
197
+
198
+ """
199
+ input_file = Path(input_file)
200
+ output_file = Path(output_file) if output_file is not None else None
201
+
202
+ if not input_file.exists():
203
+ raise FileNotFoundError(f"Input file {input_file} does not exist.")
204
+ if output_file is not None and output_file.exists() and not overwrite:
205
+ raise FileExistsError(
206
+ f"Output file {output_file} already exists. Please choose a different name."
207
+ )
208
+ if not reader_kwargs:
209
+ reader_kwargs = {}
210
+ if not writer_kwargs:
211
+ writer_kwargs = {}
212
+ if output_file is None:
213
+ output_file = input_file.with_name(input_file.name.replace(".h5", "_additional.h5"))
214
+
215
+ if not isinstance(append_function, list):
216
+ append_function = [append_function]
217
+
218
+ reader = H5Reader(input_file, shuffle=False, **reader_kwargs)
219
+ if "precision" not in writer_kwargs:
220
+ writer_kwargs["precision"] = "full"
221
+
222
+ njets = reader.num_jets if num_jets == -1 else num_jets
223
+ writer = None
224
+
225
+ input_variables = (
226
+ get_all_groups(input_file) if input_groups is None else dict.fromkeys(input_groups)
227
+ )
228
+ if output_groups is None:
229
+ output_groups = list(input_variables.keys())
230
+
231
+ assert all(
232
+ o in input_variables for o in output_groups
233
+ ), f"Output groups {output_groups} not in input groups {input_variables.keys()}"
234
+
235
+ num_batches = njets // reader.batch_size + 1
236
+ for i, batch in enumerate(reader.stream(input_variables, num_jets=njets)):
237
+ if (i + 1) % 10 == 0:
238
+ print(f"Processing batch {i + 1}/{num_batches} ({(i + 1) / num_batches * 100:.2f}%)")
239
+
240
+ to_append = merge_dicts([af(batch) for af in append_function])
241
+ for k, newvars in to_append.items():
242
+ if k not in output_groups:
243
+ raise ValueError(f"Trying to output to {k} but only {output_groups} are allowed")
244
+ for newkey, newval in newvars.items():
245
+ if newkey in batch[k].dtype.names:
246
+ raise ValueError(
247
+ f"Trying to append {newkey} to {k} but it already exists in batch"
248
+ )
249
+ if newval.shape != batch[k].shape:
250
+ raise ValueError(
251
+ f"Trying to append {newkey} to {k} but the shape is not correct"
252
+ )
253
+
254
+ to_write = {}
255
+
256
+ for key, str_array in batch.items():
257
+ if key not in output_groups:
258
+ continue
259
+ if key in to_append:
260
+ combined = np.lib.recfunctions.append_fields(
261
+ str_array,
262
+ list(to_append[key].keys()),
263
+ list(to_append[key].values()),
264
+ usemask=False,
265
+ )
266
+ to_write[key] = combined
267
+ else:
268
+ to_write[key] = str_array
269
+ if writer is None:
270
+ writer = H5Writer(
271
+ output_file,
272
+ dtypes={key: str_array.dtype for key, str_array in to_write.items()},
273
+ shapes=get_shape(njets, to_write),
274
+ shuffle=False,
275
+ **writer_kwargs,
276
+ )
277
+
278
+ writer.write(to_write)
279
+
280
+
281
+ def parse_append_function(func_path: str) -> Callable:
282
+ """Attempts to load the function specified by func_path.
283
+ The function should be specified as 'path/to/file.py:function_name'.
284
+
285
+ Parameters
286
+ ----------
287
+ func_path : str
288
+ Path to the function to load. Should be of the form 'path/to/file.py:function_name'.
289
+
290
+ Returns
291
+ -------
292
+ Callable
293
+ The function specified by func_path.
294
+
295
+ Raises
296
+ ------
297
+ ValueError
298
+ If the function path is not of the form 'path/to/file.py:function_name'.
299
+ FileNotFoundError
300
+ If the file does not exist.
301
+ ImportError
302
+ If the file cannot be imported.
303
+ AttributeError
304
+ If the function does not exist in the file.
305
+ """
306
+ if isinstance(func_path, Path):
307
+ func_path = str(func_path)
308
+ if ":" not in func_path:
309
+ print(func_path)
310
+ raise ValueError("Function should be specified as 'path/to/file.py:function_name'")
311
+
312
+ file_str, func_name = func_path.split(":")
313
+ file_path = Path(file_str).resolve()
314
+
315
+ if not file_path.is_file():
316
+ raise FileNotFoundError(f"No such file: {file_path}")
317
+
318
+ module_name = file_path.stem # Just the filename without extension
319
+
320
+ spec = importlib.util.spec_from_file_location(module_name, str(file_path))
321
+ if spec is None or spec.loader is None:
322
+ raise ImportError(f"Cannot load spec for {file_path}")
323
+
324
+ module = importlib.util.module_from_spec(spec)
325
+ spec.loader.exec_module(module)
326
+
327
+ if not hasattr(module, func_name):
328
+ raise AttributeError(f"Module {module_name} has no attribute {func_name}")
329
+
330
+ return getattr(module, func_name)
331
+
332
+
333
+ def get_args(args):
334
+ parser = argparse.ArgumentParser(description="Append columns to an h5 file.")
335
+ parser.add_argument("--input", "-i", type=str, required=True, help="Input h5 file")
336
+ parser.add_argument(
337
+ "--append_function",
338
+ type=str,
339
+ nargs="+",
340
+ help="Function to append to the h5 file. Can be a list of functions.",
341
+ required=True,
342
+ )
343
+ parser.add_argument("--output", type=str, help="Output h5 file")
344
+ parser.add_argument(
345
+ "--num_jets", type=int, default=-1, help="Number of jets to read from the input file"
346
+ )
347
+ parser.add_argument(
348
+ "--input_groups",
349
+ type=str,
350
+ nargs="+",
351
+ default=None,
352
+ help="List of groups to read from the input file",
353
+ )
354
+ parser.add_argument(
355
+ "--output_groups",
356
+ type=str,
357
+ nargs="+",
358
+ default=None,
359
+ help="List of groups to write to the output file",
360
+ )
361
+ parser.add_argument(
362
+ "--reader_kwargs", type=dict, default=None, help="Additional arguments for H5Reader"
363
+ )
364
+ parser.add_argument(
365
+ "--writer_kwargs", type=dict, default=None, help="Additional arguments for H5Writer"
366
+ )
367
+ parser.add_argument(
368
+ "--overwrite", action="store_true", help="Overwrite the output file if it exists"
369
+ )
370
+
371
+ return parser.parse_args(args)
372
+
373
+
374
+ def main(args=None):
375
+ args = get_args(args)
376
+ append_function = [
377
+ parse_append_function(func_path) if isinstance(func_path, str) else func_path
378
+ for func_path in args.append_function
379
+ ]
380
+
381
+ h5_add_column(
382
+ args.input,
383
+ args.output,
384
+ append_function,
385
+ num_jets=args.num_jets,
386
+ input_groups=args.input_groups,
387
+ output_groups=args.output_groups,
388
+ reader_kwargs=args.reader_kwargs,
389
+ writer_kwargs=args.writer_kwargs,
390
+ overwrite=args.overwrite,
391
+ )
ftag/hdf5/h5writer.py CHANGED
@@ -31,8 +31,11 @@ class H5Writer:
31
31
  Compression algorithm to use. Default is "lzf".
32
32
  precision : str | None, optional
33
33
  Precision to use. Default is None.
34
+ full_precision_vars : list[str] | None, optional
35
+ List of variables to store in full precision. Default is None.
34
36
  shuffle : bool, optional
35
37
  Whether to shuffle the jets before writing. Default is True.
38
+
36
39
  """
37
40
 
38
41
  dst: Path | str
@@ -42,6 +45,7 @@ class H5Writer:
42
45
  add_flavour_label: bool = False
43
46
  compression: str = "lzf"
44
47
  precision: str = "full"
48
+ full_precision_vars: list[str] | None = None
45
49
  shuffle: bool = True
46
50
 
47
51
  def __post_init__(self):
@@ -85,8 +89,15 @@ class H5Writer:
85
89
  dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
86
90
 
87
91
  # adjust dtype based on specified precision
92
+ full_precision_vars = [] if self.full_precision_vars is None else self.full_precision_vars
93
+ # If the field is in full_precision_vars, use the full precision dtype
88
94
  dtype = np.dtype([
89
- (field, self.fp_dtype if np.issubdtype(dt, np.floating) else dt)
95
+ (
96
+ field,
97
+ self.fp_dtype
98
+ if field not in full_precision_vars and np.issubdtype(dt, np.floating)
99
+ else dt,
100
+ )
90
101
  for field, dt in dtype.descr
91
102
  ])
92
103
 
ftag/labels.py CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
62
62
  except KeyError as e:
63
63
  raise KeyError(f"Label '{key}' not found") from e
64
64
 
65
+ def __len__(self) -> int:
66
+ return len(self.labels.keys())
67
+
65
68
  def __getattr__(self, name) -> Label:
66
69
  return self[name]
67
70
 
@@ -120,8 +123,13 @@ class LabelContainer:
120
123
  def from_list(cls, labels: list[Label]) -> LabelContainer:
121
124
  return cls({f.name: f for f in labels})
122
125
 
123
- def backgrounds(self, label: Label, only_signals: bool = True) -> LabelContainer:
124
- bkg = [f for f in self if f.category == label.category and f != label]
126
+ def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
127
+ bkg = [f for f in self if f.category == signal.category and f != signal]
125
128
  if not only_signals:
126
129
  bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
130
+ if len(bkg) == 0:
131
+ raise TypeError(
132
+ "No background flavour could be found in the flavours for signal "
133
+ f"flavour {signal.name}"
134
+ )
127
135
  return LabelContainer.from_list(bkg)
ftag/utils/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from .logging import logger, set_log_level
4
+ from .metrics import (
5
+ calculate_efficiency,
6
+ calculate_efficiency_error,
7
+ calculate_rejection,
8
+ calculate_rejection_error,
9
+ get_discriminant,
10
+ save_divide,
11
+ weighted_percentile,
12
+ )
13
+
14
+ __all__ = [
15
+ "calculate_efficiency",
16
+ "calculate_efficiency_error",
17
+ "calculate_rejection",
18
+ "calculate_rejection_error",
19
+ "get_discriminant",
20
+ "logger",
21
+ "save_divide",
22
+ "set_log_level",
23
+ "weighted_percentile",
24
+ ]
ftag/utils/logging.py ADDED
@@ -0,0 +1,123 @@
1
+ """Configuration for logger of atlas-ftag-tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import ClassVar
8
+
9
+
10
+ class CustomFormatter(logging.Formatter):
11
+ """
12
+ Logging Formatter to add colours and count warning / errors using implementation
13
+ from
14
+ https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
15
+ """
16
+
17
+ grey = "\x1b[38;21m"
18
+ yellow = "\x1b[33;21m"
19
+ green = "\x1b[32;21m"
20
+ red = "\x1b[31;21m"
21
+ bold_red = "\x1b[31;1m"
22
+ reset = "\x1b[0m"
23
+ debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
24
+ date_format = "%(levelname)s:%(name)s: %(message)s"
25
+
26
+ formats: ClassVar = {
27
+ logging.DEBUG: grey + debugformat + reset,
28
+ logging.INFO: green + date_format + reset,
29
+ logging.WARNING: yellow + date_format + reset,
30
+ logging.ERROR: red + debugformat + reset,
31
+ logging.CRITICAL: bold_red + debugformat + reset,
32
+ }
33
+
34
+ def format(self, record):
35
+ log_fmt = self.formats.get(record.levelno)
36
+ formatter = logging.Formatter(log_fmt)
37
+ return formatter.format(record)
38
+
39
+
40
+ def get_log_level(
41
+ level: str,
42
+ ):
43
+ """Get logging levels with string key.
44
+
45
+ Parameters
46
+ ----------
47
+ level : str
48
+ Log level as string.
49
+
50
+ Returns
51
+ -------
52
+ logging level
53
+ logging object with log level info
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If non-valid option is given
59
+ """
60
+ log_levels = {
61
+ "CRITICAL": logging.CRITICAL,
62
+ "ERROR": logging.ERROR,
63
+ "WARNING": logging.WARNING,
64
+ "INFO": logging.INFO,
65
+ "DEBUG": logging.DEBUG,
66
+ "NOTSET": logging.NOTSET,
67
+ }
68
+ if level not in log_levels:
69
+ raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
70
+ return log_levels[level]
71
+
72
+
73
+ def initialise_logger(
74
+ log_level: str | None = None,
75
+ ):
76
+ """Initialise.
77
+
78
+ Parameters
79
+ ----------
80
+ log_level : str, optional
81
+ Logging level defining the verbose level. Accepted values are:
82
+ CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
83
+ If the log_level is not set, it will be set to info
84
+
85
+ Returns
86
+ -------
87
+ logger
88
+ logger object with new level set
89
+ """
90
+ retrieved_log_level = get_log_level(
91
+ os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
92
+ )
93
+
94
+ tools_logger = logging.getLogger("atlas-ftag-tools")
95
+ tools_logger.setLevel(retrieved_log_level)
96
+ ch_handler = logging.StreamHandler()
97
+ ch_handler.setLevel(retrieved_log_level)
98
+ ch_handler.setFormatter(CustomFormatter())
99
+
100
+ tools_logger.addHandler(ch_handler)
101
+ tools_logger.propagate = False
102
+ return tools_logger
103
+
104
+
105
+ def set_log_level(
106
+ tools_logger,
107
+ log_level: str,
108
+ ):
109
+ """Setting log level.
110
+
111
+ Parameters
112
+ ----------
113
+ tools_logger : logger
114
+ logger object
115
+ log_level : str
116
+ Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
117
+ """
118
+ tools_logger.setLevel(get_log_level(log_level))
119
+ for handler in tools_logger.handlers:
120
+ handler.setLevel(get_log_level(log_level))
121
+
122
+
123
+ logger = initialise_logger()