tesorotools-python 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tesorotools/__init__.py +6 -0
  2. tesorotools/artists/__init__.py +5 -0
  3. tesorotools/artists/barh_plot.py +310 -0
  4. tesorotools/artists/line_plot.py +245 -0
  5. tesorotools/artists/table.py +200 -0
  6. tesorotools/artists/type_curve.py +218 -0
  7. tesorotools/assets/README.md +5 -0
  8. tesorotools/assets/fonts/CabinetGrotesk-Black.otf +0 -0
  9. tesorotools/assets/fonts/CabinetGrotesk-Bold.otf +0 -0
  10. tesorotools/assets/fonts/CabinetGrotesk-Extrabold.otf +0 -0
  11. tesorotools/assets/fonts/CabinetGrotesk-Extralight.otf +0 -0
  12. tesorotools/assets/fonts/CabinetGrotesk-Light.otf +0 -0
  13. tesorotools/assets/fonts/CabinetGrotesk-Medium.otf +0 -0
  14. tesorotools/assets/fonts/CabinetGrotesk-Regular.otf +0 -0
  15. tesorotools/assets/fonts/CabinetGrotesk-Thin.otf +0 -0
  16. tesorotools/assets/fonts/README.md +1 -0
  17. tesorotools/assets/plots.yaml +43 -0
  18. tesorotools/assets/tesoro.mplstyle +21 -0
  19. tesorotools/convert.py +99 -0
  20. tesorotools/data_sources/README.md +14 -0
  21. tesorotools/data_sources/__init__.py +0 -0
  22. tesorotools/data_sources/debug.py +26 -0
  23. tesorotools/data_sources/lseg.py +117 -0
  24. tesorotools/database/__init__.py +0 -0
  25. tesorotools/database/push.py +70 -0
  26. tesorotools/dependencies/__init__.py +0 -0
  27. tesorotools/dependencies/functions.py +11 -0
  28. tesorotools/dependencies/node.py +34 -0
  29. tesorotools/dependencies/resolution.py +118 -0
  30. tesorotools/main.py +37 -0
  31. tesorotools/offsets/__init__.py +0 -0
  32. tesorotools/offsets/offsets.py +439 -0
  33. tesorotools/offsets/outliers.py +15 -0
  34. tesorotools/render/__init__.py +17 -0
  35. tesorotools/render/content/__init__.py +0 -0
  36. tesorotools/render/content/content.py +17 -0
  37. tesorotools/render/content/images.py +147 -0
  38. tesorotools/render/content/section.py +53 -0
  39. tesorotools/render/content/subtitle.py +53 -0
  40. tesorotools/render/content/table.py +308 -0
  41. tesorotools/render/content/text.py +23 -0
  42. tesorotools/render/content/title.py +40 -0
  43. tesorotools/render/report.py +31 -0
  44. tesorotools/utils/__init__.py +0 -0
  45. tesorotools/utils/config.py +35 -0
  46. tesorotools/utils/globals.py +14 -0
  47. tesorotools/utils/matplotlib.py +38 -0
  48. tesorotools/utils/series.py +40 -0
  49. tesorotools/utils/shortcuts.py +32 -0
  50. tesorotools/utils/template.py +126 -0
  51. tesorotools_python-0.0.18.dist-info/METADATA +16 -0
  52. tesorotools_python-0.0.18.dist-info/RECORD +53 -0
  53. tesorotools_python-0.0.18.dist-info/WHEEL +4 -0
