geometallurgy 0.4.12__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. elphick/geomet/__init__.py +11 -11
  2. elphick/geomet/base.py +1133 -1133
  3. elphick/geomet/block_model.py +319 -358
  4. elphick/geomet/config/__init__.py +1 -1
  5. elphick/geomet/config/config_read.py +39 -39
  6. elphick/geomet/config/flowsheet_example_partition.yaml +31 -31
  7. elphick/geomet/config/flowsheet_example_simple.yaml +25 -25
  8. elphick/geomet/config/mc_config.yml +35 -35
  9. elphick/geomet/data/downloader.py +39 -39
  10. elphick/geomet/data/register.csv +12 -12
  11. elphick/geomet/datasets/__init__.py +2 -2
  12. elphick/geomet/datasets/datasets.py +47 -47
  13. elphick/geomet/datasets/downloader.py +40 -40
  14. elphick/geomet/datasets/register.csv +12 -12
  15. elphick/geomet/datasets/sample_data.py +196 -196
  16. elphick/geomet/extras.py +35 -35
  17. elphick/geomet/flowsheet/__init__.py +1 -1
  18. elphick/geomet/flowsheet/flowsheet.py +1216 -1216
  19. elphick/geomet/flowsheet/loader.py +99 -99
  20. elphick/geomet/flowsheet/operation.py +256 -256
  21. elphick/geomet/flowsheet/stream.py +39 -39
  22. elphick/geomet/interval_sample.py +641 -641
  23. elphick/geomet/io.py +379 -379
  24. elphick/geomet/plot.py +147 -147
  25. elphick/geomet/sample.py +28 -28
  26. elphick/geomet/utils/amenability.py +49 -49
  27. elphick/geomet/utils/block_model_converter.py +93 -93
  28. elphick/geomet/utils/components.py +136 -136
  29. elphick/geomet/utils/data.py +49 -49
  30. elphick/geomet/utils/estimates.py +108 -108
  31. elphick/geomet/utils/interp.py +193 -193
  32. elphick/geomet/utils/interp2.py +134 -134
  33. elphick/geomet/utils/layout.py +72 -72
  34. elphick/geomet/utils/moisture.py +61 -61
  35. elphick/geomet/utils/output.html +617 -0
  36. elphick/geomet/utils/pandas.py +378 -378
  37. elphick/geomet/utils/parallel.py +29 -29
  38. elphick/geomet/utils/partition.py +63 -63
  39. elphick/geomet/utils/size.py +51 -51
  40. elphick/geomet/utils/timer.py +80 -80
  41. elphick/geomet/utils/viz.py +56 -56
  42. elphick/geomet/validate.py.hide +176 -176
  43. {geometallurgy-0.4.12.dist-info → geometallurgy-0.4.13.dist-info}/LICENSE +21 -21
  44. {geometallurgy-0.4.12.dist-info → geometallurgy-0.4.13.dist-info}/METADATA +7 -5
  45. geometallurgy-0.4.13.dist-info/RECORD +49 -0
  46. {geometallurgy-0.4.12.dist-info → geometallurgy-0.4.13.dist-info}/WHEEL +1 -1
  47. geometallurgy-0.4.12.dist-info/RECORD +0 -48
  48. {geometallurgy-0.4.12.dist-info → geometallurgy-0.4.13.dist-info}/entry_points.txt +0 -0
