workbench 0.8.219__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/api/feature_set.py +0 -1
  6. workbench/core/artifacts/feature_set_core.py +183 -228
  7. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  9. workbench/model_scripts/chemprop/chemprop.template +193 -68
  10. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  11. workbench/model_scripts/pytorch_model/generated_model_script.py +3 -3
  12. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  13. workbench/scripts/ml_pipeline_sqs.py +71 -2
  14. workbench/themes/light/custom.css +7 -1
  15. workbench/themes/midnight_blue/custom.css +34 -0
  16. workbench/utils/chem_utils/projections.py +16 -6
  17. workbench/utils/model_utils.py +0 -1
  18. workbench/utils/plot_utils.py +146 -28
  19. workbench/utils/theme_manager.py +95 -30
  20. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  21. workbench/web_interface/components/settings_menu.py +184 -0
  22. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  23. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/RECORD +27 -25
  24. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  25. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +0 -0
  26. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  27. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -17,18 +17,28 @@ log = logging.getLogger("workbench")
17
17
 
18
18
  def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
19
19
  """
20
- Convert bitstring fingerprints to numpy matrix.
20
+ Convert fingerprints to numpy matrix.
21
+
22
+ Supports two formats (auto-detected):
23
+ - Bitstrings: "10110010..." → matrix of 0s and 1s
24
+ - Count vectors: "0,3,0,1,5,..." → matrix of counts (or binary if dtype=np.bool_)
21
25
 
22
26
  Args:
23
- fingerprints: pandas Series or list of bitstring fingerprints
24
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
27
+ fingerprints: pandas Series or list of fingerprints
28
+ dtype: numpy data type (uint8 is default; np.bool_ for Jaccard computations)
25
29
 
26
30
  Returns:
27
31
  dense numpy array of shape (n_molecules, n_bits)
28
32
  """
29
-
30
- # Dense matrix representation (we might support sparse in the future)
31
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
33
+ # Auto-detect format based on first fingerprint
34
+ sample = str(fingerprints.iloc[0] if hasattr(fingerprints, "iloc") else fingerprints[0])
35
+ if "," in sample:
36
+ # Count vector format: comma-separated integers
37
+ matrix = np.array([list(map(int, fp.split(","))) for fp in fingerprints], dtype=dtype)
38
+ else:
39
+ # Bitstring format: each character is a bit
40
+ matrix = np.array([list(fp) for fp in fingerprints], dtype=dtype)
41
+ return matrix
32
42
 
33
43
 
34
44
  def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
@@ -173,7 +173,6 @@ def fingerprint_prox_model_local(
173
173
  include_all_columns=include_all_columns,
174
174
  radius=radius,
175
175
  n_bits=n_bits,
176
- counts=counts,
177
176
  )
178
177
 
179
178
 
@@ -1,14 +1,18 @@
1
1
  """Plot Utilities for Workbench"""
2
2
 
3
+ import logging
3
4
  import numpy as np
4
5
  import pandas as pd
5
6
  import plotly.graph_objects as go
7
+ from dash import html
8
+
9
+ log = logging.getLogger("workbench")
6
10
 
7
11
 
8
12
  # For approximating beeswarm effect
9
13
  def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
