braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def square(x):
|
|
11
|
+
return x * x
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
|
|
15
|
+
"""Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
|
|
16
|
+
return torch.log(torch.clamp(x, min=eps))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def identity(x):
|
|
20
|
+
return x
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def drop_path(
|
|
24
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
25
|
+
):
|
|
26
|
+
"""Drop paths (Stochastic Depth) per sample.
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
Notes: This implementation is taken from timm library.
|
|
30
|
+
|
|
31
|
+
All credit goes to Ross Wightman.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
x: torch.Tensor
|
|
36
|
+
input tensor
|
|
37
|
+
drop_prob : float, optional
|
|
38
|
+
survival rate (i.e. probability of being kept), by default 0.0
|
|
39
|
+
training : bool, optional
|
|
40
|
+
whether the model is in training mode, by default False
|
|
41
|
+
scale_by_keep : bool, optional
|
|
42
|
+
whether to scale output by (1/keep_prob) during training, by default True
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
torch.Tensor
|
|
47
|
+
output tensor
|
|
48
|
+
|
|
49
|
+
Notes from Ross Wightman:
|
|
50
|
+
(when applied in main path of residual blocks)
|
|
51
|
+
This is the same as the DropConnect impl I created for EfficientNet,
|
|
52
|
+
etc. networks, however,
|
|
53
|
+
the original name is misleading as 'Drop Connect' is a different form
|
|
54
|
+
of dropout in a separate paper...
|
|
55
|
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
|
|
56
|
+
... I've opted for changing the layer and argument names to 'drop path'
|
|
57
|
+
rather than mix DropConnect as a layer name and use
|
|
58
|
+
'survival rate' as the argument.
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
if drop_prob == 0.0 or not training:
|
|
62
|
+
return x
|
|
63
|
+
keep_prob = 1 - drop_prob
|
|
64
|
+
shape = (x.shape[0],) + (1,) * (
|
|
65
|
+
x.ndim - 1
|
|
66
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
67
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
68
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
69
|
+
random_tensor.div_(keep_prob)
|
|
70
|
+
return x * random_tensor
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
74
|
+
"""
|
|
75
|
+
Generates a 1-dimensional Gaussian kernel based on the specified kernel
|
|
76
|
+
size and standard deviation (sigma).
|
|
77
|
+
This kernel is useful for Gaussian smoothing or filtering operations in
|
|
78
|
+
image processing. The function calculates a range limit to ensure the kernel
|
|
79
|
+
effectively covers the Gaussian distribution. It generates a tensor of
|
|
80
|
+
specified size and type, filled with values distributed according to a
|
|
81
|
+
Gaussian curve, normalized using a softmax function
|
|
82
|
+
to ensure all weights sum to 1.
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
kernel_size: int
|
|
88
|
+
sigma: float
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
kernel1d: torch.Tensor
|
|
93
|
+
|
|
94
|
+
Notes
|
|
95
|
+
-----
|
|
96
|
+
Code copied and modified from TorchVision:
|
|
97
|
+
https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
|
|
98
|
+
All rights reserved.
|
|
99
|
+
|
|
100
|
+
LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
ksize_half = (kernel_size - 1) * 0.5
|
|
104
|
+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
105
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
106
|
+
kernel1d = pdf / pdf.sum()
|
|
107
|
+
return kernel1d
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def hilbert_freq(x, forward_fourier=True):
|
|
111
|
+
r"""
|
|
112
|
+
Compute the Hilbert transform using PyTorch, separating the real and
|
|
113
|
+
imaginary parts.
|
|
114
|
+
|
|
115
|
+
The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
|
|
116
|
+
is defined as:
|
|
117
|
+
|
|
118
|
+
.. math::
|
|
119
|
+
|
|
120
|
+
x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
|
|
121
|
+
|
|
122
|
+
where:
|
|
123
|
+
- :math:`\mathcal{F}` is the Fourier transform,
|
|
124
|
+
- :math:`U(f)` is the unit step function,
|
|
125
|
+
- :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
input : torch.Tensor
|
|
131
|
+
Input tensor. The expected shape depends on the `forward_fourier` parameter:
|
|
132
|
+
|
|
133
|
+
- If `forward_fourier` is True:
|
|
134
|
+
(..., seq_len)
|
|
135
|
+
- If `forward_fourier` is False:
|
|
136
|
+
(..., seq_len / 2 + 1, 2)
|
|
137
|
+
|
|
138
|
+
forward_fourier : bool, optional
|
|
139
|
+
Determines the format of the input tensor.
|
|
140
|
+
- If True, the input is in the forward Fourier domain.
|
|
141
|
+
- If False, the input contains separate real and imaginary parts.
|
|
142
|
+
Default is True.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
torch.Tensor
|
|
147
|
+
Output tensor with shape (..., seq_len, 2), where the last dimension represents
|
|
148
|
+
the real and imaginary parts of the Hilbert transform.
|
|
149
|
+
|
|
150
|
+
Examples
|
|
151
|
+
--------
|
|
152
|
+
>>> import torch
|
|
153
|
+
>>> input = torch.randn(10, 100) # Example input tensor
|
|
154
|
+
>>> output = hilbert_transform(input)
|
|
155
|
+
>>> print(output.shape)
|
|
156
|
+
torch.Size([10, 100, 2])
|
|
157
|
+
|
|
158
|
+
Notes
|
|
159
|
+
-----
|
|
160
|
+
The implementation is matching scipy implementation, but using torch.
|
|
161
|
+
https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
if forward_fourier:
|
|
165
|
+
x = torch.fft.rfft(x, norm=None, dim=-1)
|
|
166
|
+
x = torch.view_as_real(x)
|
|
167
|
+
x = x * 2.0
|
|
168
|
+
x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
|
|
169
|
+
x = F.pad(
|
|
170
|
+
x, [0, 0, 0, x.shape[-2] - 2]
|
|
171
|
+
) # Fill Fourier coefficients to retain shape
|
|
172
|
+
x = torch.view_as_complex(x)
|
|
173
|
+
x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
|
|
174
|
+
x = torch.view_as_real(x)
|
|
175
|
+
|
|
176
|
+
return x
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
|
|
180
|
+
"""Compute the Phase Locking Value (PLV) metric in the time domain.
|
|
181
|
+
|
|
182
|
+
The Phase Locking Value (PLV) is a measure of the synchronization between
|
|
183
|
+
different channels by evaluating the consistency of phase differences
|
|
184
|
+
over time. It ranges from 0 (no synchronization) to 1 (perfect
|
|
185
|
+
synchronization) [1]_.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
x : torch.Tensor
|
|
190
|
+
Input tensor containing the signal data.
|
|
191
|
+
- If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
|
|
192
|
+
- If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
|
|
193
|
+
where the last dimension represents the real and imaginary parts.
|
|
194
|
+
forward_fourier : bool, optional
|
|
195
|
+
Specifies the format of the input tensor `x`.
|
|
196
|
+
- If `True`, `x` is assumed to be in the time domain.
|
|
197
|
+
- If `False`, `x` is assumed to be in the Fourier domain with separate real and
|
|
198
|
+
imaginary components.
|
|
199
|
+
Default is `True`.
|
|
200
|
+
epsilon : float, default 1e-6
|
|
201
|
+
Small numerical value to ensure positivity constraint on the complex part
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
plv : torch.Tensor
|
|
206
|
+
The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
|
|
207
|
+
element `[i, j]` represents the PLV between channel `i` and channel `j`.
|
|
208
|
+
|
|
209
|
+
References
|
|
210
|
+
----------
|
|
211
|
+
[1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
|
|
212
|
+
Measuring phase synchrony in brain signals. Human brain mapping,
|
|
213
|
+
8(4), 194-208.
|
|
214
|
+
"""
|
|
215
|
+
# Compute the analytic signal using the Hilbert transform.
|
|
216
|
+
# x_a has separate real and imaginary parts.
|
|
217
|
+
analytic_signal = hilbert_freq(x, forward_fourier)
|
|
218
|
+
# Calculate the amplitude (magnitude) of the analytic signal.
|
|
219
|
+
# Adding a small epsilon (1e-6) to avoid division by zero.
|
|
220
|
+
amplitude = torch.sqrt(
|
|
221
|
+
analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
|
|
222
|
+
)
|
|
223
|
+
# Normalize the analytic signal to obtain unit vectors (phasors).
|
|
224
|
+
unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
|
|
225
|
+
|
|
226
|
+
# Compute the real part of the outer product between phasors of
|
|
227
|
+
# different channels.
|
|
228
|
+
real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
|
|
229
|
+
|
|
230
|
+
# Compute the imaginary part of the outer product between phasors of
|
|
231
|
+
# different channels.
|
|
232
|
+
imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
|
|
233
|
+
|
|
234
|
+
# Compute the cross-terms for the real and imaginary parts.
|
|
235
|
+
real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
|
|
236
|
+
imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
|
|
237
|
+
|
|
238
|
+
# Combine the real and imaginary parts to form the complex correlation.
|
|
239
|
+
correlation_real = real_real + imag_imag
|
|
240
|
+
correlation_imag = real_imag - imag_real
|
|
241
|
+
|
|
242
|
+
# Determine the number of time points (or frequency bins if in Fourier domain).
|
|
243
|
+
time = amplitude.shape[-1]
|
|
244
|
+
|
|
245
|
+
# Calculate the PLV by averaging the magnitude of the complex correlation over time.
|
|
246
|
+
# epsilon is small numerical value to ensure positivity constraint on the complex part
|
|
247
|
+
plv_matrix = (
|
|
248
|
+
1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return plv_matrix
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def glorot_weight_zero_bias(model):
|
|
7
|
+
"""Initialize parameters of all modules by initializing weights with
|
|
8
|
+
glorot
|
|
9
|
+
uniform/xavier initialization, and setting biases to zero. Weights from
|
|
10
|
+
batch norm layers are set to 1.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
model: Module
|
|
15
|
+
"""
|
|
16
|
+
for module in model.modules():
|
|
17
|
+
if hasattr(module, "weight"):
|
|
18
|
+
if "BatchNorm" in module.__class__.__name__:
|
|
19
|
+
nn.init.constant_(module.weight, 1)
|
|
20
|
+
if hasattr(module, "bias"):
|
|
21
|
+
if module.bias is not None:
|
|
22
|
+
nn.init.constant_(module.bias, 0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def rescale_parameter(param, layer_id):
|
|
26
|
+
r"""Recaling the l-th transformer layer.
|
|
27
|
+
|
|
28
|
+
Rescales the parameter tensor by the inverse square root of the layer id.
|
|
29
|
+
Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
|
|
30
|
+
|
|
31
|
+
In the labram, this is used to rescale the output matrices
|
|
32
|
+
(i.e., the last linear projection within each sub-layer) of the
|
|
33
|
+
self-attention module.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
param: :class:`torch.Tensor`
|
|
38
|
+
tensor to be rescaled
|
|
39
|
+
layer_id: int
|
|
40
|
+
layer id in the neural network
|
|
41
|
+
|
|
42
|
+
References
|
|
43
|
+
----------
|
|
44
|
+
[Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
|
|
45
|
+
Pre-Training of Image Transformers.
|
|
46
|
+
"""
|
|
47
|
+
param.div_(math.sqrt(2.0 * layer_id))
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Some predefined network architectures for EEG decoding.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .atcnet import ATCNet
|
|
6
|
+
from .attentionbasenet import AttentionBaseNet
|
|
7
|
+
from .base import EEGModuleMixin
|
|
8
|
+
from .biot import BIOT
|
|
9
|
+
from .contrawr import ContraWR
|
|
10
|
+
from .ctnet import CTNet
|
|
11
|
+
from .deep4 import Deep4Net
|
|
12
|
+
from .deepsleepnet import DeepSleepNet
|
|
13
|
+
from .eegconformer import EEGConformer
|
|
14
|
+
from .eeginception_erp import EEGInceptionERP
|
|
15
|
+
from .eeginception_mi import EEGInceptionMI
|
|
16
|
+
from .eegitnet import EEGITNet
|
|
17
|
+
from .eegminer import EEGMiner
|
|
18
|
+
from .eegnet import EEGNetv1, EEGNetv4
|
|
19
|
+
from .eegnex import EEGNeX
|
|
20
|
+
from .eegresnet import EEGResNet
|
|
21
|
+
from .eegsimpleconv import EEGSimpleConv
|
|
22
|
+
from .eegtcnet import EEGTCNet
|
|
23
|
+
from .fbcnet import FBCNet
|
|
24
|
+
from .fblightconvnet import FBLightConvNet
|
|
25
|
+
from .fbmsnet import FBMSNet
|
|
26
|
+
from .hybrid import HybridNet
|
|
27
|
+
from .ifnet import IFNet
|
|
28
|
+
from .labram import Labram
|
|
29
|
+
from .msvtnet import MSVTNet
|
|
30
|
+
from .sccnet import SCCNet
|
|
31
|
+
from .shallow_fbcsp import ShallowFBCSPNet
|
|
32
|
+
from .signal_jepa import (
|
|
33
|
+
SignalJEPA,
|
|
34
|
+
SignalJEPA_Contextual,
|
|
35
|
+
SignalJEPA_PostLocal,
|
|
36
|
+
SignalJEPA_PreLocal,
|
|
37
|
+
)
|
|
38
|
+
from .sinc_shallow import SincShallowNet
|
|
39
|
+
from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
|
|
40
|
+
from .sleep_stager_chambon_2018 import SleepStagerChambon2018
|
|
41
|
+
from .sleep_stager_eldele_2021 import SleepStagerEldele2021
|
|
42
|
+
from .sparcnet import SPARCNet
|
|
43
|
+
from .syncnet import SyncNet
|
|
44
|
+
from .tcn import BDTCN, TCN
|
|
45
|
+
from .tidnet import TIDNet
|
|
46
|
+
from .tsinception import TSceptionV1
|
|
47
|
+
from .usleep import USleep
|
|
48
|
+
from .util import _init_models_dict, models_mandatory_parameters
|
|
49
|
+
|
|
50
|
+
# Call this last in order to make sure the dataset list is populated with
|
|
51
|
+
# the models imported in this file.
|
|
52
|
+
_init_models_dict()
|