wums 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wums/fitutilsjax.py ADDED
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ import scipy
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ def chisqloss(xvals, yvals, yvariances, func, parms):
7
+ return jnp.sum( (func(xvals, parms) - yvals)**2/yvariances )
8
+
9
+ chisqloss_grad = jax.jit(jax.value_and_grad(chisqloss, argnums = 4), static_argnums = 3)
10
+
11
+ def _chisqloss_grad_hess(xvals, yvals, yvariances, func, parms):
12
+ def lossf(parms):
13
+ return chisqloss(xvals, yvals, yvariances, func, parms)
14
+
15
+ gradf = jax.grad(lossf)
16
+ hessf = jax.jacfwd(gradf)
17
+
18
+ loss = lossf(parms)
19
+ grad = gradf(parms)
20
+ hess = hessf(parms)
21
+
22
+ return loss, grad, hess
23
+
24
+ chisqloss_grad_hess = jax.jit(_chisqloss_grad_hess, static_argnums = 3)
25
+
26
+ def _chisqloss_hessp(xvals, yvals, yvariances, func, parms, p):
27
+ def lossf(parms):
28
+ return chisqloss(xvals, yvals, yvariances, func, parms)
29
+
30
+ gradf = jax.grad(lossf)
31
+ hessp = jax.jvp(gradf, (parms,), (p,))[1]
32
+ return hessp
33
+
34
+ chisqloss_hessp = jax.jit(_chisqloss_hessp, static_argnums = 3)
35
+
36
+ def fit_hist_jax(hist, func, parmvals, max_iter = 5, edmtol = 1e-5):
37
+
38
+ xvals = [jnp.array(center) for center in hist.axes.centers]
39
+ yvals = jnp.array(hist.values())
40
+ yvariances = jnp.array(hist.variances())
41
+
42
+ def scipy_loss(parmvals):
43
+ parms = jnp.array(parmvals)
44
+ loss, grad = chisqloss_grad(xvals, yvals, yvariances, func, parms)
45
+ return np.asarray(loss).item(), np.asarray(grad)
46
+
47
+ def scipy_hessp(parmvals, p):
48
+ parms = jnp.array(parmvals)
49
+ tangent = jnp.array(p)
50
+ hessp = chisqloss_hessp(xvals, yvals, yvariances, func, parms, tangent)
51
+ return np.asarray(hessp)
52
+
53
+ for iiter in range(max_iter):
54
+ res = scipy.optimize.minimize(scipy_loss, parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp)
55
+
56
+ parms = jnp.array(res.x)
57
+ loss, grad, hess = chisqloss_grad_hess(xvals, yvals, yvariances, func, parms)
58
+ loss, grad, hess = np.asarray(loss).item(), np.asarray(grad), np.asarray(hess)
59
+
60
+ eigvals = np.linalg.eigvalsh(hess)
61
+ cov = 2.*np.linalg.inv(hess)
62
+
63
+ gradv = grad[:, np.newaxis]
64
+ edmval = 0.5*gradv.transpose()@cov@gradv
65
+ edmval = edmval[0][0]
66
+
67
+ converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
68
+ if converged:
69
+ break
70
+
71
+ status = 1
72
+ covstatus = 1
73
+ if edmval < edmtol and np.abs(edmval) >= 0.:
74
+ status = 0
75
+ if eigvals[0] > 0.:
76
+ covstatus = 0
77
+
78
+ res = { "x" : res.x,
79
+ "cov" : cov,
80
+ "status" : status,
81
+ "covstatus" : covstatus,
82
+ "hess_eigvals" : eigvals,
83
+ "edmval" : edmval,
84
+ "chisqval" : loss }
85
+
86
+ return res
wums/logging.py CHANGED
@@ -42,19 +42,19 @@ def set_logging_level(log, verbosity):
42
42
  log.setLevel(logging_verboseLevel[max(0, min(4, verbosity))])
43
43
 
44
44
 
45
- def setup_logger(basefile, verbosity=3, no_colors=False, initName="wremnants"):
45
+ def setup_logger(basefile, verbosity=3, no_colors=False, initName="wums"):
46
46
 
47
47
  setup_func = setup_base_logger if no_colors else setup_color_logger
48
48
  logger = setup_func(os.path.basename(basefile), verbosity, initName)
49
49
  # count messages of base logger
50
- base_logger = logging.getLogger("wremnants")
50
+ base_logger = logging.getLogger("wums")
51
51
  add_logging_counter(base_logger)
52
52
  # stop total time
53
53
  add_time_info("Total time")
54
54
  return logger
55
55
 
56
56
 
57
- def setup_color_logger(name, verbosity, initName="wremnants"):
57
+ def setup_color_logger(name, verbosity, initName="wums"):
58
58
  base_logger = logging.getLogger(initName)
