desdeo 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. desdeo/adm/ADMAfsar.py +551 -0
  2. desdeo/adm/ADMChen.py +414 -0
  3. desdeo/adm/BaseADM.py +119 -0
  4. desdeo/adm/__init__.py +11 -0
  5. desdeo/api/__init__.py +6 -6
  6. desdeo/api/app.py +38 -28
  7. desdeo/api/config.py +65 -44
  8. desdeo/api/config.toml +23 -12
  9. desdeo/api/db.py +10 -8
  10. desdeo/api/db_init.py +12 -6
  11. desdeo/api/models/__init__.py +220 -20
  12. desdeo/api/models/archive.py +16 -27
  13. desdeo/api/models/emo.py +128 -0
  14. desdeo/api/models/enautilus.py +69 -0
  15. desdeo/api/models/gdm/gdm_aggregate.py +139 -0
  16. desdeo/api/models/gdm/gdm_base.py +69 -0
  17. desdeo/api/models/gdm/gdm_score_bands.py +114 -0
  18. desdeo/api/models/gdm/gnimbus.py +138 -0
  19. desdeo/api/models/generic.py +104 -0
  20. desdeo/api/models/generic_states.py +401 -0
  21. desdeo/api/models/nimbus.py +158 -0
  22. desdeo/api/models/preference.py +44 -6
  23. desdeo/api/models/problem.py +274 -64
  24. desdeo/api/models/session.py +4 -1
  25. desdeo/api/models/state.py +419 -52
  26. desdeo/api/models/user.py +7 -6
  27. desdeo/api/models/utopia.py +25 -0
  28. desdeo/api/routers/_EMO.backup +309 -0
  29. desdeo/api/routers/_NIMBUS.py +6 -3
  30. desdeo/api/routers/emo.py +497 -0
  31. desdeo/api/routers/enautilus.py +237 -0
  32. desdeo/api/routers/gdm/gdm_aggregate.py +234 -0
  33. desdeo/api/routers/gdm/gdm_base.py +420 -0
  34. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_manager.py +398 -0
  35. desdeo/api/routers/gdm/gdm_score_bands/gdm_score_bands_routers.py +377 -0
  36. desdeo/api/routers/gdm/gnimbus/gnimbus_manager.py +698 -0
  37. desdeo/api/routers/gdm/gnimbus/gnimbus_routers.py +591 -0
  38. desdeo/api/routers/generic.py +233 -0
  39. desdeo/api/routers/nimbus.py +705 -0
  40. desdeo/api/routers/problem.py +201 -4
  41. desdeo/api/routers/reference_point_method.py +20 -44
  42. desdeo/api/routers/session.py +50 -26
  43. desdeo/api/routers/user_authentication.py +180 -26
  44. desdeo/api/routers/utils.py +187 -0
  45. desdeo/api/routers/utopia.py +230 -0
  46. desdeo/api/schema.py +10 -4
  47. desdeo/api/tests/conftest.py +94 -2
  48. desdeo/api/tests/test_enautilus.py +330 -0
  49. desdeo/api/tests/test_models.py +550 -72
  50. desdeo/api/tests/test_routes.py +902 -43
  51. desdeo/api/utils/_database.py +263 -0
  52. desdeo/api/utils/database.py +28 -266
  53. desdeo/api/utils/emo_database.py +40 -0
  54. desdeo/core.py +7 -0
  55. desdeo/emo/__init__.py +154 -24
  56. desdeo/emo/hooks/archivers.py +18 -2
  57. desdeo/emo/methods/EAs.py +128 -5
  58. desdeo/emo/methods/bases.py +9 -56
  59. desdeo/emo/methods/templates.py +111 -0
  60. desdeo/emo/operators/crossover.py +544 -42
  61. desdeo/emo/operators/evaluator.py +10 -14
  62. desdeo/emo/operators/generator.py +127 -24
  63. desdeo/emo/operators/mutation.py +212 -41
  64. desdeo/emo/operators/scalar_selection.py +202 -0
  65. desdeo/emo/operators/selection.py +956 -214
  66. desdeo/emo/operators/termination.py +124 -16
  67. desdeo/emo/options/__init__.py +108 -0
  68. desdeo/emo/options/algorithms.py +435 -0
  69. desdeo/emo/options/crossover.py +164 -0
  70. desdeo/emo/options/generator.py +131 -0
  71. desdeo/emo/options/mutation.py +260 -0
  72. desdeo/emo/options/repair.py +61 -0
  73. desdeo/emo/options/scalar_selection.py +66 -0
  74. desdeo/emo/options/selection.py +127 -0
  75. desdeo/emo/options/templates.py +383 -0
  76. desdeo/emo/options/termination.py +143 -0
  77. desdeo/gdm/__init__.py +22 -0
  78. desdeo/gdm/gdmtools.py +45 -0
  79. desdeo/gdm/score_bands.py +114 -0
  80. desdeo/gdm/voting_rules.py +50 -0
  81. desdeo/mcdm/__init__.py +23 -1
  82. desdeo/mcdm/enautilus.py +338 -0
  83. desdeo/mcdm/gnimbus.py +484 -0
  84. desdeo/mcdm/nautilus_navigator.py +7 -6
  85. desdeo/mcdm/reference_point_method.py +70 -0
  86. desdeo/problem/__init__.py +5 -1
  87. desdeo/problem/external/__init__.py +18 -0
  88. desdeo/problem/external/core.py +356 -0
  89. desdeo/problem/external/pymoo_provider.py +266 -0
  90. desdeo/problem/external/runtime.py +44 -0
  91. desdeo/problem/infix_parser.py +2 -2
  92. desdeo/problem/pyomo_evaluator.py +25 -6
  93. desdeo/problem/schema.py +69 -48
  94. desdeo/problem/simulator_evaluator.py +65 -15
  95. desdeo/problem/testproblems/__init__.py +26 -11
  96. desdeo/problem/testproblems/benchmarks_server.py +120 -0
  97. desdeo/problem/testproblems/cake_problem.py +185 -0
  98. desdeo/problem/testproblems/dmitry_forest_problem_discrete.py +71 -0
  99. desdeo/problem/testproblems/forest_problem.py +77 -69
  100. desdeo/problem/testproblems/multi_valued_constraints.py +119 -0
  101. desdeo/problem/testproblems/{river_pollution_problem.py → river_pollution_problems.py} +28 -22
  102. desdeo/problem/testproblems/single_objective.py +289 -0
  103. desdeo/problem/testproblems/zdt_problem.py +4 -1
  104. desdeo/tools/__init__.py +39 -21
  105. desdeo/tools/desc_gen.py +22 -0
  106. desdeo/tools/generics.py +22 -2
  107. desdeo/tools/group_scalarization.py +3090 -0
  108. desdeo/tools/indicators_binary.py +107 -1
  109. desdeo/tools/indicators_unary.py +3 -16
  110. desdeo/tools/message.py +33 -2
  111. desdeo/tools/non_dominated_sorting.py +4 -3
  112. desdeo/tools/patterns.py +9 -7
  113. desdeo/tools/pyomo_solver_interfaces.py +48 -35
  114. desdeo/tools/reference_vectors.py +118 -351
  115. desdeo/tools/scalarization.py +340 -1413
  116. desdeo/tools/score_bands.py +491 -328
  117. desdeo/tools/utils.py +117 -49
  118. desdeo/tools/visualizations.py +67 -0
  119. desdeo/utopia_stuff/utopia_problem.py +1 -1
  120. desdeo/utopia_stuff/utopia_problem_old.py +1 -1
  121. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/METADATA +46 -28
  122. desdeo-2.1.0.dist-info/RECORD +180 -0
  123. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info}/WHEEL +1 -1
  124. desdeo-2.0.0.dist-info/RECORD +0 -120
  125. /desdeo/api/utils/{logger.py → _logger.py} +0 -0
  126. {desdeo-2.0.0.dist-info → desdeo-2.1.0.dist-info/licenses}/LICENSE +0 -0
