edges 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of edges might be problematic. Click here for more details.

Files changed (66) hide show
  1. edges/__init__.py +9 -2
  2. edges/data/AWARE 2.0_Country_all_yearly.json +8 -1
  3. edges/data/AWARE 2.0_Country_irri_yearly.json +8 -1
  4. edges/data/AWARE 2.0_Country_non_irri_yearly.json +8 -1
  5. edges/data/AWARE 2.0_Country_unspecified_yearly.json +8 -1
  6. edges/data/GeoPolRisk_paired_2024.json +7 -0
  7. edges/data/ImpactWorld+ 2.1_Freshwater acidification_damage.json +8 -1
  8. edges/data/ImpactWorld+ 2.1_Freshwater acidification_midpoint.json +8 -1
  9. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity, long term_damage.json +8 -1
  10. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity, short term_damage.json +8 -1
  11. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity_midpoint.json +8 -1
  12. edges/data/ImpactWorld+ 2.1_Freshwater eutrophication_damage.json +8 -1
  13. edges/data/ImpactWorld+ 2.1_Freshwater eutrophication_midpoint.json +8 -1
  14. edges/data/ImpactWorld+ 2.1_Land occupation, biodiversity_damage.json +8 -1
  15. edges/data/ImpactWorld+ 2.1_Land occupation, biodiversity_midpoint.json +8 -1
  16. edges/data/ImpactWorld+ 2.1_Land transformation, biodiversity_damage.json +8 -1
  17. edges/data/ImpactWorld+ 2.1_Land transformation, biodiversity_midpoint.json +8 -1
  18. edges/data/ImpactWorld+ 2.1_Marine ecotoxicity, long term_damage.json +8 -1
  19. edges/data/ImpactWorld+ 2.1_Marine ecotoxicity, short term_damage.json +8 -1
  20. edges/data/ImpactWorld+ 2.1_Marine eutrophication_damage.json +8 -1
  21. edges/data/ImpactWorld+ 2.1_Marine eutrophication_midpoint.json +8 -1
  22. edges/data/ImpactWorld+ 2.1_Particulate matter formation_damage.json +8 -1
  23. edges/data/ImpactWorld+ 2.1_Particulate matter formation_midpoint.json +8 -1
  24. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation, ecosystem quality_damage.json +8 -1
  25. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation, human health_damage.json +8 -1
  26. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation_midpoint.json +8 -1
  27. edges/data/ImpactWorld+ 2.1_Terrestrial acidification_damage.json +8 -1
  28. edges/data/ImpactWorld+ 2.1_Terrestrial acidification_midpoint.json +8 -1
  29. edges/data/ImpactWorld+ 2.1_Terrestrial ecotoxicity, long term_damage.json +8 -1
  30. edges/data/ImpactWorld+ 2.1_Terrestrial ecotoxicity, short term_damage.json +8 -1
  31. edges/data/ImpactWorld+ 2.1_Thermally polluted water_damage.json +8 -1
  32. edges/data/ImpactWorld+ 2.1_Water availability, freshwater ecosystem_damage.json +8 -1
  33. edges/data/ImpactWorld+ 2.1_Water availability, human health_damage.json +8 -1
  34. edges/data/ImpactWorld+ 2.1_Water availability, terrestrial ecosystem_damage.json +8 -1
  35. edges/data/ImpactWorld+ 2.1_Water scarcity_midpoint.json +8 -1
  36. edges/data/LCC 1.0_2023.json +8 -1
  37. edges/data/RELICS_copper_primary.json +44 -0
  38. edges/data/RELICS_copper_secondary.json +42 -0
  39. edges/data/SCP_1.0.json +4 -1
  40. edges/edgelcia.py +2113 -816
  41. edges/flow_matching.py +344 -130
  42. edges/georesolver.py +61 -2
  43. edges/supply_chain.py +2052 -0
  44. edges/uncertainty.py +37 -8
  45. {edges-1.0.1.dist-info → edges-1.0.3.dist-info}/METADATA +5 -2
  46. edges-1.0.3.dist-info/RECORD +57 -0
  47. edges/data/GeoPolRisk_elementary flows_2024.json +0 -877
  48. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity, long term_midpoint.json +0 -5
  49. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity, short term_midpoint.json +0 -5
  50. edges/data/ImpactWorld+ 2.1_Freshwater ecotoxicity_damage.json +0 -0
  51. edges/data/ImpactWorld+ 2.1_Marine ecotoxicity, long term_midpoint.json +0 -5
  52. edges/data/ImpactWorld+ 2.1_Marine ecotoxicity, short term_midpoint.json +0 -5
  53. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation, ecosystem quality_midpoint.json +0 -5
  54. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation, human health_midpoint.json +0 -5
  55. edges/data/ImpactWorld+ 2.1_Photochemical ozone formation_damage.json +0 -5
  56. edges/data/ImpactWorld+ 2.1_Terrestrial ecotoxicity, long term_midpoint.json +0 -5
  57. edges/data/ImpactWorld+ 2.1_Terrestrial ecotoxicity, short term_midpoint.json +0 -5
  58. edges/data/ImpactWorld+ 2.1_Thermally polluted water_midpoint.json +0 -5
  59. edges/data/ImpactWorld+ 2.1_Water availability, freshwater ecosystem_midpoint.json +0 -5
  60. edges/data/ImpactWorld+ 2.1_Water availability, human health_midpoint.json +0 -5
  61. edges/data/ImpactWorld+ 2.1_Water availability, terrestrial ecosystem_midpoint.json +0 -5
  62. edges/data/ImpactWorld+ 2.1_Water scarcity_damage.json +0 -5
  63. edges/data/RELICS_copper.json +0 -22
  64. edges-1.0.1.dist-info/RECORD +0 -71
  65. {edges-1.0.1.dist-info → edges-1.0.3.dist-info}/WHEEL +0 -0
  66. {edges-1.0.1.dist-info → edges-1.0.3.dist-info}/top_level.txt +0 -0
edges/supply_chain.py ADDED
@@ -0,0 +1,2052 @@
1
+ # edge_supply_chain.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, asdict
5
+ from io import StringIO
6
+ import textwrap as _tw
7
+ import collections
8
+
9
+ import math
10
+ import time
11
+ from typing import Optional, Sequence, Tuple, Dict, Any, List
12
+
13
+ import pandas as pd
14
+ import plotly.graph_objects as go
15
+
16
+ import os
17
+ from pathlib import Path
18
+ import html
19
+ import re
20
+ import plotly.io as pio
21
+
22
+
23
+ try:
24
+ from bw2data.backends.peewee import Activity
25
+ except ImportError: # bw2data >= 4.0
26
+ from bw2data.backends import Activity
27
+
28
+ from .edgelcia import EdgeLCIA
29
+
30
+ from bw2data import __version__ as bw2data_version
31
+
32
+ if isinstance(bw2data_version, str):
33
+ bw2data_version = tuple(map(int, bw2data_version.split(".")))
34
+
35
+ if bw2data_version >= (4, 0, 0):
36
+ is_bw25 = True
37
+ else:
38
+ is_bw25 = False
39
+
40
+
41
+ # --- helpers for labels
42
+ def truncate_one_line(text: str, max_chars: int) -> str:
43
+ if text is None:
44
+ return ""
45
+ s = str(text).strip()
46
+ if len(s) <= max_chars:
47
+ return s
48
+ return s[: max(0, max_chars - 1)] + "…"
49
+
50
+
51
+ def make_label_two_lines(name: str, location: str, name_chars: int) -> str:
52
+ """Line 1: truncated name (single line). Line 2: full location, never truncated."""
53
+ n = truncate_one_line(name or "", name_chars)
54
+ loc = "" if (location is None or pd.isna(location)) else str(location).strip()
55
+ return f"{n}\n{loc}" if loc else n
56
+
57
+
58
+ def _is_market_name(val: Any) -> bool:
59
+ """True if the activity 'name' looks like a market node."""
60
+ if pd.isna(val):
61
+ return False
62
+ s = str(val).strip().lower()
63
+ return s.startswith("market for ") or s.startswith("market group for ")
64
+
65
+
66
+ # --- Multi-method (multi-impact) HTML export ---------------------------------
67
+ def save_sankey_html_multi(
68
+ label_to_df: Dict[str, pd.DataFrame],
69
+ path: str,
70
+ *,
71
+ title: str = "Supply chain Sankey — multiple impact categories",
72
+ offline: bool = True,
73
+ auto_open: bool = True,
74
+ plot_kwargs: Optional[Dict[str, Any]] = None,
75
+ modebar_remove: tuple = ("lasso2d", "select2d"),
76
+ ) -> str:
77
+ """
78
+ Save several Sankey figures (one per impact category) into a single tabbed HTML.
79
+
80
+ Parameters
81
+ ----------
82
+ label_to_df : {label: DataFrame}
83
+ Keys are tab labels (e.g., method names); values are the dataframes to plot.
84
+ path : str
85
+ Output file path; '.html' will be appended if missing.
86
+ title : str
87
+ Browser tab title.
88
+ offline : bool
89
+ If True, embed plotly.js once into the file. Otherwise load from CDN.
90
+ auto_open : bool
91
+ If True, open the file in a browser after writing.
92
+ plot_kwargs : dict
93
+ Extra kwargs forwarded to sankey_from_supply_df (e.g., width_max, height_max).
94
+ modebar_remove : tuple
95
+ Modebar buttons to remove in each figure.
96
+
97
+ Returns
98
+ -------
99
+ str : the file path written.
100
+ """
101
+
102
+ plot_kwargs = plot_kwargs or {}
103
+ if not path.lower().endswith(".html"):
104
+ path += ".html"
105
+ Path(os.path.dirname(path) or ".").mkdir(parents=True, exist_ok=True)
106
+
107
+ # Build one figure per label
108
+ pieces: List[tuple[str, str]] = []
109
+ include = "cdn" if not offline else True
110
+ config = {"displaylogo": False, "modeBarButtonsToRemove": list(modebar_remove)}
111
+
112
+ def _slug(s: str) -> str:
113
+ s2 = re.sub(r"\s+", "-", s.strip())
114
+ s2 = re.sub(r"[^A-Za-z0-9\-_]", "", s2)
115
+ return s2 or "tab"
116
+
117
+ first = True
118
+ for label, df in label_to_df.items():
119
+ fig = sankey_from_supply_df(df, **plot_kwargs)
120
+ # include plotly.js only once
121
+ html_snippet = pio.to_html(
122
+ fig,
123
+ include_plotlyjs=(include if first else False),
124
+ full_html=False,
125
+ config=config,
126
+ )
127
+ pieces.append((label, html_snippet))
128
+ first = False
129
+
130
+ # Simple tab UI (CSS+JS) and body with all figures (hidden except first)
131
+ # We wrap each snippet in a container <div class="tab-pane"> and switch display via JS.
132
+ tabs_html = []
133
+ panes_html = []
134
+ for i, (label, snippet) in enumerate(pieces):
135
+ tab_id = f"tab-{_slug(label)}"
136
+ active = "active" if i == 0 else ""
137
+ tabs_html.append(
138
+ f'<button class="tab-btn {active}" onclick="showTab(\'{tab_id}\', this)">{html.escape(label)}</button>'
139
+ )
140
+ panes_html.append(
141
+ f'<div id="{tab_id}" class="tab-pane {active}">{snippet}</div>'
142
+ )
143
+
144
+ full = f"""<!DOCTYPE html>
145
+ <html>
146
+ <head>
147
+ <meta charset="utf-8"/>
148
+ <title>{html.escape(title)}</title>
149
+ <style>
150
+ body {{ font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; margin: 0; }}
151
+ .tabs {{ position: sticky; top: 0; background: #fafafa; border-bottom: 1px solid #eee; padding: 8px; z-index: 10; display: flex; flex-wrap: wrap; gap: 6px; }}
152
+ .tab-btn {{ border: 1px solid #ddd; background: #fff; border-radius: 6px; padding: 6px 10px; cursor: pointer; }}
153
+ .tab-btn.active {{ background: #0d6efd; color: white; border-color: #0d6efd; }}
154
+ .tab-pane {{ display: none; padding: 8px; }}
155
+ .tab-pane.active {{ display: block; }}
156
+ </style>
157
+ </head>
158
+ <body>
159
+ <div class="tabs">
160
+ {''.join(tabs_html)}
161
+ </div>
162
+ {''.join(panes_html)}
163
+ <script>
164
+ function showTab(id, btn) {{
165
+ document.querySelectorAll('.tab-pane').forEach(p => p.classList.remove('active'));
166
+ document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
167
+ const el = document.getElementById(id);
168
+ if (el) el.classList.add('active');
169
+ if (btn) btn.classList.add('active');
170
+ // Force Plotly to resize when switching tabs (in case container size changed)
171
+ if (window.Plotly && el) {{
172
+ el.querySelectorAll('.js-plotly-plot').forEach(plot => {{
173
+ try {{ window.Plotly.Plots.resize(plot); }} catch(e) {{}}
174
+ }});
175
+ }}
176
+ }}
177
+ // Ensure first tab active on load
178
+ (function() {{
179
+ const firstPane = document.querySelector('.tab-pane');
180
+ const firstBtn = document.querySelector('.tab-btn');
181
+ if (firstPane) firstPane.classList.add('active');
182
+ if (firstBtn) firstBtn.classList.add('active');
183
+ }})();
184
+ </script>
185
+ </body>
186
+ </html>"""
187
+
188
+ with open(path, "w", encoding="utf-8") as f:
189
+ f.write(full)
190
+
191
+ if auto_open:
192
+ import webbrowser
193
+
194
+ webbrowser.open(f"file://{os.path.abspath(path)}")
195
+ return path
196
+
197
+
198
+ def sankey_from_supply_df(
199
+ df: pd.DataFrame,
200
+ *,
201
+ col_level: str = "level",
202
+ col_id: str = "activity_key",
203
+ col_parent: str = "parent_key",
204
+ col_name: str = "name",
205
+ col_location: str = "location",
206
+ col_score: str = "score",
207
+ col_amount: str = "amount",
208
+ wrap_chars: int = 18,
209
+ max_label_lines: int = 2,
210
+ add_toggle: bool = True,
211
+ base_height: int = 380,
212
+ per_level_px: int = 110,
213
+ per_node_px: int = 6,
214
+ height_min: int = 460,
215
+ height_max: int = 1200,
216
+ auto_width: bool = False,
217
+ per_level_width: int = 250,
218
+ per_node_width: int = 2,
219
+ width_min: int = 900,
220
+ width_max: Optional[int] = None,
221
+ node_thickness: int = 18,
222
+ node_pad: int = 12,
223
+ lock_x_by_level: bool = True,
224
+ balance_mode: str = "none",
225
+ palette: Sequence[str] = (
226
+ "#636EFA",
227
+ "#EF553B",
228
+ "#00CC96",
229
+ "#AB63FA",
230
+ "#FFA15A",
231
+ "#19D3F3",
232
+ "#FF6692",
233
+ "#B6E880",
234
+ "#FF97FF",
235
+ "#FECB52",
236
+ ),
237
+ # Category colors
238
+ color_direct: str = "#E53935",
239
+ color_below: str = "#FB8C00",
240
+ color_loss: str = "#FDD835",
241
+ color_other: str = "#9E9E9E",
242
+ col_ref_product: str = "reference product",
243
+ enable_highlight: bool = True,
244
+ highlight_top_k: int = 25,
245
+ highlight_alpha_on: float = 0.9,
246
+ highlight_alpha_off: float = 0.08,
247
+ node_instance_mode: str = "merge", # "merge" | "by_parent" | "by_child_level"
248
+ ) -> go.Figure:
249
+ """Sankey with last-level specials, untruncated hover labels, per-parent outgoing balancing, and tidy UI."""
250
+ if df.empty:
251
+ raise ValueError("Empty DataFrame")
252
+
253
+ df = df.copy()
254
+
255
+ for c in [col_level, col_name, col_score, col_id, col_parent]:
256
+ if c not in df.columns:
257
+ raise KeyError(f"Missing required column '{c}' in df")
258
+
259
+ if col_location not in df.columns:
260
+ df[col_location] = ""
261
+ else:
262
+ df[col_location] = df[col_location].apply(
263
+ lambda x: "" if (pd.isna(x) or x is None) else str(x)
264
+ )
265
+
266
+ # Root total for %
267
+ try:
268
+ root = df.loc[df[col_level] == df[col_level].min()].iloc[0]
269
+ total_root_score = float(root[col_score])
270
+ except Exception:
271
+ total_root_score = float(df[col_score].abs().max())
272
+
273
+ # Helpers
274
+
275
+ def _rgba_with_alpha(c: str, a: float) -> str:
276
+ c = str(c)
277
+ if c.startswith("rgba("):
278
+ # replace alpha
279
+ parts = c[5:-1].split(",")
280
+ if len(parts) >= 4:
281
+ parts = [p.strip() for p in parts[:3]] + [f"{a:.3f}"]
282
+ return f"rgba({','.join(parts)})"
283
+ if c.startswith("#"):
284
+ return hex_to_rgba(c, a)
285
+ # fallback: try to parse "rgb(r,g,b)"
286
+ if c.startswith("rgb("):
287
+ parts = c[4:-1].split(",")
288
+ if len(parts) == 3:
289
+ parts = [p.strip() for p in parts]
290
+ return f"rgba({parts[0]},{parts[1]},{parts[2]},{a:.3f})"
291
+ # last resort: force to grey w/ alpha
292
+ return f"rgba(150,150,150,{a:.3f})"
293
+
294
+ def _append_tag_to_label(lbl: str, tag: str) -> str:
295
+ """Append a short tag to the *last* line of a 1–2 line label."""
296
+ parts = lbl.split("\n")
297
+ if not parts:
298
+ return tag
299
+ parts[-1] = f"{parts[-1]} {tag}"
300
+ return "\n".join(parts)
301
+
302
+ def _normalize_special(raw: Any) -> Optional[str]:
303
+ if pd.isna(raw):
304
+ return None
305
+ s = str(raw).strip().lower()
306
+ if not s:
307
+ return None
308
+ # Direct emissions variants
309
+ if s.startswith("direct emissions"):
310
+ # accept "direct emissions", "direct emissions/res. use", etc.
311
+ return "direct emissions"
312
+ # Below cutoff variants
313
+ if s in {"activities below cutoff", "below cutoff"}:
314
+ return "activities below cutoff"
315
+ # Loss
316
+ if s == "loss":
317
+ return "loss"
318
+ return None
319
+
320
+ special_names = {
321
+ "direct emissions": "Direct emissions/Res. use",
322
+ "activities below cutoff": "Activities below cutoff",
323
+ "loss": "Loss",
324
+ }
325
+ SPECIAL_NODE_COLOR = {
326
+ "direct emissions": color_direct,
327
+ "activities below cutoff": color_below,
328
+ "loss": color_loss,
329
+ }
330
+
331
+ def is_special(nm: Any) -> bool:
332
+ return _normalize_special(nm) in special_names
333
+
334
+ def special_key(row) -> Optional[Tuple[str, str]]:
335
+ nm = _normalize_special(row[col_name])
336
+ return (nm, "__GLOBAL__") if nm else None
337
+
338
+ def special_label(nm: str) -> str:
339
+ return special_names[nm]
340
+
341
+ def fallback_key(idx, r):
342
+ ak = r.get(col_id)
343
+ if pd.notna(ak):
344
+ return ak
345
+ return (r.get(col_name), r.get(col_location), r.get("unit"), int(idx))
346
+
347
+ def hex_to_rgba(h: str, a: float) -> str:
348
+ h = h.lstrip("#")
349
+ if len(h) == 3:
350
+ h = "".join([c * 2 for c in h])
351
+ r = int(h[0:2], 16)
352
+ g = int(h[2:4], 16)
353
+ b = int(h[4:6], 16)
354
+ return f"rgba({r},{g},{b},{a})"
355
+
356
+ def palette_cycle(i: int, base: Sequence[str]) -> str:
357
+ return base[i % len(base)]
358
+
359
+ # Collect full ref-product text per node index for hover
360
+ node_full_refprod: Dict[int, str] = {}
361
+
362
+ def _row_refprod(row) -> str:
363
+ """Best-effort: prefer explicit column; else try second item of activity_key tuple."""
364
+ # explicit column
365
+ if col_ref_product in df.columns:
366
+ val = row.get(col_ref_product, None)
367
+ if pd.notna(val) and val is not None and str(val).strip():
368
+ return str(val).strip()
369
+ # infer from activity_key tuple (name, reference product, location)
370
+ ak = row.get(col_id, None)
371
+ if (
372
+ isinstance(ak, tuple)
373
+ and len(ak) >= 2
374
+ and pd.notna(ak[1])
375
+ and ak[1] is not None
376
+ ):
377
+ s = str(ak[1]).strip()
378
+ if s:
379
+ return s
380
+ return ""
381
+
382
+ df["_is_special"] = df[col_name].apply(is_special)
383
+
384
+ # Columns (specials live in the *last real* level)
385
+ # Columns: compute the column set from NON-SPECIAL nodes
386
+ # so specials don't create their own extra column.
387
+ # (We compute df["_is_special"] above with is_special/normalizer.)
388
+ levels_all = sorted(int(l) for l in df[col_level].unique())
389
+
390
+ non_special_mask = ~df["_is_special"]
391
+ levels_real = (
392
+ sorted(int(l) for l in df.loc[non_special_mask, col_level].unique())
393
+ if non_special_mask.any()
394
+ else []
395
+ )
396
+
397
+ levels = levels_real if levels_real else levels_all
398
+
399
+ col_index = {L: i for i, L in enumerate(levels)}
400
+ ncols = len(levels)
401
+ max_real_level = levels[-1] # last level among REGULAR nodes
402
+ last_col = col_index[max_real_level]
403
+ level_to_color = {lvl: i for i, lvl in enumerate(levels)}
404
+
405
+ # Build nodes (visible truncated labels + full label for hover)
406
+ labels_vis: List[str] = []
407
+ colors_full: List[str] = []
408
+ x_full: List[float] = []
409
+ key_to_idx: Dict[Any, int] = {}
410
+ node_full_name: Dict[int, str] = {}
411
+ node_full_loc: Dict[int, str] = {}
412
+
413
+ node_key_by_idx: Dict[Any, Any] = {}
414
+
415
+ def _node_key_for_row(idx, r) -> Any:
416
+ if r["_is_special"]:
417
+ return special_key(r) # keep specials global
418
+
419
+ base = fallback_key(idx, r) # usually the activity_key tuple
420
+ mode = (node_instance_mode or "merge").lower()
421
+
422
+ if mode == "merge":
423
+ return base
424
+
425
+ elif mode == "by_parent":
426
+ # one instance per (activity, parent)
427
+ return "by_parent", base, r.get(col_parent)
428
+
429
+ elif mode == "by_child_level":
430
+ # one instance per (activity, child level) – useful if the same activity appears at multiple levels
431
+ return "by_level", base, int(r[col_level])
432
+
433
+ else:
434
+ # fallback to old behavior if an unknown value is passed
435
+ return base
436
+
437
+ for i, r in df.sort_values([col_level]).iterrows():
438
+ L = int(r[col_level])
439
+ full_name = str(r[col_name]) if pd.notna(r[col_name]) else ""
440
+ full_loc = (
441
+ str(r.get(col_location, "")) if pd.notna(r.get(col_location, "")) else ""
442
+ )
443
+
444
+ if r["_is_special"]:
445
+ key = special_key(r) # e.g., ("direct emissions","__GLOBAL__")
446
+ label_disp = special_label(key[0])
447
+ x_val = last_col / max(1, (ncols - 1)) if ncols > 1 else 0.0
448
+ color = SPECIAL_NODE_COLOR.get(key[0], palette[0])
449
+ else:
450
+ key = _node_key_for_row(i, r)
451
+ label_disp = make_label_two_lines(full_name, full_loc, wrap_chars)
452
+ L_eff = L if L in col_index else max_real_level
453
+ x_val = col_index[L_eff] / max(1, (ncols - 1)) if ncols > 1 else 0.0
454
+ color = palette[level_to_color.get(L_eff, 0) % len(palette)]
455
+
456
+ vis_lbl = wrap_label(label_disp, wrap_chars, max_label_lines)
457
+
458
+ if key not in key_to_idx:
459
+ idx = len(labels_vis)
460
+ key_to_idx[key] = idx
461
+ labels_vis.append(vis_lbl)
462
+ colors_full.append(color)
463
+ x_full.append(float(max(0.0, min(0.999, x_val))))
464
+ node_full_name[idx] = full_name
465
+ node_full_loc[idx] = full_loc
466
+ node_full_refprod[idx] = "" if r["_is_special"] else _row_refprod(r)
467
+ node_key_by_idx[i] = key
468
+
469
+ df["_node_key"] = pd.Series(
470
+ (node_key_by_idx.get(i) for i in df.index),
471
+ index=df.index,
472
+ dtype="object",
473
+ )
474
+
475
+ rowid_to_nodeidx = {}
476
+ if "row_id" in df.columns:
477
+ for _, r in df.iterrows():
478
+ rid = r.get("row_id", None)
479
+ if rid is not None and not pd.isna(rid):
480
+ rowid_to_nodeidx[int(rid)] = key_to_idx[r["_node_key"]]
481
+
482
+ # --- Per-node score/share for node hover ------------------------------------
483
+ node_score_by_idx = collections.defaultdict(float)
484
+ node_share_by_idx = collections.defaultdict(float)
485
+
486
+ accumulate_for_merge = (node_instance_mode or "merge").lower() == "merge"
487
+ for _, r in df.iterrows():
488
+ idx = key_to_idx.get(r["_node_key"])
489
+ if idx is None:
490
+ continue
491
+ sc = float(r.get(col_score) or 0.0)
492
+ sh = (
493
+ float(r.get("share_of_total"))
494
+ if "share_of_total" in df.columns and pd.notna(r.get("share_of_total"))
495
+ else ((sc / total_root_score) if total_root_score else 0.0)
496
+ )
497
+ if r["_is_special"] or accumulate_for_merge:
498
+ node_score_by_idx[idx] += sc
499
+ node_share_by_idx[idx] += sh
500
+ else:
501
+ node_score_by_idx[idx] = sc
502
+ node_share_by_idx[idx] = sh
503
+
504
+ # Build link rows (always include below-cutoff)
505
+ def link_rows():
506
+ out = []
507
+
508
+ # Fast lookup by row_id if present
509
+ df_by_rowid = None
510
+ if "row_id" in df.columns:
511
+ df_by_rowid = {
512
+ int(rr["row_id"]): rr
513
+ for _, rr in df.iterrows()
514
+ if rr.get("row_id") is not None and not pd.isna(rr.get("row_id"))
515
+ }
516
+
517
+ for _, r in df.iterrows():
518
+ s_idx, prow = None, None
519
+
520
+ # Preferred: wire by parent_row_id (exact instance)
521
+ pri = r.get("parent_row_id", None)
522
+ if pri is not None and not pd.isna(pri) and df_by_rowid is not None:
523
+ pri = int(pri)
524
+ prow = df_by_rowid.get(pri)
525
+ if prow is not None:
526
+ s_idx = rowid_to_nodeidx.get(pri)
527
+
528
+ # Fallback: wire by parent_key (may merge instances)
529
+ if s_idx is None:
530
+ pid = r.get(col_parent)
531
+ if pd.isna(pid) or pid is None:
532
+ continue
533
+ prows = df.loc[df[col_id] == pid]
534
+ if prows.empty:
535
+ continue
536
+ prow = prows.iloc[0]
537
+ parent_key = prow["_node_key"]
538
+ s_idx = key_to_idx.get(parent_key)
539
+
540
+ # Target (child)
541
+ t_idx = key_to_idx.get(r["_node_key"])
542
+ if s_idx is None or t_idx is None:
543
+ continue
544
+
545
+ v = float(r[col_score] or 0.0)
546
+ if v == 0:
547
+ continue
548
+
549
+ out.append((s_idx, t_idx, v, prow, r))
550
+ return out
551
+
552
+ rows_all = link_rows()
553
+
554
+ # --- adjacency for highlight ---------------------------------------------
555
+ from collections import defaultdict, deque
556
+
557
+ # rows_all: list of (s_idx, t_idx, v_signed, prow, crow)
558
+ children = defaultdict(list)
559
+ parents = defaultdict(list)
560
+ for li, (s_idx, t_idx, _v, _prow, _crow) in enumerate(rows_all):
561
+ children[s_idx].append(t_idx)
562
+ parents[t_idx].append(s_idx)
563
+
564
+ from collections import defaultdict as _dd
565
+
566
+ # Indices of special global nodes
567
+ _special_idx = {
568
+ idx
569
+ for key, idx in key_to_idx.items()
570
+ if isinstance(key, tuple) and len(key) == 2 and key[1] == "__GLOBAL__"
571
+ }
572
+
573
+ # Group by full (name, location) for non-special nodes
574
+ _label_groups = _dd(list)
575
+ for i in range(len(labels_vis)):
576
+ if i in _special_idx:
577
+ continue
578
+ key = (node_full_name.get(i, "").strip(), node_full_loc.get(i, "").strip())
579
+ _label_groups[key].append(i)
580
+
581
+ _instance_info = {}
582
+ for (_nm, _loc), idxs in _label_groups.items():
583
+ if len(idxs) <= 1:
584
+ continue
585
+ for pos, idx in enumerate(sorted(idxs), start=1):
586
+ pars = parents.get(idx, [])
587
+ if len(pars) == 1:
588
+ p_name = (node_full_name.get(pars[0], "") or "").strip()
589
+ short_parent = truncate_one_line(p_name, 16) or "parent"
590
+ tag = f"⟵ {short_parent}"
591
+ else:
592
+ tag = f"[{pos}]"
593
+ labels_vis[idx] = _append_tag_to_label(labels_vis[idx], tag)
594
+ _instance_info[idx] = (pos, len(idxs))
595
+
596
+ def _descendants(root: int) -> set[int]:
597
+ out, q = {root}, deque([root])
598
+ while q:
599
+ u = q.popleft()
600
+ for v in children.get(u, ()):
601
+ if v not in out:
602
+ out.add(v)
603
+ q.append(v)
604
+ return out
605
+
606
+ def _ancestors(root: int) -> set[int]:
607
+ out, q = {root}, deque([root])
608
+ while q:
609
+ v = q.popleft()
610
+ for u in parents.get(v, ()):
611
+ if u not in out:
612
+ out.add(u)
613
+ q.append(u)
614
+ return out
615
+
616
+ # rank links by |value| (absolute contribution), keep top-K as candidates
617
+ link_abs = [abs(v) for (_s, _t, v, _prow, _crow) in rows_all]
618
+ order_links = sorted(range(len(rows_all)), key=lambda i: -link_abs[i])
619
+ topK_idx = order_links[: min(highlight_top_k, len(order_links))]
620
+
621
+ # Incident magnitudes (for ordering/spacing)
622
+
623
+ magnitude = collections.defaultdict(float)
624
+ for s, t, v, _, _ in rows_all:
625
+ a = abs(v)
626
+ magnitude[s] += a
627
+ magnitude[t] += a
628
+
629
+ # Group nodes by column
630
+ def col_from_x(xv: float) -> int:
631
+ return int(round(xv * max(1, (ncols - 1))))
632
+
633
+ nodes_by_col: Dict[int, List[int]] = {c: [] for c in range(ncols)}
634
+ for k, idx in key_to_idx.items():
635
+ c = col_from_x(x_full[idx])
636
+ nodes_by_col[c].append(idx)
637
+
638
+ from collections import Counter
639
+
640
+ parent_locs = [(prow.get(col_location, "") or "—") for _, _, _, prow, _ in rows_all]
641
+ loc_counts = Counter(parent_locs)
642
+ unique_locs_sorted = [k for (k, _) in loc_counts.most_common()]
643
+ loc_to_color: Dict[str, str] = {
644
+ loc: palette_cycle(i, palette) for i, loc in enumerate(unique_locs_sorted)
645
+ }
646
+ MAX_LOC_LEGEND = 8
647
+
648
+ # Hover (full labels)
649
+ def _fmt_pct(x) -> str:
650
+ try:
651
+ v = 100.0 * float(x)
652
+ except Exception:
653
+ return "0%"
654
+ if v != 0 and abs(v) < 0.01:
655
+ return "<0.01%"
656
+ return f"{v:.2f}%"
657
+
658
+ def _rp_from_index_or_row(node_idx: int, row) -> str:
659
+ """Prefer the per-node cache; if empty, re-infer from the row."""
660
+ rp = (node_full_refprod.get(node_idx) or "").strip()
661
+ if not rp:
662
+ # reuse the same logic you used to build node_full_refprod
663
+ try:
664
+ # try explicit column if present
665
+ if col_ref_product in df.columns:
666
+ val = row.get(col_ref_product, None)
667
+ if pd.notna(val) and val is not None and str(val).strip():
668
+ return str(val).strip()
669
+ except Exception:
670
+ pass
671
+ # fallback to activity_key tuple
672
+ ak = row.get(col_id, None)
673
+ if (
674
+ isinstance(ak, tuple)
675
+ and len(ak) >= 2
676
+ and pd.notna(ak[1])
677
+ and ak[1] is not None
678
+ ):
679
+ return str(ak[1]).strip()
680
+ return rp
681
+
682
+ def make_hover_link(s_idx: int, t_idx: int, v_signed: float, prow, crow) -> str:
683
+ # % of total
684
+ rel_total = (abs(v_signed) / abs(total_root_score)) if total_root_score else 0.0
685
+
686
+ parent_loc = prow.get(col_location, "") or "—"
687
+ child_key = crow["_node_key"]
688
+ child_loc = (
689
+ "—"
690
+ if (
691
+ isinstance(child_key, tuple)
692
+ and len(child_key) == 2
693
+ and child_key[1] == "__GLOBAL__"
694
+ )
695
+ else (crow.get(col_location, "") or "—")
696
+ )
697
+
698
+ parent_name = node_full_name.get(s_idx, "")
699
+ child_name = node_full_name.get(t_idx, "")
700
+
701
+ # --- Reference products (read from the per-node cache, not the row) ---
702
+ parent_rp = _rp_from_index_or_row(s_idx, prow)
703
+ child_rp = _rp_from_index_or_row(t_idx, crow)
704
+
705
+ # If the child is the special "below cutoff" node, use the summary string
706
+ nm_special = None
707
+ if (
708
+ isinstance(child_key, tuple)
709
+ and len(child_key) == 2
710
+ and child_key[1] == "__GLOBAL__"
711
+ ):
712
+ nm_special = child_key[
713
+ 0
714
+ ] # "direct emissions" | "activities below cutoff" | "loss"
715
+ if not child_rp and nm_special == "activities below cutoff":
716
+ child_rp = (crow.get("collapsed_ref_products") or "").strip()
717
+
718
+ extra_lines = []
719
+ if parent_rp:
720
+ extra_lines.append(f"<br><i>Parent ref product:</i> {parent_rp}")
721
+ if child_rp:
722
+ label = "Child ref product" + (
723
+ "(s)" if nm_special == "activities below cutoff" else ""
724
+ )
725
+ extra_lines.append(f"<br><i>{label}:</i> {child_rp}")
726
+
727
+ amt = crow.get(col_amount, None)
728
+ amt_line = (
729
+ f"<br>Raw amount: {amt:,.5g}"
730
+ if (amt is not None and not pd.isna(amt))
731
+ else ""
732
+ )
733
+
734
+ return (
735
+ f"<b>{child_name}</b> ← <b>{parent_name}</b>"
736
+ f"<br><i>Child location:</i> {child_loc}"
737
+ f"<br><i>Parent location:</i> {parent_loc}"
738
+ f"<br>Flow: {v_signed:,.5g}"
739
+ f"<br>Contribution of total: {_fmt_pct(rel_total)}"
740
+ f"{amt_line}" + "".join(extra_lines)
741
+ )
742
+
743
+ node_hoverdata = []
744
+ for i in range(len(labels_vis)):
745
+ parts = [f"<b>{node_full_name.get(i,'')}</b>"]
746
+
747
+ rp = (node_full_refprod.get(i, "") or "").strip()
748
+ if rp:
749
+ parts.append(f"<i>Ref. product:</i> {rp}")
750
+
751
+ loc = (node_full_loc.get(i, "") or "").strip()
752
+ if loc:
753
+ parts.append(loc)
754
+
755
+ # Add node score and share
756
+ sc = node_score_by_idx.get(i, None)
757
+ if sc is not None:
758
+ parts.append(f"<i>Node score:</i> {sc:,.6g}")
759
+ if total_root_score:
760
+ parts.append(f"<i>Share of total:</i> {_fmt_pct(sc/total_root_score)}")
761
+
762
+ # If this node label was disambiguated, show the instance number
763
+ inst = _instance_info.get(i)
764
+ if inst:
765
+ parts.append(f"<i>Instance:</i> #{inst[0]} of {inst[1]}")
766
+
767
+ node_hoverdata.append("<br>".join(parts))
768
+
769
+ # ---------- Forward pass scaling: make outgoing == actually-received incoming ----------
770
+ balance_mode = str(balance_mode).lower()
771
+
772
+ # base (unscaled) absolute link widths
773
+ base_vals = [abs(v) for (_s, _t, v, _pr, _cr) in rows_all]
774
+
775
+ # index outgoing links per source, incoming links per target
776
+ from collections import defaultdict
777
+
778
+ out_links = defaultdict(list)
779
+ in_links = defaultdict(list)
780
+ for li, (s_idx, t_idx, _v, _pr, _cr) in enumerate(rows_all):
781
+ out_links[s_idx].append(li)
782
+ in_links[t_idx].append(li)
783
+
784
+ # unscaled outgoing sum per node
785
+ out_abs = {
786
+ node: sum(base_vals[li] for li in out_links.get(node, ()))
787
+ for node in range(len(labels_vis))
788
+ }
789
+
790
+ # incoming widths after upstream scaling (initialize zeros)
791
+ incoming_scaled = [0.0] * len(labels_vis)
792
+ out_scale = [1.0] * len(labels_vis) # default
793
+
794
+ # process nodes by column from left to right so parents go first
795
+ cols_sorted = sorted(nodes_by_col.keys())
796
+ for col in cols_sorted:
797
+ for node in nodes_by_col[col]:
798
+ out_sum = out_abs.get(node, 0.0)
799
+ in_sum = incoming_scaled[node]
800
+
801
+ if out_sum > 0:
802
+ if balance_mode == "match" and in_sum > 0:
803
+ out_scale[node] = in_sum / out_sum
804
+ elif balance_mode == "cap" and in_sum > 0:
805
+ out_scale[node] = min(1.0, in_sum / out_sum)
806
+ else:
807
+ out_scale[node] = 1.0
808
+ else:
809
+ out_scale[node] = 1.0
810
+
811
+ # propagate scaled outgoing to children
812
+ for li in out_links.get(node, ()):
813
+ s_idx, t_idx, _v, _pr, _cr = rows_all[li]
814
+ incoming_scaled[t_idx] += base_vals[li] * out_scale[node]
815
+
816
+ # Build links for both color modes, applying the per-parent scale to outgoing widths
817
+ def links_category(rows):
818
+ src, tgt, val, colr, hov = [], [], [], [], []
819
+ for li, (s_idx, t_idx, v_signed, prow, crow) in enumerate(rows):
820
+ ck = crow["_node_key"]
821
+ nm = ck[0] if (isinstance(ck, tuple)) else ""
822
+ if nm == "direct emissions":
823
+ c = hex_to_rgba(color_direct, 0.55)
824
+ elif nm == "activities below cutoff":
825
+ c = hex_to_rgba(color_below, 0.55)
826
+ elif nm == "loss":
827
+ c = hex_to_rgba(color_loss, 0.55)
828
+ else:
829
+ c = hex_to_rgba(color_other, 0.40)
830
+
831
+ v = base_vals[li] * out_scale[s_idx]
832
+ src.append(s_idx)
833
+ tgt.append(t_idx)
834
+ val.append(v)
835
+ colr.append(c)
836
+ hov.append(make_hover_link(s_idx, t_idx, v_signed, prow, crow))
837
+ return dict(source=src, target=tgt, value=val, color=colr, customdata=hov)
838
+
839
+ def links_by_parentloc(rows):
840
+ src, tgt, val, colr, hov = [], [], [], [], []
841
+ for li, (s_idx, t_idx, v_signed, prow, crow) in enumerate(rows):
842
+ base = loc_to_color.get(prow.get(col_location, "") or "—", color_other)
843
+ c = hex_to_rgba(base, 0.60)
844
+ v = base_vals[li] * out_scale[s_idx] # <--- scaled width
845
+ src.append(s_idx)
846
+ tgt.append(t_idx)
847
+ val.append(v)
848
+ colr.append(c)
849
+ hov.append(make_hover_link(s_idx, t_idx, v_signed, prow, crow))
850
+ return dict(source=src, target=tgt, value=val, color=colr, customdata=hov)
851
+
852
+ links_cat = links_category(rows_all)
853
+ links_loc = links_by_parentloc(rows_all)
854
+
855
+ # ----- Build "hide specials" variants (transparent color ONLY; keep values) -----
856
+ is_special_target = []
857
+ for _s, _t, _v, _prow, crow in rows_all:
858
+ ck = crow["_node_key"]
859
+ is_special_target.append(
860
+ isinstance(ck, tuple) and len(ck) == 2 and ck[1] == "__GLOBAL__"
861
+ )
862
+
863
+ def _hide_colors(colors, mask):
864
+ return [_rgba_with_alpha(c, 0.0) if m else c for c, m in zip(colors, mask)]
865
+
866
+ cat_cols_hide = _hide_colors(links_cat["color"], is_special_target)
867
+ loc_cols_hide = _hide_colors(links_loc["color"], is_special_target)
868
+
869
+ # (optional) also mute hover on hidden links:
870
+ cat_hover_hide = [
871
+ ("" if m else h) for h, m in zip(links_cat["customdata"], is_special_target)
872
+ ]
873
+ loc_hover_hide = [
874
+ ("" if m else h) for h, m in zip(links_loc["customdata"], is_special_target)
875
+ ]
876
+
877
+ # --- base color arrays for restyling --------------------------------------
878
+ node_colors_base = [_rgba_with_alpha(c, 1.0) for c in colors_full]
879
+ link_colors_cat_base = list(links_cat["color"])
880
+ link_colors_loc_base = list(links_loc["color"])
881
+
882
+ def _make_highlight_state(link_i: int):
883
+ """Return (node_colors, link_colors_cat, link_colors_loc) for a selected link."""
884
+ s_idx, t_idx, _v, _prow, _crow = rows_all[link_i]
885
+
886
+ # upstream: all ancestors of source; downstream: all descendants of target
887
+ up_nodes = _ancestors(s_idx)
888
+ down_nodes = _descendants(t_idx)
889
+ on_nodes = up_nodes | down_nodes | {s_idx, t_idx}
890
+
891
+ # choose links that stay within the upstream DAG or the downstream subtree
892
+ on_links_mask = [False] * len(rows_all)
893
+ for j, (sj, tj, _vj, _pr, _cr) in enumerate(rows_all):
894
+ if (
895
+ (sj in up_nodes and tj in up_nodes)
896
+ or (sj in down_nodes and tj in down_nodes)
897
+ or (j == link_i)
898
+ ):
899
+ on_links_mask[j] = True
900
+
901
+ # nodes: keep hue, change alpha
902
+ node_cols = [
903
+ _rgba_with_alpha(
904
+ colors_full[i],
905
+ highlight_alpha_on if i in on_nodes else highlight_alpha_off,
906
+ )
907
+ for i in range(len(colors_full))
908
+ ]
909
+ # links: keep hue, change alpha
910
+ link_cols_cat = [
911
+ _rgba_with_alpha(
912
+ link_colors_cat_base[j],
913
+ highlight_alpha_on if on_links_mask[j] else highlight_alpha_off,
914
+ )
915
+ for j in range(len(rows_all))
916
+ ]
917
+ link_cols_loc = [
918
+ _rgba_with_alpha(
919
+ link_colors_loc_base[j],
920
+ highlight_alpha_on if on_links_mask[j] else highlight_alpha_off,
921
+ )
922
+ for j in range(len(rows_all))
923
+ ]
924
+ return node_cols, link_cols_cat, link_cols_loc
925
+
926
+ # ---- Top/bottom domain for nodes (keep nodes away from menus) -------------
927
+ needs_top_bar = add_toggle or (enable_highlight and len(rows_all) > 0)
928
+ top_margin = 156 if needs_top_bar else (132 if add_toggle else 56)
929
+ bottom_margin = 8
930
+ top_dom, bot_dom = 0.04, 0.96 # y-domain used by the sankey traces
931
+ dom_span = bot_dom - top_dom
932
+
933
+ # ---- Layout numbers (top/bottom margins + domain) ----
934
+ needs_top_bar = add_toggle or (enable_highlight and len(rows_all) > 0)
935
+ top_margin = 156 if needs_top_bar else (132 if add_toggle else 56)
936
+ bottom_margin = 8
937
+ top_dom, bot_dom = 0.04, 0.96
938
+ dom_span = bot_dom - top_dom
939
+
940
+ # Soft height heuristic
941
+ n_nodes_total = sum(len(v) for v in nodes_by_col.values())
942
+ est_h_soft = int(
943
+ base_height
944
+ + per_level_px * (len(levels) - 1)
945
+ + per_node_px * math.sqrt(max(1, n_nodes_total))
946
+ )
947
+ est_h = min(height_max, max(height_min, est_h_soft))
948
+ pane_h = max(1.0, est_h - (top_margin + bottom_margin))
949
+ px_per_dom = pane_h * dom_span
950
+ pad_dom = node_pad / px_per_dom
951
+ # --- Node rectangle thickness (pixels) and its domain equivalent ---
952
+ th_eff = int(node_thickness) # pixels: what Plotly will actually draw
953
+ th_norm = th_eff / max(
954
+ 1e-9, px_per_dom
955
+ ) # domain units: min height each node occupies
956
+
957
+ # --- Use the same scaled link values Plotly will render ---
958
+ scaled_vals = [
959
+ base_vals[li] * out_scale[s_idx]
960
+ for li, (s_idx, t_idx, _v, _pr, _cr) in enumerate(rows_all)
961
+ ]
962
+ incoming = collections.defaultdict(float)
963
+ outgoing = collections.defaultdict(float)
964
+ for li, (s_idx, t_idx, _v, _pr, _cr) in enumerate(rows_all):
965
+ v = scaled_vals[li]
966
+ outgoing[s_idx] += v
967
+ incoming[t_idx] += v
968
+
969
+ # raw value-height per node (same units as link values)
970
+ h_raw = [0.0] * len(labels_vis)
971
+ for i in range(len(labels_vis)):
972
+ h_raw[i] = max(incoming.get(i, 0.0), outgoing.get(i, 0.0), 1e-12)
973
+
974
+ # per-column sums (raw)
975
+ col_sum_raw = {c: sum(h_raw[i] for i in idxs) for c, idxs in nodes_by_col.items()}
976
+
977
+ # global values→domain scale (limiting column sets the scale)
978
+ if nodes_by_col:
979
+ S_dom_candidates = []
980
+ for c, idxs in nodes_by_col.items():
981
+ total = col_sum_raw.get(c, 0.0)
982
+ if total > 0:
983
+ n = len(idxs)
984
+ S_dom_candidates.append((dom_span - max(0, n - 1) * pad_dom) / total)
985
+ S_dom = max(0.0, min(S_dom_candidates) if S_dom_candidates else 1.0)
986
+ else:
987
+ S_dom = 1.0
988
+
989
+ # node heights in domain units
990
+ h_dom = [h * S_dom for h in h_raw]
991
+ # Use at least the visual rectangle height when packing to avoid overlap
992
+ h_draw_dom = [max(h, th_norm) for h in h_dom]
993
+
994
+ def pack_column_tops(order, lo, hi):
995
+ if not order:
996
+ return []
997
+ avail = max(1e-9, hi - lo)
998
+ n = len(order)
999
+ total_h = sum(h_draw_dom[i] for i in order)
1000
+ total_with_pad = total_h + max(0, n - 1) * pad_dom
1001
+ pad_eff = (
1002
+ pad_dom
1003
+ if total_with_pad <= avail
1004
+ else max(0.0, (avail - total_h) / max(1, n - 1))
1005
+ )
1006
+ slack = avail - (total_h + max(0, n - 1) * pad_eff)
1007
+
1008
+ # give bigger nodes a bit more breathing room
1009
+ weights = [max(1e-12, h_dom[i]) for i in order]
1010
+ wsum = sum(weights)
1011
+ gaps = [slack * (w / wsum) for w in weights]
1012
+
1013
+ ytops, cur = [], lo
1014
+ for k, i in enumerate(order):
1015
+ cur += gaps[k]
1016
+ ytops.append(cur)
1017
+ cur += h_draw_dom[i] + pad_eff
1018
+ # clamp
1019
+ return [max(lo, min(hi - h_draw_dom[i], y)) for y, i in zip(ytops, order)]
1020
+
1021
+ def _tie_break_key(i: int) -> tuple:
1022
+ return (-magnitude[i], labels_vis[i], i)
1023
+
1024
+ # build y_top using actual heights; pin specials first in last col
1025
+ y_top = [0.5] * len(labels_vis)
1026
+ special_order_keys = [
1027
+ ("direct emissions", "__GLOBAL__"),
1028
+ ("activities below cutoff", "__GLOBAL__"),
1029
+ ("loss", "__GLOBAL__"),
1030
+ ]
1031
+ special_indices = [key_to_idx[k] for k in special_order_keys if k in key_to_idx]
1032
+
1033
+ for c, idxs in nodes_by_col.items():
1034
+ if not idxs:
1035
+ continue
1036
+ lo, hi = top_dom, bot_dom
1037
+ if c == last_col:
1038
+ ordered_rest = sorted(
1039
+ [i for i in idxs if i not in special_indices], key=_tie_break_key
1040
+ )
1041
+ ordered = special_indices + ordered_rest
1042
+ else:
1043
+ ordered = sorted(idxs, key=_tie_break_key)
1044
+ y_col = pack_column_tops(ordered, lo, hi)
1045
+ for i, y in zip(ordered, y_col):
1046
+ y_top[i] = y
1047
+
1048
+ # numerical guard
1049
+ EPS = 1e-6
1050
+ for c, idxs in nodes_by_col.items():
1051
+ col = sorted(idxs, key=lambda i: y_top[i])
1052
+ for k in range(1, len(col)):
1053
+ prev, cur = col[k - 1], col[k]
1054
+ min_top = y_top[prev] + h_draw_dom[prev] - EPS
1055
+ if y_top[cur] < min_top:
1056
+ y_top[cur] = min_top
1057
+ overflow = (y_top[col[-1]] + h_draw_dom[col[-1]] - bot_dom) if col else 0.0
1058
+ if overflow > 0:
1059
+ for i in col:
1060
+ y_top[i] = max(top_dom, y_top[i] - overflow)
1061
+
1062
+ # Traces (two sankeys)
1063
+ th_eff = int(node_thickness)
1064
+
1065
+ def make_trace(link_dict: Dict[str, list]) -> go.Sankey:
1066
+ node_dict = dict(
1067
+ pad=node_pad,
1068
+ thickness=th_eff,
1069
+ label=labels_vis,
1070
+ color=colors_full,
1071
+ customdata=node_hoverdata,
1072
+ hovertemplate="%{customdata}<extra></extra>",
1073
+ )
1074
+ arrangement = "fixed" if lock_x_by_level else "snap"
1075
+ if lock_x_by_level:
1076
+ node_dict["x"] = x_full
1077
+ node_dict["y"] = y_top # TOP coords in domain units
1078
+ return go.Sankey(
1079
+ arrangement=arrangement,
1080
+ domain=dict(
1081
+ x=[0, 1], y=[top_dom, bot_dom]
1082
+ ), # <--- keep nodes inside this band
1083
+ node=node_dict,
1084
+ link=dict(
1085
+ source=link_dict["source"],
1086
+ target=link_dict["target"],
1087
+ value=link_dict["value"],
1088
+ color=link_dict["color"],
1089
+ customdata=link_dict["customdata"],
1090
+ hovertemplate="%{customdata}<extra></extra>",
1091
+ ),
1092
+ )
1093
+
1094
+ fig = go.Figure(data=[make_trace(links_cat), make_trace(links_loc)])
1095
+ fig.data[0].visible = True
1096
+ fig.data[1].visible = False
1097
+
1098
+ # Legends
1099
+ legend_cat = [
1100
+ go.Scatter(
1101
+ x=[None],
1102
+ y=[None],
1103
+ mode="markers",
1104
+ marker=dict(size=10, color=color_direct),
1105
+ name="Direct emissions/Res. use",
1106
+ showlegend=True,
1107
+ hoverinfo="skip",
1108
+ ),
1109
+ go.Scatter(
1110
+ x=[None],
1111
+ y=[None],
1112
+ mode="markers",
1113
+ marker=dict(size=10, color=color_below),
1114
+ name="Activities below cutoff",
1115
+ showlegend=True,
1116
+ hoverinfo="skip",
1117
+ ),
1118
+ go.Scatter(
1119
+ x=[None],
1120
+ y=[None],
1121
+ mode="markers",
1122
+ marker=dict(size=10, color=color_loss),
1123
+ name="Loss",
1124
+ showlegend=True,
1125
+ hoverinfo="skip",
1126
+ ),
1127
+ go.Scatter(
1128
+ x=[None],
1129
+ y=[None],
1130
+ mode="markers",
1131
+ marker=dict(size=10, color=color_other),
1132
+ name="Other flows",
1133
+ showlegend=True,
1134
+ hoverinfo="skip",
1135
+ ),
1136
+ ]
1137
+ top_locs = unique_locs_sorted[:MAX_LOC_LEGEND]
1138
+ legend_loc = [
1139
+ go.Scatter(
1140
+ x=[None],
1141
+ y=[None],
1142
+ mode="markers",
1143
+ marker=dict(size=10, color=loc_to_color[loc]),
1144
+ name=f"{loc}",
1145
+ showlegend=True,
1146
+ hoverinfo="skip",
1147
+ )
1148
+ for loc in top_locs
1149
+ ]
1150
+ if len(unique_locs_sorted) > MAX_LOC_LEGEND:
1151
+ legend_loc.append(
1152
+ go.Scatter(
1153
+ x=[None],
1154
+ y=[None],
1155
+ mode="markers",
1156
+ marker=dict(size=10, color=color_other),
1157
+ name="Other locations",
1158
+ showlegend=True,
1159
+ hoverinfo="skip",
1160
+ )
1161
+ )
1162
+ for tr in legend_cat + legend_loc:
1163
+ fig.add_trace(tr)
1164
+
1165
+ cat_legend_count = len(legend_cat)
1166
+ loc_legend_count = len(legend_loc)
1167
+
1168
+ def vis_array(mode: str) -> List[bool]:
1169
+ base = [mode == "cat", mode == "loc"]
1170
+ cat_leg = [mode == "cat"] * cat_legend_count
1171
+ loc_leg = [mode == "loc"] * loc_legend_count
1172
+ return base + cat_leg + loc_leg
1173
+
1174
+ # Apply initial vis
1175
+ for i, v in enumerate(vis_array("cat")):
1176
+ fig.data[i].visible = v
1177
+
1178
+ # ---------------- Place the top controls without overlap ----------------
1179
+ # ---------- Build highlight dropdown buttons ----------
1180
+ highlight_buttons = []
1181
+ if enable_highlight and len(rows_all) > 0:
1182
+ # Reset option
1183
+ highlight_buttons.append(
1184
+ dict(
1185
+ label="Highlight: None",
1186
+ method="restyle",
1187
+ args=[
1188
+ {
1189
+ "node.color": [node_colors_base, node_colors_base],
1190
+ "link.color": [link_colors_cat_base, link_colors_loc_base],
1191
+ },
1192
+ [0, 1],
1193
+ ],
1194
+ )
1195
+ )
1196
+ # Top-K links
1197
+ for rank, li in enumerate(topK_idx, start=1):
1198
+ s_idx, t_idx, _v_signed, _prow, _crow = rows_all[li]
1199
+ parent_name = node_full_name.get(s_idx, "")
1200
+ child_name = node_full_name.get(t_idx, "")
1201
+ label_txt = f"#{rank} {child_name} ← {parent_name}"
1202
+
1203
+ node_cols, link_cols_cat, link_cols_loc = _make_highlight_state(li)
1204
+ highlight_buttons.append(
1205
+ dict(
1206
+ label=label_txt[:80],
1207
+ method="restyle",
1208
+ args=[
1209
+ {
1210
+ "node.color": [node_cols, node_cols],
1211
+ "link.color": [link_cols_cat, link_cols_loc],
1212
+ },
1213
+ [0, 1],
1214
+ ],
1215
+ )
1216
+ )
1217
+
1218
+ menus = []
1219
+
1220
+ # Left: color-mode buttons
1221
+ if add_toggle:
1222
+ menus.append(
1223
+ dict(
1224
+ type="buttons",
1225
+ direction="left",
1226
+ x=0.01,
1227
+ xanchor="left", # left edge
1228
+ y=1.28,
1229
+ yanchor="top", # above plot area
1230
+ pad=dict(l=6, r=6, t=2, b=2),
1231
+ buttons=[
1232
+ dict(
1233
+ label="Color: Category",
1234
+ method="update",
1235
+ args=[{"visible": vis_array("cat")}],
1236
+ ),
1237
+ dict(
1238
+ label="Color: Parent location",
1239
+ method="update",
1240
+ args=[{"visible": vis_array("loc")}],
1241
+ ),
1242
+ ],
1243
+ )
1244
+ )
1245
+
1246
+ # Right: highlight dropdown (only if enabled and we have links)
1247
+ if enable_highlight and len(rows_all) > 0:
1248
+ menus.append(
1249
+ dict(
1250
+ type="dropdown",
1251
+ direction="down",
1252
+ x=0.99,
1253
+ xanchor="right", # right edge
1254
+ y=1.28,
1255
+ yanchor="top",
1256
+ showactive=True,
1257
+ pad=dict(l=6, r=6, t=2, b=2),
1258
+ buttons=highlight_buttons,
1259
+ )
1260
+ )
1261
+
1262
+ # Right/center-left: flows toggle (show/hide links to special nodes)
1263
+ menus.append(
1264
+ dict(
1265
+ type="buttons",
1266
+ direction="left",
1267
+ x=0.32,
1268
+ xanchor="left",
1269
+ y=1.28,
1270
+ yanchor="top",
1271
+ pad=dict(l=6, r=6, t=2, b=2),
1272
+ buttons=[
1273
+ dict(
1274
+ label="Flows: Show specials",
1275
+ method="restyle",
1276
+ args=[
1277
+ {
1278
+ "link.color": [links_cat["color"], links_loc["color"]],
1279
+ # Optional: also restore hover text
1280
+ "link.customdata": [
1281
+ links_cat["customdata"],
1282
+ links_loc["customdata"],
1283
+ ],
1284
+ },
1285
+ [0, 1],
1286
+ ],
1287
+ ),
1288
+ dict(
1289
+ label="Flows: Hide specials",
1290
+ method="restyle",
1291
+ args=[
1292
+ {
1293
+ "link.color": [cat_cols_hide, loc_cols_hide],
1294
+ # Optional: blank hover on hidden links
1295
+ "link.customdata": [cat_hover_hide, loc_hover_hide],
1296
+ },
1297
+ [0, 1],
1298
+ ],
1299
+ ),
1300
+ ],
1301
+ )
1302
+ )
1303
+
1304
+ fig.update_layout(updatemenus=menus)
1305
+
1306
+ # ---------------- Layout/margins (extra top space for the controls) -----------
1307
+ needs_top_bar = add_toggle or (enable_highlight and len(rows_all) > 0)
1308
+ top_margin = 156 if needs_top_bar else (132 if add_toggle else 56)
1309
+
1310
+ # ---------- Width & autosize ----------
1311
+ if auto_width:
1312
+ est_w, autosize_flag = None, True
1313
+ else:
1314
+ raw_w = per_level_width * len(levels) + per_node_width * math.sqrt(
1315
+ max(1, n_nodes_total)
1316
+ )
1317
+ if width_max is not None:
1318
+ raw_w = min(width_max, raw_w)
1319
+ est_w, autosize_flag = max(width_min, int(raw_w)), False
1320
+
1321
+ fig.update_layout(
1322
+ height=est_h,
1323
+ width=est_w,
1324
+ autosize=autosize_flag,
1325
+ margin=dict(l=8, r=8, t=top_margin, b=8),
1326
+ paper_bgcolor="rgba(0,0,0,0)",
1327
+ plot_bgcolor="rgba(0,0,0,0)",
1328
+ xaxis=dict(visible=False),
1329
+ yaxis=dict(visible=False),
1330
+ legend=dict(
1331
+ orientation="h",
1332
+ yanchor="bottom",
1333
+ y=1.10, # slightly lower so it won't collide with menus
1334
+ xanchor="center",
1335
+ x=0.5,
1336
+ bgcolor="rgba(0,0,0,0)",
1337
+ ),
1338
+ )
1339
+
1340
+ compact = (est_w is not None) and (est_w < 1100)
1341
+ if compact and menus:
1342
+ for m in menus:
1343
+ m.update(x=0.5, xanchor="center")
1344
+ y_base, y_step = 1.34, 0.08
1345
+ for i, m in enumerate(menus):
1346
+ m.update(y=y_base - i * y_step)
1347
+ top_margin = max(top_margin, 200)
1348
+ fig.update_layout(margin=dict(l=8, r=8, t=top_margin, b=8))
1349
+
1350
+ return fig
1351
+
1352
+
1353
+ def wrap_label(text: str, max_chars: int, max_lines: int) -> str:
1354
+ """Wrap text to at most `max_lines` lines of width `max_chars`,
1355
+ adding an ellipsis on the last line if truncated. Never breaks words/hyphens.
1356
+ """
1357
+ if not text:
1358
+ return ""
1359
+
1360
+ s = str(text).strip()
1361
+ if not s:
1362
+ return ""
1363
+ lines = _tw.wrap(s, width=max_chars, break_long_words=False, break_on_hyphens=False)
1364
+ if len(lines) > max_lines:
1365
+ lines = lines[:max_lines]
1366
+ if len(lines[-1]) >= max_chars:
1367
+ lines[-1] = lines[-1][: max_chars - 1] + "…"
1368
+ else:
1369
+ lines[-1] += "…"
1370
+ return "\n".join(lines)
1371
+
1372
+
1373
+ def save_sankey_html(
1374
+ fig: go.Figure,
1375
+ path: str,
1376
+ *,
1377
+ title: str = "Supply chain Sankey",
1378
+ offline: bool = True,
1379
+ auto_open: bool = True,
1380
+ modebar_remove: tuple = ("lasso2d", "select2d"),
1381
+ ) -> str:
1382
+ """
1383
+ Save a Plotly Sankey figure as a standalone HTML file.
1384
+
1385
+ Parameters
1386
+ ----------
1387
+ fig : go.Figure
1388
+ Figure returned by sankey_from_supply_df(...) or SupplyChain.plot_sankey(...).
1389
+ path : str
1390
+ Output file path. '.html' will be added if missing.
1391
+ title : str
1392
+ <title> of the HTML document (browser tab name).
1393
+ offline : bool
1394
+ If True, embed plotly.js inside the HTML (bigger file, fully offline).
1395
+ If False, load plotly.js from CDN (smaller file).
1396
+ auto_open : bool
1397
+ If True, open the file in a browser after writing.
1398
+ modebar_remove : tuple
1399
+ Modebar buttons to remove.
1400
+
1401
+ Returns
1402
+ -------
1403
+ str
1404
+ The (possibly extended) file path that was written.
1405
+ """
1406
+
1407
+ if not path.lower().endswith(".html"):
1408
+ path += ".html"
1409
+
1410
+ out_dir = os.path.dirname(path) or "."
1411
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
1412
+
1413
+ include = True if offline else "cdn"
1414
+ config = {
1415
+ "displaylogo": False,
1416
+ "modeBarButtonsToRemove": list(modebar_remove),
1417
+ # you can add "toImageButtonOptions":{"scale":2} if you want bigger PNG exports
1418
+ }
1419
+
1420
+ # Keep figure layout as-is; just write it out
1421
+ try:
1422
+ pio.write_html(
1423
+ fig,
1424
+ file=path,
1425
+ include_plotlyjs=include,
1426
+ full_html=True,
1427
+ auto_open=auto_open,
1428
+ config=config,
1429
+ )
1430
+ except TypeError:
1431
+ # Fallback for older Plotly that doesn't support 'title' in write_html
1432
+ pio.write_html(
1433
+ fig,
1434
+ file=path,
1435
+ include_plotlyjs=include,
1436
+ full_html=True,
1437
+ auto_open=auto_open,
1438
+ config=config,
1439
+ )
1440
+ return path
1441
+
1442
+
1443
+ def save_html_multi_methods_for_activity(
1444
+ activity: Activity,
1445
+ methods: Sequence[tuple],
1446
+ path: str,
1447
+ *,
1448
+ amount: float = 1.0,
1449
+ level: int = 3,
1450
+ cutoff: float = 0.01,
1451
+ cutoff_basis: str = "total",
1452
+ scenario: str | None = None,
1453
+ scenario_idx: int | str = 0,
1454
+ use_distributions: bool = False,
1455
+ iterations: int = 100,
1456
+ random_seed: int | None = None,
1457
+ collapse_markets: bool = False,
1458
+ plot_kwargs: Optional[Dict[str, Any]] = None,
1459
+ offline: bool = False,
1460
+ auto_open: bool = False,
1461
+ label_fn=lambda m: " / ".join(str(x) for x in m),
1462
+ ) -> str:
1463
+ """
1464
+ Compute one Sankey per impact method and save them into a single tabbed HTML.
1465
+
1466
+ Usage:
1467
+ save_html_multi_methods_for_activity(
1468
+ activity, methods, "outputs/multi_impact.html",
1469
+ level=3, cutoff=0.01, collapse_markets=True,
1470
+ plot_kwargs=dict(width_max=1800, height_max=800),
1471
+ offline=False, auto_open=True
1472
+ )
1473
+ """
1474
+ label_to_df: Dict[str, pd.DataFrame] = {}
1475
+ for m in methods:
1476
+ sc = SupplyChain(
1477
+ activity=activity,
1478
+ method=m,
1479
+ amount=amount,
1480
+ level=level,
1481
+ cutoff=cutoff,
1482
+ cutoff_basis=cutoff_basis,
1483
+ scenario=scenario,
1484
+ scenario_idx=scenario_idx,
1485
+ use_distributions=use_distributions,
1486
+ iterations=iterations,
1487
+ random_seed=random_seed,
1488
+ collapse_markets=collapse_markets,
1489
+ )
1490
+ sc.bootstrap()
1491
+ df, _, _ = sc.calculate()
1492
+ label_to_df[label_fn(m)] = df
1493
+
1494
+ return save_sankey_html_multi(
1495
+ label_to_df,
1496
+ path,
1497
+ plot_kwargs=plot_kwargs or {},
1498
+ offline=offline,
1499
+ auto_open=auto_open,
1500
+ title="Multi-impact Sankey",
1501
+ )
1502
+
1503
+
1504
+ @dataclass
1505
+ class SupplyChainRow:
1506
+ level: int
1507
+ share_of_total: float
1508
+ score: float
1509
+ amount: float
1510
+ name: str | None
1511
+ location: str | None
1512
+ unit: str | None
1513
+ activity_key: Tuple[str, str, str] | None
1514
+ parent_key: Tuple[str, str, str] | None
1515
+ collapsed_ref_products: str | None = None
1516
+ row_id: int | None = None # <---
1517
+ parent_row_id: int | None = None # <---
1518
+
1519
+
1520
+ class SupplyChain:
1521
+
1522
+ def __init__(
1523
+ self,
1524
+ activity: Activity,
1525
+ method: tuple,
1526
+ *,
1527
+ amount: float = 1.0,
1528
+ level: int = 3,
1529
+ cutoff: float = 0.01,
1530
+ cutoff_basis: str = "total",
1531
+ scenario: str | None = None,
1532
+ scenario_idx: int | str = 0,
1533
+ use_distributions: bool = False,
1534
+ iterations: int = 100,
1535
+ random_seed: int | None = None,
1536
+ collapse_markets: bool = False,
1537
+ debug: bool = False,
1538
+ dbg_max_prints: int = 2000,
1539
+ market_top_k: int = 60,
1540
+ ):
1541
+ if not isinstance(activity, Activity):
1542
+ raise TypeError("`activity` must be a Brightway2 Activity.")
1543
+
1544
+ self.root = activity
1545
+ self.method = method
1546
+ self.amount = float(amount) * (
1547
+ -1.0 if self._is_waste_process(activity) else 1.0
1548
+ )
1549
+ self.level = int(level)
1550
+ self.cutoff = float(cutoff)
1551
+ self.cutoff_basis = str(cutoff_basis).lower()
1552
+ if self.cutoff_basis not in {"total", "parent"}:
1553
+ raise ValueError("cutoff_basis must be 'total' or 'parent'")
1554
+
1555
+ self.scenario = scenario
1556
+ self.scenario_idx = scenario_idx
1557
+ self.collapse_markets = bool(collapse_markets)
1558
+
1559
+ self.elcia = EdgeLCIA(
1560
+ demand={activity: self.amount},
1561
+ method=method,
1562
+ use_distributions=use_distributions,
1563
+ iterations=iterations,
1564
+ random_seed=random_seed,
1565
+ scenario=scenario,
1566
+ )
1567
+
1568
+ self._total_score: Optional[float] = None
1569
+ self._unit_score_cache: Dict[Any, float] = {}
1570
+ self._market_flat_cache: Dict[Any, List[Tuple[Activity, float]]] = {}
1571
+
1572
+ self.market_top_k = int(market_top_k)
1573
+
1574
+ self._row_counter = 0
1575
+
1576
+ def _next_row_id(self) -> int:
1577
+ rid = self._row_counter
1578
+ self._row_counter += 1
1579
+ return rid
1580
+
1581
+ @staticmethod
1582
+ def _short_act(act: Activity) -> str:
1583
+ try:
1584
+ nm = str(act.get("name"))
1585
+ except Exception:
1586
+ nm = "<?>"
1587
+ try:
1588
+ loc = act.get("location")
1589
+ except Exception:
1590
+ loc = None
1591
+ locs = f" [{loc}]" if loc else ""
1592
+ return nm + locs
1593
+
1594
+ @staticmethod
1595
+ def _is_market_name(val: Any) -> bool:
1596
+ """Return True if an activity name looks like an ecoinvent market."""
1597
+ if val is None or (isinstance(val, float) and pd.isna(val)):
1598
+ return False
1599
+ s = str(val).strip().lower()
1600
+ return s.startswith("market for ") or s.startswith("market group for ")
1601
+
1602
+ def _flatten_market_suppliers(
1603
+ self, market_act: Activity
1604
+ ) -> List[Tuple[Activity, float]]:
1605
+ """Flatten a MARKET into final suppliers with per-unit coefficients. Cached."""
1606
+ mk = self._act_cache_key(market_act)
1607
+ hit = self._market_flat_cache.get(mk)
1608
+ if hit is not None:
1609
+ return hit
1610
+
1611
+ t0 = time.perf_counter()
1612
+ out_pairs: List[Tuple[Activity, float]] = []
1613
+ nodes_visited = 0
1614
+ edges_traversed = 0
1615
+
1616
+ def _dfs(act: Activity, coef: float, path: set):
1617
+ nonlocal nodes_visited, edges_traversed
1618
+ nodes_visited += 1
1619
+ ak = self._act_cache_key(act)
1620
+ if ak in path:
1621
+ out_pairs.append((act, coef))
1622
+ return
1623
+ if not _is_market_name(act.get("name")):
1624
+ out_pairs.append((act, coef))
1625
+ return
1626
+ sups = list(act.technosphere())
1627
+ if not sups:
1628
+ out_pairs.append((act, coef))
1629
+ return
1630
+ path.add(ak)
1631
+ for ex in sups:
1632
+ sup = ex.input
1633
+ amt = float(ex["amount"])
1634
+ edges_traversed += 1
1635
+ _dfs(sup, coef * amt, path)
1636
+ path.remove(ak)
1637
+
1638
+ _dfs(market_act, 1.0, set())
1639
+
1640
+ # aggregate duplicates
1641
+ from collections import defaultdict
1642
+
1643
+ agg: Dict[Any, float] = defaultdict(float)
1644
+ key2act: Dict[Any, Activity] = {}
1645
+ for s, c in out_pairs:
1646
+ k = self._act_cache_key(s)
1647
+ agg[k] += c
1648
+ key2act[k] = s
1649
+
1650
+ flat = [(key2act[k], agg[k]) for k in agg]
1651
+ self._market_flat_cache[mk] = flat
1652
+
1653
+ return flat
1654
+
1655
+ def _score_per_unit(self, act: Activity) -> float:
1656
+ """Memoized unit score for an activity (uses current scenario/flags)."""
1657
+ k = self._act_cache_key(act)
1658
+ hit = self._unit_score_cache.get(k)
1659
+ if hit is not None:
1660
+ return hit
1661
+ t0 = time.perf_counter()
1662
+
1663
+ self.elcia.redo_lcia(
1664
+ demand={act.id if is_bw25 else act: 1.0},
1665
+ scenario_idx=self.scenario_idx,
1666
+ scenario=self.scenario,
1667
+ recompute_score=True,
1668
+ )
1669
+ s = float(self.elcia.score or 0.0)
1670
+ dt = time.perf_counter() - t0
1671
+ self._unit_score_cache[k] = s
1672
+
1673
+ return s
1674
+
1675
+ @staticmethod
1676
+ def _act_cache_key(act: Activity) -> Any:
1677
+ # Prefer unique, stable identifiers if available
1678
+ for attr in ("id", "key"):
1679
+ if hasattr(act, attr):
1680
+ return getattr(act, attr)
1681
+ # Fallback: use db/code if present, else your tuple key
1682
+ try:
1683
+ return (act["database"], act["code"])
1684
+ except Exception:
1685
+ return (act["name"], act.get("reference product"), act.get("location"))
1686
+
1687
+ def bootstrap(self) -> float:
1688
+ """
1689
+ Run the initial EdgeLCIA pipeline on the root demand to build CM,
1690
+ then compute and store the total score.
1691
+ """
1692
+ # Standard pipeline on root demand
1693
+ self.elcia.lci()
1694
+ self.elcia.apply_strategies()
1695
+
1696
+ self.elcia.evaluate_cfs(scenario_idx=self.scenario_idx, scenario=self.scenario)
1697
+ self.elcia.lcia()
1698
+ self._total_score = float(self.elcia.score or 0.0)
1699
+ return self._total_score
1700
+
1701
+ def calculate(self) -> tuple[pd.DataFrame, float, float]:
1702
+ """
1703
+ Recursively traverse the technosphere, returning (df, total_score, reference_amount).
1704
+ Call `bootstrap()` first for best performance/coverage.
1705
+ """
1706
+ if self._total_score is None:
1707
+ self.bootstrap()
1708
+ rows = self._walk(self.root, self.amount, level=0, parent=None)
1709
+ df = pd.DataFrame([asdict(r) for r in rows])
1710
+ return df, float(self._total_score or 0.0), self.amount
1711
+
1712
+ def as_text(self, df: pd.DataFrame) -> StringIO:
1713
+ """Pretty text view of the breakdown."""
1714
+ buf = StringIO()
1715
+ if df.empty:
1716
+ buf.write("No contributions (total score is 0?)\n")
1717
+ return buf
1718
+ view = df[
1719
+ ["level", "share_of_total", "score", "amount", "name", "location", "unit"]
1720
+ ].copy()
1721
+ view["share_of_total"] = (view["share_of_total"] * 100).round(2)
1722
+ view["score"] = view["score"].astype(float).round(6)
1723
+ view["amount"] = view["amount"].astype(float)
1724
+ with pd.option_context("display.max_colwidth", 60):
1725
+ buf.write(view.to_string(index=False))
1726
+ return buf
1727
+
1728
+ # ---------- Internals ----------------------------------------------------
1729
+
1730
+ def _walk(
1731
+ self, act, amount, level, parent, _precomputed_score=None, _parent_row_id=None
1732
+ ):
1733
+ """Traverse one node with lazy market expansion (expand only above-cutoff, top-K)."""
1734
+ indent = " " * level
1735
+
1736
+ # --- Node score ---
1737
+ t0 = time.perf_counter()
1738
+ if level == 0:
1739
+ node_score = float(self._total_score or 0.0)
1740
+
1741
+ else:
1742
+ if _precomputed_score is None:
1743
+ self.elcia.redo_lcia(
1744
+ demand={(act.id if is_bw25 else act): amount},
1745
+ scenario_idx=self.scenario_idx,
1746
+ scenario=self.scenario,
1747
+ recompute_score=True,
1748
+ )
1749
+ node_score = float(self.elcia.score or 0.0)
1750
+ dt = time.perf_counter() - t0
1751
+
1752
+ else:
1753
+ node_score = float(_precomputed_score)
1754
+
1755
+ total = float(self._total_score or 0.0)
1756
+ share = (node_score / total) if total != 0 else 0.0
1757
+ cur_key = self._key(act)
1758
+
1759
+ # Cycle guard
1760
+ if parent is not None and cur_key == parent:
1761
+ return [
1762
+ SupplyChainRow(
1763
+ level=level,
1764
+ share_of_total=share,
1765
+ score=node_score,
1766
+ amount=float(amount),
1767
+ name="loss",
1768
+ location=None,
1769
+ unit=None,
1770
+ activity_key=None,
1771
+ parent_key=parent,
1772
+ )
1773
+ ]
1774
+
1775
+ rid = self._next_row_id()
1776
+ rows: List[SupplyChainRow] = [
1777
+ SupplyChainRow(
1778
+ level=level,
1779
+ share_of_total=share,
1780
+ score=node_score,
1781
+ amount=float(amount),
1782
+ name=act["name"],
1783
+ location=act.get("location"),
1784
+ unit=act.get("unit"),
1785
+ activity_key=cur_key,
1786
+ parent_key=parent,
1787
+ row_id=rid, # <---
1788
+ parent_row_id=_parent_row_id, # <---
1789
+ )
1790
+ ]
1791
+
1792
+ # Depth limit
1793
+ if level >= self.level:
1794
+ return rows
1795
+
1796
+ # Treat unknown-amount nodes as terminals
1797
+ if isinstance(amount, float) and math.isnan(amount):
1798
+ if node_score != 0.0:
1799
+ rows.append(
1800
+ SupplyChainRow(
1801
+ level=level + 1,
1802
+ share_of_total=(node_score / total) if total else 0.0,
1803
+ score=node_score,
1804
+ amount=float("nan"),
1805
+ name="Direct emissions/Res. use",
1806
+ location=None,
1807
+ unit=None,
1808
+ activity_key=None,
1809
+ parent_key=cur_key,
1810
+ row_id=self._next_row_id(),
1811
+ parent_row_id=rid,
1812
+ )
1813
+ )
1814
+ return rows
1815
+
1816
+ # ----------------------------------------------------------------------
1817
+ # 1) Collect children WITHOUT expanding markets; aggregate & score once.
1818
+ # ----------------------------------------------------------------------
1819
+ from collections import defaultdict
1820
+
1821
+ agg_amounts: Dict[Any, float] = defaultdict(float)
1822
+ key_to_act: Dict[Any, Activity] = {}
1823
+
1824
+ def _add_child(a: Activity, amt: float):
1825
+ k = self._act_cache_key(a)
1826
+ agg_amounts[k] += amt
1827
+ key_to_act[k] = a
1828
+
1829
+ exs = list(act.technosphere())
1830
+ for exc in exs:
1831
+ ch = exc.input
1832
+ ch_amt = amount * float(exc["amount"])
1833
+ _add_child(ch, ch_amt)
1834
+
1835
+ # Score each unique child ONCE with unit scores
1836
+ children: List[Tuple[Activity, float, float]] = []
1837
+ t_score0 = time.perf_counter()
1838
+ for k, amt in agg_amounts.items():
1839
+ a = key_to_act[k]
1840
+ unit = self._score_per_unit(a)
1841
+ children.append((a, amt, unit * amt))
1842
+ dt_score = time.perf_counter() - t_score0
1843
+
1844
+ if not children:
1845
+ # Leaf → all is direct emissions
1846
+ if node_score != 0.0:
1847
+ rows.append(
1848
+ SupplyChainRow(
1849
+ level=level + 1,
1850
+ share_of_total=(node_score / total) if total else 0.0,
1851
+ score=node_score,
1852
+ amount=float("nan"),
1853
+ name="Direct emissions/Res. use",
1854
+ location=None,
1855
+ unit=None,
1856
+ activity_key=None,
1857
+ parent_key=cur_key,
1858
+ row_id=self._next_row_id(),
1859
+ parent_row_id=rid,
1860
+ )
1861
+ )
1862
+ return rows
1863
+
1864
+ # --- Cutoff split (track BOTH above and below) -----------------------
1865
+ denom_parent = abs(node_score)
1866
+ denom_total = abs(total)
1867
+ denom_for_cutoff = (
1868
+ denom_parent
1869
+ if (self.cutoff_basis == "parent" and denom_parent > 0)
1870
+ else denom_total
1871
+ )
1872
+
1873
+ # Keep explicit lists; we’ll need `below` later to summarize ref products
1874
+ above: List[Tuple[Activity, float, float]] = []
1875
+ below: List[Tuple[Activity, float, float]] = []
1876
+
1877
+ for ch, ch_amt, ch_score in children:
1878
+ rel = (abs(ch_score) / denom_for_cutoff) if denom_for_cutoff > 0 else 0.0
1879
+ if rel >= self.cutoff:
1880
+ above.append((ch, ch_amt, ch_score))
1881
+ else:
1882
+ below.append((ch, ch_amt, ch_score))
1883
+
1884
+ # --- Lazy market expansion (only for above-cutoff markets) ----------
1885
+ if self.collapse_markets and above:
1886
+ above_final: List[Tuple[Activity, float, float]] = []
1887
+ below_extra: List[Tuple[Activity, float, float]] = []
1888
+ K = max(0, self.market_top_k)
1889
+
1890
+ for ch, ch_amt, ch_score in above:
1891
+ # Non-market stays as-is
1892
+ if not self._is_market_name(ch.get("name")):
1893
+ above_final.append((ch, ch_amt, ch_score))
1894
+ continue
1895
+
1896
+ # Expand market into suppliers
1897
+ t_flat = time.perf_counter()
1898
+ flat = self._flatten_market_suppliers(ch) # [(sup_act, coef_per_unit)]
1899
+ dt_flat = time.perf_counter() - t_flat
1900
+
1901
+ # Rank candidates; compute scores only for top-K
1902
+ flat_sorted = sorted(flat, key=lambda t: abs(t[1]), reverse=True)
1903
+
1904
+ promoted_scores = 0.0
1905
+ promoted_cnt = 0
1906
+ tested_cnt = 0
1907
+
1908
+ for sup, coef in flat_sorted[:K]:
1909
+ sup_amt = ch_amt * coef
1910
+ unit = self._score_per_unit(sup)
1911
+ sup_score = unit * sup_amt
1912
+ tested_cnt += 1
1913
+ rel = (
1914
+ (abs(sup_score) / denom_for_cutoff)
1915
+ if denom_for_cutoff > 0
1916
+ else 0.0
1917
+ )
1918
+ if rel >= self.cutoff:
1919
+ above_final.append((sup, sup_amt, sup_score))
1920
+ promoted_scores += sup_score
1921
+ promoted_cnt += 1
1922
+ else:
1923
+ below_extra.append((sup, sup_amt, sup_score))
1924
+
1925
+ residual = ch_score - promoted_scores
1926
+
1927
+ if promoted_cnt == 0:
1928
+ # No supplier cleared the global cutoff → keep the market itself visible
1929
+ # (don’t demote the whole thing into "below cutoff")
1930
+ above_final.append((ch, ch_amt, ch_score))
1931
+ else:
1932
+ # We promoted some suppliers. Decide what to do with the residual:
1933
+ # if the residual itself is big enough, keep it visible; else send to 'below'.
1934
+ rel_resid = (
1935
+ (abs(residual) / denom_for_cutoff)
1936
+ if denom_for_cutoff > 0
1937
+ else 0.0
1938
+ )
1939
+ if rel_resid >= self.cutoff:
1940
+ # Use 0.0 (not NaN) so recursion yields direct = node_score and shows a direct-emissions link
1941
+ above_final.append((ch, 0.0, residual))
1942
+ elif abs(residual) > 0:
1943
+ below_extra.append(
1944
+ (ch, 0.0, residual)
1945
+ ) # harmless either way (we don't recurse into "below")
1946
+
1947
+ # Replace above with expanded set; extend below with what fell short
1948
+ above = above_final
1949
+ below.extend(below_extra)
1950
+
1951
+ # --- Balance & specials ---------------------------------------------
1952
+ sum_above = sum(cs for _, _, cs in above)
1953
+ sum_below = sum(cs for _, _, cs in below)
1954
+ direct = node_score - (sum_above + sum_below)
1955
+
1956
+ if abs(direct) > 0.0:
1957
+ rows.append(
1958
+ SupplyChainRow(
1959
+ level=level + 1,
1960
+ share_of_total=(direct / total) if total else 0.0,
1961
+ score=direct,
1962
+ amount=float("nan"),
1963
+ name="Direct emissions/Res. use",
1964
+ location=None,
1965
+ unit=None,
1966
+ activity_key=None,
1967
+ parent_key=cur_key,
1968
+ row_id=self._next_row_id(),
1969
+ parent_row_id=rid,
1970
+ )
1971
+ )
1972
+
1973
+ if abs(sum_below) > 0.0:
1974
+ # Build a compact summary of ref products among below-cutoff children
1975
+ from collections import defaultdict
1976
+
1977
+ agg_rp = defaultdict(float)
1978
+ for ch, _amt, cs in below:
1979
+ rp = ch.get("reference product") or ""
1980
+ agg_rp[rp] += abs(cs)
1981
+
1982
+ TOPN = 6
1983
+ total_abs = sum(agg_rp.values()) or 0.0
1984
+ items = sorted(agg_rp.items(), key=lambda kv: kv[1], reverse=True)
1985
+ if total_abs > 0:
1986
+ parts = [
1987
+ f"{(k or '—')} ({v/total_abs*100:.1f}%)" for k, v in items[:TOPN]
1988
+ ]
1989
+ else:
1990
+ parts = [(k or "—") for k, _ in items[:TOPN]]
1991
+ more = max(0, len(items) - TOPN)
1992
+ if more:
1993
+ parts.append(f"+{more} more")
1994
+ rp_summary = ", ".join(parts)
1995
+
1996
+ rows.append(
1997
+ SupplyChainRow(
1998
+ level=level + 1,
1999
+ share_of_total=(sum_below / total) if total else 0.0,
2000
+ score=sum_below,
2001
+ amount=float("nan"),
2002
+ name="activities below cutoff",
2003
+ location=None,
2004
+ unit=None,
2005
+ activity_key=None,
2006
+ parent_key=cur_key,
2007
+ collapsed_ref_products=rp_summary,
2008
+ row_id=self._next_row_id(),
2009
+ parent_row_id=rid,
2010
+ )
2011
+ )
2012
+
2013
+ # --- Recurse into the final above-cutoff set ------------------------
2014
+ max_list = 6
2015
+ for idx, (ch, ch_amt, ch_score) in enumerate(above):
2016
+ rows.extend(
2017
+ self._walk(
2018
+ ch,
2019
+ ch_amt,
2020
+ level=level + 1,
2021
+ parent=cur_key,
2022
+ _precomputed_score=ch_score,
2023
+ _parent_row_id=rid,
2024
+ )
2025
+ )
2026
+
2027
+ return rows
2028
+
2029
+ # ---------- Small helpers ------------------------------------------------
2030
+
2031
+ @staticmethod
2032
+ def _is_waste_process(activity: Activity) -> bool:
2033
+ for exc in activity.production():
2034
+ if exc["amount"] < 0:
2035
+ return True
2036
+ return False
2037
+
2038
+ @staticmethod
2039
+ def _key(a: Activity) -> Tuple[str, str, str]:
2040
+ return a["name"], a.get("reference product"), a.get("location")
2041
+
2042
+ def plot_sankey(self, df: pd.DataFrame, **kwargs):
2043
+ """Convenience method: EdgeSupplyChainScorer.plot_sankey(df, ...)."""
2044
+ return sankey_from_supply_df(df, **kwargs)
2045
+
2046
+ def save_html(self, df: pd.DataFrame, path: str, **plot_kwargs) -> str:
2047
+ """
2048
+ Build the Sankey from `df` with plot kwargs, then save to HTML.
2049
+ Returns the final file path.
2050
+ """
2051
+ fig = self.plot_sankey(df, **plot_kwargs)
2052
+ return save_sankey_html(fig, path)