59
59
  # set console handler
60
60
  ch = logging.StreamHandler()
@@ -65,14 +65,14 @@ def setup_color_logger(name, verbosity, initName="wremnants"):
65
65
  return base_logger.getChild(name)
66
66
 
67
67
 
68
- def setup_base_logger(name, verbosity, initName="wremnants"):
68
+ def setup_base_logger(name, verbosity, initName="wums"):
69
69
  logging.basicConfig(format="%(levelname)s: %(message)s")
70
70
  base_logger = logging.getLogger(initName)
71
71
  set_logging_level(base_logger, verbosity)
72
72
  return base_logger.getChild(name)
73
73
 
74
74
 
75
- def child_logger(name, initName="wremnants"):
75
+ def child_logger(name, initName="wums"):
76
76
  # count messages of child logger
77
77
  logger = logging.getLogger(initName).getChild(name)
78
78
  add_logging_counter(logger)
@@ -110,7 +110,7 @@ def print_logging_count(logger, verbosity=logging.WARNING):
110
110
  )
111
111
 
112
112
 
113
- def add_time_info(tag, logger=logging.getLogger("wremnants")):
113
+ def add_time_info(tag, logger=logging.getLogger("wums")):
114
114
  if not hasattr(logger, "times"):
115
115
  logger.times = {}
116
116
  logger.times[tag] = time.time()
@@ -125,7 +125,7 @@ def print_time_info(logger):
125
125
 
126
126
 
127
127
  def summary(verbosity=logging.WARNING, extended=True):
128
- base_logger = logging.getLogger("wremnants")
128
+ base_logger = logging.getLogger("wums")
129
129
 
130
130
  base_logger.info(f"--------------------------------------")
131
131
  base_logger.info(f"----------- logger summary -----------")
@@ -141,5 +141,5 @@ def summary(verbosity=logging.WARNING, extended=True):
141
141
  # Iterate through all child loggers and print their names, levels, and counts
142
142
  all_loggers = logging.Logger.manager.loggerDict
143
143
  for logger_name, logger_obj in all_loggers.items():
144
- if logger_name.startswith("wremnants."):
144
+ if logger_name.startswith("wums."):
145
145
  print_logging_count(logger_obj, verbosity=verbosity)
wums/plot_tools.py CHANGED
@@ -601,14 +601,48 @@ def wrap_text(
601
601
  )
602
602
 
603
603
 