@@ -3,15 +3,22 @@
3
3
  This module contains the functions which generate SCORE bands visualizations. It also contains functions to calculate
4
4
  the order and positions of the objective axes, as well as a heatmap of correlation matrix.
5
5
 
6
- This file is just copied from the old SCORE bands repo.
7
- It is very much out of date and is missing documentation.
6
+ To run the SCORE bands visualization, use the `score_json` function to generate the data for the visualization, and then
7
+ use the `plot_score` function to generate the figure. You can also pass the result of `score_json` to other frontends
8
+ for visualization.
8
9
  """
9
10
 
11
+ from copy import deepcopy
12
+ from enum import Enum
13
+ from typing import Literal
14
+ from warnings import warn
15
+
10
16
  import numpy as np
11
- import pandas as pd
12
17
  import plotly.figure_factory as ff
13
18
  import plotly.graph_objects as go
19
+ import polars as pl
14
20
  from matplotlib import cm
21
+ from pydantic import BaseModel, ConfigDict, Field
15
22
  from scipy.stats import pearsonr
16
23
  from sklearn.cluster import DBSCAN
17
24
  from sklearn.metrics import silhouette_score
@@ -20,32 +27,179 @@ from sklearn.preprocessing import StandardScaler
20
27
  from tsp_solver.greedy import solve_tsp
21
28
 
22
29
 
23
- def _gaussianmixtureclusteringwithBIC(data: pd.DataFrame):
24
- data = StandardScaler().fit_transform(data)
30
+ class GMMOptions(BaseModel):
31
+ """Options for Gaussian Mixture Model clustering algorithm."""
32
+
33
+ model_config = ConfigDict(use_attribute_docstrings=True)
34
+
35
+ name: str = Field(default="GMM")
36
+ """Gaussian Mixture Model clustering algorithm."""
37
+ scoring_method: Literal["BIC", "silhouette"] = Field(default="silhouette")
38
+ """Scoring method to use for GMM. Either "BIC" or "silhouette". Defaults to "silhouette".
39
+ This option determines how the number of clusters is chosen."""
40
+
41
+
42
+ class DBSCANOptions(BaseModel):
43
+ """Options for DBSCAN clustering algorithm."""
44
+
45
+ model_config = ConfigDict(use_attribute_docstrings=True)
46
+
47
+ name: str = Field(default="DBSCAN")
48
+ """DBSCAN clustering algorithm."""
49
+
50
+
51
+ class KMeansOptions(BaseModel):
52
+ """Options for KMeans clustering algorithm."""
53
+
54
+ model_config = ConfigDict(use_attribute_docstrings=True)
55
+
56
+ name: str = Field(default="KMeans")
57
+ """KMeans clustering algorithm."""
58
+ n_clusters: int = Field(default=5)
59
+ """Number of clusters to use. Defaults to 5."""
60
+
61
+
62
+ class DimensionClusterOptions(BaseModel):
63
+ """Options for clustering by one of the objectives/decision variables."""
64
+
65
+ model_config = ConfigDict(use_attribute_docstrings=True)
66
+
67
+ name: str = Field(default="DimensionCluster")
68
+ """Clustering by one of the dimensions."""
69
+ dimension_name: str
70
+ """Dimension to use for clustering."""
71
+ n_clusters: int = Field(default=5)
72
+ """Number of clusters to use. Defaults to 5."""
73
+ kind: Literal["EqualWidth", "EqualFrequency"] = Field(default="EqualWidth")
74
+ """Kind of clustering to use. Either "EqualWidth", which divides the dimension range into equal width intervals,
75
+ or "EqualFrequency", which divides the dimension values into intervals with equal number of solutions.
76
+ Defaults to "EqualWidth"."""
77
+
78
+
79
+ class CustomClusterOptions(BaseModel):
80
+ """Options for custom clustering provided by the user."""
81
+
82
+ model_config = ConfigDict(use_attribute_docstrings=True)
83
+
84
+ name: str = Field(default="Custom")
85
+ """Custom user-provided clusters."""
86
+ clusters: list[int]
87
+ """List of cluster IDs (one for each solution) indicating the cluster to which each solution belongs."""
88
+
89
+
90
+ ClusteringOptions = GMMOptions | DBSCANOptions | KMeansOptions | DimensionClusterOptions | CustomClusterOptions
91
+
92
+
93
+ class DistanceFormula(int, Enum):
94
+ """Distance formulas supported by SCORE bands. See the paper for details."""
95
+
96
+ FORMULA_1 = 1
97
+ FORMULA_2 = 2
98
+
99
+
100
+ class SCOREBandsConfig(BaseModel):
101
+ """Configuration options for SCORE bands visualization."""
102
+
103
+ model_config = ConfigDict(use_attribute_docstrings=True)
104
+
105
+ dimensions: list[str] | None = Field(default=None)
106
+ """List of variable/objective names (i.e., column names in the data) to include in the visualization.
107
+ If None, all columns in the data are used. Defaults to None."""
108
+ descriptive_names: dict[str, str] | None = Field(default=None)
109
+ """Optional dictionary mapping dimensions to descriptive names for display in the visualization.
110
+ If None, the original dimension names are used. Defaults to None."""
111
+ units: dict[str, str] | None = Field(default=None)
112
+ """Optional dictionary mapping dimensions to their units for display in the visualization.
113
+ If None, no units are displayed. Defaults to None."""
114
+ axis_positions: dict[str, float] | None = Field(default=None)
115
+ """Dictionary mapping objective names to their positions on the axes in the SCORE bands visualization. The first
116
+ objective is at position 0.0, and the last objective is at position 1.0. Use this option if you want to
117
+ manually set the axis positions. If None, the axis positions are calculated automatically based on correlations.
118
+ Defaults to None."""
119
+ clustering_algorithm: ClusteringOptions = Field(
120
+ default=DBSCANOptions(),
121
+ )
122
+ """
123
+ Clustering algorithm to use. Currently supported options: "GMM", "DBSCAN",
124
+ and "KMeans". Defaults to "DBSCAN".
125
+ """
126
+ distance_formula: DistanceFormula = Field(default=DistanceFormula.FORMULA_1)
127
+ """Distance formula to use. The value should be 1 or 2. Check the paper for details. Defaults to 1."""
128
+ distance_parameter: float = Field(default=0.05)
129
+ """Change the relative distances between the objective axes. Increase this value if objectives are placed too close
130
+ together. Decrease this value if the objectives are equidistant in a problem with objective clusters. Defaults
131
+ to 0.05."""
132
+ use_absolute_correlations: bool = Field(default=False)
133
+ """Whether to use absolute value of the correlation to calculate the placement of axes. Defaults to False."""
134
+ include_solutions: bool = Field(default=False)
135
+ """Whether to include individual solutions. Defaults to False. If True, the size of the resulting figure may be
136
+ very large for datasets with many solutions. Moreover, the individual traces are hidden by default, but can be
137
+ viewed interactively in the figure."""
138
+ include_medians: bool = Field(default=False)
139
+ """Whether to include cluster medians. Defaults to False. If True, the median traces are hidden by default, but
140
+ can be viewed interactively in the figure."""
141
+ interval_size: float = Field(default=0.95)
142
+ """The size (as a fraction) of the interval to use for the bands. Defaults to 0.95, meaning that 95% of the
143
+ middle solutions in a cluster will be included in the band. The rest will be considered outliers."""
144
+ scales: dict[str, tuple[float, float]] | None = Field(default=None)
145
+ """Optional dictionary specifying the min and max values for each objective. The keys should be the
146
+ objective names (i.e., column names in the data), and the values should be tuples of (min, max).
147
+ If not provided, the min and max will be calculated from the data."""
148
+
149
+
150
+ class SCOREBandsResult(BaseModel):
151
+ """Pydantic/JSON model for representing SCORE Bands."""
152
+
153
+ model_config = ConfigDict(use_attribute_docstrings=True)
154
+
155
+ options: SCOREBandsConfig
156
+ """Configuration options used to generate the SCORE bands."""
157
+ ordered_dimensions: list[str]
158
+ """List of variable/objective names (i.e., column names in the data).
159
+ Ordered according to their placement in the SCORE bands visualization."""
160
+ clusters: list[int]
161
+ """List of cluster IDs (one for each solution) indicating the cluster to which each solution belongs."""
162
+ axis_positions: dict[str, float]
163
+ """Dictionary mapping objective names to their positions on the axes in the SCORE bands visualization. The first
164
+ objective is at position 0.0, and the last objective is at position 1.0."""
165
+ bands: dict[int, dict[str, tuple[float, float]]]
166
+ """Dictionary mapping cluster IDs to dictionaries of objective names and their corresponding band
167
+ extremes (min, max)."""
168
+ medians: dict[int, dict[str, float]]
169
+ """Dictionary mapping cluster IDs to dictionaries of objective names and their corresponding median values."""
170
+ cardinalities: dict[int, int]
171
+ """Dictionary mapping cluster IDs to the number of solutions in each cluster."""
172
+
173
+
174
+ def _gaussianmixtureclusteringwithBIC(data: pl.DataFrame) -> np.ndarray:
175
+ """Cluster the data using Gaussian Mixture Model with BIC scoring."""
176
+ data_copy = data.to_numpy()
177
+ data_copy = StandardScaler().fit_transform(data_copy)
25
178
  lowest_bic = np.inf
26
179
  bic = []
27
- n_components_range = range(1, min(11, len(data)))
28
- cv_types = ["spherical", "tied", "diag", "full"]
180
+ n_components_range = range(1, min(11, len(data_copy)))
181
+ cv_types: list[Literal["full", "tied", "diag", "spherical"]] = ["spherical", "tied", "diag", "full"]
29
182
  for cv_type in cv_types:
30
183
  for n_components in n_components_range:
31
184
  # Fit a Gaussian mixture with EM
32
185
  gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)
33
- gmm.fit(data)
34
- bic.append(gmm.score(data))
186
+ gmm.fit(data_copy)
187
+ bic.append(gmm.score(data_copy))
35
188
  # bic.append(gmm.bic(data))
36
189
  if bic[-1] < lowest_bic:
37
190
  lowest_bic = bic[-1]
38
191
  best_gmm = gmm
39
192
 
40
- return best_gmm.predict(data)
193
+ return best_gmm.predict(data_copy)
41
194
 
42
195
 
43
- def _gaussianmixtureclusteringwithsilhouette(data: pd.DataFrame):
44
- X = StandardScaler().fit_transform(data)
196
+ def _gaussianmixtureclusteringwithsilhouette(data: pl.DataFrame) -> np.ndarray:
197
+ """Cluster the data using Gaussian Mixture Model with silhouette scoring."""
198
+ X = StandardScaler().fit_transform(data.to_numpy())
45
199
  best_score = -np.inf
46
- best_labels = []
200
+ best_labels = np.ones(len(data))
47
201
  n_components_range = range(1, min(11, len(data)))
48
- cv_types = ["spherical", "tied", "diag", "full"]
202
+ cv_types: list[Literal["full", "tied", "diag", "spherical"]] = ["spherical", "tied", "diag", "full"]
49
203
  for cv_type in cv_types:
50
204
  for n_components in n_components_range:
51
205
  # Fit a Gaussian mixture with EM
@@ -62,11 +216,12 @@ def _gaussianmixtureclusteringwithsilhouette(data: pd.DataFrame):
62
216
  return best_labels
63
217
 
64
218
 
65
- def _DBSCANClustering(data: pd.DataFrame):
66
- X = StandardScaler().fit_transform(data)
219
+ def _DBSCANClustering(data: pl.DataFrame) -> np.ndarray:
220
+ """Cluster the data using DBSCAN with silhouette scoring to choose eps."""
221
+ X = StandardScaler().fit_transform(data.to_numpy())
67
222
  eps_options = np.linspace(0.01, 1, 20)
68
223
  best_score = -np.inf
69
- best_labels = [1] * len(X)
224
+ best_labels = np.ones(len(data))
70
225
  for eps_option in eps_options:
71
226
  db = DBSCAN(eps=eps_option, min_samples=10, metric="cosine").fit(X)
72
227
  core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
@@ -83,250 +238,54 @@ def _DBSCANClustering(data: pd.DataFrame):
83
238
  return best_labels
84
239
 
85
240
 
86
- def cluster(data: pd.DataFrame, algorithm: str = "DBSCAN", score: str = "silhoutte"):
87
- if not (score == "silhoutte" or score == "BIC"):
88
- raise ValueError()
89
- if not (algorithm == "GMM" or algorithm == "DBSCAN"):
90
- raise ValueError()
91
- if algorithm == "DBSCAN":
92
- return _DBSCANClustering(data)
93
- if score == "silhoutte":
94
- return _gaussianmixtureclusteringwithsilhouette(data)
95
- else:
96
- return _gaussianmixtureclusteringwithBIC(data)
97
-
98
-
99
- def SCORE_bands(
100
- data: pd.DataFrame,
101
- axis_signs: np.ndarray = None,
102
- color_groups: list | np.ndarray = None,
103
- axis_positions: np.ndarray = None,
104
- solutions: bool = True,
105
- bands: bool = False,
106
- medians: bool = False,
107
- quantile: float = 0.25,
108
- ) -> go.Figure:
109
- """Generate SCORE bands figure from the provided data.
110
-
111
- Args:
112
- data (pd.DataFrame): Pandas dataframe where each column represents an objective and each row is an objective
113
- vector. The column names are displayed as the objective names in the generated figure. Each element in the
114
- dataframe must be numeric.
115
-
116
- color_groups (Union[List, np.ndarray], optional): List or numpy array of the same length as the number of
117
- objective vectors. The elements should be contiguous set of integers starting at 1. The element value represents
118
- the Cluster ID of the corresponding objective vector. Defaults to None (though this behaviour is not fully
119
- tested yet).
120
-
121
- axis_positions (np.ndarray, optional): 1-D numpy array of the same length as the number of objectives. The value
122
- represents the horizontal position of the corresponding objective axes. The value of the first and last element
123
- should be 0 and 1 respectively, and all intermediate values should lie between 0 and 1.
124
- Defaults to None, in which case all axes are positioned equidistant.
125
-
126
- axis_signs (np.ndarray, optional): 1-D Numpy array of the same length as the number of objectives. Each element
127
- can either be 1 or -1. A value of -1 flips the objective in the SCORE bands visualization. This feature is
128
- experimental and should be ignored for now. Defaults to None.
129
-
130
- solutions (bool, optional): Show or hide individual solutions. Defaults to True.
131
-
132
- bands (bool, optional): Show or hide cluster bands. Defaults to False.
133
-
134
- medians (bool, optional): Show or hide cluster medians. Defaults to False.
135
-
136
- quantile (float, optional): The quantile value to calculate the band. The band represents the range between
137
- (quantile) and (1 - quantile) quantiles of the objective values. Defaults to 0.25.
138
-
139
- Returns:
140
- go.Figure: SCORE bands plot.
141
-
142
- """
143
- # show on render
144
- show_solutions = "legendonly"
145
- bands_visible = True
146
- if bands:
147
- show_medians = "legendonly"
148
- if medians:
149
- show_medians = True
150
- # pio.templates.default = "simple_white"
151
- column_names = data.columns
152
- num_columns = len(column_names)
153
- if axis_positions is None:
154
- axis_positions = np.linspace(0, 1, num_columns)
155
- if axis_signs is None:
156
- axis_signs = np.ones_like(axis_positions)
157
- if color_groups is None:
158
- color_groups = "continuous"
159
- colorscale = cm.get_cmap("viridis")
160
- elif isinstance(color_groups, (np.ndarray, list)):
161
- groups = list(np.unique(color_groups))
162
- if len(groups) <= 8:
163
- colorscale = cm.get_cmap("Accent", len(groups))
164
- # print(len(groups))
165
- # print("hi!")
166
- else:
167
- colorscale = cm.get_cmap("tab20", len(groups))
168
- # colorscale = cm.get_cmap("viridis_r", len(groups))
169
- data = data * axis_signs
170
- num_labels = 6
171
-
172
- # Scaling the objective values between 0 and 1.
173
- scaled_data = data - data.min(axis=0)
174
- scaled_data = scaled_data / scaled_data.max(axis=0)
175
- scales = pd.DataFrame([data.min(axis=0), data.max(axis=0)], index=["min", "max"]) * axis_signs
241
+ def cluster_by_dimension(data: pl.DataFrame, options: DimensionClusterOptions) -> np.ndarray:
242
+ """Cluster the data by a specific dimension."""
243
+ if options.dimension_name not in data.columns:
244
+ raise ValueError(f"Objective '{options.dimension_name}' not found in data.")
176
245
 
177
- fig = go.Figure()
178
- fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
179
- fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
180
- fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
181
-
182
- scaled_data.insert(0, "group", value=color_groups)
183
- for cluster_id, solns in scaled_data.groupby("group"):
184
- # TODO: Many things here are very inefficient. Improve when free.
185
- num_solns = len(solns)
186
-
187
- r, g, b, a = colorscale(cluster_id - 1) # Needed as cluster numbering starts at 1
188
- a = 0.6
189
- a_soln = 0.6
190
- color_bands = f"rgba({r}, {g}, {b}, {a})"
191
- color_soln = f"rgba({r}, {g}, {b}, {a_soln})"
192
-
193
- low = solns.drop("group", axis=1).quantile(quantile)
194
- high = solns.drop("group", axis=1).quantile(1 - quantile)
195
- median = solns.drop("group", axis=1).median()
196
-
197
- if bands is True:
198
- # lower bound of the band
199
- fig.add_scatter(
200
- x=axis_positions,
201
- y=low,
202
- line={"color": color_bands},
203
- name=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}; {num_solns} Solutions ",
204
- mode="lines",
205
- legendgroup=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}",
206
- showlegend=True,
207
- line_shape="spline",
208
- hovertext=f"Cluster {cluster_id}",
209
- visible=bands_visible,
210
- )
211
- # upper bound of the band
212
- fig.add_scatter(
213
- x=axis_positions,
214
- y=high,
215
- line={"color": color_bands},
216
- name=f"Cluster {cluster_id}",
217
- fillcolor=color_bands,
218
- mode="lines",
219
- legendgroup=f"{int(100 - 200 * quantile)}% band: Cluster {cluster_id}",
220
- showlegend=False,
221
- line_shape="spline",
222
- fill="tonexty",
223
- hovertext=f"Cluster {cluster_id}",
224
- visible=bands_visible,
225
- )
226
- if medians is True:
227
- # median
228
- fig.add_scatter(
229
- x=axis_positions,
230
- y=median,
231
- line={"color": color_bands},
232
- name=f"Median: Cluster {cluster_id}",
233
- mode="lines+markers",
234
- marker={"line": {"color": "Black", "width": 2}},
235
- legendgroup=f"Median: Cluster {cluster_id}",
236
- showlegend=True,
237
- visible=show_medians,
238
- )
239
- if solutions is True:
240
- # individual solutions
241
- legend = True
242
- for _, soln in solns.drop("group", axis=1).iterrows():
243
- fig.add_scatter(
244
- x=axis_positions,
245
- y=soln,
246
- line={"color": color_soln},
247
- name=f"Solutions: Cluster {cluster_id} ",
248
- legendgroup=f"Solutions: Cluster {cluster_id}",
249
- showlegend=legend,
250
- visible=show_solutions,
251
- )
252
- legend = False
253
- # Axis lines
254
- for i, col_name in enumerate(column_names):
255
- # better = "Upper" if axis_signs[i] == -1 else "Lower"
256
- label_text = np.linspace(scales[col_name]["min"], scales[col_name]["max"], num_labels)
257
- # label_text = ["{:.3g}".format(i) for i in label_text]
258
- heights = np.linspace(0, 1, num_labels)
259
- scale_factors = []
260
- for current_label in label_text:
261
- try:
262
- with np.errstate(divide="ignore"):
263
- scale_factors.append(int(np.floor(np.log10(np.abs(current_label)))))
264
- except OverflowError:
265
- pass
266
-
267
- scale_factor = int(np.median(scale_factors))
268
- if scale_factor == -1 or scale_factor == 1:
269
- scale_factor = 0
270
-
271
- # TODO: This sometimes doesn't generate the correct label text. Check with datasets where objs lie between (0,1).
272
- label_text = label_text / 10 ** (scale_factor)
273
- label_text = ["{:.1f}".format(i) for i in label_text]
274
- scale_factor_text = f"e{scale_factor}" if scale_factor != 0 else ""
275
-
276
- # Bottom axis label
277
- fig.add_scatter(
278
- x=[axis_positions[i]],
279
- y=[heights[0]],
280
- text=[label_text[0] + scale_factor_text],
281
- textposition="bottom center",
282
- mode="text",
283
- line={"color": "black"},
284
- showlegend=False,
285
- )
286
- # Top axis label
287
- fig.add_scatter(
288
- x=[axis_positions[i]],
289
- y=[heights[-1]],
290
- text=[label_text[-1] + scale_factor_text],
291
- textposition="top center",
292
- mode="text",
293
- line={"color": "black"},
294
- showlegend=False,
295
- )
296
- label_text[0] = ""
297
- label_text[-1] = ""
298
- # Intermediate axes labels
299
- fig.add_scatter(
300
- x=[axis_positions[i]] * num_labels,
301
- y=heights,
302
- text=label_text,
303
- textposition="middle left",
304
- mode="markers+lines+text",
305
- line={"color": "black"},
306
- showlegend=False,
307
- )
246
+ # Select the dimension column for clustering
247
+ dimension = data[options.dimension_name]
308
248
 
309
- fig.add_scatter(
310
- x=[axis_positions[i]],
311
- y=[1.10],
312
- text=f"{col_name}",
313
- textfont={"size": 20},
314
- mode="text",
315
- showlegend=False,
249
+ # Perform clustering based on the specified method
250
+ if options.kind == "EqualWidth":
251
+ min_val: float = dimension.min()
252
+ max_val: float = dimension.max()
253
+ SMALL_VALUE = 1e-8
254
+ thresholds = np.linspace(
255
+ min_val * (1 - SMALL_VALUE), # Ensure the minimum value is included in the first cluster
256
+ max_val * (1 + SMALL_VALUE), # Ensure the maximum value is included in the last cluster
257
+ options.n_clusters + 1,
316
258
  )
317
- """fig.add_scatter(
318
- x=[axis_positions[i]], y=[1.1], text=better, mode="text", showlegend=False,
319
- )
320
- fig.add_scatter(
321
- x=[axis_positions[i]],
322
- y=[1.05],
323
- text="is better",
324
- mode="text",
325
- showlegend=False,
326
- )"""
327
- fig.update_layout(font_size=18)
328
- fig.update_layout(legend={"orientation": "h", "yanchor": "top", "font": {"size": 24}})
329
- return fig
259
+ return np.digitize(dimension.to_numpy(), thresholds) # Cluster IDs start at 1
260
+ elif options.kind == "EqualFrequency":
261
+ levels: list[float] = [dimension.quantile(i / options.n_clusters) for i in range(1, options.n_clusters)]
262
+ thresholds = [-np.inf] + levels + [np.inf]
263
+ return np.digitize(dimension.to_numpy(), thresholds) # Cluster IDs start at 1
264
+ raise ValueError(f"Unknown clustering kind: {options.kind}")
265
+
266
+
267
+ def cluster(data: pl.DataFrame, options: ClusteringOptions) -> np.ndarray:
268
+ """Cluster the data using the specified clustering algorithm and options."""
269
+ if isinstance(options, DimensionClusterOptions):
270
+ return cluster_by_dimension(data, options)
271
+ if isinstance(options, KMeansOptions):
272
+ from sklearn.cluster import KMeans
273
+
274
+ X = StandardScaler().fit_transform(data.to_numpy())
275
+ kmeans = KMeans(n_clusters=options.n_clusters, random_state=0).fit(X)
276
+ return kmeans.labels_
277
+ if isinstance(options, DBSCANOptions):
278
+ return _DBSCANClustering(data)
279
+ if isinstance(options, GMMOptions):
280
+ if options.scoring_method == "silhouette":
281
+ return _gaussianmixtureclusteringwithsilhouette(data)
282
+ if options.scoring_method == "BIC":
283
+ return _gaussianmixtureclusteringwithBIC(data)
284
+ if isinstance(options, CustomClusterOptions):
285
+ if len(options.clusters) != len(data):
286
+ raise ValueError("Length of custom clusters must match number of solutions in data.")
287
+ return np.array(options.clusters)
288
+ raise ValueError(f"Unknown clustering algorithm: {options}")
330
289
 
331
290
 
332
291
  def annotated_heatmap(correlation_matrix: np.ndarray, col_names: list, order: list | np.ndarray) -> go.Figure:
@@ -340,7 +299,7 @@ def annotated_heatmap(correlation_matrix: np.ndarray, col_names: list, order: li
340
299
  Returns:
341
300
  go.Figure: The heatmap
342
301
  """ # noqa: D212, D213, D406, D407
