atlas-ftag-tools 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas_ftag_tools-0.2.11.dist-info/METADATA +53 -0
- atlas_ftag_tools-0.2.11.dist-info/RECORD +32 -0
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/WHEEL +1 -1
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/entry_points.txt +2 -1
- atlas_ftag_tools-0.2.11.dist-info/licenses/LICENSE +201 -0
- ftag/__init__.py +13 -12
- ftag/flavours.yaml +33 -12
- ftag/fraction_optimization.py +184 -0
- ftag/hdf5/__init__.py +5 -3
- ftag/hdf5/h5add_col.py +391 -0
- ftag/hdf5/h5writer.py +12 -1
- ftag/labels.py +10 -2
- ftag/utils/__init__.py +24 -0
- ftag/utils/logging.py +123 -0
- ftag/utils/metrics.py +431 -0
- ftag/vds.py +39 -4
- ftag/{wps/working_points.py → working_points.py} +1 -1
- atlas_ftag_tools-0.2.9.dist-info/METADATA +0 -150
- atlas_ftag_tools-0.2.9.dist-info/RECORD +0 -28
- ftag/wps/__init__.py +0 -0
- ftag/wps/discriminant.py +0 -84
- {atlas_ftag_tools-0.2.9.dist-info → atlas_ftag_tools-0.2.11.dist-info}/top_level.txt +0 -0
ftag/hdf5/h5add_col.py
ADDED
@@ -0,0 +1,391 @@
|
|
1
|
+
# Utils to take an input h5 file, and append one or more columns to it
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import argparse
|
5
|
+
import importlib.util
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Callable
|
8
|
+
|
9
|
+
import h5py
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from ftag.hdf5.h5reader import H5Reader
|
13
|
+
from ftag.hdf5.h5writer import H5Writer
|
14
|
+
|
15
|
+
|
16
|
+
def merge_dicts(dicts: list[dict[str, dict[str, np.ndarray]]]) -> dict[str, dict[str, np.ndarray]]:
|
17
|
+
"""Merges a list of dictionaries.
|
18
|
+
|
19
|
+
Each dict is of the form:
|
20
|
+
{
|
21
|
+
group1: {
|
22
|
+
variable_1: np.array
|
23
|
+
variable_2: np.array
|
24
|
+
},
|
25
|
+
group2: {
|
26
|
+
variable_1: np.array
|
27
|
+
variable_2: np.array
|
28
|
+
}
|
29
|
+
}
|
30
|
+
|
31
|
+
E.g.
|
32
|
+
|
33
|
+
dict1 = {
|
34
|
+
"jets": {
|
35
|
+
"pt": np.array([1, 2, 3]),
|
36
|
+
"eta": np.array([4, 5, 6])
|
37
|
+
},
|
38
|
+
}
|
39
|
+
dict2 = {
|
40
|
+
"jets": {
|
41
|
+
"phi": np.array([7, 8, 9]),
|
42
|
+
"energy": np.array([10, 11, 12])
|
43
|
+
},
|
44
|
+
}
|
45
|
+
|
46
|
+
merged = {
|
47
|
+
"jets": {
|
48
|
+
"pt": np.array([1, 2, 3]),
|
49
|
+
"eta": np.array([4, 5, 6]),
|
50
|
+
"phi": np.array([7, 8, 9]),
|
51
|
+
"energy": np.array([10, 11, 12])
|
52
|
+
}
|
53
|
+
}
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
dicts : list[dict[str, dict[str, np.ndarray]]]
|
58
|
+
List of dictionaries to merge. Each dictionary should be of the form:
|
59
|
+
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
dict[str, dict[str, np.ndarray]]
|
63
|
+
Merged dictionary of the form:
|
64
|
+
{
|
65
|
+
group1: {
|
66
|
+
variable_1: np.array
|
67
|
+
variable_2: np.array
|
68
|
+
},
|
69
|
+
group2: {
|
70
|
+
variable_1: np.array
|
71
|
+
variable_2: np.array
|
72
|
+
}
|
73
|
+
}
|
74
|
+
|
75
|
+
Raises
|
76
|
+
------
|
77
|
+
ValueError
|
78
|
+
If a variable already exists in the merged dictionary.
|
79
|
+
"""
|
80
|
+
merged: dict[str, dict[str, np.ndarray]] = {}
|
81
|
+
for d in dicts:
|
82
|
+
for group, variables in d.items():
|
83
|
+
if group not in merged:
|
84
|
+
merged[group] = {}
|
85
|
+
for variable, data in variables.items():
|
86
|
+
if variable not in merged[group]:
|
87
|
+
merged[group][variable] = data
|
88
|
+
else:
|
89
|
+
raise ValueError(f"Variable {variable} already exists in group {group}.")
|
90
|
+
return merged
|
91
|
+
|
92
|
+
|
93
|
+
def get_shape(num_jets: int, batch: dict[str, np.ndarray]) -> dict[str, tuple[int, ...]]:
|
94
|
+
"""Returns a dictionary with the correct output shapes for the H5Writer.
|
95
|
+
|
96
|
+
Parameters
|
97
|
+
----------
|
98
|
+
num_jets : int
|
99
|
+
Number of jets to write in total
|
100
|
+
batch : dict[str, np.ndarray]
|
101
|
+
Dictionary representing the batch
|
102
|
+
|
103
|
+
Returns
|
104
|
+
-------
|
105
|
+
dict[str, tuple[int, ...]]
|
106
|
+
Dictionary with the shapes of the output arrays
|
107
|
+
"""
|
108
|
+
shape: dict[str, tuple[int, ...]] = {}
|
109
|
+
|
110
|
+
for key, values in batch.items():
|
111
|
+
if values.ndim == 1:
|
112
|
+
shape[key] = (num_jets,)
|
113
|
+
else:
|
114
|
+
shape[key] = (num_jets,) + values.shape[1:]
|
115
|
+
return shape
|
116
|
+
|
117
|
+
|
118
|
+
def get_all_groups(file: Path | str) -> dict[str, None]:
|
119
|
+
"""Returns a dictionary with all the groups in the h5 file.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
file : Path | str
|
124
|
+
Path to the h5 file
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
dict[str, None]
|
129
|
+
A dictionary with all the groups in the h5 file as keys and None as values,
|
130
|
+
such that h5read.stream(all_groups) will return all the groups in the file.
|
131
|
+
"""
|
132
|
+
with h5py.File(file, "r") as f:
|
133
|
+
groups = list(f.keys())
|
134
|
+
return dict.fromkeys(groups)
|
135
|
+
|
136
|
+
|
137
|
+
def h5_add_column(
|
138
|
+
input_file: str | Path,
|
139
|
+
output_file: str | Path,
|
140
|
+
append_function: Callable | list[Callable],
|
141
|
+
num_jets: int = -1,
|
142
|
+
input_groups: list[str] | None = None,
|
143
|
+
output_groups: list[str] | None = None,
|
144
|
+
reader_kwargs: dict | None = None,
|
145
|
+
writer_kwargs: dict | None = None,
|
146
|
+
overwrite: bool = False,
|
147
|
+
) -> None:
|
148
|
+
"""Appends one or more columns to one or more groups in an h5 file.
|
149
|
+
|
150
|
+
Parameters
|
151
|
+
----------
|
152
|
+
input_file : str | Path
|
153
|
+
Input h5 file to read from.
|
154
|
+
output_file : str | Path
|
155
|
+
Output h5 file to write to.
|
156
|
+
append_function : callable | list[callable]
|
157
|
+
A function, or list of functions, which take a batch from H5Reader and returns a dictionary
|
158
|
+
of the form:
|
159
|
+
{
|
160
|
+
group1 : {
|
161
|
+
new_column1 : data,
|
162
|
+
new_column2 : data,
|
163
|
+
},
|
164
|
+
group2 : {
|
165
|
+
new_column3 : data,
|
166
|
+
new_column4 : data,
|
167
|
+
},
|
168
|
+
...
|
169
|
+
}
|
170
|
+
num_jets : int, optional
|
171
|
+
Number of jets to read from the input file. If -1, reads all jets. By default -1.
|
172
|
+
input_groups : list[str] | None, optional
|
173
|
+
List of groups to read from the input file. If None, reads all groups. By default None.
|
174
|
+
output_groups : list[str] | None, optional
|
175
|
+
List of groups to write to the output file. If None, writes all groups. By default None.
|
176
|
+
Note that this is a subset of the input groups, and must include all groups that the
|
177
|
+
append functions wish to write to.
|
178
|
+
reader_kwargs : dict, optional
|
179
|
+
Additional arguments to pass to the H5Reader. By default None.
|
180
|
+
writer_kwargs : dict, optional
|
181
|
+
Additional arguments to pass to the H5Writer. By default None.
|
182
|
+
overwrite : bool, optional
|
183
|
+
If True, will overwrite the output file if it exists. By default False.
|
184
|
+
If False, will raise a FileExistsError if the output file exists.
|
185
|
+
If None, will check if the output file exists and raise an error if it does unless
|
186
|
+
overwrite is True.
|
187
|
+
|
188
|
+
Raises
|
189
|
+
------
|
190
|
+
FileNotFoundError
|
191
|
+
If the input file does not exist.
|
192
|
+
FileExistsError
|
193
|
+
If the output file exists and overwrite is False.
|
194
|
+
ValueError
|
195
|
+
If the new variable already exists, shape is incorrect, or the output group is not in
|
196
|
+
the input groups.
|
197
|
+
|
198
|
+
"""
|
199
|
+
input_file = Path(input_file)
|
200
|
+
output_file = Path(output_file) if output_file is not None else None
|
201
|
+
|
202
|
+
if not input_file.exists():
|
203
|
+
raise FileNotFoundError(f"Input file {input_file} does not exist.")
|
204
|
+
if output_file is not None and output_file.exists() and not overwrite:
|
205
|
+
raise FileExistsError(
|
206
|
+
f"Output file {output_file} already exists. Please choose a different name."
|
207
|
+
)
|
208
|
+
if not reader_kwargs:
|
209
|
+
reader_kwargs = {}
|
210
|
+
if not writer_kwargs:
|
211
|
+
writer_kwargs = {}
|
212
|
+
if output_file is None:
|
213
|
+
output_file = input_file.with_name(input_file.name.replace(".h5", "_additional.h5"))
|
214
|
+
|
215
|
+
if not isinstance(append_function, list):
|
216
|
+
append_function = [append_function]
|
217
|
+
|
218
|
+
reader = H5Reader(input_file, shuffle=False, **reader_kwargs)
|
219
|
+
if "precision" not in writer_kwargs:
|
220
|
+
writer_kwargs["precision"] = "full"
|
221
|
+
|
222
|
+
njets = reader.num_jets if num_jets == -1 else num_jets
|
223
|
+
writer = None
|
224
|
+
|
225
|
+
input_variables = (
|
226
|
+
get_all_groups(input_file) if input_groups is None else dict.fromkeys(input_groups)
|
227
|
+
)
|
228
|
+
if output_groups is None:
|
229
|
+
output_groups = list(input_variables.keys())
|
230
|
+
|
231
|
+
assert all(
|
232
|
+
o in input_variables for o in output_groups
|
233
|
+
), f"Output groups {output_groups} not in input groups {input_variables.keys()}"
|
234
|
+
|
235
|
+
num_batches = njets // reader.batch_size + 1
|
236
|
+
for i, batch in enumerate(reader.stream(input_variables, num_jets=njets)):
|
237
|
+
if (i + 1) % 10 == 0:
|
238
|
+
print(f"Processing batch {i + 1}/{num_batches} ({(i + 1) / num_batches * 100:.2f}%)")
|
239
|
+
|
240
|
+
to_append = merge_dicts([af(batch) for af in append_function])
|
241
|
+
for k, newvars in to_append.items():
|
242
|
+
if k not in output_groups:
|
243
|
+
raise ValueError(f"Trying to output to {k} but only {output_groups} are allowed")
|
244
|
+
for newkey, newval in newvars.items():
|
245
|
+
if newkey in batch[k].dtype.names:
|
246
|
+
raise ValueError(
|
247
|
+
f"Trying to append {newkey} to {k} but it already exists in batch"
|
248
|
+
)
|
249
|
+
if newval.shape != batch[k].shape:
|
250
|
+
raise ValueError(
|
251
|
+
f"Trying to append {newkey} to {k} but the shape is not correct"
|
252
|
+
)
|
253
|
+
|
254
|
+
to_write = {}
|
255
|
+
|
256
|
+
for key, str_array in batch.items():
|
257
|
+
if key not in output_groups:
|
258
|
+
continue
|
259
|
+
if key in to_append:
|
260
|
+
combined = np.lib.recfunctions.append_fields(
|
261
|
+
str_array,
|
262
|
+
list(to_append[key].keys()),
|
263
|
+
list(to_append[key].values()),
|
264
|
+
usemask=False,
|
265
|
+
)
|
266
|
+
to_write[key] = combined
|
267
|
+
else:
|
268
|
+
to_write[key] = str_array
|
269
|
+
if writer is None:
|
270
|
+
writer = H5Writer(
|
271
|
+
output_file,
|
272
|
+
dtypes={key: str_array.dtype for key, str_array in to_write.items()},
|
273
|
+
shapes=get_shape(njets, to_write),
|
274
|
+
shuffle=False,
|
275
|
+
**writer_kwargs,
|
276
|
+
)
|
277
|
+
|
278
|
+
writer.write(to_write)
|
279
|
+
|
280
|
+
|
281
|
+
def parse_append_function(func_path: str) -> Callable:
|
282
|
+
"""Attempts to load the function specified by func_path.
|
283
|
+
The function should be specified as 'path/to/file.py:function_name'.
|
284
|
+
|
285
|
+
Parameters
|
286
|
+
----------
|
287
|
+
func_path : str
|
288
|
+
Path to the function to load. Should be of the form 'path/to/file.py:function_name'.
|
289
|
+
|
290
|
+
Returns
|
291
|
+
-------
|
292
|
+
Callable
|
293
|
+
The function specified by func_path.
|
294
|
+
|
295
|
+
Raises
|
296
|
+
------
|
297
|
+
ValueError
|
298
|
+
If the function path is not of the form 'path/to/file.py:function_name'.
|
299
|
+
FileNotFoundError
|
300
|
+
If the file does not exist.
|
301
|
+
ImportError
|
302
|
+
If the file cannot be imported.
|
303
|
+
AttributeError
|
304
|
+
If the function does not exist in the file.
|
305
|
+
"""
|
306
|
+
if isinstance(func_path, Path):
|
307
|
+
func_path = str(func_path)
|
308
|
+
if ":" not in func_path:
|
309
|
+
print(func_path)
|
310
|
+
raise ValueError("Function should be specified as 'path/to/file.py:function_name'")
|
311
|
+
|
312
|
+
file_str, func_name = func_path.split(":")
|
313
|
+
file_path = Path(file_str).resolve()
|
314
|
+
|
315
|
+
if not file_path.is_file():
|
316
|
+
raise FileNotFoundError(f"No such file: {file_path}")
|
317
|
+
|
318
|
+
module_name = file_path.stem # Just the filename without extension
|
319
|
+
|
320
|
+
spec = importlib.util.spec_from_file_location(module_name, str(file_path))
|
321
|
+
if spec is None or spec.loader is None:
|
322
|
+
raise ImportError(f"Cannot load spec for {file_path}")
|
323
|
+
|
324
|
+
module = importlib.util.module_from_spec(spec)
|
325
|
+
spec.loader.exec_module(module)
|
326
|
+
|
327
|
+
if not hasattr(module, func_name):
|
328
|
+
raise AttributeError(f"Module {module_name} has no attribute {func_name}")
|
329
|
+
|
330
|
+
return getattr(module, func_name)
|
331
|
+
|
332
|
+
|
333
|
+
def get_args(args):
|
334
|
+
parser = argparse.ArgumentParser(description="Append columns to an h5 file.")
|
335
|
+
parser.add_argument("--input", "-i", type=str, required=True, help="Input h5 file")
|
336
|
+
parser.add_argument(
|
337
|
+
"--append_function",
|
338
|
+
type=str,
|
339
|
+
nargs="+",
|
340
|
+
help="Function to append to the h5 file. Can be a list of functions.",
|
341
|
+
required=True,
|
342
|
+
)
|
343
|
+
parser.add_argument("--output", type=str, help="Output h5 file")
|
344
|
+
parser.add_argument(
|
345
|
+
"--num_jets", type=int, default=-1, help="Number of jets to read from the input file"
|
346
|
+
)
|
347
|
+
parser.add_argument(
|
348
|
+
"--input_groups",
|
349
|
+
type=str,
|
350
|
+
nargs="+",
|
351
|
+
default=None,
|
352
|
+
help="List of groups to read from the input file",
|
353
|
+
)
|
354
|
+
parser.add_argument(
|
355
|
+
"--output_groups",
|
356
|
+
type=str,
|
357
|
+
nargs="+",
|
358
|
+
default=None,
|
359
|
+
help="List of groups to write to the output file",
|
360
|
+
)
|
361
|
+
parser.add_argument(
|
362
|
+
"--reader_kwargs", type=dict, default=None, help="Additional arguments for H5Reader"
|
363
|
+
)
|
364
|
+
parser.add_argument(
|
365
|
+
"--writer_kwargs", type=dict, default=None, help="Additional arguments for H5Writer"
|
366
|
+
)
|
367
|
+
parser.add_argument(
|
368
|
+
"--overwrite", action="store_true", help="Overwrite the output file if it exists"
|
369
|
+
)
|
370
|
+
|
371
|
+
return parser.parse_args(args)
|
372
|
+
|
373
|
+
|
374
|
+
def main(args=None):
|
375
|
+
args = get_args(args)
|
376
|
+
append_function = [
|
377
|
+
parse_append_function(func_path) if isinstance(func_path, str) else func_path
|
378
|
+
for func_path in args.append_function
|
379
|
+
]
|
380
|
+
|
381
|
+
h5_add_column(
|
382
|
+
args.input,
|
383
|
+
args.output,
|
384
|
+
append_function,
|
385
|
+
num_jets=args.num_jets,
|
386
|
+
input_groups=args.input_groups,
|
387
|
+
output_groups=args.output_groups,
|
388
|
+
reader_kwargs=args.reader_kwargs,
|
389
|
+
writer_kwargs=args.writer_kwargs,
|
390
|
+
overwrite=args.overwrite,
|
391
|
+
)
|
ftag/hdf5/h5writer.py
CHANGED
@@ -31,8 +31,11 @@ class H5Writer:
|
|
31
31
|
Compression algorithm to use. Default is "lzf".
|
32
32
|
precision : str | None, optional
|
33
33
|
Precision to use. Default is None.
|
34
|
+
full_precision_vars : list[str] | None, optional
|
35
|
+
List of variables to store in full precision. Default is None.
|
34
36
|
shuffle : bool, optional
|
35
37
|
Whether to shuffle the jets before writing. Default is True.
|
38
|
+
|
36
39
|
"""
|
37
40
|
|
38
41
|
dst: Path | str
|
@@ -42,6 +45,7 @@ class H5Writer:
|
|
42
45
|
add_flavour_label: bool = False
|
43
46
|
compression: str = "lzf"
|
44
47
|
precision: str = "full"
|
48
|
+
full_precision_vars: list[str] | None = None
|
45
49
|
shuffle: bool = True
|
46
50
|
|
47
51
|
def __post_init__(self):
|
@@ -85,8 +89,15 @@ class H5Writer:
|
|
85
89
|
dtype = np.dtype([*dtype.descr, ("flavour_label", "i4")])
|
86
90
|
|
87
91
|
# adjust dtype based on specified precision
|
92
|
+
full_precision_vars = [] if self.full_precision_vars is None else self.full_precision_vars
|
93
|
+
# If the field is in full_precision_vars, use the full precision dtype
|
88
94
|
dtype = np.dtype([
|
89
|
-
(
|
95
|
+
(
|
96
|
+
field,
|
97
|
+
self.fp_dtype
|
98
|
+
if field not in full_precision_vars and np.issubdtype(dt, np.floating)
|
99
|
+
else dt,
|
100
|
+
)
|
90
101
|
for field, dt in dtype.descr
|
91
102
|
])
|
92
103
|
|
ftag/labels.py
CHANGED
@@ -62,6 +62,9 @@ class LabelContainer:
|
|
62
62
|
except KeyError as e:
|
63
63
|
raise KeyError(f"Label '{key}' not found") from e
|
64
64
|
|
65
|
+
def __len__(self) -> int:
|
66
|
+
return len(self.labels.keys())
|
67
|
+
|
65
68
|
def __getattr__(self, name) -> Label:
|
66
69
|
return self[name]
|
67
70
|
|
@@ -120,8 +123,13 @@ class LabelContainer:
|
|
120
123
|
def from_list(cls, labels: list[Label]) -> LabelContainer:
|
121
124
|
return cls({f.name: f for f in labels})
|
122
125
|
|
123
|
-
def backgrounds(self,
|
124
|
-
bkg = [f for f in self if f.category ==
|
126
|
+
def backgrounds(self, signal: Label, only_signals: bool = True) -> LabelContainer:
|
127
|
+
bkg = [f for f in self if f.category == signal.category and f != signal]
|
125
128
|
if not only_signals:
|
126
129
|
bkg = [f for f in bkg if f.name not in {"ujets", "qcd"}]
|
130
|
+
if len(bkg) == 0:
|
131
|
+
raise TypeError(
|
132
|
+
"No background flavour could be found in the flavours for signal "
|
133
|
+
f"flavour {signal.name}"
|
134
|
+
)
|
127
135
|
return LabelContainer.from_list(bkg)
|
ftag/utils/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from .logging import logger, set_log_level
|
4
|
+
from .metrics import (
|
5
|
+
calculate_efficiency,
|
6
|
+
calculate_efficiency_error,
|
7
|
+
calculate_rejection,
|
8
|
+
calculate_rejection_error,
|
9
|
+
get_discriminant,
|
10
|
+
save_divide,
|
11
|
+
weighted_percentile,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"calculate_efficiency",
|
16
|
+
"calculate_efficiency_error",
|
17
|
+
"calculate_rejection",
|
18
|
+
"calculate_rejection_error",
|
19
|
+
"get_discriminant",
|
20
|
+
"logger",
|
21
|
+
"save_divide",
|
22
|
+
"set_log_level",
|
23
|
+
"weighted_percentile",
|
24
|
+
]
|
ftag/utils/logging.py
ADDED
@@ -0,0 +1,123 @@
|
|
1
|
+
"""Configuration for logger of atlas-ftag-tools."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import ClassVar
|
8
|
+
|
9
|
+
|
10
|
+
class CustomFormatter(logging.Formatter):
|
11
|
+
"""
|
12
|
+
Logging Formatter to add colours and count warning / errors using implementation
|
13
|
+
from
|
14
|
+
https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output.
|
15
|
+
"""
|
16
|
+
|
17
|
+
grey = "\x1b[38;21m"
|
18
|
+
yellow = "\x1b[33;21m"
|
19
|
+
green = "\x1b[32;21m"
|
20
|
+
red = "\x1b[31;21m"
|
21
|
+
bold_red = "\x1b[31;1m"
|
22
|
+
reset = "\x1b[0m"
|
23
|
+
debugformat = "%(asctime)s - %(levelname)s:%(name)s: %(message)s (%(filename)s:%(lineno)d)"
|
24
|
+
date_format = "%(levelname)s:%(name)s: %(message)s"
|
25
|
+
|
26
|
+
formats: ClassVar = {
|
27
|
+
logging.DEBUG: grey + debugformat + reset,
|
28
|
+
logging.INFO: green + date_format + reset,
|
29
|
+
logging.WARNING: yellow + date_format + reset,
|
30
|
+
logging.ERROR: red + debugformat + reset,
|
31
|
+
logging.CRITICAL: bold_red + debugformat + reset,
|
32
|
+
}
|
33
|
+
|
34
|
+
def format(self, record):
|
35
|
+
log_fmt = self.formats.get(record.levelno)
|
36
|
+
formatter = logging.Formatter(log_fmt)
|
37
|
+
return formatter.format(record)
|
38
|
+
|
39
|
+
|
40
|
+
def get_log_level(
|
41
|
+
level: str,
|
42
|
+
):
|
43
|
+
"""Get logging levels with string key.
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
level : str
|
48
|
+
Log level as string.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
logging level
|
53
|
+
logging object with log level info
|
54
|
+
|
55
|
+
Raises
|
56
|
+
------
|
57
|
+
ValueError
|
58
|
+
If non-valid option is given
|
59
|
+
"""
|
60
|
+
log_levels = {
|
61
|
+
"CRITICAL": logging.CRITICAL,
|
62
|
+
"ERROR": logging.ERROR,
|
63
|
+
"WARNING": logging.WARNING,
|
64
|
+
"INFO": logging.INFO,
|
65
|
+
"DEBUG": logging.DEBUG,
|
66
|
+
"NOTSET": logging.NOTSET,
|
67
|
+
}
|
68
|
+
if level not in log_levels:
|
69
|
+
raise ValueError(f"The 'DebugLevel' option {level} is not valid.")
|
70
|
+
return log_levels[level]
|
71
|
+
|
72
|
+
|
73
|
+
def initialise_logger(
|
74
|
+
log_level: str | None = None,
|
75
|
+
):
|
76
|
+
"""Initialise.
|
77
|
+
|
78
|
+
Parameters
|
79
|
+
----------
|
80
|
+
log_level : str, optional
|
81
|
+
Logging level defining the verbose level. Accepted values are:
|
82
|
+
CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET, by default None
|
83
|
+
If the log_level is not set, it will be set to info
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
logger
|
88
|
+
logger object with new level set
|
89
|
+
"""
|
90
|
+
retrieved_log_level = get_log_level(
|
91
|
+
os.environ.get("LOG_LEVEL", "INFO") if log_level is None else log_level
|
92
|
+
)
|
93
|
+
|
94
|
+
tools_logger = logging.getLogger("atlas-ftag-tools")
|
95
|
+
tools_logger.setLevel(retrieved_log_level)
|
96
|
+
ch_handler = logging.StreamHandler()
|
97
|
+
ch_handler.setLevel(retrieved_log_level)
|
98
|
+
ch_handler.setFormatter(CustomFormatter())
|
99
|
+
|
100
|
+
tools_logger.addHandler(ch_handler)
|
101
|
+
tools_logger.propagate = False
|
102
|
+
return tools_logger
|
103
|
+
|
104
|
+
|
105
|
+
def set_log_level(
|
106
|
+
tools_logger,
|
107
|
+
log_level: str,
|
108
|
+
):
|
109
|
+
"""Setting log level.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
tools_logger : logger
|
114
|
+
logger object
|
115
|
+
log_level : str
|
116
|
+
Logging level corresponding CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET
|
117
|
+
"""
|
118
|
+
tools_logger.setLevel(get_log_level(log_level))
|
119
|
+
for handler in tools_logger.handlers:
|
120
|
+
handler.setLevel(get_log_level(log_level))
|
121
|
+
|
122
|
+
|
123
|
+
logger = initialise_logger()
|