epyt-flow 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. epyt_flow/VERSION +1 -1
  2. epyt_flow/data/benchmarks/gecco_water_quality.py +2 -2
  3. epyt_flow/data/benchmarks/leakdb.py +40 -5
  4. epyt_flow/data/benchmarks/water_usage.py +4 -3
  5. epyt_flow/data/networks.py +27 -14
  6. epyt_flow/gym/__init__.py +0 -3
  7. epyt_flow/gym/scenario_control_env.py +11 -13
  8. epyt_flow/rest_api/scenario/control_handlers.py +118 -0
  9. epyt_flow/rest_api/scenario/event_handlers.py +114 -1
  10. epyt_flow/rest_api/scenario/handlers.py +33 -0
  11. epyt_flow/rest_api/server.py +14 -2
  12. epyt_flow/serialization.py +1 -0
  13. epyt_flow/simulation/__init__.py +0 -1
  14. epyt_flow/simulation/backend/__init__.py +1 -0
  15. epyt_flow/simulation/backend/my_epyt.py +1056 -0
  16. epyt_flow/simulation/events/actuator_events.py +7 -1
  17. epyt_flow/simulation/events/quality_events.py +3 -1
  18. epyt_flow/simulation/scada/scada_data.py +716 -5
  19. epyt_flow/simulation/scenario_config.py +1 -40
  20. epyt_flow/simulation/scenario_simulator.py +645 -119
  21. epyt_flow/simulation/sensor_config.py +18 -2
  22. epyt_flow/topology.py +24 -7
  23. epyt_flow/uncertainty/model_uncertainty.py +80 -62
  24. epyt_flow/uncertainty/sensor_noise.py +15 -4
  25. epyt_flow/uncertainty/uncertainties.py +71 -18
  26. epyt_flow/uncertainty/utils.py +40 -13
  27. epyt_flow/utils.py +45 -1
  28. epyt_flow/visualization/__init__.py +2 -0
  29. epyt_flow/visualization/scenario_visualizer.py +1240 -0
  30. epyt_flow/visualization/visualization_utils.py +738 -0
  31. {epyt_flow-0.10.0.dist-info → epyt_flow-0.12.0.dist-info}/METADATA +15 -4
  32. {epyt_flow-0.10.0.dist-info → epyt_flow-0.12.0.dist-info}/RECORD +35 -36
  33. {epyt_flow-0.10.0.dist-info → epyt_flow-0.12.0.dist-info}/WHEEL +1 -1
  34. epyt_flow/gym/control_gyms.py +0 -47
  35. epyt_flow/metrics.py +0 -466
  36. epyt_flow/models/__init__.py +0 -2
  37. epyt_flow/models/event_detector.py +0 -31
  38. epyt_flow/models/sensor_interpolation_detector.py +0 -118
  39. epyt_flow/simulation/scada/advanced_control.py +0 -138
  40. epyt_flow/simulation/scenario_visualizer.py +0 -1307
  41. {epyt_flow-0.10.0.dist-info → epyt_flow-0.12.0.dist-info/licenses}/LICENSE +0 -0
  42. {epyt_flow-0.10.0.dist-info → epyt_flow-0.12.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
1
+ """
2
+ Module provides helper functions and data management classes for visualizing
3
+ scenarios.
4
+ """
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Union, List, Tuple
8
+
9
+ import matplotlib as mpl
10
+ import matplotlib.pyplot as plt
11
+ import networkx.drawing.nx_pylab as nxp
12
+ import numpy as np
13
+ from scipy.interpolate import CubicSpline
14
+
15
+ from ..serialization import COLOR_SCHEMES_ID, JsonSerializable, serializable
16
+ from ..simulation.scada.scada_data import ScadaData
17
+
18
+ # Selection of functions for processing visualization data
19
+ stat_funcs = {
20
+ 'mean': np.mean,
21
+ 'min': np.min,
22
+ 'max': np.max
23
+ }
24
+
25
+
26
+ @dataclass
27
+ class JunctionObject:
28
+ """
29
+ Represents a junction component (e.g. nodes, tanks, reservoirs, ...) in a
30
+ water distribution network and manages all relevant attributes for drawing.
31
+
32
+ Attributes
33
+ ----------
34
+ nodelist : `list`
35
+ List of all nodes in WDN pertaining to this component type.
36
+ pos : `dict`
37
+ A dictionary mapping nodes to their coordinates in the correct format
38
+ for drawing.
39
+ node_shape : :class:`matplotlib.path.Path` or None
40
+ A shape representing the object, if none, the networkx default circle
41
+ is used.
42
+ node_size : `int`, default = 10
43
+ The size of each node.
44
+ node_color : `str` or `list`, default = 'k'
45
+ If `string`: the color for all nodes, if `list`: a list of lists
46
+ containing a numerical value for each node per frame, which will be
47
+ used for coloring.
48
+ interpolated : `bool`, default = False
49
+ Set to True, if node_colors are interpolated for smoother animation.
50
+ """
51
+ nodelist: list
52
+ pos: dict
53
+ node_shape: mpl.path.Path = None
54
+ node_size: int = 10
55
+ node_color: Union[str, list] = 'k'
56
+ interpolated: bool = False
57
+
58
+ def add_frame(self, statistic: str, values: np.ndarray,
59
+ pit: int, intervals: Union[int, List[Union[int, float]]]):
60
+ """
61
+ Adds a new frame of node_color based on a given statistic.
62
+
63
+ Parameters
64
+ ----------
65
+ statistic : `str`
66
+ The statistic to calculate for the data. Can be 'mean', 'min',
67
+ 'max' or 'time_step'.
68
+ values : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
69
+ The node values over time as extracted from the scada data.
70
+ pit : `int`
71
+ The point in time for the 'time_step' statistic.
72
+ intervals : `int`, `list[int]` or `list[float]`
73
+ If provided, the data will be grouped into intervals. It can be an
74
+ integer specifying the number of groups or a list of boundary
75
+ points.
76
+
77
+ Raises
78
+ ------
79
+ ValueError
80
+ If interval, pit or statistic is not correctly provided.
81
+
82
+ """
83
+ if statistic in stat_funcs:
84
+ stat_values = stat_funcs[statistic](values, axis=0)
85
+ elif statistic == 'time_step':
86
+ if not pit and pit != 0:
87
+ raise ValueError(
88
+ 'Please input point in time (pit) parameter when selecting'
89
+ ' time_step statistic')
90
+ stat_values = np.take(values, pit, axis=0)
91
+ else:
92
+ raise ValueError(
93
+ 'Statistic parameter must be mean, min, max or time_step')
94
+
95
+ if intervals is None:
96
+ pass
97
+ elif isinstance(intervals, (int, float)):
98
+ interv = np.linspace(stat_values.min(), stat_values.max(),
99
+ intervals + 1)
100
+ stat_values = np.digitize(stat_values, interv) - 1
101
+ elif isinstance(intervals, list):
102
+ stat_values = np.digitize(stat_values, intervals) - 1
103
+ else:
104
+ raise ValueError(
105
+ 'Intervals must be either number of groups or list of interval'
106
+ ' boundary points')
107
+
108
+ sorted_values = [v for _, v in zip(self.nodelist, stat_values)]
109
+
110
+ if isinstance(self.node_color, str):
111
+ # First run of this method
112
+ self.node_color = []
113
+ self.vmin = min(sorted_values)
114
+ self.vmax = max(sorted_values)
115
+
116
+ self.node_color.append(sorted_values)
117
+ self.vmin = min(*sorted_values, self.vmin)
118
+ self.vmax = max(*sorted_values, self.vmax)
119
+
120
+ def get_frame(self, frame_number: int = 0):
121
+ """
122
+ Returns all attributes necessary for networkx to draw the specified
123
+ frame.
124
+
125
+ Parameters
126
+ ----------
127
+ frame_number : `int`, default = 0
128
+ The frame whose parameters should be returned. Default is 0, this
129
+ is also used if only 1 frame exists (e.g. for plots, not
130
+ animations).
131
+
132
+ Returns
133
+ -------
134
+ valid_params : `dict`
135
+ A dictionary containing all attributes that function as parameters
136
+ for `networkx.drawing.nx_pylab.draw_networkx_nodes() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_nodes.html#draw-networkx-nodes>`_.
137
+ """
138
+
139
+ attributes = vars(self).copy()
140
+
141
+ if not isinstance(self.node_color, str):
142
+ if self.interpolated:
143
+ if frame_number > len(self.node_color_inter):
144
+ frame_number = -1
145
+ attributes['node_color'] = self.node_color_inter[frame_number]
146
+ else:
147
+ if frame_number > len(self.node_color):
148
+ frame_number = -1
149
+ attributes['node_color'] = self.node_color[frame_number]
150
+
151
+ sig = inspect.signature(nxp.draw_networkx_nodes)
152
+
153
+ valid_params = {
154
+ key: value for key, value in attributes.items()
155
+ if key in sig.parameters and value is not None
156
+ }
157
+
158
+ return valid_params
159
+
160
+ def get_frame_mask(self, mask, color):
161
+ """
162
+ Returns all attributes necessary for networkx to draw the specified
163
+ frame mask. Meaning covering all masked junction objects with the
164
+ default value.
165
+
166
+ Parameters
167
+ ----------
168
+ mask: `np.ndarray`
169
+ An array consisting of 0 and 1, where 0 means no sensor. Nodes
170
+ without sensor are to be masked.
171
+ color:
172
+ The default color of masked nodes.
173
+
174
+ Returns
175
+ -------
176
+ valid_params : `dict`
177
+ A dictionary containing all attributes that function as parameters
178
+ for `networkx.drawing.nx_pylab.draw_networkx_nodes() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_nodes.html#draw-networkx-nodes>`_.
179
+ """
180
+
181
+ attributes = vars(self).copy()
182
+
183
+ attributes['nodelist'] = [node for node, flag in
184
+ zip(self.nodelist, mask) if not flag]
185
+ attributes['node_color'] = color
186
+
187
+ sig = inspect.signature(nxp.draw_networkx_nodes)
188
+
189
+ valid_params = {
190
+ key: value for key, value in attributes.items()
191
+ if key in sig.parameters and value is not None
192
+ }
193
+
194
+ return valid_params
195
+
196
+ def interpolate(self, num_inter_frames: int):
197
+ """
198
+ Interpolates node_color values for smoother animations.
199
+
200
+ Parameters
201
+ ----------
202
+ num_inter_frames : `int`
203
+ The number of total frames after interpolation.
204
+ """
205
+ if isinstance(self.node_color, str) or len(self.node_color) <= 1:
206
+ return
207
+
208
+ tmp_node_color = np.array(self.node_color)
209
+ steps, num_nodes = tmp_node_color.shape
210
+
211
+ x_axis = np.linspace(0, steps - 1, steps)
212
+ new_x_axis = np.linspace(0, steps - 1, num_inter_frames)
213
+
214
+ self.node_color_inter = np.zeros(((len(new_x_axis)), num_nodes))
215
+
216
+ for node in range(num_nodes):
217
+ cs = CubicSpline(x_axis, tmp_node_color[:, node])
218
+ self.node_color_inter[:, node] = cs(new_x_axis)
219
+
220
+ self.interpolated = True
221
+
222
+ def add_attributes(self, attributes: dict):
223
+ """
224
+ Adds the given attributes dict as class attributes.
225
+
226
+ Parameters
227
+ ----------
228
+ attributes : `dict`
229
+ Attributes dict, which is to be added as class attributes.
230
+ """
231
+ for key, value in attributes.items():
232
+ setattr(self, key, value)
233
+
234
+
235
+ @dataclass
236
+ class EdgeObject:
237
+ """
238
+ Represents an edge component (pipes) in a water distribution network and
239
+ manages all relevant attributes for drawing.
240
+
241
+ Attributes
242
+ ----------
243
+ edgelist : `list`
244
+ List of all edges in WDN pertaining to this component type.
245
+ pos : `dict`
246
+ A dictionary mapping pipes to their coordinates in the correct format
247
+ for drawing.
248
+ edge_color : `str` or `list`, default = 'k'
249
+ If `string`: the color for all edges, if `list`: a list of lists
250
+ containing a numerical value for each edge per frame, which will be
251
+ used for coloring.
252
+ interpolated : `dict`, default = {}
253
+ Filled with interpolated frames if interpolation method is called.
254
+ """
255
+ edgelist: list
256
+ pos: dict
257
+ edge_color: Union[str, list] = 'k'
258
+ interpolated = {}
259
+
260
+ def rescale_widths(self, line_widths: Tuple[int, int] = (1, 2)):
261
+ """
262
+ Rescales all edge widths to the given interval.
263
+
264
+ Parameters
265
+ ----------
266
+ line_widths : `Tuple[int]`, default = (1, 2)
267
+ Min and max value, to which the edge widths should be scaled.
268
+
269
+ Raises
270
+ ------
271
+ AttributeError
272
+ If no edge width attribute exists yet.
273
+ """
274
+ if not hasattr(self, 'width'):
275
+ raise AttributeError(
276
+ 'Please call add_frame with edge_param=width before rescaling'
277
+ ' the widths.')
278
+
279
+ vmin = min(min(l) for l in self.width)
280
+ vmax = max(max(l) for l in self.width)
281
+
282
+ tmp = []
283
+ for il in self.width:
284
+ tmp.append(
285
+ self.__rescale(il, line_widths, values_min_max=(vmin, vmax)))
286
+ self.width = tmp
287
+
288
+ def add_frame(
289
+ self, topology, edge_param: str,
290
+ scada_data: Optional[ScadaData],
291
+ parameter: str = 'flow_rate', statistic: str = 'mean',
292
+ pit: Optional[Union[int, Tuple[int]]] = None,
293
+ species: str = None,
294
+ intervals: Optional[Union[int, List[Union[int, float]]]] = None,
295
+ use_sensor_data: bool = None):
296
+ """
297
+ Adds a new frame of edge_color or edge width based on the given data
298
+ and statistic.
299
+
300
+ Parameters
301
+ ----------
302
+ topology : :class:`~epyt_flow.topology.NetworkTopology`
303
+ Topology object retrieved from the scenario, containing the
304
+ structure of the water distribution network.
305
+ edge_param : `str`
306
+ Method can be called with edge_width or edge_color to calculate
307
+ either the width or color for the next frame.
308
+ scada_data : :class:`~epyt_flow.simulation.scada.scada_data.ScadaData`
309
+ SCADA data created by the :class:`~epyt_flow.simulation.scenario_simulator.ScenarioSimulator`
310
+ instance, is used to retrieve data for the next frame.
311
+ parameter : `str`, default = 'flow_rate'
312
+ The link data to visualize. Options are 'flow_rate', 'velocity', or
313
+ 'status'. Default is 'flow_rate'.
314
+ statistic : `str`, default = 'mean'
315
+ The statistic to calculate for the data. Can be 'mean', 'min',
316
+ 'max' or 'time_step'.
317
+ pit : `int`
318
+ The point in time for the 'time_step' statistic.
319
+ species: `str`, optional
320
+ Key of the species. Necessary only for parameter
321
+ 'bulk_species_concentration'.
322
+ intervals : `int`, `list[int]` or `list[float]`
323
+ If provided, the data will be grouped into intervals. It can be an
324
+ integer specifying the number of groups or a list of boundary
325
+ points.
326
+ use_sensor_data : `bool`, optional
327
+ If `True`, instead of using raw simulation data, the data recorded
328
+ by the corresponding sensors in the system is used for the
329
+ visualization. Note: Not all components may have a sensor attached
330
+ and sensors may be subject to sensor faults or noise.
331
+
332
+ Raises
333
+ ------
334
+ ValueError
335
+ If parameter, interval, pit or statistic is not set correctly.
336
+ """
337
+ if edge_param == 'edge_width' and not hasattr(self, 'width'):
338
+ self.width = []
339
+ elif edge_param == 'edge_color':
340
+ if isinstance(self.edge_color, str):
341
+ self.edge_color = []
342
+ self.edge_vmin = float('inf')
343
+ self.edge_vmax = float('-inf')
344
+
345
+ if parameter == 'flow_rate':
346
+ if use_sensor_data:
347
+ values, self.mask = scada_data.get_data_flows_as_edge_features()
348
+ values = values[:, ::2]
349
+ self.mask = self.mask[::2]
350
+ else:
351
+ values = scada_data.flow_data_raw
352
+ elif parameter == 'link_quality':
353
+ if use_sensor_data:
354
+ values, self.mask = scada_data.get_data_links_quality_as_edge_features()
355
+ values = values[:, ::2]
356
+ self.mask = self.mask[::2]
357
+ else:
358
+ values = scada_data.link_quality_data_raw
359
+ elif parameter == 'custom_data':
360
+ values = scada_data
361
+ elif parameter == 'bulk_species_concentration':
362
+ if not species:
363
+ raise ValueError('Species must be given when using '
364
+ 'bulk_species_concentration')
365
+ if use_sensor_data:
366
+ values, self.mask = scada_data.get_data_bulk_species_concentrations_as_edge_features()
367
+ self.mask = self.mask[::2,
368
+ scada_data.sensor_config.bulk_species.index(
369
+ species)]
370
+ values = values[:, ::2,
371
+ scada_data.sensor_config.bulk_species.index(species)]
372
+ else:
373
+ values = scada_data.bulk_species_link_concentration_raw[:,
374
+ scada_data.sensor_config.bulk_species.index(species),
375
+ :]
376
+ elif parameter == 'diameter':
377
+ value_dict = {
378
+ link[0]: topology.get_link_info(link[0])['diameter'] for
379
+ link in topology.get_all_links()}
380
+ sorted_values = [value_dict[x[0]] for x in
381
+ topology.get_all_links()]
382
+
383
+ if edge_param == 'edge_width':
384
+ self.width.append(sorted_values)
385
+ else:
386
+ self.edge_color.append(sorted_values)
387
+ self.edge_vmin = min(*sorted_values, self.edge_vmin)
388
+ self.edge_vmax = max(*sorted_values, self.edge_vmax)
389
+ return
390
+ else:
391
+ raise ValueError('Parameter must be flow_rate, link_quality, '
392
+ 'diameter or custom_data.')
393
+
394
+ if statistic in stat_funcs:
395
+ stat_values = stat_funcs[statistic](values, axis=0)
396
+ elif statistic == 'time_step':
397
+ if not pit and pit != 0:
398
+ raise ValueError(
399
+ 'Please input point in time (pit) parameter when selecting'
400
+ ' time_step statistic')
401
+ stat_values = np.take(values, pit, axis=0)
402
+ else:
403
+ raise ValueError(
404
+ 'Statistic parameter must be mean, min, max or time_step')
405
+
406
+ if intervals is None:
407
+ pass
408
+ elif isinstance(intervals, (int, float)):
409
+ interv = np.linspace(stat_values.min(), stat_values.max(),
410
+ intervals + 1)
411
+ stat_values = np.digitize(stat_values, interv) - 1
412
+ elif isinstance(intervals, list):
413
+ stat_values = np.digitize(stat_values, intervals) - 1
414
+ else:
415
+ raise ValueError(
416
+ 'Intervals must be either number of groups or list of interval'
417
+ ' boundary points')
418
+
419
+ sorted_values = list(stat_values)
420
+
421
+ if edge_param == 'edge_width':
422
+ self.width.append(sorted_values)
423
+ else:
424
+ self.edge_color.append(sorted_values)
425
+ self.edge_vmin = min(*sorted_values, self.edge_vmin)
426
+ self.edge_vmax = max(*sorted_values, self.edge_vmax)
427
+
428
+ def get_frame(self, frame_number: int = 0):
429
+ """
430
+ Returns all attributes necessary for networkx to draw the specified
431
+ frame.
432
+
433
+ Parameters
434
+ ----------
435
+ frame_number : `int`, default = 0
436
+ The frame whose parameters should be returned. Default is 0, this
437
+ is also used if only 1 frame exists (e.g. for plots, not
438
+ animations).
439
+
440
+ Returns
441
+ -------
442
+ valid_params : `dict`
443
+ A dictionary containing all attributes that function as parameters
444
+ for `networkx.drawing.nx_pylab.draw_networkx_edges() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_edges.html#draw-networkx-edges>`_.
445
+ """
446
+ attributes = vars(self).copy()
447
+
448
+ if not isinstance(self.edge_color, str):
449
+ if 'edge_color' in self.interpolated.keys():
450
+ if frame_number > len(self.interpolated['edge_color']):
451
+ frame_number = -1
452
+ attributes['edge_color'] = self.interpolated['edge_color'][
453
+ frame_number]
454
+ else:
455
+ if frame_number > len(self.edge_color):
456
+ frame_number = -1
457
+ attributes['edge_color'] = self.edge_color[frame_number]
458
+
459
+ if hasattr(self, 'width'):
460
+ if 'width' in self.interpolated.keys():
461
+ if frame_number > len(self.interpolated['width']):
462
+ frame_number = -1
463
+ attributes['width'] = self.interpolated['width'][frame_number]
464
+ else:
465
+ if frame_number > len(self.width):
466
+ frame_number = -1
467
+ attributes['width'] = self.width[frame_number]
468
+
469
+ sig = inspect.signature(nxp.draw_networkx_edges)
470
+
471
+ valid_params = {
472
+ key: value for key, value in attributes.items()
473
+ if key in sig.parameters and value is not None
474
+ }
475
+
476
+ return valid_params
477
+
478
+ def get_frame_mask(self, frame_number: int = 0, color='k'):
479
+ """
480
+ Returns all attributes necessary for networkx to draw the specified
481
+ frame mask.
482
+
483
+ Parameters
484
+ ----------
485
+ frame_number : `int`, default = 0
486
+ The frame whose parameters should be returned. Default is 0, this
487
+ is also used if only 1 frame exists (e.g. for plots, not
488
+ animations).
489
+ color:
490
+ The default color of masked nodes.
491
+
492
+ Returns
493
+ -------
494
+ valid_params : `dict`
495
+ A dictionary containing all attributes that function as parameters
496
+ for `networkx.drawing.nx_pylab.draw_networkx_edges() <https://networkx.org/documentation/stable/reference/generated/networkx.drawing.nx_pylab.draw_networkx_edges.html#draw-networkx-edges>`_.
497
+ """
498
+ attributes = vars(self).copy()
499
+
500
+ attributes['edgelist'] = [edge for edge, flag in
501
+ zip(self.edgelist, self.mask) if not flag]
502
+ attributes['edge_color'] = color
503
+
504
+ if hasattr(self, 'width'):
505
+ if 'width' in self.interpolated.keys():
506
+ if frame_number > len(self.interpolated['width']):
507
+ frame_number = -1
508
+ attributes['width'] = self.interpolated['width'][frame_number]
509
+ else:
510
+ if frame_number > len(self.width):
511
+ frame_number = -1
512
+ attributes['width'] = self.width[frame_number]
513
+ attributes['width'] = [edge for edge, flag in
514
+ zip(attributes['width'].copy(), self.mask)
515
+ if not flag]
516
+
517
+ sig = inspect.signature(nxp.draw_networkx_edges)
518
+
519
+ valid_params = {
520
+ key: value for key, value in attributes.items()
521
+ if key in sig.parameters and value is not None
522
+ }
523
+
524
+ return valid_params
525
+
526
+ def interpolate(self, num_inter_frames: int):
527
+ """
528
+ Interpolates edge_color and width values for smoother animations.
529
+
530
+ Parameters
531
+ ----------
532
+ num_inter_frames : `int`
533
+ The number of total frames after interpolation.
534
+ """
535
+ targets = {'edge_color': self.edge_color}
536
+ if hasattr(self, 'width'):
537
+ targets['width'] = self.width
538
+
539
+ for name, inter_target in targets.items():
540
+ if isinstance(inter_target, str) or len(inter_target) <= 1:
541
+ continue
542
+
543
+ tmp_target = np.array(inter_target)
544
+ steps, num_edges = tmp_target.shape
545
+
546
+ x_axis = np.linspace(0, steps - 1, steps)
547
+ new_x_axis = np.linspace(0, steps - 1, num_inter_frames)
548
+
549
+ vals_inter = np.zeros(((len(new_x_axis)), num_edges))
550
+
551
+ for edge in range(num_edges):
552
+ cs = CubicSpline(x_axis, tmp_target[:, edge])
553
+ vals_inter[:, edge] = cs(new_x_axis)
554
+
555
+ self.interpolated[name] = vals_inter
556
+
557
+ def add_attributes(self, attributes: dict):
558
+ """
559
+ Adds the given attributes dict as class attributes.
560
+
561
+ Parameters
562
+ ----------
563
+ attributes : `dict`
564
+ Attributes dict, which is to be added as class attributes.
565
+ """
566
+ for key, value in attributes.items():
567
+ setattr(self, key, value)
568
+
569
+ def __rescale(self, values: Union[np.ndarray, list],
570
+ scale_min_max: Union[List, Tuple[int]],
571
+ values_min_max: Union[
572
+ List, Tuple[int, int]] = None) -> np.ndarray:
573
+ """
574
+ Rescales the given values to a new range.
575
+
576
+ This method rescales an array of values to fit within a specified
577
+ minimum and maximum scale range. Optionally, the minimum and maximum
578
+ of the input values can be manually provided; otherwise, they are
579
+ automatically determined from the data.
580
+
581
+ Parameters
582
+ ----------
583
+ values : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ or `list`
584
+ The array of numerical values to be rescaled.
585
+ scale_min_max : `list` or `tuple`
586
+ A list or tuple containing two elements: the minimum and maximum
587
+ values of the desired output range.
588
+ values_min_max : `list` or `tuple`, optional
589
+ A list or tuple containing two elements: the minimum and maximum
590
+ values of the input data. If not provided, they are computed from
591
+ the input `values`. Default is `None`.
592
+
593
+ Returns
594
+ -------
595
+ rescaled_values : `list`
596
+ A list of values rescaled to the range specified by
597
+ `scale_min_max`.
598
+ """
599
+ if not values_min_max:
600
+ min_val, max_val = min(values), max(values)
601
+ else:
602
+ min_val, max_val = values_min_max
603
+ scale = scale_min_max[1] - scale_min_max[0]
604
+
605
+ def range_map(x):
606
+ return scale_min_max[0] + (x - min_val) / (
607
+ max_val - min_val) * scale
608
+
609
+ vectorized_range_map = np.vectorize(range_map)
610
+ rescaled_widths = vectorized_range_map(np.array(values))
611
+
612
+ if hasattr(self, 'mask'):
613
+ rescaled_widths = np.where(self.mask == 1, rescaled_widths, 1.0)
614
+
615
+ return rescaled_widths
616
+
617
+
618
+ @serializable(COLOR_SCHEMES_ID, ".epyt_flow_color_scheme")
619
+ class ColorScheme(JsonSerializable):
620
+ """
621
+ A class containing the color scheme for the
622
+ :class:`~epyt_flow.visualization.ScenarioVisualizer`.
623
+ """
624
+
625
+ def __init__(self, pipe_color: str, node_color: str, pump_color: str,
626
+ tank_color: str, reservoir_color: str,
627
+ valve_color: str) -> None:
628
+ """Initializes the ColorScheme class with the given component colors.
629
+
630
+ Accepted formats are the string representations accepted by matplotlib:
631
+ https://matplotlib.org/stable/users/explain/colors/colors.html#color-formats
632
+
633
+ Parameters
634
+ ----------
635
+ pipe_color : str
636
+ String color format accepted by matplotlib.
637
+ node_color : str
638
+ String color format accepted by matplotlib.
639
+ pump_color : str
640
+ String color format accepted by matplotlib.
641
+ tank_color : str
642
+ String color format accepted by matplotlib.
643
+ reservoir_color : str
644
+ String color format accepted by matplotlib.
645
+ valve_color : str
646
+ String color format accepted by matplotlib.
647
+ """
648
+ self.pipe_color = pipe_color
649
+ self.node_color = node_color
650
+ self.pump_color = pump_color
651
+ self.tank_color = tank_color
652
+ self.reservoir_color = reservoir_color
653
+ self.valve_color = valve_color
654
+ super().__init__()
655
+
656
+ def get_attributes(self) -> dict:
657
+ """
658
+ Gets all attributes needed for serialization.
659
+
660
+ Returns
661
+ -------
662
+ attr : A dictionary containing all attributes to be serialized.
663
+ """
664
+ attr = {
665
+ k: v for k, v in self.__dict__.items()
666
+ if
667
+ not (k.startswith("__") or k.startswith("_")) and not callable(v)
668
+ }
669
+ return super().get_attributes() | attr
670
+
671
+ def __eq__(self, other: any) -> bool:
672
+ """
673
+ Checks if two ColorScheme instances are equal.
674
+
675
+ Parameters
676
+ ----------
677
+ other : :class:`~epyt_flow.visualization_utils.ColorScheme`
678
+ The other ColorScheme instance to compare this one with.
679
+
680
+ Returns
681
+ -------
682
+ bool
683
+ True if all attributes are the same, False otherwise.
684
+ """
685
+ if not isinstance(other, ColorScheme):
686
+ return False
687
+ return (
688
+ self.pipe_color == other.pipe_color and
689
+ self.node_color == other.node_color and
690
+ self.pump_color == other.pump_color and
691
+ self.tank_color == other.tank_color and
692
+ self.reservoir_color == other.reservoir_color and
693
+ self.valve_color == other.valve_color
694
+ )
695
+
696
+ def __str__(self) -> str:
697
+ """
698
+ Returns a string representation of the ColorScheme instance.
699
+
700
+ Returns
701
+ -------
702
+ str
703
+ A string describing the ColorScheme instance.
704
+ """
705
+ return (f"ColorScheme(pipe_color={self.pipe_color}, "
706
+ f"node_color={self.node_color}, "
707
+ f"pump_color={self.pump_color}, "
708
+ f"tank_color={self.tank_color}, "
709
+ f"reservoir_color={self.reservoir_color}, "
710
+ f"valve_color={self.valve_color})")
711
+
712
+
713
+ epanet_colors = ColorScheme(
714
+ pipe_color="#0403ee",
715
+ node_color="#0403ee",
716
+ pump_color="#fe00ff",
717
+ tank_color="#02fffd",
718
+ reservoir_color="#00ff00",
719
+ valve_color="#000000"
720
+ )
721
+
722
+ epyt_flow_colors = ColorScheme(
723
+ pipe_color="#29222f",
724
+ node_color="#29222f",
725
+ pump_color="#d79233",
726
+ tank_color="#607b80",
727
+ reservoir_color="#33483d",
728
+ valve_color="#a3320b"
729
+ )
730
+
731
+ black_colors = ColorScheme(
732
+ pipe_color="#000000",
733
+ node_color="#000000",
734
+ pump_color="#000000",
735
+ tank_color="#000000",
736
+ reservoir_color="#000000",
737
+ valve_color="#000000"
738
+ )