accusleepy 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,24 @@ from types import SimpleNamespace
12
12
 
13
13
  import matplotlib.pyplot as plt
14
14
  import numpy as np
15
- from PySide6 import QtCore, QtGui, QtWidgets
15
+ from PySide6.QtCore import (
16
+ QKeyCombination,
17
+ QRect,
18
+ Qt,
19
+ QUrl,
20
+ )
21
+ from PySide6.QtGui import (
22
+ QCloseEvent,
23
+ QKeySequence,
24
+ QShortcut,
25
+ )
26
+ from PySide6.QtWidgets import (
27
+ QDialog,
28
+ QMessageBox,
29
+ QTextBrowser,
30
+ QVBoxLayout,
31
+ QWidget,
32
+ )
16
33
 
17
34
  from accusleepy.constants import UNDEFINED_LABEL
18
35
  from accusleepy.fileio import load_config, save_labels
@@ -73,7 +90,7 @@ class StateChange:
73
90
  epoch: int # first epoch affected
74
91
 
75
92
 
76
- class ManualScoringWindow(QtWidgets.QDialog):
93
+ class ManualScoringWindow(QDialog):
77
94
  """AccuSleePy manual scoring GUI"""
78
95
 
79
96
  def __init__(
@@ -191,33 +208,25 @@ class ManualScoringWindow(QtWidgets.QDialog):
191
208
  self.update_lower_figure()
192
209
 
193
210
  # user input: keyboard shortcuts
194
- keypress_right = QtGui.QShortcut(
195
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Right), self
196
- )
211
+ keypress_right = QShortcut(QKeySequence(Qt.Key.Key_Right), self)
197
212
  keypress_right.activated.connect(partial(self.shift_epoch, DIRECTION_RIGHT))
198
213
 
199
- keypress_left = QtGui.QShortcut(
200
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Left), self
201
- )
214
+ keypress_left = QShortcut(QKeySequence(Qt.Key.Key_Left), self)
202
215
  keypress_left.activated.connect(partial(self.shift_epoch, DIRECTION_LEFT))
203
216
 
204
217
  keypress_zoom_in_x = list()
205
- for zoom_key in [QtCore.Qt.Key.Key_Plus, QtCore.Qt.Key.Key_Equal]:
206
- keypress_zoom_in_x.append(
207
- QtGui.QShortcut(QtGui.QKeySequence(zoom_key), self)
208
- )
218
+ for zoom_key in [Qt.Key.Key_Plus, Qt.Key.Key_Equal]:
219
+ keypress_zoom_in_x.append(QShortcut(QKeySequence(zoom_key), self))
209
220
  keypress_zoom_in_x[-1].activated.connect(partial(self.zoom_x, ZOOM_IN))
210
221
 
211
- keypress_zoom_out_x = QtGui.QShortcut(
212
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Minus), self
213
- )
222
+ keypress_zoom_out_x = QShortcut(QKeySequence(Qt.Key.Key_Minus), self)
214
223
  keypress_zoom_out_x.activated.connect(partial(self.zoom_x, ZOOM_OUT))
215
224
 
216
225
  keypress_modify_label = list()
217
226
  for brain_state in self.brain_state_set.brain_states:
218
227
  keypress_modify_label.append(
219
- QtGui.QShortcut(
220
- QtGui.QKeySequence(QtCore.Qt.Key[f"Key_{brain_state.digit}"]),
228
+ QShortcut(
229
+ QKeySequence(Qt.Key[f"Key_{brain_state.digit}"]),
221
230
  self,
222
231
  )
223
232
  )
@@ -225,25 +234,19 @@ class ManualScoringWindow(QtWidgets.QDialog):
225
234
  partial(self.modify_current_epoch_label, brain_state.digit)
226
235
  )
227
236
 
228
- keypress_delete_label = QtGui.QShortcut(
229
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Backspace), self
230
- )
237
+ keypress_delete_label = QShortcut(QKeySequence(Qt.Key.Key_Backspace), self)
231
238
  keypress_delete_label.activated.connect(
232
239
  partial(self.modify_current_epoch_label, UNDEFINED_LABEL)
233
240
  )
