eegpp3-beta 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eegpp3_beta-1.0/LICENSE +1 -0
- eegpp3_beta-1.0/PKG-INFO +69 -0
- eegpp3_beta-1.0/README.md +43 -0
- eegpp3_beta-1.0/pyproject.toml +38 -0
- eegpp3_beta-1.0/setup.cfg +4 -0
- eegpp3_beta-1.0/src/eegpp2/TOKEN.txt +3 -0
- eegpp3_beta-1.0/src/eegpp2/__init__.py +6 -0
- eegpp3_beta-1.0/src/eegpp2/__main__.py +98 -0
- eegpp3_beta-1.0/src/eegpp2/configs/__init__.py +7 -0
- eegpp3_beta-1.0/src/eegpp2/configs/cnn1d_config.yml +20 -0
- eegpp3_beta-1.0/src/eegpp2/configs/fft2cwout_config.yml +0 -0
- eegpp3_beta-1.0/src/eegpp2/configs/fftcnn1dnc_config.yml +48 -0
- eegpp3_beta-1.0/src/eegpp2/configs/ffttransnc_config.yml +7 -0
- eegpp3_beta-1.0/src/eegpp2/configs/stftcnn1dnc_config.yml +24 -0
- eegpp3_beta-1.0/src/eegpp2/configs/stfttransnc_config.yml +10 -0
- eegpp3_beta-1.0/src/eegpp2/configs/transformer_config.yml +7 -0
- eegpp3_beta-1.0/src/eegpp2/configs/wtcnn1dnc_config.yml +49 -0
- eegpp3_beta-1.0/src/eegpp2/configs/wtresnet501dnc_config.yml +3 -0
- eegpp3_beta-1.0/src/eegpp2/data/__init__.py +39 -0
- eegpp3_beta-1.0/src/eegpp2/dataloader.py +177 -0
- eegpp3_beta-1.0/src/eegpp2/dataset.py +114 -0
- eegpp3_beta-1.0/src/eegpp2/inference.py +311 -0
- eegpp3_beta-1.0/src/eegpp2/logger.py +67 -0
- eegpp3_beta-1.0/src/eegpp2/models/__init__.py +0 -0
- eegpp3_beta-1.0/src/eegpp2/models/baseline/__init__.py +0 -0
- eegpp3_beta-1.0/src/eegpp2/models/baseline/cnn1d.py +101 -0
- eegpp3_beta-1.0/src/eegpp2/models/baseline/fft.py +23 -0
- eegpp3_beta-1.0/src/eegpp2/models/baseline/resnet.py +148 -0
- eegpp3_beta-1.0/src/eegpp2/models/baseline/transformer.py +196 -0
- eegpp3_beta-1.0/src/eegpp2/models/cnn1d2c.py +123 -0
- eegpp3_beta-1.0/src/eegpp2/models/fft2c.py +59 -0
- eegpp3_beta-1.0/src/eegpp2/models/fftcnn1dnc.py +158 -0
- eegpp3_beta-1.0/src/eegpp2/models/ffttransnc.py +138 -0
- eegpp3_beta-1.0/src/eegpp2/models/stft2c.py +13 -0
- eegpp3_beta-1.0/src/eegpp2/models/stftcnn1dnc.py +166 -0
- eegpp3_beta-1.0/src/eegpp2/models/stfttransnc.py +186 -0
- eegpp3_beta-1.0/src/eegpp2/models/wtcnn1dnc.py +136 -0
- eegpp3_beta-1.0/src/eegpp2/models/wtresnet1dnc.py +105 -0
- eegpp3_beta-1.0/src/eegpp2/out/__init__.py +8 -0
- eegpp3_beta-1.0/src/eegpp2/overwrite_template.py +65 -0
- eegpp3_beta-1.0/src/eegpp2/params.py +27 -0
- eegpp3_beta-1.0/src/eegpp2/trainer.py +415 -0
- eegpp3_beta-1.0/src/eegpp2/trainer_backup.py +148 -0
- eegpp3_beta-1.0/src/eegpp2/utils/__init__.py +0 -0
- eegpp3_beta-1.0/src/eegpp2/utils/common_utils.py +85 -0
- eegpp3_beta-1.0/src/eegpp2/utils/config_utils.py +11 -0
- eegpp3_beta-1.0/src/eegpp2/utils/data_utils.py +265 -0
- eegpp3_beta-1.0/src/eegpp2/utils/model_utils.py +74 -0
- eegpp3_beta-1.0/src/eegpp2/visualization.py +209 -0
- eegpp3_beta-1.0/src/eegpp3_beta.egg-info/PKG-INFO +69 -0
- eegpp3_beta-1.0/src/eegpp3_beta.egg-info/SOURCES.txt +52 -0
- eegpp3_beta-1.0/src/eegpp3_beta.egg-info/dependency_links.txt +1 -0
- eegpp3_beta-1.0/src/eegpp3_beta.egg-info/requires.txt +13 -0
- eegpp3_beta-1.0/src/eegpp3_beta.egg-info/top_level.txt +1 -0
eegpp3_beta-1.0/LICENSE
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
(C) Tuan Vu from DSLab, SoICT, HUST
|
eegpp3_beta-1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: eegpp3-beta
|
|
3
|
+
Version: 1.0
|
|
4
|
+
Summary: EEG Phrase Predictor ver 3
|
|
5
|
+
Author-email: "Duc Anh Nguyen, Vu An Tuan" <vuanhtuan1407@gmail.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: dropbox>=12.0.2
|
|
13
|
+
Requires-Dist: joblib>=1.4.2
|
|
14
|
+
Requires-Dist: numpy>=2.1.1
|
|
15
|
+
Requires-Dist: tqdm>=4.66.5
|
|
16
|
+
Requires-Dist: torchinfo>=1.8.0
|
|
17
|
+
Requires-Dist: scipy>=1.14.1
|
|
18
|
+
Requires-Dist: pyyaml>=6.0.2
|
|
19
|
+
Requires-Dist: pywavelets>=1.7.0
|
|
20
|
+
Requires-Dist: ptwt>=0.1.9
|
|
21
|
+
Requires-Dist: pandas>=2.2.2
|
|
22
|
+
Requires-Dist: lightning>=2.4.0
|
|
23
|
+
Requires-Dist: scikit-learn>=1.5.2
|
|
24
|
+
Requires-Dist: seaborn>=0.13.2
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
# EEG Phase Predictor ver 2
|
|
28
|
+
|
|
29
|
+
**Note: This is beta version, use for training with default dataset and inference only
|
|
30
|
+
|
|
31
|
+
## Setup
|
|
32
|
+
|
|
33
|
+
- Requirements: python >= 3.10
|
|
34
|
+
- Installing:
|
|
35
|
+
|
|
36
|
+
```aiignore
|
|
37
|
+
pip install eegpp3-beta
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Train with default dataset
|
|
41
|
+
|
|
42
|
+
```aiignore
|
|
43
|
+
python -m eegpp3 --mode "train" --model_type <mode_type> --lr <learning_rate> --batch_size <batch_size> --n_epochs <num_epochs> --n_splits <num_folds> --resume_checkpoint <resume_from_checkpoint>
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
ex: `python -m eegpp3 --mode "train" --model_type "stftcnn1dnc" --n_epochs 20 --n_splits 10 --resume_checkpoint False`
|
|
47
|
+
|
|
48
|
+
## Inference
|
|
49
|
+
|
|
50
|
+
1. Old version (**Recommend**)
|
|
51
|
+
|
|
52
|
+
```aiignore
|
|
53
|
+
python -m eegpp3 -p <path_to_infer_config.yml>
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
ex: `python -m eegpp3 -p "data_config_infer.yml"`
|
|
57
|
+
|
|
58
|
+
2. New version
|
|
59
|
+
|
|
60
|
+
```aiignore
|
|
61
|
+
python -m eegpp3 --mode "infer" --data_path <path_to_data_file> --infer_path <path_to_saving_file> --model_type <model_type>
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
ex:
|
|
65
|
+
`python -m eegpp3 --mode "infer" --data_path "./dump_eeg_1.pkl" --infer_path './inference_result.txt" --model_type "stftcnn1dnc"`
|
|
66
|
+
|
|
67
|
+
## Model type
|
|
68
|
+
|
|
69
|
+
- stftcnn1dnc: Multi-channels STFT-CNN
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# EEG Phase Predictor ver 2
|
|
2
|
+
|
|
3
|
+
**Note: This is beta version, use for training with default dataset and inference only
|
|
4
|
+
|
|
5
|
+
## Setup
|
|
6
|
+
|
|
7
|
+
- Requirements: python >= 3.10
|
|
8
|
+
- Installing:
|
|
9
|
+
|
|
10
|
+
```aiignore
|
|
11
|
+
pip install eegpp3-beta
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Train with default dataset
|
|
15
|
+
|
|
16
|
+
```aiignore
|
|
17
|
+
python -m eegpp3 --mode "train" --model_type <mode_type> --lr <learning_rate> --batch_size <batch_size> --n_epochs <num_epochs> --n_splits <num_folds> --resume_checkpoint <resume_from_checkpoint>
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
ex: `python -m eegpp3 --mode "train" --model_type "stftcnn1dnc" --n_epochs 20 --n_splits 10 --resume_checkpoint False`
|
|
21
|
+
|
|
22
|
+
## Inference
|
|
23
|
+
|
|
24
|
+
1. Old version (**Recommend**)
|
|
25
|
+
|
|
26
|
+
```aiignore
|
|
27
|
+
python -m eegpp3 -p <path_to_infer_config.yml>
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
ex: `python -m eegpp3 -p "data_config_infer.yml"`
|
|
31
|
+
|
|
32
|
+
2. New version
|
|
33
|
+
|
|
34
|
+
```aiignore
|
|
35
|
+
python -m eegpp3 --mode "infer" --data_path <path_to_data_file> --infer_path <path_to_saving_file> --model_type <model_type>
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
ex:
|
|
39
|
+
`python -m eegpp3 --mode "infer" --data_path "./dump_eeg_1.pkl" --infer_path './inference_result.txt" --model_type "stftcnn1dnc"`
|
|
40
|
+
|
|
41
|
+
## Model type
|
|
42
|
+
|
|
43
|
+
- stftcnn1dnc: Multi-channels STFT-CNN
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
[project]
|
|
5
|
+
name = "eegpp3-beta"
|
|
6
|
+
version = "1.0"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Duc Anh Nguyen, Vu An Tuan", email = "vuanhtuan1407@gmail.com" },
|
|
9
|
+
]
|
|
10
|
+
description = "EEG Phrase Predictor ver 3"
|
|
11
|
+
readme = "README.md"
|
|
12
|
+
requires-python = ">=3.10"
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Programming Language :: Python :: 3",
|
|
15
|
+
"License :: OSI Approved :: MIT License",
|
|
16
|
+
"Operating System :: OS Independent",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
dependencies = [
|
|
20
|
+
'dropbox>=12.0.2',
|
|
21
|
+
'joblib>=1.4.2',
|
|
22
|
+
'numpy>=2.1.1',
|
|
23
|
+
'tqdm>=4.66.5',
|
|
24
|
+
'torchinfo>=1.8.0',
|
|
25
|
+
'scipy>=1.14.1',
|
|
26
|
+
'pyyaml>=6.0.2',
|
|
27
|
+
'pywavelets>=1.7.0',
|
|
28
|
+
'ptwt>=0.1.9',
|
|
29
|
+
'pandas>=2.2.2',
|
|
30
|
+
'lightning>=2.4.0',
|
|
31
|
+
'scikit-learn>=1.5.2',
|
|
32
|
+
'seaborn>=0.13.2',
|
|
33
|
+
]
|
|
34
|
+
[tool.setuptools.packages.find]
|
|
35
|
+
where = ["src"]
|
|
36
|
+
|
|
37
|
+
[tool.setuptools.package-data]
|
|
38
|
+
eegpp2 = ["configs/*.yml", "TOKEN.txt"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from . import params
|
|
4
|
+
from .inference import infer, infer2
|
|
5
|
+
from .trainer import EEGKFoldTrainer
|
|
6
|
+
from .visualization import visualize_results
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def parse_arguments():
|
|
10
|
+
parser = argparse.ArgumentParser()
|
|
11
|
+
parser.add_argument("--mode", type=str, default='infer', choices=['infer', 'train'])
|
|
12
|
+
parser.add_argument("--model_type", type=str, default="stftcnn1dnc")
|
|
13
|
+
parser.add_argument("--lr", type=float, default=5e-4)
|
|
14
|
+
parser.add_argument("--batch_size", type=int, default=10)
|
|
15
|
+
parser.add_argument("--n_epochs", type=int, default=2)
|
|
16
|
+
parser.add_argument("--n_splits", type=int, default=2)
|
|
17
|
+
parser.add_argument("--n_workers", type=int, default=0)
|
|
18
|
+
parser.add_argument("--resume_checkpoint", type=bool, default=False)
|
|
19
|
+
parser.add_argument("--checkpoint_path", type=str, default=None)
|
|
20
|
+
parser.add_argument('--auto_visualize', type=bool, default=True)
|
|
21
|
+
parser.add_argument("--early_stopping", type=int, default=None)
|
|
22
|
+
parser.add_argument("--export_torchscript", type=bool, default=False)
|
|
23
|
+
parser.add_argument('--data_file', type=str, default='default',
|
|
24
|
+
help='Need to specify data file path when in infer mode')
|
|
25
|
+
parser.add_argument("--infer_path", type=str, default=None)
|
|
26
|
+
parser.add_argument("--remove_tmp", type=bool, default=True)
|
|
27
|
+
parser.add_argument('--accelerator', type=str, default='auto')
|
|
28
|
+
parser.add_argument("-p", "--yaml_config_path", type=str, default='../../data_config_infer2.yml')
|
|
29
|
+
|
|
30
|
+
return parser.parse_args()
|
|
31
|
+
|
|
32
|
+
#
|
|
33
|
+
# def parse_options():
|
|
34
|
+
# parser = argparse.ArgumentParser()
|
|
35
|
+
# parser.add_argument('--data_file', type=str, default='default')
|
|
36
|
+
# parser.add_argument("--infer_path", type=str, default=None)
|
|
37
|
+
# parser.add_argument("--model_type", type=str, default="stftcnn1dnc")
|
|
38
|
+
# parser.add_argument("--batch_size", type=int, default=10)
|
|
39
|
+
# parser.add_argument("--n_workers", type=int, default=0)
|
|
40
|
+
# parser.add_argument("--checkpoint_path", type=str, default=None)
|
|
41
|
+
# return parser.parse_args()
|
|
42
|
+
#
|
|
43
|
+
|
|
44
|
+
def run():
|
|
45
|
+
args = parse_arguments()
|
|
46
|
+
if args.mode == 'train':
|
|
47
|
+
trainer = EEGKFoldTrainer(
|
|
48
|
+
model_type=args.model_type,
|
|
49
|
+
lr=args.lr,
|
|
50
|
+
batch_size=args.batch_size,
|
|
51
|
+
n_splits=args.n_splits,
|
|
52
|
+
n_epochs=args.n_epochs,
|
|
53
|
+
n_workers=args.n_workers,
|
|
54
|
+
accelerator=args.accelerator,
|
|
55
|
+
devices=params.DEVICES,
|
|
56
|
+
early_stopping=None, # Force not apply early stopping because of KFold training process
|
|
57
|
+
export_torchscript=args.export_torchscript,
|
|
58
|
+
resume_checkpoint=args.resume_checkpoint,
|
|
59
|
+
checkpoint_path=args.checkpoint_path,
|
|
60
|
+
)
|
|
61
|
+
trainer.fit()
|
|
62
|
+
trainer.test()
|
|
63
|
+
|
|
64
|
+
if args.auto_visualize:
|
|
65
|
+
visualize_results(args.model_type)
|
|
66
|
+
else:
|
|
67
|
+
# opts = parse_options()
|
|
68
|
+
opts = args # dummy code
|
|
69
|
+
if args.data_file == 'default' and opts.yaml_config_path is None:
|
|
70
|
+
raise ValueError('--data_file must be specified in infer mode')
|
|
71
|
+
else:
|
|
72
|
+
if opts.yaml_config_path is not None:
|
|
73
|
+
params.DATA_CONFIG_PATH = opts.yaml_config_path
|
|
74
|
+
infer2(opts)
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
infer(
|
|
78
|
+
data_path=opts.data_file,
|
|
79
|
+
infer_path=opts.infer_path,
|
|
80
|
+
model_type=opts.model_type,
|
|
81
|
+
batch_size=opts.batch_size,
|
|
82
|
+
n_workers=opts.n_workers,
|
|
83
|
+
checkpoint_path=opts.checkpoint_path,
|
|
84
|
+
remove_tmp=opts.remove_tmp,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if __name__ == "__main__":
|
|
89
|
+
run()
|
|
90
|
+
|
|
91
|
+
# import torch
|
|
92
|
+
#
|
|
93
|
+
# t = torch.rand(2, 5120)
|
|
94
|
+
# window = torch.hamming_window(2048)
|
|
95
|
+
# rf = torch.fft.rfft(t)
|
|
96
|
+
# rs = torch.stft(t, n_fft=2048, win_length=2048, hop_length=512, normalized=True, return_complex=True, window=window)
|
|
97
|
+
# rt = torch.sqrt(rs.real ** 2 + rs.imag ** 2)
|
|
98
|
+
# print(rt, rt.shape)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
turnoff: null
|
|
2
|
+
conv_layers:
|
|
3
|
+
- name: conv1
|
|
4
|
+
out_channels: 4
|
|
5
|
+
kernel_size: 3
|
|
6
|
+
stride: 1
|
|
7
|
+
padding: 0
|
|
8
|
+
pooling_kernel_size: 2
|
|
9
|
+
pooling_stride: 2
|
|
10
|
+
pooling_padding: 0
|
|
11
|
+
dropout: 0.1
|
|
12
|
+
- name: conv2
|
|
13
|
+
out_channels: 3
|
|
14
|
+
kernel_size: 3
|
|
15
|
+
stride: 1
|
|
16
|
+
padding: 0
|
|
17
|
+
pooling_kernel_size: 2
|
|
18
|
+
pooling_stride: 2
|
|
19
|
+
pooling_padding: 0
|
|
20
|
+
dropout: 0.1
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
num_chains: 3
|
|
2
|
+
n_fft: 5120
|
|
3
|
+
conv_layers:
|
|
4
|
+
- name: conv1
|
|
5
|
+
out_channels: 4096
|
|
6
|
+
kernel_size: 11
|
|
7
|
+
stride: 4
|
|
8
|
+
padding: 0
|
|
9
|
+
pooling_kernel_size: 3
|
|
10
|
+
pooling_stride: 2
|
|
11
|
+
pooling_padding: 0
|
|
12
|
+
dropout: 0.1
|
|
13
|
+
- name: conv2
|
|
14
|
+
out_channels: 5120
|
|
15
|
+
kernel_size: 5
|
|
16
|
+
stride: 1
|
|
17
|
+
padding: 0
|
|
18
|
+
pooling_kernel_size: 3
|
|
19
|
+
pooling_stride: 2
|
|
20
|
+
pooling_padding: 0
|
|
21
|
+
dropout: 0.1
|
|
22
|
+
- name: conv3
|
|
23
|
+
out_channels: 6144
|
|
24
|
+
kernel_size: 5
|
|
25
|
+
stride: 1
|
|
26
|
+
padding: 0
|
|
27
|
+
pooling_kernel_size: 3
|
|
28
|
+
pooling_stride: 2
|
|
29
|
+
pooling_padding: 0
|
|
30
|
+
dropout: 0.1
|
|
31
|
+
- name: conv4
|
|
32
|
+
out_channels: 3072
|
|
33
|
+
kernel_size: 5
|
|
34
|
+
stride: 1
|
|
35
|
+
padding: 0
|
|
36
|
+
pooling_kernel_size: 3
|
|
37
|
+
pooling_stride: 2
|
|
38
|
+
pooling_padding: 0
|
|
39
|
+
dropout: 0.1
|
|
40
|
+
- name: conv5
|
|
41
|
+
out_channels: 1024
|
|
42
|
+
kernel_size: 5
|
|
43
|
+
stride: 1
|
|
44
|
+
padding: 0
|
|
45
|
+
pooling_kernel_size: 3
|
|
46
|
+
pooling_stride: 2
|
|
47
|
+
pooling_padding: 0
|
|
48
|
+
dropout: 0.1
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
num_chains: 3
|
|
2
|
+
n_fft: 2048
|
|
3
|
+
win_length: 2048
|
|
4
|
+
hop_length: 512
|
|
5
|
+
normalized: True
|
|
6
|
+
conv_layers:
|
|
7
|
+
- name: conv1
|
|
8
|
+
out_channels: 4096
|
|
9
|
+
kernel_size: 3
|
|
10
|
+
stride: 1
|
|
11
|
+
padding: 0
|
|
12
|
+
pooling_kernel_size: 3
|
|
13
|
+
pooling_stride: 2
|
|
14
|
+
pooling_padding: 0
|
|
15
|
+
dropout: 0.1
|
|
16
|
+
- name: conv2
|
|
17
|
+
out_channels: 1024
|
|
18
|
+
kernel_size: 1
|
|
19
|
+
stride: 1
|
|
20
|
+
padding: 0
|
|
21
|
+
pooling_kernel_size: 2
|
|
22
|
+
pooling_stride: 2
|
|
23
|
+
pooling_padding: 0
|
|
24
|
+
dropout: 0.1
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
num_chains: 3
|
|
2
|
+
emb_size: 100
|
|
3
|
+
wavelet: "morl"
|
|
4
|
+
conv_layers:
|
|
5
|
+
- name: conv1
|
|
6
|
+
out_channels: 4096
|
|
7
|
+
kernel_size: 11
|
|
8
|
+
stride: 4
|
|
9
|
+
padding: 0
|
|
10
|
+
pooling_kernel_size: 3
|
|
11
|
+
pooling_stride: 2
|
|
12
|
+
pooling_padding: 0
|
|
13
|
+
dropout: 0.1
|
|
14
|
+
- name: conv2
|
|
15
|
+
out_channels: 5120
|
|
16
|
+
kernel_size: 5
|
|
17
|
+
stride: 1
|
|
18
|
+
padding: 0
|
|
19
|
+
pooling_kernel_size: 3
|
|
20
|
+
pooling_stride: 2
|
|
21
|
+
pooling_padding: 0
|
|
22
|
+
dropout: 0.1
|
|
23
|
+
- name: conv3
|
|
24
|
+
out_channels: 6144
|
|
25
|
+
kernel_size: 5
|
|
26
|
+
stride: 1
|
|
27
|
+
padding: 0
|
|
28
|
+
pooling_kernel_size: 3
|
|
29
|
+
pooling_stride: 2
|
|
30
|
+
pooling_padding: 0
|
|
31
|
+
dropout: 0.1
|
|
32
|
+
- name: conv4
|
|
33
|
+
out_channels: 3072
|
|
34
|
+
kernel_size: 5
|
|
35
|
+
stride: 1
|
|
36
|
+
padding: 0
|
|
37
|
+
pooling_kernel_size: 3
|
|
38
|
+
pooling_stride: 2
|
|
39
|
+
pooling_padding: 0
|
|
40
|
+
dropout: 0.1
|
|
41
|
+
- name: conv5
|
|
42
|
+
out_channels: 1024
|
|
43
|
+
kernel_size: 5
|
|
44
|
+
stride: 1
|
|
45
|
+
padding: 0
|
|
46
|
+
pooling_kernel_size: 3
|
|
47
|
+
pooling_stride: 2
|
|
48
|
+
pooling_padding: 0
|
|
49
|
+
dropout: 0.1
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
DATA_DIR = os.path.abspath(os.path.dirname(__file__))
|
|
5
|
+
RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw')
|
|
6
|
+
DUMP_DATA_DIR = os.path.join(DATA_DIR, 'dump')
|
|
7
|
+
|
|
8
|
+
os.makedirs(DUMP_DATA_DIR, exist_ok=True)
|
|
9
|
+
|
|
10
|
+
SEQ_FILES = [
|
|
11
|
+
str(Path(RAW_DATA_DIR, "raw_K3_EEG3_11h.txt")),
|
|
12
|
+
str(Path(RAW_DATA_DIR, "raw_RS2_EEG1_23 hr.txt")),
|
|
13
|
+
#str(Path(RAW_DATA_DIR, "raw_S1_EEG1_23 hr.txt")),
|
|
14
|
+
#str(Path(RAW_DATA_DIR, "K1_EEG1_SAL.csv")),
|
|
15
|
+
#str(Path(RAW_DATA_DIR, "K1_EEG7_SAL.csv")),
|
|
16
|
+
#str(Path(RAW_DATA_DIR, "K2_EEG4_SAL.csv")),
|
|
17
|
+
#str(Path(RAW_DATA_DIR, "K2_EEG5_SAL.csv")),
|
|
18
|
+
#str(Path(RAW_DATA_DIR, "K4_EEG7_SAL.csv")),
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
LABEL_FILES = [
|
|
22
|
+
str(Path(RAW_DATA_DIR, "K3_EEG3_11h.txt")),
|
|
23
|
+
str(Path(RAW_DATA_DIR, "RS2_EEG1_23 hr.txt")),
|
|
24
|
+
#str(Path(RAW_DATA_DIR, "S1_EEG1_23 hr.txt")),
|
|
25
|
+
#str(Path(RAW_DATA_DIR, "K1_EEG1_11h.txt")),
|
|
26
|
+
#str(Path(RAW_DATA_DIR, "K1_EEG7_11h.txt")),
|
|
27
|
+
#str(Path(RAW_DATA_DIR, "K2_EEG4_11h.txt")),
|
|
28
|
+
#str(Path(RAW_DATA_DIR, "K2_EEG5_11h.txt")),
|
|
29
|
+
#str(Path(RAW_DATA_DIR, "K4_EEG7_11h.txt")),
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
DUMP_DATA_FILES = {
|
|
33
|
+
"train": [
|
|
34
|
+
str(Path(DUMP_DATA_DIR, f"dump_eeg_{i+1}.pkl")) for i in range(len(SEQ_FILES))
|
|
35
|
+
],
|
|
36
|
+
"infer": [
|
|
37
|
+
str(Path(DUMP_DATA_DIR, f"dump_eeg_{i+1}_infer.pkl")) for i in range(len(SEQ_FILES))
|
|
38
|
+
],
|
|
39
|
+
}
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from sklearn.model_selection import KFold
|
|
6
|
+
from torch.utils.data import Subset, ConcatDataset, DataLoader, random_split, Sampler, RandomSampler
|
|
7
|
+
|
|
8
|
+
from .data import DUMP_DATA_FILES, DUMP_DATA_DIR
|
|
9
|
+
from . import params
|
|
10
|
+
from .dataset import EEGDataset
|
|
11
|
+
from .utils.data_utils import get_dataset_train
|
|
12
|
+
|
|
13
|
+
LABEL_MARKER = "EpochNo"
|
|
14
|
+
|
|
15
|
+
def fread_header_labels(inp):
|
|
16
|
+
global SEP_CHECKED, SEPERATOR
|
|
17
|
+
fin = open(inp, errors='ignore')
|
|
18
|
+
|
|
19
|
+
headers = []
|
|
20
|
+
|
|
21
|
+
while True:
|
|
22
|
+
line = fin.readline()
|
|
23
|
+
if line == "":
|
|
24
|
+
break
|
|
25
|
+
headers.append(line)
|
|
26
|
+
if line.startswith(LABEL_MARKER):
|
|
27
|
+
break
|
|
28
|
+
|
|
29
|
+
return "".join(headers), fin
|
|
30
|
+
|
|
31
|
+
class EEGKFoldSampler(Sampler):
|
|
32
|
+
def __init__(self, dataset, k_fold):
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
def __iter__(self):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def __len__(self):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class EEGKFoldDataLoader:
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
dataset_files='default',
|
|
46
|
+
n_splits=5,
|
|
47
|
+
n_workers=0,
|
|
48
|
+
batch_size=4,
|
|
49
|
+
minmax_normalized=True,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
:param dataset_files: must be list of Path to dump file or "default"
|
|
53
|
+
:param n_splits:
|
|
54
|
+
:param n_workers:
|
|
55
|
+
:param batch_size:
|
|
56
|
+
:param minmax_normalized:
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
self.minmax_normalized = minmax_normalized
|
|
60
|
+
|
|
61
|
+
self.val_dataset = None
|
|
62
|
+
self.train_dataset = None
|
|
63
|
+
self.test_dataset = None
|
|
64
|
+
self.datasets = None
|
|
65
|
+
self.all_splits = None
|
|
66
|
+
self.train_val_datasets = None
|
|
67
|
+
self.current_fold = None
|
|
68
|
+
self.current_epoch = None
|
|
69
|
+
|
|
70
|
+
if isinstance(dataset_files, list):
|
|
71
|
+
self.dataset_files = dataset_files
|
|
72
|
+
else:
|
|
73
|
+
self.prepare_default_data()
|
|
74
|
+
self.dataset_files = DUMP_DATA_FILES['train']
|
|
75
|
+
|
|
76
|
+
self.n_splits = n_splits
|
|
77
|
+
self.n_workers = n_workers
|
|
78
|
+
self.batch_size = batch_size
|
|
79
|
+
self.split_generator = torch.Generator().manual_seed(params.RD_SEED)
|
|
80
|
+
self.dataloader_generator = torch.Generator().manual_seed(0)
|
|
81
|
+
self.setup()
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def prepare_default_data():
|
|
85
|
+
os.makedirs(DUMP_DATA_DIR, exist_ok=True)
|
|
86
|
+
for _, _, files in os.walk(DUMP_DATA_DIR):
|
|
87
|
+
if len(files) != len(DUMP_DATA_FILES['train']):
|
|
88
|
+
get_dataset_train(remote_type='dump')
|
|
89
|
+
|
|
90
|
+
def setup(self):
|
|
91
|
+
print("Loading data...")
|
|
92
|
+
datasets = []
|
|
93
|
+
train_val_dts, test_dts = [], []
|
|
94
|
+
for i, _ in enumerate(self.dataset_files):
|
|
95
|
+
# print(type(DUMP_DATA_FILES), DUMP_DATA_FILES)
|
|
96
|
+
dump_file = DUMP_DATA_FILES['train'][i]
|
|
97
|
+
print("Loading dump file {}".format(dump_file))
|
|
98
|
+
i_dataset = EEGDataset(dump_file, w_out=params.W_OUT, minmax_normalized=self.minmax_normalized)
|
|
99
|
+
datasets.append(i_dataset)
|
|
100
|
+
train_val_dt, test_dt = random_split(i_dataset, [0.9, 0.1], generator=self.split_generator)
|
|
101
|
+
train_val_dts.append(train_val_dt)
|
|
102
|
+
test_dts.append(test_dt)
|
|
103
|
+
|
|
104
|
+
self.datasets = datasets
|
|
105
|
+
self.train_val_datasets = train_val_dts
|
|
106
|
+
self.test_dataset = ConcatDataset(test_dts)
|
|
107
|
+
kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=params.RD_SEED)
|
|
108
|
+
all_splits = []
|
|
109
|
+
for i, _ in enumerate(self.datasets):
|
|
110
|
+
splits = [subset for subset in kf.split(train_val_dts[i])]
|
|
111
|
+
all_splits.append(splits)
|
|
112
|
+
self.all_splits = all_splits
|
|
113
|
+
|
|
114
|
+
def set_fold(self, k):
|
|
115
|
+
if k is None or k < 0 or k >= self.n_splits:
|
|
116
|
+
print("Fold value is invalid. Set to default value: 0")
|
|
117
|
+
self.current_fold = 0
|
|
118
|
+
else:
|
|
119
|
+
self.current_fold = k
|
|
120
|
+
|
|
121
|
+
train_dts, val_dts = [], []
|
|
122
|
+
for i, (dataset, splits) in enumerate(zip(self.datasets, self.all_splits)):
|
|
123
|
+
train_ids, val_ids = splits[k]
|
|
124
|
+
train_subset_ids = np.array(self.train_val_datasets[i].indices)[train_ids]
|
|
125
|
+
val_subset_ids = np.array(self.train_val_datasets[i].indices)[val_ids]
|
|
126
|
+
train_dt = Subset(dataset, train_subset_ids)
|
|
127
|
+
val_dt = Subset(dataset, val_subset_ids)
|
|
128
|
+
train_dts.append(train_dt)
|
|
129
|
+
val_dts.append(val_dt)
|
|
130
|
+
|
|
131
|
+
self.train_dataset = ConcatDataset(train_dts)
|
|
132
|
+
self.val_dataset = ConcatDataset(val_dts)
|
|
133
|
+
|
|
134
|
+
def set_epoch(self, epoch):
|
|
135
|
+
if self.current_fold is None:
|
|
136
|
+
print("Fold attribute is None. Set to default value: 0")
|
|
137
|
+
self.current_fold = 0
|
|
138
|
+
self.current_epoch = epoch
|
|
139
|
+
dataloader_generator_seed = (self.current_epoch + 1) * (self.current_fold + 1)
|
|
140
|
+
self.dataloader_generator = self.dataloader_generator.manual_seed(dataloader_generator_seed)
|
|
141
|
+
|
|
142
|
+
def train_dataloader(self, epoch):
|
|
143
|
+
self.set_epoch(epoch)
|
|
144
|
+
return DataLoader(
|
|
145
|
+
dataset=self.train_dataset,
|
|
146
|
+
batch_size=self.batch_size,
|
|
147
|
+
shuffle=False,
|
|
148
|
+
sampler=RandomSampler(self.train_dataset, generator=self.dataloader_generator),
|
|
149
|
+
num_workers=self.n_workers,
|
|
150
|
+
drop_last=True
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def val_dataloader(self):
|
|
154
|
+
return DataLoader(
|
|
155
|
+
dataset=self.val_dataset,
|
|
156
|
+
batch_size=self.batch_size,
|
|
157
|
+
shuffle=False,
|
|
158
|
+
num_workers=self.n_workers,
|
|
159
|
+
drop_last=True,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def test_dataloader(self):
|
|
163
|
+
return DataLoader(
|
|
164
|
+
dataset=self.test_dataset,
|
|
165
|
+
batch_size=self.batch_size,
|
|
166
|
+
shuffle=False,
|
|
167
|
+
num_workers=self.n_workers,
|
|
168
|
+
drop_last=True,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == '__main__':
|
|
173
|
+
dataloader = EEGKFoldDataLoader()
|
|
174
|
+
dataloader.set_fold(0)
|
|
175
|
+
train_dataloader = dataloader.train_dataloader(epoch=0)
|
|
176
|
+
for i, data in enumerate(train_dataloader):
|
|
177
|
+
print(data)
|