bec-widgets 0.44.5__py3-none-any.whl → 0.46.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,374 +0,0 @@
1
- # pylint: disable = no-name-in-module,missing-module-docstring
2
- import time
3
- from collections import defaultdict
4
-
5
- import numpy as np
6
- import pyqtgraph as pg
7
- from bec_lib import MessageEndpoints
8
- from qtpy.QtCore import Signal as pyqtSignal
9
- from qtpy.QtCore import Slot as pyqtSlot
10
- from qtpy.QtWidgets import QApplication, QVBoxLayout, QWidget
11
-
12
- from bec_widgets.utils import yaml_dialog
13
- from bec_widgets.utils.bec_dispatcher import BECDispatcher
14
-
15
- CONFIG_DEFAULT = {
16
- "plot_settings": {"colormap": "CET-L4", "num_columns": 1},
17
- "waveform2D": [
18
- {
19
- "plot_name": "Waveform 2D Scatter (1)",
20
- "x_label": "Sam X",
21
- "y_label": "Sam Y",
22
- "signals": {
23
- "x": [{"name": "samx", "entry": "samx"}],
24
- "y": [{"name": "samy", "entry": "samy"}],
25
- "z": [{"name": "gauss_bpm", "entry": "gauss_bpm"}],
26
- },
27
- },
28
- {
29
- "plot_name": "Waveform 2D Scatter (2)",
30
- "x_label": "Sam Y",
31
- "y_label": "Sam X",
32
- "signals": {
33
- "x": [{"name": "samy", "entry": "samy"}],
34
- "y": [{"name": "samx", "entry": "samx"}],
35
- "z": [{"name": "gauss_bpm", "entry": "gauss_bpm"}],
36
- },
37
- },
38
- ],
39
- }
40
-
41
-
42
- class BECMonitor2DScatter(QWidget):
43
- update_signal = pyqtSignal()
44
-
45
- def __init__(
46
- self,
47
- parent=None,
48
- client=None,
49
- config: dict = None,
50
- enable_crosshair: bool = True,
51
- gui_id=None,
52
- skip_validation: bool = True,
53
- toolbar_enabled=True,
54
- ):
55
- super().__init__(parent=parent)
56
-
57
- # Client and device manager from BEC
58
- self.plot_data = None
59
- bec_dispatcher = BECDispatcher()
60
- self.client = bec_dispatcher.client if client is None else client
61
- self.dev = self.client.device_manager.devices
62
- self.queue = self.client.queue
63
-
64
- self.validator = None # TODO implement validator when ready
65
- self.gui_id = gui_id
66
-
67
- if self.gui_id is None:
68
- self.gui_id = self.__class__.__name__ + str(time.time())
69
-
70
- # Connect dispatcher slots #TODO connect endpoints related to CLI
71
- bec_dispatcher.connect_slot(self.on_scan_segment, MessageEndpoints.scan_segment())
72
-
73
- # Config related variables
74
- self.plot_data = None
75
- self.plot_settings = None
76
- self.num_columns = None
77
- self.database = {}
78
- self.plots = {}
79
- self.grid_coordinates = []
80
-
81
- self.curves_data = {}
82
- # Current configuration
83
- self.config = config
84
- self.skip_validation = skip_validation
85
-
86
- # Enable crosshair
87
- self.enable_crosshair = enable_crosshair
88
-
89
- # Displayed Data
90
- self.database = {}
91
-
92
- self.crosshairs = None
93
- self.plots = None
94
- self.curves_data = None
95
- self.grid_coordinates = None
96
- self.scan_id = None
97
-
98
- # Connect the update signal to the update plot method
99
- self.proxy_update_plot = pg.SignalProxy(
100
- self.update_signal, rateLimit=10, slot=self.update_plot
101
- )
102
-
103
- # Init UI
104
- self.layout = QVBoxLayout(self)
105
- self.setLayout(self.layout)
106
- if toolbar_enabled: # TODO implement toolbar when ready
107
- self._init_toolbar()
108
-
109
- self.glw = pg.GraphicsLayoutWidget()
110
- self.layout.addWidget(self.glw)
111
-
112
- if self.config is None:
113
- print("No initial config found for BECDeviceMonitor")
114
- else:
115
- self.on_config_update(self.config)
116
-
117
- def _init_toolbar(self):
118
- """Initialize the toolbar."""
119
- # TODO implement toolbar when ready
120
- # from bec_widgets.widgets import ModularToolBar
121
- #
122
- # # Create and configure the toolbar
123
- # self.toolbar = ModularToolBar(self)
124
- #
125
- # # Add the toolbar to the layout
126
- # self.layout.addWidget(self.toolbar)
127
-
128
- def _init_config(self):
129
- """Initialize the configuration."""
130
- # Global widget settings
131
- self._get_global_settings()
132
-
133
- # Plot data
134
- self.plot_data = self.config.get("waveform2D", [])
135
-
136
- # Initiate database
137
- self.database = self._init_database()
138
-
139
- # Initialize the plot UI
140
- self._init_ui()
141
-
142
- def _get_global_settings(self):
143
- """Get the global widget settings."""
144
-
145
- self.plot_settings = self.config.get("plot_settings", {})
146
-
147
- self.num_columns = self.plot_settings.get("num_columns", 1)
148
- self.colormap = self.plot_settings.get("colormap", "viridis")
149
-
150
- def _init_database(self) -> dict:
151
- """
152
- Initialize the database to store the data for each plot.
153
- Returns:
154
- dict: The database.
155
- """
156
-
157
- database = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
158
-
159
- return database
160
-
161
- def _init_ui(self, num_columns: int = 3) -> None:
162
- """
163
- Initialize the UI components, create plots and store their grid positions.
164
-
165
- Args:
166
- num_columns (int): Number of columns to wrap the layout.
167
-
168
- This method initializes a dictionary `self.plots` to store the plot objects
169
- along with their corresponding x and y signal names. It dynamically arranges
170
- the plots in a grid layout based on the given number of columns and dynamically
171
- stretches the last plots to fit the remaining space.
172
- """
173
- self.glw.clear()
174
- self.plots = {}
175
- self.imageItems = {}
176
- self.grid_coordinates = []
177
- self.scatterPlots = {}
178
- self.colorBars = {}
179
-
180
- num_plots = len(self.plot_data)
181
- # Check if num_columns exceeds the number of plots
182
- if num_columns >= num_plots:
183
- num_columns = num_plots
184
- self.plot_settings["num_columns"] = num_columns # Update the settings
185
- print(
186
- "Warning: num_columns in the YAML file was greater than the number of plots."
187
- f" Resetting num_columns to number of plots:{num_columns}."
188
- )
189
- else:
190
- self.plot_settings["num_columns"] = num_columns # Update the settings
191
-
192
- num_rows = num_plots // num_columns
193
- last_row_cols = num_plots % num_columns
194
- remaining_space = num_columns - last_row_cols
195
-
196
- for i, plot_config in enumerate(self.plot_data):
197
- row, col = i // num_columns, i % num_columns
198
- colspan = 1
199
-
200
- if row == num_rows and remaining_space > 0:
201
- if last_row_cols == 1:
202
- colspan = num_columns
203
- else:
204
- colspan = remaining_space // last_row_cols + 1
205
- remaining_space -= colspan - 1
206
- last_row_cols -= 1
207
-
208
- plot_name = plot_config.get("plot_name", "")
209
-
210
- x_label = plot_config.get("x_label", "")
211
- y_label = plot_config.get("y_label", "")
212
-
213
- plot = self.glw.addPlot(row=row, col=col, colspan=colspan, title=plot_name)
214
- plot.setLabel("bottom", x_label)
215
- plot.setLabel("left", y_label)
216
- plot.addLegend()
217
-
218
- self.plots[plot_name] = plot
219
-
220
- self.grid_coordinates.append((row, col))
221
-
222
- self._init_curves()
223
-
224
- def _init_curves(self):
225
- """Init scatter plot pg containers"""
226
- self.scatterPlots = {}
227
- for i, plot_config in enumerate(self.plot_data):
228
- plot_name = plot_config.get("plot_name", "")
229
- plot = self.plots[plot_name]
230
- plot.clear()
231
-
232
- # Create ScatterPlotItem for each plot
233
- scatterPlot = pg.ScatterPlotItem(size=10)
234
- plot.addItem(scatterPlot)
235
- self.scatterPlots[plot_name] = scatterPlot
236
-
237
- @pyqtSlot(dict)
238
- def on_config_update(self, config: dict):
239
- """
240
- Validate and update the configuration settings.
241
- Args:
242
- config(dict): Configuration settings
243
- """
244
- # TODO implement BEC CLI commands similar to BECPlotter
245
- # convert config from BEC CLI to correct formatting
246
- config_tag = config.get("config", None)
247
- if config_tag is not None:
248
- config = config["config"]
249
-
250
- if self.skip_validation is True:
251
- self.config = config
252
- self._init_config()
253
-
254
- else: # TODO implement validator
255
- print("Do validation")
256
-
257
- def flush(self):
258
- """Reset current plot"""
259
-
260
- self.database = self._init_database()
261
- self._init_curves()
262
-
263
- @pyqtSlot(dict, dict)
264
- def on_scan_segment(self, msg, metadata):
265
- """
266
- Handle new scan segments and saves data to a dictionary. Linked through bec_dispatcher.
267
-
268
- Args:
269
- msg (dict): Message received with scan data.
270
- metadata (dict): Metadata of the scan.
271
- """
272
-
273
- # TODO check if this is correct
274
- current_scan_id = msg.get("scan_id", None)
275
- if current_scan_id is None:
276
- return
277
-
278
- if current_scan_id != self.scan_id:
279
- self.scan_id = current_scan_id
280
- self.scan_data = self.queue.scan_storage.find_scan_by_ID(self.scan_id)
281
- if not self.scan_data:
282
- print(f"No data found for scan_id: {self.scan_id}") # TODO better error
283
- return
284
- self.flush()
285
-
286
- # Update the database with new data
287
- self.update_database_with_scan_data(msg)
288
-
289
- # Emit signal to update plot #TODO could be moved to update_database_with_scan_data just for coresponding plot name
290
- self.update_signal.emit()
291
-
292
- def update_database_with_scan_data(self, msg):
293
- """
294
- Update the database with data from the new scan segment.
295
-
296
- Args:
297
- msg (dict): Message containing the new scan data.
298
- """
299
- data = msg.get("data", {})
300
- for plot_config in self.plot_data: # Iterate over the list
301
- plot_name = plot_config["plot_name"]
302
- x_signal = plot_config["signals"]["x"][0]["name"]
303
- y_signal = plot_config["signals"]["y"][0]["name"]
304
- z_signal = plot_config["signals"]["z"][0]["name"]
305
-
306
- if x_signal in data and y_signal in data and z_signal in data:
307
- x_value = data[x_signal][x_signal]["value"]
308
- y_value = data[y_signal][y_signal]["value"]
309
- z_value = data[z_signal][z_signal]["value"]
310
-
311
- # Update database for the corresponding plot
312
- self.database[plot_name]["x"][x_signal].append(x_value)
313
- self.database[plot_name]["y"][y_signal].append(y_value)
314
- self.database[plot_name]["z"][z_signal].append(z_value)
315
-
316
- def update_plot(self):
317
- """
318
- Update the plots with the latest data from the database.
319
- """
320
- for plot_name, scatterPlot in self.scatterPlots.items():
321
- x_data = self.database[plot_name]["x"]
322
- y_data = self.database[plot_name]["y"]
323
- z_data = self.database[plot_name]["z"]
324
-
325
- if x_data and y_data and z_data:
326
- # Extract values for each axis
327
- x_values = next(iter(x_data.values()), [])
328
- y_values = next(iter(y_data.values()), [])
329
- z_values = next(iter(z_data.values()), [])
330
-
331
- # Check if the data lists are not empty
332
- if x_values and y_values and z_values:
333
- # Normalize z_values for color mapping
334
- z_min, z_max = np.min(z_values), np.max(z_values)
335
- if z_max != z_min: # Ensure that there is a range in the z values
336
- z_values_norm = (z_values - z_min) / (z_max - z_min)
337
- colormap = pg.colormap.get(
338
- self.colormap
339
- ) # using colormap from global settings
340
- colors = [colormap.map(z) for z in z_values_norm]
341
-
342
- # Update scatter plot data with colors
343
- scatterPlot.setData(x=x_values, y=y_values, brush=colors)
344
- else:
345
- # Handle case where all z values are the same (e.g., avoid division by zero)
346
- scatterPlot.setData(x=x_values, y=y_values) # Default brush can be used
347
-
348
-
349
- if __name__ == "__main__": # pragma: no cover
350
- import argparse
351
- import json
352
- import sys
353
-
354
- parser = argparse.ArgumentParser()
355
- parser.add_argument("--config_file", help="Path to the config file.")
356
- parser.add_argument("--config", help="Path to the config file.")
357
- parser.add_argument("--id", help="GUI ID.")
358
- args = parser.parse_args()
359
-
360
- if args.config is not None:
361
- # Load config from file
362
- config = json.loads(args.config)
363
- elif args.config_file is not None:
364
- # Load config from file
365
- config = yaml_dialog.load_yaml(args.config_file)
366
- else:
367
- config = CONFIG_DEFAULT
368
-
369
- client = BECDispatcher().client
370
- client.start()
371
- app = QApplication(sys.argv)
372
- monitor = BECMonitor2DScatter(config=config, gui_id=args.id, skip_validation=True)
373
- monitor.show()
374
- sys.exit(app.exec())
@@ -1,162 +0,0 @@
1
- # pylint: disable=missing-module-docstring, missing-function-docstring
2
- from collections import defaultdict
3
- from unittest.mock import MagicMock
4
-
5
- import pytest
6
- from qtpy import QtGui
7
-
8
- from bec_widgets.widgets import BECMonitor2DScatter
9
-
10
- CONFIG_DEFAULT = {
11
- "plot_settings": {"colormap": "CET-L4", "num_columns": 1},
12
- "waveform2D": [
13
- {
14
- "plot_name": "Waveform 2D Scatter (1)",
15
- "x_label": "Sam X",
16
- "y_label": "Sam Y",
17
- "signals": {
18
- "x": [{"name": "samx", "entry": "samx"}],
19
- "y": [{"name": "samy", "entry": "samy"}],
20
- "z": [{"name": "gauss_bpm", "entry": "gauss_bpm"}],
21
- },
22
- },
23
- {
24
- "plot_name": "Waveform 2D Scatter (2)",
25
- "x_label": "Sam X",
26
- "y_label": "Sam Y",
27
- "signals": {
28
- "x": [{"name": "samy", "entry": "samy"}],
29
- "y": [{"name": "samx", "entry": "samx"}],
30
- "z": [{"name": "gauss_bpm", "entry": "gauss_bpm"}],
31
- },
32
- },
33
- ],
34
- }
35
-
36
- CONFIG_ONE_PLOT = {
37
- "plot_settings": {"colormap": "CET-L4", "num_columns": 1},
38
- "waveform2D": [
39
- {
40
- "plot_name": "Waveform 2D Scatter (1)",
41
- "x_label": "Sam X",
42
- "y_label": "Sam Y",
43
- "signals": {
44
- "x": [{"name": "aptrx", "entry": "aptrx"}],
45
- "y": [{"name": "aptry", "entry": "aptry"}],
46
- "z": [{"name": "gauss_bpm", "entry": "gauss_bpm"}],
47
- },
48
- }
49
- ],
50
- }
51
-
52
-
53
- @pytest.fixture(scope="function")
54
- def monitor_2Dscatter(qtbot):
55
- client = MagicMock()
56
- widget = BECMonitor2DScatter(client=client)
57
- qtbot.addWidget(widget)
58
- qtbot.waitExposed(widget)
59
- yield widget
60
-
61
-
62
- @pytest.mark.parametrize("config, number_of_plots", [(CONFIG_DEFAULT, 2), (CONFIG_ONE_PLOT, 1)])
63
- def test_initialization(monitor_2Dscatter, config, number_of_plots):
64
- config_load = config
65
- monitor_2Dscatter.on_config_update(config_load)
66
- assert isinstance(monitor_2Dscatter, BECMonitor2DScatter)
67
- assert monitor_2Dscatter.client is not None
68
- assert monitor_2Dscatter.config == config_load
69
- assert len(monitor_2Dscatter.plot_data) == number_of_plots
70
-
71
-
72
- @pytest.mark.parametrize("config ", [(CONFIG_DEFAULT), (CONFIG_ONE_PLOT)])
73
- def test_database_initialization(monitor_2Dscatter, config):
74
- monitor_2Dscatter.on_config_update(config)
75
- # Check if the database is a defaultdict
76
- assert isinstance(monitor_2Dscatter.database, defaultdict)
77
- for axis_dict in monitor_2Dscatter.database.values():
78
- assert isinstance(axis_dict, defaultdict)
79
- for signal_list in axis_dict.values():
80
- assert isinstance(signal_list, defaultdict)
81
-
82
- # Access the elements
83
- for plot_config in config["waveform2D"]:
84
- plot_name = plot_config["plot_name"]
85
-
86
- for axis in ["x", "y", "z"]:
87
- for signal in plot_config["signals"][axis]:
88
- signal_name = signal["name"]
89
- assert not monitor_2Dscatter.database[plot_name][axis][signal_name]
90
- assert isinstance(monitor_2Dscatter.database[plot_name][axis][signal_name], list)
91
-
92
-
93
- @pytest.mark.parametrize("config ", [(CONFIG_DEFAULT), (CONFIG_ONE_PLOT)])
94
- def test_ui_initialization(monitor_2Dscatter, config):
95
- monitor_2Dscatter.on_config_update(config)
96
- assert len(monitor_2Dscatter.plots) == len(config["waveform2D"])
97
- for plot_config in config["waveform2D"]:
98
- plot_name = plot_config["plot_name"]
99
- assert plot_name in monitor_2Dscatter.plots
100
- plot = monitor_2Dscatter.plots[plot_name]
101
- assert plot.titleLabel.text == plot_name
102
-
103
-
104
- def simulate_scan_data(monitor, x_value, y_value, z_value):
105
- """Helper function to simulate scan data input with three devices."""
106
- msg = {
107
- "data": {
108
- "samx": {"samx": {"value": x_value}},
109
- "samy": {"samy": {"value": y_value}},
110
- "gauss_bpm": {"gauss_bpm": {"value": z_value}},
111
- },
112
- "scan_id": 1,
113
- }
114
- monitor.on_scan_segment(msg, {})
115
-
116
-
117
- def test_data_update_and_plotting(monitor_2Dscatter, qtbot):
118
- monitor_2Dscatter.on_config_update(CONFIG_DEFAULT)
119
- data_sets = [(1, 4, 7), (2, 5, 8), (3, 6, 9)] # (x, y, z) tuples
120
- plot_name = "Waveform 2D Scatter (1)"
121
-
122
- for x, y, z in data_sets:
123
- simulate_scan_data(monitor_2Dscatter, x, y, z)
124
- qtbot.wait(100) # Wait for the plot to update
125
-
126
- # Retrieve the plot and check if the number of data points matches
127
- scatterPlot = monitor_2Dscatter.scatterPlots[plot_name]
128
- assert len(scatterPlot.data) == len(data_sets)
129
-
130
- # Check if the data in the database matches the sent data
131
- x_data = [
132
- point
133
- for points_list in monitor_2Dscatter.database[plot_name]["x"].values()
134
- for point in points_list
135
- ]
136
- y_data = [
137
- point
138
- for points_list in monitor_2Dscatter.database[plot_name]["y"].values()
139
- for point in points_list
140
- ]
141
- z_data = [
142
- point
143
- for points_list in monitor_2Dscatter.database[plot_name]["z"].values()
144
- for point in points_list
145
- ]
146
-
147
- assert x_data == [x for x, _, _ in data_sets]
148
- assert y_data == [y for _, y, _ in data_sets]
149
- assert z_data == [z for _, _, z in data_sets]
150
-
151
-
152
- def test_color_mapping(monitor_2Dscatter, qtbot):
153
- monitor_2Dscatter.on_config_update(CONFIG_DEFAULT)
154
- data_sets = [(1, 4, 7), (2, 5, 8), (3, 6, 9)] # (x, y, z) tuples
155
- for x, y, z in data_sets:
156
- simulate_scan_data(monitor_2Dscatter, x, y, z)
157
- qtbot.wait(100) # Wait for the plot to update
158
-
159
- scatterPlot = monitor_2Dscatter.scatterPlots["Waveform 2D Scatter (1)"]
160
-
161
- # Check if colors are applied
162
- assert all(isinstance(point.brush().color(), QtGui.QColor) for point in scatterPlot.points())