234
241
 
235
- keypress_quit = QtGui.QShortcut(
236
- QtGui.QKeySequence(
237
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_W)
238
- ),
242
+ keypress_quit = QShortcut(
243
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_W)),
239
244
  self,
240
245
  )
241
246
  keypress_quit.activated.connect(self.close)
242
247
 
243
- keypress_save = QtGui.QShortcut(
244
- QtGui.QKeySequence(
245
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_S)
246
- ),
248
+ keypress_save = QShortcut(
249
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_S)),
247
250
  self,
248
251
  )
249
252
  keypress_save.activated.connect(self.save)
@@ -251,11 +254,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
251
254
  keypress_roi = list()
252
255
  for brain_state in self.brain_state_set.brain_states:
253
256
  keypress_roi.append(
254
- QtGui.QShortcut(
255
- QtGui.QKeySequence(
256
- QtCore.QKeyCombination(
257
- QtCore.Qt.Modifier.SHIFT,
258
- QtCore.Qt.Key[f"Key_{brain_state.digit}"],
257
+ QShortcut(
258
+ QKeySequence(
259
+ QKeyCombination(
260
+ Qt.Modifier.SHIFT,
261
+ Qt.Key[f"Key_{brain_state.digit}"],
259
262
  )
260
263
  ),
261
264
  self,
@@ -265,11 +268,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
265
268
  partial(self.enter_label_roi_mode, brain_state.digit)
266
269
  )
267
270
  keypress_roi.append(
268
- QtGui.QShortcut(
269
- QtGui.QKeySequence(
270
- QtCore.QKeyCombination(
271
- QtCore.Qt.Modifier.SHIFT,
272
- QtCore.Qt.Key.Key_Backspace,
271
+ QShortcut(
272
+ QKeySequence(
273
+ QKeyCombination(
274
+ Qt.Modifier.SHIFT,
275
+ Qt.Key.Key_Backspace,
273
276
  )
274
277
  ),
275
278
  self,
@@ -279,22 +282,18 @@ class ManualScoringWindow(QtWidgets.QDialog):
279
282
  partial(self.enter_label_roi_mode, UNDEFINED_LABEL)
280
283
  )
281
284
 
282
- keypress_esc = QtGui.QShortcut(
283
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Escape), self
284
- )
285
+ keypress_esc = QShortcut(QKeySequence(Qt.Key.Key_Escape), self)
285
286
  keypress_esc.activated.connect(self.exit_label_roi_mode)
286
287
 
287
- keypress_space = QtGui.QShortcut(
288
- QtGui.QKeySequence(QtCore.Qt.Key.Key_Space), self
289
- )
288
+ keypress_space = QShortcut(QKeySequence(Qt.Key.Key_Space), self)
290
289
  keypress_space.activated.connect(
291
290
  partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
292
291
  )
293
- keypress_shift_right = QtGui.QShortcut(
294
- QtGui.QKeySequence(
295
- QtCore.QKeyCombination(
296
- QtCore.Qt.Modifier.SHIFT,
297
- QtCore.Qt.Key.Key_Right,
292
+ keypress_shift_right = QShortcut(
293
+ QKeySequence(
294
+ QKeyCombination(
295
+ Qt.Modifier.SHIFT,
296
+ Qt.Key.Key_Right,
298
297
  )
299
298
  ),
300
299
  self,
@@ -302,11 +301,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
302
301
  keypress_shift_right.activated.connect(
303
302
  partial(self.jump_to_next_state, DIRECTION_RIGHT, DIFFERENT_STATE)
304
303
  )
305
- keypress_shift_left = QtGui.QShortcut(
306
- QtGui.QKeySequence(
307
- QtCore.QKeyCombination(
308
- QtCore.Qt.Modifier.SHIFT,
309
- QtCore.Qt.Key.Key_Left,
304
+ keypress_shift_left = QShortcut(
305
+ QKeySequence(
306
+ QKeyCombination(
307
+ Qt.Modifier.SHIFT,
308
+ Qt.Key.Key_Left,
310
309
  )
311
310
  ),
312
311
  self,
@@ -314,11 +313,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
314
313
  keypress_shift_left.activated.connect(
315
314
  partial(self.jump_to_next_state, DIRECTION_LEFT, DIFFERENT_STATE)
316
315
  )
317
- keypress_ctrl_right = QtGui.QShortcut(
318
- QtGui.QKeySequence(
319
- QtCore.QKeyCombination(
320
- QtCore.Qt.Modifier.CTRL,
321
- QtCore.Qt.Key.Key_Right,
316
+ keypress_ctrl_right = QShortcut(
317
+ QKeySequence(
318
+ QKeyCombination(
319
+ Qt.Modifier.CTRL,
320
+ Qt.Key.Key_Right,
322
321
  )
323
322
  ),
324
323
  self,
@@ -326,11 +325,11 @@ class ManualScoringWindow(QtWidgets.QDialog):
326
325
  keypress_ctrl_right.activated.connect(
327
326
  partial(self.jump_to_next_state, DIRECTION_RIGHT, UNDEFINED_STATE)
328
327
  )
329
- keypress_ctrl_left = QtGui.QShortcut(
330
- QtGui.QKeySequence(
331
- QtCore.QKeyCombination(
332
- QtCore.Qt.Modifier.CTRL,
333
- QtCore.Qt.Key.Key_Left,
328
+ keypress_ctrl_left = QShortcut(
329
+ QKeySequence(
330
+ QKeyCombination(
331
+ Qt.Modifier.CTRL,
332
+ Qt.Key.Key_Left,
334
333
  )
335
334
  ),
336
335
  self,
@@ -339,17 +338,13 @@ class ManualScoringWindow(QtWidgets.QDialog):
339
338
  partial(self.jump_to_next_state, DIRECTION_LEFT, UNDEFINED_STATE)
340
339
  )
341
340
 
342
- keypress_undo = QtGui.QShortcut(
343
- QtGui.QKeySequence(
344
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Z)
345
- ),
341
+ keypress_undo = QShortcut(
342
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Z)),
346
343
  self,
347
344
  )
348
345
  keypress_undo.activated.connect(self.undo)
349
- keypress_redo = QtGui.QShortcut(
350
- QtGui.QKeySequence(
351
- QtCore.QKeyCombination(QtCore.Qt.Modifier.CTRL, QtCore.Qt.Key.Key_Y)
352
- ),
346
+ keypress_redo = QShortcut(
347
+ QKeySequence(QKeyCombination(Qt.Modifier.CTRL, Qt.Key.Key_Y)),
353
348
  self,
354
349
  )
355
350
  keypress_redo.activated.connect(self.redo)
@@ -483,34 +478,34 @@ class ManualScoringWindow(QtWidgets.QDialog):
483
478
  )
484
479
  self.click_to_jump(simulated_click)
485
480
 
486
- def closeEvent(self, event: QtGui.QCloseEvent) -> None:
481
+ def closeEvent(self, event: QCloseEvent) -> None:
487
482
  """Check if there are unsaved changes before closing"""
488
483
  if not all(self.labels == self.last_saved_labels):
489
- result = QtWidgets.QMessageBox.question(
484
+ result = QMessageBox.question(
490
485
  self,
491
486
  "Unsaved changes",
492
487
  "You have unsaved changes. Really quit?",
493
- QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
488
+ QMessageBox.Yes | QMessageBox.No,
494
489
  )
495
- if result == QtWidgets.QMessageBox.Yes:
490
+ if result == QMessageBox.Yes:
496
491
  event.accept()
497
492
  else:
498
493
  event.ignore()
499
494
 
500
495
  def show_user_manual(self) -> None:
501
496
  """Show a popup window with the user manual"""
502
- self.popup = QtWidgets.QWidget()
503
- self.popup_vlayout = QtWidgets.QVBoxLayout(self.popup)
504
- self.guide_textbox = QtWidgets.QTextBrowser(self.popup)
497
+ self.popup = QWidget()
498
+ self.popup_vlayout = QVBoxLayout(self.popup)
499
+ self.guide_textbox = QTextBrowser(self.popup)
505
500
  self.popup_vlayout.addWidget(self.guide_textbox)
506
501
 
507
- url = QtCore.QUrl.fromLocalFile(
502
+ url = QUrl.fromLocalFile(
508
503
  os.path.join(os.path.dirname(os.path.abspath(__file__)), USER_MANUAL_FILE)
509
504
  )
510
505
  self.guide_textbox.setSource(url)
511
506
  self.guide_textbox.setOpenLinks(False)
512
507
 
513
- self.popup.setGeometry(QtCore.QRect(100, 100, 830, 600))
508
+ self.popup.setGeometry(QRect(100, 100, 830, 600))
514
509
  self.popup.show()
515
510
 
516
511
  def jump_to_next_state(self, direction: str, target: str) -> None:
@@ -339,18 +339,23 @@ def resample_x_ticks(x_ticks: np.array) -> np.array:
339
339
  """Choose a subset of x_ticks to display
340
340
 
341
341
  The x-axis can get crowded if there are too many timestamps shown.
342
- This function resamples the x-axis ticks by a factor of either
343
- MAX_LOWER_X_TICK_N or MAX_LOWER_X_TICK_N - 2, whichever is closer
344
- to being a factor of the number of ticks.
342
+ This function finds a subset of evenly spaced x-axis ticks that
343
+ includes the one at the beginning of the central epoch.
345
344
 
346
345
  :param x_ticks: full set of x_ticks
347
346
  :return: smaller subset of x_ticks
348
347
  """
349
- # add one since the tick at the rightmost edge isn't shown
350
- n_ticks = len(x_ticks) + 1
351
- if n_ticks < MAX_LOWER_X_TICK_N:
348
+ if len(x_ticks) <= MAX_LOWER_X_TICK_N:
352
349
  return x_ticks
353
- elif n_ticks % MAX_LOWER_X_TICK_N < n_ticks % (MAX_LOWER_X_TICK_N - 2):
354
- return x_ticks[:: int(n_ticks / MAX_LOWER_X_TICK_N)]
355
- else:
356
- return x_ticks[:: int(n_ticks / (MAX_LOWER_X_TICK_N - 2))]
350
+
351
+ # number of ticks to the left of the central epoch
352
+ # this will always be an integer
353
+ nl = round((len(x_ticks) - 1) / 2)
354
+
355
+ # search for even tick spacings that include the central epoch
356
+ # if necessary, skip the leftmost tick
357
+ for offset in [0, 1]:
358
+ if (nl - offset) % 3 == 0:
359
+ return x_ticks[offset :: round((nl - offset) / 3)]
360
+ elif (nl - offset) % 2 == 0:
361
+ return x_ticks[offset :: round((nl - offset) / 2)]
@@ -30,6 +30,7 @@ from PySide6.QtWidgets import (
30
30
  QVBoxLayout,
31
31
  QWidget,
32
32
  )
33
+
33
34
  import accusleepy.gui.resources_rc # noqa F401
34
35
 
35
36
 
accusleepy/models.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch import nn
2
+ from torch import device, flatten, nn
3
+ from torch import load as torch_load
4
+ from torch import save as torch_save
5
5
 
6
+ from accusleepy.brain_state_set import BRAIN_STATES_KEY, BrainStateSet
6
7
  from accusleepy.constants import (
7
8
  DOWNSAMPLING_START_FREQ,
8
9
  EMG_COPIES,
@@ -41,8 +42,57 @@ class SSANN(nn.Module):
41
42
 
42
43
  def forward(self, x):
43
44
  x = x.float()
44
- x = self.pool(F.relu(self.conv1_bn(self.conv1(x))))
45
- x = self.pool(F.relu(self.conv2_bn(self.conv2(x))))
46
- x = self.pool(F.relu(self.conv3_bn(self.conv3(x))))
47
- x = torch.flatten(x, 1) # flatten all dimensions except batch
45
+ x = self.pool(nn.functional.relu(self.conv1_bn(self.conv1(x))))
46
+ x = self.pool(nn.functional.relu(self.conv2_bn(self.conv2(x))))
47
+ x = self.pool(nn.functional.relu(self.conv3_bn(self.conv3(x))))
48
+ x = flatten(x, 1) # flatten all dimensions except batch
48
49
  return self.fc1(x)
50
+
51
+
52
+ def save_model(
53
+ model: SSANN,
54
+ filename: str,
55
+ epoch_length: int | float,
56
+ epochs_per_img: int,
57
+ model_type: str,
58
+ brain_state_set: BrainStateSet,
59
+ ) -> None:
60
+ """Save classification model and its metadata
61
+
62
+ :param model: classification model
63
+ :param epoch_length: epoch length used when training the model
64
+ :param epochs_per_img: number of epochs in each model input
65
+ :param model_type: default or real-time
66
+ :param brain_state_set: set of brain state options
67
+ :param filename: filename
68
+ """
69
+ state_dict = model.state_dict()
70
+ state_dict.update({"epoch_length": epoch_length})
71
+ state_dict.update({"epochs_per_img": epochs_per_img})
72
+ state_dict.update({"model_type": model_type})
73
+ state_dict.update(
74
+ {BRAIN_STATES_KEY: brain_state_set.to_output_dict()[BRAIN_STATES_KEY]}
75
+ )
76
+
77
+ torch_save(state_dict, filename)
78
+
79
+
80
+ def load_model(filename: str) -> tuple[SSANN, int | float, int, str, dict]:
81
+ """Load classification model and its metadata
82
+
83
+ :param filename: filename
84
+ :return: model, epoch length used when training the model,
85
+ number of epochs in each model input, model type
86
+ (default or real-time), set of brain state options
87
+ used when training the model
88
+ """
89
+ state_dict = torch_load(filename, weights_only=True, map_location=device("cpu"))
90
+ epoch_length = state_dict.pop("epoch_length")
91
+ epochs_per_img = state_dict.pop("epochs_per_img")
92
+ model_type = state_dict.pop("model_type")
93
+ brain_states = state_dict.pop(BRAIN_STATES_KEY)
94
+ n_classes = len([b for b in brain_states if b["is_scored"]])
95
+
96
+ model = SSANN(n_classes=n_classes)
97
+ model.load_state_dict(state_dict)
98
+ return model, epoch_length, epochs_per_img, model_type, brain_states
accusleepy/multitaper.py CHANGED
@@ -15,8 +15,9 @@ import warnings
15
15
 
16
16
  import numpy as np
17
17
  from joblib import Parallel, cpu_count, delayed
18
- from scipy.signal import detrend
19
- from scipy.signal.windows import dpss
18
+
19
+ # from scipy.signal import detrend # unused by AccuSleePy
20
+ # from scipy.signal.windows import dpss # lazily loaded later
20
21
 
21
22
 
22
23
  # MULTITAPER SPECTROGRAM #
@@ -28,14 +29,14 @@ def spectrogram(
28
29
  num_tapers=None,
29
30
  window_params=None,
30
31
  min_nfft=0,
31
- detrend_opt="linear",
32
+ detrend_opt="off", # this functionality is disabled
32
33
  multiprocess=False,
33
34
  n_jobs=None,
34
35
  weighting="unity",
35
36
  plot_on=False,
36
37
  return_fig=False,
37
38
  clim_scale=True,
38
- verbose=True,
39
+ verbose=False,
39
40
  xyflip=False,
40
41
  ax=None,
41
42
  ):
@@ -121,6 +122,7 @@ def spectrogram(
121
122
 
122
123
  __________________________________________________________________________________________________________________
123
124
  """
125
+ from scipy.signal.windows import dpss
124
126
 
125
127
  # Process user input
126
128
  [
@@ -618,9 +620,9 @@ def calc_mts_segment(
618
620
  ret.fill(np.nan)
619
621
  return ret
620
622
 
621
- # Option to detrend data to remove low frequency DC component
622
- if detrend_opt != "off":
623
- data_segment = detrend(data_segment, type=detrend_opt)
623
+ # # Option to detrend data to remove low frequency DC component
624
+ # if detrend_opt != "off":
625
+ # data_segment = detrend(data_segment, type=detrend_opt)
624
626
 
625
627
  # Multiply data by dpss tapers (STEP 2)
626
628
  tapered_data = np.multiply(np.asmatrix(data_segment).T, np.asmatrix(dpss_tapers.T))
@@ -1,17 +1,14 @@
1
1
  import os
2
- import re
3
2
  import warnings
4
- from dataclasses import dataclass
5
- from operator import attrgetter
6
3
 
7
4
  import numpy as np
8
5
  import pandas as pd
9
6
  from PIL import Image
10
- from scipy.signal import butter, filtfilt
11
7
  from tqdm import trange
12
8
 
13
9
  from accusleepy.brain_state_set import BrainStateSet
14
10
  from accusleepy.constants import (
11
+ ANNOTATIONS_FILENAME,
15
12
  DEFAULT_MODEL_TYPE,
16
13
  DOWNSAMPLING_START_FREQ,
17
14
  EMG_COPIES,
@@ -23,13 +20,13 @@ from accusleepy.constants import (
23
20
  from accusleepy.fileio import Recording, load_labels, load_recording
24
21
  from accusleepy.multitaper import spectrogram
25
22
 
23
+ # note: scipy is lazily imported
24
+
26
25
  # clip mixture z-scores above and below this level
27
26
  # in the matlab implementation, I used 4.5
28
27
  ABS_MAX_Z_SCORE = 3.5
29
28
  # upper frequency limit when generating EEG spectrograms
30
29
  SPECTROGRAM_UPPER_FREQ = 64
31
- # filename used to store info about training image datasets
32
- ANNOTATIONS_FILENAME = "annotations.csv"
33
30
 
34
31
 
35
32
  def resample(
@@ -186,6 +183,8 @@ def get_emg_power(
186
183
  :param epoch_length: epoch length, in seconds
187
184
  :return: EMG "power" for each epoch
188
185
  """
186
+ from scipy.signal import butter, filtfilt
187
+
189
188
  # filter parameters
190
189
  order = 8
191
190
  bp_lower = 20
@@ -450,140 +449,3 @@ def create_training_images(
450
449
  )
451
450
 
452
451
  return failed_recordings
453
-
454
-
455
- @dataclass
456
- class Bout:
457
- """Stores information about a brain state bout"""
458
-
459
- length: int # length, in number of epochs
460
- start_index: int # index where bout starts
461
- end_index: int # index where bout ends
462
- surrounding_state: int # brain state on both sides of the bout
463
-
464
-
465
- def find_last_adjacent_bout(sorted_bouts: list[Bout], bout_index: int) -> int:
466
- """Find index of last consecutive same-length bout
467
-
468
- When running the post-processing step that enforces a minimum duration
469
- for brain state bouts, there is a special case when bouts below the
470
- duration threshold occur consecutively. This function performs a
471
- recursive search for the index of a bout at the end of such a sequence.
472
- When initially called, bout_index will always be 0. If, for example, the
473
- first three bouts in the list are consecutive, the function will return 2.
474
-
475
- :param sorted_bouts: list of brain state bouts, sorted by start time
476
- :param bout_index: index of the bout in question
477
- :return: index of the last consecutive same-length bout
478
- """
479
- # if we're at the end of the bout list, stop
480
- if bout_index == len(sorted_bouts) - 1:
481
- return bout_index
482
-
483
- # if there is an adjacent bout
484
- if sorted_bouts[bout_index].end_index == sorted_bouts[bout_index + 1].start_index:
485
- # look for more adjacent bouts using that one as a starting point
486
- return find_last_adjacent_bout(sorted_bouts, bout_index + 1)
487
- else:
488
- return bout_index
489
-
490
-
491
- def enforce_min_bout_length(
492
- labels: np.array, epoch_length: int | float, min_bout_length: int | float
493
- ) -> np.array:
494
- """Ensure brain state bouts meet the min length requirement
495
-
496
- As a post-processing step for sleep scoring, we can require that any
497
- bout (continuous period) of a brain state have a minimum duration.
498
- This function sets any bout shorter than the minimum duration to the
499
- surrounding brain state (if the states on the left and right sides
500
- are the same). In the case where there are consecutive short bouts,
501
- it either creates a transition at the midpoint or removes all short
502
- bouts, depending on whether the number is even or odd. For example:
503
- ...AAABABAAA... -> ...AAAAAAAAA...
504
- ...AAABABABBB... -> ...AAAAABBBBB...
505
-
506
- :param labels: brain state labels (digits in the 0-9 range)
507
- :param epoch_length: epoch length, in seconds
508
- :param min_bout_length: minimum bout length, in seconds
509
- :return: updated brain state labels
510
- """
511
- # if recording is very short, don't change anything
512
- if labels.size < 3:
513
- return labels
514
-
515
- if epoch_length == min_bout_length:
516
- return labels
517
-
518
- # get minimum number of epochs in a bout
519
- min_epochs = int(np.ceil(min_bout_length / epoch_length))
520
- # get set of states in the labels
521
- brain_states = set(labels.tolist())
522
-
523
- while True: # so true
524
- # convert labels to a string for regex search
525
- # There is probably a regex that can find all patterns like ab+a
526
- # without consuming each "a" but I haven't found it :(
527
- label_string = "".join(labels.astype(str))
528
-
529
- bouts = list()
530
-
531
- for state in brain_states:
532
- for other_state in brain_states:
533
- if state == other_state:
534
- continue
535
- # get start and end indices of each bout
536
- expression = (
537
- f"(?<={other_state}){state}{{1,{min_epochs - 1}}}(?={other_state})"
538
- )
539
- matches = re.finditer(expression, label_string)
540
- spans = [match.span() for match in matches]
541
-
542
- # if some bouts were found
543
- for span in spans:
544
- bouts.append(
545
- Bout(
546
- length=span[1] - span[0],
547
- start_index=span[0],
548
- end_index=span[1],
549
- surrounding_state=other_state,
550
- )
551
- )
552
-
553
- if len(bouts) == 0:
554
- break
555
-
556
- # only keep the shortest bouts
557
- min_length_in_list = np.min([bout.length for bout in bouts])
558
- bouts = [i for i in bouts if i.length == min_length_in_list]
559
- # sort by start index
560
- sorted_bouts = sorted(bouts, key=attrgetter("start_index"))
561
-
562
- while len(sorted_bouts) > 0:
563
- # get row index of latest adjacent bout (of same length)
564
- last_adjacent_bout_index = find_last_adjacent_bout(sorted_bouts, 0)
565
- # if there's an even number of adjacent bouts
566
- if (last_adjacent_bout_index + 1) % 2 == 0:
567
- midpoint = sorted_bouts[
568
- round((last_adjacent_bout_index + 1) / 2)
569
- ].start_index
570
- labels[sorted_bouts[0].start_index : midpoint] = sorted_bouts[
571
- 0
572
- ].surrounding_state
573
- labels[midpoint : sorted_bouts[last_adjacent_bout_index].end_index] = (
574
- sorted_bouts[last_adjacent_bout_index].surrounding_state
575
- )
576
- else:
577
- labels[
578
- sorted_bouts[0].start_index : sorted_bouts[
579
- last_adjacent_bout_index
580
- ].end_index
581
- ] = sorted_bouts[0].surrounding_state
582
-
583
- # delete the bouts we just fixed
584
- if last_adjacent_bout_index == len(sorted_bouts) - 1:
585
- sorted_bouts = []
586
- else:
587
- sorted_bouts = sorted_bouts[(last_adjacent_bout_index + 1) :]
588
-
589
- return labels