triggerflow 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. trigger_dataset/__init__.py +0 -0
  2. trigger_dataset/core.py +88 -0
  3. trigger_loader/__init__.py +0 -0
  4. trigger_loader/cluster_manager.py +107 -0
  5. trigger_loader/loader.py +154 -0
  6. trigger_loader/processor.py +212 -0
  7. triggerflow/__init__.py +0 -0
  8. triggerflow/cli.py +122 -0
  9. triggerflow/core.py +617 -0
  10. triggerflow/interfaces/__init__.py +0 -0
  11. triggerflow/interfaces/uGT.py +187 -0
  12. triggerflow/mlflow_wrapper.py +270 -0
  13. triggerflow/starter/.gitignore +143 -0
  14. triggerflow/starter/README.md +0 -0
  15. triggerflow/starter/cookiecutter.json +5 -0
  16. triggerflow/starter/prompts.yml +9 -0
  17. triggerflow/starter/{{ cookiecutter.repo_name }}/.dvcignore +3 -0
  18. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitignore +143 -0
  19. triggerflow/starter/{{ cookiecutter.repo_name }}/.gitlab-ci.yml +56 -0
  20. triggerflow/starter/{{ cookiecutter.repo_name }}/README.md +29 -0
  21. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/README.md +26 -0
  22. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/catalog.yml +84 -0
  23. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters.yml +0 -0
  24. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_compile.yml +14 -0
  25. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_data_processing.yml +8 -0
  26. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_load_data.yml +5 -0
  27. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_training.yml +9 -0
  28. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/base/parameters_model_validation.yml +5 -0
  29. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/catalog.yml +90 -0
  30. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters.yml +0 -0
  31. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_compile.yml +14 -0
  32. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_data_processing.yml +8 -0
  33. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_load_data.yml +5 -0
  34. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_training.yml +9 -0
  35. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/local/parameters_model_validation.yml +5 -0
  36. triggerflow/starter/{{ cookiecutter.repo_name }}/conf/logging.yml +43 -0
  37. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/.gitkeep +0 -0
  38. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/condor_config.json +11 -0
  39. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/cuda_config.json +4 -0
  40. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/samples.json +24 -0
  41. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/settings.json +8 -0
  42. triggerflow/starter/{{ cookiecutter.repo_name }}/data/01_raw/test.root +0 -0
  43. triggerflow/starter/{{ cookiecutter.repo_name }}/data/02_loaded/.gitkeep +0 -0
  44. triggerflow/starter/{{ cookiecutter.repo_name }}/data/03_preprocessed/.gitkeep +0 -0
  45. triggerflow/starter/{{ cookiecutter.repo_name }}/data/04_models/.gitkeep +0 -0
  46. triggerflow/starter/{{ cookiecutter.repo_name }}/data/05_validation/.gitkeep +0 -0
  47. triggerflow/starter/{{ cookiecutter.repo_name }}/data/06_compile/.gitkeep +0 -0
  48. triggerflow/starter/{{ cookiecutter.repo_name }}/data/07_reporting/.gitkeep +0 -0
  49. triggerflow/starter/{{ cookiecutter.repo_name }}/dvc.yaml +7 -0
  50. triggerflow/starter/{{ cookiecutter.repo_name }}/environment.yml +23 -0
  51. triggerflow/starter/{{ cookiecutter.repo_name }}/pyproject.toml +50 -0
  52. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__init__.py +3 -0
  53. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +25 -0
  54. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/any_object.py +20 -0
  55. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_dataset.py +137 -0
  56. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/base_loader.py +101 -0
  57. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/meta_dataset.py +49 -0
  58. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_dataset.py +35 -0
  59. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/datasets/{{ cookiecutter.python_package }}_loader.py +32 -0
  60. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/__init__.py +0 -0
  61. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/base_model.py +155 -0
  62. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/models/{{ cookiecutter.python_package }}_model.py +16 -0
  63. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipeline_registry.py +17 -0
  64. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/__init__.py +10 -0
  65. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/nodes.py +70 -0
  66. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/compile/pipeline.py +20 -0
  67. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/__init__.py +10 -0
  68. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/nodes.py +41 -0
  69. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/data_processing/pipeline.py +28 -0
  70. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/__init__.py +10 -0
  71. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/nodes.py +13 -0
  72. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/load_data/pipeline.py +20 -0
  73. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/__init__.py +10 -0
  74. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/nodes.py +48 -0
  75. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_training/pipeline.py +24 -0
  76. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/__init__.py +10 -0
  77. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/nodes.py +31 -0
  78. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/pipelines/model_validation/pipeline.py +24 -0
  79. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/settings.py +46 -0
  80. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/__init__.py +0 -0
  81. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/metric.py +4 -0
  82. triggerflow/starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/utils/plotting.py +598 -0
  83. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/__init__.py +0 -0
  84. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/__init__.py +0 -0
  85. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/__init__.py +0 -0
  86. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/compile/test_pipeline.py +9 -0
  87. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/__init__.py +0 -0
  88. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/data_processing/test_pipeline.py +9 -0
  89. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/__init__.py +0 -0
  90. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/load_data/test_pipeline.py +9 -0
  91. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/__init__.py +0 -0
  92. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_training/test_pipeline.py +9 -0
  93. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/__init__.py +0 -0
  94. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/pipelines/model_validation/test_pipeline.py +9 -0
  95. triggerflow/starter/{{ cookiecutter.repo_name }}/tests/test_run.py +27 -0
  96. triggerflow/templates/build_ugt.tcl +46 -0
  97. triggerflow/templates/data_types.h +524 -0
  98. triggerflow/templates/makefile +28 -0
  99. triggerflow/templates/makefile_version +15 -0
  100. triggerflow/templates/model-gt.cpp +104 -0
  101. triggerflow/templates/model_template.cpp +63 -0
  102. triggerflow/templates/scales.h +20 -0
  103. triggerflow-0.3.4.dist-info/METADATA +206 -0
  104. triggerflow-0.3.4.dist-info/RECORD +107 -0
  105. triggerflow-0.3.4.dist-info/WHEEL +5 -0
  106. triggerflow-0.3.4.dist-info/entry_points.txt +2 -0
  107. triggerflow-0.3.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,598 @@
