desdeo 1.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- desdeo/__init__.py +8 -8
- desdeo/api/README.md +73 -0
- desdeo/api/__init__.py +15 -0
- desdeo/api/app.py +40 -0
- desdeo/api/config.py +69 -0
- desdeo/api/config.toml +53 -0
- desdeo/api/db.py +25 -0
- desdeo/api/db_init.py +79 -0
- desdeo/api/db_models.py +164 -0
- desdeo/api/malaga_db_init.py +27 -0
- desdeo/api/models/__init__.py +66 -0
- desdeo/api/models/archive.py +34 -0
- desdeo/api/models/preference.py +90 -0
- desdeo/api/models/problem.py +507 -0
- desdeo/api/models/reference_point_method.py +18 -0
- desdeo/api/models/session.py +46 -0
- desdeo/api/models/state.py +96 -0
- desdeo/api/models/user.py +51 -0
- desdeo/api/routers/_NAUTILUS.py +245 -0
- desdeo/api/routers/_NAUTILUS_navigator.py +233 -0
- desdeo/api/routers/_NIMBUS.py +762 -0
- desdeo/api/routers/__init__.py +5 -0
- desdeo/api/routers/problem.py +110 -0
- desdeo/api/routers/reference_point_method.py +117 -0
- desdeo/api/routers/session.py +76 -0
- desdeo/api/routers/test.py +16 -0
- desdeo/api/routers/user_authentication.py +366 -0
- desdeo/api/schema.py +94 -0
- desdeo/api/tests/__init__.py +0 -0
- desdeo/api/tests/conftest.py +59 -0
- desdeo/api/tests/test_models.py +701 -0
- desdeo/api/tests/test_routes.py +216 -0
- desdeo/api/utils/database.py +274 -0
- desdeo/api/utils/logger.py +29 -0
- desdeo/core.py +27 -0
- desdeo/emo/__init__.py +29 -0
- desdeo/emo/hooks/archivers.py +172 -0
- desdeo/emo/methods/EAs.py +418 -0
- desdeo/emo/methods/__init__.py +0 -0
- desdeo/emo/methods/bases.py +59 -0
- desdeo/emo/operators/__init__.py +1 -0
- desdeo/emo/operators/crossover.py +780 -0
- desdeo/emo/operators/evaluator.py +118 -0
- desdeo/emo/operators/generator.py +356 -0
- desdeo/emo/operators/mutation.py +1053 -0
- desdeo/emo/operators/selection.py +1036 -0
- desdeo/emo/operators/termination.py +178 -0
- desdeo/explanations/__init__.py +6 -0
- desdeo/explanations/explainer.py +100 -0
- desdeo/explanations/utils.py +90 -0
- desdeo/mcdm/__init__.py +19 -0
- desdeo/mcdm/nautili.py +345 -0
- desdeo/mcdm/nautilus.py +477 -0
- desdeo/mcdm/nautilus_navigator.py +655 -0
- desdeo/mcdm/nimbus.py +417 -0
- desdeo/mcdm/pareto_navigator.py +269 -0
- desdeo/mcdm/reference_point_method.py +116 -0
- desdeo/problem/__init__.py +79 -0
- desdeo/problem/evaluator.py +561 -0
- desdeo/problem/gurobipy_evaluator.py +562 -0
- desdeo/problem/infix_parser.py +341 -0
- desdeo/problem/json_parser.py +944 -0
- desdeo/problem/pyomo_evaluator.py +468 -0
- desdeo/problem/schema.py +1808 -0
- desdeo/problem/simulator_evaluator.py +298 -0
- desdeo/problem/sympy_evaluator.py +244 -0
- desdeo/problem/testproblems/__init__.py +73 -0
- desdeo/problem/testproblems/binh_and_korn_problem.py +88 -0
- desdeo/problem/testproblems/dtlz2_problem.py +102 -0
- desdeo/problem/testproblems/forest_problem.py +275 -0
- desdeo/problem/testproblems/knapsack_problem.py +163 -0
- desdeo/problem/testproblems/mcwb_problem.py +831 -0
- desdeo/problem/testproblems/mixed_variable_dimenrions_problem.py +83 -0
- desdeo/problem/testproblems/momip_problem.py +172 -0
- desdeo/problem/testproblems/nimbus_problem.py +143 -0
- desdeo/problem/testproblems/pareto_navigator_problem.py +89 -0
- desdeo/problem/testproblems/re_problem.py +492 -0
- desdeo/problem/testproblems/river_pollution_problem.py +434 -0
- desdeo/problem/testproblems/rocket_injector_design_problem.py +140 -0
- desdeo/problem/testproblems/simple_problem.py +351 -0
- desdeo/problem/testproblems/simulator_problem.py +92 -0
- desdeo/problem/testproblems/spanish_sustainability_problem.py +945 -0
- desdeo/problem/testproblems/zdt_problem.py +271 -0
- desdeo/problem/utils.py +245 -0
- desdeo/tools/GenerateReferencePoints.py +181 -0
- desdeo/tools/__init__.py +102 -0
- desdeo/tools/generics.py +145 -0
- desdeo/tools/gurobipy_solver_interfaces.py +258 -0
- desdeo/tools/indicators_binary.py +11 -0
- desdeo/tools/indicators_unary.py +375 -0
- desdeo/tools/interaction_schema.py +38 -0
- desdeo/tools/intersection.py +54 -0
- desdeo/tools/iterative_pareto_representer.py +99 -0
- desdeo/tools/message.py +234 -0
- desdeo/tools/ng_solver_interfaces.py +199 -0
- desdeo/tools/non_dominated_sorting.py +133 -0
- desdeo/tools/patterns.py +281 -0
- desdeo/tools/proximal_solver.py +99 -0
- desdeo/tools/pyomo_solver_interfaces.py +464 -0
- desdeo/tools/reference_vectors.py +462 -0
- desdeo/tools/scalarization.py +3138 -0
- desdeo/tools/scipy_solver_interfaces.py +454 -0
- desdeo/tools/score_bands.py +464 -0
- desdeo/tools/utils.py +320 -0
- desdeo/utopia_stuff/__init__.py +0 -0
- desdeo/utopia_stuff/data/1.json +15 -0
- desdeo/utopia_stuff/data/2.json +13 -0
- desdeo/utopia_stuff/data/3.json +15 -0
- desdeo/utopia_stuff/data/4.json +17 -0
- desdeo/utopia_stuff/data/5.json +15 -0
- desdeo/utopia_stuff/from_json.py +40 -0
- desdeo/utopia_stuff/reinit_user.py +38 -0
- desdeo/utopia_stuff/utopia_db_init.py +212 -0
- desdeo/utopia_stuff/utopia_problem.py +403 -0
- desdeo/utopia_stuff/utopia_problem_old.py +415 -0
- desdeo/utopia_stuff/utopia_reference_solutions.py +79 -0
- desdeo-2.0.0.dist-info/LICENSE +21 -0
- desdeo-2.0.0.dist-info/METADATA +168 -0
- desdeo-2.0.0.dist-info/RECORD +120 -0
- {desdeo-1.2.dist-info → desdeo-2.0.0.dist-info}/WHEEL +1 -1
- desdeo-1.2.dist-info/METADATA +0 -16
- desdeo-1.2.dist-info/RECORD +0 -4
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""Use the auto_SCORE function to generate the SCORE bands visualization.
|
|
2
|
+
|
|
3
|
+
This module contains the functions which generate SCORE bands visualizations. It also contains functions to calculate
|
|
4
|
+
the order and positions of the objective axes, as well as a heatmap of correlation matrix.
|
|
5
|
+
|
|
6
|
+
This file is just copied from the old SCORE bands repo.
|
|
7
|
+
It is very much out of date and is missing documentation.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import plotly.figure_factory as ff
|
|
13
|
+
import plotly.graph_objects as go
|
|
14
|
+
from matplotlib import cm
|
|
15
|
+
from scipy.stats import pearsonr
|
|
16
|
+
from sklearn.cluster import DBSCAN
|
|
17
|
+
from sklearn.metrics import silhouette_score
|
|
18
|
+
from sklearn.mixture import GaussianMixture
|
|
19
|
+
from sklearn.preprocessing import StandardScaler
|
|
20
|
+
from tsp_solver.greedy import solve_tsp
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _gaussianmixtureclusteringwithBIC(data: pd.DataFrame):
|
|
24
|
+
data = StandardScaler().fit_transform(data)
|
|
25
|
+
lowest_bic = np.inf
|
|
26
|
+
bic = []
|
|
27
|
+
n_components_range = range(1, min(11, len(data)))
|
|
28
|
+
cv_types = ["spherical", "tied", "diag", "full"]
|
|
29
|
+
for cv_type in cv_types:
|
|
30
|
+
for n_components in n_components_range:
|
|
31
|
+
# Fit a Gaussian mixture with EM
|
|
32
|
+
gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)
|
|
33
|
+
gmm.fit(data)
|
|
34
|
+
bic.append(gmm.score(data))
|
|
35
|
+
# bic.append(gmm.bic(data))
|
|
36
|
+
if bic[-1] < lowest_bic:
|
|
37
|
+
lowest_bic = bic[-1]
|
|
38
|
+
best_gmm = gmm
|
|
39
|
+
|
|
40
|
+
return best_gmm.predict(data)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _gaussianmixtureclusteringwithsilhouette(data: pd.DataFrame):
|
|
44
|
+
X = StandardScaler().fit_transform(data)
|
|
45
|
+
best_score = -np.inf
|
|
46
|
+
best_labels = []
|
|
47
|
+
n_components_range = range(1, min(11, len(data)))
|
|
48
|
+
cv_types = ["spherical", "tied", "diag", "full"]
|
|
49
|
+
for cv_type in cv_types:
|
|
50
|
+
for n_components in n_components_range:
|
|
51
|
+
# Fit a Gaussian mixture with EM
|
|
52
|
+
gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)
|
|
53
|
+
labels = gmm.fit_predict(X)
|
|
54
|
+
try:
|
|
55
|
+
score = silhouette_score(X, labels, metric="cosine")
|
|
56
|
+
except ValueError:
|
|
57
|
+
score = -np.inf
|
|
58
|
+
if score > best_score:
|
|
59
|
+
best_score = score
|
|
60
|
+
best_labels = labels
|
|
61
|
+
# print(best_score)
|
|
62
|
+
return best_labels
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _DBSCANClustering(data: pd.DataFrame):
|
|
66
|
+
X = StandardScaler().fit_transform(data)
|
|
67
|
+
eps_options = np.linspace(0.01, 1, 20)
|
|
68
|
+
best_score = -np.inf
|
|
69
|
+
best_labels = [1] * len(X)
|
|
70
|
+
for eps_option in eps_options:
|
|
71
|
+
db = DBSCAN(eps=eps_option, min_samples=10, metric="cosine").fit(X)
|
|
72
|
+
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
|
|
73
|
+
core_samples_mask[db.core_sample_indices_] = True
|
|
74
|
+
labels = db.labels_
|
|
75
|
+
try:
|
|
76
|
+
score = silhouette_score(X, labels, metric="cosine")
|
|
77
|
+
except ValueError:
|
|
78
|
+
score = -np.inf
|
|
79
|
+
if score > best_score:
|
|
80
|
+
best_score = score
|
|
81
|
+
best_labels = labels
|
|
82
|
+
# print((best_score, chosen_eps))
|
|
83
|
+
return best_labels
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def cluster(data: pd.DataFrame, algorithm: str = "DBSCAN", score: str = "silhoutte"):
|
|
87
|
+
if not (score == "silhoutte" or score == "BIC"):
|
|
88
|
+
raise ValueError()
|
|
89
|
+
if not (algorithm == "GMM" or algorithm == "DBSCAN"):
|
|
90
|
+
raise ValueError()
|
|
91
|
+
if algorithm == "DBSCAN":
|
|
92
|
+
return _DBSCANClustering(data)
|
|
93
|
+
if score == "silhoutte":
|
|
94
|
+
return _gaussianmixtureclusteringwithsilhouette(data)
|
|
95
|
+
else:
|
|
96
|
+
return _gaussianmixtureclusteringwithBIC(data)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def SCORE_bands(
|
|
100
|
+
data: pd.DataFrame,
|
|
101
|
+
axis_signs: np.ndarray = None,
|
|
102
|
+
color_groups: list | np.ndarray = None,
|
|
103
|
+
axis_positions: np.ndarray = None,
|
|
104
|
+
solutions: bool = True,
|
|
105
|
+
bands: bool = False,
|
|
106
|
+
medians: bool = False,
|
|
107
|
+
quantile: float = 0.25,
|
|
108
|
+
) -> go.Figure:
|
|
109
|
+
"""Generate SCORE bands figure from the provided data.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
data (pd.DataFrame): Pandas dataframe where each column represents an objective and each row is an objective
|
|
113
|
+
vector. The column names are displayed as the objective names in the generated figure. Each element in the
|
|
114
|
+
dataframe must be numeric.
|
|
115
|
+
|
|
116
|
+
color_groups (Union[List, np.ndarray], optional): List or numpy array of the same length as the number of
|
|
117
|
+
objective vectors. The elements should be contiguous set of integers starting at 1. The element value represents
|
|
118
|
+
the Cluster ID of the corresponding objective vector. Defaults to None (though this behaviour is not fully
|
|
119
|
+
tested yet).
|
|
120
|
+
|
|
121
|
+
axis_positions (np.ndarray, optional): 1-D numpy array of the same length as the number of objectives. The value
|
|
122
|
+
represents the horizontal position of the corresponding objective axes. The value of the first and last element
|
|
123
|
+
should be 0 and 1 respectively, and all intermediate values should lie between 0 and 1.
|
|
124
|
+
Defaults to None, in which case all axes are positioned equidistant.
|
|
125
|
+
|
|
126
|
+
axis_signs (np.ndarray, optional): 1-D Numpy array of the same length as the number of objectives. Each element
|
|
127
|
+
can either be 1 or -1. A value of -1 flips the objective in the SCORE bands visualization. This feature is
|
|
128
|
+
experimental and should be ignored for now. Defaults to None.
|
|
129
|
+
|
|
130
|
+
solutions (bool, optional): Show or hide individual solutions. Defaults to True.
|
|
131
|
+
|
|
132
|
+
bands (bool, optional): Show or hide cluster bands. Defaults to False.
|
|
133
|
+
|
|
134
|
+
medians (bool, optional): Show or hide cluster medians. Defaults to False.
|
|
135
|
+
|
|
136
|
+
quantile (float, optional): The quantile value to calculate the band. The band represents the range between
|
|
137
|
+
(quantile) and (1 - quantile) quantiles of the objective values. Defaults to 0.25.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
go.Figure: SCORE bands plot.
|
|
141
|
+
|
|
142
|
+
"""
|
|
143
|
+
# show on render
|
|
144
|
+
show_solutions = "legendonly"
|
|
145
|
+
bands_visible = True
|
|
146
|
+
if bands:
|
|
147
|
+
show_medians = "legendonly"
|
|
148
|
+
if medians:
|
|
149
|
+
show_medians = True
|
|
150
|
+
# pio.templates.default = "simple_white"
|
|
151
|
+
column_names = data.columns
|
|
152
|
+
num_columns = len(column_names)
|
|
153
|
+
if axis_positions is None:
|
|
154
|
+
axis_positions = np.linspace(0, 1, num_columns)
|
|
155
|
+
if axis_signs is None:
|
|
156
|
+
axis_signs = np.ones_like(axis_positions)
|
|
157
|
+
if color_groups is None:
|
|
158
|
+
color_groups = "continuous"
|
|
159
|
+
colorscale = cm.get_cmap("viridis")
|
|
160
|
+
elif isinstance(color_groups, (np.ndarray, list)):
|
|
161
|
+
groups = list(np.unique(color_groups))
|
|
162
|
+
if len(groups) <= 8:
|
|
163
|
+
colorscale = cm.get_cmap("Accent", len(groups))
|
|
164
|
+
# print(len(groups))
|
|
165
|
+
# print("hi!")
|
|
166
|
+
else:
|
|
167
|
+
colorscale = cm.get_cmap("tab20", len(groups))
|
|
168
|
+
# colorscale = cm.get_cmap("viridis_r", len(groups))
|
|
169
|
+
data = data * axis_signs
|
|
170
|
+
num_labels = 6
|
|
171
|
+
|
|
172
|
+
# Scaling the objective values between 0 and 1.
|
|
173
|
+
scaled_data = data - data.min(axis=0)
|
|
174
|
+
scaled_data = scaled_data / scaled_data.max(axis=0)
|
|
175
|
+
scales = pd.DataFrame([data.min(axis=0), data.max(axis=0)], index=["min", "max"]) * axis_signs
|
|
176
|
+
|
|
177
|
+
fig = go.Figure()
|
|
178
|
+
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
|
|
179
|
+
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
|
|
180
|
+
fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
|
|
181
|
+
|
|
182
|
+
scaled_data.insert(0, "group", value=color_groups)
|
|
183
|
+
for cluster_id, solns in scaled_data.groupby("group"):
|
|
184
|
+
# TODO: Many things here are very inefficient. Improve when free.
|
|
185
|
+
num_solns = len(solns)
|
|
186
|
+
|
|
187
|
+
r, g, b, a = colorscale(cluster_id - 1) # Needed as cluster numbering starts at 1
|
|
188
|
+
a = 0.6
|
|
189
|
+
a_soln = 0.6
|
|
190
|
+
color_bands = f"rgba({r}, {g}, {b}, {a})"
|
|
191
|
+
color_soln = f"rgba({r}, {g}, {b}, {a_soln})"
|
|
192
|
+
|
|
193
|
+
low = solns.drop("group", axis=1).quantile(quantile)
|
|
194
|
+
high = solns.drop("group", axis=1).quantile(1 - quantile)
|
|
195
|
+
median = solns.drop("group", axis=1).median()
|
|
196
|
+
|
|
197
|
+
if bands is True:
|
|
198
|
+
# lower bound of the band
|
|
199
|
+
fig.add_scatter(
|
|
200
|
+
x=axis_positions,
|
|
201
|
+
y=low,
|
|
202
|
+
line={"color": color_bands},
|
|
203
|
+
name=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}; {num_solns} Solutions ",
|
|
204
|
+
mode="lines",
|
|
205
|
+
legendgroup=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}",
|
|
206
|
+
showlegend=True,
|
|
207
|
+
line_shape="spline",
|
|
208
|
+
hovertext=f"Cluster {cluster_id}",
|
|
209
|
+
visible=bands_visible,
|
|
210
|
+
)
|
|
211
|
+
# upper bound of the band
|
|
212
|
+
fig.add_scatter(
|
|
213
|
+
x=axis_positions,
|
|
214
|
+
y=high,
|
|
215
|
+
line={"color": color_bands},
|
|
216
|
+
name=f"Cluster {cluster_id}",
|
|
217
|
+
fillcolor=color_bands,
|
|
218
|
+
mode="lines",
|
|
219
|
+
legendgroup=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}",
|
|
220
|
+
showlegend=False,
|
|
221
|
+
line_shape="spline",
|
|
222
|
+
fill="tonexty",
|
|
223
|
+
hovertext=f"Cluster {cluster_id}",
|
|
224
|
+
visible=bands_visible,
|
|
225
|
+
)
|
|
226
|
+
if medians is True:
|
|
227
|
+
# median
|
|
228
|
+
fig.add_scatter(
|
|
229
|
+
x=axis_positions,
|
|
230
|
+
y=median,
|
|
231
|
+
line={"color": color_bands},
|
|
232
|
+
name=f"Median: Cluster {cluster_id}",
|
|
233
|
+
mode="lines+markers",
|
|
234
|
+
marker={"line": {"color": "Black", "width": 2}},
|
|
235
|
+
legendgroup=f"Median: Cluster {cluster_id}",
|
|
236
|
+
showlegend=True,
|
|
237
|
+
visible=show_medians,
|
|
238
|
+
)
|
|
239
|
+
if solutions is True:
|
|
240
|
+
# individual solutions
|
|
241
|
+
legend = True
|
|
242
|
+
for _, soln in solns.drop("group", axis=1).iterrows():
|
|
243
|
+
fig.add_scatter(
|
|
244
|
+
x=axis_positions,
|
|
245
|
+
y=soln,
|
|
246
|
+
line={"color": color_soln},
|
|
247
|
+
name=f"Solutions: Cluster {cluster_id} ",
|
|
248
|
+
legendgroup=f"Solutions: Cluster {cluster_id}",
|
|
249
|
+
showlegend=legend,
|
|
250
|
+
visible=show_solutions,
|
|
251
|
+
)
|
|
252
|
+
legend = False
|
|
253
|
+
# Axis lines
|
|
254
|
+
for i, col_name in enumerate(column_names):
|
|
255
|
+
# better = "Upper" if axis_signs[i] == -1 else "Lower"
|
|
256
|
+
label_text = np.linspace(scales[col_name]["min"], scales[col_name]["max"], num_labels)
|
|
257
|
+
# label_text = ["{:.3g}".format(i) for i in label_text]
|
|
258
|
+
heights = np.linspace(0, 1, num_labels)
|
|
259
|
+
scale_factors = []
|
|
260
|
+
for current_label in label_text:
|
|
261
|
+
try:
|
|
262
|
+
with np.errstate(divide="ignore"):
|
|
263
|
+
scale_factors.append(int(np.floor(np.log10(np.abs(current_label)))))
|
|
264
|
+
except OverflowError:
|
|
265
|
+
pass
|
|
266
|
+
|
|
267
|
+
scale_factor = int(np.median(scale_factors))
|
|
268
|
+
if scale_factor == -1 or scale_factor == 1:
|
|
269
|
+
scale_factor = 0
|
|
270
|
+
|
|
271
|
+
# TODO: This sometimes doesn't generate the correct label text. Check with datasets where objs lie between (0,1).
|
|
272
|
+
label_text = label_text / 10 ** (scale_factor)
|
|
273
|
+
label_text = ["{:.1f}".format(i) for i in label_text]
|
|
274
|
+
scale_factor_text = f"e{scale_factor}" if scale_factor != 0 else ""
|
|
275
|
+
|
|
276
|
+
# Bottom axis label
|
|
277
|
+
fig.add_scatter(
|
|
278
|
+
x=[axis_positions[i]],
|
|
279
|
+
y=[heights[0]],
|
|
280
|
+
text=[label_text[0] + scale_factor_text],
|
|
281
|
+
textposition="bottom center",
|
|
282
|
+
mode="text",
|
|
283
|
+
line={"color": "black"},
|
|
284
|
+
showlegend=False,
|
|
285
|
+
)
|
|
286
|
+
# Top axis label
|
|
287
|
+
fig.add_scatter(
|
|
288
|
+
x=[axis_positions[i]],
|
|
289
|
+
y=[heights[-1]],
|
|
290
|
+
text=[label_text[-1] + scale_factor_text],
|
|
291
|
+
textposition="top center",
|
|
292
|
+
mode="text",
|
|
293
|
+
line={"color": "black"},
|
|
294
|
+
showlegend=False,
|
|
295
|
+
)
|
|
296
|
+
label_text[0] = ""
|
|
297
|
+
label_text[-1] = ""
|
|
298
|
+
# Intermediate axes labels
|
|
299
|
+
fig.add_scatter(
|
|
300
|
+
x=[axis_positions[i]] * num_labels,
|
|
301
|
+
y=heights,
|
|
302
|
+
text=label_text,
|
|
303
|
+
textposition="middle left",
|
|
304
|
+
mode="markers+lines+text",
|
|
305
|
+
line={"color": "black"},
|
|
306
|
+
showlegend=False,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
fig.add_scatter(
|
|
310
|
+
x=[axis_positions[i]],
|
|
311
|
+
y=[1.10],
|
|
312
|
+
text=f"{col_name}",
|
|
313
|
+
textfont={"size": 20},
|
|
314
|
+
mode="text",
|
|
315
|
+
showlegend=False,
|
|
316
|
+
)
|
|
317
|
+
"""fig.add_scatter(
|
|
318
|
+
x=[axis_positions[i]], y=[1.1], text=better, mode="text", showlegend=False,
|
|
319
|
+
)
|
|
320
|
+
fig.add_scatter(
|
|
321
|
+
x=[axis_positions[i]],
|
|
322
|
+
y=[1.05],
|
|
323
|
+
text="is better",
|
|
324
|
+
mode="text",
|
|
325
|
+
showlegend=False,
|
|
326
|
+
)"""
|
|
327
|
+
fig.update_layout(font_size=18)
|
|
328
|
+
fig.update_layout(legend={"orientation": "h", "yanchor": "top", "font": {"size": 24}})
|
|
329
|
+
return fig
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def annotated_heatmap(correlation_matrix: np.ndarray, col_names: list, order: list | np.ndarray) -> go.Figure:
|
|
333
|
+
"""Create a heatmap of the correlation matrix. Probably should be named something else.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
correlation_matrix (np.ndarray): 2-D square array of correlation values between pairs of objectives.
|
|
337
|
+
col_names (List): Objective names.
|
|
338
|
+
order (Union[List, np.ndarray]): Order in which the objectives are shown in SCORE bands.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
go.Figure: The heatmap
|
|
342
|
+
""" # noqa: D212, D213, D406, D407
|
|
343
|
+
corr = pd.DataFrame(correlation_matrix, index=col_names, columns=col_names)
|
|
344
|
+
corr = corr[col_names[order]].loc[col_names[order[::-1]]]
|
|
345
|
+
corr = np.rint(corr * 100) / 100 # Take upto two significant figures only to make heatmap readable.
|
|
346
|
+
fig = ff.create_annotated_heatmap(
|
|
347
|
+
corr.to_numpy(),
|
|
348
|
+
x=list(corr.columns),
|
|
349
|
+
y=list(corr.index),
|
|
350
|
+
annotation_text=corr.astype(str).to_numpy(),
|
|
351
|
+
)
|
|
352
|
+
fig.update_layout(title="Pearson correlation coefficients")
|
|
353
|
+
return fig
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def order_objectives(data: pd.DataFrame, use_absolute_corr: bool = False):
|
|
357
|
+
"""Calculate the order of objectives.
|
|
358
|
+
|
|
359
|
+
Also returns the correlation matrix.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
data (pd.DataFrame): Data to be visualized.
|
|
363
|
+
use_absolute_corr (bool, optional): Use absolute value of the correlation to calculate order. Defaults to False.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
tuple: The first element is the correlation matrix. The second element is the order of the objectives.
|
|
367
|
+
"""
|
|
368
|
+
# Calculating correlations
|
|
369
|
+
# corr = spearmanr(data).correlation # Pearson's coeff is better than Spearmann's, in some cases
|
|
370
|
+
corr = np.asarray(
|
|
371
|
+
[
|
|
372
|
+
[pearsonr(data.to_numpy()[:, i], data.to_numpy()[:, j])[0] for j in range(len(data.columns))]
|
|
373
|
+
for i in range(len(data.columns))
|
|
374
|
+
]
|
|
375
|
+
)
|
|
376
|
+
# axes order: solving TSP
|
|
377
|
+
distances = corr
|
|
378
|
+
if use_absolute_corr:
|
|
379
|
+
distances = np.abs(distances)
|
|
380
|
+
obj_order = solve_tsp(-distances)
|
|
381
|
+
return corr, obj_order
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def calculate_axes_positions(data, obj_order, corr, dist_parameter, distance_formula: int = 1):
|
|
385
|
+
# axes positions
|
|
386
|
+
order = np.asarray(list((zip(obj_order[:-1], obj_order[1:]))))
|
|
387
|
+
axis_len = corr[order[:, 0], order[:, 1]]
|
|
388
|
+
if distance_formula == 1:
|
|
389
|
+
axis_len = 1 - axis_len # TODO Make this formula available to the user
|
|
390
|
+
elif distance_formula == 2:
|
|
391
|
+
axis_len = 1 / (np.abs(axis_len) + 1) # Reciprocal for reverse
|
|
392
|
+
else:
|
|
393
|
+
raise ValueError("distance_formula should be either 1 or 2 (int)")
|
|
394
|
+
# axis_len = np.abs(axis_len)
|
|
395
|
+
# axis_len = axis_len / sum(axis_len) #TODO Changed
|
|
396
|
+
axis_len = axis_len + dist_parameter # Minimum distance between axes
|
|
397
|
+
axis_len = axis_len / sum(axis_len)
|
|
398
|
+
axis_dist = np.cumsum(np.append(0, axis_len))
|
|
399
|
+
# Axis signs (normalizing negative correlations)
|
|
400
|
+
axis_signs = np.cumprod(np.sign(np.hstack((1, corr[order[:, 0], order[:, 1]]))))
|
|
401
|
+
return data.iloc[:, obj_order], axis_dist, axis_signs
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def auto_SCORE(
|
|
405
|
+
data: pd.DataFrame,
|
|
406
|
+
solutions: bool = True,
|
|
407
|
+
bands: bool = True,
|
|
408
|
+
medians: bool = False,
|
|
409
|
+
dist_parameter: float = 0.05,
|
|
410
|
+
use_absolute_corr: bool = False,
|
|
411
|
+
distance_formula: int = 1,
|
|
412
|
+
flip_axes: bool = False,
|
|
413
|
+
clustering_algorithm: str = "DBSCAN",
|
|
414
|
+
clustering_score: str = "silhoutte",
|
|
415
|
+
quantile: float = 0.05,
|
|
416
|
+
):
|
|
417
|
+
"""Generate the SCORE Bands visualization for a dataset with predefined values for the hyperparameters.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
data (pd.DataFrame): Dataframe of objective values. The column names should be the objective names. Each row
|
|
421
|
+
should be an objective vector.
|
|
422
|
+
|
|
423
|
+
solutions (bool, optional): Show or hide individual solutions. Defaults to True.
|
|
424
|
+
bands (bool, optional): Show or hide the cluster bands. Defaults to True.
|
|
425
|
+
medians (bool, optional): Show or hide the cluster medians. Defaults to False.
|
|
426
|
+
dist_parameter (float, optional): Change the relative distances between the objective axes. Increase this value
|
|
427
|
+
if objectives are placed too close together. Decrease this value if the objectives are equidistant in a problem
|
|
428
|
+
with objective clusters. Defaults to 0.05.
|
|
429
|
+
use_absolute_corr (bool, optional): Use absolute value of the correlation to calculate the placement of axes.
|
|
430
|
+
Defaults to False.
|
|
431
|
+
distance_formula (int, optional): The value should be 1 or 2. Check the paper for details. Defaults to 1.
|
|
432
|
+
flip_axes (bool, optional): Do not use this option. Defaults to False.
|
|
433
|
+
clustering_algorithm (str, optional): Currently supported options: "GMM" and "DBSCAN". Defaults to "DBSCAN".
|
|
434
|
+
clustering_score (str, optional): If "GMM" is chosen for clustering algorithm, the scoring mechanism can be
|
|
435
|
+
either "silhoutte" or "BIC". Defaults to "silhoutte".
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
_type_: _description_
|
|
439
|
+
"""
|
|
440
|
+
# Calculating correlations and axes positions
|
|
441
|
+
corr, obj_order = order_objectives(data, use_absolute_corr=use_absolute_corr)
|
|
442
|
+
|
|
443
|
+
ordered_data, axis_dist, axis_signs = calculate_axes_positions(
|
|
444
|
+
data,
|
|
445
|
+
obj_order,
|
|
446
|
+
corr,
|
|
447
|
+
dist_parameter=dist_parameter,
|
|
448
|
+
distance_formula=distance_formula,
|
|
449
|
+
)
|
|
450
|
+
if not flip_axes:
|
|
451
|
+
axis_signs = None
|
|
452
|
+
groups = cluster(ordered_data, algorithm=clustering_algorithm, score=clustering_score)
|
|
453
|
+
groups = groups - np.min(groups) + 1 # translate minimum to 1.
|
|
454
|
+
fig1 = SCORE_bands(
|
|
455
|
+
ordered_data,
|
|
456
|
+
color_groups=groups,
|
|
457
|
+
axis_positions=axis_dist,
|
|
458
|
+
axis_signs=axis_signs,
|
|
459
|
+
solutions=solutions,
|
|
460
|
+
bands=bands,
|
|
461
|
+
medians=medians,
|
|
462
|
+
quantile=0.05,
|
|
463
|
+
)
|
|
464
|
+
return fig1, corr, obj_order, groups, axis_dist
|