wums 0.1.7__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: wums
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: .
5
5
  Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "wums"
7
- version = "0.1.7"
7
+ version = "0.1.9"
8
8
  description = "."
9
9
  readme = { file = "README.md", content-type = "text/markdown" }
10
10
  license = { text = "MIT" }
@@ -38,4 +38,4 @@ all = ["plotting", "fitting", "pickling"]
38
38
  where = ["."]
39
39
 
40
40
  [tool.setuptools.package-data]
41
- "wums" = ["Templates/index.php"]
41
+ "wums" = ["Templates/index.php"]
@@ -4,6 +4,7 @@ from functools import reduce
4
4
 
5
5
  import hist
6
6
  import numpy as np
7
+ from scipy.interpolate import make_smoothing_spline
7
8
 
8
9
  from wums import logging
9
10
 
@@ -60,7 +61,7 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
60
61
  # move back to original order
61
62
  new_vals = np.moveaxis(new_vals, np.arange(len(moves)), list(moves.keys()))
62
63
 
63
- if new_vals.shape != h2.values(flow=flow).shape:
64
+ if new_vals.shape != s2:
64
65
  raise ValueError(
65
66
  f"Broadcast shape {new_vals.shape} (from h1.shape={h1.values(flow=flow).shape}, axes={h1.axes.name}) "
66
67
  f"does not match desired shape {h2.view(flow=flow).shape} (axes={h2.axes.name})"
@@ -193,10 +194,10 @@ def multiplyWithVariance(vals1, vals2, vars1=None, vars2=None):
193
194
  return outvals, outvars
194
195
 
195
196
 
196
- def multiplyHists(h1, h2, allowBroadcast=True, createNew=True, flow=True):
197
+ def multiplyHists(h1, h2, allowBroadcast=True, createNew=True, flow=True, broadcast_by_ax_name=True):
197
198
  if allowBroadcast:
198
- h1 = broadcastSystHist(h1, h2, flow=flow)
199
- h2 = broadcastSystHist(h2, h1, flow=flow)
199
+ h1 = broadcastSystHist(h1, h2, flow=flow, by_ax_name=broadcast_by_ax_name)
200
+ h2 = broadcastSystHist(h2, h1, flow=flow, by_ax_name=broadcast_by_ax_name)
200
201
 
201
202
  if (
202
203
  h1.storage_type == hist.storage.Double
@@ -233,6 +234,7 @@ def concatenateHists(h1, h2, allowBroadcast=True, by_ax_name=True, flow=False):
233
234
  h2 = broadcastSystHist(h2, h1, flow=flow, by_ax_name=by_ax_name)
234
235
 
235
236
  axes = []
237
+
236
238
  for ax1, ax2 in zip(h1.axes, h2.axes):
237
239
  if ax1 == ax2:
238
240
  axes.append(ax1)
@@ -263,7 +265,7 @@ def concatenateHists(h1, h2, allowBroadcast=True, by_ax_name=True, flow=False):
263
265
  )
264
266
  else:
265
267
  raise ValueError(
266
- f"Cannot concatenate hists with inconsistent axes: {ax1.name} and {ax2.name}"
268
+ f"Cannot concatenate hists with inconsistent axes: {ax1.name}: ({ax1.edges}) and {ax2.name}: ({ax2.edges})"
267
269
  )
268
270
 
269
271
  newh = hist.Hist(*axes, storage=h1.storage_type())
@@ -407,6 +409,10 @@ def normalize(h, scale=1e6, createNew=True, flow=True):
407
409
  return scaleHist(h, scale, createNew, flow)
408
410
 
409
411
 
412
+ def renameAxis(h, axis_name, new_name):
413
+ h.axes[axis_name].__dict__['name'] = new_name
414
+
415
+
410
416
  def makeAbsHist(h, axis_name, rename=True):
411
417
  ax = h.axes[axis_name]
412
418
  axidx = list(h.axes).index(ax)
@@ -455,8 +461,8 @@ def compatibleBins(edges1, edges2):
455
461
 
456
462
  def rebinHistMultiAx(h, axes, edges=[], lows=[], highs=[]):
457
463
  # edges: lists of new edges or integers to merge bins, in case new edges are given the lows and highs will be ignored
458
- # lows: list of new lower boundaries
459
- # highs: list of new upper boundaries
464
+ # lows: list of new lower boundaries or bins
465
+ # highs: list of new upper boundaries or bins
460
466
 
461
467
  sel = {}
462
468
  for ax, low, high, rebin in itertools.zip_longest(axes, lows, highs, edges):
@@ -467,14 +473,21 @@ def rebinHistMultiAx(h, axes, edges=[], lows=[], highs=[]):
467
473
  h = rebinHist(h, ax, rebin)
468
474
  elif low is not None and high is not None:
469
475
  # in case high edge is upper edge of last bin we need to manually set the upper limit
470
- upper = hist.overflow if high == h.axes[ax].edges[-1] else complex(0, high)
471
- logger.info(f"Restricting the axis '{ax}' to range [{low}, {high}]")
476
+ # distinguish case with pure real or imaginary number (to select index or value, respectively)
477
+ # in the former case, force casting into integer
478
+ if isinstance(high, int):
479
+ upper = hist.overflow if high == h.axes[ax].size else high
480
+ elif isinstance(high, complex):
481
+ high_imag = high.imag
482
+ upper = hist.overflow if high_imag == h.axes[ax].edges[-1] else high
483
+ logger.info(f"Slicing the axis '{ax}' to [{low}, {upper}]")
472
484
  sel[ax] = slice(
473
- complex(0, low), upper, hist.rebin(rebin) if rebin else None
485
+ low, upper, hist.rebin(rebin) if rebin else None
474
486
  )
475
487
  elif type(rebin) == int and rebin > 1:
476
488
  logger.info(f"Rebinning the axis '{ax}' by [{rebin}]")
477
489
  sel[ax] = slice(None, None, hist.rebin(rebin))
490
+
478
491
  return h[sel] if len(sel) > 0 else h
479
492
 
480
493
 
@@ -499,9 +512,9 @@ def mirrorAxes(h, axes, flow=True):
499
512
  return h
500
513
 
501
514
 
502
- def disableAxisFlow(ax):
515
+ def disableAxisFlow(ax, under=False, over=False):
503
516
  if isinstance(ax, hist.axis.Integer):
504
- args = [ax.edges[0], ax.edges[-1]]
517
+ args = [int(ax.edges[0]), int(ax.edges[-1])]
505
518
  elif isinstance(ax, hist.axis.Regular):
506
519
  args = [ax.size, ax.edges[0], ax.edges[-1]]
507
520
  else:
@@ -510,17 +523,24 @@ def disableAxisFlow(ax):
510
523
  return type(ax)(
511
524
  *args,
512
525
  name=ax.name,
513
- overflow=False,
514
- underflow=False,
526
+ overflow=over,
527
+ underflow=under,
515
528
  circular=ax.traits.circular,
516
529
  )
517
530
 
518
531
 
519
- def disableFlow(h, axis_name):
520
- # disable the overflow and underflow bins of a single axes, while keeping the flow bins of other axes
532
+ def disableFlow(h, axis_name, under=False, over=False):
533
+ # axes_name can be either string or a list of strings with the axis name(s) to disable the flow
534
+ if not isinstance(axis_name, str):
535
+ for var in axis_name:
536
+ if var in h.axes.name:
537
+ h = disableFlow(h, var)
538
+ return h
539
+
540
+ # disable the overflow and underflow bins of a single axis, while keeping the flow bins of other axes
521
541
  ax = h.axes[axis_name]
522
542
  ax_idx = [a.name for a in h.axes].index(axis_name)
523
- new_ax = disableAxisFlow(ax)
543
+ new_ax = disableAxisFlow(ax, under=under, over=over)
524
544
  axes = list(h.axes)
525
545
  axes[ax_idx] = new_ax
526
546
  hnew = hist.Hist(*axes, name=h.name, storage=h.storage_type())
@@ -528,7 +548,7 @@ def disableFlow(h, axis_name):
528
548
  (
529
549
  slice(None)
530
550
  if i != ax_idx
531
- else slice(ax.traits.underflow, new_ax.size + ax.traits.underflow)
551
+ else slice(ax.traits.underflow * (not under), ax.size + ax.traits.underflow + ax.traits.overflow * over)
532
552
  )
533
553
  for i in range(len(axes))
534
554
  ]
