wums 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wums/fitutilsjax.py ADDED
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ import scipy
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ def chisqloss(xvals, yvals, yvariances, func, parms):
7
+ return jnp.sum( (func(xvals, parms) - yvals)**2/yvariances )
8
+
9
+ chisqloss_grad = jax.jit(jax.value_and_grad(chisqloss, argnums = 4), static_argnums = 3)
10
+
11
+ def _chisqloss_grad_hess(xvals, yvals, yvariances, func, parms):
12
+ def lossf(parms):
13
+ return chisqloss(xvals, yvals, yvariances, func, parms)
14
+
15
+ gradf = jax.grad(lossf)
16
+ hessf = jax.jacfwd(gradf)
17
+
18
+ loss = lossf(parms)
19
+ grad = gradf(parms)
20
+ hess = hessf(parms)
21
+
22
+ return loss, grad, hess
23
+
24
+ chisqloss_grad_hess = jax.jit(_chisqloss_grad_hess, static_argnums = 3)
25
+
26
+ def _chisqloss_hessp(xvals, yvals, yvariances, func, parms, p):
27
+ def lossf(parms):
28
+ return chisqloss(xvals, yvals, yvariances, func, parms)
29
+
30
+ gradf = jax.grad(lossf)
31
+ hessp = jax.jvp(gradf, (parms,), (p,))[1]
32
+ return hessp
33
+
34
+ chisqloss_hessp = jax.jit(_chisqloss_hessp, static_argnums = 3)
35
+
36
+ def fit_hist_jax(hist, func, parmvals, max_iter = 5, edmtol = 1e-5):
37
+
38
+ xvals = [jnp.array(center) for center in hist.axes.centers]
39
+ yvals = jnp.array(hist.values())
40
+ yvariances = jnp.array(hist.variances())
41
+
42
+ def scipy_loss(parmvals):
43
+ parms = jnp.array(parmvals)
44
+ loss, grad = chisqloss_grad(xvals, yvals, yvariances, func, parms)
45
+ return np.asarray(loss).item(), np.asarray(grad)
46
+
47
+ def scipy_hessp(parmvals, p):
48
+ parms = jnp.array(parmvals)
49
+ tangent = jnp.array(p)
50
+ hessp = chisqloss_hessp(xvals, yvals, yvariances, func, parms, tangent)
51
+ return np.asarray(hessp)
52
+
53
+ for iiter in range(max_iter):
54
+ res = scipy.optimize.minimize(scipy_loss, parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp)
55
+
56
+ parms = jnp.array(res.x)
57
+ loss, grad, hess = chisqloss_grad_hess(xvals, yvals, yvariances, func, parms)
58
+ loss, grad, hess = np.asarray(loss).item(), np.asarray(grad), np.asarray(hess)
59
+
60
+ eigvals = np.linalg.eigvalsh(hess)
61
+ cov = 2.*np.linalg.inv(hess)
62
+
63
+ gradv = grad[:, np.newaxis]
64
+ edmval = 0.5*gradv.transpose()@cov@gradv
65
+ edmval = edmval[0][0]
66
+
67
+ converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
68
+ if converged:
69
+ break
70
+
71
+ status = 1
72
+ covstatus = 1
73
+ if edmval < edmtol and np.abs(edmval) >= 0.:
74
+ status = 0
75
+ if eigvals[0] > 0.:
76
+ covstatus = 0
77
+
78
+ res = { "x" : res.x,
79
+ "cov" : cov,
80
+ "status" : status,
81
+ "covstatus" : covstatus,
82
+ "hess_eigvals" : eigvals,
83
+ "edmval" : edmval,
84
+ "chisqval" : loss }
85
+
86
+ return res
wums/logging.py CHANGED
@@ -42,19 +42,19 @@ def set_logging_level(log, verbosity):
42
42
  log.setLevel(logging_verboseLevel[max(0, min(4, verbosity))])
43
43
 
44
44
 
45
- def setup_logger(basefile, verbosity=3, no_colors=False, initName="wremnants"):
45
+ def setup_logger(basefile, verbosity=3, no_colors=False, initName="wums"):
46
46
 
47
47
  setup_func = setup_base_logger if no_colors else setup_color_logger
48
48
  logger = setup_func(os.path.basename(basefile), verbosity, initName)
49
49
  # count messages of base logger
50
- base_logger = logging.getLogger("wremnants")
50
+ base_logger = logging.getLogger("wums")
51
51
  add_logging_counter(base_logger)
52
52
  # stop total time
53
53
  add_time_info("Total time")
54
54
  return logger
55
55
 
56
56
 
57
- def setup_color_logger(name, verbosity, initName="wremnants"):
57
+ def setup_color_logger(name, verbosity, initName="wums"):
58
58
  base_logger = logging.getLogger(initName)
59
59
  # set console handler
