wums 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wums/boostHistHelpers.py CHANGED
@@ -4,6 +4,7 @@ from functools import reduce
4
4
 
5
5
  import hist
6
6
  import numpy as np
7
+ from scipy.interpolate import make_smoothing_spline
7
8
 
8
9
  from wums import logging
9
10
 
@@ -60,7 +61,7 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
60
61
  # move back to original order
61
62
  new_vals = np.moveaxis(new_vals, np.arange(len(moves)), list(moves.keys()))
62
63
 
63
- if new_vals.shape != h2.values(flow=flow).shape:
64
+ if new_vals.shape != s2:
64
65
  raise ValueError(
65
66
  f"Broadcast shape {new_vals.shape} (from h1.shape={h1.values(flow=flow).shape}, axes={h1.axes.name}) "
66
67
  f"does not match desired shape {h2.view(flow=flow).shape} (axes={h2.axes.name})"
@@ -193,10 +194,10 @@ def multiplyWithVariance(vals1, vals2, vars1=None, vars2=None):
193
194
  return outvals, outvars
194
195
 
195
196
 
196
- def multiplyHists(h1, h2, allowBroadcast=True, createNew=True, flow=True):
197
+ def multiplyHists(h1, h2, allowBroadcast=True, createNew=True, flow=True, broadcast_by_ax_name=True):
197
198
  if allowBroadcast:
198
- h1 = broadcastSystHist(h1, h2, flow=flow)
199
- h2 = broadcastSystHist(h2, h1, flow=flow)
199
+ h1 = broadcastSystHist(h1, h2, flow=flow, by_ax_name=broadcast_by_ax_name)
200
+ h2 = broadcastSystHist(h2, h1, flow=flow, by_ax_name=broadcast_by_ax_name)
200
201
 
201
202
  if (
202
203
  h1.storage_type == hist.storage.Double
@@ -233,6 +234,7 @@ def concatenateHists(h1, h2, allowBroadcast=True, by_ax_name=True, flow=False):
233
234
  h2 = broadcastSystHist(h2, h1, flow=flow, by_ax_name=by_ax_name)
234
235
 
235
236
  axes = []
237
+
236
238
  for ax1, ax2 in zip(h1.axes, h2.axes):
237
239
  if ax1 == ax2:
238
240
  axes.append(ax1)
@@ -263,7 +265,7 @@ def concatenateHists(h1, h2, allowBroadcast=True, by_ax_name=True, flow=False):
263
265
  )
264
266
  else:
265
267
  raise ValueError(
266
- f"Cannot concatenate hists with inconsistent axes: {ax1.name} and {ax2.name}"
268
+ f"Cannot concatenate hists with inconsistent axes: {ax1.name}: ({ax1.edges}) and {ax2.name}: ({ax2.edges})"
267
269
  )
268
270
 
269
271
  newh = hist.Hist(*axes, storage=h1.storage_type())
@@ -407,6 +409,10 @@ def normalize(h, scale=1e6, createNew=True, flow=True):
407
409
  return scaleHist(h, scale, createNew, flow)
408
410
 
409
411
 
412
+ def renameAxis(h, axis_name, new_name):
413
+ h.axes[axis_name].__dict__['name'] = new_name
414
+
415
+
410
416
  def makeAbsHist(h, axis_name, rename=True):
411
417
  ax = h.axes[axis_name]
412
418
  axidx = list(h.axes).index(ax)
@@ -455,8 +461,8 @@ def compatibleBins(edges1, edges2):
455
461
 
456
462
  def rebinHistMultiAx(h, axes, edges=[], lows=[], highs=[]):
457
463
  # edges: lists of new edges or integers to merge bins, in case new edges are given the lows and highs will be ignored
458
- # lows: list of new lower boundaries
459
- # highs: list of new upper boundaries
464
+ # lows: list of new lower boundaries or bins
465
+ # highs: list of new upper boundaries or bins
460
466
 
461
467
  sel = {}
462
468
  for ax, low, high, rebin in itertools.zip_longest(axes, lows, highs, edges):
@@ -467,14 +473,21 @@ def rebinHistMultiAx(h, axes, edges=[], lows=[], highs=[]):
467
473
  h = rebinHist(h, ax, rebin)
468
474
  elif low is not None and high is not None:
469
475
  # in case high edge is upper edge of last bin we need to manually set the upper limit
470
- upper = hist.overflow if high == h.axes[ax].edges[-1] else complex(0, high)
471
- logger.info(f"Restricting the axis '{ax}' to range [{low}, {high}]")
476
+ # distinguish case with pure real or imaginary number (to select index or value, respectively)
477
+ # in the former case, force casting into integer
478
+ if isinstance(high, int):
479
+ upper = hist.overflow if high == h.axes[ax].size else high
480
+ elif isinstance(high, complex):
481
+ high_imag = high.imag
482
+ upper = hist.overflow if high_imag == h.axes[ax].edges[-1] else high
483
+ logger.info(f"Slicing the axis '{ax}' to [{low}, {upper}]")
472
484
  sel[ax] = slice(
473
- complex(0, low), upper, hist.rebin(rebin) if rebin else None
485
+ low, upper, hist.rebin(rebin) if rebin else None
474
486
  )
475
487
  elif type(rebin) == int and rebin > 1:
476
488
  logger.info(f"Rebinning the axis '{ax}' by [{rebin}]")
477
489
  sel[ax] = slice(None, None, hist.rebin(rebin))
490
+
478
491
  return h[sel] if len(sel) > 0 else h
479
492
 
480
493
 
@@ -499,9 +512,9 @@ def mirrorAxes(h, axes, flow=True):
499
512
  return h
500
513
 
501
514
 
502
- def disableAxisFlow(ax):
515
+ def disableAxisFlow(ax, under=False, over=False):
503
516
  if isinstance(ax, hist.axis.Integer):
504
- args = [ax.edges[0], ax.edges[-1]]
517
+ args = [int(ax.edges[0]), int(ax.edges[-1])]
505
518
  elif isinstance(ax, hist.axis.Regular):
506
519
  args = [ax.size, ax.edges[0], ax.edges[-1]]
507
520
  else:
@@ -510,17 +523,24 @@ def disableAxisFlow(ax):
510
523
  return type(ax)(
511
524
  *args,
512
525
  name=ax.name,
513
- overflow=False,
514
- underflow=False,
526
+ overflow=over,
527
+ underflow=under,
515
528
  circular=ax.traits.circular,
516
529
  )
517
530
 
518
531
 
519
- def disableFlow(h, axis_name):
520
- # disable the overflow and underflow bins of a single axes, while keeping the flow bins of other axes
532
+ def disableFlow(h, axis_name, under=False, over=False):
533
+ # axes_name can be either string or a list of strings with the axis name(s) to disable the flow
534
+ if not isinstance(axis_name, str):
535
+ for var in axis_name:
536
+ if var in h.axes.name:
537
+ h = disableFlow(h, var)
538
+ return h
539
+
540
+ # disable the overflow and underflow bins of a single axis, while keeping the flow bins of other axes
521
541
  ax = h.axes[axis_name]
522
542
  ax_idx = [a.name for a in h.axes].index(axis_name)
523
- new_ax = disableAxisFlow(ax)
543
+ new_ax = disableAxisFlow(ax, under=under, over=over)
524
544
  axes = list(h.axes)
525
545
  axes[ax_idx] = new_ax
526
546
  hnew = hist.Hist(*axes, name=h.name, storage=h.storage_type())
@@ -528,7 +548,7 @@ def disableFlow(h, axis_name):
528
548
  (
529
549
  slice(None)
530
550
  if i != ax_idx
531
- else slice(ax.traits.underflow, new_ax.size + ax.traits.underflow)
551
+ else slice(ax.traits.underflow * (not under), ax.size + ax.traits.underflow + ax.traits.overflow * over)
532
552
  )
533
553
  for i in range(len(axes))
534
554
  ]