@@ -773,9 +793,9 @@ def unrolledHist(h, obs=None, binwnorm=None, add_flow_bins=False):
773
793
 
774
794
  if binwnorm:
775
795
  edges = (
776
- plot_tools.extendEdgesByFlow(hproj) if add_flow_bins else hproj.axes.edges
796
+ extendEdgesByFlow(hproj) if add_flow_bins else hproj.axes.edges
777
797
  )
778
- binwidths = np.outer(*[np.diff(e.squeeze()) for e in edges]).flatten()
798
+ binwidths = np.array(list(itertools.product(*[np.diff(e.squeeze()) for e in edges]))).prod(axis=1)
779
799
  scale = binwnorm / binwidths
780
800
  else:
781
801
  scale = 1
@@ -1133,3 +1153,42 @@ def rssHistsMid(h, syst_axis, scale=1.0):
1133
1153
  hDown = addHists(hnom, hrss[{"downUpVar": -1j}], scale2=-1.0)
1134
1154
 
1135
1155
  return hUp, hDown
1156
+
1157
+ def smooth_hist(h, smooth_ax_name, exclude_axes=[], start_bin=0, end_bin=None):
1158
+
1159
+ hnew = h.copy()
1160
+
1161
+ smooth_ax = h.axes[smooth_ax_name]
1162
+ hproj = h.project(smooth_ax_name, *[ax for ax in h.axes.name if ax not in [smooth_ax_name, *exclude_axes]])
1163
+ smoothh = hproj.copy()
1164
+
1165
+ # Reshape before looping over all other bins and smoothing along the relevant axis
1166
+ vals = smoothh.values().reshape(smooth_ax.size, -1)
1167
+ if not end_bin:
1168
+ end_bin = vals.shape[0]
1169
+
1170
+ # Correct for bin width
1171
+ binw = np.diff(smoothh.axes[smooth_ax_name].edges)
1172
+ vals = (vals.T/binw).T
1173
+
1174
+ for b in range(vals.shape[-1]):
1175
+ spl = make_smoothing_spline(smooth_ax.centers[start_bin:end_bin], vals[start_bin:end_bin,b])
1176
+ vals[start_bin:end_bin,b] = spl(smooth_ax.centers[start_bin:end_bin])
1177
+
1178
+ #Recorrect for bin width
1179
+ vals = (vals.T*binw).T
1180
+
1181
+ smoothh.values()[...] = vals.reshape(smoothh.shape)
1182
+
1183
+ if not exclude_axes:
1184
+ return smoothh.project(*hnew.axes.name)
1185
+
1186
+ # If some axis has been excluded, broadcast it back
1187
+ smoothfac = divideHists(smoothh, hproj).project(*[ax for ax in h.axes.name if ax not in exclude_axes])
1188
+
1189
+ # Broadcast over excluded axes
1190
+ indices = tuple(
1191
+ None if ax in exclude_axes else slice(None) for ax in h.axes.name
1192
+ )
1193
+ hnew.values()[...] = h.values() * smoothfac.values()[indices]
1194
+ return hnew
@@ -325,8 +325,12 @@ class H5Unpickler(pickle.Unpickler):
325
325
  raise pickle.UnpicklingError("unsupported persistent object")