343
- corr = pd.DataFrame(correlation_matrix, index=col_names, columns=col_names)
302
+ corr = pl.DataFrame(correlation_matrix, index=col_names, columns=col_names)
344
303
  corr = corr[col_names[order]].loc[col_names[order[::-1]]]
345
304
  corr = np.rint(corr * 100) / 100 # Take upto two significant figures only to make heatmap readable.
346
305
  fig = ff.create_annotated_heatmap(
@@ -353,13 +312,13 @@ def annotated_heatmap(correlation_matrix: np.ndarray, col_names: list, order: li
353
312
  return fig
354
313
 
355
314
 
356
- def order_objectives(data: pd.DataFrame, use_absolute_corr: bool = False):
315
+ def order_dimensions(data: pl.DataFrame, use_absolute_corr: bool = False):
357
316
  """Calculate the order of objectives.
358
317
 
359
318
  Also returns the correlation matrix.
360
319
 
361
320
  Args:
362
- data (pd.DataFrame): Data to be visualized.
321
+ data (pl.DataFrame): Data to be visualized.
363
322
  use_absolute_corr (bool, optional): Use absolute value of the correlation to calculate order. Defaults to False.
364
323
 
365
324
  Returns:
@@ -381,84 +340,288 @@ def order_objectives(data: pd.DataFrame, use_absolute_corr: bool = False):
381
340
  return corr, obj_order
382
341
 
383
342
 
384
- def calculate_axes_positions(data, obj_order, corr, dist_parameter, distance_formula: int = 1):
343
+ def calculate_axes_positions(
344
+ dimension_order: list[int],
345
+ corr: np.ndarray,
346
+ dist_parameter: float,
347
+ distance_formula: DistanceFormula = DistanceFormula.FORMULA_1,
348
+ ) -> np.ndarray:
349
+ """Calculate the position of the axes for the SCORE bands visualization based on correlations.
350
+
351
+ Args:
352
+ dimension_order (list[int]): Order of the variables to be plotted.
353
+ corr (np.ndarray): Correlation (pearson) matrix.
354
+ dist_parameter (float): Change the relative distances between the axes. Increase this value if the axes are
355
+ placed too close together. Decrease this value if the axes are equidistant.
356
+ distance_formula (DistanceFormula, optional): The value should be 1 or 2. Check the paper for details.
357
+ Defaults to DistanceFormula.FORMULA_1.
358
+
359
+ Returns:
360
+ np.ndarray: Positions of the axes in the range [0, 1].
361
+ """
385
362
  # axes positions
386
- order = np.asarray(list((zip(obj_order[:-1], obj_order[1:]))))
363
+ order = np.asarray(list(zip(dimension_order[:-1], dimension_order[1:], strict=True)))
387
364
  axis_len = corr[order[:, 0], order[:, 1]]
388
- if distance_formula == 1:
389
- axis_len = 1 - axis_len # TODO Make this formula available to the user
390
- elif distance_formula == 2:
365
+ if distance_formula == DistanceFormula.FORMULA_1:
366
+ axis_len = 1 - axis_len
367
+ elif distance_formula == DistanceFormula.FORMULA_2:
391
368
  axis_len = 1 / (np.abs(axis_len) + 1) # Reciprocal for reverse
392
369
  else:
370
+ # Should never reach here
393
371
  raise ValueError("distance_formula should be either 1 or 2 (int)")
394
- # axis_len = np.abs(axis_len)
395
- # axis_len = axis_len / sum(axis_len) #TODO Changed
396
- axis_len = axis_len + dist_parameter # Minimum distance between axes
372
+ axis_len = axis_len + dist_parameter
397
373
  axis_len = axis_len / sum(axis_len)
398
- axis_dist = np.cumsum(np.append(0, axis_len))
399
- # Axis signs (normalizing negative correlations)
400
- axis_signs = np.cumprod(np.sign(np.hstack((1, corr[order[:, 0], order[:, 1]]))))
401
- return data.iloc[:, obj_order], axis_dist, axis_signs
402
-
403
-
404
- def auto_SCORE(
405
- data: pd.DataFrame,
406
- solutions: bool = True,
407
- bands: bool = True,
408
- medians: bool = False,
409
- dist_parameter: float = 0.05,
410
- use_absolute_corr: bool = False,
411
- distance_formula: int = 1,
412
- flip_axes: bool = False,
413
- clustering_algorithm: str = "DBSCAN",
414
- clustering_score: str = "silhoutte",
415
- quantile: float = 0.05,
416
- ):
417
- """Generate the SCORE Bands visualization for a dataset with predefined values for the hyperparameters.
374
+ return np.cumsum(np.append(0, axis_len))
375
+
376
+
377
+ def score_json(
378
+ data: pl.DataFrame,
379
+ options: SCOREBandsConfig,
380
+ ) -> SCOREBandsResult:
381
+ """Generate the SCORE Bands data for a given dataset and configuration options.
418
382
 
419
383
  Args:
420
- data (pd.DataFrame): Dataframe of objective values. The column names should be the objective names. Each row
421
- should be an objective vector.
384
+ data (pl.DataFrame): Dataframe of variable (decision or objective) values.
385
+ The column names should be the names of the variables to be plotted. Each row should be a solution.
422
386
 
423
- solutions (bool, optional): Show or hide individual solutions. Defaults to True.
424
- bands (bool, optional): Show or hide the cluster bands. Defaults to True.
425
- medians (bool, optional): Show or hide the cluster medians. Defaults to False.
426
- dist_parameter (float, optional): Change the relative distances between the objective axes. Increase this value
427
- if objectives are placed too close together. Decrease this value if the objectives are equidistant in a problem
428
- with objective clusters. Defaults to 0.05.
429
- use_absolute_corr (bool, optional): Use absolute value of the correlation to calculate the placement of axes.
430
- Defaults to False.
431
- distance_formula (int, optional): The value should be 1 or 2. Check the paper for details. Defaults to 1.
432
- flip_axes (bool, optional): Do not use this option. Defaults to False.
433
- clustering_algorithm (str, optional): Currently supported options: "GMM" and "DBSCAN". Defaults to "DBSCAN".
434
- clustering_score (str, optional): If "GMM" is chosen for clustering algorithm, the scoring mechanism can be
435
- either "silhoutte" or "BIC". Defaults to "silhoutte".
387
+ options (SCOREBandsConfig): Configuration options for generating the SCORE bands.
436
388
 
437
389
  Returns:
438
- _type_: _description_
390
+ SCOREBandsResult: The result containing all relevant data for the SCORE bands visualization.
439
391
  """
392
+ options = deepcopy(options)
440
393
  # Calculating correlations and axes positions
441
- corr, obj_order = order_objectives(data, use_absolute_corr=use_absolute_corr)
442
-
443
- ordered_data, axis_dist, axis_signs = calculate_axes_positions(
444
- data,
445
- obj_order,
446
- corr,
447
- dist_parameter=dist_parameter,
448
- distance_formula=distance_formula,
449
- )
450
- if not flip_axes:
451
- axis_signs = None
452
- groups = cluster(ordered_data, algorithm=clustering_algorithm, score=clustering_score)
453
- groups = groups - np.min(groups) + 1 # translate minimum to 1.
454
- fig1 = SCORE_bands(
455
- ordered_data,
456
- color_groups=groups,
457
- axis_positions=axis_dist,
458
- axis_signs=axis_signs,
459
- solutions=solutions,
460
- bands=bands,
461
- medians=medians,
462
- quantile=0.05,
394
+ if options.dimensions is None:
395
+ options.dimensions = data.columns
396
+ data_copy = data.select([pl.col(col) for col in options.dimensions])
397
+
398
+ if options.axis_positions is None:
399
+ corr, dimension_order = order_dimensions(data_copy, use_absolute_corr=options.use_absolute_correlations)
400
+
401
+ axis_dist = calculate_axes_positions(
402
+ dimension_order,
403
+ corr,
404
+ dist_parameter=options.distance_parameter,
405
+ distance_formula=options.distance_formula,
406
+ )
407
+
408
+ ordered_dimension_names = [data_copy.columns[i] for i in dimension_order]
409
+ axis_positions = {name: axis_dist[i] for i, name in enumerate(ordered_dimension_names)}
410
+ else:
411
+ axis_positions = options.axis_positions
412
+ ordered_dimension_names = sorted(axis_positions.keys(), key=axis_positions.get)
413
+
414
+ clusters = cluster(data_copy, options.clustering_algorithm)
415
+
416
+ if min(clusters) <= 0:
417
+ clusters = clusters - np.min(clusters) + 1 # translate minimum to 1.
418
+
419
+ # some sanity check: check if all cluster IDs are contiguous integers starting at 1, ending at number of clusters
420
+ unique_clusters = np.unique(clusters)
421
+ max_cluster_id = max(clusters)
422
+ if not all(i in unique_clusters for i in range(1, max_cluster_id + 1)):
423
+ warn(
424
+ """Cluster IDs are not contiguous integers starting at 1.
425
+ This may cause issues with the color mapping in the visualization.""",
426
+ category=UserWarning,
427
+ stacklevel=2,
428
+ )
429
+
430
+ cluster_column_name = "cluster"
431
+ if cluster_column_name in data_copy.columns:
432
+ cluster_column_name = "cluster_id"
433
+
434
+ data_copy = data_copy.with_columns(pl.Series(cluster_column_name, clusters))
435
+ grouped = data_copy.group_by(cluster_column_name)
436
+ min_percentile = (1 - options.interval_size) / 2
437
+ max_percentile = 1 - min_percentile
438
+ mins = grouped.quantile(min_percentile)
439
+ maxs = grouped.quantile(max_percentile)
440
+ medians = grouped.median()
441
+ frequencies = grouped.len()
442
+ bands_dict = {
443
+ cluster_id: {
444
+ col_name: (
445
+ mins.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0],
446
+ maxs.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0],
447
+ )
448
+ for col_name in ordered_dimension_names
449
+ }
450
+ for cluster_id in mins[cluster_column_name].to_list()
451
+ }
452
+ medians_dict = {
453
+ cluster_id: {
454
+ col_name: medians.filter(pl.col(cluster_column_name) == cluster_id)[col_name][0]
455
+ for col_name in ordered_dimension_names
456
+ }
457
+ for cluster_id in medians[cluster_column_name].to_list()
458
+ }
459
+ frequencies_dict = {
460
+ cluster_id: frequencies.filter(pl.col(cluster_column_name) == cluster_id)["len"][0]
461
+ for cluster_id in frequencies[cluster_column_name].to_list()
462
+ }
463
+
464
+ if options.scales is None:
465
+ scales: dict[str, tuple[float, float]] = {
466
+ dimension: (data_copy[dimension].min(), data_copy[dimension].max()) for dimension in ordered_dimension_names
467
+ }
468
+ options.scales = scales
469
+ return SCOREBandsResult(
470
+ options=options,
471
+ ordered_dimensions=ordered_dimension_names,
472
+ clusters=clusters.tolist(),
473
+ axis_positions=axis_positions,
474
+ bands=bands_dict,
475
+ medians=medians_dict,
476
+ cardinalities=frequencies_dict,
463
477
  )
464
- return fig1, corr, obj_order, groups, axis_dist
478
+
479
+
480
+ def plot_score(data: pl.DataFrame, result: SCOREBandsResult) -> go.Figure:
481
+ """Generate the SCORE Bands figure from the SCOREBandsResult data.
482
+
483
+ Args:
484
+ data (pl.DataFrame): Dataframe of objective values. The column names should be the objective names. Each row
485
+ should be an objective vector.
486
+ result (SCOREBandsResult): The result containing all relevant data for the SCORE bands visualization.
487
+
488
+ Returns:
489
+ go.Figure: The SCORE bands plot.
490
+ """
491
+ column_names = result.ordered_dimensions
492
+
493
+ clusters = np.sort(np.unique(result.clusters))
494
+
495
+ if len(clusters) <= 8:
496
+ colorscale = cm.get_cmap("Accent", len(clusters))
497
+ else:
498
+ colorscale = cm.get_cmap("tab20", len(clusters))
499
+
500
+ if result.options.scales is None:
501
+ raise ValueError("Scales must be provided in the SCOREBandsResult to plot the figure.")
502
+
503
+ scale_min = pl.DataFrame({name: result.options.scales[name][0] for name in result.options.scales})
504
+ scale_max = pl.DataFrame({name: result.options.scales[name][1] for name in result.options.scales})
505
+
506
+ scaled_data = (data[column_names] - scale_min) / (scale_max - scale_min)
507
+
508
+ fig = go.Figure()
509
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
510
+ fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
511
+ fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
512
+
513
+ cluster_column_name = "cluster"
514
+ if cluster_column_name in scaled_data.columns:
515
+ cluster_column_name = "cluster_id"
516
+ scaled_data = scaled_data.with_columns(pl.Series(cluster_column_name, result.clusters))
517
+
518
+ if result.options.descriptive_names is None:
519
+ descriptive_names = {name: name for name in column_names}
520
+ else:
521
+ descriptive_names = result.options.descriptive_names
522
+ if result.options.units is None:
523
+ units = {name: "" for name in column_names}
524
+ else:
525
+ units = result.options.units
526
+
527
+ num_ticks = 6
528
+ # Add axes
529
+ for i, col_name in enumerate(column_names):
530
+ label_text = np.linspace(result.options.scales[col_name][0], result.options.scales[col_name][1], num_ticks)
531
+ label_text = ["{:.5g}".format(i) for i in label_text]
532
+ # label_text[0] = "<<"
533
+ # label_text[-1] = ">>"
534
+ heights = np.linspace(0, 1, num_ticks)
535
+ # Axis lines
536
+ fig.add_scatter(
537
+ x=[result.axis_positions[col_name]] * num_ticks,
538
+ y=heights,
539
+ text=label_text,
540
+ textposition="middle left",
541
+ mode="markers+lines+text",
542
+ line={"color": "black"},
543
+ showlegend=False,
544
+ )
545
+ # Column Name
546
+ fig.add_scatter(
547
+ x=[result.axis_positions[col_name]],
548
+ y=[1.20],
549
+ text=f"{descriptive_names[col_name]}",
550
+ textfont={"size": 20},
551
+ mode="text",
552
+ showlegend=False,
553
+ )
554
+ # Units
555
+ fig.add_scatter(
556
+ x=[result.axis_positions[col_name]],
557
+ y=[1.10],
558
+ text=f"{units[col_name]}",
559
+ textfont={"size": 12},
560
+ mode="text",
561
+ showlegend=False,
562
+ )
563
+ # Add bands
564
+ for cluster_id in sorted(result.bands.keys()):
565
+ r, g, b, a = colorscale(cluster_id - 1) # Needed as cluster numbering starts at 1
566
+ a = 0.6
567
+ color_bands = f"rgba({r}, {g}, {b}, {a})"
568
+ color_soln = f"rgba({r}, {g}, {b}, {a})"
569
+
570
+ lows = [
571
+ (result.bands[cluster_id][col_name][0] - result.options.scales[col_name][0])
572
+ / (result.options.scales[col_name][1] - result.options.scales[col_name][0])
573
+ for col_name in column_names
574
+ ]
575
+ highs = [
576
+ (result.bands[cluster_id][col_name][1] - result.options.scales[col_name][0])
577
+ / (result.options.scales[col_name][1] - result.options.scales[col_name][0])
578
+ for col_name in column_names
579
+ ]
580
+ medians = [
581
+ (result.medians[cluster_id][col_name] - result.options.scales[col_name][0])
582
+ / (result.options.scales[col_name][1] - result.options.scales[col_name][0])
583
+ for col_name in column_names
584
+ ]
585
+
586
+ fig.add_scatter(
587
+ x=[result.axis_positions[col_name] for col_name in column_names],
588
+ y=lows,
589
+ line={"color": color_bands},
590
+ name=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}; "
591
+ f"{result.cardinalities[cluster_id]} Solutions ",
592
+ mode="lines",
593
+ legendgroup=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}",
594
+ showlegend=True,
595
+ line_shape="spline",
596
+ hovertext=f"Cluster {cluster_id}",
597
+ )
598
+ # upper bound of the band
599
+ fig.add_scatter(
600
+ x=[result.axis_positions[col_name] for col_name in column_names],
601
+ y=highs,
602
+ line={"color": color_bands},
603
+ name=f"Cluster {cluster_id}",
604
+ fillcolor=color_bands,
605
+ mode="lines",
606
+ legendgroup=f"{int(100 * result.options.interval_size)}% band: Cluster {cluster_id}",
607
+ showlegend=False,
608
+ line_shape="spline",
609
+ fill="tonexty",
610
+ hovertext=f"Cluster {cluster_id}",
611
+ )
612
+
613
+ if result.options.include_medians:
614
+ # median
615
+ fig.add_scatter(
616
+ x=[result.axis_positions[col_name] for col_name in column_names],
617
+ y=medians,
618
+ line={"color": color_bands},
619
+ name=f"Median: Cluster {cluster_id}",
620
+ mode="lines+markers",
621
+ marker={"line": {"color": "Black", "width": 2}},
622
+ legendgroup=f"Median: Cluster {cluster_id}",
623
+ showlegend=True,
624
+ )
625
+ fig.update_layout(font_size=18)
626
+ fig.update_layout(legend={"orientation": "h", "yanchor": "top"})
627
+ return fig