accusleepy 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. accusleepy/__init__.py +0 -0
  2. accusleepy/__main__.py +4 -0
  3. accusleepy/bouts.py +142 -0
  4. accusleepy/brain_state_set.py +89 -0
  5. accusleepy/classification.py +285 -0
  6. accusleepy/config.json +24 -0
  7. accusleepy/constants.py +46 -0
  8. accusleepy/fileio.py +179 -0
  9. accusleepy/gui/__init__.py +0 -0
  10. accusleepy/gui/icons/brightness_down.png +0 -0
  11. accusleepy/gui/icons/brightness_up.png +0 -0
  12. accusleepy/gui/icons/double_down_arrow.png +0 -0
  13. accusleepy/gui/icons/double_up_arrow.png +0 -0
  14. accusleepy/gui/icons/down_arrow.png +0 -0
  15. accusleepy/gui/icons/home.png +0 -0
  16. accusleepy/gui/icons/question.png +0 -0
  17. accusleepy/gui/icons/save.png +0 -0
  18. accusleepy/gui/icons/up_arrow.png +0 -0
  19. accusleepy/gui/icons/zoom_in.png +0 -0
  20. accusleepy/gui/icons/zoom_out.png +0 -0
  21. accusleepy/gui/images/primary_window.png +0 -0
  22. accusleepy/gui/images/viewer_window.png +0 -0
  23. accusleepy/gui/images/viewer_window_annotated.png +0 -0
  24. accusleepy/gui/main.py +1494 -0
  25. accusleepy/gui/manual_scoring.py +1096 -0
  26. accusleepy/gui/mplwidget.py +386 -0
  27. accusleepy/gui/primary_window.py +2577 -0
  28. accusleepy/gui/primary_window.ui +3831 -0
  29. accusleepy/gui/resources.qrc +16 -0
  30. accusleepy/gui/resources_rc.py +6710 -0
  31. accusleepy/gui/text/config_guide.txt +27 -0
  32. accusleepy/gui/text/main_guide.md +167 -0
  33. accusleepy/gui/text/manual_scoring_guide.md +23 -0
  34. accusleepy/gui/viewer_window.py +610 -0
  35. accusleepy/gui/viewer_window.ui +926 -0
  36. accusleepy/models.py +108 -0
  37. accusleepy/multitaper.py +661 -0
  38. accusleepy/signal_processing.py +469 -0
  39. accusleepy/temperature_scaling.py +157 -0
  40. accusleepy-0.6.0.dist-info/METADATA +106 -0
  41. accusleepy-0.6.0.dist-info/RECORD +42 -0
  42. accusleepy-0.6.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1096 @@
