data-manipulation-utilities 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
  2. data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
  3. data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
  4. data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
  5. data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
  6. dmu/arrays/utilities.py +55 -0
  7. dmu/dataframe/dataframe.py +36 -0
  8. dmu/generic/utilities.py +69 -0
  9. dmu/logging/log_store.py +129 -0
  10. dmu/ml/cv_classifier.py +122 -0
  11. dmu/ml/cv_predict.py +152 -0
  12. dmu/ml/train_mva.py +257 -0
  13. dmu/ml/utilities.py +132 -0
  14. dmu/plotting/plotter.py +227 -0
  15. dmu/plotting/plotter_1d.py +113 -0
  16. dmu/plotting/plotter_2d.py +87 -0
  17. dmu/rdataframe/atr_mgr.py +79 -0
  18. dmu/rdataframe/utilities.py +72 -0
  19. dmu/rfile/rfprinter.py +91 -0
  20. dmu/rfile/utilities.py +34 -0
  21. dmu/stats/fitter.py +515 -0
  22. dmu/stats/function.py +314 -0
  23. dmu/stats/utilities.py +134 -0
  24. dmu/testing/utilities.py +119 -0
  25. dmu/text/transformer.py +182 -0
  26. dmu_data/__init__.py +0 -0
  27. dmu_data/ml/tests/train_mva.yaml +37 -0
  28. dmu_data/plotting/tests/2d.yaml +14 -0
  29. dmu_data/plotting/tests/fig_size.yaml +13 -0
  30. dmu_data/plotting/tests/high_stat.yaml +22 -0
  31. dmu_data/plotting/tests/name.yaml +14 -0
  32. dmu_data/plotting/tests/no_bounds.yaml +12 -0
  33. dmu_data/plotting/tests/simple.yaml +8 -0
  34. dmu_data/plotting/tests/title.yaml +14 -0
  35. dmu_data/plotting/tests/weights.yaml +13 -0
  36. dmu_data/text/transform.toml +4 -0
  37. dmu_data/text/transform.txt +6 -0
  38. dmu_data/text/transform_set.toml +8 -0
  39. dmu_data/text/transform_set.txt +6 -0
  40. dmu_data/text/transform_trf.txt +12 -0
  41. dmu_scripts/physics/check_truth.py +121 -0
  42. dmu_scripts/rfile/compare_root_files.py +299 -0
  43. dmu_scripts/rfile/print_trees.py +35 -0
  44. dmu_scripts/ssh/coned.py +168 -0
  45. dmu_scripts/text/transform_text.py +46 -0