@@ -773,9 +793,9 @@ def unrolledHist(h, obs=None, binwnorm=None, add_flow_bins=False):
773
793
 
774
794
  if binwnorm:
775
795
  edges = (
776
- plot_tools.extendEdgesByFlow(hproj) if add_flow_bins else hproj.axes.edges
796
+ extendEdgesByFlow(hproj) if add_flow_bins else hproj.axes.edges
777
797
  )
778
- binwidths = np.outer(*[np.diff(e.squeeze()) for e in edges]).flatten()
798
+ binwidths = np.array(list(itertools.product(*[np.diff(e.squeeze()) for e in edges]))).prod(axis=1)
779
799
  scale = binwnorm / binwidths
780
800
  else:
781
801
  scale = 1
@@ -1133,3 +1153,42 @@ def rssHistsMid(h, syst_axis, scale=1.0):
1133
1153
  hDown = addHists(hnom, hrss[{"downUpVar": -1j}], scale2=-1.0)
1134
1154
 
1135
1155
  return hUp, hDown
1156
+
1157
+ def smooth_hist(h, smooth_ax_name, exclude_axes=[], start_bin=0, end_bin=None):
1158
+
1159
+ hnew = h.copy()
1160
+
1161
+ smooth_ax = h.axes[smooth_ax_name]
1162
+ hproj = h.project(smooth_ax_name, *[ax for ax in h.axes.name if ax not in [smooth_ax_name, *exclude_axes]])
1163
+ smoothh = hproj.copy()
1164
+
1165
+ # Reshape before looping over all other bins and smoothing along the relevant axis
1166
+ vals = smoothh.values().reshape(smooth_ax.size, -1)
1167
+ if not end_bin:
1168
+ end_bin = vals.shape[0]
1169
+
1170
+ # Correct for bin width
1171
+ binw = np.diff(smoothh.axes[smooth_ax_name].edges)
1172
+ vals = (vals.T/binw).T
1173
+
1174
+ for b in range(vals.shape[-1]):
1175
+ spl = make_smoothing_spline(smooth_ax.centers[start_bin:end_bin], vals[start_bin:end_bin,b])
1176
+ vals[start_bin:end_bin,b] = spl(smooth_ax.centers[start_bin:end_bin])
1177
+
1178
+ #Recorrect for bin width
1179
+ vals = (vals.T*binw).T
1180
+
1181
+ smoothh.values()[...] = vals.reshape(smoothh.shape)
1182
+
1183
+ if not exclude_axes:
1184
+ return smoothh.project(*hnew.axes.name)
1185
+
1186
+ # If some axis has been excluded, broadcast it back
1187
+ smoothfac = divideHists(smoothh, hproj).project(*[ax for ax in h.axes.name if ax not in exclude_axes])
1188
+
1189
+ # Broadcast over excluded axes
1190
+ indices = tuple(
1191
+ None if ax in exclude_axes else slice(None) for ax in h.axes.name
1192
+ )
1193
+ hnew.values()[...] = h.values() * smoothfac.values()[indices]
1194
+ return hnew
wums/ioutils.py CHANGED
@@ -325,8 +325,12 @@ class H5Unpickler(pickle.Unpickler):
325
325
  raise pickle.UnpicklingError("unsupported persistent object")