604
- def add_cms_decor(
605
- ax, label=None, lumi=None, loc=2, data=True, text_size=None, no_energy=False
604
+ def add_cms_decor(ax, *args, **kwargs):
605
+ add_decor(ax, "CMS", *args, **kwargs)
606
+
607
+
608
+ def add_decor(
609
+ ax, title, label=None, lumi=None, loc=2, data=True, text_size=None, no_energy=False
606
610
  ):
607
611
  text_size = get_textsize(ax, text_size)
612
+
613
+ if title in ["CMS", "ATLAS", "LHCb", "ALICE"]:
614
+ module = getattr(hep, title.lower())
615
+ make_text = module.text
616
+ make_label = module.label
617
+ else:
618
+ def make_text(text=None, **kwargs):
619
+ for key, value in dict(hep.rcParams.text._get_kwargs()).items():
620
+ if (
621
+ value is not None
622
+ and key not in kwargs
623
+ and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
624
+ ):
625
+ kwargs.setdefault(key, value)
626
+ kwargs.setdefault("italic", (False, True, False))
627
+ kwargs.setdefault("exp", title)
628
+ return hep.label.exp_text(text=text, **kwargs)
629
+
630
+ def make_label(**kwargs):
631
+ for key, value in dict(hep.rcParams.text._get_kwargs()).items():
632
+ if (
633
+ value is not None
634
+ and key not in kwargs
635
+ and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
636
+ ):
637
+ kwargs.setdefault(key, value)
638
+ kwargs.setdefault("italic", (False, True, False))
639
+ kwargs.setdefault("exp", title)
640
+ return hep.label.exp_label(**kwargs)
641
+
608
642
  if no_energy:
609
- hep.cms.text(ax=ax, text=label, loc=loc, fontsize=text_size)
643
+ make_text(ax=ax, text=label, loc=loc, fontsize=text_size)
610
644
  else:
611
- hep.cms.label(
645
+ make_label(
612
646
  ax=ax,
613
647
  lumi=lumi,
614
648
  lumi_format="{0:.3g}",
@@ -617,7 +651,32 @@ def add_cms_decor(
617
651
  data=data,
618
652
  loc=loc,
619
653
  )
620
-
654
+
655
+ # else:
656
+ # if loc==0:
657
+ # # above frame
658
+ # x = 0.0
659
+ # y = 1.0
660
+ # elif loc==1:
661
+ # # in frame
662
+ # x = 0.05
663
+ # y = 0.88
664
+ # elif loc==2:
665
+ # # upper left, label below title
666
+ # x = 0.05
667
+ # y = 0.88
668
+ # elif loc==2:
669
+ # #
670
+ # ax.text(
671
+ # x,
672
+ # y,
673
+ # args.title,
674
+ # transform=ax1.transAxes,
675
+ # fontweight="bold",
676
+ # fontsize=1.2 * text_size,
677
+ # )
678
+ # if label is not None:
679
+ # ax.text(0.05, 0.80, label, transform=ax.transAxes, fontstyle="italic")
621
680
 
622
681
  def makeStackPlotWithRatio(
623
682
  histInfo,
@@ -826,11 +885,11 @@ def makeStackPlotWithRatio(
826
885
  for x in (data_hist.sum(), hh.sumHists(stack).sum())
827
886
  ]
828
887
  varis = [
829
- x.variance if hasattr(x, "variance") else x ** 0.5
888
+ x.variance if hasattr(x, "variance") else x**0.5
830
889
  for x in (data_hist.sum(), hh.sumHists(stack).sum())
831
890
  ]
832
891
  scale = vals[0] / vals[1]
833
- unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2) ** 0.5
892
+ unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2)**0.5
834
893
  ndigits = -math.floor(math.log10(abs(unc))) + 1
835
894
  logger.info(
836
895
  f"Rescaling all processes by {round(scale,ndigits)} +/- {round(unc,ndigits)} to match data norm"
wums/tfutils.py ADDED
@@ -0,0 +1,81 @@
1
+ import tensorflow as tf
2
+
3
+ def function_to_tflite(funcs, input_signatures, func_names=""):
4
+ """Convert function to tflite model using python dynamic execution trickery to ensure that inputs
5
+ and outputs are alphabetically ordered, since this is apparently the only way to prevent tflite from
6
+ scrambling them"""
7
+
8
+ if not isinstance(funcs, list):
9
+ funcs = [funcs]
10
+ input_signatures = [input_signatures]
11
+ func_names = [func_names]
12
+ func_names = [funcs[iif].__name__ if func_names[iif]=="" else func_names[iif] for iif in range(len(funcs))]
13
+
14
+ def wrapped_func(iif, *args):
15
+ outputs = funcs[iif](*args)
16
+
17
+ if not isinstance(outputs, tuple):
18
+ outputs = (outputs,)
19
+
20
+ output_dict = {}
21
+ for i,output in enumerate(outputs):
22
+ output_name = f"output_{iif:05d}_{i:05d}"
23
+ output_dict[output_name] = output
24
+
25
+ return output_dict
26
+
27
+ arg_string = []
28
+ for iif, input_signature in enumerate(input_signatures):
29
+ inputs = []
30
+ for i in range(len(input_signature)):
31
+ input_name = f"input_{iif:05d}_{i:05d}"
32
+ inputs.append(input_name)
33
+ arg_string.append(", ".join(inputs))
34
+
35
+ def_string = ""
36
+ def_string += "def make_module(wrapped_func, input_signatures):\n"
37
+ def_string += " class Export_Module(tf.Module):\n"
38
+ for i, func in enumerate(funcs):
39
+ def_string += f" @tf.function(input_signature = input_signatures[{i}])\n"
40
+ def_string += f" def {func_names[i]}(self, {arg_string[i]}):\n"
41
+ def_string += f" return wrapped_func({i}, {arg_string[i]})\n"
42
+ def_string += " return Export_Module"
43
+
44
+ ldict = {}
45
+ exec(def_string, globals(), ldict)
46
+
47
+ make_module = ldict["make_module"]
48
+ Export_Module = make_module(wrapped_func, input_signatures)
49
+
50
+ module = Export_Module()
51
+ concrete_functions = [getattr(module, func_name).get_concrete_function() for func_name in func_names]
52
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(concrete_functions, module)
53
+
54
+ # enable TenorFlow ops and DISABLE builtin TFLite ops since these apparently slow things down
55
+ converter.target_spec.supported_ops = [
56
+ tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
57
+ ]
58
+
59
+ converter._experimental_allow_all_select_tf_ops = True
60
+
61
+ tflite_model = converter.convert()
62
+
63
+ test_interp = tf.lite.Interpreter(model_content = tflite_model)
64
+ print(test_interp.get_input_details())
65
+ print(test_interp.get_output_details())
66
+ print(test_interp.get_signature_list())
67
+
68
+ return tflite_model
69
+
70
+
71
+
72
+ def function_to_saved_model(func, input_signature, output):
73
+
74
+ class Export_Module(tf.Module):
75
+ @tf.function(input_signature = input_signature)
76
+ def __call__(self, *args):
77
+ return func(*args)
78
+
79
+ model = Export_Module()
80
+
81
+ tf.saved_model.save(model, output)
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.2
2
+ Name: wums
3
+ Version: 0.1.7
4
+ Summary: .
5
+ Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/WMass/wums
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ Requires-Dist: hist
15
+ Requires-Dist: numpy
16
+ Provides-Extra: plotting
17
+ Requires-Dist: matplotlib; extra == "plotting"
18
+ Requires-Dist: mplhep; extra == "plotting"
19
+ Provides-Extra: fitting
20
+ Requires-Dist: tensorflow; extra == "fitting"
21
+ Requires-Dist: jax; extra == "fitting"
22
+ Requires-Dist: scipy; extra == "fitting"
23
+ Provides-Extra: pickling
24
+ Requires-Dist: boost_histogram; extra == "pickling"
25
+ Requires-Dist: h5py; extra == "pickling"
26
+ Requires-Dist: hdf5plugin; extra == "pickling"
27
+ Requires-Dist: lz4; extra == "pickling"
28
+ Provides-Extra: all
29
+ Requires-Dist: plotting; extra == "all"
30
+ Requires-Dist: fitting; extra == "all"
31
+ Requires-Dist: pickling; extra == "all"
32
+
33
+ # WUMS: Wremnants Utilities, Modules, and other Stuff
34
+
35
+ As the name suggests, this is a collection of different thins, all python based:
36
+ - Fitting with tensorflow or jax
37
+ - Custom pickling h5py objects
38
+ - Plotting functionality
39
+
40
+ ## Install
41
+
42
+ The `wums` package can be pip installed with minimal dependencies:
43
+ ```bash
44
+ pip install wums
45
+ ```
46
+ Different dependencies can be added with `plotting`, `fitting`, `pickling` to use the corresponding scripts.
47
+ For example, one can install with
48
+ ```bash
49
+ pip install wums[plotting,fitting]
50
+ ```
51
+ Or all dependencies with
52
+ ```bash
53
+ pip install wums[all]
54
+ ```
@@ -0,0 +1,14 @@
1
+ wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ wums/boostHistHelpers.py,sha256=mgdPXAgmxriqoOhrhMctyZcfwEOPfV07V27CvGt2sk8,39260
3
+ wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
4
+ wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
5
+ wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
6
+ wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
7
+ wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
8
+ wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
9
+ wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
10
+ wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
11
+ wums-0.1.7.dist-info/METADATA,sha256=GrQyVuatMvHdallbstH7YdiACEMLIo5isHyugfFawW8,1784
12
+ wums-0.1.7.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
13
+ wums-0.1.7.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
14
+ wums-0.1.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,29 +0,0 @@
1
- Metadata-Version: 2.2
2
- Name: wums
3
- Version: 0.1.5
4
- Summary: .
5
- Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>
6
- License: MIT
7
- Project-URL: Homepage, https://github.com/WMass/wums
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Operating System :: OS Independent
12
- Requires-Python: >=3.8
13
- Description-Content-Type: text/markdown
14
- Requires-Dist: boost_histogram
15
- Requires-Dist: h5py
16
- Requires-Dist: hdf5plugin
17
- Requires-Dist: hist
18
- Requires-Dist: lz4
19
- Requires-Dist: matplotlib
20
- Requires-Dist: mplhep
21
- Requires-Dist: numpy
22
- Requires-Dist: uproot
23
-
24
- # WUMS: Wremnants Utilities, Modules, and other Stuff
25
-
26
- The `wums` package can be pip installed:
27
- ```bash
28
- pip install wums
29
- ```
@@ -1,11 +0,0 @@
1
- wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- wums/boostHistHelpers.py,sha256=F4SwQEVjNObFscfs0qrJEyOHYNKqUCmusW8HIF1o-0c,38993
3
- wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
4
- wums/logging.py,sha256=zNnLVJUwG3HMvr9NeXmiheX07VmsnSt8cQ6R4q4XBk4,4534
5
- wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
6
- wums/plot_tools.py,sha256=a-sf0gy2xNbcHUcUBUmjkY5lq2RBm6bia0CyyPP1UDI,53920
7
- wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
8
- wums-0.1.5.dist-info/METADATA,sha256=S7OCqqlGWHYGf50hrqw8fx7NlWliCFv-rVjFW9ggpEc,843
9
- wums-0.1.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
- wums-0.1.5.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
11
- wums-0.1.5.dist-info/RECORD,,