@@ -0,0 +1,45 @@
1
+ dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
2
+ dmu/dataframe/dataframe.py,sha256=ZgRCw5hN18gOXGL9nHDc4eNi0P8lOIAIEILmbEiTlXw,1088
3
+ dmu/generic/utilities.py,sha256=0Xnq9t35wuebAqKxbyAiMk1ISB7IcXK4cFH25MT1fgw,1741
4
+ dmu/logging/log_store.py,sha256=v0tiNz-6ktT_afD5DuvCZ8Nmr82JKQOPli8hgd28P1Q,3960
5
+ dmu/ml/cv_classifier.py,sha256=n81m7i2M6Zq96AEd9EZGwXSrbG5m9jkS5RdeXvbsAXU,3712
6
+ dmu/ml/cv_predict.py,sha256=Bqxu-f6qquKJokFljhCzL_kiGcjLJLQFhVBD130fsyw,4893
7
+ dmu/ml/train_mva.py,sha256=d_n-A07DFweikz5nXap4OE_Mqx8VprFT7zbxmnQAbac,9638
8
+ dmu/ml/utilities.py,sha256=Nue7O9zi1QXgjGRPH6wnSAW9jusMQ2ZOSDJzBqJKIi0,3687
9
+ dmu/plotting/plotter.py,sha256=laa6Kl7P-ZOIhaOFBVjOH4XQ4kPCV7wBNvLIMBnyCwM,7181
10
+ dmu/plotting/plotter_1d.py,sha256=G-i94uzm2TjNaog1A4agAKar_G0qNdkAqIPCmzhe85Y,3660
11
+ dmu/plotting/plotter_2d.py,sha256=SWPKns-CfpUZHgBXvwm3gceH3k2eL_mKGXQ8sWpZJB0,2919
12
+ dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
13
+ dmu/rdataframe/utilities.py,sha256=a31PdUz12sC2bx78LK6gvACh1M_eFaIVwuZEvOTcvcc,2084
14
+ dmu/rfile/rfprinter.py,sha256=vGdqyHT_GwGBhrY7KG63EAUGWEOqobz_5yTL6goXbfk,2722
15
+ dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
16
+ dmu/stats/fitter.py,sha256=LDvFNyhgO0OzXN7aH3kfHe6LzuPqdQfPcKR_IegDcaU,18204
17
+ dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
18
+ dmu/stats/utilities.py,sha256=LQy4kd3xSXqpApcWuYfZxkGQyjowaXv2Wr1c4Bj-4ys,4523
19
+ dmu/testing/utilities.py,sha256=WbMM4e9Cn3-B-12Vr64mB5qTKkV32joStlRkD-48lG0,3460
20
+ dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
21
+ dmu_data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ dmu_data/ml/tests/train_mva.yaml,sha256=TCniCVpXMEFxZcHa8IIqollKA7ci4OkBnRznLEkXM9o,925
23
+ dmu_data/plotting/tests/2d.yaml,sha256=lTMNheK3DB8klp4O5QjMDwBI1A1Oh2_Wp2F2Ro9VQKM,282
24
+ dmu_data/plotting/tests/fig_size.yaml,sha256=7ROq49nwZ1A2EbPiySmu6n3G-Jq6YAOkc3d2X3YNZv0,294
25
+ dmu_data/plotting/tests/high_stat.yaml,sha256=bLglBLCZK6ft0xMhQ5OltxE76cWsBMPMjO6GG0OkDr8,522
26
+ dmu_data/plotting/tests/name.yaml,sha256=mkcPAVg8wBAmlSbSRQ1bcaMl4vOS6LXMtpqQeDrrtO4,312
27
+ dmu_data/plotting/tests/no_bounds.yaml,sha256=8e1QdphBjz-suDr857DoeUC2DXiy6SE-gvkORJQYv80,257
28
+ dmu_data/plotting/tests/simple.yaml,sha256=N_TvNBh_2dU0-VYgu_LMrtY0kV_hg2HxVuEoDlr1HX8,138
29
+ dmu_data/plotting/tests/title.yaml,sha256=bawKp9aGpeRrHzv69BOCbFX8sq9bb3Es9tdsPTE7jIk,333
30
+ dmu_data/plotting/tests/weights.yaml,sha256=RWQ1KxbCq-uO62WJ2AoY4h5Umc37zG35s-TpKnNMABI,312
31
+ dmu_data/text/transform.toml,sha256=R-832BZalzHZ6c5gD6jtT_Hj8BCsM5vxa1v6oeiwaP4,94
32
+ dmu_data/text/transform.txt,sha256=EX760da6Vkf-_EPxnQlC5hGSkfFhJCCGCD19NU-1Qto,44
33
+ dmu_data/text/transform_set.toml,sha256=Jeh7BTz82idqvbOQJtl9-ur56mZkzDn5WtvmIb48LoE,150
34
+ dmu_data/text/transform_set.txt,sha256=1KivMoP9LxPn9955QrRmOzjEqduEjhTetQ9MXykO5LY,46
35
+ dmu_data/text/transform_trf.txt,sha256=zxBRTgcSmX7RdqfmWF88W1YqbyNHa4Ccruf1MmnYv2A,74
36
+ dmu_scripts/physics/check_truth.py,sha256=b1P_Pa9ef6VcFtyY6Y9KS9Om9L-QrCBjDKp4dqca0PQ,3964
37
+ dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki7ZQynxXX9Q,9540
38
+ dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
39
+ dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
40
+ dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
41
+ data_manipulation_utilities-0.0.1.dist-info/METADATA,sha256=chAPGy68TwWTweqqzGgXvrD6-xia1Bq9XDgvgS1qLEE,19714
42
+ data_manipulation_utilities-0.0.1.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
43
+ data_manipulation_utilities-0.0.1.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
44
+ data_manipulation_utilities-0.0.1.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
45
+ data_manipulation_utilities-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.5.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,6 @@
1
+ [console_scripts]
2
+ check_truth = dmu_scripts.physics.check_truth:main
3
+ compare_root_files = dmu_scripts.rfile.compare_root_files:main
4
+ coned = dmu_scripts.ssh.coned:main
5
+ print_trees = dmu_scripts.rfile.print_trees:main
6
+ transform_text = dmu_scripts.text.transform_text:main
@@ -0,0 +1,3 @@
1
+ dmu
2
+ dmu_data
3
+ dmu_scripts
@@ -0,0 +1,55 @@
1
+ '''
2
+ Module with utility functions to handle numpy arrays
3
+ '''
4
+
5
+ import math
6
+ import numpy
7
+
8
+ #-----------------------------------------------
9
+ def _check_ftimes(ftimes : float) -> None:
10
+ '''
11
+ Check if floating scale factor makes sense
12
+ '''
13
+ if not isinstance(ftimes, float):
14
+ raise TypeError(f'Scaling factor is not a float, but: {ftimes}')
15
+
16
+ if ftimes <= 1.0:
17
+ raise ValueError(f'Scaling factor needs to be larger than 1.0, found: {ftimes}')
18
+ #-----------------------------------------------
19
+ def repeat_arr(arr_val : numpy.ndarray, ftimes : float) -> numpy.ndarray:
20
+ '''
21
+ Will repeat elements in an array a non integer number of times.
22
+
23
+ arr_val: 1D array of objects
24
+ ftimes (float): Number of times to repeat it.
25
+ '''
26
+
27
+ _check_ftimes(ftimes)
28
+
29
+ a = math.floor(ftimes)
30
+ b = math.ceil(ftimes)
31
+ if numpy.isclose(a, b):
32
+ return numpy.repeat(arr_val, a)
33
+
34
+ # Will split randomly data in arr_val, such that one set will get increased
35
+ # by floor(ftimes) and the other by ceiling(ftimes)
36
+
37
+ # Get probability that given element belongs to dataset weighted by "a"
38
+ p = b - ftimes
39
+ size_t = len(arr_val)
40
+ size_1 = int(p * size_t)
41
+
42
+ # Find subset to weight by "a"
43
+ arr_ind_1 = numpy.random.choice(size_t, size=size_1, replace=False)
44
+ arr_val_1 = arr_val[arr_ind_1]
45
+
46
+ # Find subset to weight by "b"
47
+ arr_ind_2 = numpy.setdiff1d(numpy.arange(size_t), arr_ind_1)
48
+ arr_val_2 = arr_val[arr_ind_2]
49
+
50
+ # Repeat them an integer number of times
51
+ arr_val_1 = numpy.repeat(arr_val_1, a)
52
+ arr_val_2 = numpy.repeat(arr_val_2, b)
53
+
54
+ return numpy.concatenate([arr_val_1, arr_val_2])
55
+ #-----------------------------------------------
@@ -0,0 +1,36 @@
1
+ '''
2
+ Class re-implementing dataframe, meant to be a thin layer on top of polars
3
+ '''
4
+
5
+ import polars as pl
6
+
7
+ # ------------------------------------------
8
+ class DataFrame(pl.DataFrame):
9
+ '''
10
+ Class reimplementing dataframes from polars
11
+ '''
12
+ # ------------------------------------------
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ # ------------------------------------------
16
+ def define(self, name : str, expr : str):
17
+ '''
18
+ Function will define new column in dataframe
19
+
20
+ name (str): Name of new column
21
+ expr (str): Expression depending on existing columns
22
+ '''
23
+
24
+ for col in self.columns:
25
+ expr = expr.replace(col, f' {col} ')
26
+
27
+ for col in self.columns:
28
+ expr = expr.replace(f' {col} ', f' pl.col("{col}") ')
29
+
30
+ try:
31
+ df = self.with_columns(eval(expr).alias(name))
32
+ except TypeError as exc:
33
+ raise TypeError(f'Cannot define {expr} -> {name}') from exc
34
+
35
+ return DataFrame(df)
36
+ # ------------------------------------------
@@ -0,0 +1,69 @@
1
+ '''
2
+ Module containing generic utility functions
3
+ '''
4
+ import os
5
+ import time
6
+ import json
7
+ import inspect
8
+
9
+ from typing import Callable
10
+
11
+ from functools import wraps
12
+ from dmu.logging.log_store import LogStore
13
+
14
+ TIMER_ON=False
15
+
16
+ log = LogStore.add_logger('dmu:generic:utilities')
17
+
18
+ # --------------------------------
19
+ def _get_module_name( fun : Callable) -> str:
20
+ mod = inspect.getmodule(fun)
21
+ if mod is None:
22
+ raise ValueError(f'Cannot determine module name for function: {fun}')
23
+
24
+ return mod.__name__
25
+ # --------------------------------
26
+ def timeit(f):
27
+ '''
28
+ Decorator used to time functions, it is turned off by default, can be turned on with:
29
+
30
+ from dmu.generic.utilities import TIMER_ON
31
+ from dmu.generic.utilities import timeit
32
+
33
+ TIMER_ON=True
34
+
35
+ @timeit
36
+ def fun():
37
+ ...
38
+ '''
39
+ @wraps(f)
40
+ def wrap(*args, **kw):
41
+ if not TIMER_ON:
42
+ result = f(*args, **kw)
43
+ return result
44
+
45
+ ts = time.time()
46
+ result = f(*args, **kw)
47
+ te = time.time()
48
+ mod_nam = _get_module_name(f)
49
+ fun_nam = f.__name__
50
+ log.info(f'{mod_nam}.py:{fun_nam}; Time: {te-ts:.3f}s')
51
+
52
+ return result
53
+ return wrap
54
+ # --------------------------------
55
+ def dump_json(data, path : str, sort_keys : bool = False):
56
+ '''
57
+ Saves data as JSON
58
+
59
+ Parameters
60
+ data : dictionary, list, etc
61
+ path : Path to JSON file where to save it
62
+ sort_keys: Will set sort_keys argument of json.dump function
63
+ '''
64
+ dir_name = os.path.dirname(path)
65
+ os.makedirs(dir_name, exist_ok=True)
66
+
67
+ with open(path, 'w', encoding='utf-8') as ofile:
68
+ json.dump(data, ofile, indent=4, sort_keys=sort_keys)
69
+ # --------------------------------
@@ -0,0 +1,129 @@
1
+ '''
2
+ Module holding LogStore
3
+ '''
4
+
5
+ import logging
6
+ import logzero
7
+
8
+ #------------------------------------------------------------
9
+ class StoreFormater(logging.Formatter):
10
+ '''
11
+ Custom formatter
12
+ '''
13
+ LOG_COLORS = {
14
+ logging.DEBUG : '\033[94m', # Gray
15
+ logging.INFO : '\033[37m', # White
16
+ logging.WARNING : '\033[93m', # Yellow
17
+ logging.ERROR : '\033[91m', # Red
18
+ logging.CRITICAL: '\033[1;91m' # Bold Red
19
+ }
20
+
21
+ RESET_COLOR = '\033[0m' # Reset color to default
22
+
23
+ def format(self, record):
24
+ log_color = self.LOG_COLORS.get(record.levelno, self.RESET_COLOR)
25
+ message = super().format(record)
26
+
27
+ return f'{log_color}{message}{self.RESET_COLOR}'
28
+ #------------------------------------------------------------
29
+ class LogStore:
30
+ '''
31
+ Class used to make loggers, set log levels, print loggers, e.g. interface to logging/logzero, etc.
32
+ '''
33
+ #pylint: disable = invalid-name
34
+ d_logger = {}
35
+ d_levels = {}
36
+ log_level = logging.INFO
37
+ is_configured = False
38
+ backend = 'logging'
39
+ #--------------------------
40
+ @staticmethod
41
+ def add_logger(name=None):
42
+ '''
43
+ Will use underlying logging library logzero/logging, etc to make logger
44
+
45
+ name (str): Name of logger
46
+ '''
47
+
48
+ if name is None:
49
+ raise ValueError('Logger name missing')
50
+
51
+ if name in LogStore.d_logger:
52
+ raise ValueError(f'Logger name {name} already found')
53
+
54
+ level = LogStore.log_level if name not in LogStore.d_levels else LogStore.d_levels[name]
55
+
56
+ if LogStore.backend == 'logging':
57
+ logger = LogStore._get_logging_logger(name, level)
58
+ elif LogStore.backend == 'logzero':
59
+ logger = LogStore._get_logzero_logger(name, level)
60
+ else:
61
+ raise ValueError(f'Invalid backend: {LogStore.backend}')
62
+
63
+ LogStore.d_logger[name] = logger
64
+
65
+ return logger
66
+ #--------------------------
67
+ @staticmethod
68
+ def _get_logzero_logger(name : str, level : int):
69
+ log = logzero.setup_logger(name=name)
70
+ log.setLevel(level)
71
+
72
+ return log
73
+ #--------------------------
74
+ @staticmethod
75
+ def _get_logging_logger(name : str, level : int):
76
+ logger = logging.getLogger(name=name)
77
+
78
+ logger.setLevel(level)
79
+
80
+ hnd= logging.StreamHandler()
81
+ hnd.setLevel(level)
82
+
83
+ fmt= StoreFormater('%(asctime)s - %(filename)s:%(lineno)d - %(message)s', datefmt='%H:%M:%S')
84
+ hnd.setFormatter(fmt)
85
+
86
+ if logger.hasHandlers():
87
+ logger.handlers.clear()
88
+
89
+ logger.addHandler(hnd)
90
+
91
+ return logger
92
+ #--------------------------
93
+ @staticmethod
94
+ def set_level(name, value):
95
+ '''
96
+ Will set the level of a logger, it not present yet, it will store the level and set it when created.
97
+ Parameters:
98
+ -----------------
99
+ name (str): Name of logger
100
+ value (int): 10 debug, 20 info, 30 warning
101
+ '''
102
+
103
+ if name in LogStore.d_logger:
104
+ lgr=LogStore.d_logger[name]
105
+ lgr.handlers[0].setLevel(value)
106
+ lgr.setLevel(value)
107
+ else:
108
+ LogStore.d_levels[name] = value
109
+ #--------------------------
110
+ @staticmethod
111
+ def show_loggers():
112
+ '''
113
+ Will print loggers and log levels in two columns
114
+ '''
115
+ print(80 * '-')
116
+ print(f'{"Name":<60}{"Level":<20}')
117
+ print(80 * '-')
118
+ for name, logger in LogStore.d_logger.items():
119
+ print(f'{name:<60}{logger.level:<20}')
120
+ #--------------------------
121
+ @staticmethod
122
+ def set_all_levels(level):
123
+ '''
124
+ Will set all loggers to this level (int)
125
+ '''
126
+ for name, logger in LogStore.d_logger.items():
127
+ logger.setLevel(level)
128
+ print(f'{name:<60}{"->":20}{logger.level:<20}')
129
+ #------------------------------------------------------------
@@ -0,0 +1,122 @@
1
+ '''
2
+ Module holding cv_classifier class
3
+ '''
4
+
5
+ from sklearn.ensemble import GradientBoostingClassifier
6
+
7
+ from dmu.logging.log_store import LogStore
8
+ import dmu.ml.utilities as ut
9
+
10
+ log = LogStore.add_logger('dmu:ml:CVClassifier')
11
+
12
+ # ---------------------------------------
13
+ class CVSameData(Exception):
14
+ '''
15
+ Will be raised if a model is been evaluated with a dataset such that at least one of the
16
+ samples was also used for the training
17
+ '''
18
+ # ---------------------------------------
19
+ class CVClassifier(GradientBoostingClassifier):
20
+ '''
21
+ Derived class meant to implement features needed for cross-validation
22
+ '''
23
+ # pylint: disable = too-many-ancestors, abstract-method
24
+ # ----------------------------------
25
+ def __init__(self, cfg : dict | None = None):
26
+ '''
27
+ cfg (dict) : Dictionary with configuration, specially the hyperparameters set in the `hyper` field
28
+ '''
29
+ if cfg is None:
30
+ raise ValueError('No configuration was passed')
31
+
32
+ self._cfg = cfg
33
+
34
+ d_hyp = self._cfg['training']['hyper']
35
+ super().__init__(**d_hyp)
36
+
37
+ self._s_hash = set()
38
+ self._data = {}
39
+ self._l_ft_name = None
40
+ # ----------------------------------
41
+ @property
42
+ def features(self):
43
+ '''
44
+ Returns list of feature names used in training dataset
45
+ '''
46
+ return self._l_ft_name
47
+ # ----------------------------------
48
+ @property
49
+ def hashes(self):
50
+ '''
51
+ Will return set with hashes of training data
52
+ '''
53
+ return self._s_hash
54
+ # ----------------------------------
55
+ @property
56
+ def cfg(self):
57
+ '''
58
+ Will return dictionary with configuration
59
+ '''
60
+
61
+ return self._cfg
62
+ # ----------------------------------
63
+ def __str__(self):
64
+ nhash = len(self._s_hash)
65
+
66
+ msg = 40 * '-' + '\n'
67
+ msg+= f'{"Attribute":<20}{"Value":<20}\n'
68
+ msg+= 40 * '-' + '\n'
69
+ msg += f'{"Hashes":<20}{nhash:<20}\n'
70
+ msg+= 40 * '-'
71
+
72
+ return msg
73
+ # ----------------------------------
74
+ def fit(self, *args, **kwargs):
75
+ '''
76
+ Runs the training of the model
77
+ '''
78
+ log.debug('Fitting')
79
+
80
+ df_ft = args[0]
81
+ self._l_ft_name = list(df_ft.columns)
82
+
83
+ self._s_hash = ut.get_hashes(df_ft)
84
+ log.debug(f'Saving {len(self._s_hash)} hashes')
85
+
86
+ super().fit(*args, **kwargs)
87
+
88
+ return self
89
+ # ----------------------------------
90
+ def _check_hashes(self, df_ft):
91
+ '''
92
+ Will check that the hashes of the passed features do not intersect with the
93
+ hashes of the features used for the training.
94
+ Else it will raise CVSameData exception
95
+ '''
96
+
97
+ if len(self._s_hash) == 0:
98
+ raise ValueError('Found no hashes in model')
99
+
100
+ s_hash = ut.get_hashes(df_ft)
101
+ s_inter = self._s_hash.intersection(s_hash)
102
+
103
+ nh1 = len(self._s_hash)
104
+ nh2 = len( s_hash)
105
+ nh3 = len(s_inter)
106
+
107
+ if nh3 > 0:
108
+ raise CVSameData(f'Found non empty intersection of size: {nh1} ^ {nh2} = {nh3}')
109
+ # ----------------------------------
110
+ def predict_proba(self, X, on_training_ok=False):
111
+ '''
112
+ Takes pandas dataframe with features
113
+ Will first check hashes to make sure none of the events/samples
114
+ used for the training of this model are in the prediction
115
+
116
+ on_training_ok (bool): True if the dataset is expected to contain samples used for training, default is False
117
+ '''
118
+ if not on_training_ok:
119
+ self._check_hashes(X)
120
+
121
+ return super().predict_proba(X)
122
+ # ---------------------------------------
dmu/ml/cv_predict.py ADDED
@@ -0,0 +1,152 @@
1
+ '''
2
+ Module holding CVPredict class
3
+ '''
4
+ from typing import Optional
5
+
6
+ import pandas as pnd
7
+ import numpy
8
+ import tqdm
9
+
10
+ from ROOT import RDataFrame
11
+
12
+ import dmu.ml.utilities as ut
13
+ import dmu.ml.cv_classifier as CVClassifier
14
+
15
+ from dmu.logging.log_store import LogStore
16
+
17
+ log = LogStore.add_logger('dmu:ml:cv_predict')
18
+ # ---------------------------------------
19
+ class CVPredict:
20
+ '''
21
+ Class used to get classification probabilities from ROOT
22
+ dataframe and a set of models. The models were trained with CVClassifier
23
+ '''
24
+ def __init__(self, models : Optional[list] = None, rdf : Optional[RDataFrame] = None):
25
+ '''
26
+ Will take a list of CVClassifier models and a ROOT dataframe
27
+ '''
28
+
29
+ if models is None:
30
+ raise ValueError('No list of models passed')
31
+
32
+ if rdf is None:
33
+ raise ValueError('No ROOT dataframe passed')
34
+
35
+ self._l_model = models
36
+ self._rdf = rdf
37
+
38
+ self._arr_patch : numpy.ndarray
39
+ # --------------------------------------------
40
+ def _get_df(self):
41
+ '''
42
+ Will make ROOT rdf into dataframe and return it
43
+ '''
44
+ model = self._l_model[0]
45
+ l_ft = model.features
46
+ d_data= self._rdf.AsNumpy(l_ft)
47
+ df_ft = pnd.DataFrame(d_data)
48
+ df_ft = ut.patch_and_tag(df_ft)
49
+
50
+ if 'patched_indices' in df_ft.attrs:
51
+ self._arr_patch = df_ft.attrs['patched_indices']
52
+
53
+ nfeat = len(l_ft)
54
+ log.info(f'Found {nfeat} features')
55
+ for name in l_ft:
56
+ log.debug(name)
57
+
58
+ return df_ft
59
+ # --------------------------------------------
60
+ def _non_overlapping_hashes(self, model, df_ft):
61
+ '''
62
+ Will return True if hashes of model and data do not overlap
63
+ '''
64
+
65
+ s_mod_hash = model.hashes
66
+ s_dff_hash = ut.get_hashes(df_ft)
67
+
68
+ s_int = s_mod_hash.intersection(s_dff_hash)
69
+ if len(s_int) == 0:
70
+ return True
71
+
72
+ return False
73
+ # --------------------------------------------
74
+ def _predict_with_overlap(self, df_ft : pnd.DataFrame) -> numpy.ndarray:
75
+ '''
76
+ Takes pandas dataframe with features
77
+
78
+ Will return numpy array of prediction probabilities when there is an overlap
79
+ of data and model hashes
80
+ '''
81
+ df_ft = ut.index_with_hashes(df_ft)
82
+ d_prob = {}
83
+ ntotal = len(df_ft)
84
+ log.debug(30 * '-')
85
+ log.info(f'Total size: {ntotal}')
86
+ log.debug(30 * '-')
87
+ for model in tqdm.tqdm(self._l_model, ascii=' -'):
88
+ d_prob_tmp = self._evaluate_model(model, df_ft)
89
+ d_prob.update(d_prob_tmp)
90
+
91
+ ndata = len(df_ft)
92
+ nprob = len(d_prob)
93
+ if ndata != nprob:
94
+ log.warning(f'Dataset size ({ndata}) and probabilities size ({nprob}) differ, likely there are repeated entries')
95
+
96
+ l_prob = [ d_prob[hsh] for hsh in df_ft.index ]
97
+
98
+ return numpy.array(l_prob)
99
+ # --------------------------------------------
100
+ def _evaluate_model(self, model : CVClassifier, df_ft : pnd.DataFrame) -> dict[str, float]:
101
+ '''
102
+ Evaluate the dataset for one of the folds, by taking the model and the full dataset
103
+ '''
104
+ s_dat_hash = set(df_ft.index)
105
+ s_mod_hash = model.hashes
106
+
107
+ s_dif_hash = s_dat_hash - s_mod_hash
108
+
109
+ ndif = len(s_dif_hash)
110
+ ndat = len(s_dat_hash)
111
+ nmod = len(s_mod_hash)
112
+ log.debug(f'{ndif:<20}{"=":10}{ndat:<20}{"-":10}{nmod:<20}')
113
+
114
+ df_ft_group= df_ft.loc[df_ft.index.isin(s_dif_hash)]
115
+
116
+ l_prob = model.predict_proba(df_ft_group)
117
+ l_hash = list(df_ft_group.index)
118
+ d_prob = dict(zip(l_hash, l_prob))
119
+ nfeat = len(df_ft_group)
120
+ nprob = len(l_prob)
121
+ log.debug(f'{nfeat:<10}{"->":10}{nprob:<10}')
122
+
123
+ return d_prob
124
+ # --------------------------------------------
125
+ def _patch_probabilities(self, arr_prb : numpy.ndarray) -> numpy.ndarray:
126
+ if not hasattr(self, '_arr_patch'):
127
+ return arr_prb
128
+
129
+ nentries = len(self._arr_patch)
130
+ log.warning(f'Patching {nentries} probabilities')
131
+ arr_prb[self._arr_patch] = -1
132
+
133
+ return arr_prb
134
+ # --------------------------------------------
135
+ def predict(self) -> numpy.ndarray:
136
+ '''
137
+ Will return array of prediction probabilities for the signal category
138
+ '''
139
+ df_ft = self._get_df()
140
+ model = self._l_model[0]
141
+
142
+ if self._non_overlapping_hashes(model, df_ft):
143
+ log.debug('No intersecting hashes found between model and data')
144
+ arr_prb = model.predict_proba(df_ft)
145
+ else:
146
+ log.info('Intersecting hashes found between model and data')
147
+ arr_prb = self._predict_with_overlap(df_ft)
148
+
149
+ arr_prb = self._patch_probabilities(arr_prb)
150
+
151
+ return arr_prb
152
+ # ---------------------------------------