workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -2,32 +2,122 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import base64
|
|
5
|
-
import re
|
|
6
5
|
from typing import Optional, Tuple
|
|
7
6
|
from rdkit import Chem
|
|
8
7
|
from rdkit.Chem import AllChem, Draw
|
|
9
8
|
from rdkit.Chem.Draw import rdMolDraw2D
|
|
9
|
+
from dash import html
|
|
10
|
+
|
|
11
|
+
# Workbench Imports
|
|
12
|
+
from workbench.utils.color_utils import is_dark
|
|
10
13
|
|
|
11
14
|
# Set up the logger
|
|
12
15
|
log = logging.getLogger("workbench")
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
def
|
|
16
|
-
|
|
18
|
+
def molecule_hover_tooltip(
|
|
19
|
+
smiles: str, mol_id: str = None, width: int = 300, height: int = 200, background: str = None
|
|
20
|
+
) -> list:
|
|
21
|
+
"""Generate a molecule hover tooltip from a SMILES string.
|
|
22
|
+
|
|
23
|
+
This function creates a visually appealing tooltip with a dark background
|
|
24
|
+
that displays the molecule ID at the top and structure below when hovering
|
|
25
|
+
over scatter plot points.
|
|
17
26
|
|
|
18
27
|
Args:
|
|
19
|
-
|
|
28
|
+
smiles: SMILES string representing the molecule
|
|
29
|
+
mol_id: Optional molecule ID to display at the top of the tooltip
|
|
30
|
+
width: Width of the molecule image in pixels (default: 300)
|
|
31
|
+
height: Height of the molecule image in pixels (default: 200)
|
|
32
|
+
background: Optional background color (if None, uses dark gray)
|
|
20
33
|
|
|
21
34
|
Returns:
|
|
22
|
-
|
|
35
|
+
list: A list containing an html.Div with the ID header and molecule SVG,
|
|
36
|
+
or an html.Div with an error message if rendering fails
|
|
23
37
|
"""
|
|
24
|
-
|
|
25
|
-
if not match:
|
|
26
|
-
log.warning(f"Invalid color format: {color}, defaulting to dark")
|
|
27
|
-
return True # Default to dark mode on error
|
|
38
|
+
try:
|
|
28
39
|
|
|
29
|
-
|
|
30
|
-
|
|
40
|
+
# Use provided background or default to dark gray
|
|
41
|
+
if background is None:
|
|
42
|
+
background = "rgba(64, 64, 64, 1)"
|
|
43
|
+
|
|
44
|
+
# Generate the SVG image from SMILES (base64 encoded data URI)
|
|
45
|
+
img = svg_from_smiles(smiles, width, height, background=background)
|
|
46
|
+
|
|
47
|
+
if img is None:
|
|
48
|
+
log.warning(f"Could not render molecule for SMILES: {smiles}")
|
|
49
|
+
return [
|
|
50
|
+
html.Div(
|
|
51
|
+
"Invalid SMILES",
|
|
52
|
+
className="custom-tooltip",
|
|
53
|
+
style={
|
|
54
|
+
"padding": "10px",
|
|
55
|
+
"color": "rgb(255, 140, 140)",
|
|
56
|
+
"width": f"{width}px",
|
|
57
|
+
"height": f"{height}px",
|
|
58
|
+
"display": "flex",
|
|
59
|
+
"alignItems": "center",
|
|
60
|
+
"justifyContent": "center",
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
# Build the tooltip with ID header and molecule image
|
|
66
|
+
children = []
|
|
67
|
+
|
|
68
|
+
# Add ID header if provided
|
|
69
|
+
if mol_id is not None:
|
|
70
|
+
# Set text color based on background brightness
|
|
71
|
+
text_color = "rgb(200, 200, 200)" if is_dark(background) else "rgb(60, 60, 60)"
|
|
72
|
+
children.append(
|
|
73
|
+
html.Div(
|
|
74
|
+
str(mol_id),
|
|
75
|
+
style={
|
|
76
|
+
"textAlign": "center",
|
|
77
|
+
"padding": "8px",
|
|
78
|
+
"color": text_color,
|
|
79
|
+
"fontSize": "14px",
|
|
80
|
+
"fontWeight": "bold",
|
|
81
|
+
"borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Add molecule image
|
|
87
|
+
children.append(
|
|
88
|
+
html.Img(
|
|
89
|
+
src=img,
|
|
90
|
+
style={"padding": "0px", "margin": "0px", "display": "block"},
|
|
91
|
+
width=str(width),
|
|
92
|
+
height=str(height),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return [
|
|
97
|
+
html.Div(
|
|
98
|
+
children,
|
|
99
|
+
className="custom-tooltip",
|
|
100
|
+
style={"padding": "0px", "margin": "0px"},
|
|
101
|
+
)
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
except ImportError as e:
|
|
105
|
+
log.error(f"RDKit not available for molecule rendering: {e}")
|
|
106
|
+
return [
|
|
107
|
+
html.Div(
|
|
108
|
+
"RDKit not installed",
|
|
109
|
+
className="custom-tooltip",
|
|
110
|
+
style={
|
|
111
|
+
"padding": "10px",
|
|
112
|
+
"color": "rgb(255, 195, 140)",
|
|
113
|
+
"width": f"{width}px",
|
|
114
|
+
"height": f"{height}px",
|
|
115
|
+
"display": "flex",
|
|
116
|
+
"alignItems": "center",
|
|
117
|
+
"justifyContent": "center",
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
]
|
|
31
121
|
|
|
32
122
|
|
|
33
123
|
def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
|
|
@@ -75,7 +165,13 @@ def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> No
|
|
|
75
165
|
options: RDKit drawing options object
|
|
76
166
|
background: Background color string
|
|
77
167
|
"""
|
|
78
|
-
|
|
168
|
+
try:
|
|
169
|
+
if is_dark(background):
|
|
170
|
+
rdMolDraw2D.SetDarkMode(options)
|
|
171
|
+
# Light backgrounds use RDKit defaults (no action needed)
|
|
172
|
+
except ValueError:
|
|
173
|
+
# Default to dark mode if color format is invalid
|
|
174
|
+
log.warning(f"Invalid color format: {background}, defaulting to dark mode")
|
|
79
175
|
rdMolDraw2D.SetDarkMode(options)
|
|
80
176
|
options.setBackgroundColour(_rgba_to_tuple(background))
|
|
81
177
|
|
|
@@ -137,7 +233,7 @@ def svg_from_smiles(
|
|
|
137
233
|
drawer.DrawMolecule(mol)
|
|
138
234
|
drawer.FinishDrawing()
|
|
139
235
|
|
|
140
|
-
# Encode SVG
|
|
236
|
+
# Encode SVG as base64 data URI
|
|
141
237
|
svg = drawer.GetDrawingText()
|
|
142
238
|
encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
|
143
239
|
return f"data:image/svg+xml;base64,{encoded_svg}"
|
|
@@ -222,7 +318,7 @@ if __name__ == "__main__":
|
|
|
222
318
|
# Test 6: Color parsing functions
|
|
223
319
|
print("\n6. Testing color utility functions...")
|
|
224
320
|
test_colors = [
|
|
225
|
-
("invalid_color",
|
|
321
|
+
("invalid_color", None, (0.25, 0.25, 0.25, 1.0)), # Should raise ValueError
|
|
226
322
|
("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
|
|
227
323
|
("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
|
|
228
324
|
("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
|
|
@@ -230,24 +326,38 @@ if __name__ == "__main__":
|
|
|
230
326
|
]
|
|
231
327
|
|
|
232
328
|
for color, expected_dark, expected_tuple in test_colors:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
329
|
+
try:
|
|
330
|
+
is_dark_result = is_dark(color)
|
|
331
|
+
if expected_dark is None:
|
|
332
|
+
print(f" ✗ is_dark('{color[:20]}...'): Expected ValueError but got {is_dark_result}")
|
|
333
|
+
else:
|
|
334
|
+
dark_status = "✓" if is_dark_result == expected_dark else "✗"
|
|
335
|
+
print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
|
|
336
|
+
except ValueError:
|
|
337
|
+
if expected_dark is None:
|
|
338
|
+
print(f" ✓ is_dark('{color[:20]}...'): Correctly raised ValueError")
|
|
339
|
+
else:
|
|
340
|
+
print(f" ✗ is_dark('{color[:20]}...'): Unexpected ValueError")
|
|
238
341
|
|
|
342
|
+
tuple_result = _rgba_to_tuple(color)
|
|
239
343
|
# Check tuple values with tolerance for floating point
|
|
240
344
|
tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
|
|
241
345
|
tuple_status = "✓" if tuple_match else "✗"
|
|
242
346
|
print(f" {tuple_status} rgba_to_tuple('{color[:20]}...'): matches expected")
|
|
243
347
|
|
|
244
|
-
# Test the
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
348
|
+
# Test the tooltip generation in a simple Dash app
|
|
349
|
+
from dash import Dash
|
|
350
|
+
|
|
351
|
+
app = Dash(__name__)
|
|
352
|
+
app.layout = html.Div(
|
|
353
|
+
[
|
|
354
|
+
html.Div("Tooltip Preview:", style={"color": "white", "marginBottom": "20px"}),
|
|
355
|
+
*molecule_hover_tooltip("CC(=O)OC1=CC=CC=C1C(=O)O", mol_id="Aspirin", background="rgba(200, 30, 30, 1)"),
|
|
356
|
+
],
|
|
357
|
+
style={"background": "#1a1a1a", "padding": "50px"},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
if __name__ == "__main__":
|
|
361
|
+
app.run(debug=True)
|
|
252
362
|
|
|
253
363
|
print("\n✅ All tests completed!")
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Clientside JavaScript callbacks for Dash components.
|
|
2
|
+
|
|
3
|
+
These functions return JavaScript code strings for use with Dash's clientside_callback.
|
|
4
|
+
Using clientside callbacks avoids server round-trips for simple UI interactions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def circle_overlay_callback(circle_data_uri: str) -> str:
|
|
9
|
+
"""Returns JS function for circle overlay on scatter plot hover.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
circle_data_uri: Base64-encoded SVG data URI for the circle overlay
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
JavaScript function string for use with clientside_callback
|
|
16
|
+
"""
|
|
17
|
+
return f"""
|
|
18
|
+
function(hoverData) {{
|
|
19
|
+
if (!hoverData) {{
|
|
20
|
+
return [false, window.dash_clientside.no_update, window.dash_clientside.no_update];
|
|
21
|
+
}}
|
|
22
|
+
var bbox = hoverData.points[0].bbox;
|
|
23
|
+
var centerX = (bbox.x0 + bbox.x1) / 2;
|
|
24
|
+
var centerY = (bbox.y0 + bbox.y1) / 2;
|
|
25
|
+
var adjustedBbox = {{
|
|
26
|
+
x0: centerX - 50,
|
|
27
|
+
x1: centerX + 50,
|
|
28
|
+
y0: centerY - 162,
|
|
29
|
+
y1: centerY - 62
|
|
30
|
+
}};
|
|
31
|
+
var imgElement = {{
|
|
32
|
+
type: 'Img',
|
|
33
|
+
namespace: 'dash_html_components',
|
|
34
|
+
props: {{
|
|
35
|
+
src: '{circle_data_uri}',
|
|
36
|
+
style: {{width: '100px', height: '100px'}}
|
|
37
|
+
}}
|
|
38
|
+
}};
|
|
39
|
+
return [true, adjustedBbox, [imgElement]];
|
|
40
|
+
}}
|
|
41
|
+
"""
|
|
@@ -185,6 +185,63 @@ def dict_to_collapsible_html(data: dict, title: str = None, collapse_all: bool =
|
|
|
185
185
|
return result
|
|
186
186
|
|
|
187
187
|
|
|
188
|
+
def df_to_html_table(df, round_digits: int = 2, margin_bottom: int = 30) -> str:
|
|
189
|
+
"""Convert a DataFrame to a compact styled HTML table (horizontal layout).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
df: DataFrame with metrics (can be single or multi-row)
|
|
193
|
+
round_digits: Number of decimal places to round to (default: 2)
|
|
194
|
+
margin_bottom: Bottom margin in pixels (default: 30)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
str: HTML table string
|
|
198
|
+
"""
|
|
199
|
+
# Handle index: reset if named (keeps as column), otherwise drop
|
|
200
|
+
if df.index.name:
|
|
201
|
+
df = df.reset_index()
|
|
202
|
+
else:
|
|
203
|
+
df = df.reset_index(drop=True)
|
|
204
|
+
|
|
205
|
+
# Round numeric columns
|
|
206
|
+
df = df.round(round_digits)
|
|
207
|
+
|
|
208
|
+
# Table styles
|
|
209
|
+
container_style = f"display: flex; justify-content: center; margin-top: 10px; margin-bottom: {margin_bottom}px;"
|
|
210
|
+
table_style = "border-collapse: collapse; width: 100%; font-size: 15px;"
|
|
211
|
+
header_style = (
|
|
212
|
+
"background: linear-gradient(to bottom, #4a4a4a 0%, #2d2d2d 100%); "
|
|
213
|
+
"color: white; padding: 4px 8px; text-align: center;"
|
|
214
|
+
)
|
|
215
|
+
cell_style = "padding: 3px 8px; text-align: center; border-bottom: 1px solid #444;"
|
|
216
|
+
|
|
217
|
+
# Build the HTML table (wrapped in centered container)
|
|
218
|
+
html = f'<div style="{container_style}"><table style="{table_style}">'
|
|
219
|
+
|
|
220
|
+
# Header row
|
|
221
|
+
html += "<tr>"
|
|
222
|
+
for col in df.columns:
|
|
223
|
+
html += f'<th style="{header_style}">{col}</th>'
|
|
224
|
+
html += "</tr>"
|
|
225
|
+
|
|
226
|
+
# Data rows
|
|
227
|
+
for _, row in df.iterrows():
|
|
228
|
+
html += "<tr>"
|
|
229
|
+
for val in row:
|
|
230
|
+
# Format value: integers without decimal, floats rounded
|
|
231
|
+
if isinstance(val, float):
|
|
232
|
+
if val == int(val):
|
|
233
|
+
formatted_val = int(val)
|
|
234
|
+
else:
|
|
235
|
+
formatted_val = round(val, round_digits)
|
|
236
|
+
else:
|
|
237
|
+
formatted_val = val
|
|
238
|
+
html += f'<td style="{cell_style}">{formatted_val}</td>'
|
|
239
|
+
html += "</tr>"
|
|
240
|
+
|
|
241
|
+
html += "</table></div>"
|
|
242
|
+
return html
|
|
243
|
+
|
|
244
|
+
|
|
188
245
|
if __name__ == "__main__":
|
|
189
246
|
"""Exercise the Markdown Utilities"""
|
|
190
247
|
from workbench.api.model import Model
|
workbench/utils/model_utils.py
CHANGED
workbench/utils/plot_utils.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
"""Plot Utilities for Workbench"""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
6
|
import plotly.graph_objects as go
|
|
6
7
|
|
|
8
|
+
log = logging.getLogger("workbench")
|
|
9
|
+
|
|
7
10
|
|
|
8
11
|
# For approximating beeswarm effect
|
|
9
12
|
def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
10
13
|
"""
|
|
11
|
-
Generate
|
|
14
|
+
Generate beeswarm offsets using random jitter with collision avoidance.
|
|
12
15
|
|
|
13
16
|
Args:
|
|
14
17
|
values: Array of positions to be adjusted
|
|
@@ -22,42 +25,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
|
22
25
|
values = np.asarray(values)
|
|
23
26
|
rounded = np.round(values, precision)
|
|
24
27
|
offsets = np.zeros_like(values, dtype=float)
|
|
25
|
-
|
|
26
|
-
# Sort indices by original values
|
|
27
|
-
sorted_idx = np.argsort(values)
|
|
28
|
+
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
|
28
29
|
|
|
29
30
|
for val in np.unique(rounded):
|
|
30
31
|
# Get indices belonging to this group
|
|
31
|
-
|
|
32
|
+
group_mask = rounded == val
|
|
33
|
+
group_idx = np.where(group_mask)[0]
|
|
32
34
|
|
|
33
35
|
if len(group_idx) > 1:
|
|
34
36
|
# Track occupied positions for collision detection
|
|
35
37
|
occupied = []
|
|
36
38
|
|
|
37
39
|
for idx in group_idx:
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
40
|
+
# Try random positions, starting near center and expanding outward
|
|
41
|
+
best_offset = 0
|
|
42
|
+
found = False
|
|
43
|
+
|
|
44
|
+
# First point goes to center
|
|
45
|
+
if not occupied:
|
|
46
|
+
found = True
|
|
47
|
+
else:
|
|
48
|
+
# Try random positions with increasing spread
|
|
49
|
+
for attempt in range(50):
|
|
50
|
+
# Gradually increase the range of random offsets
|
|
51
|
+
spread = min(max_offset, point_size * (1 + attempt * 0.5))
|
|
52
|
+
offset = rng.uniform(-spread, spread)
|
|
53
|
+
|
|
54
|
+
# Check for collision with occupied positions
|
|
55
|
+
if not any(abs(offset - pos) < point_size for pos in occupied):
|
|
56
|
+
best_offset = offset
|
|
57
|
+
found = True
|
|
58
|
+
break
|
|
59
|
+
|
|
60
|
+
# If no free position found after attempts, find the least crowded spot
|
|
61
|
+
if not found:
|
|
62
|
+
# Try a grid of positions and pick one with most space
|
|
63
|
+
candidates = np.linspace(-max_offset, max_offset, 20)
|
|
64
|
+
rng.shuffle(candidates)
|
|
65
|
+
for candidate in candidates:
|
|
66
|
+
if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
|
|
67
|
+
best_offset = candidate
|
|
68
|
+
found = True
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
# Last resort: just use a random position within bounds
|
|
72
|
+
if not found:
|
|
73
|
+
best_offset = rng.uniform(-max_offset, max_offset)
|
|
74
|
+
|
|
75
|
+
offsets[idx] = best_offset
|
|
76
|
+
occupied.append(best_offset)
|
|
61
77
|
|
|
62
78
|
return offsets
|
|
63
79
|
|
|
@@ -126,13 +142,13 @@ def prediction_intervals(df, figure, x_col):
|
|
|
126
142
|
|
|
127
143
|
# Sort dataframe by x_col for connected lines
|
|
128
144
|
sorted_df = df.sort_values(by=x_col)
|
|
129
|
-
# Add outer band (q_025 to q_975) - more transparent
|
|
145
|
+
# Add outer band (q_025 to q_975) - desaturated blue-gray, more transparent
|
|
130
146
|
figure.add_trace(
|
|
131
147
|
go.Scatter(
|
|
132
148
|
x=sorted_df[x_col],
|
|
133
149
|
y=sorted_df["q_025"],
|
|
134
150
|
mode="lines",
|
|
135
|
-
line=dict(width=1, color="rgba(
|
|
151
|
+
line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
|
|
136
152
|
name="2.5 Percentile",
|
|
137
153
|
hoverinfo="skip",
|
|
138
154
|
showlegend=False,
|
|
@@ -143,21 +159,21 @@ def prediction_intervals(df, figure, x_col):
|
|
|
143
159
|
x=sorted_df[x_col],
|
|
144
160
|
y=sorted_df["q_975"],
|
|
145
161
|
mode="lines",
|
|
146
|
-
line=dict(width=1, color="rgba(
|
|
162
|
+
line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
|
|
147
163
|
name="97.5 Percentile",
|
|
148
164
|
hoverinfo="skip",
|
|
149
165
|
showlegend=False,
|
|
150
166
|
fill="tonexty",
|
|
151
|
-
fillcolor="rgba(
|
|
167
|
+
fillcolor="rgba(120, 130, 180, 0.15)",
|
|
152
168
|
)
|
|
153
169
|
)
|
|
154
|
-
# Add inner band (q_25 to q_75) -
|
|
170
|
+
# Add inner band (q_25 to q_75) - desaturated green-gray, slightly more visible
|
|
155
171
|
figure.add_trace(
|
|
156
172
|
go.Scatter(
|
|
157
173
|
x=sorted_df[x_col],
|
|
158
174
|
y=sorted_df["q_25"],
|
|
159
175
|
mode="lines",
|
|
160
|
-
line=dict(width=1, color="rgba(
|
|
176
|
+
line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
|
|
161
177
|
name="25 Percentile",
|
|
162
178
|
hoverinfo="skip",
|
|
163
179
|
showlegend=False,
|
|
@@ -168,12 +184,12 @@ def prediction_intervals(df, figure, x_col):
|
|
|
168
184
|
x=sorted_df[x_col],
|
|
169
185
|
y=sorted_df["q_75"],
|
|
170
186
|
mode="lines",
|
|
171
|
-
line=dict(width=1, color="rgba(
|
|
187
|
+
line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
|
|
172
188
|
name="75 Percentile",
|
|
173
189
|
hoverinfo="skip",
|
|
174
190
|
showlegend=False,
|
|
175
191
|
fill="tonexty",
|
|
176
|
-
fillcolor="rgba(
|
|
192
|
+
fillcolor="rgba(130, 180, 140, 0.18)",
|
|
177
193
|
)
|
|
178
194
|
)
|
|
179
195
|
return figure
|