@@ -0,0 +1,6 @@
1
+ from tesorotools.artists.line_plot import Format, Legend, LinePlot
2
+ from tesorotools.utils.config import TemplateLoader
3
+
4
+ TemplateLoader.add_constructor("!line_plot", LinePlot.from_yaml)
5
+ TemplateLoader.add_constructor("!format", Format.from_yaml)
6
+ TemplateLoader.add_constructor("!legend", Legend.from_yaml)
@@ -0,0 +1,5 @@
1
+ import matplotlib.style
2
+
3
+ from ..utils.globals import STYLE_SHEET
4
+
5
+ matplotlib.style.use(STYLE_SHEET)
@@ -0,0 +1,310 @@
1
+ from enum import Enum
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ from matplotlib.container import BarContainer
8
+ from matplotlib.ticker import FuncFormatter
9
+
10
+ from tesorotools.offsets.offsets import Difference, FloatingOffset, Stat
11
+ from tesorotools.offsets.outliers import flag_outliers
12
+ from tesorotools.utils.globals import DEBUG
13
+ from tesorotools.utils.matplotlib import (
14
+ PLOT_CONFIG,
15
+ format_annotation,
16
+ load_fonts,
17
+ )
18
+
19
+ BARH_CONFIG: dict[str, Any] = PLOT_CONFIG["barh"]
20
+ AX_CONFIG: dict[str, Any] = PLOT_CONFIG["ax"]
21
+ FIG_CONFIG: dict[str, Any] = PLOT_CONFIG["figure"]
22
+
23
+ load_fonts()
24
+
25
+
26
+ class Column(Enum):
27
+ VALUE = "value"
28
+ AXIS = "axis"
29
+ DEVIATION = "deviation"
30
+ COLOR = "color"
31
+ ALPHA = "alpha"
32
+
33
+
34
+ def _style_spines(
35
+ ax: plt.Axes,
36
+ decimals: int,
37
+ units: str,
38
+ *,
39
+ color: str,
40
+ linewidth: str,
41
+ ):
42
+ ax.grid(visible=True, axis="x")
43
+ for spine in ax.spines.values():
44
+ spine.set_color(color)
45
+ spine.set_linewidth(linewidth)
46
+ ax.xaxis.set_major_formatter(
47
+ FuncFormatter(lambda x, _: format_annotation(x, decimals, units))
48
+ )
49
+ ax.tick_params(axis="both", which="major")
50
+ for tick in ax.get_xticklines():
51
+ tick.set_markeredgecolor(color)
52
+ for tick in ax.get_yticklines():
53
+ tick.set_markeredgecolor(color)
54
+
55
+
56
+ def _style_baseline(ax: plt.Axes, **baseline_config):
57
+ color: str = baseline_config["color"]
58
+ left_lim, right_lim = ax.get_xlim()
59
+ ax.set_xlim(left=min(0, left_lim), right=max(0, right_lim))
60
+ left_lim, right_lim = ax.get_xlim()
61
+ if left_lim == 0:
62
+ ax.spines["left"].set_edgecolor(color)
63
+ elif right_lim == 0:
64
+ ax.spines["right"].set_edgecolor(color)
65
+ else:
66
+ ax.axvline(x=0, **baseline_config)
67
+
68
+
69
+ def _collect_series(
70
+ blocks: dict[str, Any] | None, series: dict[str, str] | None
71
+ ) -> dict[str, str]:
72
+ if series is None and blocks is None:
73
+ raise ValueError("blocks and series cannot be both missing")
74
+ if series is None and blocks is not None:
75
+ return _collect_block_series(blocks)
76
+ else:
77
+ return series
78
+
79
+
80
+ def _collect_block_series(blocks: dict[str, Any]) -> dict[str, str]:
81
+ series = {}
82
+ for _, block_cfg in blocks.items():
83
+ series = series | block_cfg["series"]
84
+ return series
85
+
86
+
87
+ def _infer_colors(
88
+ value_series: pd.Series, blocks: dict[str, Any] | None
89
+ ) -> pd.Series:
90
+ color_series: pd.Series = pd.Series(
91
+ index=value_series.index, name=Column.COLOR.value, dtype=str
92
+ )
93
+ if blocks is not None:
94
+ for idx, block_cfg in enumerate(blocks.values()):
95
+ block_series: dict[str, str] = block_cfg["series"]
96
+ color_series.loc[block_series.keys()] = f"C{idx}"
97
+ else:
98
+ color_series[value_series >= 0] = "C0"
99
+ color_series[value_series < 0] = "C1"
100
+ return color_series
101
+
102
+
103
+ def _highlight_series(
104
+ alias: dict[str, str], value_series: pd.Series
105
+ ) -> pd.Series:
106
+ alpha_series: pd.Series = pd.Series(
107
+ index=value_series.index, name=Column.ALPHA.value
108
+ )
109
+ alpha_series.loc[:] = 1
110
+ high_series = [k for k, v in alias.items() if v.endswith("*")]
111
+ alpha_series.loc[high_series] = BARH_CONFIG["highlight_factor"]
112
+ return alpha_series
113
+
114
+
115
+ def _format_yaxis(
116
+ alias: dict[str, str],
117
+ axis_format: dict[str, Any],
118
+ value_series: pd.Series,
119
+ axis_series: pd.Series | None,
120
+ ) -> pd.Series:
121
+ # format y axis ticker labels
122
+ renamer = {_: label.replace("*", "") for _, label in alias.items()}
123
+ value_series = value_series.rename(renamer)
124
+ if axis_format is not None:
125
+ decimals: int = axis_format["decimals"]
126
+ units: str = axis_format["units"]
127
+ axis_series: pd.Series = axis_series.rename(renamer).apply(
128
+ lambda x: format_annotation(x, decimals, units)
129
+ )
130
+ value_series = value_series.rename(
131
+ lambda x: f"{x} ({axis_series.loc[x]})"
132
+ )
133
+ return value_series
134
+
135
+
136
+ def _annotate(
137
+ fig: plt.Figure,
138
+ ax: plt.Axes,
139
+ bar_container: BarContainer,
140
+ *,
141
+ decimals: int,
142
+ units: str,
143
+ ):
144
+ # annotate
145
+ labels = ax.bar_label(
146
+ container=bar_container,
147
+ fmt=lambda x: format_annotation(x, decimals, units),
148
+ padding=BARH_CONFIG["padding"],
149
+ )
150
+
151
+ # rescale
152
+ fig.canvas.draw_idle()
153
+ for label in labels:
154
+ bbox = label.get_window_extent()
155
+ bbox_data = bbox.transformed(ax.transData.inverted())
156
+ ax.update_datalim(bbox_data.corners())
157
+ ax.autoscale_view()
158
+
159
+
160
+ def _plot_barh_chart(
161
+ out_file: Path,
162
+ standard_dict: dict[Column, pd.Series | None],
163
+ alias: dict[str, str],
164
+ sorted: bool,
165
+ format: dict,
166
+ annot_format: dict,
167
+ axis_format: dict | None = None,
168
+ blocks: dict | None = None,
169
+ **kwargs,
170
+ ):
171
+ # infer colors
172
+ value_series: pd.Series = standard_dict[Column.VALUE]
173
+ color_series: pd.Series = _infer_colors(value_series, blocks)
174
+ alpha_series: pd.Series = _highlight_series(alias, value_series)
175
+
176
+ # format y axis ticker labels
177
+ axis_series = standard_dict[Column.AXIS]
178
+ value_series = _format_yaxis(alias, axis_format, value_series, axis_series)
179
+ color_series.index = value_series.index
180
+ alpha_series.index = value_series.index
181
+
182
+ data: pd.DataFrame = pd.concat(
183
+ [value_series, color_series, alpha_series], axis=1
184
+ )
185
+
186
+ # sort if required
187
+ if sorted:
188
+ data = data.sort_values(by=Column.VALUE.value)
189
+
190
+ # plot
191
+ fig = plt.figure(**FIG_CONFIG)
192
+ ax = fig.add_subplot()
193
+
194
+ bar_container: BarContainer = ax.barh(
195
+ y=data.index,
196
+ width=data[Column.VALUE.value],
197
+ color=data[Column.COLOR.value],
198
+ )
199
+ for bar, alpha in zip(bar_container, data[Column.ALPHA.value]):
200
+ bar.set_alpha(alpha)
201
+
202
+ _annotate(fig, ax, bar_container, **annot_format)
203
+ _style_spines(ax, **format, **AX_CONFIG["spines"])
204
+ _style_baseline(ax, **AX_CONFIG["baseline"])
205
+
206
+ fig.savefig(out_file)
207
+
208
+
209
+ def _normalize_from_flash(
210
+ flash: pd.DataFrame,
211
+ axis: bool,
212
+ *,
213
+ date: str | pd.Timestamp | None,
214
+ offset: str,
215
+ difference: str,
216
+ deviations: bool,
217
+ units_bar: str,
218
+ units_axis: str,
219
+ ) -> dict[Column, pd.Series | None]:
220
+
221
+ # format parameters
222
+ date: pd.Timestamp = (
223
+ flash.index.get_level_values(level=0).max()
224
+ if date is None
225
+ else pd.to_datetime(date)
226
+ )
227
+ offset: FloatingOffset = FloatingOffset(offset)
228
+ difference: Difference = Difference(difference)
229
+
230
+ # value column
231
+ values_series: pd.Series = flash.loc[
232
+ (date, offset.value, difference.value, Stat.VALUE.value),
233
+ :,
234
+ ].copy()
235
+ values_series.name = Column.VALUE.value
236
+ values_series = (
237
+ values_series * 100 if difference is Difference.REL else values_series
238
+ )
239
+ values_series = (
240
+ values_series * 100
241
+ if (difference is Difference.ABS and units_bar == "p.b.")
242
+ else values_series
243
+ )
244
+
245
+ # axis column
246
+ if axis:
247
+ axis_series: pd.Series = flash.loc[
248
+ (
249
+ date,
250
+ FloatingOffset.NO.value,
251
+ Difference.NO.value,
252
+ Stat.VALUE.value,
253
+ ),
254
+ :,
255
+ ].copy()
256
+ axis_series = (
257
+ axis_series * 100
258
+ if (difference is Difference.ABS and units_axis == "p.b.")
259
+ else axis_series
260
+ )
261
+
262
+ axis_series.name = Column.AXIS.value
263
+ else:
264
+ axis_series = None
265
+
266
+ # deviations column
267
+ if deviations:
268
+ deviations_df: pd.DataFrame = flash.loc[
269
+ (
270
+ date,
271
+ offset.value,
272
+ difference.value,
273
+ [Stat.VALUE.value, Stat.ROLL_AVG.value, Stat.ROLL_STD._value_],
274
+ ),
275
+ :,
276
+ ].T.copy()
277
+ deviations_df.columns = deviations_df.columns.get_level_values(level=-1)
278
+ deviations_df.columns.name = None
279
+ deviations_series: pd.Series = flag_outliers(deviations_df)
280
+ deviations_series.name = Column.DEVIATION.value
281
+ else:
282
+ deviations_series = None
283
+
284
+ return {
285
+ Column.VALUE: values_series,
286
+ Column.AXIS: axis_series,
287
+ Column.DEVIATION: deviations_series,
288
+ }
289
+
290
+
291
+ def plot_barh_charts_from_flash(
292
+ out_path: Path, flash: pd.DataFrame, config_dicts: dict[str, dict]
293
+ ):
294
+ for name, config in config_dicts.items():
295
+ blocks: dict[str, Any] = config.get("blocks", None)
296
+ series: dict[str, str] | None = config.get("series", None)
297
+ alias = _collect_series(blocks, series)
298
+ trimmed_flash: pd.DataFrame = flash.loc[:, alias.keys()]
299
+ flash_config: dict[str, Any] = config["flash"]
300
+ axis_format: dict[str, Any] = config.get("axis_format", None)
301
+ axis = axis_format is not None
302
+ standard_dict: dict[Column, pd.Series | None] = _normalize_from_flash(
303
+ trimmed_flash,
304
+ axis,
305
+ **flash_config,
306
+ units_bar=config["format"]["units"],
307
+ units_axis=config.get("axis_format", {"units": ""})["units"],
308
+ )
309
+ out_file = out_path / f"{name}.png"
310
+ _plot_barh_chart(out_file, standard_dict, alias, **config)
@@ -0,0 +1,245 @@
1
+ import datetime
2
+ import locale
3
+ from pathlib import Path
4
+ from typing import Any, Self
5
+
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+ from matplotlib.ticker import FuncFormatter
9
+ from yaml.nodes import MappingNode
10
+
11
+ from tesorotools.utils.config import TemplateLoader
12
+
13
+ locale.setlocale(locale.LC_ALL, "")
14
+
15
+ from tesorotools.utils.globals import DEBUG
16
+ from tesorotools.utils.matplotlib import (
17
+ PLOT_CONFIG,
18
+ format_annotation,
19
+ load_fonts,
20
+ )
21
+
22
+ load_fonts()
23
+
24
+ LINE_PLOT_CONFIG: dict[str, Any] = PLOT_CONFIG["line"]
25
+ AX_CONFIG: dict[str, Any] = PLOT_CONFIG["ax"]
26
+ FIG_CONFIG: dict[str, Any] = PLOT_CONFIG["figure"]
27
+
28
+
29
+ def _style_spines(
30
+ ax: plt.Axes,
31
+ decimals: int,
32
+ units: str,
33
+ *,
34
+ color: str,
35
+ linewidth: str,
36
+ ):
37
+ ax.grid(visible=True, axis="y")
38
+ for spine in ax.spines.values():
39
+ spine.set_color(color)
40
+ spine.set_linewidth(linewidth)
41
+ ax.yaxis.tick_right()
42
+ ax.yaxis.set_major_formatter(
43
+ FuncFormatter(lambda y, _: format_annotation(y, decimals, units))
44
+ )
45
+ ax.set_xlabel("")
46
+
47
+ ax.tick_params(which="minor", size=0, width=0)
48
+ ax.tick_params(axis="both", which="major")
49
+ for tick in ax.get_xticklines():
50
+ tick.set_markeredgecolor(color)
51
+ for tick in ax.get_yticklines():
52
+ tick.set_markeredgecolor(color)
53
+
54
+
55
+ def _style_baseline(ax: plt.Axes, reference: float = 0, **baseline_config):
56
+ color: str = baseline_config["color"]
57
+ bottom_lim, top_lim = ax.get_ylim()
58
+ ax.set_ylim(bottom=min(reference, bottom_lim), top=max(reference, top_lim))
59
+ bottom_lim, top_lim = ax.get_ylim()
60
+ if bottom_lim == reference:
61
+ ax.spines["bottom"].set_edgecolor(color)
62
+ elif top_lim == reference:
63
+ ax.spines["top"].set_edgecolor(color)
64
+ else:
65
+ ax.axhline(y=reference, **baseline_config)
66
+
67
+
68
+ def plot_line_chart(
69
+ out_name: Path,
70
+ data: pd.DataFrame,
71
+ *,
72
+ base_100: bool,
73
+ annotate: bool,
74
+ format: dict[str, Any],
75
+ **kwargs,
76
+ ):
77
+ if base_100:
78
+ data = data / data.iloc[0, :] * 100
79
+ if format["units"] == "p.b.":
80
+ data = data * 100
81
+ fig = plt.figure(**FIG_CONFIG)
82
+ ax = fig.add_subplot()
83
+ data.plot(ax=ax)
84
+ if annotate:
85
+ pass
86
+
87
+ reference = 100 if base_100 else 0
88
+ _style_spines(ax, **format, **AX_CONFIG["spines"])
89
+ _style_baseline(ax, reference, **AX_CONFIG["baseline"])
90
+ ax.legend(
91
+ loc="upper center",
92
+ bbox_to_anchor=(0.5, LINE_PLOT_CONFIG["legend_sep"]),
93
+ ncol=(
94
+ kwargs["legend"]["ncol"]
95
+ if kwargs is not None and kwargs.get("legend", None) is not None
96
+ else LINE_PLOT_CONFIG["ncol"]
97
+ ),
98
+ )
99
+
100
+ fig.savefig(out_name)
101
+
102
+
103
+ def plot_line_charts(
104
+ out_path: Path, data: pd.DataFrame, config_dicts: dict[str, Any]
105
+ ):
106
+ for name, config in config_dicts.items():
107
+ start_date: pd.Timestamp = pd.to_datetime(config["start_date"])
108
+ end_date_str: str | None = config["end_date"]
109
+ end_date: pd.Timestamp = (
110
+ data.index.max()
111
+ if end_date_str is None
112
+ else pd.to_datetime(end_date_str)
113
+ )
114
+ series: dict[str, str] = config["series"]
115
+ trimmed_data: pd.DataFrame = data.loc[
116
+ slice(start_date, end_date), series.keys()
117
+ ]
118
+ trimmed_data = trimmed_data.rename(columns=series)
119
+ out_name: Path = out_path / f"{name}.png"
120
+ plot_line_chart(out_name, trimmed_data, **config)
121
+
122
+
123
+ class Format:
124
+ def __init__(self, units: str = "", decimals: int = 0):
125
+ self.units = units
126
+ self.decimals = decimals
127
+
128
+ @classmethod
129
+ def from_yaml(cls, loader: TemplateLoader, node: MappingNode) -> Self:
130
+ loader.flatten_mapping(node)
131
+ format_cfg: dict[str, Any] = loader.construct_mapping(node, deep=True)
132
+ format_cfg.pop("id")
133
+ return cls(**format_cfg)
134
+
135
+
136
+ class Legend:
137
+ def __init__(self, ncol: int = 5, sep: float = -0.125):
138
+ self.ncol = ncol
139
+ self.sep = sep
140
+
141
+ @classmethod
142
+ def from_yaml(cls, loader: TemplateLoader, node: MappingNode) -> Self:
143
+ legend_cfg: dict[str, Any] = loader.construct_mapping(node, deep=True)
144
+ legend_cfg.pop("id")
145
+ return cls(**legend_cfg)
146
+
147
+
148
+ # as more stuff is needed, seems wise to make a class
149
+ class LinePlot:
150
+ def __init__(
151
+ self,
152
+ out_path: Path,
153
+ data_path: Path,
154
+ series: dict[str, str],
155
+ scale: float = 1,
156
+ start_date: datetime.datetime | None = None,
157
+ end_date: datetime.datetime | None = None,
158
+ base_100: bool = False,
159
+ annotate: bool = False,
160
+ baseline: bool = False,
161
+ format: Format | None = None,
162
+ legend: Legend | None = None,
163
+ ) -> None:
164
+
165
+ if out_path.suffix != ".png":
166
+ raise ValueError(f"The out file {out_path} should be a .png file")
167
+ self.out_path = out_path
168
+
169
+ if data_path.suffix != ".feather":
170
+ raise ValueError(
171
+ f"The data file {data_path} must be a .feather file"
172
+ )
173
+ self.data = pd.read_feather(data_path)
174
+
175
+ self.base_100 = base_100
176
+ self.annotate = annotate # unused for the moment
177
+ self.format = format
178
+ self.start_date = start_date
179
+ self.end_date = end_date
180
+ self.series = series
181
+ self.legend = legend
182
+ self.baseline = baseline
183
+ self.scale = scale
184
+
185
+ @classmethod
186
+ def from_yaml(cls, loader: TemplateLoader, node: MappingNode) -> Self:
187
+ line_plot_cfg: dict[str, Any] = loader.construct_mapping(
188
+ node, deep=True
189
+ )
190
+ line_plot_cfg.pop("id")
191
+ line_plot_cfg["out_path"] = Path(line_plot_cfg["out_path"])
192
+ line_plot_cfg["data_path"] = Path(line_plot_cfg["data_path"])
193
+ return cls(**line_plot_cfg)
194
+
195
+ def plot(self) -> plt.Axes:
196
+ start_date: pd.Timestamp = (
197
+ self.data.index.min()
198
+ if self.start_date is None
199
+ else pd.to_datetime(self.start_date)
200
+ )
201
+
202
+ end_date: pd.Timestamp = (
203
+ self.data.index.max()
204
+ if self.end_date is None
205
+ else pd.to_datetime(self.end_date)
206
+ )
207
+
208
+ plot_data: pd.DataFrame = self.data.loc[
209
+ slice(start_date, end_date), self.series.keys()
210
+ ]
211
+ plot_data = plot_data.rename(columns=self.series)
212
+
213
+ plot_data = plot_data * self.scale
214
+
215
+ if self.base_100: # maybe more flexible in the future
216
+ plot_data = plot_data / plot_data.iloc[0, :] * 100
217
+
218
+ fig = plt.figure(**FIG_CONFIG)
219
+ ax = fig.add_subplot()
220
+ plot_data.plot(ax=ax)
221
+
222
+ if self.annotate: # not implemented yet
223
+ pass
224
+
225
+ _style_spines( # maybe make this function accept a Format object
226
+ ax,
227
+ decimals=self.format.decimals,
228
+ units=self.format.units,
229
+ **AX_CONFIG["spines"],
230
+ )
231
+ if self.baseline:
232
+ reference = 100 if self.base_100 else 0
233
+ _style_baseline(ax, reference, **AX_CONFIG["baseline"])
234
+
235
+ if self.legend is not None:
236
+ ax.legend(
237
+ loc="upper center",
238
+ bbox_to_anchor=(0.5, LINE_PLOT_CONFIG["legend_sep"]),
239
+ ncol=self.legend.ncol,
240
+ )
241
+ else:
242
+ ax.legend().set_visible(False)
243
+
244
+ fig.savefig(self.out_path)
245
+ return ax