data-manipulation-utilities 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
- data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
- data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
- data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
- data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
- dmu/arrays/utilities.py +55 -0
- dmu/dataframe/dataframe.py +36 -0
- dmu/generic/utilities.py +69 -0
- dmu/logging/log_store.py +129 -0
- dmu/ml/cv_classifier.py +122 -0
- dmu/ml/cv_predict.py +152 -0
- dmu/ml/train_mva.py +257 -0
- dmu/ml/utilities.py +132 -0
- dmu/plotting/plotter.py +227 -0
- dmu/plotting/plotter_1d.py +113 -0
- dmu/plotting/plotter_2d.py +87 -0
- dmu/rdataframe/atr_mgr.py +79 -0
- dmu/rdataframe/utilities.py +72 -0
- dmu/rfile/rfprinter.py +91 -0
- dmu/rfile/utilities.py +34 -0
- dmu/stats/fitter.py +515 -0
- dmu/stats/function.py +314 -0
- dmu/stats/utilities.py +134 -0
- dmu/testing/utilities.py +119 -0
- dmu/text/transformer.py +182 -0
- dmu_data/__init__.py +0 -0
- dmu_data/ml/tests/train_mva.yaml +37 -0
- dmu_data/plotting/tests/2d.yaml +14 -0
- dmu_data/plotting/tests/fig_size.yaml +13 -0
- dmu_data/plotting/tests/high_stat.yaml +22 -0
- dmu_data/plotting/tests/name.yaml +14 -0
- dmu_data/plotting/tests/no_bounds.yaml +12 -0
- dmu_data/plotting/tests/simple.yaml +8 -0
- dmu_data/plotting/tests/title.yaml +14 -0
- dmu_data/plotting/tests/weights.yaml +13 -0
- dmu_data/text/transform.toml +4 -0
- dmu_data/text/transform.txt +6 -0
- dmu_data/text/transform_set.toml +8 -0
- dmu_data/text/transform_set.txt +6 -0
- dmu_data/text/transform_trf.txt +12 -0
- dmu_scripts/physics/check_truth.py +121 -0
- dmu_scripts/rfile/compare_root_files.py +299 -0
- dmu_scripts/rfile/print_trees.py +35 -0
- dmu_scripts/ssh/coned.py +168 -0
- dmu_scripts/text/transform_text.py +46 -0
@@ -0,0 +1,45 @@
|
|
1
|
+
dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
|
2
|
+
dmu/dataframe/dataframe.py,sha256=ZgRCw5hN18gOXGL9nHDc4eNi0P8lOIAIEILmbEiTlXw,1088
|
3
|
+
dmu/generic/utilities.py,sha256=0Xnq9t35wuebAqKxbyAiMk1ISB7IcXK4cFH25MT1fgw,1741
|
4
|
+
dmu/logging/log_store.py,sha256=v0tiNz-6ktT_afD5DuvCZ8Nmr82JKQOPli8hgd28P1Q,3960
|
5
|
+
dmu/ml/cv_classifier.py,sha256=n81m7i2M6Zq96AEd9EZGwXSrbG5m9jkS5RdeXvbsAXU,3712
|
6
|
+
dmu/ml/cv_predict.py,sha256=Bqxu-f6qquKJokFljhCzL_kiGcjLJLQFhVBD130fsyw,4893
|
7
|
+
dmu/ml/train_mva.py,sha256=d_n-A07DFweikz5nXap4OE_Mqx8VprFT7zbxmnQAbac,9638
|
8
|
+
dmu/ml/utilities.py,sha256=Nue7O9zi1QXgjGRPH6wnSAW9jusMQ2ZOSDJzBqJKIi0,3687
|
9
|
+
dmu/plotting/plotter.py,sha256=laa6Kl7P-ZOIhaOFBVjOH4XQ4kPCV7wBNvLIMBnyCwM,7181
|
10
|
+
dmu/plotting/plotter_1d.py,sha256=G-i94uzm2TjNaog1A4agAKar_G0qNdkAqIPCmzhe85Y,3660
|
11
|
+
dmu/plotting/plotter_2d.py,sha256=SWPKns-CfpUZHgBXvwm3gceH3k2eL_mKGXQ8sWpZJB0,2919
|
12
|
+
dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
|
13
|
+
dmu/rdataframe/utilities.py,sha256=a31PdUz12sC2bx78LK6gvACh1M_eFaIVwuZEvOTcvcc,2084
|
14
|
+
dmu/rfile/rfprinter.py,sha256=vGdqyHT_GwGBhrY7KG63EAUGWEOqobz_5yTL6goXbfk,2722
|
15
|
+
dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
|
16
|
+
dmu/stats/fitter.py,sha256=LDvFNyhgO0OzXN7aH3kfHe6LzuPqdQfPcKR_IegDcaU,18204
|
17
|
+
dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
|
18
|
+
dmu/stats/utilities.py,sha256=LQy4kd3xSXqpApcWuYfZxkGQyjowaXv2Wr1c4Bj-4ys,4523
|
19
|
+
dmu/testing/utilities.py,sha256=WbMM4e9Cn3-B-12Vr64mB5qTKkV32joStlRkD-48lG0,3460
|
20
|
+
dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
|
21
|
+
dmu_data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
+
dmu_data/ml/tests/train_mva.yaml,sha256=TCniCVpXMEFxZcHa8IIqollKA7ci4OkBnRznLEkXM9o,925
|
23
|
+
dmu_data/plotting/tests/2d.yaml,sha256=lTMNheK3DB8klp4O5QjMDwBI1A1Oh2_Wp2F2Ro9VQKM,282
|
24
|
+
dmu_data/plotting/tests/fig_size.yaml,sha256=7ROq49nwZ1A2EbPiySmu6n3G-Jq6YAOkc3d2X3YNZv0,294
|
25
|
+
dmu_data/plotting/tests/high_stat.yaml,sha256=bLglBLCZK6ft0xMhQ5OltxE76cWsBMPMjO6GG0OkDr8,522
|
26
|
+
dmu_data/plotting/tests/name.yaml,sha256=mkcPAVg8wBAmlSbSRQ1bcaMl4vOS6LXMtpqQeDrrtO4,312
|
27
|
+
dmu_data/plotting/tests/no_bounds.yaml,sha256=8e1QdphBjz-suDr857DoeUC2DXiy6SE-gvkORJQYv80,257
|
28
|
+
dmu_data/plotting/tests/simple.yaml,sha256=N_TvNBh_2dU0-VYgu_LMrtY0kV_hg2HxVuEoDlr1HX8,138
|
29
|
+
dmu_data/plotting/tests/title.yaml,sha256=bawKp9aGpeRrHzv69BOCbFX8sq9bb3Es9tdsPTE7jIk,333
|
30
|
+
dmu_data/plotting/tests/weights.yaml,sha256=RWQ1KxbCq-uO62WJ2AoY4h5Umc37zG35s-TpKnNMABI,312
|
31
|
+
dmu_data/text/transform.toml,sha256=R-832BZalzHZ6c5gD6jtT_Hj8BCsM5vxa1v6oeiwaP4,94
|
32
|
+
dmu_data/text/transform.txt,sha256=EX760da6Vkf-_EPxnQlC5hGSkfFhJCCGCD19NU-1Qto,44
|
33
|
+
dmu_data/text/transform_set.toml,sha256=Jeh7BTz82idqvbOQJtl9-ur56mZkzDn5WtvmIb48LoE,150
|
34
|
+
dmu_data/text/transform_set.txt,sha256=1KivMoP9LxPn9955QrRmOzjEqduEjhTetQ9MXykO5LY,46
|
35
|
+
dmu_data/text/transform_trf.txt,sha256=zxBRTgcSmX7RdqfmWF88W1YqbyNHa4Ccruf1MmnYv2A,74
|
36
|
+
dmu_scripts/physics/check_truth.py,sha256=b1P_Pa9ef6VcFtyY6Y9KS9Om9L-QrCBjDKp4dqca0PQ,3964
|
37
|
+
dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki7ZQynxXX9Q,9540
|
38
|
+
dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
|
39
|
+
dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
|
40
|
+
dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
|
41
|
+
data_manipulation_utilities-0.0.1.dist-info/METADATA,sha256=chAPGy68TwWTweqqzGgXvrD6-xia1Bq9XDgvgS1qLEE,19714
|
42
|
+
data_manipulation_utilities-0.0.1.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
43
|
+
data_manipulation_utilities-0.0.1.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
|
44
|
+
data_manipulation_utilities-0.0.1.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
|
45
|
+
data_manipulation_utilities-0.0.1.dist-info/RECORD,,
|
@@ -0,0 +1,6 @@
|
|
1
|
+
[console_scripts]
|
2
|
+
check_truth = dmu_scripts.physics.check_truth:main
|
3
|
+
compare_root_files = dmu_scripts.rfile.compare_root_files:main
|
4
|
+
coned = dmu_scripts.ssh.coned:main
|
5
|
+
print_trees = dmu_scripts.rfile.print_trees:main
|
6
|
+
transform_text = dmu_scripts.text.transform_text:main
|
dmu/arrays/utilities.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
'''
|
2
|
+
Module with utility functions to handle numpy arrays
|
3
|
+
'''
|
4
|
+
|
5
|
+
import math
|
6
|
+
import numpy
|
7
|
+
|
8
|
+
#-----------------------------------------------
|
9
|
+
def _check_ftimes(ftimes : float) -> None:
|
10
|
+
'''
|
11
|
+
Check if floating scale factor makes sense
|
12
|
+
'''
|
13
|
+
if not isinstance(ftimes, float):
|
14
|
+
raise TypeError(f'Scaling factor is not a float, but: {ftimes}')
|
15
|
+
|
16
|
+
if ftimes <= 1.0:
|
17
|
+
raise ValueError(f'Scaling factor needs to be larger than 1.0, found: {ftimes}')
|
18
|
+
#-----------------------------------------------
|
19
|
+
def repeat_arr(arr_val : numpy.ndarray, ftimes : float) -> numpy.ndarray:
|
20
|
+
'''
|
21
|
+
Will repeat elements in an array a non integer number of times.
|
22
|
+
|
23
|
+
arr_val: 1D array of objects
|
24
|
+
ftimes (float): Number of times to repeat it.
|
25
|
+
'''
|
26
|
+
|
27
|
+
_check_ftimes(ftimes)
|
28
|
+
|
29
|
+
a = math.floor(ftimes)
|
30
|
+
b = math.ceil(ftimes)
|
31
|
+
if numpy.isclose(a, b):
|
32
|
+
return numpy.repeat(arr_val, a)
|
33
|
+
|
34
|
+
# Will split randomly data in arr_val, such that one set will get increased
|
35
|
+
# by floor(ftimes) and the other by ceiling(ftimes)
|
36
|
+
|
37
|
+
# Get probability that given element belongs to dataset weighted by "a"
|
38
|
+
p = b - ftimes
|
39
|
+
size_t = len(arr_val)
|
40
|
+
size_1 = int(p * size_t)
|
41
|
+
|
42
|
+
# Find subset to weight by "a"
|
43
|
+
arr_ind_1 = numpy.random.choice(size_t, size=size_1, replace=False)
|
44
|
+
arr_val_1 = arr_val[arr_ind_1]
|
45
|
+
|
46
|
+
# Find subset to weight by "b"
|
47
|
+
arr_ind_2 = numpy.setdiff1d(numpy.arange(size_t), arr_ind_1)
|
48
|
+
arr_val_2 = arr_val[arr_ind_2]
|
49
|
+
|
50
|
+
# Repeat them an integer number of times
|
51
|
+
arr_val_1 = numpy.repeat(arr_val_1, a)
|
52
|
+
arr_val_2 = numpy.repeat(arr_val_2, b)
|
53
|
+
|
54
|
+
return numpy.concatenate([arr_val_1, arr_val_2])
|
55
|
+
#-----------------------------------------------
|
@@ -0,0 +1,36 @@
|
|
1
|
+
'''
|
2
|
+
Class re-implementing dataframe, meant to be a thin layer on top of polars
|
3
|
+
'''
|
4
|
+
|
5
|
+
import polars as pl
|
6
|
+
|
7
|
+
# ------------------------------------------
|
8
|
+
class DataFrame(pl.DataFrame):
|
9
|
+
'''
|
10
|
+
Class reimplementing dataframes from polars
|
11
|
+
'''
|
12
|
+
# ------------------------------------------
|
13
|
+
def __init__(self, *args, **kwargs):
|
14
|
+
super().__init__(*args, **kwargs)
|
15
|
+
# ------------------------------------------
|
16
|
+
def define(self, name : str, expr : str):
|
17
|
+
'''
|
18
|
+
Function will define new column in dataframe
|
19
|
+
|
20
|
+
name (str): Name of new column
|
21
|
+
expr (str): Expression depending on existing columns
|
22
|
+
'''
|
23
|
+
|
24
|
+
for col in self.columns:
|
25
|
+
expr = expr.replace(col, f' {col} ')
|
26
|
+
|
27
|
+
for col in self.columns:
|
28
|
+
expr = expr.replace(f' {col} ', f' pl.col("{col}") ')
|
29
|
+
|
30
|
+
try:
|
31
|
+
df = self.with_columns(eval(expr).alias(name))
|
32
|
+
except TypeError as exc:
|
33
|
+
raise TypeError(f'Cannot define {expr} -> {name}') from exc
|
34
|
+
|
35
|
+
return DataFrame(df)
|
36
|
+
# ------------------------------------------
|
dmu/generic/utilities.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
'''
|
2
|
+
Module containing generic utility functions
|
3
|
+
'''
|
4
|
+
import os
|
5
|
+
import time
|
6
|
+
import json
|
7
|
+
import inspect
|
8
|
+
|
9
|
+
from typing import Callable
|
10
|
+
|
11
|
+
from functools import wraps
|
12
|
+
from dmu.logging.log_store import LogStore
|
13
|
+
|
14
|
+
TIMER_ON=False
|
15
|
+
|
16
|
+
log = LogStore.add_logger('dmu:generic:utilities')
|
17
|
+
|
18
|
+
# --------------------------------
|
19
|
+
def _get_module_name( fun : Callable) -> str:
|
20
|
+
mod = inspect.getmodule(fun)
|
21
|
+
if mod is None:
|
22
|
+
raise ValueError(f'Cannot determine module name for function: {fun}')
|
23
|
+
|
24
|
+
return mod.__name__
|
25
|
+
# --------------------------------
|
26
|
+
def timeit(f):
|
27
|
+
'''
|
28
|
+
Decorator used to time functions, it is turned off by default, can be turned on with:
|
29
|
+
|
30
|
+
from dmu.generic.utilities import TIMER_ON
|
31
|
+
from dmu.generic.utilities import timeit
|
32
|
+
|
33
|
+
TIMER_ON=True
|
34
|
+
|
35
|
+
@timeit
|
36
|
+
def fun():
|
37
|
+
...
|
38
|
+
'''
|
39
|
+
@wraps(f)
|
40
|
+
def wrap(*args, **kw):
|
41
|
+
if not TIMER_ON:
|
42
|
+
result = f(*args, **kw)
|
43
|
+
return result
|
44
|
+
|
45
|
+
ts = time.time()
|
46
|
+
result = f(*args, **kw)
|
47
|
+
te = time.time()
|
48
|
+
mod_nam = _get_module_name(f)
|
49
|
+
fun_nam = f.__name__
|
50
|
+
log.info(f'{mod_nam}.py:{fun_nam}; Time: {te-ts:.3f}s')
|
51
|
+
|
52
|
+
return result
|
53
|
+
return wrap
|
54
|
+
# --------------------------------
|
55
|
+
def dump_json(data, path : str, sort_keys : bool = False):
|
56
|
+
'''
|
57
|
+
Saves data as JSON
|
58
|
+
|
59
|
+
Parameters
|
60
|
+
data : dictionary, list, etc
|
61
|
+
path : Path to JSON file where to save it
|
62
|
+
sort_keys: Will set sort_keys argument of json.dump function
|
63
|
+
'''
|
64
|
+
dir_name = os.path.dirname(path)
|
65
|
+
os.makedirs(dir_name, exist_ok=True)
|
66
|
+
|
67
|
+
with open(path, 'w', encoding='utf-8') as ofile:
|
68
|
+
json.dump(data, ofile, indent=4, sort_keys=sort_keys)
|
69
|
+
# --------------------------------
|
dmu/logging/log_store.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
'''
|
2
|
+
Module holding LogStore
|
3
|
+
'''
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import logzero
|
7
|
+
|
8
|
+
#------------------------------------------------------------
|
9
|
+
class StoreFormater(logging.Formatter):
|
10
|
+
'''
|
11
|
+
Custom formatter
|
12
|
+
'''
|
13
|
+
LOG_COLORS = {
|
14
|
+
logging.DEBUG : '\033[94m', # Gray
|
15
|
+
logging.INFO : '\033[37m', # White
|
16
|
+
logging.WARNING : '\033[93m', # Yellow
|
17
|
+
logging.ERROR : '\033[91m', # Red
|
18
|
+
logging.CRITICAL: '\033[1;91m' # Bold Red
|
19
|
+
}
|
20
|
+
|
21
|
+
RESET_COLOR = '\033[0m' # Reset color to default
|
22
|
+
|
23
|
+
def format(self, record):
|
24
|
+
log_color = self.LOG_COLORS.get(record.levelno, self.RESET_COLOR)
|
25
|
+
message = super().format(record)
|
26
|
+
|
27
|
+
return f'{log_color}{message}{self.RESET_COLOR}'
|
28
|
+
#------------------------------------------------------------
|
29
|
+
class LogStore:
|
30
|
+
'''
|
31
|
+
Class used to make loggers, set log levels, print loggers, e.g. interface to logging/logzero, etc.
|
32
|
+
'''
|
33
|
+
#pylint: disable = invalid-name
|
34
|
+
d_logger = {}
|
35
|
+
d_levels = {}
|
36
|
+
log_level = logging.INFO
|
37
|
+
is_configured = False
|
38
|
+
backend = 'logging'
|
39
|
+
#--------------------------
|
40
|
+
@staticmethod
|
41
|
+
def add_logger(name=None):
|
42
|
+
'''
|
43
|
+
Will use underlying logging library logzero/logging, etc to make logger
|
44
|
+
|
45
|
+
name (str): Name of logger
|
46
|
+
'''
|
47
|
+
|
48
|
+
if name is None:
|
49
|
+
raise ValueError('Logger name missing')
|
50
|
+
|
51
|
+
if name in LogStore.d_logger:
|
52
|
+
raise ValueError(f'Logger name {name} already found')
|
53
|
+
|
54
|
+
level = LogStore.log_level if name not in LogStore.d_levels else LogStore.d_levels[name]
|
55
|
+
|
56
|
+
if LogStore.backend == 'logging':
|
57
|
+
logger = LogStore._get_logging_logger(name, level)
|
58
|
+
elif LogStore.backend == 'logzero':
|
59
|
+
logger = LogStore._get_logzero_logger(name, level)
|
60
|
+
else:
|
61
|
+
raise ValueError(f'Invalid backend: {LogStore.backend}')
|
62
|
+
|
63
|
+
LogStore.d_logger[name] = logger
|
64
|
+
|
65
|
+
return logger
|
66
|
+
#--------------------------
|
67
|
+
@staticmethod
|
68
|
+
def _get_logzero_logger(name : str, level : int):
|
69
|
+
log = logzero.setup_logger(name=name)
|
70
|
+
log.setLevel(level)
|
71
|
+
|
72
|
+
return log
|
73
|
+
#--------------------------
|
74
|
+
@staticmethod
|
75
|
+
def _get_logging_logger(name : str, level : int):
|
76
|
+
logger = logging.getLogger(name=name)
|
77
|
+
|
78
|
+
logger.setLevel(level)
|
79
|
+
|
80
|
+
hnd= logging.StreamHandler()
|
81
|
+
hnd.setLevel(level)
|
82
|
+
|
83
|
+
fmt= StoreFormater('%(asctime)s - %(filename)s:%(lineno)d - %(message)s', datefmt='%H:%M:%S')
|
84
|
+
hnd.setFormatter(fmt)
|
85
|
+
|
86
|
+
if logger.hasHandlers():
|
87
|
+
logger.handlers.clear()
|
88
|
+
|
89
|
+
logger.addHandler(hnd)
|
90
|
+
|
91
|
+
return logger
|
92
|
+
#--------------------------
|
93
|
+
@staticmethod
|
94
|
+
def set_level(name, value):
|
95
|
+
'''
|
96
|
+
Will set the level of a logger, it not present yet, it will store the level and set it when created.
|
97
|
+
Parameters:
|
98
|
+
-----------------
|
99
|
+
name (str): Name of logger
|
100
|
+
value (int): 10 debug, 20 info, 30 warning
|
101
|
+
'''
|
102
|
+
|
103
|
+
if name in LogStore.d_logger:
|
104
|
+
lgr=LogStore.d_logger[name]
|
105
|
+
lgr.handlers[0].setLevel(value)
|
106
|
+
lgr.setLevel(value)
|
107
|
+
else:
|
108
|
+
LogStore.d_levels[name] = value
|
109
|
+
#--------------------------
|
110
|
+
@staticmethod
|
111
|
+
def show_loggers():
|
112
|
+
'''
|
113
|
+
Will print loggers and log levels in two columns
|
114
|
+
'''
|
115
|
+
print(80 * '-')
|
116
|
+
print(f'{"Name":<60}{"Level":<20}')
|
117
|
+
print(80 * '-')
|
118
|
+
for name, logger in LogStore.d_logger.items():
|
119
|
+
print(f'{name:<60}{logger.level:<20}')
|
120
|
+
#--------------------------
|
121
|
+
@staticmethod
|
122
|
+
def set_all_levels(level):
|
123
|
+
'''
|
124
|
+
Will set all loggers to this level (int)
|
125
|
+
'''
|
126
|
+
for name, logger in LogStore.d_logger.items():
|
127
|
+
logger.setLevel(level)
|
128
|
+
print(f'{name:<60}{"->":20}{logger.level:<20}')
|
129
|
+
#------------------------------------------------------------
|
dmu/ml/cv_classifier.py
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
'''
|
2
|
+
Module holding cv_classifier class
|
3
|
+
'''
|
4
|
+
|
5
|
+
from sklearn.ensemble import GradientBoostingClassifier
|
6
|
+
|
7
|
+
from dmu.logging.log_store import LogStore
|
8
|
+
import dmu.ml.utilities as ut
|
9
|
+
|
10
|
+
log = LogStore.add_logger('dmu:ml:CVClassifier')
|
11
|
+
|
12
|
+
# ---------------------------------------
|
13
|
+
class CVSameData(Exception):
|
14
|
+
'''
|
15
|
+
Will be raised if a model is been evaluated with a dataset such that at least one of the
|
16
|
+
samples was also used for the training
|
17
|
+
'''
|
18
|
+
# ---------------------------------------
|
19
|
+
class CVClassifier(GradientBoostingClassifier):
|
20
|
+
'''
|
21
|
+
Derived class meant to implement features needed for cross-validation
|
22
|
+
'''
|
23
|
+
# pylint: disable = too-many-ancestors, abstract-method
|
24
|
+
# ----------------------------------
|
25
|
+
def __init__(self, cfg : dict | None = None):
|
26
|
+
'''
|
27
|
+
cfg (dict) : Dictionary with configuration, specially the hyperparameters set in the `hyper` field
|
28
|
+
'''
|
29
|
+
if cfg is None:
|
30
|
+
raise ValueError('No configuration was passed')
|
31
|
+
|
32
|
+
self._cfg = cfg
|
33
|
+
|
34
|
+
d_hyp = self._cfg['training']['hyper']
|
35
|
+
super().__init__(**d_hyp)
|
36
|
+
|
37
|
+
self._s_hash = set()
|
38
|
+
self._data = {}
|
39
|
+
self._l_ft_name = None
|
40
|
+
# ----------------------------------
|
41
|
+
@property
|
42
|
+
def features(self):
|
43
|
+
'''
|
44
|
+
Returns list of feature names used in training dataset
|
45
|
+
'''
|
46
|
+
return self._l_ft_name
|
47
|
+
# ----------------------------------
|
48
|
+
@property
|
49
|
+
def hashes(self):
|
50
|
+
'''
|
51
|
+
Will return set with hashes of training data
|
52
|
+
'''
|
53
|
+
return self._s_hash
|
54
|
+
# ----------------------------------
|
55
|
+
@property
|
56
|
+
def cfg(self):
|
57
|
+
'''
|
58
|
+
Will return dictionary with configuration
|
59
|
+
'''
|
60
|
+
|
61
|
+
return self._cfg
|
62
|
+
# ----------------------------------
|
63
|
+
def __str__(self):
|
64
|
+
nhash = len(self._s_hash)
|
65
|
+
|
66
|
+
msg = 40 * '-' + '\n'
|
67
|
+
msg+= f'{"Attribute":<20}{"Value":<20}\n'
|
68
|
+
msg+= 40 * '-' + '\n'
|
69
|
+
msg += f'{"Hashes":<20}{nhash:<20}\n'
|
70
|
+
msg+= 40 * '-'
|
71
|
+
|
72
|
+
return msg
|
73
|
+
# ----------------------------------
|
74
|
+
def fit(self, *args, **kwargs):
|
75
|
+
'''
|
76
|
+
Runs the training of the model
|
77
|
+
'''
|
78
|
+
log.debug('Fitting')
|
79
|
+
|
80
|
+
df_ft = args[0]
|
81
|
+
self._l_ft_name = list(df_ft.columns)
|
82
|
+
|
83
|
+
self._s_hash = ut.get_hashes(df_ft)
|
84
|
+
log.debug(f'Saving {len(self._s_hash)} hashes')
|
85
|
+
|
86
|
+
super().fit(*args, **kwargs)
|
87
|
+
|
88
|
+
return self
|
89
|
+
# ----------------------------------
|
90
|
+
def _check_hashes(self, df_ft):
|
91
|
+
'''
|
92
|
+
Will check that the hashes of the passed features do not intersect with the
|
93
|
+
hashes of the features used for the training.
|
94
|
+
Else it will raise CVSameData exception
|
95
|
+
'''
|
96
|
+
|
97
|
+
if len(self._s_hash) == 0:
|
98
|
+
raise ValueError('Found no hashes in model')
|
99
|
+
|
100
|
+
s_hash = ut.get_hashes(df_ft)
|
101
|
+
s_inter = self._s_hash.intersection(s_hash)
|
102
|
+
|
103
|
+
nh1 = len(self._s_hash)
|
104
|
+
nh2 = len( s_hash)
|
105
|
+
nh3 = len(s_inter)
|
106
|
+
|
107
|
+
if nh3 > 0:
|
108
|
+
raise CVSameData(f'Found non empty intersection of size: {nh1} ^ {nh2} = {nh3}')
|
109
|
+
# ----------------------------------
|
110
|
+
def predict_proba(self, X, on_training_ok=False):
|
111
|
+
'''
|
112
|
+
Takes pandas dataframe with features
|
113
|
+
Will first check hashes to make sure none of the events/samples
|
114
|
+
used for the training of this model are in the prediction
|
115
|
+
|
116
|
+
on_training_ok (bool): True if the dataset is expected to contain samples used for training, default is False
|
117
|
+
'''
|
118
|
+
if not on_training_ok:
|
119
|
+
self._check_hashes(X)
|
120
|
+
|
121
|
+
return super().predict_proba(X)
|
122
|
+
# ---------------------------------------
|
dmu/ml/cv_predict.py
ADDED
@@ -0,0 +1,152 @@
|
|
1
|
+
'''
|
2
|
+
Module holding CVPredict class
|
3
|
+
'''
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import pandas as pnd
|
7
|
+
import numpy
|
8
|
+
import tqdm
|
9
|
+
|
10
|
+
from ROOT import RDataFrame
|
11
|
+
|
12
|
+
import dmu.ml.utilities as ut
|
13
|
+
import dmu.ml.cv_classifier as CVClassifier
|
14
|
+
|
15
|
+
from dmu.logging.log_store import LogStore
|
16
|
+
|
17
|
+
log = LogStore.add_logger('dmu:ml:cv_predict')
|
18
|
+
# ---------------------------------------
|
19
|
+
class CVPredict:
|
20
|
+
'''
|
21
|
+
Class used to get classification probabilities from ROOT
|
22
|
+
dataframe and a set of models. The models were trained with CVClassifier
|
23
|
+
'''
|
24
|
+
def __init__(self, models : Optional[list] = None, rdf : Optional[RDataFrame] = None):
|
25
|
+
'''
|
26
|
+
Will take a list of CVClassifier models and a ROOT dataframe
|
27
|
+
'''
|
28
|
+
|
29
|
+
if models is None:
|
30
|
+
raise ValueError('No list of models passed')
|
31
|
+
|
32
|
+
if rdf is None:
|
33
|
+
raise ValueError('No ROOT dataframe passed')
|
34
|
+
|
35
|
+
self._l_model = models
|
36
|
+
self._rdf = rdf
|
37
|
+
|
38
|
+
self._arr_patch : numpy.ndarray
|
39
|
+
# --------------------------------------------
|
40
|
+
def _get_df(self):
|
41
|
+
'''
|
42
|
+
Will make ROOT rdf into dataframe and return it
|
43
|
+
'''
|
44
|
+
model = self._l_model[0]
|
45
|
+
l_ft = model.features
|
46
|
+
d_data= self._rdf.AsNumpy(l_ft)
|
47
|
+
df_ft = pnd.DataFrame(d_data)
|
48
|
+
df_ft = ut.patch_and_tag(df_ft)
|
49
|
+
|
50
|
+
if 'patched_indices' in df_ft.attrs:
|
51
|
+
self._arr_patch = df_ft.attrs['patched_indices']
|
52
|
+
|
53
|
+
nfeat = len(l_ft)
|
54
|
+
log.info(f'Found {nfeat} features')
|
55
|
+
for name in l_ft:
|
56
|
+
log.debug(name)
|
57
|
+
|
58
|
+
return df_ft
|
59
|
+
# --------------------------------------------
|
60
|
+
def _non_overlapping_hashes(self, model, df_ft):
|
61
|
+
'''
|
62
|
+
Will return True if hashes of model and data do not overlap
|
63
|
+
'''
|
64
|
+
|
65
|
+
s_mod_hash = model.hashes
|
66
|
+
s_dff_hash = ut.get_hashes(df_ft)
|
67
|
+
|
68
|
+
s_int = s_mod_hash.intersection(s_dff_hash)
|
69
|
+
if len(s_int) == 0:
|
70
|
+
return True
|
71
|
+
|
72
|
+
return False
|
73
|
+
# --------------------------------------------
|
74
|
+
def _predict_with_overlap(self, df_ft : pnd.DataFrame) -> numpy.ndarray:
|
75
|
+
'''
|
76
|
+
Takes pandas dataframe with features
|
77
|
+
|
78
|
+
Will return numpy array of prediction probabilities when there is an overlap
|
79
|
+
of data and model hashes
|
80
|
+
'''
|
81
|
+
df_ft = ut.index_with_hashes(df_ft)
|
82
|
+
d_prob = {}
|
83
|
+
ntotal = len(df_ft)
|
84
|
+
log.debug(30 * '-')
|
85
|
+
log.info(f'Total size: {ntotal}')
|
86
|
+
log.debug(30 * '-')
|
87
|
+
for model in tqdm.tqdm(self._l_model, ascii=' -'):
|
88
|
+
d_prob_tmp = self._evaluate_model(model, df_ft)
|
89
|
+
d_prob.update(d_prob_tmp)
|
90
|
+
|
91
|
+
ndata = len(df_ft)
|
92
|
+
nprob = len(d_prob)
|
93
|
+
if ndata != nprob:
|
94
|
+
log.warning(f'Dataset size ({ndata}) and probabilities size ({nprob}) differ, likely there are repeated entries')
|
95
|
+
|
96
|
+
l_prob = [ d_prob[hsh] for hsh in df_ft.index ]
|
97
|
+
|
98
|
+
return numpy.array(l_prob)
|
99
|
+
# --------------------------------------------
|
100
|
+
def _evaluate_model(self, model : CVClassifier, df_ft : pnd.DataFrame) -> dict[str, float]:
|
101
|
+
'''
|
102
|
+
Evaluate the dataset for one of the folds, by taking the model and the full dataset
|
103
|
+
'''
|
104
|
+
s_dat_hash = set(df_ft.index)
|
105
|
+
s_mod_hash = model.hashes
|
106
|
+
|
107
|
+
s_dif_hash = s_dat_hash - s_mod_hash
|
108
|
+
|
109
|
+
ndif = len(s_dif_hash)
|
110
|
+
ndat = len(s_dat_hash)
|
111
|
+
nmod = len(s_mod_hash)
|
112
|
+
log.debug(f'{ndif:<20}{"=":10}{ndat:<20}{"-":10}{nmod:<20}')
|
113
|
+
|
114
|
+
df_ft_group= df_ft.loc[df_ft.index.isin(s_dif_hash)]
|
115
|
+
|
116
|
+
l_prob = model.predict_proba(df_ft_group)
|
117
|
+
l_hash = list(df_ft_group.index)
|
118
|
+
d_prob = dict(zip(l_hash, l_prob))
|
119
|
+
nfeat = len(df_ft_group)
|
120
|
+
nprob = len(l_prob)
|
121
|
+
log.debug(f'{nfeat:<10}{"->":10}{nprob:<10}')
|
122
|
+
|
123
|
+
return d_prob
|
124
|
+
# --------------------------------------------
|
125
|
+
def _patch_probabilities(self, arr_prb : numpy.ndarray) -> numpy.ndarray:
|
126
|
+
if not hasattr(self, '_arr_patch'):
|
127
|
+
return arr_prb
|
128
|
+
|
129
|
+
nentries = len(self._arr_patch)
|
130
|
+
log.warning(f'Patching {nentries} probabilities')
|
131
|
+
arr_prb[self._arr_patch] = -1
|
132
|
+
|
133
|
+
return arr_prb
|
134
|
+
# --------------------------------------------
|
135
|
+
def predict(self) -> numpy.ndarray:
|
136
|
+
'''
|
137
|
+
Will return array of prediction probabilities for the signal category
|
138
|
+
'''
|
139
|
+
df_ft = self._get_df()
|
140
|
+
model = self._l_model[0]
|
141
|
+
|
142
|
+
if self._non_overlapping_hashes(model, df_ft):
|
143
|
+
log.debug('No intersecting hashes found between model and data')
|
144
|
+
arr_prb = model.predict_proba(df_ft)
|
145
|
+
else:
|
146
|
+
log.info('Intersecting hashes found between model and data')
|
147
|
+
arr_prb = self._predict_with_overlap(df_ft)
|
148
|
+
|
149
|
+
arr_prb = self._patch_probabilities(arr_prb)
|
150
|
+
|
151
|
+
return arr_prb
|
152
|
+
# ---------------------------------------
|