60
60
  ch = logging.StreamHandler()
@@ -65,14 +65,14 @@ def setup_color_logger(name, verbosity, initName="wremnants"):
65
65
  return base_logger.getChild(name)
66
66
 
67
67
 
68
- def setup_base_logger(name, verbosity, initName="wremnants"):
68
+ def setup_base_logger(name, verbosity, initName="wums"):
69
69
  logging.basicConfig(format="%(levelname)s: %(message)s")
70
70
  base_logger = logging.getLogger(initName)
71
71
  set_logging_level(base_logger, verbosity)
72
72
  return base_logger.getChild(name)
73
73
 
74
74
 
75
- def child_logger(name, initName="wremnants"):
75
+ def child_logger(name, initName="wums"):
76
76
  # count messages of child logger
77
77
  logger = logging.getLogger(initName).getChild(name)
78
78
  add_logging_counter(logger)
@@ -110,7 +110,7 @@ def print_logging_count(logger, verbosity=logging.WARNING):
110
110
  )
111
111
 
112
112
 
113
- def add_time_info(tag, logger=logging.getLogger("wremnants")):
113
+ def add_time_info(tag, logger=logging.getLogger("wums")):
114
114
  if not hasattr(logger, "times"):
115
115
  logger.times = {}
116
116
  logger.times[tag] = time.time()
@@ -125,7 +125,7 @@ def print_time_info(logger):
125
125
 
126
126
 
127
127
  def summary(verbosity=logging.WARNING, extended=True):
128
- base_logger = logging.getLogger("wremnants")
128
+ base_logger = logging.getLogger("wums")
129
129
 
130
130
  base_logger.info(f"--------------------------------------")
131
131
  base_logger.info(f"----------- logger summary -----------")
@@ -141,5 +141,5 @@ def summary(verbosity=logging.WARNING, extended=True):
141
141
  # Iterate through all child loggers and print their names, levels, and counts
142
142
  all_loggers = logging.Logger.manager.loggerDict
143
143
  for logger_name, logger_obj in all_loggers.items():
144
- if logger_name.startswith("wremnants."):
144
+ if logger_name.startswith("wums."):
145
145
  print_logging_count(logger_obj, verbosity=verbosity)