1
+ # AccuSleePy manual scoring GUI
2
+ # Icon sources:
3
+ # Arkinasi, https://www.flaticon.com/authors/arkinasi
4
+ # kendis lasman, https://www.flaticon.com/packs/ui-79
5
+
6
+
7
+ import copy
8
+ import os
9
+ from dataclasses import dataclass
10
+ from functools import partial
11
+ from types import SimpleNamespace
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PySide6.QtCore import (
16
+ QKeyCombination,
17
+ QRect,
18
+ Qt,
19
+ QUrl,
20
+ )
21
+ from PySide6.QtGui import (
22
+ QCloseEvent,
23
+ QKeySequence,
24
+ QShortcut,
25
+ )
26
+ from PySide6.QtWidgets import (
27
+ QDialog,
28
+ QMessageBox,
29
+ QTextBrowser,
30
+ QVBoxLayout,
31
+ QWidget,
32
+ )
33
+
34
+ from accusleepy.constants import UNDEFINED_LABEL
35
+ from accusleepy.fileio import load_config, save_labels
36
+ from accusleepy.gui.mplwidget import resample_x_ticks
37
+ from accusleepy.gui.viewer_window import Ui_ViewerWindow
38
+ from accusleepy.signal_processing import create_spectrogram, get_emg_power
39
+
40
+ # colormap for displaying brain state labels
41
+ # the first entry represents the "undefined" state
42
+ # the other entries are the digits in "keyboard" order (1234567890)
43
+ LABEL_CMAP = np.concatenate(
44
+ [np.array([[0, 0, 0, 0]]), plt.colormaps["tab10"](range(10))], axis=0
45
+ )
46
+ # relative path to user manual text file
47
+ USER_MANUAL_FILE = os.path.normpath(r"text/manual_scoring_guide.md")
48
+
49
+ # constants used by callback functions
50
+ # label formats
51
+ DISPLAY_FORMAT = "display"
52
+ DIGIT_FORMAT = "digit"
53
+ # offset changes
54
+ OFFSET_UP = "up"
55
+ OFFSET_DOWN = "down"
56
+ OFFSET_INCREMENTS = {OFFSET_UP: 0.02, OFFSET_DOWN: -0.02}
57
+ # changes to number of epochs
58
+ DIRECTION_PLUS = "plus"
59
+ DIRECTION_MINUS = "minus"
60
+ # changes to selected epoch
61
+ DIRECTION_LEFT = "left"
62
+ DIRECTION_RIGHT = "right"
63
+ # zoom directions
64
+ ZOOM_IN = "in"
65
+ ZOOM_OUT = "out"
66
+ ZOOM_RESET = "reset"
67
+ SIGNAL_ZOOM_FACTORS = {ZOOM_IN: 1.08, ZOOM_OUT: 0.95}
68
+ # signal names
69
+ EEG_SIGNAL = "eeg"
70
+ EMG_SIGNAL = "emg"
71
+ # spectrogram color changes
72
+ BRIGHTER = "brighter"
73
+ DIMMER = "dimmer"
74
+ # next epoch target
75
+ DIFFERENT_STATE = "different"
76
+ UNDEFINED_STATE = "undefined"
77
+ # how far from the edge of the upper plot the marker should be
78
+ # before starting to scroll again - must be in (0, 0.5)
79
+ SCROLL_BOUNDARY = 0.35
80
+ # max number of sequential undo actions allowed
81
+ UNDO_LIMIT = 1000
82
+
83
+
84
+ @dataclass
85
+ class StateChange:
86
+ """Information about an event when brain state labels were changed"""
87
+
88
+ previous_labels: np.array # old brain state labels
89
+ new_labels: np.array # new brain state labels
90
+ epoch: int # first epoch affected
91
+
92
+
93
+ class ManualScoringWindow(QDialog):
94
+ """AccuSleePy manual scoring GUI"""
95
+
96
+ def __init__(
97
+ self,
98
+ eeg: np.array,
99
+ emg: np.array,
100
+ label_file: str,
101
+ labels: np.array,
102
+ confidence_scores: np.array,
103
+ sampling_rate: int | float,
104
+ epoch_length: int | float,
105
+ ):
106
+ """Initialize the manual scoring window
107
+
108
+ :param eeg: EEG signal
109
+ :param emg: EMG signal
110
+ :param label_file: filename for labels
111
+ :param labels: brain state labels
112
+ :param confidence_scores: confidence scores
113
+ :param sampling_rate: sampling rate, in Hz
114
+ :param epoch_length: epoch length, in seconds
115
+ """
116
+ super(ManualScoringWindow, self).__init__()
117
+
118
+ self.label_file = label_file
119
+ self.eeg = eeg
120
+ self.emg = emg
121
+ self.labels = labels
122
+ self.confidence_scores = confidence_scores
123
+ self.sampling_rate = sampling_rate
124
+ self.epoch_length = epoch_length
125
+
126
+ self.n_epochs = len(self.labels)
127
+
128
+ # initialize the UI
129
+ self.ui = Ui_ViewerWindow()
130
+ self.ui.setupUi(self)
131
+ self.setWindowTitle("AccuSleePy manual scoring window")
132
+
133
+ # load set of valid brain states
134
+ self.brain_state_set, _, _ = load_config()
135
+
136
+ # initial setting for number of epochs to show in the lower plot
137
+ self.epochs_to_show = 5
138
+
139
+ # find the set of y-axis locations of valid brain state labels
140
+ self.label_display_options = convert_labels(
141
+ np.array([b.digit for b in self.brain_state_set.brain_states]),
142
+ style=DISPLAY_FORMAT,
143
+ )
144
+ self.smallest_display_label = np.min(self.label_display_options)
145
+
146
+ self.ui.upperfigure.epoch_length = self.epoch_length
147
+ self.ui.lowerfigure.epoch_length = self.epoch_length
148
+
149
+ # get EEG spectrogram and its frequency axis
150
+ spectrogram, spectrogram_frequencies = create_spectrogram(
151
+ self.eeg, self.sampling_rate, self.epoch_length
152
+ )
153
+
154
+ # calculate RMS of EMG for each epoch and apply a ceiling
155
+ self.upper_emg = create_upper_emg_signal(
156
+ self.emg, self.sampling_rate, self.epoch_length
157
+ )
158
+
159
+ # center and scale the EEG and EMG signals to fit the display
160
+ self.eeg, self.emg = transform_eeg_emg(self.eeg, self.emg)
161
+
162
+ # convert labels to "display" format and make an image to display them
163
+ self.display_labels = convert_labels(self.labels, DISPLAY_FORMAT)
164
+ self.label_img = create_label_img(
165
+ self.display_labels, self.label_display_options
166
+ )
167
+ # same sort of thing for confidence scores
168
+ self.confidence_img = create_confidence_img(self.confidence_scores)
169
+
170
+ # history of changes to the brain state labels
171
+ self.history = list()
172
+ # index of the change "ahead" of the current state
173
+ # i.e., which change will be applied by a "redo" action
174
+ self.history_index = 0
175
+
176
+ # set up both figures
177
+ self.ui.upperfigure.setup_upper_figure(
178
+ self.n_epochs,
179
+ self.label_img,
180
+ self.confidence_scores,
181
+ self.confidence_img,
182
+ spectrogram,
183
+ spectrogram_frequencies,
184
+ self.upper_emg,
185
+ self.epochs_to_show,
186
+ self.label_display_options,
187
+ self.brain_state_set,
188
+ self.roi_callback,
189
+ )
190
+ self.ui.lowerfigure.setup_lower_figure(
191
+ self.label_img,
192
+ self.sampling_rate,
193
+ self.epochs_to_show,
194
+ self.brain_state_set,
195
+ self.label_display_options,
196
+ )
197
+
198
+ # initialize values that can be changed by user input
199
+ self.epoch = 0
200
+ self.upper_left_epoch = 0
201
+ self.upper_right_epoch = self.n_epochs - 1
202
+ self.lower_left_epoch = 0
203
+ self.lower_right_epoch = self.epochs_to_show - 1
204
+ self.eeg_signal_scale_factor = 1
205
+ self.emg_signal_scale_factor = 1
206
+ self.eeg_signal_offset = 0
207
+ self.emg_signal_offset = 0
208
+ self.roi_brain_state = 0
209
+ self.label_roi_mode = False
210
+ self.autoscroll_state = False
211
+ # keep track of save state to warn user when they quit
212
+ self.last_saved_labels = copy.deepcopy(self.labels)
213
+
214
+ # populate the lower figure
215
+ self.update_lower_figure()
216
+
217
+ # user input: keyboard shortcuts
218
+ keypress_right = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
219
+ keypress_right.activated.connect(partial(self.shift_epoch, DIRECTION_RIGHT))
220
+
221
+ keypress_left = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
222
+ keypress_left.activated.connect(partial(self.shift_epoch, DIRECTION_LEFT))
223
+
224
+ keypress_zoom_in_x = list()
225
+ for zoom_key in [Qt.Key.Key_Plus, Qt.Key.Key_Equal]:
226
+ keypress_zoom_in_x.append(QShortcut(QKeySequence(zoom_key), self))
227
+ keypress_zoom_in_x[-1].activated.connect(partial(self.zoom_x, ZOOM_IN))
228
+
229
+ keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
230
+ keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
231
+
232
+ keypress_modify_label = list()
233
+ for brain_state in self.brain_state_set.brain_states:
234
+ keypress_modify_label.append(
235
+ QShortcut(
236
+ QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
237
+ self,
238
+ )
239
+ )
240
+ keypress_modify_label[-1].activated.connect(
241
+ partial(self.modify_current_epoch_label, brain_state.digit)
242
+ )
243
+
244
+ keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
245
+ keypress_delete_label.activated.connect(
246
+ partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
247
+ )
248
+
249
+ keypress_quit = QShortcut(
250
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
251
+ self,
252
+ )
253
+ keypress_quit.activated.connect(self.close)
254
+
255
+ keypress_save = QShortcut(
256
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_S)),
257
+ self,
258
+ )
259
+ keypress_save.activated.connect(self.save)
260
+
261
+ keypress_roi = list()
262
+ for brain_state in self.brain_state_set.brain_states:
263
+ keypress_roi.append(
264
+ QShortcut(
265
+ QKeySequence(
266
+ QKeyCombination(
267
+ Qt.Modifier.SHIFT,
268
+ Qt.Key[f"Key_{brain_state.digit}"],
269
+ )
270
+ ),
271
+ self,
272
+ )
273
+ )
274
+ keypress_roi[-1].activated.connect(
275
+ partial(self.enter_label_roi_mode, brain_state.digit)
276
+ )
277
+ keypress_roi.append(
278
+ QShortcut(
279
+ QKeySequence(
280
+ QKeyCombination(
281
+ Qt.Modifier.SHIFT,
282
+ Qt.Key.Key_Backspace,
283
+ )
284
+ ),
285
+ self,
286
+ )
287
+ )
288
+ keypress_roi[-1].activated.connect(
289
+ partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
290
+ )
291
+
292
+ keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
293
+ keypress_esc.activated.connect(self.exit_label_roi_mode)
294
+
295
+ keypress_space = QShortcut(QKeySequence(Qt.Key.Key_Space), self)
296
+ keypress_space.activated.connect(
297
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
298
+ )
299
+ keypress_shift_right = QShortcut(
300
+ QKeySequence(
301
+ QKeyCombination(
302
+ Qt.Modifier.SHIFT,
303
+ Qt.Key.Key_Right,
304
+ )
305
+ ),
306
+ self,
307
+ )
308
+ keypress_shift_right.activated.connect(
309
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
310
+ )
311
+ keypress_shift_left = QShortcut(
312
+ QKeySequence(
313
+ QKeyCombination(
314
+ Qt.Modifier.SHIFT,
315
+ Qt.Key.Key_Left,
316
+ )
317
+ ),
318
+ self,
319
+ )
320
+ keypress_shift_left.activated.connect(
321
+ partial(self.jump_to_next_state, DIRECTION_LEFT, DIFFERENT_STATE)
322
+ )
323
+ keypress_ctrl_right = QShortcut(
324
+ QKeySequence(
325
+ QKeyCombination(
326
+ Qt.Modifier.CTRL,
327
+ Qt.Key.Key_Right,
328
+ )
329
+ ),
330
+ self,
331
+ )
332
+ keypress_ctrl_right.activated.connect(
333
+ partial(self.jump_to_next_state, DIRECTION_RIGHT, UNDEFINED_STATE)
334
+ )
335
+ keypress_ctrl_left = QShortcut(
336
+ QKeySequence(
337
+ QKeyCombination(
338
+ Qt.Modifier.CTRL,
339
+ Qt.Key.Key_Left,
340
+ )
341
+ ),
342
+ self,
343
+ )
344
+ keypress_ctrl_left.activated.connect(
345
+ partial(self.jump_to_next_state, DIRECTION_LEFT, UNDEFINED_STATE)
346
+ )
347
+
348
+ keypress_undo = QShortcut(
349
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Z)),
350
+ self,
351
+ )
352
+ keypress_undo.activated.connect(self.undo)
353
+ keypress_redo = QShortcut(
354
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Y)),
355
+ self,
356
+ )
357
+ keypress_redo.activated.connect(self.redo)
358
+
359
+ # user input: clicks
360
+ self.ui.upperfigure.canvas.mpl_connect("button_press_event", self.click_to_jump)
361
+
362
+ # user input: buttons
363
+ self.ui.savebutton.clicked.connect(self.save)
364
+ self.ui.xzoomin.clicked.connect(partial(self.zoom_x, ZOOM_IN))
365
+ self.ui.xzoomout.clicked.connect(partial(self.zoom_x, ZOOM_OUT))
366
+ self.ui.xzoomreset.clicked.connect(partial(self.zoom_x, ZOOM_RESET))
367
+ self.ui.autoscroll.stateChanged.connect(self.update_autoscroll_state)
368
+ self.ui.eegzoomin.clicked.connect(
369
+ partial(self.update_signal_zoom, EEG_SIGNAL, ZOOM_IN)
370
+ )
371
+ self.ui.eegzoomout.clicked.connect(
372
+ partial(self.update_signal_zoom, EEG_SIGNAL, ZOOM_OUT)
373
+ )
374
+ self.ui.emgzoomin.clicked.connect(
375
+ partial(self.update_signal_zoom, EMG_SIGNAL, ZOOM_IN)
376
+ )
377
+ self.ui.emgzoomout.clicked.connect(
378
+ partial(self.update_signal_zoom, EMG_SIGNAL, ZOOM_OUT)
379
+ )
380
+ self.ui.eegshiftup.clicked.connect(
381
+ partial(self.update_signal_offset, EEG_SIGNAL, OFFSET_UP)
382
+ )
383
+ self.ui.eegshiftdown.clicked.connect(
384
+ partial(self.update_signal_offset, EEG_SIGNAL, OFFSET_DOWN)
385
+ )
386
+ self.ui.emgshiftup.clicked.connect(
387
+ partial(self.update_signal_offset, EMG_SIGNAL, OFFSET_UP)
388
+ )
389
+ self.ui.emgshiftdown.clicked.connect(
390
+ partial(self.update_signal_offset, EMG_SIGNAL, OFFSET_DOWN)
391
+ )
392
+ self.ui.shownepochsplus.clicked.connect(
393
+ partial(self.update_epochs_shown, DIRECTION_PLUS)
394
+ )
395
+ self.ui.shownepochsminus.clicked.connect(
396
+ partial(self.update_epochs_shown, DIRECTION_MINUS)
397
+ )
398
+ self.ui.specbrighter.clicked.connect(
399
+ partial(self.update_spectrogram_brightness, BRIGHTER)
400
+ )
401
+ self.ui.specdimmer.clicked.connect(
402
+ partial(self.update_spectrogram_brightness, DIMMER)
403
+ )
404
+ self.ui.helpbutton.clicked.connect(self.show_user_manual)
405
+
406
+ self.show()
407
+
408
+ def add_to_history(self, state_change: StateChange) -> None:
409
+ """Add an event to the history of changes to brain state labels
410
+
411
+ This allows the user to undo / redo changes to the brain state
412
+ labels by navigating backwards or forwards through a list of
413
+ changes that have been made. If one or more changes are undone,
414
+ and then a new change is performed, it will not be possible to
415
+ redo the changes that were undone. At most UNDO_LIMIT changes
416
+ are stored.
417
+
418
+ :param state_change: description of the change to the labels
419
+ """
420
+ # if history is empty
421
+ if len(self.history) == 0:
422
+ self.history.append(state_change)
423
+ self.history_index = 1
424
+ return
425
+ # if we are not at the end of the history
426
+ if self.history_index < len(self.history):
427
+ # remove events after the most recent one
428
+ self.history = self.history[: self.history_index]
429
+ self.history.append(state_change)
430
+ self.history_index += 1
431
+ else:
432
+ self.history.append(state_change)
433
+ # if this would make the history list too long
434
+ if self.history_index == UNDO_LIMIT:
435
+ # remove the oldest entry
436
+ self.history = self.history[1:]
437
+ else:
438
+ self.history_index += 1
439
+
440
+ def force_modify_labels(self, epoch: int, new_labels: np.array) -> None:
441
+ """Change brain state labels for an undo/redo action"""
442
+ # make the change
443
+ self.labels[epoch : epoch + len(new_labels)] = new_labels
444
+ self.display_labels = convert_labels(
445
+ self.labels,
446
+ style=DISPLAY_FORMAT,
447
+ )
448
+ self.label_img = create_label_img(
449
+ self.display_labels, self.label_display_options
450
+ )
451
+
452
+ def redo(self) -> None:
453
+ """Redo the last change to brain state labels that was undone"""
454
+ # if there are no events to redo
455
+ if self.history_index == len(self.history):
456
+ return
457
+ # make the change
458
+ state_change = self.history[self.history_index]
459
+ self.force_modify_labels(
460
+ epoch=state_change.epoch, new_labels=state_change.new_labels
461
+ )
462
+ # update history index
463
+ self.history_index += 1
464
+ # shift the cursor
465
+ simulated_click = SimpleNamespace(
466
+ **{"xdata": state_change.epoch, "inaxes": None}
467
+ )
468
+ self.click_to_jump(simulated_click)
469
+
470
+ def undo(self) -> None:
471
+ """Undo the last change to the brain state labels"""
472
+ # if there are no events to undo
473
+ if self.history_index == 0:
474
+ return
475
+ # make the change
476
+ state_change = self.history[self.history_index - 1]
477
+ self.force_modify_labels(
478
+ epoch=state_change.epoch, new_labels=state_change.previous_labels
479
+ )
480
+ # update history index
481
+ self.history_index -= 1
482
+ # shift the cursor
483
+ simulated_click = SimpleNamespace(
484
+ **{"xdata": state_change.epoch, "inaxes": None}
485
+ )
486
+ self.click_to_jump(simulated_click)
487
+
488
+ def closeEvent(self, event: QCloseEvent) -> None:
489
+ """Check if there are unsaved changes before closing"""
490
+ if not all(self.labels == self.last_saved_labels):
491
+ result = QMessageBox.question(
492
+ self,
493
+ "Unsaved changes",
494
+ "You have unsaved changes. Really quit?",
495
+ QMessageBox.Yes | QMessageBox.No,
496
+ )
497
+ if result == QMessageBox.Yes:
498
+ event.accept()
499
+ else:
500
+ event.ignore()
501
+
502
+ def show_user_manual(self) -> None:
503
+ """Show a popup window with the user manual"""
504
+ self.popup = QWidget()
505
+ self.popup_vlayout = QVBoxLayout(self.popup)
506
+ self.guide_textbox = QTextBrowser(self.popup)
507
+ self.popup_vlayout.addWidget(self.guide_textbox)
508
+
509
+ url = QUrl.fromLocalFile(
510
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE)
511
+ )
512
+ self.guide_textbox.setSource(url)
513
+ self.guide_textbox.setOpenLinks(False)
514
+
515
+ self.popup.setGeometry(QRect(100, 100, 830, 600))
516
+ self.popup.show()
517
+
518
+ def jump_to_next_state(self, direction: str, target: str) -> None:
519
+ """Jump to epoch based on a target brain state
520
+
521
+ This allows the user to jump to the next epoch in a given direction
522
+ (left or right) that has a given state (undefined, or different from
523
+ the current epoch). It's useful for reviewing state transitions or
524
+ locating unlabeled epochs.
525
+
526
+ :param direction: left or right
527
+ :param target: different or undefined
528
+ """
529
+ # create a simulated click so we can reuse click_to_jump
530
+ simulated_click = SimpleNamespace(**{"xdata": self.epoch, "inaxes": None})
531
+ if direction == DIRECTION_RIGHT:
532
+ if target == DIFFERENT_STATE:
533
+ matches = np.where(
534
+ self.labels[self.epoch + 1 :] != self.labels[self.epoch]
535
+ )[0]
536
+ else:
537
+ matches = np.where(self.labels[self.epoch + 1 :] == UNDEFINED_LABEL)[0]
538
+ if matches.size > 0:
539
+ simulated_click.xdata = int(matches[0]) + 1 + self.epoch
540
+ else:
541
+ if target == DIFFERENT_STATE:
542
+ matches = np.where(
543
+ self.labels[: self.epoch] != self.labels[self.epoch]
544
+ )[0]
545
+ else:
546
+ matches = np.where(self.labels[: self.epoch] == UNDEFINED_LABEL)[0]
547
+ if matches.size > 0:
548
+ simulated_click.xdata = int(matches[-1])
549
+ self.click_to_jump(simulated_click)
550
+
551
+ def roi_callback(self, eclick, erelease) -> None:
552
+ """Callback for ROI labeling widget
553
+
554
+ This is called by the RectangleSelector widget when the user finishes
555
+ drawing an ROI. It sets a range of epochs to the desired brain state.
556
+ The function signature is required to have this format.
557
+ """
558
+ # get range of epochs affected
559
+ epoch_start = int(np.ceil(eclick.xdata))
560
+ epoch_end = int(np.floor(erelease.xdata)) + 1
561
+
562
+ previous_labels = copy.deepcopy(self.labels[epoch_start:epoch_end])
563
+ new_labels = np.ones(epoch_end - epoch_start).astype(int) * self.roi_brain_state
564
+
565
+ # if something changed, track the change
566
+ if not np.array_equal(previous_labels, new_labels):
567
+ self.add_to_history(
568
+ StateChange(
569
+ previous_labels=previous_labels,
570
+ new_labels=new_labels,
571
+ epoch=epoch_start,
572
+ )
573
+ )
574
+ # make the change
575
+ self.labels[epoch_start:epoch_end] = self.roi_brain_state
576
+ self.display_labels = convert_labels(
577
+ self.labels,
578
+ style=DISPLAY_FORMAT,
579
+ )
580
+ self.label_img = create_label_img(
581
+ self.display_labels, self.label_display_options
582
+ )
583
+ # update the plots
584
+ self.update_figures()
585
+ self.exit_label_roi_mode()
586
+
587
+ def exit_label_roi_mode(self) -> None:
588
+ """Restore the normal GUI state after an ROI is drawn"""
589
+ self.ui.upperfigure.roi.set_active(False)
590
+ self.ui.upperfigure.roi.set_visible(False)
591
+ self.ui.upperfigure.roi.update()
592
+ self.label_roi_mode = False
593
+ self.ui.upperfigure.editing_patch.set_visible(False)
594
+
595
+ def enter_label_roi_mode(self, brain_state: int) -> None:
596
+ """Enter ROI drawing mode
597
+
598
+ In this mode, a user can draw an ROI on the upper brain state label
599
+ image to set a range of epochs to a new brain state.
600
+
601
+ :param brain_state: new brain state to set
602
+ """
603
+ self.label_roi_mode = True
604
+ self.roi_brain_state = brain_state
605
+ self.ui.upperfigure.roi_patch.set(
606
+ facecolor=LABEL_CMAP[
607
+ convert_labels(np.array([brain_state]), DISPLAY_FORMAT)
608
+ ]
609
+ )
610
+ self.ui.upperfigure.editing_patch.set_visible(True)
611
+ self.ui.upperfigure.canvas.draw()
612
+ self.ui.upperfigure.roi.set_active(True)
613
+
614
+ def save(self) -> None:
615
+ """Save brain state labels to file"""
616
+ save_labels(self.labels, self.label_file)
617
+ self.last_saved_labels = copy.deepcopy(self.labels)
618
+
619
+ def update_spectrogram_brightness(self, direction: str) -> None:
620
+ """Modify spectrogram color range based on button press
621
+
622
+ :param direction: brighter or dimmer
623
+ """
624
+ vmin, vmax = self.ui.upperfigure.spec_ref.get_clim()
625
+ if direction == BRIGHTER:
626
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 0.96))
627
+ else:
628
+ self.ui.upperfigure.spec_ref.set(clim=(vmin, vmax * 1.07))
629
+ self.ui.upperfigure.canvas.draw()
630
+
631
+ def update_epochs_shown(self, direction: str) -> None:
632
+ """Change the number of epochs shown based on button press
633
+
634
+ The user can change the number of epochs shown in the lower figure
635
+ via button presses. This requires extensive changes to both figures.
636
+ The number of epochs can only change in increments of 2 and should
637
+ always be an odd number >= 3.
638
+
639
+ :param direction: plus or minus
640
+ """
641
+ # if we are near the beginning or end of the recording, we need
642
+ # to change the epoch range differently.
643
+ if direction == DIRECTION_PLUS:
644
+ self.epochs_to_show += 2
645
+ if self.lower_left_epoch == 0:
646
+ self.lower_right_epoch += 2
647
+ elif self.lower_right_epoch == self.n_epochs - 1:
648
+ self.lower_left_epoch -= 2
649
+ else:
650
+ self.lower_left_epoch -= 1
651
+ self.lower_right_epoch += 1
652
+ else:
653
+ if self.epochs_to_show > 3:
654
+ self.epochs_to_show -= 2
655
+ if self.lower_left_epoch == 0:
656
+ self.lower_right_epoch -= 2
657
+ elif self.lower_right_epoch == self.n_epochs - 1:
658
+ self.lower_left_epoch += 2
659
+ else:
660
+ self.lower_left_epoch += 1
661
+ self.lower_right_epoch -= 1
662
+
663
+ self.ui.shownepochslabel.setText(str(self.epochs_to_show))
664
+
665
+ # rebuild lower figure from scratch
666
+ self.ui.lowerfigure.canvas.figure.clf()
667
+ self.ui.lowerfigure.setup_lower_figure(
668
+ self.label_img,
669
+ self.sampling_rate,
670
+ self.epochs_to_show,
671
+ self.brain_state_set,
672
+ self.label_display_options,
673
+ )
674
+ self.update_figures()
675
+
676
+ def update_signal_offset(self, signal: str, direction: str) -> None:
677
+ """Shift EEG or EMG up or down
678
+
679
+ :param signal: eeg or emg
680
+ :param direction: up or down
681
+ """
682
+ if signal == EEG_SIGNAL:
683
+ self.eeg_signal_offset += OFFSET_INCREMENTS[direction]
684
+ else:
685
+ self.emg_signal_offset += OFFSET_INCREMENTS[direction]
686
+ self.update_lower_figure()
687
+
688
+ def update_signal_zoom(self, signal: str, direction: str) -> None:
689
+ """Zoom EEG or EMG y-axis
690
+
691
+ :param signal: eeg or emg
692
+ :param direction: in or out
693
+ """
694
+ if signal == EEG_SIGNAL:
695
+ self.eeg_signal_scale_factor *= SIGNAL_ZOOM_FACTORS[direction]
696
+ else:
697
+ self.emg_signal_scale_factor *= SIGNAL_ZOOM_FACTORS[direction]
698
+ self.update_lower_figure()
699
+
700
+ def update_autoscroll_state(self, checked) -> None:
701
+ """Toggle autoscroll behavior
702
+
703
+ If autoscroll is enabled, setting the brain state of the current epoch
704
+ via a keypress will advance to the next epoch.
705
+
706
+ :param checked: state of the checkbox
707
+ """
708
+ self.autoscroll_state = checked
709
+
710
+ def adjust_upper_figure_x_limits(self) -> None:
711
+ """Update the x-axis limits of the upper figure subplots"""
712
+ for i in [0, 1, 2, 4]:
713
+ self.ui.upperfigure.canvas.axes[i].set_xlim(
714
+ (self.upper_left_epoch - 0.5, self.upper_right_epoch + 0.5)
715
+ )
716
+ self.ui.upperfigure.canvas.axes[3].set_xlim(
717
+ (self.upper_left_epoch, self.upper_right_epoch + 1)
718
+ )
719
+
720
+ def zoom_x(self, direction: str) -> None:
721
+ """Change upper figure x-axis zoom level
722
+
723
+ :param direction: in, out, or reset
724
+ """
725
+ zoom_in_factor = 0.45
726
+ zoom_out_factor = 1.017
727
+ epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
728
+ if direction == ZOOM_IN:
729
+ self.upper_left_epoch = max(
730
+ [
731
+ self.upper_left_epoch,
732
+ round(self.epoch - zoom_in_factor * epochs_shown),
733
+ ]
734
+ )
735
+
736
+ self.upper_right_epoch = min(
737
+ [
738
+ self.upper_right_epoch,
739
+ round(self.epoch + zoom_in_factor * epochs_shown),
740
+ ]
741
+ )
742
+
743
+ elif direction == ZOOM_OUT:
744
+ self.upper_left_epoch = max(
745
+ [0, round(self.epoch - zoom_out_factor * epochs_shown)]
746
+ )
747
+
748
+ self.upper_right_epoch = min(
749
+ [self.n_epochs - 1, round(self.epoch + zoom_out_factor * epochs_shown)]
750
+ )
751
+
752
+ else: # reset
753
+ self.upper_left_epoch = 0
754
+ self.upper_right_epoch = self.n_epochs - 1
755
+ self.adjust_upper_figure_x_limits()
756
+ self.ui.upperfigure.canvas.draw()
757
+
758
+ def modify_current_epoch_label(self, digit: int) -> None:
759
+ """Change the current epoch's brain state label
760
+
761
+ :param digit: new brain state label in "digit" format
762
+ """
763
+ previous_label = self.labels[self.epoch]
764
+ self.labels[self.epoch] = digit
765
+ # if something changed, track the change
766
+ if previous_label != digit:
767
+ self.add_to_history(
768
+ StateChange(
769
+ previous_labels=np.array([previous_label]),
770
+ new_labels=np.array([digit]),
771
+ epoch=self.epoch,
772
+ )
773
+ )
774
+
775
+ # make the change
776
+ display_label = convert_labels(
777
+ np.array([digit]),
778
+ style=DISPLAY_FORMAT,
779
+ )[0]
780
+ self.display_labels[self.epoch] = display_label
781
+ # update the label image
782
+ if display_label == 0:
783
+ self.label_img[:, self.epoch] = np.array([0, 0, 0, 1])
784
+ else:
785
+ self.label_img[:, self.epoch, :] = 1
786
+ self.label_img[
787
+ display_label - self.smallest_display_label, self.epoch, :
788
+ ] = LABEL_CMAP[display_label]
789
+ # autoscroll, if that is enabled
790
+ if self.autoscroll_state and self.epoch < self.n_epochs - 1:
791
+ self.shift_epoch(DIRECTION_RIGHT) # this calls update_figures()
792
+ else:
793
+ self.update_figures()
794
+
795
+ def shift_epoch(self, direction: str) -> None:
796
+ """Set the current epoch one step forward or backward
797
+
798
+ When the user presses the left or right arrow key, the previous
799
+ or next epoch will be selected. There are a variety of edge cases
800
+ that need to be handled separately for the upper and lower figures.
801
+
802
+ :param direction: left or right
803
+ """
804
+ shift_amount = {DIRECTION_LEFT: -1, DIRECTION_RIGHT: 1}[direction]
805
+ # prevent movement outside the data range
806
+ if not (0 <= (self.epoch + shift_amount) < self.n_epochs):
807
+ return
808
+
809
+ # shift to new epoch
810
+ self.epoch = self.epoch + shift_amount
811
+
812
+ # update upper plot if needed
813
+ upper_epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
814
+ if (
815
+ self.epoch
816
+ > self.upper_left_epoch + (1 - SCROLL_BOUNDARY) * upper_epochs_shown
817
+ and self.upper_right_epoch < (self.n_epochs - 1)
818
+ and direction == DIRECTION_RIGHT
819
+ ):
820
+ self.upper_left_epoch += 1
821
+ self.upper_right_epoch += 1
822
+ self.adjust_upper_figure_x_limits()
823
+ elif (
824
+ self.epoch < self.upper_left_epoch + SCROLL_BOUNDARY * upper_epochs_shown
825
+ and self.upper_left_epoch > 0
826
+ and direction == DIRECTION_LEFT
827
+ ):
828
+ self.upper_left_epoch -= 1
829
+ self.upper_right_epoch -= 1
830
+ self.adjust_upper_figure_x_limits()
831
+
832
+ # update parts of lower plot
833
+ old_window_center = round((self.epochs_to_show - 1) / 2) + self.lower_left_epoch
834
+ # change the window bounds if needed
835
+ if self.epoch < old_window_center and self.lower_left_epoch > 0:
836
+ self.lower_left_epoch -= 1
837
+ self.lower_right_epoch -= 1
838
+ elif (
839
+ self.epoch > old_window_center
840
+ and self.lower_right_epoch < self.n_epochs - 1
841
+ ):
842
+ self.lower_left_epoch += 1
843
+ self.lower_right_epoch += 1
844
+
845
+ self.update_figures()
846
+
847
+ def update_upper_marker(self) -> None:
848
+ """Update location of the upper figure's epoch marker"""
849
+ epoch_padding = round((self.epochs_to_show - 1) / 2)
850
+ if self.epoch - epoch_padding < 0:
851
+ left_edge = 0
852
+ right_edge = self.epochs_to_show - 1
853
+ elif self.epoch + epoch_padding > self.n_epochs - 1:
854
+ right_edge = self.n_epochs - 1
855
+ left_edge = self.n_epochs - self.epochs_to_show
856
+ else:
857
+ left_edge = self.epoch - epoch_padding
858
+ right_edge = self.epoch + epoch_padding
859
+ self.ui.upperfigure.upper_marker[0].set_xdata(
860
+ [
861
+ left_edge - 0.5,
862
+ right_edge + 0.5,
863
+ ]
864
+ )
865
+ self.ui.upperfigure.upper_marker[1].set_xdata([self.epoch])
866
+
867
+ def update_lower_epoch_marker(self) -> None:
868
+ """Update location of the lower figure's epoch marker"""
869
+ marker_left = (
870
+ (self.epoch - self.lower_left_epoch)
871
+ * self.epoch_length
872
+ * self.sampling_rate
873
+ )
874
+ marker_right = (
875
+ (1 + self.epoch - self.lower_left_epoch)
876
+ * self.epoch_length
877
+ * self.sampling_rate
878
+ )
879
+ self.ui.lowerfigure.top_marker[0].set_xdata([marker_left, marker_left])
880
+ self.ui.lowerfigure.top_marker[1].set_xdata([marker_left, marker_right])
881
+ self.ui.lowerfigure.top_marker[2].set_xdata([marker_right, marker_right])
882
+ self.ui.lowerfigure.bottom_marker[0].set_xdata([marker_left, marker_left])
883
+ self.ui.lowerfigure.bottom_marker[1].set_xdata([marker_left, marker_right])
884
+ self.ui.lowerfigure.bottom_marker[2].set_xdata([marker_right, marker_right])
885
+
886
+ def update_figures(self) -> None:
887
+ """Update and redraw both figures"""
888
+ # upper figure
889
+ self.update_upper_marker()
890
+ # this step isn't always needed, but it's not too expensive
891
+ self.ui.upperfigure.label_img_ref.set(data=self.label_img)
892
+ self.ui.upperfigure.canvas.draw()
893
+ # lower figure
894
+ self.update_lower_figure()
895
+
896
+ def update_lower_figure(self) -> None:
897
+ """Update and redraw the lower figure"""
898
+ # get subset of signals to plot
899
+ first_sample = round(
900
+ self.lower_left_epoch * self.sampling_rate * self.epoch_length
901
+ )
902
+ last_sample = round(
903
+ (self.lower_right_epoch + 1) * self.sampling_rate * self.epoch_length
904
+ )
905
+ eeg = self.eeg[first_sample:last_sample]
906
+ emg = self.emg[first_sample:last_sample]
907
+
908
+ # scale and shift as needed
909
+ eeg = eeg * self.eeg_signal_scale_factor + self.eeg_signal_offset
910
+ emg = emg * self.emg_signal_scale_factor + self.emg_signal_offset
911
+
912
+ self.update_lower_epoch_marker()
913
+
914
+ # replot eeg and emg
915
+ self.ui.lowerfigure.eeg_line.set_ydata(eeg)
916
+ self.ui.lowerfigure.emg_line.set_ydata(emg)
917
+
918
+ # replot brain state
919
+ self.ui.lowerfigure.label_img_ref.set(
920
+ data=self.label_img[
921
+ :, self.lower_left_epoch : (self.lower_right_epoch + 1), :
922
+ ]
923
+ )
924
+ # update timestamps
925
+ x_ticks = resample_x_ticks(
926
+ np.arange(self.lower_left_epoch, self.lower_right_epoch + 1)
927
+ )
928
+ self.ui.lowerfigure.canvas.axes[1].set_xticklabels(
929
+ [
930
+ "{:02d}:{:02d}:{:05.2f}".format(
931
+ int(x // 3600), int(x // 60) % 60, (x % 60)
932
+ )
933
+ for x in x_ticks * self.epoch_length
934
+ ]
935
+ )
936
+
937
+ self.ui.lowerfigure.canvas.draw()
938
+
939
+ def click_to_jump(self, event) -> None:
940
+ """Jump to a new epoch when the user clicks on the upper figure
941
+
942
+ This is the callback for mouse clicks on the upper figure. Clicking on
943
+ any of the subplots will jump to the nearest epoch.
944
+
945
+ :param event: a MouseEvent containing the click data
946
+ """
947
+ # make sure click location is valid
948
+ # and we are not in label ROI mode
949
+ if event.xdata is None or self.label_roi_mode:
950
+ return
951
+
952
+ # get click location
953
+ x = event.xdata
954
+ # if it's a real click, and not one we simulated
955
+ if event.inaxes is not None:
956
+ # if it's on the spectrogram, we have to adjust it slightly
957
+ # since that uses a different x-axis range
958
+ ax_index = self.ui.upperfigure.canvas.axes.index(event.inaxes)
959
+ if ax_index == 3:
960
+ x -= 0.5
961
+
962
+ # get the "zoom level" so we can preserve that
963
+ upper_epochs_shown = self.upper_right_epoch - self.upper_left_epoch + 1
964
+ upper_epoch_padding = round((upper_epochs_shown - 1) / 2)
965
+ # update epoch
966
+ self.epoch = round(np.clip(x, 0, self.n_epochs - 1))
967
+ # update upper figure x-axis limits
968
+ if self.epoch - upper_epoch_padding < 0:
969
+ self.upper_left_epoch = 0
970
+ self.upper_right_epoch = upper_epochs_shown - 1
971
+ elif self.epoch + upper_epoch_padding > self.n_epochs - 1:
972
+ self.upper_right_epoch = self.n_epochs - 1
973
+ self.upper_left_epoch = self.n_epochs - upper_epochs_shown
974
+ else:
975
+ self.upper_left_epoch = self.epoch - upper_epoch_padding
976
+ self.upper_right_epoch = self.epoch + upper_epoch_padding
977
+ self.adjust_upper_figure_x_limits()
978
+
979
+ # update lower figure x-axis range
980
+ lower_epoch_padding = round((self.epochs_to_show - 1) / 2)
981
+ if self.epoch - lower_epoch_padding < 0:
982
+ self.lower_left_epoch = 0
983
+ self.lower_right_epoch = self.epochs_to_show - 1
984
+ elif self.epoch + lower_epoch_padding > self.n_epochs - 1:
985
+ self.lower_right_epoch = self.n_epochs - 1
986
+ self.lower_left_epoch = self.n_epochs - self.epochs_to_show
987
+ else:
988
+ self.lower_left_epoch = self.epoch - lower_epoch_padding
989
+ self.lower_right_epoch = self.epoch + lower_epoch_padding
990
+
991
+ self.update_figures()
992
+
993
+
994
+ def convert_labels(labels: np.array, style: str) -> np.array:
995
+ """Convert labels between "display" and "digit" formats
996
+
997
+ It's useful to represent brain state labels in two ways:
998
+ Digit format: this is how labels are represented in files. It matches the digit
999
+ attribute of the BrainState class as well as the number pressed on the
1000
+ keyboard to set an epoch to that brain state.
1001
+ Display format: the y-axis value associated with a brain state when brain state
1002
+ labels are displayed as an image. This is also the index of the brain state
1003
+ in the colormap. Undefined epochs are mapped to 0, and digits are mapped to
1004
+ the numbers 1-10 in keyboard order (1234567890).
1005
+
1006
+ :param labels: brain state labels
1007
+ :param style: target format for the output
1008
+ :return: formatted labels
1009
+ """
1010
+ if style == DISPLAY_FORMAT:
1011
+ # convert 0 to 10, undefined to 0
1012
+ labels = [i if i != 0 else 10 for i in labels]
1013
+ return np.array([i if i != UNDEFINED_LABEL else 0 for i in labels])
1014
+ elif style == DIGIT_FORMAT:
1015
+ # convert 0 to undefined, 10 to 0
1016
+ labels = [i if i != 0 else UNDEFINED_LABEL for i in labels]
1017
+ return np.array([i if i != 10 else 0 for i in labels])
1018
+ else:
1019
+ raise Exception(f"style must be '{DISPLAY_FORMAT}' or '{DIGIT_FORMAT}'")
1020
+
1021
+
1022
+ def create_label_img(labels: np.array, label_display_options: np.array) -> np.array:
1023
+ """Create an image to display brain state labels
1024
+
1025
+ :param labels: brain state labels, in "display" format
1026
+ :param label_display_options: y-axis locations of valid brain state labels
1027
+ :return: brain state label image
1028
+ """
1029
+ # While there can be up to 10 valid brain states, it's possible that not all of them
1030
+ # are in use. We don't need to display brain states below and above the range of
1031
+ # valid brain states, since those rows would always be empty.
1032
+ smallest_display_label = np.min(label_display_options)
1033
+ # "background" of the image is white
1034
+ label_img = np.ones(
1035
+ [
1036
+ (np.max(label_display_options) - smallest_display_label + 1),
1037
+ len(labels),
1038
+ 4,
1039
+ ]
1040
+ )
1041
+ for i, label in enumerate(labels):
1042
+ if label > 0:
1043
+ label_img[label - smallest_display_label, i, :] = LABEL_CMAP[label]
1044
+ else:
1045
+ # label is undefined
1046
+ label_img[:, i] = np.array([0, 0, 0, 1])
1047
+ return label_img
1048
+
1049
+
1050
+ def create_confidence_img(confidence_scores: np.array) -> np.array:
1051
+ """Create an image to display confidence scores
1052
+
1053
+ :param confidence_scores: confidence scores
1054
+ :return: confidence score image
1055
+ """
1056
+ if confidence_scores is None:
1057
+ return None
1058
+
1059
+ confidence_img = np.ones([1, len(confidence_scores), 3])
1060
+ for i, c in enumerate(confidence_scores):
1061
+ confidence_img[0, i, 1:] = c
1062
+ return confidence_img
1063
+
1064
+
1065
+ def create_upper_emg_signal(
1066
+ emg: np.array, sampling_rate: int | float, epoch_length: int | float
1067
+ ) -> np.array:
1068
+ """Calculate RMS of EMG for each epoch and apply a ceiling
1069
+
1070
+ :param emg: EMG signal
1071
+ :param sampling_rate: sampling rate, in Hz
1072
+ :param epoch_length: epoch length, in seconds
1073
+ :return: processed EMG signal
1074
+ """
1075
+ emg_rms = get_emg_power(
1076
+ emg,
1077
+ sampling_rate,
1078
+ epoch_length,
1079
+ )
1080
+ return np.clip(emg_rms, 0, np.mean(emg_rms) + np.std(emg_rms) * 2.5)
1081
+
1082
+
1083
+ def transform_eeg_emg(eeg: np.array, emg: np.array) -> (np.array, np.array):
1084
+ """Center and scale the EEG and EMG signals
1085
+
1086
+ A heuristic approach to fitting the EEG and EMG signals in the plot.
1087
+
1088
+ :param eeg: EEG signal
1089
+ :param emg: EMG signal
1090
+ :return: centered and scaled signals
1091
+ """
1092
+ eeg = eeg - np.mean(eeg)
1093
+ emg = emg - np.mean(emg)
1094
+ eeg = eeg / np.percentile(eeg, 95) / 2.2
1095
+ emg = emg / np.percentile(emg, 95) / 2.2
1096
+ return eeg, emg