326
326
 
327
327
 
328
- def pickle_dump_h5py(name, obj, h5out):
328
+ def pickle_dump_h5py(name, obj, h5out, override=False):
329
329
  """Write an object to a new h5py group which will be created within the provided group."""
330
+ if name in h5out and override:
331
+ # If the group already exists, delete it first
332
+ del h5out[name]
333
+
330
334
  obj_group = h5out.create_group(name)
331
335
  try:
332
336
  obj_group.attrs["narf_h5py_pickle_protocol_version"] = CURRENT_PROTOCOL_VERSION
@@ -165,6 +165,12 @@ def make_meta_info_dict(
165
165
  return meta_data
166
166
 
167
167
 
168
+ def encode_complex(obj):
169
+ if isinstance(obj, complex):
170
+ return {"real": obj.real, "imag": obj.imag}
171
+ raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
172
+
173
+
168
174
  def write_logfile(
169
175
  outpath,
170
176
  logname,
@@ -181,7 +187,8 @@ def write_logfile(
181
187
  logf.write("\n" + "-" * 80 + "\n")
182
188
  if isinstance(v, dict):
183
189
  logf.write(k)
184
- logf.write(json.dumps(v, indent=5).replace("\\n", "\n"))
190
+ # String conversion needed for non-primitive types (e.g., complex number)
191
+ logf.write(json.dumps(v, default=encode_complex, indent=5).replace("\\n", "\n"))
185
192
  else:
186
193
  logf.write(f"{k}: {v}\n")
187
194
 
@@ -202,3 +209,19 @@ def write_index_and_log(
202
209
  ):
203
210
  write_indexfile(outpath, template_dir)
204
211
  write_logfile(outpath, logname, args, analysis_meta_info)
212
+
213
+ def write_lz4_pkl_output(
214
+ outfile, outfolder, output_dict, basedir, args=None, file_meta_data=None
215
+ ):
216
+ if not outfile.endswith(".pkl.lz4"):
217
+ outfile += ".pkl.lz4"
218
+
219
+ logger.info(f"Write file {outfile}")
220
+ result_dict = {
221
+ outfolder : output_dict,
222
+ "meta_data": make_meta_info_dict(args, wd=basedir),
223
+ }
224
+ if file_meta_data is not None:
225
+ result_dict["file_meta_data"] = file_meta_data
226
+ with lz4.frame.open(outfile, "wb") as f:
227
+ pickle.dump(result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
@@ -1,12 +1,7 @@
1
- import datetime
2
- import json
1
+ import importlib
2
+ import inspect
3
3
  import math
4
- import pathlib
5
- import shutil
6
- import socket
7
- import sys
8
4
  import textwrap
9
- import importlib
10
5
 
11
6
  import hist
12
7
  import matplotlib as mpl
@@ -22,7 +17,7 @@ from matplotlib.patches import Polygon
22
17
  from matplotlib.ticker import StrMethodFormatter
23
18
 
24
19
  from wums import boostHistHelpers as hh
25
- from wums import ioutils, logging
20
+ from wums import logging
26
21
 
27
22
  hep.style.use(hep.style.ROOT)
28
23
 
@@ -466,6 +461,7 @@ def addLegend(
466
461
  reverse=True,
467
462
  labelcolor=None,
468
463
  padding_loc="auto",
464
+ title=None,
469
465
  ):
470
466
  handles, labels = ax.get_legend_handles_labels()
471
467
  if extra_entries_first:
@@ -496,6 +492,7 @@ def addLegend(
496
492
  text_size = get_textsize(ax, text_size)
497
493
  handler_map = get_custom_handler_map(custom_handlers)
498
494
  leg = ax.legend(
495
+ title=title,
499
496
  handles=handles,
500
497
  labels=labels,
501
498
  prop={"size": text_size},
@@ -507,6 +504,8 @@ def addLegend(
507
504
  markerfirst=markerfirst,
508
505
  labelcolor=labelcolor,
509
506
  )
507
+ if title is not None:
508
+ leg.set_title(title, prop={"size": text_size})
510
509
 
511
510
  if extra_text is not None:
512
511
  if extra_text_loc is None:
@@ -609,18 +608,18 @@ def add_decor(
609
608
  ax, title, label=None, lumi=None, loc=2, data=True, text_size=None, no_energy=False
610
609
  ):
611
610
  text_size = get_textsize(ax, text_size)
612
-
613
611
  if title in ["CMS", "ATLAS", "LHCb", "ALICE"]:
614
612
  module = getattr(hep, title.lower())
615
613
  make_text = module.text
616
614
  make_label = module.label
617
615
  else:
616
+
618
617
  def make_text(text=None, **kwargs):
619
618
  for key, value in dict(hep.rcParams.text._get_kwargs()).items():
620
619
  if (
621
620
  value is not None
622
621
  and key not in kwargs
623
- and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
622
+ and key in inspect.getfullargspec(hep.label.exp_text).kwonlyargs
624
623
  ):
625
624
  kwargs.setdefault(key, value)
626
625
  kwargs.setdefault("italic", (False, True, False))
@@ -632,7 +631,7 @@ def add_decor(
632
631
  if (
633
632
  value is not None
634
633
  and key not in kwargs
635
- and key in inspect.getfullargspec(label_base.exp_text).kwonlyargs
634
+ and key in inspect.getfullargspec(hep.label.exp_text).kwonlyargs
636
635
  ):
637
636
  kwargs.setdefault(key, value)
638
637
  kwargs.setdefault("italic", (False, True, False))
@@ -651,7 +650,7 @@ def add_decor(
651
650
  data=data,
652
651
  loc=loc,
653
652
  )
654
-
653
+
655
654
  # else:
656
655
  # if loc==0:
657
656
  # # above frame
@@ -666,7 +665,7 @@ def add_decor(
666
665
  # x = 0.05
667
666
  # y = 0.88
668
667
  # elif loc==2:
669
- # #
668
+ # #
670
669
  # ax.text(
671
670
  # x,
672
671
  # y,
@@ -678,13 +677,12 @@ def add_decor(
678
677
  # if label is not None:
679
678
  # ax.text(0.05, 0.80, label, transform=ax.transAxes, fontstyle="italic")
680
679
 
680
+
681
681
  def makeStackPlotWithRatio(
682
682
  histInfo,
683
683
  stackedProcs,
684
684
  histName="nominal",
685
685
  unstacked=None,
686
- fitresult=None,
687
- prefit=False,
688
686
  xlabel="",
689
687
  ylabel=None,
690
688
  rlabel="Data/Pred.",
@@ -727,12 +725,14 @@ def makeStackPlotWithRatio(
727
725
  alpha=0.7,
728
726
  legPos="upper right",
729
727
  leg_padding="auto",
728
+ lowerLeg=True,
730
729
  lowerLegCols=2,
731
730
  lowerLegPos="upper right",
732
731
  lower_panel_variations=0,
733
732
  lower_leg_padding="auto",
734
733
  scaleRatioUnstacked=[],
735
734
  subplotsizes=[4, 2],
735
+ x_vertLines_edges=[],
736
736
  ):
737
737
  add_ratio = not (no_stack or no_ratio)
738
738
  if ylabel is None:
@@ -757,11 +757,6 @@ def makeStackPlotWithRatio(
757
757
  if xlim:
758
758
  h = h[complex(0, xlim[0]) : complex(0, xlim[1])]
759
759
 
760
- # If plotting from combine, apply the action to the underlying hist.
761
- # Don't do this for the generic case, as it screws up the ability to make multiple plots
762
- if fitresult:
763
- histInfo[k].hists[histName] = h
764
-
765
760
  if k != "Data":
766
761
  stack.append(h)
767
762
  else:
@@ -803,72 +798,10 @@ def makeStackPlotWithRatio(
803
798
  ratio_axes = None
804
799
  ax2 = None
805
800
 
806
- if fitresult:
807
- import uproot
808
-
809
- combine_result = uproot.open(fitresult)
810
-
811
- fittype = "prefit" if prefit else "postfit"
812
-
813
- # set histograms to prefit/postfit values
814
- for p in to_read:
815
-
816
- hname = f"expproc_{p}_{fittype}" if p != "Data" else "obs"
817
- vals = combine_result[hname].to_hist().values()
818
- if len(histInfo[p].hists[histName].values()) != len(vals):
819
- raise ValueError(
820
- f"The size of the combine histogram ({(vals.shape)}) is not consistent with the xlim or input hist ({histInfo[p].hists[histName].shape})"
821
- )
822
-
823
- histInfo[p].hists[histName].values()[...] = vals
824
- if p == "Data":
825
- histInfo[p].hists[histName].variances()[...] = vals
826
-
827
- # for postfit uncertaity bands
828
- axis = histInfo[to_read[0]].hists[histName].axes[0].edges
829
-
830
- # need to divide by bin width
831
- binwidth = axis[1:] - axis[:-1]
832
- hexp = combine_result[f"expfull_{fittype}"].to_hist()
833
- if hexp.storage_type != hist.storage.Weight:
834
- raise ValueError(
835
- f"Did not find uncertainties in {fittype} hist. Make sure you run combinetf with --computeHistErrors!"
836
- )
837
- nom = hexp.values() / binwidth
838
- std = np.sqrt(hexp.variances()) / binwidth
839
-
840
- hatchstyle = "///"
841
- ax1.fill_between(
842
- axis,
843
- np.append(nom + std, (nom + std)[-1]),
844
- np.append(nom - std, (nom - std)[-1]),
845
- step="post",
846
- facecolor="none",
847
- zorder=2,
848
- hatch=hatchstyle,
849
- edgecolor="k",
850
- linewidth=0.0,
851
- label="Uncertainty",
852
- )
853
-
854
- if add_ratio:
855
- ax2.fill_between(
856
- axis,
857
- np.append((nom + std) / nom, ((nom + std) / nom)[-1]),
858
- np.append((nom - std) / nom, ((nom - std) / nom)[-1]),
859
- step="post",
860
- facecolor="none",
861
- zorder=2,
862
- hatch=hatchstyle,
863
- edgecolor="k",
864
- linewidth=0.0,
865
- )
866
-
867
801
  opts = dict(stack=not no_stack, flow=flow)
868
802
  optsr = opts.copy() # no binwnorm for ratio axis
869
- optsr["density"] = density
870
803
  if density:
871
- opts["density"] = True
804
+ opts["density"] = True
872
805
  else:
873
806
  opts["binwnorm"] = binwnorm
874
807
 
@@ -889,13 +822,14 @@ def makeStackPlotWithRatio(
889
822
  for x in (data_hist.sum(), hh.sumHists(stack).sum())
890
823
  ]
891
824
  scale = vals[0] / vals[1]
892
- unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2)**0.5
825
+ unc = scale * (varis[0] / vals[0] ** 2 + varis[1] / vals[1] ** 2) ** 0.5
893
826
  ndigits = -math.floor(math.log10(abs(unc))) + 1
894
827
  logger.info(
895
828
  f"Rescaling all processes by {round(scale,ndigits)} +/- {round(unc,ndigits)} to match data norm"
896
829
  )
897
830
  stack = [s * scale for s in stack]
898
831
 
832
+
899
833
  hep.histplot(
900
834
  stack,
901
835
  histtype="fill" if not no_fill else "step",
@@ -994,8 +928,7 @@ def makeStackPlotWithRatio(
994
928
 
995
929
  for i, (proc, style) in enumerate(zip(unstacked, linestyles)):
996
930
  unstack = histInfo[proc].hists[histName]
997
- if not fitresult or proc not in to_read:
998
- unstack = action(unstack)[select]
931
+ unstack = action(unstack)[select]
999
932
  if proc != "Data":
1000
933
  unstack = unstack * scale
1001
934
  if len(scaleRatioUnstacked) > i:
@@ -1042,6 +975,11 @@ def makeStackPlotWithRatio(
1042
975
  extra_labels.append(histInfo[proc].label)
1043
976
  if ratio_to_data and proc == "Data" or not add_ratio:
1044
977
  continue
978
+ if xlim:
979
+ unstack = unstack[complex(0, xlim[0]) : complex(0, xlim[1])]
980
+ if density:
981
+ unstack = hh.scaleHist(unstack, np.sum(ratio_ref.values())/np.sum(unstack.values()))
982
+
1045
983
  stack_ratio = hh.divideHists(
1046
984
  unstack,
1047
985
  ratio_ref,
@@ -1068,6 +1006,14 @@ def makeStackPlotWithRatio(
1068
1006
  **optsr,
1069
1007
  )
1070
1008
 
1009
+ if len(x_vertLines_edges):
1010
+ h_inclusive = hh.sumHists(stack)
1011
+ max_y = 1.05 * np.max(h_inclusive.values() + h_inclusive.variances() ** 0.5)
1012
+ min_y = np.min(h_inclusive.values() - h_inclusive.variances() ** 0.5)
1013
+ for x in x_vertLines_edges:
1014
+ ax1.plot([x, x], [min_y, max_y], linestyle="--", color="black")
1015
+ ax2.plot([x, x], [rrange[0], rrange[1]], linestyle="--", color="black")
1016
+
1071
1017
  addLegend(
1072
1018
  ax1,
1073
1019
  nlegcols,
@@ -1077,7 +1023,7 @@ def makeStackPlotWithRatio(
1077
1023
  text_size=legtext_size,
1078
1024
  padding_loc=leg_padding,
1079
1025
  )
1080
- if add_ratio:
1026
+ if add_ratio and lowerLeg:
1081
1027
  addLegend(
1082
1028
  ax2,
1083
1029
  lowerLegCols,
@@ -1105,7 +1051,7 @@ def makeStackPlotWithRatio(
1105
1051
  def makePlotWithRatioToRef(
1106
1052
  hists,
1107
1053
  labels,
1108
- colors,
1054
+ colors=None,
1109
1055
  hists_ratio=None,
1110
1056
  midratio_idxs=None,
1111
1057
  linestyles=[],
@@ -1141,6 +1087,7 @@ def makePlotWithRatioToRef(
1141
1087
  cms_label=None,
1142
1088
  cutoff=1e-6,
1143
1089
  only_ratio=False,
1090
+ ratio_legend=True,
1144
1091
  width_scale=1,
1145
1092
  automatic_scale=True,
1146
1093
  base_size=8,
@@ -1162,9 +1109,19 @@ def makePlotWithRatioToRef(
1162
1109
  elif select is not None:
1163
1110
  hists_ratio = [h[select] for h in hists_ratio]
1164
1111
 
1165
- if len(hists_ratio) != len(labels) or len(hists_ratio) != len(colors):
1112
+ if colors is None:
1113
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][: len(hists)]
1114
+ if len(colors) < len(hists):
1115
+ colors = (
1116
+ colors
1117
+ + plt.rcParams["axes.prop_cycle"].by_key()["color"][
1118
+ : (len(hists) - len(colors))
1119
+ ]
1120
+ )
1121
+
1122
+ if len(hists_ratio) != len(labels):
1166
1123
  raise ValueError(
1167
- f"Number of hists ({len(hists_ratio)}), colors ({len(colors)}), and labels ({len(labels)}) must agree!"
1124
+ f"Number of hists ({len(hists_ratio)}) and labels ({len(labels)}) must agree!"
1168
1125
  )
1169
1126
  ratio_hists = [
1170
1127
  hh.divideHists(
@@ -1299,7 +1256,7 @@ def makePlotWithRatioToRef(
1299
1256
  fill_between=fill_between,
1300
1257
  dataIdx=dataIdx,
1301
1258
  baseline=baseline,
1302
- add_legend=not only_ratio,
1259
+ add_legend=ratio_legend and not only_ratio,
1303
1260
  )
1304
1261
  if midratio_hists:
1305
1262
  plotRatio(
@@ -1580,7 +1537,7 @@ def fix_axes(
1580
1537
  if noSci and not logy:
1581
1538
  redo_axis_ticks(ax1, "y")
1582
1539
  elif not logy:
1583
- ax1.ticklabel_format(style="sci", useMathText=True, axis="y", scilimits=(0, 0))
1540
+ ax1.ticklabel_format(style="sci", useMathText=True, axis="y", scilimits=(-2, 2))
1584
1541
 
1585
1542
  if ratio_axes is not None:
1586
1543
  if not isinstance(ratio_axes, (list, tuple, np.ndarray)):
@@ -1623,6 +1580,8 @@ def redo_axis_ticks(ax, axlabel, no_labels=False):
1623
1580
  fixedloc = ticker.FixedLocator(
1624
1581
  autoloc.tick_values(*getattr(ax, f"get_{axlabel}lim")())
1625
1582
  )
1583
+ if ax.get_xscale() == 'log':
1584
+ fixedloc = ticker.LogLocator(base=10, numticks=5)
1626
1585
  getattr(ax, f"{axlabel}axis").set_major_locator(fixedloc)
1627
1586
  ticks = getattr(ax, f"get_{axlabel}ticks")()
1628
1587
  labels = [format_axis_num(x, ticks[-1]) for x in ticks] if not no_labels else []
@@ -1865,19 +1824,21 @@ def read_axis_label(x, labels, with_unit=True):
1865
1824
  return x
1866
1825
 
1867
1826
 
1868
- def get_axis_label(config, default_keys=None, label=None, is_bin=False):
1827
+ def get_axis_label(config, default_keys=None, label=None, is_bin=False, with_unit=True):
1869
1828
  if label is not None:
1870
1829
  return label
1871
1830
 
1872
1831
  if default_keys is None:
1873
1832
  return "Bin index"
1833
+ elif isinstance(default_keys, str):
1834
+ default_keys = [default_keys]
1874
1835
 
1875
1836
  labels = getattr(config, "axis_labels", {})
1876
1837
 
1877
- if len(default_keys) == 1:
1838
+ if len(default_keys) == 1:
1878
1839
  if is_bin:
1879
1840
  return f"{read_axis_label(default_keys[0], labels, False)} bin"
1880
1841
  else:
1881
- return read_axis_label(default_keys[0], labels)
1842
+ return read_axis_label(default_keys[0], labels, with_unit)
1882
1843
  else:
1883
1844
  return f"({', '.join([read_axis_label(a, labels, False) for a in default_keys])}) bin"
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: wums
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: .
5
5
  Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
6
  License: MIT
@@ -1,3 +1,2 @@
1
1
  dist
2
- env
3
2
  wums
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes