desdeo 1.2__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/adm/ADMAfsar.py +551 -0
- desdeo/adm/ADMChen.py +414 -0
- desdeo/adm/BaseADM.py +119 -0
- desdeo/adm/__init__.py +11 -0
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +50 -0
- desdeo/api/config.py +90 -0
- desdeo/api/config.toml +64 -0
- desdeo/api/db.py +27 -0
- desdeo/api/db_init.py +85 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +266 -0
- desdeo/api/models/archive.py +23 -0
- desdeo/api/models/emo.py +128 -0
- desdeo/api/models/enautilus.py +69 -0
- desdeo/api/models/gdm/gdm_aggregate.py +139 -0
- desdeo/api/models/gdm/gdm_base.py +69 -0
- desdeo/api/models/gdm/gdm_score_bands.py +114 -0
- desdeo/api/models/gdm/gnimbus.py +138 -0
- desdeo/api/models/generic.py +104 -0
- desdeo/api/models/generic_states.py +401 -0
- desdeo/api/models/nimbus.py +158 -0
- desdeo/api/models/preference.py +128 -0
- desdeo/api/models/problem.py +717 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +49 -0
- desdeo/api/models/state.py +463 -0
- desdeo/api/models/user.py +52 -0
- desdeo/api/models/utopia.py +25 -0
- desdeo/api/routers/_EMO.backup +309 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +765 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/emo.py +497 -0
- desdeo/api/routers/enautilus.py +237 -0
- desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
- desdeo/api/routers/gdm/gdm_base.py +420 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
- desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
- desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
- desdeo/api/routers/generic.py +233 -0
- desdeo/api/routers/nimbus.py +705 -0
- desdeo/api/routers/problem.py +307 -0
- desdeo/api/routers/reference_point_method.py +93 -0
- desdeo/api/routers/session.py +100 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +520 -0
- desdeo/api/routers/utils.py +187 -0
- desdeo/api/routers/utopia.py +230 -0
- desdeo/api/schema.py +100 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +151 -0
- desdeo/api/tests/test_enautilus.py +330 -0
- desdeo/api/tests/test_models.py +1179 -0
- desdeo/api/tests/test_routes.py +1075 -0
- desdeo/api/utils/_database.py +263 -0
- desdeo/api/utils/_logger.py +29 -0
- desdeo/api/utils/database.py +36 -0
- desdeo/api/utils/emo_database.py +40 -0
- desdeo/core.py +34 -0
- desdeo/emo/__init__.py +159 -0
- desdeo/emo/hooks/archivers.py +188 -0
- desdeo/emo/methods/EAs.py +541 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +12 -0
- desdeo/emo/methods/templates.py +111 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +1282 -0
- desdeo/emo/operators/evaluator.py +114 -0
- desdeo/emo/operators/generator.py +459 -0
- desdeo/emo/operators/mutation.py +1224 -0
- desdeo/emo/operators/scalar_selection.py +202 -0
- desdeo/emo/operators/selection.py +1778 -0
- desdeo/emo/operators/termination.py +286 -0
- desdeo/emo/options/__init__.py +108 -0
- desdeo/emo/options/algorithms.py +435 -0
- desdeo/emo/options/crossover.py +164 -0
- desdeo/emo/options/generator.py +131 -0
- desdeo/emo/options/mutation.py +260 -0
- desdeo/emo/options/repair.py +61 -0
- desdeo/emo/options/scalar_selection.py +66 -0
- desdeo/emo/options/selection.py +127 -0
- desdeo/emo/options/templates.py +383 -0
- desdeo/emo/options/termination.py +143 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/gdm/__init__.py +22 -0
- desdeo/gdm/gdmtools.py +45 -0
- desdeo/gdm/score_bands.py +114 -0
- desdeo/gdm/voting_rules.py +50 -0
- desdeo/mcdm/__init__.py +41 -0
- desdeo/mcdm/enautilus.py +338 -0
- desdeo/mcdm/gnimbus.py +484 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +656 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +186 -0
- desdeo/problem/__init__.py +83 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/external/__init__.py +18 -0
- desdeo/problem/external/core.py +356 -0
- desdeo/problem/external/pymoo_provider.py +266 -0
- desdeo/problem/external/runtime.py +44 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +487 -0
- desdeo/problem/schema.py +1829 -0
- desdeo/problem/simulator_evaluator.py +348 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +88 -0
- desdeo/problem/testproblems/benchmarks_server.py +120 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/cake_problem.py +185 -0
- desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +283 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problems.py +440 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/single_objective.py +289 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +274 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +120 -0
- desdeo/tools/desc_gen.py +22 -0
- desdeo/tools/generics.py +165 -0
- desdeo/tools/group_scalarization.py +3090 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +117 -0
- desdeo/tools/indicators_unary.py +362 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +265 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +134 -0
- desdeo/tools/patterns.py +283 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +477 -0
- desdeo/tools/reference_vectors.py +229 -0
- desdeo/tools/scalarization.py +2065 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +627 -0
- desdeo/tools/utils.py +388 -0
- desdeo/tools/visualizations.py +67 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.1.0.dist-info/METADATA +186 -0
- desdeo-2.1.0.dist-info/RECORD +180 -0
- {desdeo-1.2.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
- desdeo-2.1.0.dist-info/licenses/LICENSE +21 -0
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
"""Use the auto_SCORE function to generate the SCORE bands visualization.
|
|
2
|
+
|
|
3
|
+
This module contains the functions which generate SCORE bands visualizations. It also contains functions to calculate
|
|
4
|
+
the order and positions of the objective axes, as well as a heatmap of correlation matrix.
|
|
5
|
+
|
|
6
|
+
To run the SCORE bands visualization, use the `score_json` function to generate the data for the visualization, and then
|
|
7
|
+
use the `plot_score` function to generate the figure. You can also pass the result of `score_json` to other frontends
|
|
8
|
+
for visualization.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from copy import deepcopy
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Literal
|
|
14
|
+
from warnings import warn
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import plotly.figure_factory as ff
|
|
18
|
+
import plotly.graph_objects as go
|
|
19
|
+
import polars as pl
|
|
20
|
+
from matplotlib import cm
|
|
21
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
22
|
+
from scipy.stats import pearsonr
|
|
23
|
+
from sklearn.cluster import DBSCAN
|
|
24
|
+
from sklearn.metrics import silhouette_score
|
|
25
|
+
from sklearn.mixture import GaussianMixture
|
|
26
|
+
from sklearn.preprocessing import StandardScaler
|
|
27
|
+
from tsp_solver.greedy import solve_tsp
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GMMOptions(BaseModel):
|
|
31
|
+
"""Options for Gaussian Mixture Model clustering algorithm."""
|
|
32
|
+
|
|
33
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
34
|
+
|
|
35
|
+
name: str = Field(default="GMM")
|
|
36
|
+
"""Gaussian Mixture Model clustering algorithm."""
|
|
37
|
+
scoring_method: Literal["BIC", "silhouette"] = Field(default="silhouette")
|
|
38
|
+
"""Scoring method to use for GMM. Either "BIC" or "silhouette". Defaults to "silhouette".
|
|
39
|
+
This option determines how the number of clusters is chosen."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DBSCANOptions(BaseModel):
|
|
43
|
+
"""Options for DBSCAN clustering algorithm."""
|
|
44
|
+
|
|
45
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
46
|
+
|
|
47
|
+
name: str = Field(default="DBSCAN")
|
|
48
|
+
"""DBSCAN clustering algorithm."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class KMeansOptions(BaseModel):
|
|
52
|
+
"""Options for KMeans clustering algorithm."""
|
|
53
|
+
|
|
54
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
55
|
+
|
|
56
|
+
name: str = Field(default="KMeans")
|
|
57
|
+
"""KMeans clustering algorithm."""
|
|
58
|
+
n_clusters: int = Field(default=5)
|
|
59
|
+
"""Number of clusters to use. Defaults to 5."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DimensionClusterOptions(BaseModel):
|
|
63
|
+
"""Options for clustering by one of the objectives/decision variables."""
|
|
64
|
+
|
|
65
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
66
|
+
|
|
67
|
+
name: str = Field(default="DimensionCluster")
|
|
68
|
+
"""Clustering by one of the dimensions."""
|
|
69
|
+
dimension_name: str
|
|
70
|
+
"""Dimension to use for clustering."""
|
|
71
|
+
n_clusters: int = Field(default=5)
|
|
72
|
+
"""Number of clusters to use. Defaults to 5."""
|
|
73
|
+
kind: Literal["EqualWidth", "EqualFrequency"] = Field(default="EqualWidth")
|
|
74
|
+
"""Kind of clustering to use. Either "EqualWidth", which divides the dimension range into equal width intervals,
|
|
75
|
+
or "EqualFrequency", which divides the dimension values into intervals with equal number of solutions.
|
|
76
|
+
Defaults to "EqualWidth"."""
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class CustomClusterOptions(BaseModel):
|
|
80
|
+
"""Options for custom clustering provided by the user."""
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
83
|
+
|
|
84
|
+
name: str = Field(default="Custom")
|
|
85
|
+
"""Custom user-provided clusters."""
|
|
86
|
+
clusters: list[int]
|
|
87
|
+
"""List of cluster IDs (one for each solution) indicating the cluster to which each solution belongs."""
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
ClusteringOptions = GMMOptions | DBSCANOptions | KMeansOptions | DimensionClusterOptions | CustomClusterOptions
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DistanceFormula(int, Enum):
|
|
94
|
+
"""Distance formulas supported by SCORE bands. See the paper for details."""
|
|
95
|
+
|
|
96
|
+
FORMULA_1 = 1
|
|
97
|
+
FORMULA_2 = 2
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class SCOREBandsConfig(BaseModel):
|
|
101
|
+
"""Configuration options for SCORE bands visualization."""
|
|
102
|
+
|
|
103
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
104
|
+
|
|
105
|
+
dimensions: list[str] | None = Field(default=None)
|
|
106
|
+
"""List of variable/objective names (i.e., column names in the data) to include in the visualization.
|
|
107
|
+
If None, all columns in the data are used. Defaults to None."""
|
|
108
|
+
descriptive_names: dict[str, str] | None = Field(default=None)
|
|
109
|
+
"""Optional dictionary mapping dimensions to descriptive names for display in the visualization.
|
|
110
|
+
If None, the original dimension names are used. Defaults to None."""
|
|
111
|
+
units: dict[str, str] | None = Field(default=None)
|
|
112
|
+
"""Optional dictionary mapping dimensions to their units for display in the visualization.
|
|
113
|
+
If None, no units are displayed. Defaults to None."""
|
|
114
|
+
axis_positions: dict[str, float] | None = Field(default=None)
|
|
115
|
+
"""Dictionary mapping objective names to their positions on the axes in the SCORE bands visualization. The first
|
|
116
|
+
objective is at position 0.0, and the last objective is at position 1.0. Use this option if you want to
|
|
117
|
+
manually set the axis positions. If None, the axis positions are calculated automatically based on correlations.
|
|
118
|
+
Defaults to None."""
|
|
119
|
+
clustering_algorithm: ClusteringOptions = Field(
|
|
120
|
+
default=DBSCANOptions(),
|
|
121
|
+
)
|
|
122
|
+
"""
|
|
123
|
+
Clustering algorithm to use. Currently supported options: "GMM", "DBSCAN",
|
|
124
|
+
and "KMeans". Defaults to "DBSCAN".
|
|
125
|
+
"""
|
|
126
|
+
distance_formula: DistanceFormula = Field(default=DistanceFormula.FORMULA_1)
|
|
127
|
+
"""Distance formula to use. The value should be 1 or 2. Check the paper for details. Defaults to 1."""
|
|
128
|
+
distance_parameter: float = Field(default=0.05)
|
|
129
|
+
"""Change the relative distances between the objective axes. Increase this value if objectives are placed too close
|
|
130
|
+
together. Decrease this value if the objectives are equidistant in a problem with objective clusters. Defaults
|
|
131
|
+
to 0.05."""
|
|
132
|
+
use_absolute_correlations: bool = Field(default=False)
|
|
133
|
+
"""Whether to use absolute value of the correlation to calculate the placement of axes. Defaults to False."""
|
|
134
|
+
include_solutions: bool = Field(default=False)
|
|
135
|
+
"""Whether to include individual solutions. Defaults to False. If True, the size of the resulting figure may be
|
|
136
|
+
very large for datasets with many solutions. Moreover, the individual traces are hidden by default, but can be
|
|
137
|
+
viewed interactively in the figure."""
|
|
138
|
+
include_medians: bool = Field(default=False)
|
|
139
|
+
"""Whether to include cluster medians. Defaults to False. If True, the median traces are hidden by default, but
|
|
140
|
+
can be viewed interactively in the figure."""
|
|
141
|
+
interval_size: float = Field(default=0.95)
|
|
142
|
+
"""The size (as a fraction) of the interval to use for the bands. Defaults to 0.95, meaning that 95% of the
|
|
143
|
+
middle solutions in a cluster will be included in the band. The rest will be considered outliers."""
|
|
144
|
+
scales: dict[str, tuple[float, float]] | None = Field(default=None)
|
|
145
|
+
"""Optional dictionary specifying the min and max values for each objective. The keys should be the
|
|
146
|
+
objective names (i.e., column names in the data), and the values should be tuples of (min, max).
|
|
147
|
+
If not provided, the min and max will be calculated from the data."""
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class SCOREBandsResult(BaseModel):
|
|
151
|
+
"""Pydantic/JSON model for representing SCORE Bands."""
|
|
152
|
+
|
|
153
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
|
154
|
+
|
|
155
|
+
options: SCOREBandsConfig
|
|
156
|
+
"""Configuration options used to generate the SCORE bands."""
|
|
157
|
+
ordered_dimensions: list[str]
|
|
158
|
+
"""List of variable/objective names (i.e., column names in the data).
|
|
159
|
+
Ordered according to their placement in the SCORE bands visualization."""
|
|
160
|
+
clusters: list[int]
|
|
161
|
+
"""List of cluster IDs (one for each solution) indicating the cluster to which each solution belongs."""
|
|
162
|
+
axis_positions: dict[str, float]
|
|
163
|
+
"""Dictionary mapping objective names to their positions on the axes in the SCORE bands visualization. The first
|
|
164
|
+
objective is at position 0.0, and the last objective is at position 1.0."""
|
|
165
|
+
bands: dict[int, dict[str, tuple[float, float]]]
|
|
166
|
+
"""Dictionary mapping cluster IDs to dictionaries of objective names and their corresponding band
|
|
167
|
+
extremes (min, max)."""
|
|
168
|
+
medians: dict[int, dict[str, float]]
|
|
169
|
+
"""Dictionary mapping cluster IDs to dictionaries of objective names and their corresponding median values."""
|
|
170
|
+
cardinalities: dict[int, int]
|
|
171
|
+
"""Dictionary mapping cluster IDs to the number of solutions in each cluster."""
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _gaussianmixtureclusteringwithBIC(data: pl.DataFrame) -> np.ndarray:
|
|
175
|
+
"""Cluster the data using Gaussian Mixture Model with BIC scoring."""
|
|
176
|
+
data_copy = data.to_numpy()
|
|
177
|
+
data_copy = StandardScaler().fit_transform(data_copy)
|
|
178
|
+
lowest_bic = np.inf
|
|
179
|
+
bic = []
|
|
180
|
+
n_components_range = range(1, min(11, len(data_copy)))
|
|
181
|
+
cv_types: list[Literal["full", "tied", "diag", "spherical"]] = ["spherical", "tied", "diag", "full"]
|
|
182
|
+
for cv_type in cv_types:
|
|
183
|
+
for n_components in n_components_range:
|
|
184
|
+
# Fit a Gaussian mixture with EM
|
|
185
|
+
gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)
|
|
186
|
+
gmm.fit(data_copy)
|
|
187
|
+
bic.append(gmm.score(data_copy))
|
|
188
|
+
# bic.append(gmm.bic(data))
|
|
189
|
+
if bic[-1] < lowest_bic:
|
|
190
|
+
lowest_bic = bic[-1]
|
|
191
|
+
best_gmm = gmm
|
|
192
|
+
|
|
193
|
+
return best_gmm.predict(data_copy)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _gaussianmixtureclusteringwithsilhouette(data: pl.DataFrame) -> np.ndarray:
|
|
197
|
+
"""Cluster the data using Gaussian Mixture Model with silhouette scoring."""
|
|
198
|
+
X = StandardScaler().fit_transform(data.to_numpy())
|
|
199
|
+
best_score = -np.inf
|
|
200
|
+
best_labels = np.ones(len(data))
|
|
201
|
+
n_components_range = range(1, min(11, len(data)))
|
|
202
|
+
cv_types: list[Literal["full", "tied", "diag", "spherical"]] = ["spherical", "tied", "diag", "full"]
|
|
203
|
+
for cv_type in cv_types:
|
|
204
|
+
for n_components in n_components_range:
|
|
205
|
+
# Fit a Gaussian mixture with EM
|
|
206
|
+
gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)
|
|
207
|
+
labels = gmm.fit_predict(X)
|
|
208
|
+
try:
|
|
209
|
+
score = silhouette_score(X, labels, metric="cosine")
|
|
210
|
+
except ValueError:
|
|
211
|
+
score = -np.inf
|
|
212
|
+
if score > best_score:
|
|
213
|
+
best_score = score
|
|
214
|
+
best_labels = labels
|
|
215
|
+
# print(best_score)
|
|
216
|
+
return best_labels
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _DBSCANClustering(data: pl.DataFrame) -> np.ndarray:
|
|
220
|
+
"""Cluster the data using DBSCAN with silhouette scoring to choose eps."""
|
|
221
|
+
X = StandardScaler().fit_transform(data.to_numpy())
|
|
222
|
+
eps_options = np.linspace(0.01, 1, 20)
|
|
223
|
+
best_score = -np.inf
|
|
224
|
+
best_labels = np.ones(len(data))
|
|
225
|
+
for eps_option in eps_options:
|
|
226
|
+
db = DBSCAN(eps=eps_option, min_samples=10, metric="cosine").fit(X)
|
|
227
|
+
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
|
|
228
|
+
core_samples_mask[db.core_sample_indices_] = True
|
|
229
|
+
labels = db.labels_
|
|
230
|
+
try:
|
|
231
|
+
score = silhouette_score(X, labels, metric="cosine")
|
|
232
|
+
except ValueError:
|
|
233
|
+
score = -np.inf
|
|
234
|
+
if score > best_score:
|
|
235
|
+
best_score = score
|
|
236
|
+
best_labels = labels
|
|
237
|
+
# print((best_score, chosen_eps))
|
|
238
|
+
return best_labels
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def cluster_by_dimension(data: pl.DataFrame, options: DimensionClusterOptions) -> np.ndarray:
|
|
242
|
+
"""Cluster the data by a specific dimension."""
|
|
243
|
+
if options.dimension_name not in data.columns:
|
|
244
|
+
raise ValueError(f"Objective '{options.dimension_name}' not found in data.")
|
|
245
|
+
|
|
246
|
+
# Select the dimension column for clustering
|
|
247
|
+
dimension = data[options.dimension_name]
|
|
248
|
+
|
|
249
|
+
# Perform clustering based on the specified method
|
|
250
|
+
if options.kind == "EqualWidth":
|
|
251
|
+
min_val: float = dimension.min()
|
|
252
|
+
max_val: float = dimension.max()
|
|
253
|
+
SMALL_VALUE = 1e-8
|
|
254
|
+
thresholds = np.linspace(
|
|
255
|
+
min_val * (1 - SMALL_VALUE), # Ensure the minimum value is included in the first cluster
|
|
256
|
+
max_val * (1 + SMALL_VALUE), # Ensure the maximum value is included in the last cluster
|
|
257
|
+
options.n_clusters + 1,
|
|
258
|
+
)
|
|
259
|
+
return np.digitize(dimension.to_numpy(), thresholds) # Cluster IDs start at 1
|
|
260
|
+
elif options.kind == "EqualFrequency":
|
|
261
|
+
levels: list[float] = [dimension.quantile(i / options.n_clusters) for i in range(1, options.n_clusters)]
|
|
262
|
+
thresholds = [-np.inf] + levels + [np.inf]
|
|
263
|
+
return np.digitize(dimension.to_numpy(), thresholds) # Cluster IDs start at 1
|
|
264
|
+
raise ValueError(f"Unknown clustering kind: {options.kind}")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def cluster(data: pl.DataFrame, options: ClusteringOptions) -> np.ndarray:
|
|
268
|
+
"""Cluster the data using the specified clustering algorithm and options."""
|
|
269
|
+
if isinstance(options, DimensionClusterOptions):
|
|
270
|
+
return cluster_by_dimension(data, options)
|
|
271
|
+
if isinstance(options, KMeansOptions):
|
|
272
|
+
from sklearn.cluster import KMeans
|
|
273
|
+
|
|
274
|
+
X = StandardScaler().fit_transform(data.to_numpy())
|
|
275
|
+
kmeans = KMeans(n_clusters=options.n_clusters, random_state=0).fit(X)
|
|
276
|
+
return kmeans.labels_
|
|
277
|
+
if isinstance(options, DBSCANOptions):
|
|
278
|
+
return _DBSCANClustering(data)
|
|
279
|
+
if isinstance(options, GMMOptions):
|
|
280
|
+
if options.scoring_method == "silhouette":
|
|
281
|
+
return _gaussianmixtureclusteringwithsilhouette(data)
|
|
282
|
+
if options.scoring_method == "BIC":
|
|
283
|
+
return _gaussianmixtureclusteringwithBIC(data)
|
|
284
|
+
if isinstance(options, CustomClusterOptions):
|
|
285
|
+
if len(options.clusters) != len(data):
|
|
286
|
+
raise ValueError("Length of custom clusters must match number of solutions in data.")
|
|
287
|
+
return np.array(options.clusters)
|
|
288
|
+
raise ValueError(f"Unknown clustering algorithm: {options}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def annotated_heatmap(correlation_matrix: np.ndarray, col_names: list, order: list | np.ndarray) -> go.Figure:
|
|
292
|
+
"""Create a heatmap of the correlation matrix. Probably should be named something else.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
correlation_matrix (np.ndarray): 2-D square array of correlation values between pairs of objectives.
|
|
296
|
+
col_names (List): Objective names.
|
|
297
|
+
order (Union[List, np.ndarray]): Order in which the objectives are shown in SCORE bands.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
go.Figure: The heatmap
|
|
301
|
+
""" # noqa: D212, D213, D406, D407
|
|
302
|
+
corr = pl.DataFrame(correlation_matrix, index=col_names, columns=col_names)
|
|
303
|
+
corr = corr[col_names[order]].loc[col_names[order[::-1]]]
|
|
304
|
+
corr = np.rint(corr * 100) / 100 # Take upto two significant figures only to make heatmap readable.
|
|
305
|
+
fig = ff.create_annotated_heatmap(
|
|
306
|
+
corr.to_numpy(),
|
|
307
|
+
x=list(corr.columns),
|
|
308
|
+
y=list(corr.index),
|
|
309
|
+
annotation_text=corr.astype(str).to_numpy(),
|
|
310
|
+
)
|
|
311
|
+
fig.update_layout(title="Pearson correlation coefficients")
|
|
312
|
+
return fig
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def order_dimensions(data: pl.DataFrame, use_absolute_corr: bool = False):
|
|
316
|
+
"""Calculate the order of objectives.
|
|
317
|
+
|
|
318
|
+
Also returns the correlation matrix.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
data (pl.DataFrame): Data to be visualized.
|
|
322
|
+
use_absolute_corr (bool, optional): Use absolute value of the correlation to calculate order. Defaults to False.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
tuple: The first element is the correlation matrix. The second element is the order of the objectives.
|
|
326
|
+
"""
|
|
327
|
+
# Calculating correlations
|
|
328
|
+
# corr = spearmanr(data).correlation # Pearson's coeff is better than Spearmann's, in some cases
|
|
329
|
+
corr = np.asarray(
|
|
330
|
+
[
|
|
331
|
+
[pearsonr(data.to_numpy()[:, i], data.to_numpy()[:, j])[0] for j in range(len(data.columns))]
|
|
332
|
+
for i in range(len(data.columns))
|
|
333
|
+
]
|
|
334
|
+
)
|
|
335
|
+
# axes order: solving TSP
|
|
336
|
+
distances = corr
|
|
337
|
+
if use_absolute_corr:
|
|
338
|
+
distances = np.abs(distances)
|
|
339
|
+
obj_order = solve_tsp(-distances)
|
|
340
|
+
return corr, obj_order
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def calculate_axes_positions(
|
|
344
|
+
dimension_order: list[int],
|
|
345
|
+
corr: np.ndarray,
|
|
346
|
+
dist_parameter: float,
|
|
347
|
+
distance_formula: DistanceFormula = DistanceFormula.FORMULA_1,
|
|
348
|
+
) -> np.ndarray:
|
|
349
|
+
"""Calculate the position of the axes for the SCORE bands visualization based on correlations.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
dimension_order (list[int]): Order of the variables to be plotted.
|
|
353
|
+
corr (np.ndarray): Correlation (pearson) matrix.
|
|
354
|
+
dist_parameter (float): Change the relative distances between the axes. Increase this value if the axes are
|
|
355
|
+
placed too close together. Decrease this value if the axes are equidistant.
|
|
356
|
+
distance_formula (DistanceFormula, optional): The value should be 1 or 2. Check the paper for details.
|
|
357
|
+
Defaults to DistanceFormula.FORMULA_1.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
np.ndarray: Positions of the axes in the range [0, 1].
|
|
361
|
+
"""
|
|
362
|
+
# axes positions
|
|
363
|
+
order = np.asarray(list(zip(dimension_order[:-1], dimension_order[1:], strict=True)))
|
|
364
|
+
axis_len = corr[order[:, 0], order[:, 1]]
|
|
365
|
+
if distance_formula == DistanceFormula.FORMULA_1:
|
|
366
|
+
axis_len = 1 - axis_len
|
|
367
|
+
elif distance_formula == DistanceFormula.FORMULA_2:
|
|
368
|
+
axis_len = 1 / (np.abs(axis_len) + 1) # Reciprocal for reverse
|
|
369
|
+
else:
|
|
370
|
+
# Should never reach here
|
|
371
|
+
raise ValueError("distance_formula should be either 1 or 2 (int)")
|
|
372
|
+
axis_len = axis_len + dist_parameter
|
|
373
|
+
axis_len = axis_len / sum(axis_len)
|
|
374
|
+
return np.cumsum(np.append(0, axis_len))
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def score_json(
|
|
378
|
+
data: pl.DataFrame,
|
|
379
|
+
options: SCOREBandsConfig,
|
|
380
|
+
) -> SCOREBandsResult:
|
|
381
|
+
"""Generate the SCORE Bands data for a given dataset and configuration options.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
data (pl.DataFrame): Dataframe of variable (decision or objective) values.
|
|
385
|
+
The column names should be the names of the variables to be plotted. Each row should be a solution.
|
|
386
|
+
|
|
387
|
+
options (SCOREBandsConfig): Configuration options for generating the SCORE bands.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
SCOREBandsResult: The result containing all relevant data for the SCORE bands visualization.
|
|
391
|
+
"""
|
|
392
|
+
options = deepcopy(options)
|
|
393
|
+
# Calculating correlations and axes positions
|
|
394
|
+
if options.dimensions is None:
|
|
395
|
+
options.dimensions = data.columns
|
|
396
|
+
data_copy = data.select([pl.col(col) for col in options.dimensions])
|
|
397
|
+
|
|
398
|
+
if options.axis_positions is None:
|
|
399
|
+
corr, dimension_order = order_dimensions(data_copy, use_absolute_corr=options.use_absolute_correlations)
|
|
400
|
+
|
|
401
|
+
axis_dist = calculate_axes_positions(
|
|
402
|
+
dimension_order,
|
|
403
|
+
corr,
|
|
404
|
+
dist_parameter=options.distance_parameter,
|
|
405
|
+
distance_formula=options.distance_formula,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
ordered_dimension_names = [data_copy.columns[i] for i in dimension_order]
|
|
409
|
+
axis_positions = {name: axis_dist[i] for i, name in enumerate(ordered_dimension_names)}
|
|
410
|
+
else:
|
|
411
|
+
axis_positions = options.axis_positions
|
|
412
|
+
ordered_dimension_names = sorted(axis_positions.keys(), key=axis_positions.get)
|
|
413
|
+
|
|
414
|
+
clusters = cluster(data_copy, options.clustering_algorithm)
|
|
415
|
+
|
|
416
|
+
if min(clusters) <= 0:
|
|
417
|
+
clusters = clusters - np.min(clusters) + 1 # translate minimum to 1.
|
|
418
|
+
|
|
419
|
+
# some sanity check: check if all cluster IDs are contiguous integers starting at 1, ending at number of clusters
|
|
420
|
+
unique_clusters = np.unique(clusters)
|
|
421
|
+
max_cluster_id = max(clusters)
|
|
422
|
+
if not all(i in unique_clusters for i in range(1, max_cluster_id + 1)):
|
|
423
|
+
warn(
|
|
424
|
+
"""Cluster IDs are not contiguous integers starting at 1.
|
|
425
|
+
This may cause issues with the color mapping in the visualization.""",
|
|
426
|
+
category=UserWarning,
|
|
427
|
+
stacklevel=2,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
cluster_column_name = "cluster"
|
|
431
|
+
if cluster_column_name in data_copy.columns:
|
|
432
|
+
cluster_column_name = "cluster_id"
|
|
433
|
+
|
|
434
|
+
data_copy = data_copy.with_columns(pl.Series(cluster_column_name, clusters))
|
|
435
|
+
grouped = data_copy.group_by(cluster_column_name)
|
|
436
|
+
min_percentile = (1 - options.interval_size) / 2
|
|
437
|
+
max_percentile = 1 - min_percentile
|
|
438
|
+
mins = grouped.quantile(min_percentile)
|
|
439
|
+
maxs = grouped.quantile(max_percentile)
|
|
440
|
+
medians = grouped.median()
|
|
441
|
+
frequencies = grouped.len()
|
|
442
|
+
bands_dict = {
|
|
443
|
+
cluster_id: {
|
|
444
|
+
col_name: (
|
|
445
|
+
mins.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0],
|
|
446
|
+
maxs.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0],
|
|
447
|
+
)
|
|
448
|
+
for col_name in ordered_dimension_names
|
|
449
|
+
}
|
|
450
|
+
for cluster_id in mins[cluster_column_name].to_list()
|
|
451
|
+
}
|
|
452
|
+
medians_dict = {
|
|
453
|
+
cluster_id: {
|
|
454
|
+
col_name: medians.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0]
|
|
455
|
+
for col_name in ordered_dimension_names
|
|
456
|
+
}
|
|
457
|
+
for cluster_id in medians[cluster_column_name].to_list()
|
|
458
|
+
}
|
|
459
|
+
frequencies_dict = {
|
|
460
|
+
cluster_id: frequencies.filter(pl.col(cluster_column_name) == cluster_id)["len"][0]
|
|
461
|
+
for cluster_id in frequencies[cluster_column_name].to_list()
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
if options.scales is None:
|
|
465
|
+
scales: dict[str, tuple[float, float]] = {
|
|
466
|
+
dimension: (data_copy[dimension].min(), data_copy[dimension].max()) for dimension in ordered_dimension_names
|
|
467
|
+
}
|
|
468
|
+
options.scales = scales
|
|
469
|
+
return SCOREBandsResult(
|
|
470
|
+
options=options,
|
|
471
|
+
ordered_dimensions=ordered_dimension_names,
|
|
472
|
+
clusters=clusters.tolist(),
|
|
473
|
+
axis_positions=axis_positions,
|
|
474
|
+
bands=bands_dict,
|
|
475
|
+
medians=medians_dict,
|
|
476
|
+
cardinalities=frequencies_dict,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def plot_score(data: pl.DataFrame, result: SCOREBandsResult) -> go.Figure:
|
|
481
|
+
"""Generate the SCORE Bands figure from the SCOREBandsResult data.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
data (pl.DataFrame): Dataframe of objective values. The column names should be the objective names. Each row
|
|
485
|
+
should be an objective vector.
|
|
486
|
+
result (SCOREBandsResult): The result containing all relevant data for the SCORE bands visualization.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
go.Figure: The SCORE bands plot.
|
|
490
|
+
"""
|
|
491
|
+
column_names = result.ordered_dimensions
|
|
492
|
+
|
|
493
|
+
clusters = np.sort(np.unique(result.clusters))
|
|
494
|
+
|
|
495
|
+
if len(clusters) <= 8:
|
|
496
|
+
colorscale = cm.get_cmap("Accent", len(clusters))
|
|
497
|
+
else:
|
|
498
|
+
colorscale = cm.get_cmap("tab20", len(clusters))
|
|
499
|
+
|
|
500
|
+
if result.options.scales is None:
|
|
501
|
+
raise ValueError("Scales must be provided in the SCOREBandsResult to plot the figure.")
|
|
502
|
+
|
|
503
|
+
scale_min = pl.DataFrame({name: result.options.scales[name][0] for name in result.options.scales})
|
|
504
|
+
scale_max = pl.DataFrame({name: result.options.scales[name][1] for name in result.options.scales})
|
|
505
|
+
|
|
506
|
+
scaled_data = (data[column_names] - scale_min) / (scale_max - scale_min)
|
|
507
|
+
|
|
508
|
+
fig = go.Figure()
|
|
509
|
+
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
|
|
510
|
+
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
|
|
511
|
+
fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
|
|
512
|
+
|
|
513
|
+
cluster_column_name = "cluster"
|
|
514
|
+
if cluster_column_name in scaled_data.columns:
|
|
515
|
+
cluster_column_name = "cluster_id"
|
|
516
|
+
scaled_data = scaled_data.with_columns(pl.Series(cluster_column_name, result.clusters))
|
|
517
|
+
|
|
518
|
+
if result.options.descriptive_names is None:
|
|
519
|
+
descriptive_names = {name: name for name in column_names}
|
|
520
|
+
else:
|
|
521
|
+
descriptive_names = result.options.descriptive_names
|
|
522
|
+
if result.options.units is None:
|
|
523
|
+
units = {name: "" for name in column_names}
|
|
524
|
+
else:
|
|
525
|
+
units = result.options.units
|
|
526
|
+
|
|
527
|
+
num_ticks = 6
|
|
528
|
+
# Add axes
|
|
529
|
+
for i, col_name in enumerate(column_names):
|
|
530
|
+
label_text = np.linspace(result.options.scales[col_name][0], result.options.scales[col_name][1], num_ticks)
|
|
531
|
+
label_text = ["{:.5g}".format(i) for i in label_text]
|
|
532
|
+
# label_text[0] = "<<"
|
|
533
|
+
# label_text[-1] = ">>"
|
|
534
|
+
heights = np.linspace(0, 1, num_ticks)
|
|
535
|
+
# Axis lines
|
|
536
|
+
fig.add_scatter(
|
|
537
|
+
x=[result.axis_positions[col_name]] * num_ticks,
|
|
538
|
+
y=heights,
|
|
539
|
+
text=label_text,
|
|
540
|
+
textposition="middle left",
|
|
541
|
+
mode="markers+lines+text",
|
|
542
|
+
line={"color": "black"},
|
|
543
|
+
showlegend=False,
|
|
544
|
+
)
|
|
545
|
+
# Column Name
|
|
546
|
+
fig.add_scatter(
|
|
547
|
+
x=[result.axis_positions[col_name]],
|
|
548
|
+
y=[1.20],
|
|
549
|
+
text=f"{descriptive_names[col_name]}",
|
|
550
|
+
textfont={"size": 20},
|
|
551
|
+
mode="text",
|
|
552
|
+
showlegend=False,
|
|
553
|
+
)
|
|
554
|
+
# Units
|
|
555
|
+
fig.add_scatter(
|
|
556
|
+
x=[result.axis_positions[col_name]],
|
|
557
|
+
y=[1.10],
|
|
558
|
+
text=f"{units[col_name]}",
|
|
559
|
+
textfont={"size": 12},
|
|
560
|
+
mode="text",
|
|
561
|
+
showlegend=False,
|
|
562
|
+
)
|
|
563
|
+
# Add bands
|
|
564
|
+
for cluster_id in sorted(result.bands.keys()):
|
|
565
|
+
r, g, b, a = colorscale(cluster_id - 1) # Needed as cluster numbering starts at 1
|
|
566
|
+
a = 0.6
|
|
567
|
+
color_bands = f"rgba({r}, {g}, {b}, {a})"
|
|
568
|
+
color_soln = f"rgba({r}, {g}, {b}, {a})"
|
|
569
|
+
|
|
570
|
+
lows = [
|
|
571
|
+
(result.bands[cluster_id][col_name][0] - result.options.scales[col_name][0])
|
|
572
|
+
/ (result.options.scales[col_name][1] - result.options.scales[col_name][0])
|
|
573
|
+
for col_name in column_names
|
|
574
|
+
]
|
|
575
|
+
highs = [
|
|
576
|
+
(result.bands[cluster_id][col_name][1] - result.options.scales[col_name][0])
|
|
577
|
+
/ (result.options.scales[col_name][1] - result.options.scales[col_name][0])
|
|
578
|
+
for col_name in column_names
|
|
579
|
+
]
|
|
580
|
+
medians = [
|
|
581
|
+
(result.medians[cluster_id][col_name] - result.options.scales[col_name][0])
|
|
582
|
+
/ (result.options.scales[col_name][1] - result.options.scales[col_name][0])
|
|
583
|
+
for col_name in column_names
|
|
584
|
+
]
|
|
585
|
+
|
|
586
|
+
fig.add_scatter(
|
|
587
|
+
x=[result.axis_positions[col_name] for col_name in column_names],
|
|
588
|
+
y=lows,
|
|
589
|
+
line={"color": color_bands},
|
|
590
|
+
name=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}; "
|
|
591
|
+
f"{result.cardinalities[cluster_id]} Solutions ",
|
|
592
|
+
mode="lines",
|
|
593
|
+
legendgroup=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}",
|
|
594
|
+
showlegend=True,
|
|
595
|
+
line_shape="spline",
|
|
596
|
+
hovertext=f"Cluster {cluster_id}",
|
|
597
|
+
)
|
|
598
|
+
# upper bound of the band
|
|
599
|
+
fig.add_scatter(
|
|
600
|
+
x=[result.axis_positions[col_name] for col_name in column_names],
|
|
601
|
+
y=highs,
|
|
602
|
+
line={"color": color_bands},
|
|
603
|
+
name=f"Cluster {cluster_id}",
|
|
604
|
+
fillcolor=color_bands,
|
|
605
|
+
mode="lines",
|
|
606
|
+
legendgroup=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}",
|
|
607
|
+
showlegend=False,
|
|
608
|
+
line_shape="spline",
|
|
609
|
+
fill="tonexty",
|
|
610
|
+
hovertext=f"Cluster {cluster_id}",
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
if result.options.include_medians:
|
|
614
|
+
# median
|
|
615
|
+
fig.add_scatter(
|
|
616
|
+
x=[result.axis_positions[col_name] for col_name in column_names],
|
|
617
|
+
y=medians,
|
|
618
|
+
line={"color": color_bands},
|
|
619
|
+
name=f"Median: Cluster {cluster_id}",
|
|
620
|
+
mode="lines+markers",
|
|
621
|
+
marker={"line": {"color": "Black", "width": 2}},
|
|
622
|
+
legendgroup=f"Median: Cluster {cluster_id}",
|
|
623
|
+
showlegend=True,
|
|
624
|
+
)
|
|
625
|
+
fig.update_layout(font_size=18)
|
|
626
|
+
fig.update_layout(legend={"orientation": "h", "yanchor": "top"})
|
|
627
|
+
return fig
|