326
326
 
327
327
 
328
- def pickle_dump_h5py(name, obj, h5out):
328
+ def pickle_dump_h5py(name, obj, h5out, override=False):
329
329
  """Write an object to a new h5py group which will be created within the provided group."""
330
+ if name in h5out and override:
331
+ # If the group already exists, delete it first
332
+ del h5out[name]
333
+
330
334
  obj_group = h5out.create_group(name)
331
335
  try:
332
336
  obj_group.attrs["narf_h5py_pickle_protocol_version"] = CURRENT_PROTOCOL_VERSION
wums/output_tools.py CHANGED
@@ -165,6 +165,12 @@ def make_meta_info_dict(
165
165
  return meta_data
166
166
 
167
167
 
168
+ def encode_complex(obj):
169
+ if isinstance(obj, complex):
170
+ return {"real": obj.real, "imag": obj.imag}
171
+ raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
172
+
173
+
168
174
  def write_logfile(
169
175
  outpath,
170
176
  logname,
@@ -181,7 +187,8 @@ def write_logfile(
181
187
  logf.write("\n" + "-" * 80 + "\n")
182
188
  if isinstance(v, dict):
183
189
  logf.write(k)
184
- logf.write(json.dumps(v, indent=5).replace("\\n", "\n"))
190
+ # String conversion needed for non-primitive types (e.g., complex number)
191
+ logf.write(json.dumps(v, default=encode_complex, indent=5).replace("\\n", "\n"))
185
192
  else:
186
193
  logf.write(f"{k}: {v}\n")
187
194
 
@@ -202,3 +209,19 @@ def write_index_and_log(
202
209
  ):
203
210
  write_indexfile(outpath, template_dir)
204
211
  write_logfile(outpath, logname, args, analysis_meta_info)
212
+
213
+ def write_lz4_pkl_output(
214
+ outfile, outfolder, output_dict, basedir, args=None, file_meta_data=None
215
+ ):
216
+ if not outfile.endswith(".pkl.lz4"):
217
+ outfile += ".pkl.lz4"
218
+
219
+ logger.info(f"Write file {outfile}")
220
+ result_dict = {
221
+ outfolder : output_dict,
222
+ "meta_data": make_meta_info_dict(args, wd=basedir),
223
+ }
224
+ if file_meta_data is not None:
225
+ result_dict["file_meta_data"] = file_meta_data
226
+ with lz4.frame.open(outfile, "wb") as f:
227
+ pickle.dump(result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
wums/plot_tools.py CHANGED
@@ -1,12 +1,7 @@
1
- import datetime
2
- import json
1
+ import importlib
2
+ import inspect
3
3
  import math
4
- import pathlib
5
- import shutil
6
- import socket
7
- import sys
8
4
  import textwrap
9
- import importlib
10
5
 
11
6
  import hist
12
7
  import matplotlib as mpl
@@ -22,7 +17,7 @@ from matplotlib.patches import Polygon
22
17
  from matplotlib.ticker import StrMethodFormatter
23
18
 
24
19
  from wums import boostHistHelpers as hh
25
- from wums import ioutils, logging
20
+ from wums import logging
26
21
 
27
22
  hep.style.use(hep.style.ROOT)
28
23
 
@@ -466,6 +461,7 @@ def addLegend(
466
461
  reverse=True,
467
462
  labelcolor=None,
468
463
  padding_loc="auto",
464
+ title=None,
469
465
  ):
470
466
  handles, labels = ax.get_legend_handles_labels()
471
467
  if extra_entries_first:
@@ -496,6 +492,7 @@ def addLegend(
496
492
  text_size = get_textsize(ax, text_size)
497
493
  handler_map = get_custom_handler_map(custom_handlers)
498
494
  leg = ax.legend(
495
+ title=title,
499
496
  handles=handles,
500
497
  labels=labels,
501
498
  prop={"size": text_size},
@@ -507,6 +504,8 @@ def addLegend(
507
504
  markerfirst=markerfirst,
508
505
  labelcolor=labelcolor,
509
506
  )
507
+ if title is not None:
508
+ leg.set_title(title, prop={"size": text_size})
510
509
 
511
510
  if extra_text is not None:
512
511
  if extra_text_loc is None:
@@ -609,18 +608,18 @@ def add_decor(
609
608
  ax, title, label=None, lumi=None, loc=2, data=True, text_size=None, no_energy=False
610
609
  ):
611
610
  text_size = get_textsize(ax, text_size)
612
-
613
611
  if title in ["CMS", "ATLAS", "LHCb", "ALICE"]:
614
612
  module = getattr(hep, title.lower())
615
613
  make_text = module.text
616
614
  make_label = module.label
617
615
  else:
616
+
618
617
  def make_text(text=None, **kwargs):
619
618
  for key, value in dict(hep.rcParams.text._get_kwargs()).items():
620
619
  if (
621
620
  value is not None
622
621
  and key not in kwargs
623
- and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
622
+ and key in inspect.getfullargspec(hep.label.exp_text).kwonlyargs
624
623
  ):
625
624
  kwargs.setdefault(key, value)
626
625
  kwargs.setdefault("italic", (False, True, False))
@@ -632,7 +631,7 @@ def add_decor(
632
631
  if (
633
632
  value is not None
634
633
  and key not in kwargs
635
- and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
634
+ and key in inspect.getfullargspec(hep.label.exp_text).kwonlyargs
636
635
  ):
637
636
  kwargs.setdefault(key, value)
638
637
  kwargs.setdefault("italic", (False, True, False))
@@ -651,7 +650,7 @@ def add_decor(
651
650
  data=data,
652
651
  loc=loc,
653
652
  )
654
-
653
+
655
654
  # else:
656
655
  # if loc==0:
657
656
  # # above frame
@@ -666,7 +665,7 @@ def add_decor(
666
665
  # x = 0.05
667
666
  # y = 0.88
668
667
  # elif loc==2:
669
- # #
668
+ # #
670
669
  # ax.text(
671
670
  # x,
672
671
  # y,
@@ -678,12 +677,12 @@ def add_decor(
678
677
  # if label is not None:
679
678
  # ax.text(0.05, 0.80, label, transform=ax.transAxes, fontstyle="italic")
680
679
 
680
+
681
681
  def makeStackPlotWithRatio(
682
682
  histInfo,
683
683
  stackedProcs,
684
684
  histName="nominal",
685
685
  unstacked=None,
686
- prefit=False,
687
686
  xlabel="",
688
687
  ylabel=None,
689
688
  rlabel="Data/Pred.",
@@ -726,12 +725,14 @@ def makeStackPlotWithRatio(
726
725
  alpha=0.7,
727
726
  legPos="upper right",
728
727
  leg_padding="auto",
728
+ lowerLeg=True,
729
729
  lowerLegCols=2,
730
730
  lowerLegPos="upper right",
731
731
  lower_panel_variations=0,
732
732
  lower_leg_padding="auto",
733
733
  scaleRatioUnstacked=[],
734
734
  subplotsizes=[4, 2],
735
+ x_vertLines_edges=[],
735
736
  ):
736
737
  add_ratio = not (no_stack or no_ratio)
737
738
  if ylabel is None:
@@ -799,9 +800,8 @@ def makeStackPlotWithRatio(
799
800
 
800
801
  opts = dict(stack=not no_stack, flow=flow)
801
802
  optsr = opts.copy() # no binwnorm for ratio axis
802
- optsr["density"] = density
803
803
  if density:
804
- opts["density"] = True
804
+ opts["density"] = True
805
805
  else:
806
806
  opts["binwnorm"] = binwnorm
807
807
 
@@ -822,13 +822,14 @@ def makeStackPlotWithRatio(
822
822
  for x in (data_hist.sum(), hh.sumHists(stack).sum())
823
823
  ]
824
824
  scale = vals[0] / vals[1]
825
- unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2)**0.5
825
+ unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2) ** 0.5
826
826
  ndigits = -math.floor(math.log10(abs(unc))) + 1
827
827
  logger.info(
828
828
  f"Rescaling all processes by {round(scale,ndigits)} +/- {round(unc,ndigits)} to match data norm"
829
829
  )
830
830
  stack = [s * scale for s in stack]
831
831
 
832
+
832
833
  hep.histplot(
833
834
  stack,
834
835
  histtype="fill" if not no_fill else "step",
@@ -927,8 +928,7 @@ def makeStackPlotWithRatio(
927
928
 
928
929
  for i, (proc, style) in enumerate(zip(unstacked, linestyles)):
929
930
  unstack = histInfo[proc].hists[histName]
930
- if proc not in to_read:
931
- unstack = action(unstack)[select]
931
+ unstack = action(unstack)[select]
932
932
  if proc != "Data":
933
933
  unstack = unstack * scale
934
934
  if len(scaleRatioUnstacked) > i:
@@ -975,6 +975,11 @@ def makeStackPlotWithRatio(
975
975
  extra_labels.append(histInfo[proc].label)
976
976
  if ratio_to_data and proc == "Data" or not add_ratio:
977
977
  continue
978
+ if xlim:
979
+ unstack = unstack[complex(0, xlim[0]) : complex(0, xlim[1])]
980
+ if density:
981
+ unstack = hh.scaleHist(unstack, np.sum(ratio_ref.values())/np.sum(unstack.values()))
982
+
978
983
  stack_ratio = hh.divideHists(
979
984
  unstack,
980
985
  ratio_ref,
@@ -1001,6 +1006,14 @@ def makeStackPlotWithRatio(
1001
1006
  **optsr,
1002
1007
  )
1003
1008
 
1009
+ if len(x_vertLines_edges):
1010
+ h_inclusive = hh.sumHists(stack)
1011
+ max_y = 1.05 * np.max(h_inclusive.values() + h_inclusive.variances() ** 0.5)
1012
+ min_y = np.min(h_inclusive.values() - h_inclusive.variances() ** 0.5)
1013
+ for x in x_vertLines_edges:
1014
+ ax1.plot([x, x], [min_y, max_y], linestyle="--", color="black")
1015
+ ax2.plot([x, x], [rrange[0], rrange[1]], linestyle="--", color="black")
1016
+
1004
1017
  addLegend(
1005
1018
  ax1,
1006
1019
  nlegcols,
@@ -1010,7 +1023,7 @@ def makeStackPlotWithRatio(
1010
1023
  text_size=legtext_size,
1011
1024
  padding_loc=leg_padding,
1012
1025
  )
1013
- if add_ratio:
1026
+ if add_ratio and lowerLeg:
1014
1027
  addLegend(
1015
1028
  ax2,
1016
1029
  lowerLegCols,
@@ -1038,7 +1051,7 @@ def makeStackPlotWithRatio(
1038
1051
  def makePlotWithRatioToRef(
1039
1052
  hists,
1040
1053
  labels,
1041
- colors,
1054
+ colors=None,
1042
1055
  hists_ratio=None,
1043
1056
  midratio_idxs=None,
1044
1057
  linestyles=[],
@@ -1074,6 +1087,7 @@ def makePlotWithRatioToRef(
1074
1087
  cms_label=None,
1075
1088
  cutoff=1e-6,
1076
1089
  only_ratio=False,
1090
+ ratio_legend=True,
1077
1091
  width_scale=1,
1078
1092
  automatic_scale=True,
1079
1093
  base_size=8,
@@ -1095,9 +1109,19 @@ def makePlotWithRatioToRef(
1095
1109
  elif select is not None:
1096
1110
  hists_ratio = [h[select] for h in hists_ratio]
1097
1111
 
1098
- if len(hists_ratio) != len(labels) or len(hists_ratio) != len(colors):
1112
+ if colors is None:
1113
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][: len(hists)]
1114
+ if len(colors) < len(hists):
1115
+ colors = (
1116
+ colors
1117
+ + plt.rcParams["axes.prop_cycle"].by_key()["color"][
1118
+ : (len(hists) - len(colors))
1119
+ ]
1120
+ )
1121
+
1122
+ if len(hists_ratio) != len(labels):
1099
1123
  raise ValueError(
1100
- f"Number of hists ({len(hists_ratio)}), colors ({len(colors)}), and labels ({len(labels)}) must agree!"
1124
+ f"Number of hists ({len(hists_ratio)}) and labels ({len(labels)}) must agree!"
1101
1125
  )
1102
1126
  ratio_hists = [
1103
1127
  hh.divideHists(
@@ -1232,7 +1256,7 @@ def makePlotWithRatioToRef(
1232
1256
  fill_between=fill_between,
1233
1257
  dataIdx=dataIdx,
1234
1258
  baseline=baseline,
1235
- add_legend=not only_ratio,
1259
+ add_legend=ratio_legend and not only_ratio,
1236
1260
  )
1237
1261
  if midratio_hists:
1238
1262
  plotRatio(
@@ -1513,7 +1537,7 @@ def fix_axes(
1513
1537
  if noSci and not logy:
1514
1538
  redo_axis_ticks(ax1, "y")
1515
1539
  elif not logy:
1516
- ax1.ticklabel_format(style="sci", useMathText=True, axis="y", scilimits=(0, 0))
1540
+ ax1.ticklabel_format(style="sci", useMathText=True, axis="y", scilimits=(-2, 2))
1517
1541
 
1518
1542
  if ratio_axes is not None:
1519
1543
  if not isinstance(ratio_axes, (list, tuple, np.ndarray)):
@@ -1556,6 +1580,8 @@ def redo_axis_ticks(ax, axlabel, no_labels=False):
1556
1580
  fixedloc = ticker.FixedLocator(
1557
1581
  autoloc.tick_values(*getattr(ax, f"get_{axlabel}lim")())
1558
1582
  )
1583
+ if ax.get_xscale() == 'log':
1584
+ fixedloc = ticker.LogLocator(base=10, numticks=5)
1559
1585
  getattr(ax, f"{axlabel}axis").set_major_locator(fixedloc)
1560
1586
  ticks = getattr(ax, f"get_{axlabel}ticks")()
1561
1587
  labels = [format_axis_num(x, ticks[-1]) for x in ticks] if not no_labels else []
@@ -1798,19 +1824,21 @@ def read_axis_label(x, labels, with_unit=True):
1798
1824
  return x
1799
1825
 
1800
1826
 
1801
- def get_axis_label(config, default_keys=None, label=None, is_bin=False):
1827
+ def get_axis_label(config, default_keys=None, label=None, is_bin=False, with_unit=True):
1802
1828
  if label is not None:
1803
1829
  return label
1804
1830
 
1805
1831
  if default_keys is None:
1806
1832
  return "Bin index"
1833
+ elif isinstance(default_keys, str):
1834
+ default_keys = [default_keys]
1807
1835
 
1808
1836
  labels = getattr(config, "axis_labels", {})
1809
1837
 
1810
- if len(default_keys) == 1:
1838
+ if len(default_keys) == 1:
1811
1839
  if is_bin:
1812
1840
  return f"{read_axis_label(default_keys[0], labels, False)} bin"
1813
1841
  else:
1814
- return read_axis_label(default_keys[0], labels)
1842
+ return read_axis_label(default_keys[0], labels, with_unit)
1815
1843
  else:
1816
1844
  return f"({', '.join([read_axis_label(a, labels, False) for a in default_keys])}) bin"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: wums
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: .
5
5
  Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
6
  License: MIT
@@ -0,0 +1,14 @@
1
+ wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ wums/boostHistHelpers.py,sha256=KMZXVPLlx7_v0XgGEQg6mZmuzEOQ6AWO5Wd_mPaAPjg,41636
3
+ wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
4
+ wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
5
+ wums/ioutils.py,sha256=EyCOBin7ifStLtdgKl7J1_0VWB6RnXWtIWFge9x73Ow,12465
6
+ wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
7
+ wums/output_tools.py,sha256=89rQPOWpwGuzJK5ZQBDv38rmO9th1D2206QOe9PE-gY,7572
8
+ wums/plot_tools.py,sha256=0olJuXAnuDDXJrbF9GF_ZaWBJ14ljOzF5ww19VGCG_g,54662
9
+ wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
10
+ wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
11
+ wums-0.1.9.dist-info/METADATA,sha256=TgZpEIJTjE_xZ3J_cbrXaD2EClyo3iEKS6MHl5lanLU,1784
12
+ wums-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ wums-0.1.9.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
14
+ wums-0.1.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1 @@
1
+ wums
@@ -1,90 +0,0 @@
1
- import wums.fitutils
2
-
3
- import tensorflow as tf
4
-
5
- import matplotlib.pyplot as plt
6
-
7
- import numpy as np
8
- import hist
9
- import math
10
-
11
- np.random.seed(1234)
12
-
13
- nevt = 100000
14
-
15
- rgaus = np.random.normal(size=(nevt,))
16
-
17
- print(rgaus.dtype)
18
- print(rgaus)
19
-
20
- axis0 = hist.axis.Regular(100, -5., 5.)
21
-
22
- htest = hist.Hist(axis0)
23
- htest.fill(rgaus)
24
-
25
- print(htest)
26
-
27
-
28
- quant_cdfvals = tf.constant([0.0, 1e-3, 0.02, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.98, 1.0-1e-3, 1.0], tf.float64)
29
-
30
- nquants = quant_cdfvals.shape.num_elements()
31
-
32
- def func_transform_cdf(quantile):
33
- const_sqrt2 = tf.constant(math.sqrt(2.), quantile.dtype)
34
- return 0.5*(1. + tf.math.erf(quantile/const_sqrt2))
35
-
36
- def func_transform_quantile(cdf):
37
- const_sqrt2 = tf.constant(math.sqrt(2.), cdf.dtype)
38
- return const_sqrt2*tf.math.erfinv(2*cdf - 1.)
39
-
40
-
41
-
42
- def func_cdf(xvals, xedges, parms, quant_cdfvals):
43
- qparms = parms
44
-
45
- cdf = narf.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, transform = (func_transform_cdf, func_transform_quantile))
46
-
47
- return cdf
48
-
49
-
50
- #this is just for plotting
51
- def func_pdf(h, parms):
52
- dtype = tf.float64
53
- xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
54
- xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
55
-
56
- tfparms = tf.constant(parms)
57
-
58
- cdf = func_cdf(xvals, xedges, tfparms, quant_cdfvals)
59
-
60
- pdf = cdf[1:] - cdf[:-1]
61
- pdf = tf.maximum(pdf, tf.zeros_like(pdf))
62
-
63
- return pdf
64
-
65
- nparms = nquants-1
66
-
67
-
68
- initial_parms = np.array([np.log(1./nparms)]*nparms)
69
-
70
- res = narf.fitutils.fit_hist(htest, func_cdf, initial_parms, mode="nll_bin_integrated", func_constraint=narf.fitutils.func_constraint_for_quantile_fit, args = (quant_cdfvals,))
71
-
72
- print(res)
73
-
74
-
75
- parmvals = res["x"]
76
-
77
-
78
- pdfvals = func_pdf(htest, parmvals)
79
- pdfvals *= htest.sum()/np.sum(pdfvals)
80
-
81
- #
82
- plot = plt.figure()
83
- plt.yscale("log")
84
- htest.plot()
85
- plt.plot(htest.axes[0].centers, pdfvals)
86
- # plt.show()
87
- plot.savefig("test.png")
88
-
89
-
90
-
@@ -1,323 +0,0 @@
1
- import wums.fitutils
2
-
3
- import tensorflow as tf
4
-
5
- import matplotlib.pyplot as plt
6
-
7
- import numpy as np
8
- import hist
9
- import math
10
-
11
- import onnx
12
- import tf2onnx
13
-
14
- np.random.seed(1234)
15
-
16
- nevt = 20000
17
-
18
- runiform = np.random.random((nevt,))
19
- rgaus = np.random.normal(size=(nevt,))
20
-
21
- data = np.stack([runiform, rgaus], axis=-1)
22
-
23
- # "pt"-dependent mean and sigma
24
- data[:,1] = -0.1 + 0.1*data[:,0] + (1. + 0.2*data[:,0])*data[:,1]
25
-
26
-
27
- # print(rgaus.dtype)
28
- # print(rgaus)
29
-
30
- axis0 = hist.axis.Regular(50, 0., 1., name="pt")
31
- axis1 = hist.axis.Regular(100, -5., 5., name="recoil")
32
-
33
- htest_data = hist.Hist(axis0, axis1)
34
- htest_mc = hist.Hist(axis0, axis1)
35
-
36
- # print("data.shape", data.shape)
37
- htest_data.fill(data[:nevt//2,0], data[:nevt//2, 1])
38
- htest_mc.fill(data[nevt//2:,0], data[nevt//2:, 1])
39
-
40
-
41
-
42
-
43
- quant_cdfvals = tf.constant([0.0, 1e-3, 0.02, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.98, 1.0-1e-3, 1.0], dtype = tf.float64)
44
- nquants = quant_cdfvals.shape.num_elements()
45
-
46
- print("nquants", nquants)
47
-
48
- #cdf is in terms of axis1, so shapes need to be compatible
49
- quant_cdfvals = quant_cdfvals[None, :]
50
-
51
-
52
- # get quantiles from histogram, e.g. to help initialize the parameters for the fit (not actually used here)
53
-
54
- # hist_quantiles, hist_quantile_errs = wums.fitutils.hist_to_quantiles(htest, quant_cdfvals, axis=1)
55
- #
56
- # print(hist_quantiles)
57
- # print(hist_quantile_errs)
58
- #
59
- # hist_qparms, hist_qparm_errs = wums.fitutils.quantiles_to_qparms(hist_quantiles, hist_quantile_errs)
60
- #
61
- # print(hist_qparms)
62
- # print(hist_qparm_errs)
63
-
64
- def parms_to_qparms(xvals, parms):
65
-
66
- parms_2d = tf.reshape(parms, (-1, 2))
67
- parms_const = parms_2d[:,0]
68
- parms_slope = parms_2d[:,1]
69
-
70
- #cdf is in terms of axis1, so shapes need to be compatible
71
- parms_const = parms_const[None, :]
72
- parms_slope = parms_slope[None, :]
73
-
74
- qparms = parms_const + parms_slope*xvals[0]
75
-
76
- return qparms
77
-
78
-
79
- def func_transform_cdf(quantile):
80
- const_sqrt2 = tf.constant(math.sqrt(2.), quantile.dtype)
81
- return 0.5*(1. + tf.math.erf(quantile/const_sqrt2))
82
-
83
- def func_transform_quantile(cdf):
84
- const_sqrt2 = tf.constant(math.sqrt(2.), cdf.dtype)
85
- return const_sqrt2*tf.math.erfinv(2.*cdf - 1.)
86
-
87
- # def func_transform_cdf(quantile):
88
- # return tf.math.log(quantile/(1.-quantile))
89
- #
90
- # def func_transform_quantile(cdf):
91
- # return tf.math.sigmoid(cdf)
92
-
93
- def func_cdf(xvals, xedges, parms):
94
- qparms = parms_to_qparms(xvals, parms)
95
- # return wums.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=1)
96
-
97
- return wums.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=1, transform = (func_transform_cdf, func_transform_quantile))
98
-
99
- def func_constraint(xvals, xedges, parms):
100
- qparms = parms_to_qparms(xvals, parms)
101
- return wums.fitutils.func_constraint_for_quantile_fit(xvals, xedges, qparms)
102
-
103
- #this is just for plotting
104
- def func_pdf(h, parms):
105
- dtype = tf.float64
106
- xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
107
- xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
108
-
109
- tfparms = tf.constant(parms)
110
-
111
- cdf = func_cdf(xvals, xedges, tfparms)
112
-
113
- pdf = cdf[:,1:] - cdf[:,:-1]
114
- pdf = tf.maximum(pdf, tf.zeros_like(pdf))
115
-
116
- return pdf
117
-
118
- nparms = nquants-1
119
-
120
-
121
- # print("edges", htest.edges)
122
-
123
- # assert(0)
124
-
125
- initial_parms_const = np.array([np.log(1./nparms)]*nparms)
126
- initial_parms_slope = np.zeros_like(initial_parms_const)
127
-
128
- initial_parms = np.stack([initial_parms_const, initial_parms_slope], axis=-1)
129
- initial_parms = np.reshape(initial_parms, (-1,))
130
-
131
- res_data = wums.fitutils.fit_hist(htest_data, func_cdf, initial_parms, mode="nll_bin_integrated", norm_axes=[1])
132
-
133
- res_mc = wums.fitutils.fit_hist(htest_mc, func_cdf, initial_parms, mode="nll_bin_integrated", norm_axes=[1])
134
-
135
- print(res_data)
136
-
137
-
138
- parmvals_data = tf.constant(res_data["x"], tf.float64)
139
- parmvals_mc = tf.constant(res_mc["x"], tf.float64)
140
-
141
- hess_data = res_data["hess"]
142
- hess_mc = res_mc["hess"]
143
-
144
- def get_scaled_eigenvectors(hess, num_null = 2):
145
- e,v = np.linalg.eigh(hess)
146
-
147
- # remove the null eigenvectors
148
- e = e[None, num_null:]
149
- v = v[:, num_null:]
150
-
151
- # scale the eigenvectors
152
- vscaled = v/np.sqrt(e)
153
-
154
- return vscaled
155
-
156
- vscaled_data = tf.constant(get_scaled_eigenvectors(hess_data), tf.float64)
157
- vscaled_mc = tf.constant(get_scaled_eigenvectors(hess_data), tf.float64)
158
-
159
- print("vscaled_data.shape", vscaled_data.shape)
160
-
161
- ut_flat = np.reshape(htest_data.axes.edges[1], (-1,))
162
- ut_low = tf.constant(ut_flat[0], tf.float64)
163
- ut_high = tf.constant(ut_flat[-1], tf.float64)
164
-
165
- def func_cdf_mc(pt, ut):
166
- pts = tf.reshape(pt, (1,1))
167
- uts = tf.reshape(ut, (1,1))
168
-
169
- xvals = [pts, None]
170
- xedges = [None, uts]
171
-
172
- parms = parmvals_mc
173
-
174
- qparms = parms_to_qparms(xvals, parms)
175
-
176
- ut_axis = 1
177
-
178
- quants = wums.fitutils.qparms_to_quantiles(qparms, x_low = ut_low, x_high = ut_high, axis = ut_axis)
179
- spline_edges = xedges[ut_axis]
180
-
181
- cdfvals = wums.fitutils.pchip_interpolate(quants, quant_cdfvals, spline_edges, axis=ut_axis)
182
-
183
- return cdfvals
184
-
185
- def func_cdfinv_data(pt, quant):
186
- pts = tf.reshape(pt, (1,1))
187
- quant_outs = tf.reshape(quant, (1,1))
188
-
189
- xvals = [pts, None]
190
- xedges = [None, quant_outs]
191
-
192
- parms = parmvals_data
193
-
194
- qparms = parms_to_qparms(xvals, parms)
195
-
196
- ut_axis = 1
197
-
198
- quants = wums.fitutils.qparms_to_quantiles(qparms, x_low = ut_low, x_high = ut_high, axis = ut_axis)
199
- spline_edges = xedges[ut_axis]
200
-
201
- cdfinvvals = wums.fitutils.pchip_interpolate(quant_cdfvals, quants, spline_edges, axis=ut_axis)
202
-
203
- return cdfinvvals
204
-
205
- def func_cdfinv_pdf_data(pt, quant):
206
- with tf.GradientTape() as t:
207
- t.watch(quant)
208
- cdfinv = func_cdfinv_data(pt, quant)
209
- pdfreciprocal = t.gradient(cdfinv, quant)
210
- pdf = 1./pdfreciprocal
211
- return cdfinv, pdf
212
-
213
- scalar_spec = tf.TensorSpec([], tf.float64)
214
-
215
-
216
- def transform_mc(pt, ut):
217
- with tf.GradientTape(persistent=True) as t:
218
- t.watch(parmvals_mc)
219
- t.watch(parmvals_data)
220
-
221
- cdf_mc = func_cdf_mc(pt, ut)
222
- ut_transformed, pdf = func_cdfinv_pdf_data(pt, cdf_mc)
223
-
224
- ut_transformed = tf.reshape(ut_transformed, [])
225
- pdf = tf.reshape(pdf, [])
226
-
227
- pdf_grad_mc = t.gradient(pdf, parmvals_mc)
228
- pdf_grad_data = t.gradient(pdf, parmvals_data)
229
-
230
- del t
231
-
232
- weight_grad_mc = pdf_grad_mc/pdf
233
- weight_grad_data = pdf_grad_data/pdf
234
-
235
- weight_grad_mc = weight_grad_mc[None, :]
236
- weight_grad_data = weight_grad_data[None, :]
237
-
238
- weight_grad_mc_eig = weight_grad_mc @ vscaled_mc
239
- weight_grad_data_eig = weight_grad_data @ vscaled_data
240
-
241
- weight_grad_mc_eig = tf.reshape(weight_grad_mc_eig, [-1])
242
- weight_grad_data_eig = tf.reshape(weight_grad_data_eig, [-1])
243
-
244
- weight_grad_eig = tf.concat([weight_grad_mc_eig, weight_grad_data_eig], axis=0)
245
-
246
- return ut_transformed, weight_grad_eig
247
- # return ut_transformed
248
-
249
- @tf.function
250
- def transform_mc_simple(pt, ut):
251
- cdf_mc = func_cdf_mc(pt, ut)
252
- ut_transformed, pdf = func_cdfinv_pdf_data(pt, cdf_mc)
253
-
254
- ut_transformed = tf.reshape(ut_transformed, [])
255
-
256
- return ut_transformed
257
-
258
-
259
-
260
- pt_test = tf.constant(0.2, tf.float64)
261
- ut_test = tf.constant(1.0, tf.float64)
262
-
263
- ut, grad = transform_mc(pt_test, ut_test)
264
- # ut = transform_mc(pt_test, ut_test)
265
-
266
- print("shapes", ut.shape, grad.shape)
267
-
268
- print("ut", ut)
269
- print("grad", grad)
270
-
271
- input_signature = [scalar_spec, scalar_spec]
272
-
273
- class TestMod(tf.Module):
274
-
275
- @tf.function(input_signature = [scalar_spec, scalar_spec])
276
- def __call__(self, pt, ut):
277
- return transform_mc(pt, ut)
278
-
279
- module = TestMod()
280
- # tf.saved_model.save(module, "test")
281
-
282
- concrete_function = module.__call__.get_concrete_function()
283
-
284
- # Convert the model
285
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function], module)
286
-
287
- # converter = tf.lite.TFLiteConverter.from_saved_model("test") # path to the SavedModel directory
288
- converter.target_spec.supported_ops = [
289
- tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
290
- tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
291
- ]
292
-
293
- tflite_model = converter.convert()
294
-
295
- # print(tflite_model)
296
-
297
- # Save the model.
298
- with open('model.tflite', 'wb') as f:
299
- f.write(tflite_model)
300
-
301
-
302
- # onnx_model, _ = tf2onnx.convert.from_function(transform_mc, input_signature)
303
- # onnx.save(onnx_model, "test.onnx")
304
-
305
-
306
- parmvals = res_data["x"]
307
-
308
-
309
- pdfvals = func_pdf(htest_data, parmvals)
310
- pdfvals *= htest_data.sum()/np.sum(pdfvals)
311
-
312
-
313
- # hplot = htest[5]
314
-
315
- plot = plt.figure()
316
- plt.yscale("log")
317
- htest_data[5,:].plot()
318
- plt.plot(htest_data.axes[1].centers, pdfvals[5])
319
- # plt.show()
320
- plot.savefig("test.png")
321
-
322
-
323
-
@@ -1,16 +0,0 @@
1
- scripts/test/testsplinepdf.py,sha256=sXnmDjEXiO0OIHAXLXU4UxTD4_nLwUpoojCecfjyT04,1964
2
- scripts/test/testsplinepdf2d.py,sha256=vGw9mq67f6aoymefLqv6CqF8teluva4Lx6tpbnC_NGU,8513
3
- wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- wums/boostHistHelpers.py,sha256=mgdPXAgmxriqoOhrhMctyZcfwEOPfV07V27CvGt2sk8,39260
5
- wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
6
- wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
7
- wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
8
- wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
9
- wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
10
- wums/plot_tools.py,sha256=7GBQAO--wuP8aatkjy-ir1lQWpNrzMc1lSI6zSq3JXE,53502
11
- wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
12
- wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
13
- wums-0.1.8.dist-info/METADATA,sha256=87fET64UzNDs6swv1-tcJWcuzVE5S3kEuLWOfy1JN6c,1784
14
- wums-0.1.8.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
15
- wums-0.1.8.dist-info/top_level.txt,sha256=cGGeFZQ8IwVw-BhgxMCTu5zfkgQelfF1wEFFWGhycds,13
16
- wums-0.1.8.dist-info/RECORD,,
@@ -1,2 +0,0 @@
1
- scripts
2
- wums