10
14
  """
11
- Generate optimal beeswarm offsets with a maximum limit.
15
+ Generate beeswarm offsets using random jitter with collision avoidance.
12
16
 
13
17
  Args:
14
18
  values: Array of positions to be adjusted
@@ -22,42 +26,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
22
26
  values = np.asarray(values)
23
27
  rounded = np.round(values, precision)
24
28
  offsets = np.zeros_like(values, dtype=float)
25
-
26
- # Sort indices by original values
27
- sorted_idx = np.argsort(values)
29
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
28
30
 
29
31
  for val in np.unique(rounded):
30
32
  # Get indices belonging to this group
31
- group_idx = sorted_idx[np.isin(sorted_idx, np.where(rounded == val)[0])]
33
+ group_mask = rounded == val
34
+ group_idx = np.where(group_mask)[0]
32
35
 
33
36
  if len(group_idx) > 1:
34
37
  # Track occupied positions for collision detection
35
38
  occupied = []
36
39
 
37
40
  for idx in group_idx:
38
- # Find best position with no collision
39
- offset = 0
40
- direction = 1
41
- step = 0
42
-
43
- while True:
44
- # Check if current offset position is free
45
- collision = any(abs(offset - pos) < point_size for pos in occupied)
46
-
47
- if not collision or abs(offset) >= max_offset:
48
- # Accept position if no collision or max offset reached
49
- if abs(offset) > max_offset:
50
- # Clamp to maximum
51
- offset = max_offset * (1 if offset > 0 else -1)
52
- break
53
-
54
- # Switch sides with increasing distance
55
- step += 0.25
56
- direction *= -1
57
- offset = direction * step * point_size
58
-
59
- offsets[idx] = offset
60
- occupied.append(offset)
41
+ # Try random positions, starting near center and expanding outward
42
+ best_offset = 0
43
+ found = False
44
+
45
+ # First point goes to center
46
+ if not occupied:
47
+ found = True
48
+ else:
49
+ # Try random positions with increasing spread
50
+ for attempt in range(50):
51
+ # Gradually increase the range of random offsets
52
+ spread = min(max_offset, point_size * (1 + attempt * 0.5))
53
+ offset = rng.uniform(-spread, spread)
54
+
55
+ # Check for collision with occupied positions
56
+ if not any(abs(offset - pos) < point_size for pos in occupied):
57
+ best_offset = offset
58
+ found = True
59
+ break
60
+
61
+ # If no free position found after attempts, find the least crowded spot
62
+ if not found:
63
+ # Try a grid of positions and pick one with most space
64
+ candidates = np.linspace(-max_offset, max_offset, 20)
65
+ rng.shuffle(candidates)
66
+ for candidate in candidates:
67
+ if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
68
+ best_offset = candidate
69
+ found = True
70
+ break
71
+
72
+ # Last resort: just use a random position within bounds
73
+ if not found:
74
+ best_offset = rng.uniform(-max_offset, max_offset)
75
+
76
+ offsets[idx] = best_offset
77
+ occupied.append(best_offset)
61
78
 
62
79
  return offsets
63
80
 
@@ -179,6 +196,107 @@ def prediction_intervals(df, figure, x_col):
179
196
  return figure
180
197
 
181
198
 
199
+ def molecule_hover_tooltip(smiles: str, mol_id: str = None, width: int = 300, height: int = 200) -> list:
200
+ """Generate a molecule hover tooltip from a SMILES string.
201
+
202
+ This function creates a visually appealing tooltip with a dark background
203
+ that displays the molecule ID at the top and structure below when hovering
204
+ over scatter plot points.
205
+
206
+ Args:
207
+ smiles: SMILES string representing the molecule
208
+ mol_id: Optional molecule ID to display at the top of the tooltip
209
+ width: Width of the molecule image in pixels (default: 300)
210
+ height: Height of the molecule image in pixels (default: 200)
211
+
212
+ Returns:
213
+ list: A list containing an html.Div with the ID header and molecule SVG,
214
+ or an html.Div with an error message if rendering fails
215
+ """
216
+ try:
217
+ from workbench.utils.chem_utils.vis import svg_from_smiles
218
+ from workbench.utils.theme_manager import ThemeManager
219
+
220
+ # Get the background color from the current theme
221
+ background = ThemeManager().background()
222
+
223
+ # Generate the SVG image from SMILES
224
+ img = svg_from_smiles(smiles, width, height, background=background)
225
+
226
+ if img is None:
227
+ log.warning(f"Could not render molecule for SMILES: {smiles}")
228
+ return [
229
+ html.Div(
230
+ "Invalid SMILES",
231
+ className="custom-tooltip",
232
+ style={
233
+ "padding": "10px",
234
+ "color": "rgb(255, 140, 140)",
235
+ "width": f"{width}px",
236
+ "height": f"{height}px",
237
+ "display": "flex",
238
+ "alignItems": "center",
239
+ "justifyContent": "center",
240
+ },
241
+ )
242
+ ]
243
+
244
+ # Build the tooltip with ID header and molecule image
245
+ children = []
246
+
247
+ # Add ID header if provided
248
+ if mol_id is not None:
249
+ children.append(
250
+ html.Div(
251
+ str(mol_id),
252
+ style={
253
+ "textAlign": "center",
254
+ "padding": "8px",
255
+ "color": "rgb(200, 200, 200)",
256
+ "fontSize": "14px",
257
+ "fontWeight": "bold",
258
+ "borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
259
+ },
260
+ )
261
+ )
262
+
263
+ # Add molecule image
264
+ children.append(
265
+ html.Img(
266
+ src=img,
267
+ style={"padding": "0px", "margin": "0px", "display": "block"},
268
+ width=str(width),
269
+ height=str(height),
270
+ )
271
+ )
272
+
273
+ return [
274
+ html.Div(
275
+ children,
276
+ className="custom-tooltip",
277
+ style={"padding": "0px", "margin": "0px"},
278
+ )
279
+ ]
280
+
281
+ except ImportError as e:
282
+ log.error(f"RDKit not available for molecule rendering: {e}")
283
+ return [
284
+ html.Div(
285
+ "RDKit not installed",
286
+ className="custom-tooltip",
287
+ style={
288
+ "padding": "10px",
289
+ "color": "rgb(255, 195, 140)",
290
+ "width": f"{width}px",
291
+ "height": f"{height}px",
292
+ "display": "flex",
293
+ "alignItems": "center",
294
+ "justifyContent": "center",
295
+ },
296
+ )
297
+ ]
298
+
299
+
182
300
  if __name__ == "__main__":
183
301
  """Exercise the Plot Utilities"""
184
302
  import plotly.express as px
@@ -76,10 +76,28 @@ class ThemeManager:
76
76
  def set_theme(cls, theme_name: str):
77
77
  """Set the current theme."""
78
78
 
79
- # For 'auto', we try to grab a theme from the Parameter Store
80
- # if we can't find one, we'll set the theme to the default
79
+ # For 'auto', we check multiple sources in priority order:
80
+ # 1. Browser cookie (from localStorage, for per-user preference)
81
+ # 2. Parameter Store (for org-wide default)
82
+ # 3. Default theme
81
83
  if theme_name == "auto":
82
- theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False) or cls.default_theme
84
+ theme_name = None
85
+
86
+ # 1. Check Flask request cookie (set from localStorage)
87
+ try:
88
+ from flask import request, has_request_context
89
+
90
+ if has_request_context():
91
+ theme_name = request.cookies.get("wb_theme")
92
+ except Exception:
93
+ pass
94
+
95
+ # 2. Fall back to ParameterStore
96
+ if not theme_name:
97
+ theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
98
+
99
+ # 3. Fall back to default
100
+ theme_name = theme_name or cls.default_theme
83
101
 
84
102
  # Check if the theme is in our available themes
85
103
  if theme_name not in cls.available_themes:
@@ -104,9 +122,27 @@ class ThemeManager:
104
122
  cls.current_theme_name = theme_name
105
123
  cls.log.info(f"Theme set to '{theme_name}'")
106
124
 
125
+ # Bootstrap themes that are dark mode (from Bootswatch)
126
+ _dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
127
+
107
128
  @classmethod
108
129
  def dark_mode(cls) -> bool:
109
- """Check if the current theme is a dark mode theme."""
130
+ """Check if the current theme is a dark mode theme.
131
+
132
+ Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
133
+ Falls back to checking if 'dark' is in the theme name.
134
+ """
135
+ theme = cls.available_themes.get(cls.current_theme_name, {})
136
+ base_css = theme.get("base_css", "")
137
+
138
+ # Check if the base CSS URL contains a known dark Bootstrap theme
139
+ if base_css:
140
+ base_css_upper = base_css.upper()
141
+ for dark_theme in cls._dark_bootstrap_themes:
142
+ if dark_theme in base_css_upper:
143
+ return True
144
+
145
+ # Fallback: check if 'dark' is in the theme name
110
146
  return "dark" in cls.current_theme().lower()
111
147
 
112
148
  @classmethod
@@ -184,30 +220,57 @@ class ThemeManager:
184
220
 
185
221
  @classmethod
186
222
  def css_files(cls) -> list[str]:
187
- """Get the list of CSS files for the current theme."""
188
- theme = cls.available_themes[cls.current_theme_name]
223
+ """Get the list of CSS files for the current theme.
224
+
225
+ Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
226
+ """
189
227
  css_files = []
190
228
 
191
- # Add base.css or its CDN URL
192
- if theme["base_css"]:
193
- css_files.append(theme["base_css"])
229
+ # Use Flask route for base CSS (allows dynamic theme switching)
230
+ css_files.append("/base.css")
194
231
 
195
232
  # Add the DBC template CSS
196
233
  css_files.append(cls.dbc_css)
197
234
 
198
235
  # Add custom.css if it exists
199
- if theme["custom_css"]:
200
- css_files.append("/custom.css")
236
+ css_files.append("/custom.css")
201
237
 
202
238
  return css_files
203
239
 
240
+ @classmethod
241
+ def _get_theme_from_cookie(cls):
242
+ """Get the theme dict based on the wb_theme cookie, falling back to current theme."""
243
+ from flask import request
244
+
245
+ theme_name = request.cookies.get("wb_theme")
246
+ if theme_name and theme_name in cls.available_themes:
247
+ return cls.available_themes[theme_name], theme_name
248
+ return cls.available_themes[cls.current_theme_name], cls.current_theme_name
249
+
204
250
  @classmethod
205
251
  def register_css_route(cls, app):
206
- """Register Flask route for custom.css."""
252
+ """Register Flask routes for CSS and before_request hook for theme switching."""
253
+ from flask import redirect
254
+
255
+ @app.server.before_request
256
+ def check_theme_cookie():
257
+ """Check for theme cookie on each request and update theme if needed."""
258
+ _, theme_name = cls._get_theme_from_cookie()
259
+ if theme_name != cls.current_theme_name:
260
+ cls.set_theme(theme_name)
261
+
262
+ @app.server.route("/base.css")
263
+ def serve_base_css():
264
+ """Redirect to the appropriate Bootstrap theme CSS based on cookie."""
265
+ theme, _ = cls._get_theme_from_cookie()
266
+ if theme["base_css"]:
267
+ return redirect(theme["base_css"])
268
+ return "", 404
207
269
 
208
270
  @app.server.route("/custom.css")
209
271
  def serve_custom_css():
210
- theme = cls.available_themes[cls.current_theme_name]
272
+ """Serve the custom.css file based on cookie."""
273
+ theme, _ = cls._get_theme_from_cookie()
211
274
  if theme["custom_css"]:
212
275
  return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
213
276
  return "", 404
@@ -250,23 +313,25 @@ class ThemeManager:
250
313
  # Loop over each path in the theme path
251
314
  for theme_path in cls.theme_path_list:
252
315
  for theme_dir in theme_path.iterdir():
253
- if theme_dir.is_dir():
254
- theme_name = theme_dir.name
255
-
256
- # Grab the base.css URL
257
- base_css_url = cls._get_base_css_url(theme_dir)
258
-
259
- # Grab the plotly template json, custom.css, and branding json
260
- plotly_template = theme_dir / "plotly.json"
261
- custom_css = theme_dir / "custom.css"
262
- branding = theme_dir / "branding.json"
263
-
264
- cls.available_themes[theme_name] = {
265
- "base_css": base_css_url,
266
- "plotly_template": plotly_template,
267
- "custom_css": custom_css if custom_css.exists() else None,
268
- "branding": branding if branding.exists() else None,
269
- }
316
+ # Skip hidden directories (e.g., .idea, .git)
317
+ if not theme_dir.is_dir() or theme_dir.name.startswith("."):
318
+ continue
319
+ theme_name = theme_dir.name
320
+
321
+ # Grab the base.css URL
322
+ base_css_url = cls._get_base_css_url(theme_dir)
323
+
324
+ # Grab the plotly template json, custom.css, and branding json
325
+ plotly_template = theme_dir / "plotly.json"
326
+ custom_css = theme_dir / "custom.css"
327
+ branding = theme_dir / "branding.json"
328
+
329
+ cls.available_themes[theme_name] = {
330
+ "base_css": base_css_url,
331
+ "plotly_template": plotly_template,
332
+ "custom_css": custom_css if custom_css.exists() else None,
333
+ "branding": branding if branding.exists() else None,
334
+ }
270
335
 
271
336
  if not cls.available_themes:
272
337
  cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")