data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +641 -44
  2. data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
  3. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
  4. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
  5. dmu/__init__.py +0 -0
  6. dmu/generic/hashing.py +34 -8
  7. dmu/generic/utilities.py +164 -11
  8. dmu/logging/log_store.py +34 -2
  9. dmu/logging/messages.py +96 -0
  10. dmu/ml/cv_classifier.py +3 -3
  11. dmu/ml/cv_diagnostics.py +3 -0
  12. dmu/ml/cv_performance.py +58 -0
  13. dmu/ml/cv_predict.py +149 -46
  14. dmu/ml/train_mva.py +482 -100
  15. dmu/ml/utilities.py +29 -10
  16. dmu/pdataframe/utilities.py +28 -3
  17. dmu/plotting/fwhm.py +2 -2
  18. dmu/plotting/matrix.py +1 -1
  19. dmu/plotting/plotter.py +23 -3
  20. dmu/plotting/plotter_1d.py +96 -32
  21. dmu/plotting/plotter_2d.py +5 -0
  22. dmu/rdataframe/utilities.py +54 -3
  23. dmu/rfile/ddfgetter.py +102 -0
  24. dmu/stats/fit_stats.py +129 -0
  25. dmu/stats/fitter.py +55 -22
  26. dmu/stats/gof_calculator.py +7 -0
  27. dmu/stats/model_factory.py +153 -62
  28. dmu/stats/parameters.py +100 -0
  29. dmu/stats/utilities.py +443 -12
  30. dmu/stats/wdata.py +187 -0
  31. dmu/stats/zfit.py +17 -0
  32. dmu/stats/zfit_plotter.py +147 -36
  33. dmu/testing/utilities.py +102 -24
  34. dmu/workflow/__init__.py +0 -0
  35. dmu/workflow/cache.py +266 -0
  36. dmu_data/ml/tests/train_mva.yaml +9 -7
  37. dmu_data/ml/tests/train_mva_def.yaml +75 -0
  38. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +10 -5
  39. dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
  40. dmu_data/plotting/tests/2d.yaml +5 -5
  41. dmu_data/plotting/tests/line.yaml +15 -0
  42. dmu_data/plotting/tests/styling.yaml +8 -1
  43. dmu_data/rfile/friends.yaml +13 -0
  44. dmu_data/stats/fitter/test_simple.yaml +28 -0
  45. dmu_data/stats/kde_optimizer/control.json +1 -0
  46. dmu_data/stats/kde_optimizer/signal.json +1 -0
  47. dmu_data/stats/parameters/data.yaml +178 -0
  48. dmu_data/tests/config.json +6 -0
  49. dmu_data/tests/config.yaml +4 -0
  50. dmu_data/tests/pdf_to_tex.txt +34 -0
  51. dmu_scripts/kerberos/check_expiration +21 -0
  52. dmu_scripts/kerberos/convert_certificate +22 -0
  53. dmu_scripts/ml/compare_classifiers.py +85 -0
  54. data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
  55. {data_manipulation_utilities-0.2.7.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
  56. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,93 @@
1
+ data_manipulation_utilities-0.2.8.dev714.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
2
+ dmu/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
4
+ dmu/generic/hashing.py,sha256=QR5Gbv6-ANvi5hL232UNMrw9DONpU27BWTynXGxQLGU,1806
5
+ dmu/generic/utilities.py,sha256=0tT93vF_x0q8STRrTD0GvBEpALz-mqE-vJyen4zWCO8,6861
6
+ dmu/generic/version_management.py,sha256=j0ImlAq6SVNjTh3xRsF6G7DSoyr1w8kTRY84dNriGRE,3750
7
+ dmu/logging/log_store.py,sha256=eRSy8Y4fuiDFJK02Z6fq67XQzOrhQ7GMr2LvvJQbJ40,5172
8
+ dmu/logging/messages.py,sha256=Oj3O5EO2KOPtffyVq2P7RPzjpoXtxZ6yXO5HwTftVcM,2903
9
+ dmu/ml/cv_classifier.py,sha256=6rjezMahwL-WzLGKU-fzMzNxJZAGbM7YAbhaZVcJ3F0,4258
10
+ dmu/ml/cv_diagnostics.py,sha256=PLh41mSVE8Kagp9KcuRDN_7tDL9MjPxQzuewY8jDnNo,7600
11
+ dmu/ml/cv_performance.py,sha256=q9sLxIx7GP-dand3tnhHCBJnT6xqssNdRYv_TVjYWUM,1910
12
+ dmu/ml/cv_predict.py,sha256=0sc_OqwOewKvipcMyi3QqkgG30nkpZZjE-SOhHWHMd0,10778
13
+ dmu/ml/train_mva.py,sha256=7KAFX_zOx8MGbYx62U81JbdBkrZvqclSSkgmYvWX-60,34861
14
+ dmu/ml/utilities.py,sha256=A9j3tBh-jfaFdwwLUleo1QnttfawN7XDiQRh4VTvqVY,4597
15
+ dmu/pdataframe/utilities.py,sha256=xl6iLVKUccqVXYjuHsDUZ6UrCKQPw1k8D-f6407Yq30,2742
16
+ dmu/plotting/fwhm.py,sha256=4e8n6624pxWLcOOtayCQ_hDSSMKU21-3UsdmbkX1ojk,1949
17
+ dmu/plotting/matrix.py,sha256=s_5W8O3yXF3u8OX3f4J4hCoxIVZt1TF8S-qJsFBh2Go,5005
18
+ dmu/plotting/plotter.py,sha256=oc_n9ug0JPaQZycrW_TJkgNxjr0LHNrVJcijqmiLUR4,8136
19
+ dmu/plotting/plotter_1d.py,sha256=Kyoyh-QyZLXXqX19wqEDUWCD1nJEvEonGp9nlgEaoZE,10936
20
+ dmu/plotting/plotter_2d.py,sha256=dXC-7Rsquibe5cn7622ryoKpuv7KCAmouIIXwQ_VEFM,3172
21
+ dmu/plotting/utilities.py,sha256=SI9dvtZq2gr-PXVz71KE4o0i09rZOKgqJKD1jzf6KXk,1167
22
+ dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
23
+ dmu/rdataframe/utilities.py,sha256=cY1Na8HbJ7kB2dwmBagRdsRyCA4ZT_vyIU86ewREj2Y,5322
24
+ dmu/rfile/ddfgetter.py,sha256=0jfNzpv72_NQUKOK5SBsn289rUqVt2BMvuL-Ro5oY7I,3316
25
+ dmu/rfile/rfprinter.py,sha256=mp5jd-oCJAnuokbdmGyL9i6tK2lY72jEfROuBIZ_ums,3941
26
+ dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
27
+ dmu/stats/fit_stats.py,sha256=wzkQT9U32ljGe4azUj1Fj0ECF3zmnH2Ncn0O-_Pl1zQ,4070
28
+ dmu/stats/fitter.py,sha256=rm_fwjkq-0LSjXB_gt3y6BnHoK8Xvd4gHYwKBUJaItQ,19603
29
+ dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
30
+ dmu/stats/gof_calculator.py,sha256=63zNJJGKPy-j_hPNPfu9qNlhrHjYIgJOyL8-VDtbwuI,4894
31
+ dmu/stats/minimizers.py,sha256=db9R2G0SOV-k0BKi6m4EyB_yp6AtZdP23_28B0315oo,7094
32
+ dmu/stats/model_factory.py,sha256=0_o5OmiX0cNhp9_cNqBOYfasBgKlQkQPiy5nqi9qQKA,18966
33
+ dmu/stats/parameters.py,sha256=9lycexTT5ZcxXciiQY9HoJV8O1ahrTEkagd7dYXcfj8,3224
34
+ dmu/stats/utilities.py,sha256=7_tr1j-dl3lLNpxIMWruZs4yUtlNuUTknwGMERpfLhs,17338
35
+ dmu/stats/wdata.py,sha256=IbjZFU9SHTLSYfaBgqamDvqy1K7-3-SaKbU4bGsamK0,6799
36
+ dmu/stats/zfit.py,sha256=aSZj_4IHi9IBthfqlNJeA8YSoMmXO5WipgiKnXKGbnM,286
37
+ dmu/stats/zfit_models.py,sha256=SI61KJ-OG1UAabDICU1iTh6JPKM3giR2ErDraRjkCV8,1842
38
+ dmu/stats/zfit_plotter.py,sha256=gbN5KxhJcP4ItCi98c-fj5_UtvVWL_NA9jkTHiRjvnE,23854
39
+ dmu/testing/utilities.py,sha256=WYlz7Ve5lQjuWhhNL4gWe6_qcByBLV762Lhrc6A0P9E,7421
40
+ dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
41
+ dmu/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ dmu/workflow/cache.py,sha256=CtkGwxuF4UJlD55SmUJcRgWYLsbZOyUvYLI8oTVzk_g,8768
43
+ dmu_data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
+ dmu_data/ml/tests/diagnostics_from_file.yaml,sha256=quvXOPkRducnBsctyape_Rn5_aqMEpPo6nO_UweMORo,404
45
+ dmu_data/ml/tests/diagnostics_from_model.yaml,sha256=rtCQlmGS9ld2xoQJEE35nA07yfRMklEfQEW0w3gRv2A,261
46
+ dmu_data/ml/tests/diagnostics_multiple_methods.yaml,sha256=w8Fpmr7kX1Jsb_h6LL2hiuYKf5lYpckFCpYKzWetbA0,265
47
+ dmu_data/ml/tests/diagnostics_overlay.yaml,sha256=ZVOsxLL8_JQtf41n8Ct-M9Ch10xBwHK54q1fttWPDlE,866
48
+ dmu_data/ml/tests/train_mva.yaml,sha256=KArbTkaj6FqerrUhlkgyBde_4DfkpVza6kCMgMQPi9g,1388
49
+ dmu_data/ml/tests/train_mva_def.yaml,sha256=UyPMo-9nshoB8BHxm9E6S0xd9ngRARdgUq6vnuMlhwI,1765
50
+ dmu_data/ml/tests/train_mva_with_diagnostics.yaml,sha256=-2KKIJ8CiNgMlgpCXkmZRdPEo-sJmAqr01vizfeqkj0,2098
51
+ dmu_data/ml/tests/train_mva_with_preffix.yaml,sha256=Q9SsJSXGbkHWGBvMZIkTZlKNUz5ZcSVBscrKgeMWBvE,1386
52
+ dmu_data/plotting/tests/2d.yaml,sha256=40wKQmNbIabZ7CI8-2QnD6mG1a_B7vEcPdzvehHkseY,520
53
+ dmu_data/plotting/tests/fig_size.yaml,sha256=7ROq49nwZ1A2EbPiySmu6n3G-Jq6YAOkc3d2X3YNZv0,294
54
+ dmu_data/plotting/tests/high_stat.yaml,sha256=bLglBLCZK6ft0xMhQ5OltxE76cWsBMPMjO6GG0OkDr8,522
55
+ dmu_data/plotting/tests/legend.yaml,sha256=wGpj58ig-GOlqbWoN894zrCet2Fj9f5QtY0rig_UC-c,213
56
+ dmu_data/plotting/tests/line.yaml,sha256=EERDeTctbauwqAvmKFXC4Ot3Tgx-8kcIniGbepXwsKs,305
57
+ dmu_data/plotting/tests/name.yaml,sha256=mkcPAVg8wBAmlSbSRQ1bcaMl4vOS6LXMtpqQeDrrtO4,312
58
+ dmu_data/plotting/tests/no_bounds.yaml,sha256=8e1QdphBjz-suDr857DoeUC2DXiy6SE-gvkORJQYv80,257
59
+ dmu_data/plotting/tests/normalized.yaml,sha256=Y0eKtyV5pvlSxvqfsLjytYtv8xYF3HZ5WEdCJdeHGQI,193
60
+ dmu_data/plotting/tests/plug_fwhm.yaml,sha256=xl5LXc9Nt66anM-HOXAxCtlaxWNM7zzIXf1Y6U8M4Wg,449
61
+ dmu_data/plotting/tests/plug_stats.yaml,sha256=ROO8soYXBbZIFYZcGngA_K5XHgIAFCmuAGfZCJgMmd0,384
62
+ dmu_data/plotting/tests/simple.yaml,sha256=Xc59Pjfb3BKMicLVBxODVqomHFupcb5GvefKbKHCQWQ,195
63
+ dmu_data/plotting/tests/stats.yaml,sha256=fSZjoV-xPnukpCH2OAXsz_SNPjI113qzDg8Ln3spaaA,165
64
+ dmu_data/plotting/tests/styling.yaml,sha256=ZglA4fG6gr5Q_K2VinwVDPjIitiFizCzxr-KsHw2ERI,370
65
+ dmu_data/plotting/tests/title.yaml,sha256=bawKp9aGpeRrHzv69BOCbFX8sq9bb3Es9tdsPTE7jIk,333
66
+ dmu_data/plotting/tests/weights.yaml,sha256=RWQ1KxbCq-uO62WJ2AoY4h5Umc37zG35s-TpKnNMABI,312
67
+ dmu_data/rfile/friends.yaml,sha256=sEGKFKK0q1U6b9qlfHUFBLZW0FeruR1t2LCOo6Ck1Rg,264
68
+ dmu_data/stats/fitter/test_simple.yaml,sha256=lBw6igBT57BZnuG3GgoxOiXTMFHfs5LchbI3Ubb8Qz0,1549
69
+ dmu_data/stats/kde_optimizer/control.json,sha256=EiArsHUAHBmzw4gmaNyOOW1ziYtNhdelIAqc3EH0K_M,1327616
70
+ dmu_data/stats/kde_optimizer/signal.json,sha256=MocwnYizcKki4dlxEIsWwE8HzY-ZBQaUo-lrCR5N3Tw,1327616
71
+ dmu_data/stats/parameters/data.yaml,sha256=lNmuolhUQmwB6sxHQvBRm-Kz5MUW_H1qAouynzBiWvs,2087
72
+ dmu_data/tests/config.json,sha256=QSfx-irgPV-BHAVe1Xe1dgiVkZGPp0fxb9OhXeVaEBg,60
73
+ dmu_data/tests/config.yaml,sha256=rFTk9PSFOgEVEcGDxr4K9vFIUrCVhbEMUoj683Py1AQ,38
74
+ dmu_data/tests/pdf_to_tex.txt,sha256=yzzH1L7P2SOFrVxS737Ykg1SlcD0jhrrBwQGsui2oAQ,3854
75
+ dmu_data/text/transform.toml,sha256=R-832BZalzHZ6c5gD6jtT_Hj8BCsM5vxa1v6oeiwaP4,94
76
+ dmu_data/text/transform.txt,sha256=EX760da6Vkf-_EPxnQlC5hGSkfFhJCCGCD19NU-1Qto,44
77
+ dmu_data/text/transform_set.toml,sha256=Jeh7BTz82idqvbOQJtl9-ur56mZkzDn5WtvmIb48LoE,150
78
+ dmu_data/text/transform_set.txt,sha256=1KivMoP9LxPn9955QrRmOzjEqduEjhTetQ9MXykO5LY,46
79
+ dmu_data/text/transform_trf.txt,sha256=zxBRTgcSmX7RdqfmWF88W1YqbyNHa4Ccruf1MmnYv2A,74
80
+ dmu_scripts/git/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
81
+ dmu_scripts/kerberos/check_expiration,sha256=PRJopcyFSeiAHdWpLEZp6mu_OKctUdIJj0HZfC0EWxg,308
82
+ dmu_scripts/kerberos/convert_certificate,sha256=_4k4fmxpK-MbSLkkRYEPLQc9twfYBqOIiYZqL9yAXKE,445
83
+ dmu_scripts/ml/compare_classifiers.py,sha256=XuHdcVyDLFGoKfvfv6YrgIavRpjpMrnBSqUnlliD7ew,2312
84
+ dmu_scripts/physics/check_truth.py,sha256=b1P_Pa9ef6VcFtyY6Y9KS9Om9L-QrCBjDKp4dqca0PQ,3964
85
+ dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki7ZQynxXX9Q,9540
86
+ dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
87
+ dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
88
+ dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
89
+ data_manipulation_utilities-0.2.8.dev714.dist-info/METADATA,sha256=M5n-tPUt3o_0kY4viuQj6lbP4JQxWhpxkSnWCW29PFg,50263
90
+ data_manipulation_utilities-0.2.8.dev714.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
91
+ data_manipulation_utilities-0.2.8.dev714.dist-info/entry_points.txt,sha256=-02cr8ibY6L_reX-_Owz2N7OUQyTAwydRIvLr9kKZK0,332
92
+ data_manipulation_utilities-0.2.8.dev714.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
93
+ data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,5 +1,6 @@
1
1
  [console_scripts]
2
2
  check_truth = dmu_scripts.physics.check_truth:main
3
+ compare_classifiers = dmu_scripts.ml.compare_classifiers:main
3
4
  compare_root_files = dmu_scripts.rfile.compare_root_files:main
4
5
  coned = dmu_scripts.ssh.coned:main
5
6
  print_trees = dmu_scripts.rfile.print_trees:main
dmu/__init__.py ADDED
File without changes
dmu/generic/hashing.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Module with functions needed to provide hashes
3
3
  '''
4
4
 
5
+ import os
5
6
  import json
6
7
  import hashlib
7
8
  from typing import Any
@@ -12,12 +13,10 @@ from dmu.logging.log_store import LogStore
12
13
  log=LogStore.add_logger('dmu:generic.hashing')
13
14
  # ------------------------------------
14
15
  def _object_to_string(obj : Any) -> str:
15
- try:
16
- string = json.dumps(obj)
17
- except Exception as exc:
18
- raise ValueError(f'Cannot hash object: {obj}') from exc
16
+ def default_encoder(x):
17
+ raise TypeError(f"Unserializable type: {type(x)}")
19
18
 
20
- return string
19
+ return json.dumps(obj, sort_keys=True, default=default_encoder)
21
20
  # ------------------------------------
22
21
  def _dataframe_to_hash(df : pnd.DataFrame) -> str:
23
22
  sr_hash = pnd.util.hash_pandas_object(df, index=True)
@@ -29,16 +28,43 @@ def _dataframe_to_hash(df : pnd.DataFrame) -> str:
29
28
  # ------------------------------------
30
29
  def hash_object(obj : Any) -> str:
31
30
  '''
32
- Function taking a python object and returning
31
+ Function taking a python object and returning
33
32
  a string representing the hash
34
33
  '''
35
34
 
36
35
  if isinstance(obj, pnd.DataFrame):
37
- return _dataframe_to_hash(df=obj)
36
+ value = _dataframe_to_hash(df=obj)
37
+ value = value[:10]
38
+
39
+ return value
38
40
 
39
41
  string = _object_to_string(obj=obj)
40
42
  string_bin = string.encode('utf-8')
41
43
  hsh = hashlib.sha256(string_bin)
44
+ value = hsh.hexdigest()
45
+ value = value[:10]
46
+
47
+ return value
48
+ # ------------------------------------
49
+ def hash_file(path : str) -> str:
50
+ '''
51
+ Parameters
52
+ ----------------
53
+ path: Path to file whose content has to be hashed
54
+
55
+ Returns
56
+ ----------------
57
+ A string representing the hash
58
+ '''
59
+ if not os.path.isfile(path):
60
+ raise FileNotFoundError(f'Cannot find: {path}')
61
+
62
+ h = hashlib.sha256()
63
+ with open(path, 'rb') as f:
64
+ for chunk in iter(lambda: f.read(8192), b''):
65
+ h.update(chunk)
66
+
67
+ value = h.hexdigest()
42
68
 
43
- return hsh.hexdigest()
69
+ return value[:10]
44
70
  # ------------------------------------
dmu/generic/utilities.py CHANGED
@@ -4,17 +4,69 @@ Module containing generic utility functions
4
4
  import os
5
5
  import time
6
6
  import json
7
+ import pickle
7
8
  import inspect
8
-
9
- from typing import Callable
10
-
9
+ from importlib.resources import files
10
+ from typing import Callable, Any
11
11
  from functools import wraps
12
+ from contextlib import contextmanager
13
+
14
+ import yaml
15
+ from omegaconf import OmegaConf, DictConfig
16
+ from dmu.generic import hashing
17
+ from dmu.generic import utilities as gut
12
18
  from dmu.logging.log_store import LogStore
13
19
 
14
20
  TIMER_ON=False
15
21
 
16
22
  log = LogStore.add_logger('dmu:generic:utilities')
23
+ # --------------------------------
24
+ class BlockStyleDumper(yaml.SafeDumper):
25
+ '''
26
+ Class needed to specify proper indentation when
27
+ dumping data to YAML files
28
+ '''
29
+ def increase_indent(self, flow=False, indentless=False):
30
+ return super().increase_indent(flow=flow, indentless=False)
31
+ # ---------------------------------
32
+ def load_data(package : str, fpath : str) -> Any:
33
+ '''
34
+ This function will load a YAML or JSON file from a data package
35
+
36
+ Parameters
37
+ ---------------------
38
+ package: Data package, e.g. `dmu_data`
39
+ path : Path to YAML/JSON file, relative to the data package
40
+
41
+ Returns
42
+ ---------------------
43
+ Dictionary or whatever structure the file is holding
44
+ '''
45
+
46
+ cpath = files(package).joinpath(fpath)
47
+ cpath = str(cpath)
48
+ data = load_json(cpath)
49
+
50
+ return data
51
+ # --------------------------------
52
+ def load_conf(package : str, fpath : str) -> DictConfig:
53
+ '''
54
+ This function will load a YAML or JSON file from a data package
55
+
56
+ Parameters
57
+ ---------------------
58
+ package: Data package, e.g. `dmu_data`
59
+ path : Path to YAML/JSON file, relative to the data package
60
+
61
+ Returns
62
+ ---------------------
63
+ DictConfig class from the OmegaConf package
64
+ '''
65
+
66
+ cpath = files(package).joinpath(fpath)
67
+ cfg = OmegaConf.load(cpath)
17
68
 
69
+ return cfg
18
70
  # --------------------------------
19
71
  def _get_module_name( fun : Callable) -> str:
20
72
  mod = inspect.getmodule(fun)
@@ -28,7 +80,7 @@ def timeit(f):
28
80
  Decorator used to time functions, it is turned off by default, can be turned on with:
29
81
 
30
82
  from dmu.generic.utilities import TIMER_ON
31
- from dmu.generic.utilities import timeit
83
+ from dmu.generic.utilities import timeit
32
84
 
33
85
  TIMER_ON=True
34
86
 
@@ -54,29 +106,130 @@ def timeit(f):
54
106
  # --------------------------------
55
107
  def dump_json(data, path : str, sort_keys : bool = False) -> None:
56
108
  '''
57
- Saves data as JSON
109
+ Saves data as JSON or YAML, depending on the extension, supported .json, .yaml, .yml
58
110
 
59
111
  Parameters
60
112
  data : dictionary, list, etc
61
- path : Path to JSON file where to save it
62
- sort_keys: Will set sort_keys argument of json.dump function
113
+ path : Path to output file where to save it
114
+ sort_keys: Will set sort_keys argument of json.dump function
63
115
  '''
64
116
  dir_name = os.path.dirname(path)
65
117
  os.makedirs(dir_name, exist_ok=True)
66
118
 
67
119
  with open(path, 'w', encoding='utf-8') as ofile:
68
- json.dump(data, ofile, indent=4, sort_keys=sort_keys)
120
+ if path.endswith('.json'):
121
+ json.dump(data, ofile, indent=4, sort_keys=sort_keys)
122
+ return
123
+
124
+ if path.endswith('.yaml') or path.endswith('.yml'):
125
+ yaml.dump(data, ofile, Dumper=BlockStyleDumper, sort_keys=sort_keys)
126
+ return
127
+
128
+ raise NotImplementedError(f'Cannot deduce format from extension in path: {path}')
69
129
  # --------------------------------
70
130
  def load_json(path : str):
71
131
  '''
72
- Loads data from JSON
132
+ Loads data from JSON or YAML, depending on extension of files, supported .json, .yaml, .yml
73
133
 
74
134
  Parameters
75
- path : Path to JSON file where data is saved
135
+ path : Path to outut file where data is saved
76
136
  '''
77
137
 
78
138
  with open(path, encoding='utf-8') as ofile:
79
- data = json.load(ofile)
139
+ if path.endswith('.json'):
140
+ data = json.load(ofile)
141
+ return data
142
+
143
+ if path.endswith('.yaml') or path.endswith('.yml'):
144
+ data = yaml.safe_load(ofile)
145
+ return data
146
+
147
+ raise NotImplementedError(f'Cannot deduce format from extension in path: {path}')
148
+ # --------------------------------
149
+ def dump_pickle(data, path : str) -> None:
150
+ '''
151
+ Saves data as pickle file
152
+
153
+ Parameters
154
+ data : dictionary, list, etc
155
+ path : Path to output file where to save it
156
+ '''
157
+ dir_name = os.path.dirname(path)
158
+ os.makedirs(dir_name, exist_ok=True)
159
+
160
+ with open(path, 'wb') as ofile:
161
+ pickle.dump(data, ofile)
162
+ # --------------------------------
163
+ def load_pickle(path : str) -> None:
164
+ '''
165
+ loads data file
166
+
167
+ Parameters
168
+ path : Path to output file where to save it
169
+ '''
170
+ with open(path, 'rb') as ofile:
171
+ data = pickle.load(ofile)
80
172
 
81
173
  return data
82
174
  # --------------------------------
175
+ @contextmanager
176
+ def silent_import():
177
+ '''
178
+ In charge of suppressing messages
179
+ of imported modules
180
+ '''
181
+ saved_stdout_fd = os.dup(1)
182
+ saved_stderr_fd = os.dup(2)
183
+
184
+ with open(os.devnull, 'w', encoding='utf-8') as devnull:
185
+ os.dup2(devnull.fileno(), 1)
186
+ os.dup2(devnull.fileno(), 2)
187
+ try:
188
+ yield
189
+ finally:
190
+ os.dup2(saved_stdout_fd, 1)
191
+ os.dup2(saved_stderr_fd, 2)
192
+ os.close(saved_stdout_fd)
193
+ os.close(saved_stderr_fd)
194
+ # --------------------------------
195
+ # Caching
196
+ # --------------------------------
197
+ def cache_data(obj : Any, hash_obj : Any) -> None:
198
+ '''
199
+ Will save data to a text file using a name from a hash
200
+
201
+ Parameters
202
+ -----------
203
+ obj : Object that can be saved to a text file, e.g. list, number, dictionary
204
+ hash_obj : Object that can be used to get hash e.g. immutable
205
+ '''
206
+ try:
207
+ json.dumps(obj)
208
+ except Exception as exc:
209
+ raise ValueError('Object is not JSON serializable') from exc
210
+
211
+ val = hashing.hash_object(hash_obj)
212
+ path = f'/tmp/dmu/cache/{val}.json'
213
+ gut.dump_json(obj, path)
214
+ # --------------------------------
215
+ def load_cached(hash_obj : Any, on_fail : Any = None) -> Any:
216
+ '''
217
+ Loads data corresponding to hash from hash_obj
218
+
219
+ Parameters
220
+ ---------------
221
+ hash_obj: Object used to calculate hash, which is in the file name
222
+ on_fail : Value returned if no data was found.
223
+ By default None, and it will just raise a FileNotFoundError
224
+ '''
225
+ val = hashing.hash_object(hash_obj)
226
+ path = f'/tmp/dmu/cache/{val}.json'
227
+ if os.path.isfile(path):
228
+ data = gut.load_json(path)
229
+ return data
230
+
231
+ if on_fail is not None:
232
+ return on_fail
233
+
234
+ raise FileNotFoundError(f'Cannot find cached data at: {path}')
235
+ # --------------------------------
dmu/logging/log_store.py CHANGED
@@ -1,10 +1,11 @@
1
1
  '''
2
2
  Module holding LogStore
3
3
  '''
4
-
5
4
  import logging
6
- from logging import Logger
5
+ import contextlib
6
+ from typing import Union
7
7
 
8
+ from logging import Logger
8
9
  import logzero
9
10
 
10
11
  #------------------------------------------------------------
@@ -40,6 +41,36 @@ class LogStore:
40
41
  backend = 'logging'
41
42
  #--------------------------
42
43
  @staticmethod
44
+ @contextlib.contextmanager
45
+ def level(name : str, lvl : int) -> None:
46
+ '''
47
+ Context manager used to set the logging level of a given logger
48
+
49
+ Parameters
50
+ ------------------
51
+ name : Name of logger
52
+ lvl : Integer representing logging level
53
+ '''
54
+ log = LogStore.get_logger(name=name)
55
+ if log is None:
56
+ raise ValueError(f'Cannot find logger {name}')
57
+
58
+ old_lvl = log.getEffectiveLevel()
59
+
60
+ LogStore.set_level(name, lvl)
61
+ try:
62
+ yield
63
+ finally:
64
+ LogStore.set_level(name, old_lvl)
65
+ #--------------------------
66
+ @staticmethod
67
+ def get_logger(name : str) -> Union[Logger,None]:
68
+ '''
69
+ Returns logger for a given name or None, if no logger found for that name
70
+ '''
71
+ return LogStore.d_logger.get(name)
72
+ #--------------------------
73
+ @staticmethod
43
74
  def add_logger(name : str, exists_ok : bool = False) -> Logger:
44
75
  '''
45
76
  Will use underlying logging library logzero/logging, etc to make logger
@@ -78,6 +109,7 @@ class LogStore:
78
109
  @staticmethod
79
110
  def _get_logging_logger(name : str, level : int) -> Logger:
80
111
  logger = logging.getLogger(name=name)
112
+ logger.propagate = False
81
113
 
82
114
  logger.setLevel(level)
83
115
 
@@ -0,0 +1,96 @@
1
+ '''
2
+ Module containing code meant to deal with logging of
3
+ third party tools
4
+ '''
5
+ import os
6
+ import sys
7
+ import time
8
+ import threading
9
+ from io import StringIO
10
+ from contextlib import contextmanager
11
+ from dmu.logging.log_store import LogStore
12
+
13
+ log = LogStore.add_logger('dmu:logging:messages')
14
+ # --------------------------------
15
+ class FilteredStderr:
16
+ '''
17
+ This class is meant to be used to filter the messages
18
+ in the error stream by substrings
19
+ '''
20
+ # --------------------------------
21
+ def __init__(
22
+ self,
23
+ banned_substrings : list[str],
24
+ capture_stream : StringIO):
25
+ '''
26
+ Parameters
27
+ -------------
28
+ banned_substrings : List of substrings that, if found in error message, will drop error
29
+ capture_stream : Used to store error stream filtered messages, expected to be sys.__stderr__
30
+ '''
31
+ self._banned = banned_substrings
32
+ self._capture_stream = capture_stream
33
+ # --------------------------------
34
+ def write(self, message : str):
35
+ '''
36
+ Should allow filtering error messages
37
+ '''
38
+ if not any(bad in message for bad in self._banned):
39
+ # This will make it to the error messages
40
+ self._capture_stream.write(message)
41
+ # --------------------------------
42
+ def flush(self):
43
+ '''
44
+ Should override the error stream's flush method
45
+ '''
46
+ self._capture_stream.flush()
47
+ # --------------------------------
48
+ @contextmanager
49
+ def filter_stderr(
50
+ banned_substrings : list[str],
51
+ capture_stream : StringIO|None=None):
52
+ '''
53
+ This contextmanager will suppress error messages
54
+
55
+ Parameters
56
+ -----------------
57
+ banned_substrings : List of substrings that need to be found in error messages
58
+ in order for them to be suppressed
59
+ capture_stream : Buffer needed to run tests, not needed for normal use
60
+ '''
61
+ if capture_stream is None:
62
+ capture_stream = sys.__stderr__
63
+
64
+ read_fd, write_fd = os.pipe()
65
+ saved_fd = os.dup(2)
66
+
67
+ os.dup2(write_fd, 2)
68
+ os.close(write_fd)
69
+
70
+ filtered = FilteredStderr(banned_substrings, capture_stream)
71
+ reader_finished = threading.Event()
72
+
73
+ def reader():
74
+ try:
75
+ with os.fdopen(read_fd, 'r', buffering=1) as pipe:
76
+ while True:
77
+ line = pipe.readline()
78
+ if not line:
79
+ break
80
+ filtered.write(line)
81
+ filtered.flush()
82
+ finally:
83
+ reader_finished.set()
84
+
85
+ thread = threading.Thread(target=reader, daemon=True)
86
+ thread.start()
87
+
88
+ try:
89
+ yield
90
+ finally:
91
+ os.dup2(saved_fd, 2)
92
+ os.close(saved_fd)
93
+
94
+ time.sleep(0.1)
95
+ reader_finished.wait(timeout=1.0)
96
+ # --------------------------------
dmu/ml/cv_classifier.py CHANGED
@@ -37,17 +37,17 @@ class CVClassifier(GradientBoostingClassifier):
37
37
 
38
38
  self._s_hash = set()
39
39
  self._data = {}
40
- self._l_ft_name = None
40
+ self._l_ft_name : list[str]
41
41
  # ----------------------------------
42
42
  @property
43
- def features(self):
43
+ def features(self) -> list[str]:
44
44
  '''
45
45
  Returns list of feature names used in training dataset
46
46
  '''
47
47
  return self._l_ft_name
48
48
  # ----------------------------------
49
49
  @property
50
- def hashes(self):
50
+ def hashes(self) -> set[str]:
51
51
  '''
52
52
  Will return set with hashes of training data
53
53
  '''
dmu/ml/cv_diagnostics.py CHANGED
@@ -186,6 +186,9 @@ class CVDiagnostics:
186
186
  raise NotImplementedError(f'Correlation coefficient {method} not implemented')
187
187
  # -------------------------
188
188
  def _plot_cutflow(self) -> None:
189
+ '''
190
+ Plot the 'mass' column for different values of working point
191
+ '''
189
192
  if 'overlay' not in self._cfg['correlations']['target']:
190
193
  log.debug('Not plotting cutflow of target distribution')
191
194
  return
@@ -0,0 +1,58 @@
1
+ '''
2
+ This module contains the class CVPerformance
3
+ '''
4
+ # pylint: disable=too-many-positional-arguments, too-many-arguments
5
+
6
+ from ROOT import RDataFrame
7
+ from dmu.ml.cv_classifier import CVClassifier
8
+ from dmu.ml.cv_predict import CVPredict
9
+ from dmu.ml.train_mva import TrainMva
10
+ from dmu.logging.log_store import LogStore
11
+
12
+ log=LogStore.add_logger('dmu:ml:cv_performance')
13
+ # -----------------------------------------------------
14
+ class CVPerformance:
15
+ '''
16
+ This class is meant to:
17
+
18
+ - Compare the classifier performance, through the ROC curve, of a model, for a given background and signal sample
19
+ '''
20
+ # ---------------------------
21
+ def plot_roc(
22
+ self,
23
+ name : str,
24
+ color : str,
25
+ sig : RDataFrame,
26
+ bkg : RDataFrame,
27
+ model : list[CVClassifier] ) -> float:
28
+ '''
29
+ Method in charge of picking up model and data and plotting ROC curve
30
+
31
+ Parameters
32
+ --------------------------
33
+ name : Label of combination, used for plots
34
+ sig : ROOT dataframe storing signal samples
35
+ bkg : ROOT dataframe storing background samples
36
+ model: List of instances of the CVClassifier
37
+
38
+ Returns
39
+ --------------------------
40
+ Area under the ROC curve
41
+ '''
42
+ log.info(f'Loading {name}')
43
+
44
+ cvp_sig = CVPredict(models=model, rdf=sig)
45
+ arr_sig = cvp_sig.predict()
46
+
47
+ cvp_bkg = CVPredict(models=model, rdf=bkg)
48
+ arr_bkg = cvp_bkg.predict()
49
+
50
+ _, _, auc = TrainMva.plot_roc_from_prob(
51
+ arr_sig_prb=arr_sig,
52
+ arr_bkg_prb=arr_bkg,
53
+ kind = name,
54
+ color = color, # This should allow the function to pick kind
55
+ ifold = 999) # for the label
56
+
57
+ return auc
58
+ # -----------------------------------------------------