data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev720__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/METADATA +669 -42
- data_manipulation_utilities-0.2.8.dev720.dist-info/RECORD +45 -0
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/WHEEL +1 -2
- data_manipulation_utilities-0.2.8.dev720.dist-info/entry_points.txt +8 -0
- dmu/generic/hashing.py +34 -8
- dmu/generic/utilities.py +164 -11
- dmu/logging/log_store.py +34 -2
- dmu/logging/messages.py +96 -0
- dmu/ml/cv_classifier.py +3 -3
- dmu/ml/cv_diagnostics.py +3 -0
- dmu/ml/cv_performance.py +58 -0
- dmu/ml/cv_predict.py +149 -46
- dmu/ml/train_mva.py +482 -100
- dmu/ml/utilities.py +29 -10
- dmu/pdataframe/utilities.py +28 -3
- dmu/plotting/fwhm.py +2 -2
- dmu/plotting/matrix.py +1 -1
- dmu/plotting/plotter.py +23 -3
- dmu/plotting/plotter_1d.py +96 -32
- dmu/plotting/plotter_2d.py +5 -0
- dmu/rdataframe/utilities.py +54 -3
- dmu/rfile/ddfgetter.py +102 -0
- dmu/stats/fit_stats.py +129 -0
- dmu/stats/fitter.py +55 -22
- dmu/stats/gof_calculator.py +7 -0
- dmu/stats/model_factory.py +153 -62
- dmu/stats/parameters.py +100 -0
- dmu/stats/utilities.py +443 -12
- dmu/stats/wdata.py +187 -0
- dmu/stats/zfit.py +17 -0
- dmu/stats/zfit_plotter.py +147 -36
- dmu/testing/utilities.py +102 -24
- dmu/workflow/__init__.py +0 -0
- dmu/workflow/cache.py +266 -0
- data_manipulation_utilities-0.2.7.data/scripts/publish +0 -89
- data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
- data_manipulation_utilities-0.2.7.dist-info/entry_points.txt +0 -6
- data_manipulation_utilities-0.2.7.dist-info/top_level.txt +0 -3
- dmu_data/ml/tests/diagnostics_from_file.yaml +0 -13
- dmu_data/ml/tests/diagnostics_from_model.yaml +0 -10
- dmu_data/ml/tests/diagnostics_multiple_methods.yaml +0 -10
- dmu_data/ml/tests/diagnostics_overlay.yaml +0 -33
- dmu_data/ml/tests/train_mva.yaml +0 -58
- dmu_data/ml/tests/train_mva_with_diagnostics.yaml +0 -82
- dmu_data/plotting/tests/2d.yaml +0 -24
- dmu_data/plotting/tests/fig_size.yaml +0 -13
- dmu_data/plotting/tests/high_stat.yaml +0 -22
- dmu_data/plotting/tests/legend.yaml +0 -12
- dmu_data/plotting/tests/name.yaml +0 -14
- dmu_data/plotting/tests/no_bounds.yaml +0 -12
- dmu_data/plotting/tests/normalized.yaml +0 -9
- dmu_data/plotting/tests/plug_fwhm.yaml +0 -24
- dmu_data/plotting/tests/plug_stats.yaml +0 -19
- dmu_data/plotting/tests/simple.yaml +0 -9
- dmu_data/plotting/tests/stats.yaml +0 -9
- dmu_data/plotting/tests/styling.yaml +0 -11
- dmu_data/plotting/tests/title.yaml +0 -14
- dmu_data/plotting/tests/weights.yaml +0 -13
- dmu_data/text/transform.toml +0 -4
- dmu_data/text/transform.txt +0 -6
- dmu_data/text/transform_set.toml +0 -8
- dmu_data/text/transform_set.txt +0 -6
- dmu_data/text/transform_trf.txt +0 -12
- dmu_scripts/git/publish +0 -89
- dmu_scripts/physics/check_truth.py +0 -121
- dmu_scripts/rfile/compare_root_files.py +0 -299
- dmu_scripts/rfile/print_trees.py +0 -35
- dmu_scripts/ssh/coned.py +0 -168
- dmu_scripts/text/transform_text.py +0 -46
- {dmu_data → dmu}/__init__.py +0 -0
@@ -0,0 +1,45 @@
|
|
1
|
+
dmu/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
|
3
|
+
dmu/generic/hashing.py,sha256=QR5Gbv6-ANvi5hL232UNMrw9DONpU27BWTynXGxQLGU,1806
|
4
|
+
dmu/generic/utilities.py,sha256=0tT93vF_x0q8STRrTD0GvBEpALz-mqE-vJyen4zWCO8,6861
|
5
|
+
dmu/generic/version_management.py,sha256=j0ImlAq6SVNjTh3xRsF6G7DSoyr1w8kTRY84dNriGRE,3750
|
6
|
+
dmu/logging/log_store.py,sha256=eRSy8Y4fuiDFJK02Z6fq67XQzOrhQ7GMr2LvvJQbJ40,5172
|
7
|
+
dmu/logging/messages.py,sha256=Oj3O5EO2KOPtffyVq2P7RPzjpoXtxZ6yXO5HwTftVcM,2903
|
8
|
+
dmu/ml/cv_classifier.py,sha256=6rjezMahwL-WzLGKU-fzMzNxJZAGbM7YAbhaZVcJ3F0,4258
|
9
|
+
dmu/ml/cv_diagnostics.py,sha256=PLh41mSVE8Kagp9KcuRDN_7tDL9MjPxQzuewY8jDnNo,7600
|
10
|
+
dmu/ml/cv_performance.py,sha256=q9sLxIx7GP-dand3tnhHCBJnT6xqssNdRYv_TVjYWUM,1910
|
11
|
+
dmu/ml/cv_predict.py,sha256=0sc_OqwOewKvipcMyi3QqkgG30nkpZZjE-SOhHWHMd0,10778
|
12
|
+
dmu/ml/train_mva.py,sha256=7KAFX_zOx8MGbYx62U81JbdBkrZvqclSSkgmYvWX-60,34861
|
13
|
+
dmu/ml/utilities.py,sha256=A9j3tBh-jfaFdwwLUleo1QnttfawN7XDiQRh4VTvqVY,4597
|
14
|
+
dmu/pdataframe/utilities.py,sha256=xl6iLVKUccqVXYjuHsDUZ6UrCKQPw1k8D-f6407Yq30,2742
|
15
|
+
dmu/plotting/fwhm.py,sha256=4e8n6624pxWLcOOtayCQ_hDSSMKU21-3UsdmbkX1ojk,1949
|
16
|
+
dmu/plotting/matrix.py,sha256=s_5W8O3yXF3u8OX3f4J4hCoxIVZt1TF8S-qJsFBh2Go,5005
|
17
|
+
dmu/plotting/plotter.py,sha256=oc_n9ug0JPaQZycrW_TJkgNxjr0LHNrVJcijqmiLUR4,8136
|
18
|
+
dmu/plotting/plotter_1d.py,sha256=Kyoyh-QyZLXXqX19wqEDUWCD1nJEvEonGp9nlgEaoZE,10936
|
19
|
+
dmu/plotting/plotter_2d.py,sha256=dXC-7Rsquibe5cn7622ryoKpuv7KCAmouIIXwQ_VEFM,3172
|
20
|
+
dmu/plotting/utilities.py,sha256=SI9dvtZq2gr-PXVz71KE4o0i09rZOKgqJKD1jzf6KXk,1167
|
21
|
+
dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
|
22
|
+
dmu/rdataframe/utilities.py,sha256=cY1Na8HbJ7kB2dwmBagRdsRyCA4ZT_vyIU86ewREj2Y,5322
|
23
|
+
dmu/rfile/ddfgetter.py,sha256=0jfNzpv72_NQUKOK5SBsn289rUqVt2BMvuL-Ro5oY7I,3316
|
24
|
+
dmu/rfile/rfprinter.py,sha256=mp5jd-oCJAnuokbdmGyL9i6tK2lY72jEfROuBIZ_ums,3941
|
25
|
+
dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
|
26
|
+
dmu/stats/fit_stats.py,sha256=wzkQT9U32ljGe4azUj1Fj0ECF3zmnH2Ncn0O-_Pl1zQ,4070
|
27
|
+
dmu/stats/fitter.py,sha256=rm_fwjkq-0LSjXB_gt3y6BnHoK8Xvd4gHYwKBUJaItQ,19603
|
28
|
+
dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
|
29
|
+
dmu/stats/gof_calculator.py,sha256=63zNJJGKPy-j_hPNPfu9qNlhrHjYIgJOyL8-VDtbwuI,4894
|
30
|
+
dmu/stats/minimizers.py,sha256=db9R2G0SOV-k0BKi6m4EyB_yp6AtZdP23_28B0315oo,7094
|
31
|
+
dmu/stats/model_factory.py,sha256=0_o5OmiX0cNhp9_cNqBOYfasBgKlQkQPiy5nqi9qQKA,18966
|
32
|
+
dmu/stats/parameters.py,sha256=9lycexTT5ZcxXciiQY9HoJV8O1ahrTEkagd7dYXcfj8,3224
|
33
|
+
dmu/stats/utilities.py,sha256=7_tr1j-dl3lLNpxIMWruZs4yUtlNuUTknwGMERpfLhs,17338
|
34
|
+
dmu/stats/wdata.py,sha256=IbjZFU9SHTLSYfaBgqamDvqy1K7-3-SaKbU4bGsamK0,6799
|
35
|
+
dmu/stats/zfit.py,sha256=aSZj_4IHi9IBthfqlNJeA8YSoMmXO5WipgiKnXKGbnM,286
|
36
|
+
dmu/stats/zfit_models.py,sha256=SI61KJ-OG1UAabDICU1iTh6JPKM3giR2ErDraRjkCV8,1842
|
37
|
+
dmu/stats/zfit_plotter.py,sha256=gbN5KxhJcP4ItCi98c-fj5_UtvVWL_NA9jkTHiRjvnE,23854
|
38
|
+
dmu/testing/utilities.py,sha256=WYlz7Ve5lQjuWhhNL4gWe6_qcByBLV762Lhrc6A0P9E,7421
|
39
|
+
dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
|
40
|
+
dmu/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
|
+
dmu/workflow/cache.py,sha256=CtkGwxuF4UJlD55SmUJcRgWYLsbZOyUvYLI8oTVzk_g,8768
|
42
|
+
data_manipulation_utilities-0.2.8.dev720.dist-info/METADATA,sha256=RuHltvo8DQctnGYdFssfMv92oU6b7tgn3haFZ2HVk0E,51153
|
43
|
+
data_manipulation_utilities-0.2.8.dev720.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
44
|
+
data_manipulation_utilities-0.2.8.dev720.dist-info/entry_points.txt,sha256=M0C8_u9B_xSmyfemdPwdIBh9QuPIkjhEpG060Y5_Pjw,321
|
45
|
+
data_manipulation_utilities-0.2.8.dev720.dist-info/RECORD,,
|
@@ -0,0 +1,8 @@
|
|
1
|
+
[console_scripts]
|
2
|
+
check_truth=dmu_scripts.physics.check_truth:main
|
3
|
+
compare_classifiers=dmu_scripts.ml.compare_classifiers:main
|
4
|
+
compare_root_files=dmu_scripts.rfile.compare_root_files:main
|
5
|
+
coned=dmu_scripts.ssh.coned:main
|
6
|
+
print_trees=dmu_scripts.rfile.print_trees:main
|
7
|
+
transform_text=dmu_scripts.text.transform_text:main
|
8
|
+
|
dmu/generic/hashing.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
Module with functions needed to provide hashes
|
3
3
|
'''
|
4
4
|
|
5
|
+
import os
|
5
6
|
import json
|
6
7
|
import hashlib
|
7
8
|
from typing import Any
|
@@ -12,12 +13,10 @@ from dmu.logging.log_store import LogStore
|
|
12
13
|
log=LogStore.add_logger('dmu:generic.hashing')
|
13
14
|
# ------------------------------------
|
14
15
|
def _object_to_string(obj : Any) -> str:
|
15
|
-
|
16
|
-
|
17
|
-
except Exception as exc:
|
18
|
-
raise ValueError(f'Cannot hash object: {obj}') from exc
|
16
|
+
def default_encoder(x):
|
17
|
+
raise TypeError(f"Unserializable type: {type(x)}")
|
19
18
|
|
20
|
-
return
|
19
|
+
return json.dumps(obj, sort_keys=True, default=default_encoder)
|
21
20
|
# ------------------------------------
|
22
21
|
def _dataframe_to_hash(df : pnd.DataFrame) -> str:
|
23
22
|
sr_hash = pnd.util.hash_pandas_object(df, index=True)
|
@@ -29,16 +28,43 @@ def _dataframe_to_hash(df : pnd.DataFrame) -> str:
|
|
29
28
|
# ------------------------------------
|
30
29
|
def hash_object(obj : Any) -> str:
|
31
30
|
'''
|
32
|
-
Function taking a python object and returning
|
31
|
+
Function taking a python object and returning
|
33
32
|
a string representing the hash
|
34
33
|
'''
|
35
34
|
|
36
35
|
if isinstance(obj, pnd.DataFrame):
|
37
|
-
|
36
|
+
value = _dataframe_to_hash(df=obj)
|
37
|
+
value = value[:10]
|
38
|
+
|
39
|
+
return value
|
38
40
|
|
39
41
|
string = _object_to_string(obj=obj)
|
40
42
|
string_bin = string.encode('utf-8')
|
41
43
|
hsh = hashlib.sha256(string_bin)
|
44
|
+
value = hsh.hexdigest()
|
45
|
+
value = value[:10]
|
46
|
+
|
47
|
+
return value
|
48
|
+
# ------------------------------------
|
49
|
+
def hash_file(path : str) -> str:
|
50
|
+
'''
|
51
|
+
Parameters
|
52
|
+
----------------
|
53
|
+
path: Path to file whose content has to be hashed
|
54
|
+
|
55
|
+
Returns
|
56
|
+
----------------
|
57
|
+
A string representing the hash
|
58
|
+
'''
|
59
|
+
if not os.path.isfile(path):
|
60
|
+
raise FileNotFoundError(f'Cannot find: {path}')
|
61
|
+
|
62
|
+
h = hashlib.sha256()
|
63
|
+
with open(path, 'rb') as f:
|
64
|
+
for chunk in iter(lambda: f.read(8192), b''):
|
65
|
+
h.update(chunk)
|
66
|
+
|
67
|
+
value = h.hexdigest()
|
42
68
|
|
43
|
-
return
|
69
|
+
return value[:10]
|
44
70
|
# ------------------------------------
|
dmu/generic/utilities.py
CHANGED
@@ -4,17 +4,69 @@ Module containing generic utility functions
|
|
4
4
|
import os
|
5
5
|
import time
|
6
6
|
import json
|
7
|
+
import pickle
|
7
8
|
import inspect
|
8
|
-
|
9
|
-
from typing
|
10
|
-
|
9
|
+
from importlib.resources import files
|
10
|
+
from typing import Callable, Any
|
11
11
|
from functools import wraps
|
12
|
+
from contextlib import contextmanager
|
13
|
+
|
14
|
+
import yaml
|
15
|
+
from omegaconf import OmegaConf, DictConfig
|
16
|
+
from dmu.generic import hashing
|
17
|
+
from dmu.generic import utilities as gut
|
12
18
|
from dmu.logging.log_store import LogStore
|
13
19
|
|
14
20
|
TIMER_ON=False
|
15
21
|
|
16
22
|
log = LogStore.add_logger('dmu:generic:utilities')
|
23
|
+
# --------------------------------
|
24
|
+
class BlockStyleDumper(yaml.SafeDumper):
|
25
|
+
'''
|
26
|
+
Class needed to specify proper indentation when
|
27
|
+
dumping data to YAML files
|
28
|
+
'''
|
29
|
+
def increase_indent(self, flow=False, indentless=False):
|
30
|
+
return super().increase_indent(flow=flow, indentless=False)
|
31
|
+
# ---------------------------------
|
32
|
+
def load_data(package : str, fpath : str) -> Any:
|
33
|
+
'''
|
34
|
+
This function will load a YAML or JSON file from a data package
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
---------------------
|
38
|
+
package: Data package, e.g. `dmu_data`
|
39
|
+
path : Path to YAML/JSON file, relative to the data package
|
40
|
+
|
41
|
+
Returns
|
42
|
+
---------------------
|
43
|
+
Dictionary or whatever structure the file is holding
|
44
|
+
'''
|
45
|
+
|
46
|
+
cpath = files(package).joinpath(fpath)
|
47
|
+
cpath = str(cpath)
|
48
|
+
data = load_json(cpath)
|
49
|
+
|
50
|
+
return data
|
51
|
+
# --------------------------------
|
52
|
+
def load_conf(package : str, fpath : str) -> DictConfig:
|
53
|
+
'''
|
54
|
+
This function will load a YAML or JSON file from a data package
|
55
|
+
|
56
|
+
Parameters
|
57
|
+
---------------------
|
58
|
+
package: Data package, e.g. `dmu_data`
|
59
|
+
path : Path to YAML/JSON file, relative to the data package
|
60
|
+
|
61
|
+
Returns
|
62
|
+
---------------------
|
63
|
+
DictConfig class from the OmegaConf package
|
64
|
+
'''
|
65
|
+
|
66
|
+
cpath = files(package).joinpath(fpath)
|
67
|
+
cfg = OmegaConf.load(cpath)
|
17
68
|
|
69
|
+
return cfg
|
18
70
|
# --------------------------------
|
19
71
|
def _get_module_name( fun : Callable) -> str:
|
20
72
|
mod = inspect.getmodule(fun)
|
@@ -28,7 +80,7 @@ def timeit(f):
|
|
28
80
|
Decorator used to time functions, it is turned off by default, can be turned on with:
|
29
81
|
|
30
82
|
from dmu.generic.utilities import TIMER_ON
|
31
|
-
from dmu.generic.utilities import timeit
|
83
|
+
from dmu.generic.utilities import timeit
|
32
84
|
|
33
85
|
TIMER_ON=True
|
34
86
|
|
@@ -54,29 +106,130 @@ def timeit(f):
|
|
54
106
|
# --------------------------------
|
55
107
|
def dump_json(data, path : str, sort_keys : bool = False) -> None:
|
56
108
|
'''
|
57
|
-
Saves data as JSON
|
109
|
+
Saves data as JSON or YAML, depending on the extension, supported .json, .yaml, .yml
|
58
110
|
|
59
111
|
Parameters
|
60
112
|
data : dictionary, list, etc
|
61
|
-
path : Path to
|
62
|
-
sort_keys: Will set sort_keys argument of json.dump function
|
113
|
+
path : Path to output file where to save it
|
114
|
+
sort_keys: Will set sort_keys argument of json.dump function
|
63
115
|
'''
|
64
116
|
dir_name = os.path.dirname(path)
|
65
117
|
os.makedirs(dir_name, exist_ok=True)
|
66
118
|
|
67
119
|
with open(path, 'w', encoding='utf-8') as ofile:
|
68
|
-
|
120
|
+
if path.endswith('.json'):
|
121
|
+
json.dump(data, ofile, indent=4, sort_keys=sort_keys)
|
122
|
+
return
|
123
|
+
|
124
|
+
if path.endswith('.yaml') or path.endswith('.yml'):
|
125
|
+
yaml.dump(data, ofile, Dumper=BlockStyleDumper, sort_keys=sort_keys)
|
126
|
+
return
|
127
|
+
|
128
|
+
raise NotImplementedError(f'Cannot deduce format from extension in path: {path}')
|
69
129
|
# --------------------------------
|
70
130
|
def load_json(path : str):
|
71
131
|
'''
|
72
|
-
Loads data from JSON
|
132
|
+
Loads data from JSON or YAML, depending on extension of files, supported .json, .yaml, .yml
|
73
133
|
|
74
134
|
Parameters
|
75
|
-
path : Path to
|
135
|
+
path : Path to outut file where data is saved
|
76
136
|
'''
|
77
137
|
|
78
138
|
with open(path, encoding='utf-8') as ofile:
|
79
|
-
|
139
|
+
if path.endswith('.json'):
|
140
|
+
data = json.load(ofile)
|
141
|
+
return data
|
142
|
+
|
143
|
+
if path.endswith('.yaml') or path.endswith('.yml'):
|
144
|
+
data = yaml.safe_load(ofile)
|
145
|
+
return data
|
146
|
+
|
147
|
+
raise NotImplementedError(f'Cannot deduce format from extension in path: {path}')
|
148
|
+
# --------------------------------
|
149
|
+
def dump_pickle(data, path : str) -> None:
|
150
|
+
'''
|
151
|
+
Saves data as pickle file
|
152
|
+
|
153
|
+
Parameters
|
154
|
+
data : dictionary, list, etc
|
155
|
+
path : Path to output file where to save it
|
156
|
+
'''
|
157
|
+
dir_name = os.path.dirname(path)
|
158
|
+
os.makedirs(dir_name, exist_ok=True)
|
159
|
+
|
160
|
+
with open(path, 'wb') as ofile:
|
161
|
+
pickle.dump(data, ofile)
|
162
|
+
# --------------------------------
|
163
|
+
def load_pickle(path : str) -> None:
|
164
|
+
'''
|
165
|
+
loads data file
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
path : Path to output file where to save it
|
169
|
+
'''
|
170
|
+
with open(path, 'rb') as ofile:
|
171
|
+
data = pickle.load(ofile)
|
80
172
|
|
81
173
|
return data
|
82
174
|
# --------------------------------
|
175
|
+
@contextmanager
|
176
|
+
def silent_import():
|
177
|
+
'''
|
178
|
+
In charge of suppressing messages
|
179
|
+
of imported modules
|
180
|
+
'''
|
181
|
+
saved_stdout_fd = os.dup(1)
|
182
|
+
saved_stderr_fd = os.dup(2)
|
183
|
+
|
184
|
+
with open(os.devnull, 'w', encoding='utf-8') as devnull:
|
185
|
+
os.dup2(devnull.fileno(), 1)
|
186
|
+
os.dup2(devnull.fileno(), 2)
|
187
|
+
try:
|
188
|
+
yield
|
189
|
+
finally:
|
190
|
+
os.dup2(saved_stdout_fd, 1)
|
191
|
+
os.dup2(saved_stderr_fd, 2)
|
192
|
+
os.close(saved_stdout_fd)
|
193
|
+
os.close(saved_stderr_fd)
|
194
|
+
# --------------------------------
|
195
|
+
# Caching
|
196
|
+
# --------------------------------
|
197
|
+
def cache_data(obj : Any, hash_obj : Any) -> None:
|
198
|
+
'''
|
199
|
+
Will save data to a text file using a name from a hash
|
200
|
+
|
201
|
+
Parameters
|
202
|
+
-----------
|
203
|
+
obj : Object that can be saved to a text file, e.g. list, number, dictionary
|
204
|
+
hash_obj : Object that can be used to get hash e.g. immutable
|
205
|
+
'''
|
206
|
+
try:
|
207
|
+
json.dumps(obj)
|
208
|
+
except Exception as exc:
|
209
|
+
raise ValueError('Object is not JSON serializable') from exc
|
210
|
+
|
211
|
+
val = hashing.hash_object(hash_obj)
|
212
|
+
path = f'/tmp/dmu/cache/{val}.json'
|
213
|
+
gut.dump_json(obj, path)
|
214
|
+
# --------------------------------
|
215
|
+
def load_cached(hash_obj : Any, on_fail : Any = None) -> Any:
|
216
|
+
'''
|
217
|
+
Loads data corresponding to hash from hash_obj
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
---------------
|
221
|
+
hash_obj: Object used to calculate hash, which is in the file name
|
222
|
+
on_fail : Value returned if no data was found.
|
223
|
+
By default None, and it will just raise a FileNotFoundError
|
224
|
+
'''
|
225
|
+
val = hashing.hash_object(hash_obj)
|
226
|
+
path = f'/tmp/dmu/cache/{val}.json'
|
227
|
+
if os.path.isfile(path):
|
228
|
+
data = gut.load_json(path)
|
229
|
+
return data
|
230
|
+
|
231
|
+
if on_fail is not None:
|
232
|
+
return on_fail
|
233
|
+
|
234
|
+
raise FileNotFoundError(f'Cannot find cached data at: {path}')
|
235
|
+
# --------------------------------
|
dmu/logging/log_store.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
'''
|
2
2
|
Module holding LogStore
|
3
3
|
'''
|
4
|
-
|
5
4
|
import logging
|
6
|
-
|
5
|
+
import contextlib
|
6
|
+
from typing import Union
|
7
7
|
|
8
|
+
from logging import Logger
|
8
9
|
import logzero
|
9
10
|
|
10
11
|
#------------------------------------------------------------
|
@@ -40,6 +41,36 @@ class LogStore:
|
|
40
41
|
backend = 'logging'
|
41
42
|
#--------------------------
|
42
43
|
@staticmethod
|
44
|
+
@contextlib.contextmanager
|
45
|
+
def level(name : str, lvl : int) -> None:
|
46
|
+
'''
|
47
|
+
Context manager used to set the logging level of a given logger
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
------------------
|
51
|
+
name : Name of logger
|
52
|
+
lvl : Integer representing logging level
|
53
|
+
'''
|
54
|
+
log = LogStore.get_logger(name=name)
|
55
|
+
if log is None:
|
56
|
+
raise ValueError(f'Cannot find logger {name}')
|
57
|
+
|
58
|
+
old_lvl = log.getEffectiveLevel()
|
59
|
+
|
60
|
+
LogStore.set_level(name, lvl)
|
61
|
+
try:
|
62
|
+
yield
|
63
|
+
finally:
|
64
|
+
LogStore.set_level(name, old_lvl)
|
65
|
+
#--------------------------
|
66
|
+
@staticmethod
|
67
|
+
def get_logger(name : str) -> Union[Logger,None]:
|
68
|
+
'''
|
69
|
+
Returns logger for a given name or None, if no logger found for that name
|
70
|
+
'''
|
71
|
+
return LogStore.d_logger.get(name)
|
72
|
+
#--------------------------
|
73
|
+
@staticmethod
|
43
74
|
def add_logger(name : str, exists_ok : bool = False) -> Logger:
|
44
75
|
'''
|
45
76
|
Will use underlying logging library logzero/logging, etc to make logger
|
@@ -78,6 +109,7 @@ class LogStore:
|
|
78
109
|
@staticmethod
|
79
110
|
def _get_logging_logger(name : str, level : int) -> Logger:
|
80
111
|
logger = logging.getLogger(name=name)
|
112
|
+
logger.propagate = False
|
81
113
|
|
82
114
|
logger.setLevel(level)
|
83
115
|
|
dmu/logging/messages.py
ADDED
@@ -0,0 +1,96 @@
|
|
1
|
+
'''
|
2
|
+
Module containing code meant to deal with logging of
|
3
|
+
third party tools
|
4
|
+
'''
|
5
|
+
import os
|
6
|
+
import sys
|
7
|
+
import time
|
8
|
+
import threading
|
9
|
+
from io import StringIO
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from dmu.logging.log_store import LogStore
|
12
|
+
|
13
|
+
log = LogStore.add_logger('dmu:logging:messages')
|
14
|
+
# --------------------------------
|
15
|
+
class FilteredStderr:
|
16
|
+
'''
|
17
|
+
This class is meant to be used to filter the messages
|
18
|
+
in the error stream by substrings
|
19
|
+
'''
|
20
|
+
# --------------------------------
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
banned_substrings : list[str],
|
24
|
+
capture_stream : StringIO):
|
25
|
+
'''
|
26
|
+
Parameters
|
27
|
+
-------------
|
28
|
+
banned_substrings : List of substrings that, if found in error message, will drop error
|
29
|
+
capture_stream : Used to store error stream filtered messages, expected to be sys.__stderr__
|
30
|
+
'''
|
31
|
+
self._banned = banned_substrings
|
32
|
+
self._capture_stream = capture_stream
|
33
|
+
# --------------------------------
|
34
|
+
def write(self, message : str):
|
35
|
+
'''
|
36
|
+
Should allow filtering error messages
|
37
|
+
'''
|
38
|
+
if not any(bad in message for bad in self._banned):
|
39
|
+
# This will make it to the error messages
|
40
|
+
self._capture_stream.write(message)
|
41
|
+
# --------------------------------
|
42
|
+
def flush(self):
|
43
|
+
'''
|
44
|
+
Should override the error stream's flush method
|
45
|
+
'''
|
46
|
+
self._capture_stream.flush()
|
47
|
+
# --------------------------------
|
48
|
+
@contextmanager
|
49
|
+
def filter_stderr(
|
50
|
+
banned_substrings : list[str],
|
51
|
+
capture_stream : StringIO|None=None):
|
52
|
+
'''
|
53
|
+
This contextmanager will suppress error messages
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
-----------------
|
57
|
+
banned_substrings : List of substrings that need to be found in error messages
|
58
|
+
in order for them to be suppressed
|
59
|
+
capture_stream : Buffer needed to run tests, not needed for normal use
|
60
|
+
'''
|
61
|
+
if capture_stream is None:
|
62
|
+
capture_stream = sys.__stderr__
|
63
|
+
|
64
|
+
read_fd, write_fd = os.pipe()
|
65
|
+
saved_fd = os.dup(2)
|
66
|
+
|
67
|
+
os.dup2(write_fd, 2)
|
68
|
+
os.close(write_fd)
|
69
|
+
|
70
|
+
filtered = FilteredStderr(banned_substrings, capture_stream)
|
71
|
+
reader_finished = threading.Event()
|
72
|
+
|
73
|
+
def reader():
|
74
|
+
try:
|
75
|
+
with os.fdopen(read_fd, 'r', buffering=1) as pipe:
|
76
|
+
while True:
|
77
|
+
line = pipe.readline()
|
78
|
+
if not line:
|
79
|
+
break
|
80
|
+
filtered.write(line)
|
81
|
+
filtered.flush()
|
82
|
+
finally:
|
83
|
+
reader_finished.set()
|
84
|
+
|
85
|
+
thread = threading.Thread(target=reader, daemon=True)
|
86
|
+
thread.start()
|
87
|
+
|
88
|
+
try:
|
89
|
+
yield
|
90
|
+
finally:
|
91
|
+
os.dup2(saved_fd, 2)
|
92
|
+
os.close(saved_fd)
|
93
|
+
|
94
|
+
time.sleep(0.1)
|
95
|
+
reader_finished.wait(timeout=1.0)
|
96
|
+
# --------------------------------
|
dmu/ml/cv_classifier.py
CHANGED
@@ -37,17 +37,17 @@ class CVClassifier(GradientBoostingClassifier):
|
|
37
37
|
|
38
38
|
self._s_hash = set()
|
39
39
|
self._data = {}
|
40
|
-
self._l_ft_name
|
40
|
+
self._l_ft_name : list[str]
|
41
41
|
# ----------------------------------
|
42
42
|
@property
|
43
|
-
def features(self):
|
43
|
+
def features(self) -> list[str]:
|
44
44
|
'''
|
45
45
|
Returns list of feature names used in training dataset
|
46
46
|
'''
|
47
47
|
return self._l_ft_name
|
48
48
|
# ----------------------------------
|
49
49
|
@property
|
50
|
-
def hashes(self):
|
50
|
+
def hashes(self) -> set[str]:
|
51
51
|
'''
|
52
52
|
Will return set with hashes of training data
|
53
53
|
'''
|
dmu/ml/cv_diagnostics.py
CHANGED
@@ -186,6 +186,9 @@ class CVDiagnostics:
|
|
186
186
|
raise NotImplementedError(f'Correlation coefficient {method} not implemented')
|
187
187
|
# -------------------------
|
188
188
|
def _plot_cutflow(self) -> None:
|
189
|
+
'''
|
190
|
+
Plot the 'mass' column for different values of working point
|
191
|
+
'''
|
189
192
|
if 'overlay' not in self._cfg['correlations']['target']:
|
190
193
|
log.debug('Not plotting cutflow of target distribution')
|
191
194
|
return
|
dmu/ml/cv_performance.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
'''
|
2
|
+
This module contains the class CVPerformance
|
3
|
+
'''
|
4
|
+
# pylint: disable=too-many-positional-arguments, too-many-arguments
|
5
|
+
|
6
|
+
from ROOT import RDataFrame
|
7
|
+
from dmu.ml.cv_classifier import CVClassifier
|
8
|
+
from dmu.ml.cv_predict import CVPredict
|
9
|
+
from dmu.ml.train_mva import TrainMva
|
10
|
+
from dmu.logging.log_store import LogStore
|
11
|
+
|
12
|
+
log=LogStore.add_logger('dmu:ml:cv_performance')
|
13
|
+
# -----------------------------------------------------
|
14
|
+
class CVPerformance:
|
15
|
+
'''
|
16
|
+
This class is meant to:
|
17
|
+
|
18
|
+
- Compare the classifier performance, through the ROC curve, of a model, for a given background and signal sample
|
19
|
+
'''
|
20
|
+
# ---------------------------
|
21
|
+
def plot_roc(
|
22
|
+
self,
|
23
|
+
name : str,
|
24
|
+
color : str,
|
25
|
+
sig : RDataFrame,
|
26
|
+
bkg : RDataFrame,
|
27
|
+
model : list[CVClassifier] ) -> float:
|
28
|
+
'''
|
29
|
+
Method in charge of picking up model and data and plotting ROC curve
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
--------------------------
|
33
|
+
name : Label of combination, used for plots
|
34
|
+
sig : ROOT dataframe storing signal samples
|
35
|
+
bkg : ROOT dataframe storing background samples
|
36
|
+
model: List of instances of the CVClassifier
|
37
|
+
|
38
|
+
Returns
|
39
|
+
--------------------------
|
40
|
+
Area under the ROC curve
|
41
|
+
'''
|
42
|
+
log.info(f'Loading {name}')
|
43
|
+
|
44
|
+
cvp_sig = CVPredict(models=model, rdf=sig)
|
45
|
+
arr_sig = cvp_sig.predict()
|
46
|
+
|
47
|
+
cvp_bkg = CVPredict(models=model, rdf=bkg)
|
48
|
+
arr_bkg = cvp_bkg.predict()
|
49
|
+
|
50
|
+
_, _, auc = TrainMva.plot_roc_from_prob(
|
51
|
+
arr_sig_prb=arr_sig,
|
52
|
+
arr_bkg_prb=arr_bkg,
|
53
|
+
kind = name,
|
54
|
+
color = color, # This should allow the function to pick kind
|
55
|
+
ifold = 999) # for the label
|
56
|
+
|
57
|
+
return auc
|
58
|
+
# -----------------------------------------------------
|