@@ -1,1216 +1,1216 @@
1
- import copy
2
- import json
3
- import logging
4
- import uuid
5
- from pathlib import Path
6
- from typing import Dict, List, Optional, Tuple, Union, TypeVar, TYPE_CHECKING
7
- import re
8
-
9
- import matplotlib
10
- import matplotlib.cm as cm
11
- import networkx as nx
12
- import numpy as np
13
- import pandas as pd
14
- import plotly.graph_objects as go
15
- import seaborn as sns
16
- import yaml
17
- from matplotlib import pyplot as plt
18
- from matplotlib.colors import ListedColormap, LinearSegmentedColormap
19
- from networkx.algorithms.dag import is_directed_acyclic_graph
20
- from plotly.subplots import make_subplots
21
-
22
- from elphick.geomet import Sample
23
- from elphick.geomet.base import MC
24
- from elphick.geomet.config.config_read import get_column_config
25
- from elphick.geomet.flowsheet.operation import NodeType, OP, PartitionOperation, Operation
26
- from elphick.geomet.plot import parallel_plot, comparison_plot
27
- from elphick.geomet.utils.layout import digraph_linear_layout
28
- from elphick.geomet.flowsheet.loader import streams_from_dataframe
29
-
30
- # if TYPE_CHECKING:
31
- from elphick.geomet.flowsheet.stream import Stream
32
-
33
- # generic type variable, used for type hinting that play nicely with subclasses
34
- FS = TypeVar('FS', bound='Flowsheet')
35
-
36
-
37
- class Flowsheet:
38
- def __init__(self, name: str = 'Flowsheet'):
39
- self.graph: nx.DiGraph = nx.DiGraph(name=name)
40
- self._logger: logging.Logger = logging.getLogger(__class__.__name__)
41
-
42
- @property
43
- def name(self) -> str:
44
- return self.graph.name
45
-
46
- @name.setter
47
- def name(self, value: str):
48
- self.graph.name = value
49
-
50
- @property
51
- def healthy(self) -> bool:
52
- return self.all_nodes_healthy and self.all_streams_healthy
53
-
54
- @property
55
- def all_nodes_healthy(self) -> bool:
56
- bal_vals: List = [self.graph.nodes[n]['mc'].is_balanced for n in self.graph.nodes]
57
- bal_vals = [bv for bv in bal_vals if bv is not None]
58
- return all(bal_vals)
59
-
60
- @property
61
- def all_streams_healthy(self) -> bool:
62
- """Check if all streams are healthy"""
63
- # account for the fact that some edges may not have an mc object
64
- if not all([d['mc'] for u, v, d in self.graph.edges(data=True)]):
65
- return False
66
- return all([self.graph.edges[u, v]['mc'].status.ok for u, v in self.graph.edges])
67
-
68
- @classmethod
69
- def from_objects(cls, objects: list[MC],
70
- name: Optional[str] = 'Flowsheet') -> FS:
71
- """Instantiate from a list of objects
72
-
73
- This method is only suitable for objects that have the `_nodes` property set, such as objects that have
74
- been created from math operations, which preserve relationships between objects (via the nodes property)
75
-
76
- Args:
77
- objects: List of MassComposition objects, such as Sample, IntervalSample, Stream or BlockModel
78
- name: name of the flowsheet/network
79
-
80
- Returns:
81
-
82
- """
83
- from elphick.geomet.flowsheet.operation import Operation
84
- cls._check_indexes(objects)
85
- bunch_of_edges: list = []
86
- for stream in objects:
87
- if stream.nodes is None:
88
- raise KeyError(f'Stream {stream.name} does not have the node property set')
89
- nodes = stream.nodes
90
-
91
- # add the objects to the edges
92
- bunch_of_edges.append((nodes[0], nodes[1], {'mc': stream, 'name': stream.name}))
93
-
94
- graph = nx.DiGraph(name=name)
95
- graph.add_edges_from(bunch_of_edges)
96
- operation_objects: dict = {}
97
- for node in graph.nodes:
98
- operation_objects[node] = Operation(name=node)
99
- nx.set_node_attributes(graph, operation_objects, 'mc')
100
-
101
- for node in graph.nodes:
102
- operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
103
- operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
104
-
105
- graph = nx.convert_node_labels_to_integers(graph)
106
- # update the temporary nodes on the mc object property to match the renumbered integers
107
- for node1, node2, data in graph.edges(data=True):
108
- data['mc'].nodes = [node1, node2]
109
- # update the node names after renumbering
110
- for node in graph.nodes:
111
- graph.nodes[node]['mc'].name = str(node)
112
- obj = cls()
113
- obj.graph = graph
114
- return obj
115
-
116
- @classmethod
117
- def from_dataframe(cls, df: pd.DataFrame, name: Optional[str] = 'Flowsheet',
118
- mc_name_col: Optional[str] = None, n_jobs: int = 1) -> FS:
119
- """Instantiate from a DataFrame
120
-
121
- Args:
122
- df: The DataFrame
123
- name: name of the network
124
- mc_name_col: The column specified contains the names of objects to create.
125
- If None the DataFrame is assumed to be wide and the mc objects will be extracted from column prefixes.
126
- n_jobs: The number of parallel jobs to run. If -1, will use all available cores.
127
-
128
- Returns:
129
- Flowsheet: An instance of the Flowsheet class initialized from the provided DataFrame.
130
-
131
- """
132
- streams: list[Sample] = streams_from_dataframe(df=df, mc_name_col=mc_name_col, n_jobs=n_jobs)
133
- return cls().from_objects(objects=streams, name=name)
134
-
135
- @classmethod
136
- def from_dict(cls, config: dict) -> FS:
137
- """Create a flowsheet from a dictionary
138
-
139
- Args:
140
- config: dictionary containing the flowsheet configuration
141
-
142
- Returns:
143
- A Flowsheet object with no data on the edges
144
- """
145
-
146
- from elphick.geomet.flowsheet.operation import Operation
147
-
148
- if 'FLOWSHEET' not in config:
149
- raise ValueError("Dictionary does not contain 'FLOWSHEET' root node")
150
-
151
- flowsheet_config = config['FLOWSHEET']
152
-
153
- # create the Stream objects
154
- bunch_of_edges: list = []
155
- for stream, stream_config in flowsheet_config['streams'].items():
156
- bunch_of_edges.append(
157
- (stream_config['node_in'], stream_config['node_out'], {'mc': None, 'name': stream_config['name']}))
158
-
159
- graph = nx.DiGraph(name=flowsheet_config['flowsheet']['name'])
160
- graph.add_edges_from(bunch_of_edges)
161
- operation_objects: dict = {}
162
- for node in graph.nodes:
163
- # create the correct type of node object
164
- if node in flowsheet_config['operations']:
165
- operation_type = flowsheet_config['operations'][node].get('type', 'Operation')
166
- if operation_type == 'PartitionOperation':
167
- # get the output stream names from the graph
168
- output_stream_names = [d['name'] for u, v, d in graph.out_edges(node, data=True)]
169
- node_config = flowsheet_config['operations'][node]
170
- node_config['output_stream_names'] = output_stream_names
171
- operation_objects[node] = PartitionOperation.from_dict(node_config)
172
- else:
173
- operation_objects[node] = Operation.from_dict(flowsheet_config['operations'][node])
174
-
175
- # set the input and output streams on the operation object for the selected node
176
- operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
177
- operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
178
-
179
- nx.set_node_attributes(graph, operation_objects, 'mc')
180
-
181
- graph = nx.convert_node_labels_to_integers(graph)
182
-
183
- obj = cls()
184
- obj.graph = graph
185
-
186
- return obj
187
-
188
- # @classmethod
189
- # def from_dict_todo(cls, config: dict) -> FS:
190
- # TODO: This method is not yet implemented - fails because the Operations do not have inputs or outputs set.
191
- # flowsheet = cls()
192
- #
193
- # # Process streams
194
- # for stream_name, stream_data in config['FLOWSHEET']['streams'].items():
195
- # stream = Stream.from_dict(stream_data)
196
- # flowsheet.add_stream(stream)
197
- #
198
- # # Process operations
199
- # for operation_name, operation_data in config['FLOWSHEET']['operations'].items():
200
- # operation_type = operation_data.get('type', 'Operation')
201
- # if operation_type == 'PartitionOperation':
202
- # operation = PartitionOperation.from_dict(operation_data)
203
- # else:
204
- # operation = Operation.from_dict(operation_data)
205
- # flowsheet.add_operation(operation)
206
- #
207
- # return flowsheet
208
-
209
- @classmethod
210
- def from_yaml(cls, file_path: Path) -> FS:
211
- """Create a flowsheet from yaml
212
-
213
- Args:
214
- file_path: path to the yaml file
215
-
216
- Returns:
217
- A Flowsheet object with no data on the edges
218
- """
219
- with open(file_path, 'r') as file:
220
- config = yaml.safe_load(file)
221
-
222
- return cls.from_dict(config)
223
-
224
- @classmethod
225
- def from_json(cls, file_path: Path) -> FS:
226
- """Create a flowsheet from json
227
-
228
- Args:
229
- file_path: path to the json file
230
-
231
- Returns:
232
- A Flowsheet object with no data on the edges
233
- """
234
- with open(file_path, 'r') as file:
235
- config = json.load(file)
236
-
237
- return cls.from_dict(config)
238
-
239
- def add_stream(self, stream: 'Stream'):
240
- """Add a stream to the flowsheet."""
241
- self.graph.add_edge(stream.nodes[0], stream.nodes[1], mc=stream, name=stream.name)
242
-
243
- def add_operation(self, operation: 'Operation'):
244
- """Add an operation to the flowsheet."""
245
- self.graph.add_node(operation.name, mc=operation)
246
-
247
- def unhealthy_stream_records(self) -> pd.DataFrame:
248
- """Return on unhealthy streams
249
-
250
- Return the records for all streams that are not healthy.
251
- Returns:
252
- DataFrame: A DataFrame containing the unhealthy stream records
253
- """
254
- unhealthy_edges = [e for e in self.graph.edges if not self.graph.edges[e]['mc'].status.ok]
255
- unhealthy_data: pd.DataFrame = pd.concat(
256
- [self.graph.edges[e]['mc'].status.oor.assign(stream=self.graph.edges[e]['mc'].name) for e in
257
- unhealthy_edges], axis=1)
258
- # move the last column to the front
259
- unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
260
-
261
- # append the flowsheet records for additional context
262
- records: pd.DataFrame = self.to_dataframe()
263
- records = records.unstack(level='name').swaplevel(axis=1).sort_index(axis=1, level=0, sort_remaining=False)
264
- records.columns = [f"{col[0]}_{col[1]}" for col in records.columns]
265
-
266
- result = unhealthy_data.merge(records, left_index=True, right_index=True, how='left')
267
- return result
268
-
269
- def unhealthy_node_records(self) -> pd.DataFrame:
270
- """Return unhealthy nodes
271
-
272
- Return the records for all nodes that are not healthy.
273
- Returns:
274
- DataFrame: A DataFrame containing the unhealthy node records
275
- """
276
- unhealthy_nodes = [n for n in self.graph.nodes if
277
- self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE and not self.graph.nodes[n][
278
- 'mc'].is_balanced]
279
- unhealthy_data: pd.DataFrame = pd.concat(
280
- [self.graph.nodes[n]['mc'].unbalanced_records.assign(node=self.graph.nodes[n]['mc'].name) for n in
281
- unhealthy_nodes], axis=1)
282
- # move the last column to the front
283
- unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
284
-
285
- # todo: append the streams around the node
286
-
287
- return unhealthy_data
288
-
289
- def copy_without_stream_data(self):
290
- """Copy without stream data"""
291
- new_flowsheet = Flowsheet(name=self.name)
292
- new_graph = nx.DiGraph()
293
-
294
- # Copy nodes with Operation objects
295
- for node, data in self.graph.nodes(data=True):
296
- new_data = data.copy()
297
- new_graph.add_node(node, **new_data)
298
-
299
- # Copy edges with mc attribute set to None
300
- for u, v, data in self.graph.edges(data=True):
301
- new_data = {k: (None if k == 'mc' else copy.deepcopy(v)) for k, v in data.items()}
302
- new_graph.add_edge(u, v, **new_data)
303
-
304
- new_flowsheet.graph = new_graph
305
- return new_flowsheet
306
-
307
- def solve(self):
308
- """Solve missing streams"""
309
-
310
- if not is_directed_acyclic_graph(self.graph):
311
- self._logger.error("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
312
- self._logger.debug(f"Graph nodes: {self.graph.nodes(data=True)}")
313
- self._logger.debug(f"Graph edges: {self.graph.edges(data=True)}")
314
- raise ValueError("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
315
-
316
- # Check the number of missing mc's on edges in the network
317
- missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
318
- prev_missing_count = missing_count + 1 # Initialize with a value greater than missing_count
319
-
320
- while 0 < missing_count < prev_missing_count:
321
- prev_missing_count = missing_count
322
- for node in nx.topological_sort(self.graph):
323
- if self.graph.nodes[node]['mc'].node_type == NodeType.BALANCE:
324
- if self.graph.nodes[node]['mc'].has_empty_input:
325
- mc: MC = self.graph.nodes[node]['mc'].solve()
326
- # copy the solved object to the empty input edges
327
- for predecessor in self.graph.predecessors(node):
328
- edge_data = self.graph.get_edge_data(predecessor, node)
329
- if edge_data and edge_data['mc'] is None:
330
- edge_data['mc'] = mc
331
- edge_data['mc'].name = edge_data['name']
332
- self.set_operation_data(predecessor)
333
-
334
- if self.graph.nodes[node]['mc'].has_empty_output:
335
- # There are two cases to be managed, 1. a single output missing,
336
- # 2. a partition operation that returns two outputs
337
- if isinstance(self.graph.nodes[node]['mc'], PartitionOperation):
338
- partition_stream: str = self.graph.nodes[node]['mc'].partition['partition_stream']
339
- mc1, mc2 = self.graph.nodes[node]['mc'].solve()
340
- # copy the solved object to the empty output edges
341
- for successor in self.graph.successors(node):
342
- edge_data = self.graph.get_edge_data(node, successor)
343
- if edge_data and edge_data['mc'] is None:
344
- edge_data['mc'] = mc1 if edge_data['name'] == partition_stream else mc2
345
- edge_data['mc'].name = edge_data['name']
346
- self.set_operation_data(successor)
347
-
348
- else:
349
- mc: MC = self.graph.nodes[node]['mc'].solve()
350
- # copy the solved object to the empty output edges
351
- for successor in self.graph.successors(node):
352
- edge_data = self.graph.get_edge_data(node, successor)
353
- if edge_data and edge_data['mc'] is None:
354
- edge_data['mc'] = mc
355
- edge_data['mc'].name = edge_data['name']
356
- self.set_operation_data(successor)
357
- self.set_operation_data(node)
358
-
359
- missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
360
- self._logger.info(f"Missing count: {missing_count}")
361
-
362
- if missing_count > 0:
363
- self._logger.error(f"Failed to solve the flowsheet. Missing count: {missing_count}")
364
- raise ValueError(
365
- f"Failed to solve the flowsheet. Some streams are still missing. Missing count: {missing_count}")
366
-
367
- def query(self, expr: str, stream_name: Optional[str] = None, inplace=False) -> 'Flowsheet':
368
- """Reduce the Flowsheet Stream records with a query
369
-
370
- Args:
371
- expr: The query string to apply to all streams. The query is applied in place. The LHS of the
372
- expression requires a prefix that defines the stream name e.g. stream_name.var_name > 0.5
373
- stream_name: The name of the stream to apply the query to. If None, the query is applied to the
374
- first input stream.
375
- inplace: If True, apply the query in place on the same object, otherwise return a new instance.
376
-
377
- Returns:
378
- A Flowsheet object where the stream records conform to the query
379
- """
380
- if stream_name is None:
381
- input_stream: MC = self.get_input_streams()[0]
382
- else:
383
- input_stream: MC = self.get_stream_by_name(name=stream_name)
384
- filtered_index: pd.Index = input_stream.data.query(expr).index
385
- return self._filter(filtered_index, inplace)
386
-
387
- def filter_by_index(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
388
- """Filter the Flowsheet Stream records by a given index.
389
-
390
- Args:
391
- index: The index to filter the data.
392
- inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
393
-
394
- Returns:
395
- A Flowsheet object where the stream records are filtered by the given index.
396
- """
397
- return self._filter(index, inplace)
398
-
399
- def _filter(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
400
- """Private method to filter the Flowsheet Stream records by a given index.
401
-
402
- Args:
403
- index: The index to filter the data.
404
- inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
405
-
406
- Returns:
407
- A Flowsheet object where the stream records are filtered by the given index.
408
- """
409
- if inplace:
410
- for u, v, d in self.graph.edges(data=True):
411
- if d.get('mc') is not None:
412
- d.get('mc').filter_by_index(index)
413
- return self
414
- else:
415
- obj: Flowsheet = self.copy_without_stream_data()
416
- for u, v, d in self.graph.edges(data=True):
417
- if d.get('mc') is not None:
418
- mc: MC = d.get('mc')
419
- mc_new = mc.__class__(name=mc.name)
420
- # Copy each attribute
421
- for attr, value in mc.__dict__.items():
422
- if attr in ['_mass_data', '_supplementary_data'] and value is not None:
423
- value = value.loc[index]
424
- setattr(mc_new, attr, copy.deepcopy(value))
425
- mc_new.aggregate = mc_new.weight_average()
426
- obj.graph[u][v]['mc'] = mc_new
427
- return obj
428
-
429
- def get_input_streams(self) -> list[MC]:
430
- """Get the input (feed) streams (edge objects)
431
-
432
- Returns:
433
- List of MassComposition-like objects
434
- """
435
-
436
- # Create a dictionary that maps node names to their degrees
437
- degrees = {n: d for n, d in self.graph.degree()}
438
-
439
- res: list[MC] = [d['mc'] for u, v, d in self.graph.edges(data=True) if degrees[u] == 1]
440
- if not res:
441
- raise ValueError("No input streams found")
442
- return res
443
-
444
- def get_output_streams(self) -> list[MC]:
445
- """Get the output (product) streams (edge objects)
446
-
447
- Returns:
448
- List of MassComposition-like objects
449
- """
450
-
451
- # Create a dictionary that maps node names to their degrees
452
- degrees = {n: d for n, d in self.graph.degree()}
453
-
454
- res: list[MC] = [d['mc'] for u, v, d in self.graph.edges(data=True) if degrees[v] == 1]
455
- if not res:
456
- raise ValueError("No output streams found")
457
- return res
458
-
459
- @staticmethod
460
- def _check_indexes(streams):
461
-
462
- list_of_indexes = [s._mass_data.index for s in streams]
463
- types_of_indexes = [type(i) for i in list_of_indexes]
464
- # check the index types are consistent
465
- if len(set(types_of_indexes)) != 1:
466
- raise KeyError("stream index types are not consistent")
467
-
468
- def plot(self, orientation: str = 'horizontal') -> plt.Figure:
469
- """Plot the network with matplotlib
470
-
471
- Args:
472
- orientation: 'horizontal'|'vertical' network layout
473
-
474
- Returns:
475
-
476
- """
477
-
478
- hf, ax = plt.subplots()
479
- # pos = nx.spring_layout(self, seed=1234)
480
- pos = digraph_linear_layout(self.graph, orientation=orientation)
481
-
482
- edge_labels: Dict = {}
483
- edge_colors: List = []
484
- node_colors: List = []
485
-
486
- for node1, node2, data in self.graph.edges(data=True):
487
- edge_labels[(node1, node2)] = data['mc'].name if data['mc'] is not None else data['name']
488
- if data['mc'] and data['mc'].status.ok:
489
- edge_colors.append('gray')
490
- else:
491
- edge_colors.append('red')
492
-
493
- for n in self.graph.nodes:
494
- if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
495
- if self.graph.nodes[n]['mc'].is_balanced:
496
- node_colors.append('green')
497
- else:
498
- node_colors.append('red')
499
- else:
500
- node_colors.append('gray')
501
-
502
- nx.draw(self.graph, pos=pos, ax=ax, with_labels=True, font_weight='bold',
503
- node_color=node_colors, edge_color=edge_colors)
504
-
505
- nx.draw_networkx_edge_labels(self.graph, pos=pos, ax=ax, edge_labels=edge_labels, font_color='black')
506
- ax.set_title(self._plot_title(html=False), fontsize=10)
507
-
508
- return hf
509
-
510
- def _plot_title(self, html: bool = True, compact: bool = False):
511
- # title = self.name
512
- title = (f"{self.name}<br><sup>Nodes Healthy: "
513
- f"<span style='color: {'red' if not self.all_nodes_healthy else 'black'}'>{self.all_nodes_healthy}</span>, "
514
- f"Streams Healthy: "
515
- f"<span style='color: {'red' if not self.all_streams_healthy else 'black'}'>{self.all_streams_healthy}</span></sup>")
516
- # if compact:
517
- # title = title.replace("<br><br>", "<br>").replace("<br>Edge", ", Edge")
518
- # if not self.edge_status[0]:
519
- # title = title.replace("</sup>", "") + f", {self.edge_status[1]}</sup>"
520
- if not html:
521
- title = title.replace('<br><br>', '\n').replace('<br>', '\n').replace('<sup>', '').replace('</sup>', '')
522
- title = re.sub(r'<span style=.*?>(.*?)</span>', r'\1', title)
523
- return title
524
-
525
- def report(self, apply_formats: bool = False) -> pd.DataFrame:
526
- """Summary Report
527
-
528
- Total Mass and weight averaged composition
529
- Returns:
530
-
531
- """
532
- chunks: List[pd.DataFrame] = []
533
- for n, nbrs in self.graph.adj.items():
534
- for nbr, eattr in nbrs.items():
535
- if eattr['mc'] is None or eattr['mc'].data.empty:
536
- edge_name: str = eattr['mc']['name'] if eattr['mc'] is not None else eattr['name']
537
- raise KeyError(f"Cannot generate report on empty dataset: {edge_name}")
538
- chunks.append(eattr['mc'].aggregate.assign(name=eattr['mc'].name))
539
- rpt: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name')
540
- if apply_formats:
541
- fmts: Dict = self._get_column_formats(rpt.columns)
542
- for k, v in fmts.items():
543
- rpt[k] = rpt[k].apply((v.replace('%', '{:,') + '}').format)
544
- return rpt
545
-
546
- def _get_column_formats(self, columns: List[str], strip_percent: bool = False) -> Dict[str, str]:
547
- """
548
-
549
- Args:
550
- columns: The columns to lookup format strings for
551
- strip_percent: If True remove the leading % symbol from the format (for plotly tables)
552
-
553
- Returns:
554
-
555
- """
556
- strm = self.get_input_streams()[0]
557
- d_format: dict = get_column_config(config_dict=strm.config, var_map=strm.variable_map, config_key='format')
558
-
559
- if strip_percent:
560
- d_format = {k: v.strip('%') for k, v in d_format.items()}
561
-
562
- return d_format
563
-
564
- def plot_balance(self, facet_col_wrap: int = 3,
565
- color: Optional[str] = 'node') -> go.Figure:
566
- """Plot input versus output across all nodes in the network
567
-
568
- Args:
569
- facet_col_wrap: the number of subplots per row before wrapping
570
- color: The optional variable to color by. If None color will be by Node
571
-
572
- Returns:
573
-
574
- """
575
- # prepare the data
576
- chunks_in: List = []
577
- chunks_out: List = []
578
- for n in self.graph.nodes:
579
- if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
580
- chunks_in.append(self.graph.nodes[n]['mc'].add('in').assign(**{'direction': 'in', 'node': n}))
581
- chunks_out.append(self.graph.nodes[n]['mc'].add('out').assign(**{'direction': 'out', 'node': n}))
582
- df_in: pd.DataFrame = pd.concat(chunks_in)
583
- index_names = ['direction', 'node'] + df_in.index.names
584
- df_in = df_in.reset_index().melt(id_vars=index_names)
585
- df_out: pd.DataFrame = pd.concat(chunks_out).reset_index().melt(id_vars=index_names)
586
- df_plot: pd.DataFrame = pd.concat([df_in, df_out])
587
- df_plot = df_plot.set_index(index_names + ['variable'], append=True).unstack(['direction'])
588
- df_plot.columns = df_plot.columns.droplevel(0)
589
- df_plot.reset_index(level=list(np.arange(-1, -len(index_names) - 1, -1)), inplace=True)
590
- df_plot['node'] = pd.Categorical(df_plot['node'])
591
-
592
- # plot
593
- fig = comparison_plot(data=df_plot,
594
- x='in', y='out',
595
- facet_col_wrap=facet_col_wrap,
596
- color=color)
597
- return fig
598
-
599
- def plot_network(self, orientation: str = 'horizontal') -> go.Figure:
600
- """Plot the network with plotly
601
-
602
- Args:
603
- orientation: 'horizontal'|'vertical' network layout
604
-
605
- Returns:
606
-
607
- """
608
- # pos = nx.spring_layout(self, seed=1234)
609
- pos = digraph_linear_layout(self.graph, orientation=orientation)
610
-
611
- edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
612
- title = self._plot_title()
613
-
614
- fig = go.Figure(data=[*edge_traces, node_trace, edge_annotation_trace],
615
- layout=go.Layout(
616
- title=title,
617
- titlefont_size=16,
618
- showlegend=False,
619
- hovermode='closest',
620
- margin=dict(b=20, l=5, r=5, t=40),
621
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
622
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
623
- paper_bgcolor='rgba(0,0,0,0)',
624
- plot_bgcolor='rgba(0,0,0,0)'
625
- ),
626
- )
627
- # for k, d_args in edge_annotations.items():
628
- # fig.add_annotation(x=d_args['pos'][0], y=d_args['pos'][1], text=k, textangle=d_args['angle'])
629
-
630
- return fig
631
-
632
- def plot_sankey(self,
633
- width_var: str = 'mass_dry',
634
- color_var: Optional[str] = None,
635
- edge_colormap: Optional[str] = 'copper_r',
636
- vmin: Optional[float] = None,
637
- vmax: Optional[float] = None,
638
- ) -> go.Figure:
639
- """Plot the Network as a sankey
640
-
641
- Args:
642
- width_var: The variable that determines the sankey width
643
- color_var: The optional variable that determines the sankey edge color
644
- edge_colormap: The optional colormap. Used with color_var.
645
- vmin: The value that maps to the minimum color
646
- vmax: The value that maps to the maximum color
647
-
648
- Returns:
649
-
650
- """
651
- # Create a mapping of node names to indices, and the integer nodes
652
- node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
653
- int_graph = nx.relabel_nodes(self.graph, node_indices)
654
-
655
- # Generate the sankey diagram arguments using the new graph with integer nodes
656
- d_sankey = self._generate_sankey_args(int_graph, color_var, edge_colormap, width_var, vmin, vmax)
657
-
658
- # Create the sankey diagram
659
- node, link = self._get_sankey_node_link_dicts(d_sankey)
660
- fig = go.Figure(data=[go.Sankey(node=node, link=link)])
661
- title = self._plot_title()
662
- fig.update_layout(title_text=title, font_size=10)
663
- return fig
664
-
665
- def table_plot(self,
666
- plot_type: str = 'sankey',
667
- cols_exclude: Optional[List] = None,
668
- table_pos: str = 'left',
669
- table_area: float = 0.4,
670
- table_header_color: str = 'cornflowerblue',
671
- table_odd_color: str = 'whitesmoke',
672
- table_even_color: str = 'lightgray',
673
- sankey_width_var: str = 'mass_dry',
674
- sankey_color_var: Optional[str] = None,
675
- sankey_edge_colormap: Optional[str] = 'copper_r',
676
- sankey_vmin: Optional[float] = None,
677
- sankey_vmax: Optional[float] = None,
678
- network_orientation: Optional[str] = 'horizontal'
679
- ) -> go.Figure:
680
- """Plot with table of edge averages
681
-
682
- Args:
683
- plot_type: The type of plot ['sankey', 'network']
684
- cols_exclude: List of columns to exclude from the table
685
- table_pos: Position of the table ['left', 'right', 'top', 'bottom']
686
- table_area: The proportion of width or height to allocate to the table [0, 1]
687
- table_header_color: Color of the table header
688
- table_odd_color: Color of the odd table rows
689
- table_even_color: Color of the even table rows
690
- sankey_width_var: If plot_type is sankey, the variable that determines the sankey width
691
- sankey_color_var: If plot_type is sankey, the optional variable that determines the sankey edge color
692
- sankey_edge_colormap: If plot_type is sankey, the optional colormap. Used with sankey_color_var.
693
- sankey_vmin: The value that maps to the minimum color
694
- sankey_vmax: The value that maps to the maximum color
695
- network_orientation: The orientation of the network layout 'vertical'|'horizontal'
696
-
697
- Returns:
698
-
699
- """
700
-
701
- valid_plot_types: List[str] = ['sankey', 'network']
702
- if plot_type not in valid_plot_types:
703
- raise ValueError(f'The supplied plot_type is not in {valid_plot_types}')
704
-
705
- valid_table_pos: List[str] = ['top', 'bottom', 'left', 'right']
706
- if table_pos not in valid_table_pos:
707
- raise ValueError(f'The supplied table_pos is not in {valid_table_pos}')
708
-
709
- d_subplot, d_table, d_plot = self._get_position_kwargs(table_pos, table_area, plot_type)
710
-
711
- fig = make_subplots(**d_subplot, print_grid=False)
712
-
713
- df: pd.DataFrame = self.report().reset_index()
714
- if cols_exclude:
715
- df = df[[col for col in df.columns if col not in cols_exclude]]
716
- fmt: List[str] = ['%s'] + list(self._get_column_formats(df.columns, strip_percent=True).values())
717
- column_widths = [2] + [1] * (len(df.columns) - 1)
718
-
719
- fig.add_table(
720
- header=dict(values=list(df.columns),
721
- fill_color=table_header_color,
722
- align='center',
723
- font=dict(color='black', size=12)),
724
- columnwidth=column_widths,
725
- cells=dict(values=df.transpose().values.tolist(),
726
- align='left', format=fmt,
727
- fill_color=[
728
- [table_odd_color if i % 2 == 0 else table_even_color for i in range(len(df))] * len(
729
- df.columns)]),
730
- **d_table)
731
-
732
- if plot_type == 'sankey':
733
- # Create a mapping of node names to indices, and the integer nodes
734
- node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
735
- int_graph = nx.relabel_nodes(self.graph, node_indices)
736
-
737
- # Generate the sankey diagram arguments using the new graph with integer nodes
738
- d_sankey = self._generate_sankey_args(int_graph, sankey_color_var,
739
- sankey_edge_colormap,
740
- sankey_width_var,
741
- sankey_vmin,
742
- sankey_vmax)
743
- node, link = self._get_sankey_node_link_dicts(d_sankey)
744
- fig.add_trace(go.Sankey(node=node, link=link), **d_plot)
745
-
746
- elif plot_type == 'network':
747
- # pos = nx.spring_layout(self, seed=1234)
748
- pos = digraph_linear_layout(self.graph, orientation=network_orientation)
749
-
750
- edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
751
- fig.add_traces(data=[*edge_traces, node_trace, edge_annotation_trace], **d_plot)
752
-
753
- fig.update_layout(showlegend=False, hovermode='closest',
754
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
755
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
756
- paper_bgcolor='rgba(0,0,0,0)',
757
- plot_bgcolor='rgba(0,0,0,0)'
758
- )
759
-
760
- title = self._plot_title(compact=True)
761
- fig.update_layout(title_text=title, font_size=12)
762
-
763
- return fig
764
-
765
- def to_dataframe(self, stream_names: Optional[list[str]] = None, tidy: bool = True,
766
- as_mass: bool = False) -> pd.DataFrame:
767
- """Return a tidy dataframe
768
-
769
- Adds the mc name to the index so indexes are unique.
770
-
771
- Args:
772
- stream_names: Optional List of names of Stream/MassComposition objects (network edges) for export
773
- tidy: If True, the data will be returned in a tidy format, otherwise wide
774
- as_mass: If True, the mass data will be returned instead of the mass-composition data
775
-
776
- Returns:
777
-
778
- """
779
- chunks: List[pd.DataFrame] = []
780
- for u, v, data in self.graph.edges(data=True):
781
- if (stream_names is None) or ((stream_names is not None) and (data['mc'].name in stream_names)):
782
- if as_mass:
783
- chunks.append(data['mc'].mass_data.assign(name=data['mc'].name))
784
- else:
785
- chunks.append(data['mc'].data.assign(name=data['mc'].name))
786
-
787
- results: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name', append=True)
788
- if not tidy: # wide format
789
- results = results.unstack(level='name')
790
- column_order: list[str] = [f'{name}_{attr}' for name in results.columns.levels[1] for attr in
791
- results.columns.levels[0]]
792
- results.columns = [f'{col[1]}_{col[0]}' for col in results.columns]
793
- results = results[column_order]
794
-
795
- return results
796
-
797
- def plot_parallel(self,
798
- names: Optional[str] = None,
799
- color: Optional[str] = None,
800
- vars_include: Optional[List[str]] = None,
801
- vars_exclude: Optional[List[str]] = None,
802
- title: Optional[str] = None,
803
- include_dims: Optional[Union[bool, List[str]]] = True,
804
- plot_interval_edges: bool = False) -> go.Figure:
805
- """Create an interactive parallel plot
806
-
807
- Useful to explore multidimensional data like mass-composition data
808
-
809
- Args:
810
- names: Optional List of Names to plot
811
- color: Optional color variable
812
- vars_include: Optional List of variables to include in the plot
813
- vars_exclude: Optional List of variables to exclude in the plot
814
- title: Optional plot title
815
- include_dims: Optional boolean or list of dimension to include in the plot. True will show all dims.
816
- plot_interval_edges: If True, interval edges will be plotted instead of interval mid
817
-
818
- Returns:
819
-
820
- """
821
- df: pd.DataFrame = self.to_dataframe(stream_names=names)
822
-
823
- if not title and hasattr(self, 'name'):
824
- title = self.name
825
-
826
- fig = parallel_plot(data=df, color=color, vars_include=vars_include, vars_exclude=vars_exclude, title=title,
827
- include_dims=include_dims, plot_interval_edges=plot_interval_edges)
828
- return fig
829
-
830
- def _generate_sankey_args(self, int_graph, color_var, edge_colormap, width_var, v_min, v_max):
831
- rpt: pd.DataFrame = self.report()
832
- if color_var is not None:
833
- cmap = sns.color_palette(edge_colormap, as_cmap=True)
834
- rpt: pd.DataFrame = self.report()
835
- if not v_min:
836
- v_min = np.floor(rpt[color_var].min())
837
- if not v_max:
838
- v_max = np.ceil(rpt[color_var].max())
839
-
840
- # run the report for the hover data
841
- d_custom_data: Dict = self._rpt_to_html(df=rpt)
842
- source: List = []
843
- target: List = []
844
- value: List = []
845
- edge_custom_data = []
846
- edge_color: List = []
847
- edge_labels: List = []
848
- node_colors: List = []
849
- node_labels: List = []
850
-
851
- for n in int_graph.nodes:
852
- node_labels.append(int_graph.nodes[n]['mc'].name)
853
-
854
- if int_graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
855
- if int_graph.nodes[n]['mc'].is_balanced:
856
- node_colors.append('green')
857
- else:
858
- node_colors.append('red')
859
- else:
860
- node_colors.append('blue')
861
-
862
- for u, v, data in int_graph.edges(data=True):
863
- edge_labels.append(data['mc'].name)
864
- source.append(u)
865
- target.append(v)
866
- value.append(float(data['mc'].aggregate[width_var].iloc[0]))
867
- edge_custom_data.append(d_custom_data[data['mc'].name])
868
-
869
- if color_var is not None:
870
- val: float = float(data['mc'].aggregate[color_var].iloc[0])
871
- str_color: str = f'rgba{self._color_from_float(v_min, v_max, val, cmap)}'
872
- edge_color.append(str_color)
873
- else:
874
- edge_color: Optional[str] = None
875
-
876
- d_sankey: Dict = {'node_color': node_colors,
877
- 'edge_color': edge_color,
878
- 'edge_custom_data': edge_custom_data,
879
- 'edge_labels': edge_labels,
880
- 'labels': node_labels,
881
- 'source': source,
882
- 'target': target,
883
- 'value': value}
884
-
885
- return d_sankey
886
-
887
- @staticmethod
888
- def _get_sankey_node_link_dicts(d_sankey: Dict):
889
- node: Dict = dict(
890
- pad=15,
891
- thickness=20,
892
- line=dict(color="black", width=0.5),
893
- label=d_sankey['labels'],
894
- color=d_sankey['node_color'],
895
- customdata=d_sankey['labels']
896
- )
897
- link: Dict = dict(
898
- source=d_sankey['source'], # indices correspond to labels, eg A1, A2, A1, B1, ...
899
- target=d_sankey['target'],
900
- value=d_sankey['value'],
901
- color=d_sankey['edge_color'],
902
- label=d_sankey['edge_labels'], # over-written by hover template
903
- customdata=d_sankey['edge_custom_data'],
904
- hovertemplate='<b><i>%{label}</i></b><br />Source: %{source.customdata}<br />'
905
- 'Target: %{target.customdata}<br />%{customdata}'
906
- )
907
- return node, link
908
-
909
- def _get_scatter_node_edges(self, pos):
910
- # edges
911
- edge_color_map: Dict = {True: 'grey', False: 'red'}
912
- edge_annotations: Dict = {}
913
-
914
- edge_traces = []
915
- for u, v, data in self.graph.edges(data=True):
916
- x0, y0 = pos[u]
917
- x1, y1 = pos[v]
918
- edge_annotations[data['mc'].name] = {'pos': np.mean([pos[u], pos[v]], axis=0)}
919
- edge_traces.append(go.Scatter(x=[x0, x1], y=[y0, y1],
920
- line=dict(width=2, color=edge_color_map[data['mc'].status.ok]),
921
- hoverinfo='none',
922
- mode='lines+markers',
923
- text=str(data['mc'].name),
924
- marker=dict(
925
- symbol="arrow",
926
- color=edge_color_map[data['mc'].status.ok],
927
- size=16,
928
- angleref="previous",
929
- standoff=15)
930
- ))
931
-
932
- # nodes
933
- node_color_map: Dict = {None: 'grey', True: 'green', False: 'red'}
934
- node_x = []
935
- node_y = []
936
- node_color = []
937
- node_text = []
938
- node_label = []
939
- for node in self.graph.nodes():
940
- x, y = pos[node]
941
- node_x.append(x)
942
- node_y.append(y)
943
- node_color.append(node_color_map[self.graph.nodes[node]['mc'].is_balanced])
944
- node_text.append(node)
945
- node_label.append(self.graph.nodes[node]['mc'].name)
946
- node_trace = go.Scatter(
947
- x=node_x, y=node_y,
948
- mode='markers+text',
949
- hoverinfo='none',
950
- marker=dict(
951
- color=node_color,
952
- size=30,
953
- line_width=2),
954
- text=node_text,
955
- customdata=node_label,
956
- hovertemplate='%{customdata}<extra></extra>')
957
-
958
- # edge annotations
959
- edge_labels = list(edge_annotations.keys())
960
- edge_label_x = [edge_annotations[k]['pos'][0] for k, v in edge_annotations.items()]
961
- edge_label_y = [edge_annotations[k]['pos'][1] for k, v in edge_annotations.items()]
962
-
963
- edge_annotation_trace = go.Scatter(
964
- x=edge_label_x, y=edge_label_y,
965
- mode='markers',
966
- hoverinfo='text',
967
- marker=dict(
968
- color='grey',
969
- size=3,
970
- line_width=1),
971
- text=edge_labels)
972
-
973
- return edge_traces, node_trace, edge_annotation_trace
974
-
975
- @staticmethod
976
- def _get_position_kwargs(table_pos, table_area, plot_type):
977
- """Helper to manage location dependencies
978
-
979
- Args:
980
- table_pos: position of the table: left|right|top|bottom
981
- table_area: fraction of the plot to assign to the table [0, 1]
982
-
983
- Returns:
984
-
985
- """
986
- name_type_map: Dict = {'sankey': 'sankey', 'network': 'xy'}
987
- specs = [[{"type": 'table'}, {"type": name_type_map[plot_type]}]]
988
-
989
- widths: Optional[List[float]] = [table_area, 1.0 - table_area]
990
- subplot_kwargs: Dict = {'rows': 1, 'cols': 2, 'specs': specs}
991
- table_kwargs: Dict = {'row': 1, 'col': 1}
992
- plot_kwargs: Dict = {'row': 1, 'col': 2}
993
-
994
- if table_pos == 'left':
995
- subplot_kwargs['column_widths'] = widths
996
- elif table_pos == 'right':
997
- subplot_kwargs['column_widths'] = widths[::-1]
998
- subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}, {"type": 'table'}]]
999
- table_kwargs['col'] = 2
1000
- plot_kwargs['col'] = 1
1001
- else:
1002
- subplot_kwargs['rows'] = 2
1003
- subplot_kwargs['cols'] = 1
1004
- table_kwargs['col'] = 1
1005
- plot_kwargs['col'] = 1
1006
- if table_pos == 'top':
1007
- subplot_kwargs['row_heights'] = widths
1008
- subplot_kwargs['specs'] = [[{"type": 'table'}], [{"type": name_type_map[plot_type]}]]
1009
- table_kwargs['row'] = 1
1010
- plot_kwargs['row'] = 2
1011
- elif table_pos == 'bottom':
1012
- subplot_kwargs['row_heights'] = widths[::-1]
1013
- subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}], [{"type": 'table'}]]
1014
- table_kwargs['row'] = 2
1015
- plot_kwargs['row'] = 1
1016
-
1017
- if plot_type == 'network': # different arguments for different plots
1018
- plot_kwargs = {f'{k}s': v for k, v in plot_kwargs.items()}
1019
-
1020
- return subplot_kwargs, table_kwargs, plot_kwargs
1021
-
1022
- def _rpt_to_html(self, df: pd.DataFrame) -> Dict:
1023
- custom_data: Dict = {}
1024
- fmts: Dict = self._get_column_formats(df.columns)
1025
- for i, row in df.iterrows():
1026
- str_data: str = '<br />'
1027
- for k, v in dict(row).items():
1028
- str_data += f"{k}: {v:{fmts[k][1:]}}<br />"
1029
- custom_data[i] = str_data
1030
- return custom_data
1031
-
1032
- @staticmethod
1033
- def _color_from_float(vmin: float, vmax: float, val: float,
1034
- cmap: Union[ListedColormap, LinearSegmentedColormap]) -> Tuple[float, float, float]:
1035
- if isinstance(cmap, ListedColormap):
1036
- color_index: int = int((val - vmin) / ((vmax - vmin) / 256.0))
1037
- color_index = min(max(0, color_index), 255)
1038
- color_rgba = tuple(cmap.colors[color_index])
1039
- elif isinstance(cmap, LinearSegmentedColormap):
1040
- norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
1041
- m = cm.ScalarMappable(norm=norm, cmap=cmap)
1042
- r, g, b, a = m.to_rgba(val, bytes=True)
1043
- color_rgba = int(r), int(g), int(b), int(a)
1044
- else:
1045
- NotImplementedError("Unrecognised colormap type")
1046
-
1047
- return color_rgba
1048
-
1049
- def set_node_names(self, node_names: Dict[int, str]):
1050
- """Set the names of network nodes with a Dict
1051
- """
1052
- for node in node_names.keys():
1053
- if ('mc' in self.graph.nodes[node].keys()) and (node in node_names.keys()):
1054
- self.graph.nodes[node]['mc'].name = node_names[node]
1055
-
1056
-
1057
- def set_stream_data(self, stream_data: dict[str, Optional[MC]]):
1058
- """Set the data (MassComposition) of network edges (streams) with a Dict"""
1059
- for stream_name, stream_data in stream_data.items():
1060
- stream_found = False
1061
- nodes_to_refresh = set()
1062
- for u, v, data in self.graph.edges(data=True):
1063
- if 'mc' in data.keys() and (data['mc'].name if data['mc'] is not None else data['name']) == stream_name:
1064
- self._logger.info(f'Setting data on stream {stream_name}')
1065
- data['mc'] = stream_data
1066
- stream_found = True
1067
- nodes_to_refresh.update([u, v])
1068
- if not stream_found:
1069
- self._logger.warning(f'Stream {stream_name} not found in graph')
1070
- else:
1071
- # refresh the node status
1072
- for node in nodes_to_refresh:
1073
- self.graph.nodes[node]['mc'].inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
1074
- self.graph.in_edges(node)]
1075
- self.graph.nodes[node]['mc'].outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
1076
- self.graph.out_edges(node)]
1077
-
1078
-
1079
- def set_operation_data(self, node):
1080
- """Set the input and output data for a node.
1081
- Uses the data on the edges (streams) connected to the node to refresh the data and check for node balance.
1082
- """
1083
- node_data: Operation = self.graph.nodes[node]['mc']
1084
- node_data.inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.in_edges(node)]
1085
- node_data.outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.out_edges(node)]
1086
- node_data.check_balance()
1087
-
1088
-
1089
- def streams_to_dict(self) -> Dict[str, MC]:
1090
- """Export the Stream objects to a Dict
1091
-
1092
- Returns:
1093
- A dictionary keyed by name containing MassComposition objects
1094
-
1095
- """
1096
- streams: Dict[str, MC] = {}
1097
- for u, v, data in self.graph.edges(data=True):
1098
- if 'mc' in data.keys():
1099
- streams[data['mc'].name] = data['mc']
1100
- return streams
1101
-
1102
-
1103
- def nodes_to_dict(self) -> Dict[int, OP]:
1104
- """Export the MCNode objects to a Dict
1105
-
1106
- Returns:
1107
- A dictionary keyed by integer containing MCNode objects
1108
-
1109
- """
1110
- nodes: Dict[int, OP] = {}
1111
- for node in self.graph.nodes.keys():
1112
- if 'mc' in self.graph.nodes[node].keys():
1113
- nodes[node] = self.graph.nodes[node]['mc']
1114
- return nodes
1115
-
1116
-
1117
- def set_nodes(self, stream: str, nodes: Tuple[int, int]):
1118
- mc: MC = self.get_stream_by_name(stream)
1119
- mc._nodes = nodes
1120
- self._update_graph(mc)
1121
-
1122
-
1123
- def reset_nodes(self, stream: Optional[str] = None):
1124
- """Reset stream nodes to break relationships
1125
-
1126
- Args:
1127
- stream: The optional stream (edge) within the network.
1128
- If None all streams nodes on the network will be reset.
1129
-
1130
-
1131
- Returns:
1132
-
1133
- """
1134
- if stream is None:
1135
- streams: Dict[str, MC] = self.streams_to_dict()
1136
- for k, v in streams.items():
1137
- streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
1138
- self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
1139
- else:
1140
- mc: MC = self.get_stream_by_name(stream)
1141
- mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
1142
- self._update_graph(mc)
1143
-
1144
-
1145
- def _update_graph(self, mc: MC):
1146
- """Update the graph with an existing stream object
1147
-
1148
- Args:
1149
- mc: The stream object
1150
-
1151
- Returns:
1152
-
1153
- """
1154
- # brutal approach - rebuild from streams
1155
- strms: List[Union[Stream, MC]] = []
1156
- for u, v, a in self.graph.edges(data=True):
1157
- if a.get('mc') and a['mc'].name == mc.name:
1158
- strms.append(mc)
1159
- else:
1160
- strms.append(a['mc'])
1161
- self.graph = Flowsheet(name=self.name).from_objects(objects=strms).graph
1162
-
1163
-
1164
- def get_stream_by_name(self, name: str) -> MC:
1165
- """Get the Stream object from the network by its name
1166
-
1167
- Args:
1168
- name: The string name of the Stream object stored on an edge in the network.
1169
-
1170
- Returns:
1171
-
1172
- """
1173
-
1174
- res: Optional[Union[Stream, MC]] = None
1175
- for u, v, a in self.graph.edges(data=True):
1176
- if a.get('mc') and a['mc'].name == name:
1177
- res = a['mc']
1178
-
1179
- if not res:
1180
- raise ValueError(f"The specified name: {name} is not found on the network.")
1181
-
1182
- return res
1183
-
1184
-
1185
- def set_stream_parent(self, stream: str, parent: str):
1186
- mc: MC = self.get_stream_by_name(stream)
1187
- mc.set_parent_node(self.get_stream_by_name(parent))
1188
- self._update_graph(mc)
1189
-
1190
-
1191
- def set_stream_child(self, stream: str, child: str):
1192
- mc: MC = self.get_stream_by_name(stream)
1193
- mc.set_child_node(self.get_stream_by_name(child))
1194
- self._update_graph(mc)
1195
-
1196
-
1197
- def reset_stream_nodes(self, stream: Optional[str] = None):
1198
- """Reset stream nodes to break relationships
1199
-
1200
- Args:
1201
- stream: The optional stream (edge) within the network.
1202
- If None all streams nodes on the network will be reset.
1203
-
1204
-
1205
- Returns:
1206
-
1207
- """
1208
- if stream is None:
1209
- streams: Dict[str, MC] = self.streams_to_dict()
1210
- for k, v in streams.items():
1211
- streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
1212
- self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
1213
- else:
1214
- mc: MC = self.get_stream_by_name(stream)
1215
- mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
1216
- self._update_graph(mc)
1
+ import copy
2
+ import json
3
+ import logging
4
+ import uuid
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Tuple, Union, TypeVar, TYPE_CHECKING
7
+ import re
8
+
9
+ import matplotlib
10
+ import matplotlib.cm as cm
11
+ import networkx as nx
12
+ import numpy as np
13
+ import pandas as pd
14
+ import plotly.graph_objects as go
15
+ import seaborn as sns
16
+ import yaml
17
+ from matplotlib import pyplot as plt
18
+ from matplotlib.colors import ListedColormap, LinearSegmentedColormap
19
+ from networkx.algorithms.dag import is_directed_acyclic_graph
20
+ from plotly.subplots import make_subplots
21
+
22
+ from elphick.geomet import Sample
23
+ from elphick.geomet.base import MC
24
+ from elphick.geomet.config.config_read import get_column_config
25
+ from elphick.geomet.flowsheet.operation import NodeType, OP, PartitionOperation, Operation
26
+ from elphick.geomet.plot import parallel_plot, comparison_plot
27
+ from elphick.geomet.utils.layout import digraph_linear_layout
28
+ from elphick.geomet.flowsheet.loader import streams_from_dataframe
29
+
30
+ # if TYPE_CHECKING:
31
+ from elphick.geomet.flowsheet.stream import Stream
32
+
33
+ # generic type variable, used for type hinting that play nicely with subclasses
34
+ FS = TypeVar('FS', bound='Flowsheet')
35
+
36
+
37
+ class Flowsheet:
38
+ def __init__(self, name: str = 'Flowsheet'):
39
+ self.graph: nx.DiGraph = nx.DiGraph(name=name)
40
+ self._logger: logging.Logger = logging.getLogger(__class__.__name__)
41
+
42
+ @property
43
+ def name(self) -> str:
44
+ return self.graph.name
45
+
46
+ @name.setter
47
+ def name(self, value: str):
48
+ self.graph.name = value
49
+
50
+ @property
51
+ def healthy(self) -> bool:
52
+ return self.all_nodes_healthy and self.all_streams_healthy
53
+
54
+ @property
55
+ def all_nodes_healthy(self) -> bool:
56
+ bal_vals: List = [self.graph.nodes[n]['mc'].is_balanced for n in self.graph.nodes]
57
+ bal_vals = [bv for bv in bal_vals if bv is not None]
58
+ return all(bal_vals)
59
+
60
+ @property
61
+ def all_streams_healthy(self) -> bool:
62
+ """Check if all streams are healthy"""
63
+ # account for the fact that some edges may not have an mc object
64
+ if not all([d['mc'] for u, v, d in self.graph.edges(data=True)]):
65
+ return False
66
+ return all([self.graph.edges[u, v]['mc'].status.ok for u, v in self.graph.edges])
67
+
68
+ @classmethod
69
+ def from_objects(cls, objects: list[MC],
70
+ name: Optional[str] = 'Flowsheet') -> FS:
71
+ """Instantiate from a list of objects
72
+
73
+ This method is only suitable for objects that have the `_nodes` property set, such as objects that have
74
+ been created from math operations, which preserve relationships between objects (via the nodes property)
75
+
76
+ Args:
77
+ objects: List of MassComposition objects, such as Sample, IntervalSample, Stream or BlockModel
78
+ name: name of the flowsheet/network
79
+
80
+ Returns:
81
+
82
+ """
83
+ from elphick.geomet.flowsheet.operation import Operation
84
+ cls._check_indexes(objects)
85
+ bunch_of_edges: list = []
86
+ for stream in objects:
87
+ if stream.nodes is None:
88
+ raise KeyError(f'Stream {stream.name} does not have the node property set')
89
+ nodes = stream.nodes
90
+
91
+ # add the objects to the edges
92
+ bunch_of_edges.append((nodes[0], nodes[1], {'mc': stream, 'name': stream.name}))
93
+
94
+ graph = nx.DiGraph(name=name)
95
+ graph.add_edges_from(bunch_of_edges)
96
+ operation_objects: dict = {}
97
+ for node in graph.nodes:
98
+ operation_objects[node] = Operation(name=node)
99
+ nx.set_node_attributes(graph, operation_objects, 'mc')
100
+
101
+ for node in graph.nodes:
102
+ operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
103
+ operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
104
+
105
+ graph = nx.convert_node_labels_to_integers(graph)
106
+ # update the temporary nodes on the mc object property to match the renumbered integers
107
+ for node1, node2, data in graph.edges(data=True):
108
+ data['mc'].nodes = [node1, node2]
109
+ # update the node names after renumbering
110
+ for node in graph.nodes:
111
+ graph.nodes[node]['mc'].name = str(node)
112
+ obj = cls()
113
+ obj.graph = graph
114
+ return obj
115
+
116
+ @classmethod
117
+ def from_dataframe(cls, df: pd.DataFrame, name: Optional[str] = 'Flowsheet',
118
+ mc_name_col: Optional[str] = None, n_jobs: int = 1) -> FS:
119
+ """Instantiate from a DataFrame
120
+
121
+ Args:
122
+ df: The DataFrame
123
+ name: name of the network
124
+ mc_name_col: The column specified contains the names of objects to create.
125
+ If None the DataFrame is assumed to be wide and the mc objects will be extracted from column prefixes.
126
+ n_jobs: The number of parallel jobs to run. If -1, will use all available cores.
127
+
128
+ Returns:
129
+ Flowsheet: An instance of the Flowsheet class initialized from the provided DataFrame.
130
+
131
+ """
132
+ streams: list[Sample] = streams_from_dataframe(df=df, mc_name_col=mc_name_col, n_jobs=n_jobs)
133
+ return cls().from_objects(objects=streams, name=name)
134
+
135
+ @classmethod
136
+ def from_dict(cls, config: dict) -> FS:
137
+ """Create a flowsheet from a dictionary
138
+
139
+ Args:
140
+ config: dictionary containing the flowsheet configuration
141
+
142
+ Returns:
143
+ A Flowsheet object with no data on the edges
144
+ """
145
+
146
+ from elphick.geomet.flowsheet.operation import Operation
147
+
148
+ if 'FLOWSHEET' not in config:
149
+ raise ValueError("Dictionary does not contain 'FLOWSHEET' root node")
150
+
151
+ flowsheet_config = config['FLOWSHEET']
152
+
153
+ # create the Stream objects
154
+ bunch_of_edges: list = []
155
+ for stream, stream_config in flowsheet_config['streams'].items():
156
+ bunch_of_edges.append(
157
+ (stream_config['node_in'], stream_config['node_out'], {'mc': None, 'name': stream_config['name']}))
158
+
159
+ graph = nx.DiGraph(name=flowsheet_config['flowsheet']['name'])
160
+ graph.add_edges_from(bunch_of_edges)
161
+ operation_objects: dict = {}
162
+ for node in graph.nodes:
163
+ # create the correct type of node object
164
+ if node in flowsheet_config['operations']:
165
+ operation_type = flowsheet_config['operations'][node].get('type', 'Operation')
166
+ if operation_type == 'PartitionOperation':
167
+ # get the output stream names from the graph
168
+ output_stream_names = [d['name'] for u, v, d in graph.out_edges(node, data=True)]
169
+ node_config = flowsheet_config['operations'][node]
170
+ node_config['output_stream_names'] = output_stream_names
171
+ operation_objects[node] = PartitionOperation.from_dict(node_config)
172
+ else:
173
+ operation_objects[node] = Operation.from_dict(flowsheet_config['operations'][node])
174
+
175
+ # set the input and output streams on the operation object for the selected node
176
+ operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
177
+ operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
178
+
179
+ nx.set_node_attributes(graph, operation_objects, 'mc')
180
+
181
+ graph = nx.convert_node_labels_to_integers(graph)
182
+
183
+ obj = cls()
184
+ obj.graph = graph
185
+
186
+ return obj
187
+
188
+ # @classmethod
189
+ # def from_dict_todo(cls, config: dict) -> FS:
190
+ # TODO: This method is not yet implemented - fails because the Operations do not have inputs or outputs set.
191
+ # flowsheet = cls()
192
+ #
193
+ # # Process streams
194
+ # for stream_name, stream_data in config['FLOWSHEET']['streams'].items():
195
+ # stream = Stream.from_dict(stream_data)
196
+ # flowsheet.add_stream(stream)
197
+ #
198
+ # # Process operations
199
+ # for operation_name, operation_data in config['FLOWSHEET']['operations'].items():
200
+ # operation_type = operation_data.get('type', 'Operation')
201
+ # if operation_type == 'PartitionOperation':
202
+ # operation = PartitionOperation.from_dict(operation_data)
203
+ # else:
204
+ # operation = Operation.from_dict(operation_data)
205
+ # flowsheet.add_operation(operation)
206
+ #
207
+ # return flowsheet
208
+
209
+ @classmethod
210
+ def from_yaml(cls, file_path: Path) -> FS:
211
+ """Create a flowsheet from yaml
212
+
213
+ Args:
214
+ file_path: path to the yaml file
215
+
216
+ Returns:
217
+ A Flowsheet object with no data on the edges
218
+ """
219
+ with open(file_path, 'r') as file:
220
+ config = yaml.safe_load(file)
221
+
222
+ return cls.from_dict(config)
223
+
224
+ @classmethod
225
+ def from_json(cls, file_path: Path) -> FS:
226
+ """Create a flowsheet from json
227
+
228
+ Args:
229
+ file_path: path to the json file
230
+
231
+ Returns:
232
+ A Flowsheet object with no data on the edges
233
+ """
234
+ with open(file_path, 'r') as file:
235
+ config = json.load(file)
236
+
237
+ return cls.from_dict(config)
238
+
239
+ def add_stream(self, stream: 'Stream'):
240
+ """Add a stream to the flowsheet."""
241
+ self.graph.add_edge(stream.nodes[0], stream.nodes[1], mc=stream, name=stream.name)
242
+
243
+ def add_operation(self, operation: 'Operation'):
244
+ """Add an operation to the flowsheet."""
245
+ self.graph.add_node(operation.name, mc=operation)
246
+
247
+ def unhealthy_stream_records(self) -> pd.DataFrame:
248
+ """Return on unhealthy streams
249
+
250
+ Return the records for all streams that are not healthy.
251
+ Returns:
252
+ DataFrame: A DataFrame containing the unhealthy stream records
253
+ """
254
+ unhealthy_edges = [e for e in self.graph.edges if not self.graph.edges[e]['mc'].status.ok]
255
+ unhealthy_data: pd.DataFrame = pd.concat(
256
+ [self.graph.edges[e]['mc'].status.oor.assign(stream=self.graph.edges[e]['mc'].name) for e in
257
+ unhealthy_edges], axis=1)
258
+ # move the last column to the front
259
+ unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
260
+
261
+ # append the flowsheet records for additional context
262
+ records: pd.DataFrame = self.to_dataframe()
263
+ records = records.unstack(level='name').swaplevel(axis=1).sort_index(axis=1, level=0, sort_remaining=False)
264
+ records.columns = [f"{col[0]}_{col[1]}" for col in records.columns]
265
+
266
+ result = unhealthy_data.merge(records, left_index=True, right_index=True, how='left')
267
+ return result
268
+
269
+ def unhealthy_node_records(self) -> pd.DataFrame:
270
+ """Return unhealthy nodes
271
+
272
+ Return the records for all nodes that are not healthy.
273
+ Returns:
274
+ DataFrame: A DataFrame containing the unhealthy node records
275
+ """
276
+ unhealthy_nodes = [n for n in self.graph.nodes if
277
+ self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE and not self.graph.nodes[n][
278
+ 'mc'].is_balanced]
279
+ unhealthy_data: pd.DataFrame = pd.concat(
280
+ [self.graph.nodes[n]['mc'].unbalanced_records.assign(node=self.graph.nodes[n]['mc'].name) for n in
281
+ unhealthy_nodes], axis=1)
282
+ # move the last column to the front
283
+ unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
284
+
285
+ # todo: append the streams around the node
286
+
287
+ return unhealthy_data
288
+
289
+ def copy_without_stream_data(self):
290
+ """Copy without stream data"""
291
+ new_flowsheet = Flowsheet(name=self.name)
292
+ new_graph = nx.DiGraph()
293
+
294
+ # Copy nodes with Operation objects
295
+ for node, data in self.graph.nodes(data=True):
296
+ new_data = data.copy()
297
+ new_graph.add_node(node, **new_data)
298
+
299
+ # Copy edges with mc attribute set to None
300
+ for u, v, data in self.graph.edges(data=True):
301
+ new_data = {k: (None if k == 'mc' else copy.deepcopy(v)) for k, v in data.items()}
302
+ new_graph.add_edge(u, v, **new_data)
303
+
304
+ new_flowsheet.graph = new_graph
305
+ return new_flowsheet
306
+
307
+ def solve(self):
308
+ """Solve missing streams"""
309
+
310
+ if not is_directed_acyclic_graph(self.graph):
311
+ self._logger.error("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
312
+ self._logger.debug(f"Graph nodes: {self.graph.nodes(data=True)}")
313
+ self._logger.debug(f"Graph edges: {self.graph.edges(data=True)}")
314
+ raise ValueError("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
315
+
316
+ # Check the number of missing mc's on edges in the network
317
+ missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
318
+ prev_missing_count = missing_count + 1 # Initialize with a value greater than missing_count
319
+
320
+ while 0 < missing_count < prev_missing_count:
321
+ prev_missing_count = missing_count
322
+ for node in nx.topological_sort(self.graph):
323
+ if self.graph.nodes[node]['mc'].node_type == NodeType.BALANCE:
324
+ if self.graph.nodes[node]['mc'].has_empty_input:
325
+ mc: MC = self.graph.nodes[node]['mc'].solve()
326
+ # copy the solved object to the empty input edges
327
+ for predecessor in self.graph.predecessors(node):
328
+ edge_data = self.graph.get_edge_data(predecessor, node)
329
+ if edge_data and edge_data['mc'] is None:
330
+ edge_data['mc'] = mc
331
+ edge_data['mc'].name = edge_data['name']
332
+ self.set_operation_data(predecessor)
333
+
334
+ if self.graph.nodes[node]['mc'].has_empty_output:
335
+ # There are two cases to be managed, 1. a single output missing,
336
+ # 2. a partition operation that returns two outputs
337
+ if isinstance(self.graph.nodes[node]['mc'], PartitionOperation):
338
+ partition_stream: str = self.graph.nodes[node]['mc'].partition['partition_stream']
339
+ mc1, mc2 = self.graph.nodes[node]['mc'].solve()
340
+ # copy the solved object to the empty output edges
341
+ for successor in self.graph.successors(node):
342
+ edge_data = self.graph.get_edge_data(node, successor)
343
+ if edge_data and edge_data['mc'] is None:
344
+ edge_data['mc'] = mc1 if edge_data['name'] == partition_stream else mc2
345
+ edge_data['mc'].name = edge_data['name']
346
+ self.set_operation_data(successor)
347
+
348
+ else:
349
+ mc: MC = self.graph.nodes[node]['mc'].solve()
350
+ # copy the solved object to the empty output edges
351
+ for successor in self.graph.successors(node):
352
+ edge_data = self.graph.get_edge_data(node, successor)
353
+ if edge_data and edge_data['mc'] is None:
354
+ edge_data['mc'] = mc
355
+ edge_data['mc'].name = edge_data['name']
356
+ self.set_operation_data(successor)
357
+ self.set_operation_data(node)
358
+
359
+ missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
360
+ self._logger.info(f"Missing count: {missing_count}")
361
+
362
+ if missing_count > 0:
363
+ self._logger.error(f"Failed to solve the flowsheet. Missing count: {missing_count}")
364
+ raise ValueError(
365
+ f"Failed to solve the flowsheet. Some streams are still missing. Missing count: {missing_count}")
366
+
367
+ def query(self, expr: str, stream_name: Optional[str] = None, inplace=False) -> 'Flowsheet':
368
+ """Reduce the Flowsheet Stream records with a query
369
+
370
+ Args:
371
+ expr: The query string to apply to all streams. The query is applied in place. The LHS of the
372
+ expression requires a prefix that defines the stream name e.g. stream_name.var_name > 0.5
373
+ stream_name: The name of the stream to apply the query to. If None, the query is applied to the
374
+ first input stream.
375
+ inplace: If True, apply the query in place on the same object, otherwise return a new instance.
376
+
377
+ Returns:
378
+ A Flowsheet object where the stream records conform to the query
379
+ """
380
+ if stream_name is None:
381
+ input_stream: MC = self.get_input_streams()[0]
382
+ else:
383
+ input_stream: MC = self.get_stream_by_name(name=stream_name)
384
+ filtered_index: pd.Index = input_stream.data.query(expr).index
385
+ return self._filter(filtered_index, inplace)
386
+
387
+ def filter_by_index(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
388
+ """Filter the Flowsheet Stream records by a given index.
389
+
390
+ Args:
391
+ index: The index to filter the data.
392
+ inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
393
+
394
+ Returns:
395
+ A Flowsheet object where the stream records are filtered by the given index.
396
+ """
397
+ return self._filter(index, inplace)
398
+
399
+ def _filter(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
400
+ """Private method to filter the Flowsheet Stream records by a given index.
401
+
402
+ Args:
403
+ index: The index to filter the data.
404
+ inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
405
+
406
+ Returns:
407
+ A Flowsheet object where the stream records are filtered by the given index.
408
+ """
409
+ if inplace:
410
+ for u, v, d in self.graph.edges(data=True):
411
+ if d.get('mc') is not None:
412
+ d.get('mc').filter_by_index(index)
413
+ return self
414
+ else:
415
+ obj: Flowsheet = self.copy_without_stream_data()
416
+ for u, v, d in self.graph.edges(data=True):
417
+ if d.get('mc') is not None:
418
+ mc: MC = d.get('mc')
419
+ mc_new = mc.__class__(name=mc.name)
420
+ # Copy each attribute
421
+ for attr, value in mc.__dict__.items():
422
+ if attr in ['_mass_data', '_supplementary_data'] and value is not None:
423
+ value = value.loc[index]
424
+ setattr(mc_new, attr, copy.deepcopy(value))
425
+ mc_new.aggregate = mc_new.weight_average()
426
+ obj.graph[u][v]['mc'] = mc_new
427
+ return obj
428
+
429
+ def get_input_streams(self) -> list[MC]:
430
+ """Get the input (feed) streams (edge objects)
431
+
432
+ Returns:
433
+ List of MassComposition-like objects
434
+ """
435
+
436
+ # Create a dictionary that maps node names to their degrees
437
+ degrees = {n: d for n, d in self.graph.degree()}
438
+
439
+ res: list[MC] = [d['mc'] for u, v, d in self.graph.edges(data=True) if degrees[u] == 1]
440
+ if not res:
441
+ raise ValueError("No input streams found")
442
+ return res
443
+
444
+ def get_output_streams(self) -> list[MC]:
445
+ """Get the output (product) streams (edge objects)
446
+
447
+ Returns:
448
+ List of MassComposition-like objects
449
+ """
450
+
451
+ # Create a dictionary that maps node names to their degrees
452
+ degrees = {n: d for n, d in self.graph.degree()}
453
+
454
+ res: list[MC] = [d['mc'] for u, v, d in self.graph.edges(data=True) if degrees[v] == 1]
455
+ if not res:
456
+ raise ValueError("No output streams found")
457
+ return res
458
+
459
+ @staticmethod
460
+ def _check_indexes(streams):
461
+
462
+ list_of_indexes = [s._mass_data.index for s in streams]
463
+ types_of_indexes = [type(i) for i in list_of_indexes]
464
+ # check the index types are consistent
465
+ if len(set(types_of_indexes)) != 1:
466
+ raise KeyError("stream index types are not consistent")
467
+
468
+ def plot(self, orientation: str = 'horizontal') -> plt.Figure:
469
+ """Plot the network with matplotlib
470
+
471
+ Args:
472
+ orientation: 'horizontal'|'vertical' network layout
473
+
474
+ Returns:
475
+
476
+ """
477
+
478
+ hf, ax = plt.subplots()
479
+ # pos = nx.spring_layout(self, seed=1234)
480
+ pos = digraph_linear_layout(self.graph, orientation=orientation)
481
+
482
+ edge_labels: Dict = {}
483
+ edge_colors: List = []
484
+ node_colors: List = []
485
+
486
+ for node1, node2, data in self.graph.edges(data=True):
487
+ edge_labels[(node1, node2)] = data['mc'].name if data['mc'] is not None else data['name']
488
+ if data['mc'] and data['mc'].status.ok:
489
+ edge_colors.append('gray')
490
+ else:
491
+ edge_colors.append('red')
492
+
493
+ for n in self.graph.nodes:
494
+ if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
495
+ if self.graph.nodes[n]['mc'].is_balanced:
496
+ node_colors.append('green')
497
+ else:
498
+ node_colors.append('red')
499
+ else:
500
+ node_colors.append('gray')
501
+
502
+ nx.draw(self.graph, pos=pos, ax=ax, with_labels=True, font_weight='bold',
503
+ node_color=node_colors, edge_color=edge_colors)
504
+
505
+ nx.draw_networkx_edge_labels(self.graph, pos=pos, ax=ax, edge_labels=edge_labels, font_color='black')
506
+ ax.set_title(self._plot_title(html=False), fontsize=10)
507
+
508
+ return hf
509
+
510
+ def _plot_title(self, html: bool = True, compact: bool = False):
511
+ # title = self.name
512
+ title = (f"{self.name}<br><sup>Nodes Healthy: "
513
+ f"<span style='color: {'red' if not self.all_nodes_healthy else 'black'}'>{self.all_nodes_healthy}</span>, "
514
+ f"Streams Healthy: "
515
+ f"<span style='color: {'red' if not self.all_streams_healthy else 'black'}'>{self.all_streams_healthy}</span></sup>")
516
+ # if compact:
517
+ # title = title.replace("<br><br>", "<br>").replace("<br>Edge", ", Edge")
518
+ # if not self.edge_status[0]:
519
+ # title = title.replace("</sup>", "") + f", {self.edge_status[1]}</sup>"
520
+ if not html:
521
+ title = title.replace('<br><br>', '\n').replace('<br>', '\n').replace('<sup>', '').replace('</sup>', '')
522
+ title = re.sub(r'<span style=.*?>(.*?)</span>', r'\1', title)
523
+ return title
524
+
525
+ def report(self, apply_formats: bool = False) -> pd.DataFrame:
526
+ """Summary Report
527
+
528
+ Total Mass and weight averaged composition
529
+ Returns:
530
+
531
+ """
532
+ chunks: List[pd.DataFrame] = []
533
+ for n, nbrs in self.graph.adj.items():
534
+ for nbr, eattr in nbrs.items():
535
+ if eattr['mc'] is None or eattr['mc'].data.empty:
536
+ edge_name: str = eattr['mc']['name'] if eattr['mc'] is not None else eattr['name']
537
+ raise KeyError(f"Cannot generate report on empty dataset: {edge_name}")
538
+ chunks.append(eattr['mc'].aggregate.assign(name=eattr['mc'].name))
539
+ rpt: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name')
540
+ if apply_formats:
541
+ fmts: Dict = self._get_column_formats(rpt.columns)
542
+ for k, v in fmts.items():
543
+ rpt[k] = rpt[k].apply((v.replace('%', '{:,') + '}').format)
544
+ return rpt
545
+
546
+ def _get_column_formats(self, columns: List[str], strip_percent: bool = False) -> Dict[str, str]:
547
+ """
548
+
549
+ Args:
550
+ columns: The columns to lookup format strings for
551
+ strip_percent: If True remove the leading % symbol from the format (for plotly tables)
552
+
553
+ Returns:
554
+
555
+ """
556
+ strm = self.get_input_streams()[0]
557
+ d_format: dict = get_column_config(config_dict=strm.config, var_map=strm.variable_map, config_key='format')
558
+
559
+ if strip_percent:
560
+ d_format = {k: v.strip('%') for k, v in d_format.items()}
561
+
562
+ return d_format
563
+
564
+ def plot_balance(self, facet_col_wrap: int = 3,
565
+ color: Optional[str] = 'node') -> go.Figure:
566
+ """Plot input versus output across all nodes in the network
567
+
568
+ Args:
569
+ facet_col_wrap: the number of subplots per row before wrapping
570
+ color: The optional variable to color by. If None color will be by Node
571
+
572
+ Returns:
573
+
574
+ """
575
+ # prepare the data
576
+ chunks_in: List = []
577
+ chunks_out: List = []
578
+ for n in self.graph.nodes:
579
+ if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
580
+ chunks_in.append(self.graph.nodes[n]['mc'].add('in').assign(**{'direction': 'in', 'node': n}))
581
+ chunks_out.append(self.graph.nodes[n]['mc'].add('out').assign(**{'direction': 'out', 'node': n}))
582
+ df_in: pd.DataFrame = pd.concat(chunks_in)
583
+ index_names = ['direction', 'node'] + df_in.index.names
584
+ df_in = df_in.reset_index().melt(id_vars=index_names)
585
+ df_out: pd.DataFrame = pd.concat(chunks_out).reset_index().melt(id_vars=index_names)
586
+ df_plot: pd.DataFrame = pd.concat([df_in, df_out])
587
+ df_plot = df_plot.set_index(index_names + ['variable'], append=True).unstack(['direction'])
588
+ df_plot.columns = df_plot.columns.droplevel(0)
589
+ df_plot.reset_index(level=list(np.arange(-1, -len(index_names) - 1, -1)), inplace=True)
590
+ df_plot['node'] = pd.Categorical(df_plot['node'])
591
+
592
+ # plot
593
+ fig = comparison_plot(data=df_plot,
594
+ x='in', y='out',
595
+ facet_col_wrap=facet_col_wrap,
596
+ color=color)
597
+ return fig
598
+
599
+ def plot_network(self, orientation: str = 'horizontal') -> go.Figure:
600
+ """Plot the network with plotly
601
+
602
+ Args:
603
+ orientation: 'horizontal'|'vertical' network layout
604
+
605
+ Returns:
606
+
607
+ """
608
+ # pos = nx.spring_layout(self, seed=1234)
609
+ pos = digraph_linear_layout(self.graph, orientation=orientation)
610
+
611
+ edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
612
+ title = self._plot_title()
613
+
614
+ fig = go.Figure(data=[*edge_traces, node_trace, edge_annotation_trace],
615
+ layout=go.Layout(
616
+ title=title,
617
+ titlefont_size=16,
618
+ showlegend=False,
619
+ hovermode='closest',
620
+ margin=dict(b=20, l=5, r=5, t=40),
621
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
622
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
623
+ paper_bgcolor='rgba(0,0,0,0)',
624
+ plot_bgcolor='rgba(0,0,0,0)'
625
+ ),
626
+ )
627
+ # for k, d_args in edge_annotations.items():
628
+ # fig.add_annotation(x=d_args['pos'][0], y=d_args['pos'][1], text=k, textangle=d_args['angle'])
629
+
630
+ return fig
631
+
632
+ def plot_sankey(self,
633
+ width_var: str = 'mass_dry',
634
+ color_var: Optional[str] = None,
635
+ edge_colormap: Optional[str] = 'copper_r',
636
+ vmin: Optional[float] = None,
637
+ vmax: Optional[float] = None,
638
+ ) -> go.Figure:
639
+ """Plot the Network as a sankey
640
+
641
+ Args:
642
+ width_var: The variable that determines the sankey width
643
+ color_var: The optional variable that determines the sankey edge color
644
+ edge_colormap: The optional colormap. Used with color_var.
645
+ vmin: The value that maps to the minimum color
646
+ vmax: The value that maps to the maximum color
647
+
648
+ Returns:
649
+
650
+ """
651
+ # Create a mapping of node names to indices, and the integer nodes
652
+ node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
653
+ int_graph = nx.relabel_nodes(self.graph, node_indices)
654
+
655
+ # Generate the sankey diagram arguments using the new graph with integer nodes
656
+ d_sankey = self._generate_sankey_args(int_graph, color_var, edge_colormap, width_var, vmin, vmax)
657
+
658
+ # Create the sankey diagram
659
+ node, link = self._get_sankey_node_link_dicts(d_sankey)
660
+ fig = go.Figure(data=[go.Sankey(node=node, link=link)])
661
+ title = self._plot_title()
662
+ fig.update_layout(title_text=title, font_size=10)
663
+ return fig
664
+
665
+ def table_plot(self,
666
+ plot_type: str = 'sankey',
667
+ cols_exclude: Optional[List] = None,
668
+ table_pos: str = 'left',
669
+ table_area: float = 0.4,
670
+ table_header_color: str = 'cornflowerblue',
671
+ table_odd_color: str = 'whitesmoke',
672
+ table_even_color: str = 'lightgray',
673
+ sankey_width_var: str = 'mass_dry',
674
+ sankey_color_var: Optional[str] = None,
675
+ sankey_edge_colormap: Optional[str] = 'copper_r',
676
+ sankey_vmin: Optional[float] = None,
677
+ sankey_vmax: Optional[float] = None,
678
+ network_orientation: Optional[str] = 'horizontal'
679
+ ) -> go.Figure:
680
+ """Plot with table of edge averages
681
+
682
+ Args:
683
+ plot_type: The type of plot ['sankey', 'network']
684
+ cols_exclude: List of columns to exclude from the table
685
+ table_pos: Position of the table ['left', 'right', 'top', 'bottom']
686
+ table_area: The proportion of width or height to allocate to the table [0, 1]
687
+ table_header_color: Color of the table header
688
+ table_odd_color: Color of the odd table rows
689
+ table_even_color: Color of the even table rows
690
+ sankey_width_var: If plot_type is sankey, the variable that determines the sankey width
691
+ sankey_color_var: If plot_type is sankey, the optional variable that determines the sankey edge color
692
+ sankey_edge_colormap: If plot_type is sankey, the optional colormap. Used with sankey_color_var.
693
+ sankey_vmin: The value that maps to the minimum color
694
+ sankey_vmax: The value that maps to the maximum color
695
+ network_orientation: The orientation of the network layout 'vertical'|'horizontal'
696
+
697
+ Returns:
698
+
699
+ """
700
+
701
+ valid_plot_types: List[str] = ['sankey', 'network']
702
+ if plot_type not in valid_plot_types:
703
+ raise ValueError(f'The supplied plot_type is not in {valid_plot_types}')
704
+
705
+ valid_table_pos: List[str] = ['top', 'bottom', 'left', 'right']
706
+ if table_pos not in valid_table_pos:
707
+ raise ValueError(f'The supplied table_pos is not in {valid_table_pos}')
708
+
709
+ d_subplot, d_table, d_plot = self._get_position_kwargs(table_pos, table_area, plot_type)
710
+
711
+ fig = make_subplots(**d_subplot, print_grid=False)
712
+
713
+ df: pd.DataFrame = self.report().reset_index()
714
+ if cols_exclude:
715
+ df = df[[col for col in df.columns if col not in cols_exclude]]
716
+ fmt: List[str] = ['%s'] + list(self._get_column_formats(df.columns, strip_percent=True).values())
717
+ column_widths = [2] + [1] * (len(df.columns) - 1)
718
+
719
+ fig.add_table(
720
+ header=dict(values=list(df.columns),
721
+ fill_color=table_header_color,
722
+ align='center',
723
+ font=dict(color='black', size=12)),
724
+ columnwidth=column_widths,
725
+ cells=dict(values=df.transpose().values.tolist(),
726
+ align='left', format=fmt,
727
+ fill_color=[
728
+ [table_odd_color if i % 2 == 0 else table_even_color for i in range(len(df))] * len(
729
+ df.columns)]),
730
+ **d_table)
731
+
732
+ if plot_type == 'sankey':
733
+ # Create a mapping of node names to indices, and the integer nodes
734
+ node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
735
+ int_graph = nx.relabel_nodes(self.graph, node_indices)
736
+
737
+ # Generate the sankey diagram arguments using the new graph with integer nodes
738
+ d_sankey = self._generate_sankey_args(int_graph, sankey_color_var,
739
+ sankey_edge_colormap,
740
+ sankey_width_var,
741
+ sankey_vmin,
742
+ sankey_vmax)
743
+ node, link = self._get_sankey_node_link_dicts(d_sankey)
744
+ fig.add_trace(go.Sankey(node=node, link=link), **d_plot)
745
+
746
+ elif plot_type == 'network':
747
+ # pos = nx.spring_layout(self, seed=1234)
748
+ pos = digraph_linear_layout(self.graph, orientation=network_orientation)
749
+
750
+ edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
751
+ fig.add_traces(data=[*edge_traces, node_trace, edge_annotation_trace], **d_plot)
752
+
753
+ fig.update_layout(showlegend=False, hovermode='closest',
754
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
755
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
756
+ paper_bgcolor='rgba(0,0,0,0)',
757
+ plot_bgcolor='rgba(0,0,0,0)'
758
+ )
759
+
760
+ title = self._plot_title(compact=True)
761
+ fig.update_layout(title_text=title, font_size=12)
762
+
763
+ return fig
764
+
765
+ def to_dataframe(self, stream_names: Optional[list[str]] = None, tidy: bool = True,
766
+ as_mass: bool = False) -> pd.DataFrame:
767
+ """Return a tidy dataframe
768
+
769
+ Adds the mc name to the index so indexes are unique.
770
+
771
+ Args:
772
+ stream_names: Optional List of names of Stream/MassComposition objects (network edges) for export
773
+ tidy: If True, the data will be returned in a tidy format, otherwise wide
774
+ as_mass: If True, the mass data will be returned instead of the mass-composition data
775
+
776
+ Returns:
777
+
778
+ """
779
+ chunks: List[pd.DataFrame] = []
780
+ for u, v, data in self.graph.edges(data=True):
781
+ if (stream_names is None) or ((stream_names is not None) and (data['mc'].name in stream_names)):
782
+ if as_mass:
783
+ chunks.append(data['mc'].mass_data.assign(name=data['mc'].name))
784
+ else:
785
+ chunks.append(data['mc'].data.assign(name=data['mc'].name))
786
+
787
+ results: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name', append=True)
788
+ if not tidy: # wide format
789
+ results = results.unstack(level='name')
790
+ column_order: list[str] = [f'{name}_{attr}' for name in results.columns.levels[1] for attr in
791
+ results.columns.levels[0]]
792
+ results.columns = [f'{col[1]}_{col[0]}' for col in results.columns]
793
+ results = results[column_order]
794
+
795
+ return results
796
+
797
+ def plot_parallel(self,
798
+ names: Optional[str] = None,
799
+ color: Optional[str] = None,
800
+ vars_include: Optional[List[str]] = None,
801
+ vars_exclude: Optional[List[str]] = None,
802
+ title: Optional[str] = None,
803
+ include_dims: Optional[Union[bool, List[str]]] = True,
804
+ plot_interval_edges: bool = False) -> go.Figure:
805
+ """Create an interactive parallel plot
806
+
807
+ Useful to explore multidimensional data like mass-composition data
808
+
809
+ Args:
810
+ names: Optional List of Names to plot
811
+ color: Optional color variable
812
+ vars_include: Optional List of variables to include in the plot
813
+ vars_exclude: Optional List of variables to exclude in the plot
814
+ title: Optional plot title
815
+ include_dims: Optional boolean or list of dimension to include in the plot. True will show all dims.
816
+ plot_interval_edges: If True, interval edges will be plotted instead of interval mid
817
+
818
+ Returns:
819
+
820
+ """
821
+ df: pd.DataFrame = self.to_dataframe(stream_names=names)
822
+
823
+ if not title and hasattr(self, 'name'):
824
+ title = self.name
825
+
826
+ fig = parallel_plot(data=df, color=color, vars_include=vars_include, vars_exclude=vars_exclude, title=title,
827
+ include_dims=include_dims, plot_interval_edges=plot_interval_edges)
828
+ return fig
829
+
830
+ def _generate_sankey_args(self, int_graph, color_var, edge_colormap, width_var, v_min, v_max):
831
+ rpt: pd.DataFrame = self.report()
832
+ if color_var is not None:
833
+ cmap = sns.color_palette(edge_colormap, as_cmap=True)
834
+ rpt: pd.DataFrame = self.report()
835
+ if not v_min:
836
+ v_min = np.floor(rpt[color_var].min())
837
+ if not v_max:
838
+ v_max = np.ceil(rpt[color_var].max())
839
+
840
+ # run the report for the hover data
841
+ d_custom_data: Dict = self._rpt_to_html(df=rpt)
842
+ source: List = []
843
+ target: List = []
844
+ value: List = []
845
+ edge_custom_data = []
846
+ edge_color: List = []
847
+ edge_labels: List = []
848
+ node_colors: List = []
849
+ node_labels: List = []
850
+
851
+ for n in int_graph.nodes:
852
+ node_labels.append(int_graph.nodes[n]['mc'].name)
853
+
854
+ if int_graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
855
+ if int_graph.nodes[n]['mc'].is_balanced:
856
+ node_colors.append('green')
857
+ else:
858
+ node_colors.append('red')
859
+ else:
860
+ node_colors.append('blue')
861
+
862
+ for u, v, data in int_graph.edges(data=True):
863
+ edge_labels.append(data['mc'].name)
864
+ source.append(u)
865
+ target.append(v)
866
+ value.append(float(data['mc'].aggregate[width_var].iloc[0]))
867
+ edge_custom_data.append(d_custom_data[data['mc'].name])
868
+
869
+ if color_var is not None:
870
+ val: float = float(data['mc'].aggregate[color_var].iloc[0])
871
+ str_color: str = f'rgba{self._color_from_float(v_min, v_max, val, cmap)}'
872
+ edge_color.append(str_color)
873
+ else:
874
+ edge_color: Optional[str] = None
875
+
876
+ d_sankey: Dict = {'node_color': node_colors,
877
+ 'edge_color': edge_color,
878
+ 'edge_custom_data': edge_custom_data,
879
+ 'edge_labels': edge_labels,
880
+ 'labels': node_labels,
881
+ 'source': source,
882
+ 'target': target,
883
+ 'value': value}
884
+
885
+ return d_sankey
886
+
887
+ @staticmethod
888
+ def _get_sankey_node_link_dicts(d_sankey: Dict):
889
+ node: Dict = dict(
890
+ pad=15,
891
+ thickness=20,
892
+ line=dict(color="black", width=0.5),
893
+ label=d_sankey['labels'],
894
+ color=d_sankey['node_color'],
895
+ customdata=d_sankey['labels']
896
+ )
897
+ link: Dict = dict(
898
+ source=d_sankey['source'], # indices correspond to labels, eg A1, A2, A1, B1, ...
899
+ target=d_sankey['target'],
900
+ value=d_sankey['value'],
901
+ color=d_sankey['edge_color'],
902
+ label=d_sankey['edge_labels'], # over-written by hover template
903
+ customdata=d_sankey['edge_custom_data'],
904
+ hovertemplate='<b><i>%{label}</i></b><br />Source: %{source.customdata}<br />'
905
+ 'Target: %{target.customdata}<br />%{customdata}'
906
+ )
907
+ return node, link
908
+
909
+ def _get_scatter_node_edges(self, pos):
910
+ # edges
911
+ edge_color_map: Dict = {True: 'grey', False: 'red'}
912
+ edge_annotations: Dict = {}
913
+
914
+ edge_traces = []
915
+ for u, v, data in self.graph.edges(data=True):
916
+ x0, y0 = pos[u]
917
+ x1, y1 = pos[v]
918
+ edge_annotations[data['mc'].name] = {'pos': np.mean([pos[u], pos[v]], axis=0)}
919
+ edge_traces.append(go.Scatter(x=[x0, x1], y=[y0, y1],
920
+ line=dict(width=2, color=edge_color_map[data['mc'].status.ok]),
921
+ hoverinfo='none',
922
+ mode='lines+markers',
923
+ text=str(data['mc'].name),
924
+ marker=dict(
925
+ symbol="arrow",
926
+ color=edge_color_map[data['mc'].status.ok],
927
+ size=16,
928
+ angleref="previous",
929
+ standoff=15)
930
+ ))
931
+
932
+ # nodes
933
+ node_color_map: Dict = {None: 'grey', True: 'green', False: 'red'}
934
+ node_x = []
935
+ node_y = []
936
+ node_color = []
937
+ node_text = []
938
+ node_label = []
939
+ for node in self.graph.nodes():
940
+ x, y = pos[node]
941
+ node_x.append(x)
942
+ node_y.append(y)
943
+ node_color.append(node_color_map[self.graph.nodes[node]['mc'].is_balanced])
944
+ node_text.append(node)
945
+ node_label.append(self.graph.nodes[node]['mc'].name)
946
+ node_trace = go.Scatter(
947
+ x=node_x, y=node_y,
948
+ mode='markers+text',
949
+ hoverinfo='none',
950
+ marker=dict(
951
+ color=node_color,
952
+ size=30,
953
+ line_width=2),
954
+ text=node_text,
955
+ customdata=node_label,
956
+ hovertemplate='%{customdata}<extra></extra>')
957
+
958
+ # edge annotations
959
+ edge_labels = list(edge_annotations.keys())
960
+ edge_label_x = [edge_annotations[k]['pos'][0] for k, v in edge_annotations.items()]
961
+ edge_label_y = [edge_annotations[k]['pos'][1] for k, v in edge_annotations.items()]
962
+
963
+ edge_annotation_trace = go.Scatter(
964
+ x=edge_label_x, y=edge_label_y,
965
+ mode='markers',
966
+ hoverinfo='text',
967
+ marker=dict(
968
+ color='grey',
969
+ size=3,
970
+ line_width=1),
971
+ text=edge_labels)
972
+
973
+ return edge_traces, node_trace, edge_annotation_trace
974
+
975
+ @staticmethod
976
+ def _get_position_kwargs(table_pos, table_area, plot_type):
977
+ """Helper to manage location dependencies
978
+
979
+ Args:
980
+ table_pos: position of the table: left|right|top|bottom
981
+ table_area: fraction of the plot to assign to the table [0, 1]
982
+
983
+ Returns:
984
+
985
+ """
986
+ name_type_map: Dict = {'sankey': 'sankey', 'network': 'xy'}
987
+ specs = [[{"type": 'table'}, {"type": name_type_map[plot_type]}]]
988
+
989
+ widths: Optional[List[float]] = [table_area, 1.0 - table_area]
990
+ subplot_kwargs: Dict = {'rows': 1, 'cols': 2, 'specs': specs}
991
+ table_kwargs: Dict = {'row': 1, 'col': 1}
992
+ plot_kwargs: Dict = {'row': 1, 'col': 2}
993
+
994
+ if table_pos == 'left':
995
+ subplot_kwargs['column_widths'] = widths
996
+ elif table_pos == 'right':
997
+ subplot_kwargs['column_widths'] = widths[::-1]
998
+ subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}, {"type": 'table'}]]
999
+ table_kwargs['col'] = 2
1000
+ plot_kwargs['col'] = 1
1001
+ else:
1002
+ subplot_kwargs['rows'] = 2
1003
+ subplot_kwargs['cols'] = 1
1004
+ table_kwargs['col'] = 1
1005
+ plot_kwargs['col'] = 1
1006
+ if table_pos == 'top':
1007
+ subplot_kwargs['row_heights'] = widths
1008
+ subplot_kwargs['specs'] = [[{"type": 'table'}], [{"type": name_type_map[plot_type]}]]
1009
+ table_kwargs['row'] = 1
1010
+ plot_kwargs['row'] = 2
1011
+ elif table_pos == 'bottom':
1012
+ subplot_kwargs['row_heights'] = widths[::-1]
1013
+ subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}], [{"type": 'table'}]]
1014
+ table_kwargs['row'] = 2
1015
+ plot_kwargs['row'] = 1
1016
+
1017
+ if plot_type == 'network': # different arguments for different plots
1018
+ plot_kwargs = {f'{k}s': v for k, v in plot_kwargs.items()}
1019
+
1020
+ return subplot_kwargs, table_kwargs, plot_kwargs
1021
+
1022
+ def _rpt_to_html(self, df: pd.DataFrame) -> Dict:
1023
+ custom_data: Dict = {}
1024
+ fmts: Dict = self._get_column_formats(df.columns)
1025
+ for i, row in df.iterrows():
1026
+ str_data: str = '<br />'
1027
+ for k, v in dict(row).items():
1028
+ str_data += f"{k}: {v:{fmts[k][1:]}}<br />"
1029
+ custom_data[i] = str_data
1030
+ return custom_data
1031
+
1032
+ @staticmethod
1033
+ def _color_from_float(vmin: float, vmax: float, val: float,
1034
+ cmap: Union[ListedColormap, LinearSegmentedColormap]) -> Tuple[float, float, float]:
1035
+ if isinstance(cmap, ListedColormap):
1036
+ color_index: int = int((val - vmin) / ((vmax - vmin) / 256.0))
1037
+ color_index = min(max(0, color_index), 255)
1038
+ color_rgba = tuple(cmap.colors[color_index])
1039
+ elif isinstance(cmap, LinearSegmentedColormap):
1040
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
1041
+ m = cm.ScalarMappable(norm=norm, cmap=cmap)
1042
+ r, g, b, a = m.to_rgba(val, bytes=True)
1043
+ color_rgba = int(r), int(g), int(b), int(a)
1044
+ else:
1045
+ NotImplementedError("Unrecognised colormap type")
1046
+
1047
+ return color_rgba
1048
+
1049
+ def set_node_names(self, node_names: Dict[int, str]):
1050
+ """Set the names of network nodes with a Dict
1051
+ """
1052
+ for node in node_names.keys():
1053
+ if ('mc' in self.graph.nodes[node].keys()) and (node in node_names.keys()):
1054
+ self.graph.nodes[node]['mc'].name = node_names[node]
1055
+
1056
+
1057
+ def set_stream_data(self, stream_data: dict[str, Optional[MC]]):
1058
+ """Set the data (MassComposition) of network edges (streams) with a Dict"""
1059
+ for stream_name, stream_data in stream_data.items():
1060
+ stream_found = False
1061
+ nodes_to_refresh = set()
1062
+ for u, v, data in self.graph.edges(data=True):
1063
+ if 'mc' in data.keys() and (data['mc'].name if data['mc'] is not None else data['name']) == stream_name:
1064
+ self._logger.info(f'Setting data on stream {stream_name}')
1065
+ data['mc'] = stream_data
1066
+ stream_found = True
1067
+ nodes_to_refresh.update([u, v])
1068
+ if not stream_found:
1069
+ self._logger.warning(f'Stream {stream_name} not found in graph')
1070
+ else:
1071
+ # refresh the node status
1072
+ for node in nodes_to_refresh:
1073
+ self.graph.nodes[node]['mc'].inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
1074
+ self.graph.in_edges(node)]
1075
+ self.graph.nodes[node]['mc'].outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
1076
+ self.graph.out_edges(node)]
1077
+
1078
+
1079
+ def set_operation_data(self, node):
1080
+ """Set the input and output data for a node.
1081
+ Uses the data on the edges (streams) connected to the node to refresh the data and check for node balance.
1082
+ """
1083
+ node_data: Operation = self.graph.nodes[node]['mc']
1084
+ node_data.inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.in_edges(node)]
1085
+ node_data.outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.out_edges(node)]
1086
+ node_data.check_balance()
1087
+
1088
+
1089
+ def streams_to_dict(self) -> Dict[str, MC]:
1090
+ """Export the Stream objects to a Dict
1091
+
1092
+ Returns:
1093
+ A dictionary keyed by name containing MassComposition objects
1094
+
1095
+ """
1096
+ streams: Dict[str, MC] = {}
1097
+ for u, v, data in self.graph.edges(data=True):
1098
+ if 'mc' in data.keys():
1099
+ streams[data['mc'].name] = data['mc']
1100
+ return streams
1101
+
1102
+
1103
+ def nodes_to_dict(self) -> Dict[int, OP]:
1104
+ """Export the MCNode objects to a Dict
1105
+
1106
+ Returns:
1107
+ A dictionary keyed by integer containing MCNode objects
1108
+
1109
+ """
1110
+ nodes: Dict[int, OP] = {}
1111
+ for node in self.graph.nodes.keys():
1112
+ if 'mc' in self.graph.nodes[node].keys():
1113
+ nodes[node] = self.graph.nodes[node]['mc']
1114
+ return nodes
1115
+
1116
+
1117
+ def set_nodes(self, stream: str, nodes: Tuple[int, int]):
1118
+ mc: MC = self.get_stream_by_name(stream)
1119
+ mc._nodes = nodes
1120
+ self._update_graph(mc)
1121
+
1122
+
1123
+ def reset_nodes(self, stream: Optional[str] = None):
1124
+ """Reset stream nodes to break relationships
1125
+
1126
+ Args:
1127
+ stream: The optional stream (edge) within the network.
1128
+ If None all streams nodes on the network will be reset.
1129
+
1130
+
1131
+ Returns:
1132
+
1133
+ """
1134
+ if stream is None:
1135
+ streams: Dict[str, MC] = self.streams_to_dict()
1136
+ for k, v in streams.items():
1137
+ streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
1138
+ self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
1139
+ else:
1140
+ mc: MC = self.get_stream_by_name(stream)
1141
+ mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
1142
+ self._update_graph(mc)
1143
+
1144
+
1145
+ def _update_graph(self, mc: MC):
1146
+ """Update the graph with an existing stream object
1147
+
1148
+ Args:
1149
+ mc: The stream object
1150
+
1151
+ Returns:
1152
+
1153
+ """
1154
+ # brutal approach - rebuild from streams
1155
+ strms: List[Union[Stream, MC]] = []
1156
+ for u, v, a in self.graph.edges(data=True):
1157
+ if a.get('mc') and a['mc'].name == mc.name:
1158
+ strms.append(mc)
1159
+ else:
1160
+ strms.append(a['mc'])
1161
+ self.graph = Flowsheet(name=self.name).from_objects(objects=strms).graph
1162
+
1163
+
1164
+ def get_stream_by_name(self, name: str) -> MC:
1165
+ """Get the Stream object from the network by its name
1166
+
1167
+ Args:
1168
+ name: The string name of the Stream object stored on an edge in the network.
1169
+
1170
+ Returns:
1171
+
1172
+ """
1173
+
1174
+ res: Optional[Union[Stream, MC]] = None
1175
+ for u, v, a in self.graph.edges(data=True):
1176
+ if a.get('mc') and a['mc'].name == name:
1177
+ res = a['mc']
1178
+
1179
+ if not res:
1180
+ raise ValueError(f"The specified name: {name} is not found on the network.")
1181
+
1182
+ return res
1183
+
1184
+
1185
+ def set_stream_parent(self, stream: str, parent: str):
1186
+ mc: MC = self.get_stream_by_name(stream)
1187
+ mc.set_parent_node(self.get_stream_by_name(parent))
1188
+ self._update_graph(mc)
1189
+
1190
+
1191
+ def set_stream_child(self, stream: str, child: str):
1192
+ mc: MC = self.get_stream_by_name(stream)
1193
+ mc.set_child_node(self.get_stream_by_name(child))
1194
+ self._update_graph(mc)
1195
+
1196
+
1197
+ def reset_stream_nodes(self, stream: Optional[str] = None):
1198
+ """Reset stream nodes to break relationships
1199
+
1200
+ Args:
1201
+ stream: The optional stream (edge) within the network.
1202
+ If None all streams nodes on the network will be reset.
1203
+
1204
+
1205
+ Returns:
1206
+
1207
+ """
1208
+ if stream is None:
1209
+ streams: Dict[str, MC] = self.streams_to_dict()
1210
+ for k, v in streams.items():
1211
+ streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
1212
+ self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
1213
+ else:
1214
+ mc: MC = self.get_stream_by_name(stream)
1215
+ mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
1216
+ self._update_graph(mc)