wums/plot_tools.py CHANGED
@@ -683,7 +683,6 @@ def makeStackPlotWithRatio(
683
683
  stackedProcs,
684
684
  histName="nominal",
685
685
  unstacked=None,
686
- fitresult=None,
687
686
  prefit=False,
688
687
  xlabel="",
689
688
  ylabel=None,
@@ -757,11 +756,6 @@ def makeStackPlotWithRatio(
757
756
  if xlim:
758
757
  h = h[complex(0, xlim[0]) : complex(0, xlim[1])]
759
758
 
760
- # If plotting from combine, apply the action to the underlying hist.
761
- # Don't do this for the generic case, as it screws up the ability to make multiple plots
762
- if fitresult:
763
- histInfo[k].hists[histName] = h
764
-
765
759
  if k != "Data":
766
760
  stack.append(h)
767
761
  else:
@@ -803,67 +797,6 @@ def makeStackPlotWithRatio(
803
797
  ratio_axes = None
804
798
  ax2 = None
805
799
 
806
- if fitresult:
807
- import uproot
808
-
809
- combine_result = uproot.open(fitresult)
810
-
811
- fittype = "prefit" if prefit else "postfit"
812
-
813
- # set histograms to prefit/postfit values
814
- for p in to_read:
815
-
816
- hname = f"expproc_{p}_{fittype}" if p != "Data" else "obs"
817
- vals = combine_result[hname].to_hist().values()
818
- if len(histInfo[p].hists[histName].values()) != len(vals):
819
- raise ValueError(
820
- f"The size of the combine histogram ({(vals.shape)}) is not consistent with the xlim or input hist ({histInfo[p].hists[histName].shape})"
821
- )
822
-
823
- histInfo[p].hists[histName].values()[...] = vals
824
- if p == "Data":
825
- histInfo[p].hists[histName].variances()[...] = vals
826
-
827
- # for postfit uncertaity bands
828
- axis = histInfo[to_read[0]].hists[histName].axes[0].edges
829
-
830
- # need to divide by bin width
831
- binwidth = axis[1:] - axis[:-1]
832
- hexp = combine_result[f"expfull_{fittype}"].to_hist()
833
- if hexp.storage_type != hist.storage.Weight:
834
- raise ValueError(
835
- f"Did not find uncertainties in {fittype} hist. Make sure you run combinetf with --computeHistErrors!"
836
- )
837
- nom = hexp.values() / binwidth
838
- std = np.sqrt(hexp.variances()) / binwidth
839
-
840
- hatchstyle = "///"
841
- ax1.fill_between(
842
- axis,
843
- np.append(nom + std, (nom + std)[-1]),
844
- np.append(nom - std, (nom - std)[-1]),
845
- step="post",
846
- facecolor="none",
847
- zorder=2,
848
- hatch=hatchstyle,
849
- edgecolor="k",
850
- linewidth=0.0,
851
- label="Uncertainty",
852
- )
853
-
854
- if add_ratio:
855
- ax2.fill_between(
856
- axis,
857
- np.append((nom + std) / nom, ((nom + std) / nom)[-1]),
858
- np.append((nom - std) / nom, ((nom - std) / nom)[-1]),
859
- step="post",
860
- facecolor="none",
861
- zorder=2,
862
- hatch=hatchstyle,
863
- edgecolor="k",
864
- linewidth=0.0,
865
- )
866
-
867
800
  opts = dict(stack=not no_stack, flow=flow)
868
801
  optsr = opts.copy() # no binwnorm for ratio axis
869
802
  optsr["density"] = density
@@ -994,7 +927,7 @@ def makeStackPlotWithRatio(
994
927
 
995
928
  for i, (proc, style) in enumerate(zip(unstacked, linestyles)):
996
929
  unstack = histInfo[proc].hists[histName]
997
- if not fitresult or proc not in to_read:
930
+ if proc not in to_read:
998
931
  unstack = action(unstack)[select]
999
932
  if proc != "Data":
1000
933
  unstack = unstack * scale
wums/tfutils.py ADDED
@@ -0,0 +1,81 @@
1
+ import tensorflow as tf
2
+
3
+ def function_to_tflite(funcs, input_signatures, func_names=""):
4
+ """Convert function to tflite model using python dynamic execution trickery to ensure that inputs
5
+ and outputs are alphabetically ordered, since this is apparently the only way to prevent tflite from
6
+ scrambling them"""
7
+
8
+ if not isinstance(funcs, list):
9
+ funcs = [funcs]
10
+ input_signatures = [input_signatures]
11
+ func_names = [func_names]
12
+ func_names = [funcs[iif].__name__ if func_names[iif]=="" else func_names[iif] for iif in range(len(funcs))]
13
+
14
+ def wrapped_func(iif, *args):
15
+ outputs = funcs[iif](*args)
16
+
17
+ if not isinstance(outputs, tuple):
18
+ outputs = (outputs,)
19
+
20
+ output_dict = {}
21
+ for i,output in enumerate(outputs):
22
+ output_name = f"output_{iif:05d}_{i:05d}"
23
+ output_dict[output_name] = output
24
+
25
+ return output_dict
26
+
27
+ arg_string = []
28
+ for iif, input_signature in enumerate(input_signatures):
29
+ inputs = []
30
+ for i in range(len(input_signature)):
31
+ input_name = f"input_{iif:05d}_{i:05d}"
32
+ inputs.append(input_name)
33
+ arg_string.append(", ".join(inputs))
34
+
35
+ def_string = ""
36
+ def_string += "def make_module(wrapped_func, input_signatures):\n"
37
+ def_string += " class Export_Module(tf.Module):\n"
38
+ for i, func in enumerate(funcs):
39
+ def_string += f" @tf.function(input_signature = input_signatures[{i}])\n"
40
+ def_string += f" def {func_names[i]}(self, {arg_string[i]}):\n"
41
+ def_string += f" return wrapped_func({i}, {arg_string[i]})\n"
42
+ def_string += " return Export_Module"
43
+
44
+ ldict = {}
45
+ exec(def_string, globals(), ldict)
46
+
47
+ make_module = ldict["make_module"]
48
+ Export_Module = make_module(wrapped_func, input_signatures)
49
+
50
+ module = Export_Module()
51
+ concrete_functions = [getattr(module, func_name).get_concrete_function() for func_name in func_names]
52
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(concrete_functions, module)
53
+
54
+ # enable TenorFlow ops and DISABLE builtin TFLite ops since these apparently slow things down
55
+ converter.target_spec.supported_ops = [
56
+ tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
57
+ ]
58
+
59
+ converter._experimental_allow_all_select_tf_ops = True
60
+
61
+ tflite_model = converter.convert()
62
+
63
+ test_interp = tf.lite.Interpreter(model_content = tflite_model)
64
+ print(test_interp.get_input_details())
65
+ print(test_interp.get_output_details())
66
+ print(test_interp.get_signature_list())
67
+
68
+ return tflite_model
69
+
70
+
71
+
72
+ def function_to_saved_model(func, input_signature, output):
73
+
74
+ class Export_Module(tf.Module):
75
+ @tf.function(input_signature = input_signature)
76
+ def __call__(self, *args):
77
+ return func(*args)
78
+
79
+ model = Export_Module()
80
+
81
+ tf.saved_model.save(model, output)
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.2
2
+ Name: wums
3
+ Version: 0.1.8
4
+ Summary: .
5
+ Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/WMass/wums
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ Requires-Dist: hist
15
+ Requires-Dist: numpy
16
+ Provides-Extra: plotting
17
+ Requires-Dist: matplotlib; extra == "plotting"
18
+ Requires-Dist: mplhep; extra == "plotting"
19
+ Provides-Extra: fitting
20
+ Requires-Dist: tensorflow; extra == "fitting"
21
+ Requires-Dist: jax; extra == "fitting"
22
+ Requires-Dist: scipy; extra == "fitting"
23
+ Provides-Extra: pickling
24
+ Requires-Dist: boost_histogram; extra == "pickling"
25
+ Requires-Dist: h5py; extra == "pickling"
26
+ Requires-Dist: hdf5plugin; extra == "pickling"
27
+ Requires-Dist: lz4; extra == "pickling"
28
+ Provides-Extra: all
29
+ Requires-Dist: plotting; extra == "all"
30
+ Requires-Dist: fitting; extra == "all"
31
+ Requires-Dist: pickling; extra == "all"
32
+
33
+ # WUMS: Wremnants Utilities, Modules, and other Stuff
34
+
35
+ As the name suggests, this is a collection of different thins, all python based:
36
+ - Fitting with tensorflow or jax
37
+ - Custom pickling h5py objects
38
+ - Plotting functionality
39
+
40
+ ## Install
41
+
42
+ The `wums` package can be pip installed with minimal dependencies:
43
+ ```bash
44
+ pip install wums
45
+ ```
46
+ Different dependencies can be added with `plotting`, `fitting`, `pickling` to use the corresponding scripts.
47
+ For example, one can install with
48
+ ```bash
49
+ pip install wums[plotting,fitting]
50
+ ```
51
+ Or all dependencies with
52
+ ```bash
53
+ pip install wums[all]
54
+ ```
@@ -0,0 +1,16 @@
1
+ scripts/test/testsplinepdf.py,sha256=sXnmDjEXiO0OIHAXLXU4UxTD4_nLwUpoojCecfjyT04,1964
2
+ scripts/test/testsplinepdf2d.py,sha256=vGw9mq67f6aoymefLqv6CqF8teluva4Lx6tpbnC_NGU,8513
3
+ wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ wums/boostHistHelpers.py,sha256=mgdPXAgmxriqoOhrhMctyZcfwEOPfV07V27CvGt2sk8,39260
5
+ wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
6
+ wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
7
+ wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
8
+ wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
9
+ wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
10
+ wums/plot_tools.py,sha256=7GBQAO--wuP8aatkjy-ir1lQWpNrzMc1lSI6zSq3JXE,53502
11
+ wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
12
+ wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
13
+ wums-0.1.8.dist-info/METADATA,sha256=87fET64UzNDs6swv1-tcJWcuzVE5S3kEuLWOfy1JN6c,1784
14
+ wums-0.1.8.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
15
+ wums-0.1.8.dist-info/top_level.txt,sha256=cGGeFZQ8IwVw-BhgxMCTu5zfkgQelfF1wEFFWGhycds,13
16
+ wums-0.1.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,2 @@
1
+ scripts
2
+ wums
@@ -1,29 +0,0 @@
1
- Metadata-Version: 2.2
2
- Name: wums
3
- Version: 0.1.6
4
- Summary: .
5
- Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>
6
- License: MIT
7
- Project-URL: Homepage, https://github.com/WMass/wums
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Operating System :: OS Independent
12
- Requires-Python: >=3.8
13
- Description-Content-Type: text/markdown
14
- Requires-Dist: boost_histogram
15
- Requires-Dist: h5py
16
- Requires-Dist: hdf5plugin
17
- Requires-Dist: hist
18
- Requires-Dist: lz4
19
- Requires-Dist: matplotlib
20
- Requires-Dist: mplhep
21
- Requires-Dist: numpy
22
- Requires-Dist: uproot
23
-
24
- # WUMS: Wremnants Utilities, Modules, and other Stuff
25
-
26
- The `wums` package can be pip installed:
27
- ```bash
28
- pip install wums
29
- ```
@@ -1,11 +0,0 @@
1
- wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- wums/boostHistHelpers.py,sha256=F4SwQEVjNObFscfs0qrJEyOHYNKqUCmusW8HIF1o-0c,38993
3
- wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
4
- wums/logging.py,sha256=zNnLVJUwG3HMvr9NeXmiheX07VmsnSt8cQ6R4q4XBk4,4534
5
- wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
6
- wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
7
- wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
8
- wums-0.1.6.dist-info/METADATA,sha256=pTmIMc-rth2X53tju6Ef8WbJDda2zbr8isEqUpeqhDo,843
9
- wums-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
- wums-0.1.6.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
11
- wums-0.1.6.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- wums