accusleepy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1086 @@
1
+ # AccuSleePy manual scoring GUI
2
+ # Icon sources:
3
+ # Arkinasi, https://www.flaticon.com/authors/arkinasi
4
+ # kendis lasman, https://www.flaticon.com/packs/ui-79
5
+
6
+
7
+ import copy
8
+ import os
9
+ from dataclasses import dataclass
10
+ from functools import partial
11
+ from types import SimpleNamespace
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PySide6 import QtCore, QtGui, QtWidgets
16
+
17
+ from accusleepy.constants import UNDEFINED_LABEL
18
+ from accusleepy.fileio import load_config, save_labels
19
+ from accusleepy.gui.mplwidget import resample_x_ticks
20
+ from accusleepy.gui.viewer_window import Ui_ViewerWindow
21
+ from accusleepy.signal_processing import create_spectrogram, get_emg_power
22
+
23
+ # colormap for displaying brain state labels
24
+ # the first entry represents the "undefined" state
25
+ # the other entries are the digits in "keyboard" order (1234567890)
26
+ LABEL_CMAP = np.concatenate(
27
+ [np.array([[0, 0, 0, 0]]), plt.colormaps["tab10"](range(10))], axis=0
28
+ )
29
+ # relative path to user manual text file
30
+ USER_MANUAL_FILE = "text/manual_scoring_guide.txt"
31
+
32
+ # constants used by callback functions
33
+ # label formats
34
+ DISPLAY_FORMAT = "display"
35
+ DIGIT_FORMAT = "digit"
36
+ # offset changes
37
+ OFFSET_UP = "up"
38
+ OFFSET_DOWN = "down"
39
+ OFFSET_INCREMENTS = {OFFSET_UP: 0.02, OFFSET_DOWN: -0.02}
40
+ # changes to number of epochs
41
+ DIRECTION_PLUS = "plus"
42
+ DIRECTION_MINUS = "minus"
43
+ # changes to selected epoch
44
+ DIRECTION_LEFT = "left"
45
+ DIRECTION_RIGHT = "right"
46
+ # zoom directions
47
+ ZOOM_IN = "in"
48
+ ZOOM_OUT = "out"
49
+ ZOOM_RESET = "reset"
50
+ SIGNAL_ZOOM_FACTORS = {ZOOM_IN: 1.08, ZOOM_OUT: 0.95}
51
+ # signal names
52
+ EEG_SIGNAL = "eeg"
53
+ EMG_SIGNAL = "emg"
54
+ # spectrogram color changes
55
+ BRIGHTER = "brighter"
56
+ DIMMER = "dimmer"
57
+ # next epoch target
58
+ DIFFERENT_STATE = "different"
59
+ UNDEFINED_STATE = "undefined"
60
+ # how far from the edge of the upper plot the marker should be
61
+ # before starting to scroll again - must be in (0, 0.5)
62
+ SCROLL_BOUNDARY = 0.35
63
+ # max number of sequential undo actions allowed
64
+ UNDO_LIMIT = 100
65
+
66
+
67
+ @dataclass
68
+ class StateChange:
69
+ """Information about an event when brain state labels were changed"""
70
+
71
+ previous_labels: np.array # old brain state labels
72
+ new_labels: np.array # new brain state labels
73
+ epoch: int # first epoch affected
74
+
75
+
76
+ class ManualScoringWindow(QtWidgets.QDialog):
77
+ """AccuSleePy manual scoring GUI"""
78
+
79
+ def __init__(
80
+ self,
81
+ eeg: np.array,
82
+ emg: np.array,
83
+ label_file: str,
84
+ labels: np.array,
85
+ sampling_rate: int | float,
86
+ epoch_length: int | float,
87
+ ):
88
+ """Initialize the manual scoring window
89
+
90
+ :param eeg: EEG signal
91
+ :param emg: EMG signal
92
+ :param label_file: filename for labels
93
+ :param labels: brain state labels
94
+ :param sampling_rate: sampling rate, in Hz
95
+ :param epoch_length: epoch length, in seconds
96
+ """
97
+ super(ManualScoringWindow, self).__init__()
98
+
99
+ self.label_file = label_file
100
+ self.eeg = eeg
101
+ self.emg = emg
102
+ self.labels = labels
103
+ self.sampling_rate = sampling_rate
104
+ self.epoch_length = epoch_length
105
+
106
+ self.n_epochs = len(self.labels)
107
+
108
+ # initialize the UI
109
+ self.ui = Ui_ViewerWindow()
110
+ self.ui.setupUi(self)
111
+ self.setWindowTitle("AccuSleePy manual scoring window")
112
+
113
+ # load set of valid brain states
114
+ self.brain_state_set = load_config()
115
+
116
+ # initial setting for number of epochs to show in the lower plot
117
+ self.epochs_to_show = 5
118
+
119
+ # find the set of y-axis locations of valid brain state labels
120
+ self.label_display_options = convert_labels(
121
+ np.array([b.digit for b in self.brain_state_set.brain_states]),
122
+ style=DISPLAY_FORMAT,
123
+ )
124
+ self.smallest_display_label = np.min(self.label_display_options)
125
+
126
+ self.ui.upperfigure.epoch_length = self.epoch_length
127
+ self.ui.lowerfigure.epoch_length = self.epoch_length
128
+
129
+ # get EEG spectrogram and its frequency axis
130
+ spectrogram, spectrogram_frequencies = create_spectrogram(
131
+ self.eeg, self.sampling_rate, self.epoch_length
132
+ )
133
+
134
+ # calculate RMS of EMG for each epoch and apply a ceiling
135
+ self.upper_emg = create_upper_emg_signal(
136
+ self.emg, self.sampling_rate, self.epoch_length
137
+ )
138
+
139
+ # center and scale the EEG and EMG signals to fit the display
140
+ self.eeg, self.emg = transform_eeg_emg(self.eeg, self.emg)
141
+
142
+ # convert labels to "display" format and make an image to display them
143
+ self.display_labels = convert_labels(self.labels, DISPLAY_FORMAT)
144
+ self.label_img = create_label_img(
145
+ self.display_labels, self.label_display_options
146
+ )
147
+
148
+ # history of changes to the brain state labels
149
+ self.history = list()
150
+ # index of the change "ahead" of the current state
151
+ # i.e., which change will be applied by a "redo" action
152
+ self.history_index = 0
153
+
154
+ # set up both figures
155
+ self.ui.upperfigure.setup_upper_figure(
156
+ self.n_epochs,
157
+ self.label_img,
158
+ spectrogram,
159
+ spectrogram_frequencies,
160
+ self.upper_emg,
161
+ self.epochs_to_show,
162
+ self.label_display_options,
163
+ self.brain_state_set,
164
+ self.roi_callback,
165
+ )
166
+ self.ui.lowerfigure.setup_lower_figure(
167
+ self.label_img,
168
+ self.sampling_rate,
169
+ self.epochs_to_show,
170
+ self.brain_state_set,
171
+ self.label_display_options,
172
+ )
173
+
174
+ # initialize values that can be changed by user input
175
+ self.epoch = 0
176
+ self.upper_left_epoch = 0
177
+ self.upper_right_epoch = self.n_epochs - 1
178
+ self.lower_left_epoch = 0
179
+ self.lower_right_epoch = self.epochs_to_show - 1
180
+ self.eeg_signal_scale_factor = 1
181
+ self.emg_signal_scale_factor = 1
182
+ self.eeg_signal_offset = 0
183
+ self.emg_signal_offset = 0
184
+ self.roi_brain_state = 0
185
+ self.label_roi_mode = False
186
+ self.autoscroll_state = False
187
+ # keep track of save state to warn user when they quit
188
+ self.last_saved_labels = copy.deepcopy(self.labels)
189
+
190
+ # populate the lower figure
191
+ self.update_lower_figure()
192
+
193
+ # user input: keyboard shortcuts
194
+ keypress_right = QtGui.QShortcut(
195
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Right), self
196
+ )
197
+ keypress_right.activated.connect(partial(self.shift_epoch, DIRECTION_RIGHT))
198
+
199
+ keypress_left = QtGui.QShortcut(
200
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Left), self
201
+ )
202
+ keypress_left.activated.connect(partial(self.shift_epoch, DIRECTION_LEFT))
203
+
204
+ keypress_zoom_in_x = list()
205
+ for zoom_key in [QtCore.Qt.Key.Key_Plus, QtCore.Qt.Key.Key_Equal]:
206
+ keypress_zoom_in_x.append(
207
+ QtGui.QShortcut(QtGui.QKeySequence(zoom_key), self)
208
+ )
209
+ keypress_zoom_in_x[-1].activated.connect(partial(self.zoom_x, ZOOM_IN))
210
+
211
+ keypress_zoom_out_x = QtGui.QShortcut(
212
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Minus), self
213
+ )
214
+ keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
215
+
216
+ keypress_modify_label = list()
217
+ for brain_state in self.brain_state_set.brain_states:
218
+ keypress_modify_label.append(
219
+ QtGui.QShortcut(
220
+ QtGui.QKeySequence(QtCore.Qt.Key[f"Key_{brain_state.digit}"]),
221
+ self,
222
+ )
223
+ )
224
+ keypress_modify_label[-1].activated.connect(
225
+ partial(self.modify_current_epoch_label, brain_state.digit)
226
+ )
227
+
228
+ keypress_delete_label = QtGui.QShortcut(
229
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Backspace), self
230
+ )
231
+ keypress_delete_label.activated.connect(
232
+ partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
233
+ )
234
+
235
+ keypress_quit = QtGui.QShortcut(
236
+ QtGui.QKeySequence(
237
+ QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
238
+ ),
239
+ self,
240
+ )
241
+ keypress_quit.activated.connect(self.close)
242
+
243
+ keypress_save = QtGui.QShortcut(
244
+ QtGui.QKeySequence(
245
+ QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_S)
246
+ ),
247
+ self,
248
+ )
249
+ keypress_save.activated.connect(self.save)
250
+
251
+ keypress_roi = list()
252
+ for brain_state in self.brain_state_set.brain_states:
253
+ keypress_roi.append(
254
+ QtGui.QShortcut(
255
+ QtGui.QKeySequence(
256
+ QtCore.QKeyCombination(
257
+ QtCore.Qt.Modifier.SHIFT,
258
+ QtCore.Qt.Key[f"Key_{brain_state.digit}"],
259
+ )
260
+ ),
261
+ self,
262
+ )
263
+ )
264
+ keypress_roi[-1].activated.connect(
265
+ partial(self.enter_label_roi_mode, brain_state.digit)
266
+ )
267
+ keypress_roi.append(
268
+ QtGui.QShortcut(
269
+ QtGui.QKeySequence(
270
+ QtCore.QKeyCombination(
271
+ QtCore.Qt.Modifier.SHIFT,
272
+ QtCore.Qt.Key.Key_Backspace,
273
+ )
274
+ ),
275
+ self,
276
+ )
277
+ )
278
+ keypress_roi[-1].activated.connect(
279
+ partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
280
+ )
281
+
282
+ keypress_esc = QtGui.QShortcut(
283
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Escape), self
284
+ )
285
+ keypress_esc.activated.connect(self.exit_label_roi_mode)
286
+
287
+ keypress_space = QtGui.QShortcut(
288
+ QtGui.QKeySequence(QtCore.Qt.Key.Key_Space), self
289
+ )
290
+ keypress_space.activated.connect(
291
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
292
+ )
293
+ keypress_shift_right = QtGui.QShortcut(
294
+ QtGui.QKeySequence(
295
+ QtCore.QKeyCombination(
296
+ QtCore.Qt.Modifier.SHIFT,
297
+ QtCore.Qt.Key.Key_Right,
298
+ )
299
+ ),
300
+ self,
301
+ )
302
+ keypress_shift_right.activated.connect(
303
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
304
+ )
305
+ keypress_shift_left = QtGui.QShortcut(
306
+ QtGui.QKeySequence(
307
+ QtCore.QKeyCombination(
308
+ QtCore.Qt.Modifier.SHIFT,
309
+ QtCore.Qt.Key.Key_Left,
310
+ )
311
+ ),
312
+ self,
313
+ )
314
+ keypress_shift_left.activated.connect(
315
+ partial(self.jump_to_next_state, DIRECTION_LEFT, DIFFERENT_STATE)
316
+ )
317
+ keypress_ctrl_right = QtGui.QShortcut(
318
+ QtGui.QKeySequence(
319
+ QtCore.QKeyCombination(
320
+ QtCore.Qt.Modifier.CTRL,
321
+ QtCore.Qt.Key.Key_Right,
322
+ )
323
+ ),
324
+ self,
325
+ )
326
+ keypress_ctrl_right.activated.connect(
327
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, UNDEFINED_STATE)
328
+ )
329
+ keypress_ctrl_left = QtGui.QShortcut(
330
+ QtGui.QKeySequence(
331
+ QtCore.QKeyCombination(
332
+ QtCore.Qt.Modifier.CTRL,
333
+ QtCore.Qt.Key.Key_Left,
334
+ )
335
+ ),
336
+ self,
337
+ )
338
+ keypress_ctrl_left.activated.connect(
339
+ partial(self.jump_to_next_state, DIRECTION_LEFT, UNDEFINED_STATE)
340
+ )
341
+
342
+ keypress_undo = QtGui.QShortcut(
343
+ QtGui.QKeySequence(
344
+ QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Z)
345
+ ),
346
+ self,
347
+ )
348
+ keypress_undo.activated.connect(self.undo)
349
+ keypress_redo = QtGui.QShortcut(
350
+ QtGui.QKeySequence(
351
+ QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Y)
352
+ ),
353
+ self,
354
+ )
355
+ keypress_redo.activated.connect(self.redo)
356
+
357
+ # user input: clicks
358
+ self.ui.upperfigure.canvas.mpl_connect("button_press_event", self.click_to_jump)
359
+
360
+ # user input: buttons
361
+ self.ui.savebutton.clicked.connect(self.save)
362
+ self.ui.xzoomin.clicked.connect(partial(self.zoom_x, ZOOM_IN))
363
+ self.ui.xzoomout.clicked.connect(partial(self.zoom_x, ZOOM_OUT))
364
+ self.ui.xzoomreset.clicked.connect(partial(self.zoom_x, ZOOM_RESET))
365
+ self.ui.autoscroll.stateChanged.connect(self.update_autoscroll_state)
366
+ self.ui.eegzoomin.clicked.connect(
367
+ partial(self.update_signal_zoom, EEG_SIGNAL, ZOOM_IN)
368
+ )
369
+ self.ui.eegzoomout.clicked.connect(
370
+ partial(self.update_signal_zoom, EEG_SIGNAL, ZOOM_OUT)
371
+ )
372
+ self.ui.emgzoomin.clicked.connect(
373
+ partial(self.update_signal_zoom, EMG_SIGNAL, ZOOM_IN)
374
+ )
375
+ self.ui.emgzoomout.clicked.connect(
376
+ partial(self.update_signal_zoom, EMG_SIGNAL, ZOOM_OUT)
377
+ )
378
+ self.ui.eegshiftup.clicked.connect(
379
+ partial(self.update_signal_offset, EEG_SIGNAL, OFFSET_UP)
380
+ )
381
+ self.ui.eegshiftdown.clicked.connect(
382
+ partial(self.update_signal_offset, EEG_SIGNAL, OFFSET_DOWN)
383
+ )
384
+ self.ui.emgshiftup.clicked.connect(
385
+ partial(self.update_signal_offset, EMG_SIGNAL, OFFSET_UP)
386
+ )
387
+ self.ui.emgshiftdown.clicked.connect(
388
+ partial(self.update_signal_offset, EMG_SIGNAL, OFFSET_DOWN)
389
+ )
390
+ self.ui.shownepochsplus.clicked.connect(
391
+ partial(self.update_epochs_shown, DIRECTION_PLUS)
392
+ )
393
+ self.ui.shownepochsminus.clicked.connect(
394
+ partial(self.update_epochs_shown, DIRECTION_MINUS)
395
+ )
396
+ self.ui.specbrighter.clicked.connect(
397
+ partial(self.update_spectrogram_brightness, BRIGHTER)
398
+ )
399
+ self.ui.specdimmer.clicked.connect(
400
+ partial(self.update_spectrogram_brightness, DIMMER)
401
+ )
402
+ self.ui.helpbutton.clicked.connect(self.show_user_manual)
403
+
404
+ self.show()
405
+
406
+ def add_to_history(self, state_change: StateChange) -> None:
407
+ """Add an event to the history of changes to brain state labels
408
+
409
+ This allows the user to undo / redo changes to the brain state
410
+ labels by navigating backwards or forwards through a list of
411
+ changes that have been made. If one or more changes are undone,
412
+ and then a new change is performed, it will not be possible to
413
+ redo the changes that were undone. At most UNDO_LIMIT changes
414
+ are stored.
415
+
416
+ :param state_change: description of the change to the labels
417
+ """
418
+ # if history is empty
419
+ if len(self.history) == 0:
420
+ self.history.append(state_change)
421
+ self.history_index = 1
422
+ return
423
+ # if we are not at the end of the history
424
+ if self.history_index < len(self.history):
425
+ # remove events after the most recent one
426
+ self.history = self.history[: self.history_index]
427
+ self.history.append(state_change)
428
+ self.history_index += 1
429
+ else:
430
+ self.history.append(state_change)
431
+ # if this would make the history list too long
432
+ if self.history_index == UNDO_LIMIT:
433
+ # remove the oldest entry
434
+ self.history = self.history[1:]
435
+ else:
436
+ self.history_index += 1
437
+
438
+ def force_modify_labels(self, epoch: int, new_labels: np.array) -> None:
439
+ """Change brain state labels for an undo/redo action"""
440
+ # make the change
441
+ self.labels[epoch : epoch + len(new_labels)] = new_labels
442
+ self.display_labels = convert_labels(
443
+ self.labels,
444
+ style=DISPLAY_FORMAT,
445
+ )
446
+ self.label_img = create_label_img(
447
+ self.display_labels, self.label_display_options
448
+ )
449
+ # update the plots
450
+ self.update_figures()
451
+
452
+ def redo(self) -> None:
453
+ """Redo the last change to brain state labels that was undone"""
454
+ # if there are no events to redo
455
+ if self.history_index == len(self.history):
456
+ return
457
+ # make the change
458
+ state_change = self.history[self.history_index]
459
+ self.force_modify_labels(
460
+ epoch=state_change.epoch, new_labels=state_change.new_labels
461
+ )
462
+ # update history index
463
+ self.history_index += 1
464
+ # shift the cursor
465
+ simulated_click = SimpleNamespace(
466
+ **{"xdata": state_change.epoch, "inaxes": None}
467
+ )
468
+ self.click_to_jump(simulated_click)
469
+
470
+ def undo(self) -> None:
471
+ """Undo the last change to the brain state labels"""
472
+ # if there are no events to undo
473
+ if self.history_index == 0:
474
+ return
475
+ # make the change
476
+ state_change = self.history[self.history_index - 1]
477
+ self.force_modify_labels(
478
+ epoch=state_change.epoch, new_labels=state_change.previous_labels
479
+ )
480
+ # update history index
481
+ self.history_index -= 1
482
+ # shift the cursor
483
+ simulated_click = SimpleNamespace(
484
+ **{"xdata": state_change.epoch, "inaxes": None}
485
+ )
486
+ self.click_to_jump(simulated_click)
487
+
488
+ def closeEvent(self, event: QtGui.QCloseEvent) -> None:
489
+ """Check if there are unsaved changes before closing"""
490
+ if not all(self.labels == self.last_saved_labels):
491
+ result = QtWidgets.QMessageBox.question(
492
+ self,
493
+ "Unsaved changes",
494
+ "You have unsaved changes. Really quit?",
495
+ QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
496
+ )
497
+ if result == QtWidgets.QMessageBox.Yes:
498
+ event.accept()
499
+ else:
500
+ event.ignore()
501
+
502
+ def show_user_manual(self) -> None:
503
+ """Show a popup window with the user manual"""
504
+ user_manual_file = open(
505
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE),
506
+ "r",
507
+ )
508
+ user_manual_text = user_manual_file.read()
509
+ user_manual_file.close()
510
+ label_widget = QtWidgets.QLabel()
511
+ label_widget.setText(user_manual_text)
512
+ label_widget.setStyleSheet("background-color: white;")
513
+
514
+ self.popup = QtWidgets.QWidget()
515
+ grid = QtWidgets.QGridLayout()
516
+ grid.addWidget(label_widget)
517
+ self.popup.setLayout(grid)
518
+ self.popup.setGeometry(QtCore.QRect(50, 100, 350, 400))
519
+ self.popup.show()
520
+
521
+ def jump_to_next_state(self, direction: str, target: str) -> None:
522
+ """Jump to epoch based on a target brain state
523
+
524
+ This allows the user to jump to the next epoch in a given direction
525
+ (left or right) that has a given state (undefined, or different from
526
+ the current epoch). It's useful for reviewing state transitions or
527
+ locating unlabeled epochs.
528
+
529
+ :param direction: left or right
530
+ :param target: different or undefined
531
+ """
532
+ # create a simulated click so we can reuse click_to_jump
533
+ simulated_click = SimpleNamespace(**{"xdata": self.epoch, "inaxes": None})
534
+ if direction == DIRECTION_RIGHT:
535
+ if target == DIFFERENT_STATE:
536
+ matches = np.where(
537
+ self.labels[self.epoch + 1 :] != self.labels[self.epoch]
538
+ )[0]
539
+ else:
540
+ matches = np.where(self.labels[self.epoch + 1 :] == UNDEFINED_LABEL)[0]
541
+ if matches.size > 0:
542
+ simulated_click.xdata = int(matches[0]) + 1 + self.epoch
543
+ else:
544
+ if target == DIFFERENT_STATE:
545
+ matches = np.where(
546
+ self.labels[: self.epoch] != self.labels[self.epoch]
547
+ )[0]
548
+ else:
549
+ matches = np.where(self.labels[: self.epoch] == UNDEFINED_LABEL)[0]
550
+ if matches.size > 0:
551
+ simulated_click.xdata = int(matches[-1])
552
+ self.click_to_jump(simulated_click)
553
+
554
+ def roi_callback(self, eclick, erelease) -> None:
555
+ """Callback for ROI labeling widget
556
+
557
+ This is called by the RectangleSelector widget when the user finishes
558
+ drawing an ROI. It sets a range of epochs to the desired brain state.
559
+ The function signature is required to have this format.
560
+ """
561
+ # get range of epochs affected
562
+ epoch_start = int(np.ceil(eclick.xdata))
563
+ epoch_end = int(np.floor(erelease.xdata)) + 1
564
+
565
+ previous_labels = copy.deepcopy(self.labels[epoch_start:epoch_end])
566
+ new_labels = np.ones(epoch_end - epoch_start).astype(int) * self.roi_brain_state
567
+
568
+ # if something changed, track the change
569
+ if not np.array_equal(previous_labels, new_labels):
570
+ self.add_to_history(
571
+ StateChange(
572
+ previous_labels=previous_labels,
573
+ new_labels=new_labels,
574
+ epoch=epoch_start,
575
+ )
576
+ )
577
+ # make the change
578
+ self.labels[epoch_start:epoch_end] = self.roi_brain_state
579
+ self.display_labels = convert_labels(
580
+ self.labels,
581
+ style=DISPLAY_FORMAT,
582
+ )
583
+ self.label_img = create_label_img(
584
+ self.display_labels, self.label_display_options
585
+ )
586
+ # update the plots
587
+ self.update_figures()
588
+ self.exit_label_roi_mode()
589
+
590
+ def exit_label_roi_mode(self) -> None:
591
+ """Restore the normal GUI state after an ROI is drawn"""
592
+ self.ui.upperfigure.roi.set_active(False)
593
+ self.ui.upperfigure.roi.set_visible(False)
594
+ self.ui.upperfigure.roi.update()
595
+ self.label_roi_mode = False
596
+ self.ui.upperfigure.editing_patch.set_visible(False)
597
+
598
+ def enter_label_roi_mode(self, brain_state: int) -> None:
599
+ """Enter ROI drawing mode
600
+
601
+ In this mode, a user can draw an ROI on the upper brain state label
602
+ image to set a range of epochs to a new brain state.
603
+
604
+ :param brain_state: new brain state to set
605
+ """
606
+ self.label_roi_mode = True
607
+ self.roi_brain_state = brain_state
608
+ self.ui.upperfigure.roi_patch.set(
609
+ facecolor=LABEL_CMAP[
610
+ convert_labels(np.array([brain_state]), DISPLAY_FORMAT)
611
+ ]
612
+ )
613
+ self.ui.upperfigure.editing_patch.set_visible(True)
614
+ self.ui.upperfigure.canvas.draw()
615
+ self.ui.upperfigure.roi.set_active(True)
616
+
617
+ def save(self) -> None:
618
+ """Save brain state labels to file"""
619
+ save_labels(self.labels, self.label_file)
620
+ self.last_saved_labels = copy.deepcopy(self.labels)
621
+
622
+ def update_spectrogram_brightness(self, direction: str) -> None:
623
+ """Modify spectrogram color range based on button press
624
+
625
+ :param direction: brighter or dimmer
626
+ """
627
+ vmin, vmax = self.ui.upperfigure.spec_ref.get_clim()
628
+ if direction == BRIGHTER:
629
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 0.96))
630
+ else:
631
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 1.07))
632
+ self.ui.upperfigure.canvas.draw()
633
+
634
+ def update_epochs_shown(self, direction: str) -> None:
635
+ """Change the number of epochs shown based on button press
636
+
637
+ The user can change the number of epochs shown in the lower figure
638
+ via button presses. This requires extensive changes to both figures.
639
+ The number of epochs can only change in increments of 2 and should
640
+ always be an odd number >= 3.
641
+
642
+ :param direction: plus or minus
643
+ """
644
+ # if we are near the beginning or end of the recording, we need
645
+ # to change the epoch range differently.
646
+ if direction == DIRECTION_PLUS:
647
+ self.epochs_to_show += 2
648
+ if self.lower_left_epoch == 0:
649
+ self.lower_right_epoch += 2
650
+ elif self.lower_right_epoch == self.n_epochs - 1:
651
+ self.lower_left_epoch -= 2
652
+ else:
653
+ self.lower_left_epoch -= 1
654
+ self.lower_right_epoch += 1
655
+ else:
656
+ if self.epochs_to_show > 3:
657
+ self.epochs_to_show -= 2
658
+ if self.lower_left_epoch == 0:
659
+ self.lower_right_epoch -= 2
660
+ elif self.lower_right_epoch == self.n_epochs - 1:
661
+ self.lower_left_epoch += 2
662
+ else:
663
+ self.lower_left_epoch += 1
664
+ self.lower_right_epoch -= 1
665
+
666
+ self.ui.shownepochslabel.setText(str(self.epochs_to_show))
667
+
668
+ # rebuild lower figure from scratch
669
+ self.ui.lowerfigure.canvas.figure.clf()
670
+ self.ui.lowerfigure.setup_lower_figure(
671
+ self.label_img,
672
+ self.sampling_rate,
673
+ self.epochs_to_show,
674
+ self.brain_state_set,
675
+ self.label_display_options,
676
+ )
677
+ self.update_figures()
678
+
679
+ def update_signal_offset(self, signal: str, direction: str) -> None:
680
+ """Shift EEG or EMG up or down
681
+
682
+ :param signal: eeg or emg
683
+ :param direction: up or down
684
+ """
685
+ if signal == EEG_SIGNAL:
686
+ self.eeg_signal_offset += OFFSET_INCREMENTS[direction]
687
+ else:
688
+ self.emg_signal_offset += OFFSET_INCREMENTS[direction]
689
+ self.update_lower_figure()
690
+
691
+ def update_signal_zoom(self, signal: str, direction: str) -> None:
692
+ """Zoom EEG or EMG y-axis
693
+
694
+ :param signal: eeg or emg
695
+ :param direction: in or out
696
+ """
697
+ if signal == EEG_SIGNAL:
698
+ self.eeg_signal_scale_factor *= SIGNAL_ZOOM_FACTORS[direction]
699
+ else:
700
+ self.emg_signal_scale_factor *= SIGNAL_ZOOM_FACTORS[direction]
701
+ self.update_lower_figure()
702
+
703
+ def update_autoscroll_state(self, checked) -> None:
704
+ """Toggle autoscroll behavior
705
+
706
+ If autoscroll is enabled, setting the brain state of the current epoch
707
+ via a keypress will advance to the next epoch.
708
+
709
+ :param checked: state of the checkbox
710
+ """
711
+ self.autoscroll_state = checked
712
+
713
+ def adjust_upper_figure_x_limits(self) -> None:
714
+ """Update the x-axis limits of the upper figure subplots"""
715
+ for i in [0, 1, 3]:
716
+ self.ui.upperfigure.canvas.axes[i].set_xlim(
717
+ (self.upper_left_epoch - 0.5, self.upper_right_epoch + 0.5)
718
+ )
719
+ self.ui.upperfigure.canvas.axes[2].set_xlim(
720
+ (self.upper_left_epoch, self.upper_right_epoch + 1)
721
+ )
722
+ self.ui.upperfigure.canvas.draw()
723
+
724
+ def zoom_x(self, direction: str) -> None:
725
+ """Change upper figure x-axis zoom level
726
+
727
+ :param direction: in, out, or reset
728
+ """
729
+ zoom_in_factor = 0.45
730
+ zoom_out_factor = 1.017
731
+ epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
732
+ if direction == ZOOM_IN:
733
+ self.upper_left_epoch = max(
734
+ [
735
+ self.upper_left_epoch,
736
+ round(self.epoch - zoom_in_factor * epochs_shown),
737
+ ]
738
+ )
739
+
740
+ self.upper_right_epoch = min(
741
+ [
742
+ self.upper_right_epoch,
743
+ round(self.epoch + zoom_in_factor * epochs_shown),
744
+ ]
745
+ )
746
+
747
+ elif direction == ZOOM_OUT:
748
+ self.upper_left_epoch = max(
749
+ [0, round(self.epoch - zoom_out_factor * epochs_shown)]
750
+ )
751
+
752
+ self.upper_right_epoch = min(
753
+ [self.n_epochs - 1, round(self.epoch + zoom_out_factor * epochs_shown)]
754
+ )
755
+
756
+ else: # reset
757
+ self.upper_left_epoch = 0
758
+ self.upper_right_epoch = self.n_epochs - 1
759
+ self.adjust_upper_figure_x_limits()
760
+
761
+ def modify_current_epoch_label(self, digit: int) -> None:
762
+ """Change the current epoch's brain state label
763
+
764
+ :param digit: new brain state label in "digit" format
765
+ """
766
+ previous_label = self.labels[self.epoch]
767
+ self.labels[self.epoch] = digit
768
+ # if something changed, track the change
769
+ if previous_label != digit:
770
+ self.add_to_history(
771
+ StateChange(
772
+ previous_labels=np.array([previous_label]),
773
+ new_labels=np.array([digit]),
774
+ epoch=self.epoch,
775
+ )
776
+ )
777
+
778
+ # make the change
779
+ display_label = convert_labels(
780
+ np.array([digit]),
781
+ style=DISPLAY_FORMAT,
782
+ )[0]
783
+ self.display_labels[self.epoch] = display_label
784
+ # update the label image
785
+ if display_label == 0:
786
+ self.label_img[:, self.epoch] = np.array([0, 0, 0, 1])
787
+ else:
788
+ self.label_img[:, self.epoch, :] = 1
789
+ self.label_img[
790
+ display_label - self.smallest_display_label, self.epoch, :
791
+ ] = LABEL_CMAP[display_label]
792
+ # autoscroll, if that is enabled
793
+ if self.autoscroll_state and self.epoch < self.n_epochs - 1:
794
+ self.shift_epoch(DIRECTION_RIGHT) # this calls update_figures()
795
+ else:
796
+ self.update_figures()
797
+
798
+ def shift_epoch(self, direction: str) -> None:
799
+ """Set the current epoch one step forward or backward
800
+
801
+ When the user presses the left or right arrow key, the previous
802
+ or next epoch will be selected. There are a variety of edge cases
803
+ that need to be handled separately for the upper and lower figures.
804
+
805
+ :param direction: left or right
806
+ """
807
+ shift_amount = {DIRECTION_LEFT: -1, DIRECTION_RIGHT: 1}[direction]
808
+ # prevent movement outside the data range
809
+ if not (0 <= (self.epoch + shift_amount) < self.n_epochs):
810
+ return
811
+
812
+ # shift to new epoch
813
+ self.epoch = self.epoch + shift_amount
814
+
815
+ # update upper plot if needed
816
+ upper_epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
817
+ if (
818
+ self.epoch
819
+ > self.upper_left_epoch + (1 - SCROLL_BOUNDARY) * upper_epochs_shown
820
+ and self.upper_right_epoch < (self.n_epochs - 1)
821
+ and direction == DIRECTION_RIGHT
822
+ ):
823
+ self.upper_left_epoch += 1
824
+ self.upper_right_epoch += 1
825
+ self.adjust_upper_figure_x_limits()
826
+ elif (
827
+ self.epoch < self.upper_left_epoch + SCROLL_BOUNDARY * upper_epochs_shown
828
+ and self.upper_left_epoch > 0
829
+ and direction == DIRECTION_LEFT
830
+ ):
831
+ self.upper_left_epoch -= 1
832
+ self.upper_right_epoch -= 1
833
+ self.adjust_upper_figure_x_limits()
834
+
835
+ # update parts of lower plot
836
+ old_window_center = round(self.epochs_to_show / 2) + self.lower_left_epoch
837
+ # change the window bounds if needed
838
+ if self.epoch < old_window_center and self.lower_left_epoch > 0:
839
+ self.lower_left_epoch -= 1
840
+ self.lower_right_epoch -= 1
841
+ elif (
842
+ self.epoch > old_window_center
843
+ and self.lower_right_epoch < self.n_epochs - 1
844
+ ):
845
+ self.lower_left_epoch += 1
846
+ self.lower_right_epoch += 1
847
+
848
+ self.update_figures()
849
+
850
+ def update_upper_marker(self) -> None:
851
+ """Update location of the upper figure's epoch marker"""
852
+ epoch_padding = round((self.epochs_to_show - 1) / 2)
853
+ if self.epoch - epoch_padding < 0:
854
+ left_edge = 0
855
+ right_edge = self.epochs_to_show - 1
856
+ elif self.epoch + epoch_padding > self.n_epochs - 1:
857
+ right_edge = self.n_epochs - 1
858
+ left_edge = self.n_epochs - self.epochs_to_show
859
+ else:
860
+ left_edge = self.epoch - epoch_padding
861
+ right_edge = self.epoch + epoch_padding
862
+ self.ui.upperfigure.upper_marker[0].set_xdata(
863
+ [
864
+ left_edge - 0.5,
865
+ right_edge + 0.5,
866
+ ]
867
+ )
868
+ self.ui.upperfigure.upper_marker[1].set_xdata([self.epoch])
869
+
870
+ def update_lower_epoch_marker(self) -> None:
871
+ """Update location of the lower figure's epoch marker"""
872
+ marker_left = (
873
+ (self.epoch - self.lower_left_epoch)
874
+ * self.epoch_length
875
+ * self.sampling_rate
876
+ )
877
+ marker_right = (
878
+ (1 + self.epoch - self.lower_left_epoch)
879
+ * self.epoch_length
880
+ * self.sampling_rate
881
+ )
882
+ self.ui.lowerfigure.top_marker[0].set_xdata([marker_left, marker_left])
883
+ self.ui.lowerfigure.top_marker[1].set_xdata([marker_left, marker_right])
884
+ self.ui.lowerfigure.top_marker[2].set_xdata([marker_right, marker_right])
885
+ self.ui.lowerfigure.bottom_marker[0].set_xdata([marker_left, marker_left])
886
+ self.ui.lowerfigure.bottom_marker[1].set_xdata([marker_left, marker_right])
887
+ self.ui.lowerfigure.bottom_marker[2].set_xdata([marker_right, marker_right])
888
+
889
+ def update_figures(self) -> None:
890
+ """Update and redraw both figures"""
891
+ # upper figure
892
+ self.update_upper_marker()
893
+ # this step isn't always needed, but it's not too expensive
894
+ self.ui.upperfigure.label_img_ref.set(data=self.label_img)
895
+ self.ui.upperfigure.canvas.draw()
896
+ # lower figure
897
+ self.update_lower_figure()
898
+
899
+ def update_lower_figure(self) -> None:
900
+ """Update and redraw the lower figure"""
901
+ # get subset of signals to plot
902
+ first_sample = round(
903
+ self.lower_left_epoch * self.sampling_rate * self.epoch_length
904
+ )
905
+ last_sample = round(
906
+ (self.lower_right_epoch + 1) * self.sampling_rate * self.epoch_length
907
+ )
908
+ eeg = self.eeg[first_sample:last_sample]
909
+ emg = self.emg[first_sample:last_sample]
910
+
911
+ # scale and shift as needed
912
+ eeg = eeg * self.eeg_signal_scale_factor + self.eeg_signal_offset
913
+ emg = emg * self.emg_signal_scale_factor + self.emg_signal_offset
914
+
915
+ self.update_lower_epoch_marker()
916
+
917
+ # replot eeg and emg
918
+ self.ui.lowerfigure.eeg_line.set_ydata(eeg)
919
+ self.ui.lowerfigure.emg_line.set_ydata(emg)
920
+
921
+ # replot brain state
922
+ self.ui.lowerfigure.label_img_ref.set(
923
+ data=self.label_img[
924
+ :, self.lower_left_epoch : (self.lower_right_epoch + 1), :
925
+ ]
926
+ )
927
+ # update timestamps
928
+ x_ticks = resample_x_ticks(
929
+ np.arange(self.lower_left_epoch, self.lower_right_epoch + 1)
930
+ )
931
+ self.ui.lowerfigure.canvas.axes[1].set_xticklabels(
932
+ [
933
+ "{:02d}:{:02d}:{:05.2f}".format(
934
+ int(x // 3600), int(x // 60) % 60, (x % 60)
935
+ )
936
+ for x in x_ticks * self.epoch_length
937
+ ]
938
+ )
939
+
940
+ self.ui.lowerfigure.canvas.draw()
941
+
942
+ def click_to_jump(self, event) -> None:
943
+ """Jump to a new epoch when the user clicks on the upper figure
944
+
945
+ This is the callback for mouse clicks on the upper figure. Clicking on
946
+ any of the subplots will jump to the nearest epoch.
947
+
948
+ :param event: a MouseEvent containing the click data
949
+ """
950
+ # make sure click location is valid
951
+ # and we are not in label ROI mode
952
+ if event.xdata is None or self.label_roi_mode:
953
+ return
954
+
955
+ # get click location
956
+ x = event.xdata
957
+ # if it's a real click, and not one we simulated
958
+ if event.inaxes is not None:
959
+ # if it's on the spectrogram, we have to adjust it slightly
960
+ # since that uses a different x-axis range
961
+ ax_index = self.ui.upperfigure.canvas.axes.index(event.inaxes)
962
+ if ax_index == 2:
963
+ x -= 0.5
964
+
965
+ # get the "zoom level" so we can preserve that
966
+ upper_epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
967
+ upper_epoch_padding = round((upper_epochs_shown - 1) / 2)
968
+ # update epoch
969
+ self.epoch = round(np.clip(x, 0, self.n_epochs - 1))
970
+ # update upper figure x-axis limits
971
+ if self.epoch - upper_epoch_padding < 0:
972
+ self.upper_left_epoch = 0
973
+ self.upper_right_epoch = upper_epochs_shown - 1
974
+ elif self.epoch + upper_epoch_padding > self.n_epochs - 1:
975
+ self.upper_right_epoch = self.n_epochs - 1
976
+ self.upper_left_epoch = self.n_epochs - upper_epochs_shown
977
+ else:
978
+ self.upper_left_epoch = self.epoch - upper_epoch_padding
979
+ self.upper_right_epoch = self.epoch + upper_epoch_padding
980
+ # update upper marker
981
+ self.update_upper_marker()
982
+ self.adjust_upper_figure_x_limits()
983
+
984
+ # update lower figure x-axis range
985
+ lower_epoch_padding = round((self.epochs_to_show - 1) / 2)
986
+ if self.epoch - lower_epoch_padding < 0:
987
+ self.lower_left_epoch = 0
988
+ self.lower_right_epoch = self.epochs_to_show - 1
989
+ elif self.epoch + lower_epoch_padding > self.n_epochs - 1:
990
+ self.lower_right_epoch = self.n_epochs - 1
991
+ self.lower_left_epoch = self.n_epochs - self.epochs_to_show
992
+ else:
993
+ self.lower_left_epoch = self.epoch - lower_epoch_padding
994
+ self.lower_right_epoch = self.epoch + lower_epoch_padding
995
+
996
+ self.update_lower_figure()
997
+
998
+
999
+ def convert_labels(labels: np.array, style: str) -> np.array:
1000
+ """Convert labels between "display" and "digit" formats
1001
+
1002
+ It's useful to represent brain state labels in two ways:
1003
+ Digit format: this is how labels are represented in files. It matches the digit
1004
+ attribute of the BrainState class as well as the number pressed on the
1005
+ keyboard to set an epoch to that brain state.
1006
+ Display format: the y-axis value associated with a brain state when brain state
1007
+ labels are displayed as an image. This is also the index of the brain state
1008
+ in the colormap. Undefined epochs are mapped to 0, and digits are mapped to
1009
+ the numbers 1-10 in keyboard order (1234567890).
1010
+
1011
+ :param labels: brain state labels
1012
+ :param style: target format for the output
1013
+ :return: formatted labels
1014
+ """
1015
+ if style == DISPLAY_FORMAT:
1016
+ # convert 0 to 10, undefined to 0
1017
+ labels = [i if i != 0 else 10 for i in labels]
1018
+ return np.array([i if i != UNDEFINED_LABEL else 0 for i in labels])
1019
+ elif style == DIGIT_FORMAT:
1020
+ # convert 0 to undefined, 10 to 0
1021
+ labels = [i if i != 0 else UNDEFINED_LABEL for i in labels]
1022
+ return np.array([i if i != 10 else 0 for i in labels])
1023
+ else:
1024
+ raise Exception(f"style must be '{DISPLAY_FORMAT}' or '{DIGIT_FORMAT}'")
1025
+
1026
+
1027
+ def create_label_img(labels: np.array, label_display_options: np.array) -> np.array:
1028
+ """Create an image to display brain state labels
1029
+
1030
+ :param labels: brain state labels, in "display" format
1031
+ :param label_display_options: y-axis locations of valid brain state labels
1032
+ :return: brain state label image
1033
+ """
1034
+ # While there can be up to 10 valid brain states, it's possible that not all of them
1035
+ # are in use. We don't need to display brain states below and above the range of
1036
+ # valid brain states, since those rows would always be empty.
1037
+ smallest_display_label = np.min(label_display_options)
1038
+ # "background" of the image is white
1039
+ label_img = np.ones(
1040
+ [
1041
+ (np.max(label_display_options) - smallest_display_label + 1),
1042
+ len(labels),
1043
+ 4,
1044
+ ]
1045
+ )
1046
+ for i, label in enumerate(labels):
1047
+ if label > 0:
1048
+ label_img[label - smallest_display_label, i, :] = LABEL_CMAP[label]
1049
+ else:
1050
+ # label is undefined
1051
+ label_img[:, i] = np.array([0, 0, 0, 1])
1052
+ return label_img
1053
+
1054
+
1055
+ def create_upper_emg_signal(
1056
+ emg: np.array, sampling_rate: int | float, epoch_length: int | float
1057
+ ) -> np.array:
1058
+ """Calculate RMS of EMG for each epoch and apply a ceiling
1059
+
1060
+ :param emg: EMG signal
1061
+ :param sampling_rate: sampling rate, in Hz
1062
+ :param epoch_length: epoch length, in seconds
1063
+ :return: processed EMG signal
1064
+ """
1065
+ emg_rms = get_emg_power(
1066
+ emg,
1067
+ sampling_rate,
1068
+ epoch_length,
1069
+ )
1070
+ return np.clip(emg_rms, 0, np.mean(emg_rms) + np.std(emg_rms) * 2.5)
1071
+
1072
+
1073
+ def transform_eeg_emg(eeg: np.array, emg: np.array) -> (np.array, np.array):
1074
+ """Center and scale the EEG and EMG signals
1075
+
1076
+ A heuristic approach to fitting the EEG and EMG signals in the plot.
1077
+
1078
+ :param eeg: EEG signal
1079
+ :param emg: EMG signal
1080
+ :return: centered and scaled signals
1081
+ """
1082
+ eeg = eeg - np.mean(eeg)
1083
+ emg = emg - np.mean(emg)
1084
+ eeg = eeg / np.percentile(eeg, 95) / 2.2
1085
+ emg = emg / np.percentile(emg, 95) / 2.2
1086
+ return eeg, emg