1
+ import numpy as np
2
+ import awkward as ak
3
+ from sklearn.metrics import roc_curve
4
+ from scipy import interpolate
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ import mplhep as hep
8
+ import shap
9
+
10
+
11
+ def set_global_plotting_settings():
12
+ plt.style.use(hep.style.ROOT)
13
+
14
+ # there is a super strange hack here:
15
+ # you need to create a plot, and then only the second (and following) ones have the correct text sizes
16
+ # if there is any better solution for this, let me know!
17
+ fig = plt.figure()
18
+ matplotlib.rcParams.update({"font.size": 26})
19
+ matplotlib.rcParams.update({"figure.facecolor": "white"})
20
+
21
+ plt.close()
22
+
23
+
24
+ # a collection of plotting functions
25
+
26
+
27
+ # helper function to have a central definition of the total L1 rate
28
+ def totalMinBiasRate():
29
+ LHCfreq = 11245.6
30
+ nCollBunch = 2544
31
+
32
+ return LHCfreq * nCollBunch / 1e3 # in kHz
33
+
34
+
35
+ # helper function to get best triggers for a signal
36
+ def getBestTriggers(bits, n_best):
37
+ count_fires = ak.to_pandas(ak.sum(bits, axis=0))
38
+ best_columns = count_fires.T[0].sort_values(ascending=False)[:n_best].index.values
39
+
40
+ return best_columns
41
+
42
+
43
+ # a helper wrapper around the sklearn roc_curve method to handle multiple inputs at the same time
44
+ # expected return is fpr, tpr and thr
45
+ # where in the case of kfolding, the tpr will be a list with [mean, std]
46
+ def roc_curve_handlekfold(y_true, y_pred, weights):
47
+ if isinstance(y_true, list) and isinstance(y_pred, list):
48
+ # first calculating ROC curve for all examples
49
+ fprs = []
50
+ tprs = []
51
+ thrs = []
52
+ for y_true_one, y_pred_one, weights_one in zip(y_true, y_pred, weights):
53
+ fpr_this, tpr_this, thr_this = roc_curve(
54
+ y_true_one,
55
+ y_pred_one,
56
+ drop_intermediate=False,
57
+ sample_weight=weights_one,
58
+ )
59
+
60
+ fprs.append(fpr_this)
61
+ tprs.append(tpr_this)
62
+ thrs.append(thr_this)
63
+
64
+ # ensuring that we have the same thresholds everywhere
65
+ new_tprs = []
66
+ for fpr, tpr in zip(fprs, tprs):
67
+ func = interpolate.interp1d(
68
+ fpr, tpr, kind="linear", bounds_error=False, fill_value=0
69
+ )
70
+ new_tprs.append(func(fprs[0])) # just using the first as our baseline
71
+
72
+ new_tprs = np.asarray(new_tprs)
73
+
74
+ # now we know that these are in agreement, we can get mean and std
75
+ tpr_mean = np.mean(new_tprs, axis=0)
76
+ tpr_std = np.std(new_tprs, axis=0)
77
+
78
+ return fprs[0], [tpr_mean, tpr_std], thrs[0]
79
+
80
+ else:
81
+ return roc_curve(y_true, y_pred, drop_intermediate=False, sample_weight=weights)
82
+
83
+
84
+ # function to calculate the pure rate
85
+ def roc_curve_pure(
86
+ y_true, y_pred, other, FPRpure=True, TPRpure=True, verbosity=0, weights=None
87
+ ):
88
+ assert len(y_true) == len(y_pred) == len(other)
89
+ if (not FPRpure) and (not TPRpure):
90
+ raise Exception(
91
+ "pureROC: make at least one of FPR or TPR pure, or don't use this."
92
+ )
93
+
94
+ # all following stuff needs to also work if we are kfolding!
95
+ if (
96
+ isinstance(y_true, list)
97
+ and isinstance(y_pred, list)
98
+ and isinstance(other, list)
99
+ ):
100
+ y_pred_copy = []
101
+
102
+ for y_true_one, y_pred_one, other_one in zip(y_true, y_pred, other):
103
+ y_pred_copy_one = y_pred_one.copy()
104
+
105
+ # first, we assume everything that is already triggered by "other" is not triggered here -> setting to -1
106
+ if FPRpure:
107
+ mask = other_one & (y_true_one == 0)
108
+ y_pred_copy_one[mask] = 0
109
+ if TPRpure:
110
+ mask = other_one & (y_true_one == 1)
111
+ y_pred_copy_one[mask] = 0
112
+
113
+ y_pred_copy.append(y_pred_copy_one)
114
+
115
+ else:
116
+ # we'll construct a new prediction that gives the "pure FPR" and "pure TPR", whatever that is :D
117
+ y_pred_copy = y_pred.copy()
118
+
119
+ # first, we assume everything that is already triggered by "other" is not triggered here -> setting to -1
120
+ if FPRpure:
121
+ mask = other & (y_true == 0)
122
+ y_pred_copy[mask] = 0
123
+ if TPRpure:
124
+ mask = other & (y_true == 1)
125
+ y_pred_copy[mask] = 0
126
+
127
+ fpr_pure, tpr_pure, thr_pure = roc_curve_handlekfold(
128
+ y_true, y_pred_copy, weights=weights
129
+ )
130
+
131
+ return fpr_pure, tpr_pure, thr_pure
132
+
133
+
134
+ # Flexible function to plot ROC curves
135
+ # Minimal input are y_values for model prediction and the corresponding truth values
136
+ # FPR and TPR modes allow two main options: pure and rate / total (only for FPR / tpr)
137
+ # for this to work, a pureReference must be passed!
138
+ def plotROC(
139
+ y_true,
140
+ y_pred,
141
+ weights=None,
142
+ FPRmode="normal",
143
+ TPRmode="normal",
144
+ pureReference=None,
145
+ drawBandIfKfold=True,
146
+ ax=None,
147
+ verbosity=0,
148
+ **kwargs,
149
+ ):
150
+ if not ax:
151
+ f, ax = plt.subplots()
152
+
153
+ if "total" in TPRmode and not "pure" in TPRmode:
154
+ if verbosity > 0:
155
+ print(
156
+ "You want to plot a total efficiency. This always need to be handled 'pure' for the NN relative to the trigger to combine with, so the option pure is implicitely added. You can suppress this warning by adding 'pure' to your TPRmode manually."
157
+ )
158
+ TPRmode += " pure"
159
+
160
+ # assure that a pureReference is given if a mode requires pure
161
+ if ("pure" in FPRmode) or ("pure" in TPRmode):
162
+ if weights == None and isinstance(
163
+ y_true, list
164
+ ): # Necessary to handle kfold. Need a None for each fold
165
+ weights_None = [None for _ in y_true]
166
+ fpr, tpr, thr = roc_curve_pure(
167
+ y_true,
168
+ y_pred,
169
+ pureReference,
170
+ FPRpure=("pure" in FPRmode),
171
+ TPRpure=("pure" in TPRmode),
172
+ weights=weights_None,
173
+ )
174
+ # if pureReference == None: raise Exception("If FPR mode or TPR mode contains pure, please pass a pureReference!")
175
+ else:
176
+ fpr, tpr, thr = roc_curve_pure(
177
+ y_true,
178
+ y_pred,
179
+ pureReference,
180
+ FPRpure=("pure" in FPRmode),
181
+ TPRpure=("pure" in TPRmode),
182
+ weights=weights,
183
+ )
184
+ else:
185
+ if weights == None and isinstance(
186
+ y_true, list
187
+ ): # Necessary to handle kfold. Need a None for each fold
188
+ weights_None = [None for _ in y_true]
189
+ fpr, tpr, thr = roc_curve_handlekfold(y_true, y_pred, weights=weights_None)
190
+ else:
191
+ fpr, tpr, thr = roc_curve_handlekfold(y_true, y_pred, weights=weights)
192
+
193
+ # if we want to consider rate, scale the y axis
194
+ if "rate" in FPRmode:
195
+ fpr *= totalMinBiasRate()
196
+
197
+ # handle potential total rate
198
+ if "total" in TPRmode:
199
+ if isinstance(y_true, list) and isinstance(pureReference, list):
200
+ efficiencies = []
201
+ for y_true_one, pureReference_one in zip(y_true, pureReference):
202
+ efficiencies.append(
203
+ np.count_nonzero(pureReference_one[y_true_one == 1])
204
+ / len(pureReference_one[y_true_one == 1])
205
+ )
206
+ efficiency = np.mean(np.asarray(efficiencies))
207
+
208
+ tpr[0] += efficiency
209
+
210
+ else:
211
+ efficiency = np.count_nonzero(pureReference[y_true == 1]) / len(
212
+ pureReference[y_true == 1]
213
+ )
214
+ tpr += efficiency
215
+
216
+ if verbosity > 0:
217
+ print(
218
+ "Added a cutbased efficiency of " + str(efficiency) + " to the curve."
219
+ )
220
+
221
+ # now do the plotting
222
+ if isinstance(tpr, list): # kfold
223
+ same_color = ax.plot(fpr, tpr[0], **kwargs)[0].get_color()
224
+ if drawBandIfKfold:
225
+ ax.fill_between(
226
+ fpr, tpr[0] + tpr[1], tpr[0] - tpr[1], alpha=0.5, color=same_color
227
+ )
228
+ else:
229
+ ax.plot(fpr, tpr, **kwargs)
230
+
231
+ # and styling
232
+ ylabel = "Signal efficiency"
233
+ if "rate" in FPRmode:
234
+ ax.set_xlim(0, 20)
235
+ ax.set_ylim(0, 1)
236
+ xlabel = "L1 rate [kHz]"
237
+ else:
238
+ ax.set_xscale("log")
239
+ ax.set_yscale("log")
240
+ plt.plot([0, 1], [0, 1], "k--")
241
+ xlabel = "Background efficiency"
242
+
243
+ ax.set_ylabel
244
+
245
+ if "pure" in FPRmode:
246
+ xlabel += " (pure)"
247
+ if "pure" in TPRmode:
248
+ ylabel += " (pure)"
249
+
250
+ ax.set_xlabel(xlabel)
251
+ ax.set_ylabel(ylabel)
252
+
253
+ ax.grid(True)
254
+
255
+
256
+ def plotTrigger(
257
+ bits,
258
+ y_test,
259
+ weights=None,
260
+ mode="point",
261
+ FPRmode="normal",
262
+ TPRmode="normal",
263
+ pureReference=None,
264
+ ax=None,
265
+ verbosity=0,
266
+ **kwargs,
267
+ ):
268
+ # first, split datasets into parts
269
+ y_test = np.asarray(y_test).astype(bool)
270
+ bits_sig = bits[y_test]
271
+ bits_bkg = bits[np.logical_not(y_test)]
272
+ if np.all(weights) != None:
273
+ weights_sig = weights[y_test]
274
+ weights_bkg = weights[np.logical_not(y_test)]
275
+ assert len(bits) == len(bits_sig) + len(bits_bkg)
276
+
277
+ # calculating fpr and tpr
278
+ if np.all(weights) == None:
279
+ fpr = ak.sum(bits_bkg) / len(bits_bkg)
280
+ tpr = ak.sum(bits_sig) / len(bits_sig)
281
+ else:
282
+ fpr = ak.sum(np.array(weights_bkg)[bits_bkg]) / np.sum(weights_bkg)
283
+ tpr = ak.sum(np.array(weights_sig)[bits_sig]) / np.sum(weights_sig)
284
+
285
+ plotROCpoint(
286
+ fpr,
287
+ tpr,
288
+ mode=mode,
289
+ FPRmode=FPRmode,
290
+ TPRmode=TPRmode,
291
+ pureReference=pureReference,
292
+ ax=ax,
293
+ verbosity=verbosity,
294
+ **kwargs,
295
+ )
296
+
297
+
298
+ # stolen from https://stackoverflow.com/questions/29321835/is-it-possible-to-get-color-gradients-under-curve-in-matplotlib
299
+ # with some changes. Will be used in the function below
300
+ import matplotlib.colors as mcolors
301
+ from matplotlib.patches import Polygon
302
+
303
+
304
+ def gradient_fill(x, y, fill_color=None, ax=None, **kwargs):
305
+ """
306
+ Plot a line with a linear alpha gradient filled beneath it.
307
+
308
+ Parameters
309
+ ----------
310
+ x, y : array-like
311
+ The data values of the line.
312
+ fill_color : a matplotlib color specifier (string, tuple) or None
313
+ The color for the fill. If None, the color of the line will be used.
314
+ ax : a matplotlib Axes instance
315
+ The axes to plot on. If None, the current pyplot axes will be used.
316
+ Additional arguments are passed on to matplotlib's ``plot`` function.
317
+
318
+ Returns
319
+ -------
320
+ line : a Line2D instance
321
+ The line plotted.
322
+ im : an AxesImage instance
323
+ The transparent gradient clipped to just the area beneath the curve.
324
+ """
325
+ if ax is None:
326
+ ax = plt.gca()
327
+
328
+ (line,) = ax.plot(x, y, **kwargs)
329
+ if fill_color is None:
330
+ fill_color = line.get_color()
331
+
332
+ zorder = line.get_zorder()
333
+ alpha = line.get_alpha()
334
+ alpha = 1.0 if alpha is None else alpha
335
+
336
+ z = np.empty((100, 1, 4), dtype=float)
337
+ rgb = mcolors.colorConverter.to_rgb(fill_color)
338
+ z[:, :, :3] = rgb
339
+ z[:, :, -1] = np.linspace(0, alpha, 100)[:, None]
340
+
341
+ xmin, xmax, ymin, ymax = x.min(), x.max(), y.min(), y.max()
342
+ ymin -= 0.2
343
+ im = ax.imshow(
344
+ z, aspect="auto", extent=[xmin, xmax, ymin, ymax], origin="lower", zorder=zorder
345
+ )
346
+
347
+ xy = np.column_stack([x, y])
348
+ xy = np.vstack([[xmin, ymin], xy, [xmax, ymin], [xmin, ymin]])
349
+ clip_path = Polygon(xy, facecolor="none", edgecolor="none", closed=True)
350
+ ax.add_patch(clip_path)
351
+ im.set_clip_path(clip_path)
352
+
353
+ return line, im
354
+
355
+
356
+ # Method to plot the result of a single trigger
357
+ def plotROCpoint(
358
+ fpr,
359
+ tpr,
360
+ mode="point",
361
+ FPRmode="normal",
362
+ TPRmode="normal",
363
+ pureReference=None,
364
+ ax=None,
365
+ verbosity=0,
366
+ **kwargs,
367
+ ):
368
+ if not ax:
369
+ f, ax = plt.subplots()
370
+
371
+ if "rate" in FPRmode:
372
+ fpr *= totalMinBiasRate()
373
+
374
+ if ("pure" in FPRmode) or ("pure" in TPRmode):
375
+ raise Exception("Pure rate/efficiency not yet implemented for trigger plotting")
376
+
377
+ # plotting...
378
+ if mode == "line":
379
+ gradient_fill(np.asarray([0, 9999999]), np.asarray([tpr, tpr]), ax=ax, **kwargs)
380
+ else:
381
+ ax.plot(fpr, tpr, "o", **kwargs)
382
+
383
+ # and styling
384
+ xlabel = ""
385
+ ylabel = "Signal efficiency"
386
+ if "rate" in FPRmode:
387
+ ax.set_xlim(0, 20)
388
+ ax.set_ylim(0, 1)
389
+ xlabel = "L1 rate [kHz]"
390
+ else:
391
+ ax.set_xscale("log")
392
+ ax.set_yscale("log")
393
+ plt.plot([0, 1], [0, 1], "k--")
394
+ xlabel = "Background efficiency"
395
+
396
+ ax.set_ylabel
397
+
398
+ if "pure" in FPRmode:
399
+ xlabel += " (pure)"
400
+ if "pure" in TPRmode:
401
+ ylabel += " (pure)"
402
+
403
+ ax.set_xlabel(xlabel)
404
+ ax.set_ylabel(ylabel)
405
+
406
+ ax.grid(True)
407
+
408
+
409
+ # histogram plotting function, for example to compare situation before/after trigger
410
+ def plotHist(
411
+ data,
412
+ bins=10,
413
+ weights=None,
414
+ interval=None,
415
+ logy=False,
416
+ logx=False,
417
+ density=False,
418
+ ax=None,
419
+ divide_by_bin_width=False,
420
+ verbosity=0,
421
+ **kwargs,
422
+ ):
423
+ if not ax:
424
+ fig, ax = plt.subplots()
425
+
426
+ # convert awkward array to numpy array, if needed
427
+ try:
428
+ data_np = ak.flatten(data).to_numpy()
429
+ except:
430
+ data_np = data
431
+
432
+ # creating histogram from data
433
+ hist_data, hist_edges = np.histogram(
434
+ data_np, bins=bins, range=interval, density=density, weights=weights
435
+ )
436
+
437
+ if divide_by_bin_width:
438
+ bin_widths = np.diff(hist_edges)
439
+ assert hist_data.shape == bin_widths.shape
440
+ hist_data = hist_data / bin_widths
441
+
442
+ # plotting the histogram
443
+ hep.histplot(hist_data, hist_edges, ax=ax, **kwargs)
444
+ if logy:
445
+ ax.set_yscale("log")
446
+ if logx:
447
+ ax.set_xscale("log")
448
+
449
+ ax.set_ylabel("events")
450
+ if density:
451
+ ax.set_ylabel("events / total events")
452
+
453
+
454
+ def plotEfficiency(
455
+ efficiency, bins, error, logy=False, logx=False, ax=None, verbosity=0, **kwargs
456
+ ):
457
+ if not ax:
458
+ fig, ax = plt.subplots()
459
+
460
+ # plotting the histogram
461
+ hep.histplot(efficiency, bins, yerr=error, ax=ax, **kwargs)
462
+ if logy:
463
+ ax.set_yscale("log")
464
+ if logx:
465
+ ax.set_xscale("log")
466
+
467
+ ax.set_ylabel("efficiency")
468
+
469
+
470
+ def plotRateVsLumi(triggerResults, runInfo, interval=100, verbosity=0, **kwargs):
471
+ raise Exception("Rate vs. lumi is not implemented yet!")
472
+
473
+
474
+ def plotStability(
475
+ triggerResults, runInfo=None, interval=100, ax=None, verbosity=0, **kwargs
476
+ ):
477
+ # a function to plot the stability of some trigger
478
+
479
+ if not runInfo:
480
+ # just do a stability plot with a fixed interval
481
+ # if a runInfo object is passed, call plotRateVsLumi instead
482
+
483
+ # we'll average the dataframe entries. first, convert to numpy
484
+ np_triggerResults = triggerResults.to_numpy()
485
+
486
+ # for the following to work, we need to assure that the length is dividable by the desired interval
487
+ nLastElements = len(triggerResults) % interval
488
+ nBlocks = int(len(np_triggerResults) / interval)
489
+
490
+ if verbosity > 1:
491
+ print(
492
+ "Found "
493
+ + str(nBlocks)
494
+ + " blocks with "
495
+ + str(nLastElements)
496
+ + " leftover events"
497
+ )
498
+
499
+ # split events to handle leftovers later
500
+ # (some code improvement might be possible here)
501
+ np_lastElements = np_triggerResults[len(np_triggerResults) - nLastElements :]
502
+ np_triggerResults = np_triggerResults[: len(np_triggerResults) - nLastElements]
503
+
504
+ # block the results
505
+ np_triggerResults_blocked = np_triggerResults.reshape(nBlocks, interval)
506
+
507
+ # calculate passed events along axis
508
+ np_passedEvents = np.sum(np_triggerResults_blocked.astype("int"), axis=1)
509
+ if nLastElements > 0:
510
+ np_passedEventsLeftovers = np.sum(np_lastElements)
511
+
512
+ # calculate efficiency & get bin edges
513
+ np_efficiency = np_passedEvents / interval
514
+ binEdges = np.arange(0, (len(np_efficiency) + 1) * interval, interval)
515
+
516
+ # handle if there are leftovers
517
+ if nLastElements > 0:
518
+ np_efficiency = np.append(
519
+ np_efficiency, np_passedEventsLeftovers / nLastElements
520
+ )
521
+ binEdges = np.append(binEdges, len(triggerResults))
522
+
523
+ if not ax:
524
+ fig, ax = plt.subplots(figsize=(20, 8))
525
+
526
+ # plotting
527
+ hep.histplot(
528
+ np_efficiency, binEdges, label=triggerResults.name, ax=ax, **kwargs
529
+ )
530
+
531
+ # styling
532
+ ax.set_ylabel("efficiency")
533
+ ax.set_xlabel("event")
534
+ ax.set_xlim(0, len(triggerResults))
535
+ ax.ticklabel_format(style="plain")
536
+
537
+ return np_efficiency
538
+ else:
539
+ return plotRateVsLumi(triggerResults, runInfo, interval=interval, **kwargs)
540
+
541
+
542
+ # plotting function for a generic training history
543
+ def plotTrainingHistory(history, metrics=["loss", "accuracy"], f=None, axs=None):
544
+ # creating the plot
545
+ if not f and not axs:
546
+ f, axs = plt.subplots(
547
+ len(metrics), 1, figsize=(12, 4 * len(metrics)), sharex=True
548
+ )
549
+ if len(metrics) == 1:
550
+ axs = [axs]
551
+ plt.subplots_adjust(wspace=0, hspace=0)
552
+
553
+ # labeling
554
+ hep.cms.label("private work", data=False, ax=axs[0])
555
+
556
+ for i in range(len(metrics)):
557
+ metric = metrics[i]
558
+ ax = axs[i]
559
+ ax.set_ylabel(metric)
560
+
561
+ if isinstance(history, list): # handle kfold
562
+ for foldi in range(len(history)):
563
+ ax.plot(history[foldi].history[metric], color="C{}".format(foldi))
564
+ ax.plot(
565
+ history[foldi].history["val_" + metric],
566
+ color="C{}".format(foldi),
567
+ linestyle="--",
568
+ )
569
+
570
+ (la2,) = ax.plot([0, 0], [0, 0], color="Grey")
571
+ (lb2,) = ax.plot([0, 0], [0, 0], color="Grey", linestyle="--")
572
+ ax.legend([la2, lb2], ["training", "validation"])
573
+ else:
574
+ ax.plot(history.history[metric], label="training")
575
+ ax.plot(history.history["val_" + metric], label="validation")
576
+ ax.legend()
577
+
578
+ axs[-1].set_xlabel("Epoch")
579
+
580
+ return f, axs
581
+
582
+
583
+ # plotting function for shapley feature importance
584
+ # make sure that len(x) > 1100 if you use the default n_* values!
585
+ def plot_feature_importance(
586
+ model, x, n_fit=1000, n_explain=100, feature_names=None, show=False
587
+ ):
588
+ explainer = shap.DeepExplainer(model, x[0:n_fit])
589
+
590
+ # explain the first 100 predictions
591
+ # explaining each prediction requires 2 * background dataset size runs
592
+ shap_values = explainer.shap_values(x[n_fit : n_fit + n_explain])
593
+ shap.summary_plot(
594
+ shap_values, feature_names=feature_names, max_display=x.shape[1], show=show
595
+ )
596
+
597
+ def get_dummy():
598
+ return plt.subplots()
@@ -0,0 +1,9 @@
1
+ """
2
+ This is a boilerplate test file for pipeline 'compile'
3
+ generated using Kedro 1.0.0
4
+ Please add your pipeline tests here.
5
+
6
+ Kedro recommends using `pytest` framework, more info about it can be found
7
+ in the official documentation:
8
+ https://docs.pytest.org/en/latest/getting-started.html
9
+ """
@@ -0,0 +1,9 @@
1
+ """
2
+ This is a boilerplate test file for pipeline 'data_processing'
3
+ generated using Kedro 1.0.0
4
+ Please add your pipeline tests here.
5
+
6
+ Kedro recommends using `pytest` framework, more info about it can be found
7
+ in the official documentation:
8
+ https://docs.pytest.org/en/latest/getting-started.html
9
+ """
@@ -0,0 +1,9 @@
1
+ """
2
+ This is a boilerplate test file for pipeline 'load_data'
3
+ generated using Kedro 1.0.0
4
+ Please add your pipeline tests here.
5
+
6
+ Kedro recommends using `pytest` framework, more info about it can be found
7
+ in the official documentation:
8
+ https://docs.pytest.org/en/latest/getting-started.html
9
+ """
@@ -0,0 +1,9 @@
1
+ """
2
+ This is a boilerplate test file for pipeline 'model_training'
3
+ generated using Kedro 1.0.0
4
+ Please add your pipeline tests here.
5
+
6
+ Kedro recommends using `pytest` framework, more info about it can be found
7
+ in the official documentation:
8
+ https://docs.pytest.org/en/latest/getting-started.html
9
+ """
@@ -0,0 +1,9 @@
1
+ """
2
+ This is a boilerplate test file for pipeline 'model_validation'
3
+ generated using Kedro 1.0.0
4
+ Please add your pipeline tests here.
5
+
6
+ Kedro recommends using `pytest` framework, more info about it can be found
7
+ in the official documentation:
8
+ https://docs.pytest.org/en/latest/getting-started.html
9
+ """
@@ -0,0 +1,27 @@
1
+ """
2
+ This module contains example tests for a Kedro project.
3
+ Tests should be placed in ``src/tests``, in modules that mirror your
4
+ project's structure, and in files named test_*.py.
5
+ """
6
+
7
+ import pytest
8
+ from pathlib import Path
9
+ from kedro.framework.session import KedroSession
10
+ from kedro.framework.startup import bootstrap_project
11
+
12
+ # The tests below are here for the demonstration purpose
13
+ # and should be replaced with the ones testing the project
14
+ # functionality
15
+
16
+
17
+ class TestKedroRun:
18
+ def test_kedro_run_no_pipeline(self):
19
+ # This example test expects a pipeline run failure, since
20
+ # the default project template contains no pipelines.
21
+ bootstrap_project(Path.cwd())
22
+
23
+ with pytest.raises(Exception) as excinfo:
24
+ with KedroSession.create(project_path=Path.cwd()) as session:
25
+ session.run()
26
+
27
+ assert "Pipeline contains no nodes" in str(excinfo.value)