eegpp3-beta 1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. eegpp3_beta-1.0/LICENSE +1 -0
  2. eegpp3_beta-1.0/PKG-INFO +69 -0
  3. eegpp3_beta-1.0/README.md +43 -0
  4. eegpp3_beta-1.0/pyproject.toml +38 -0
  5. eegpp3_beta-1.0/setup.cfg +4 -0
  6. eegpp3_beta-1.0/src/eegpp2/TOKEN.txt +3 -0
  7. eegpp3_beta-1.0/src/eegpp2/__init__.py +6 -0
  8. eegpp3_beta-1.0/src/eegpp2/__main__.py +98 -0
  9. eegpp3_beta-1.0/src/eegpp2/configs/__init__.py +7 -0
  10. eegpp3_beta-1.0/src/eegpp2/configs/cnn1d_config.yml +20 -0
  11. eegpp3_beta-1.0/src/eegpp2/configs/fft2cwout_config.yml +0 -0
  12. eegpp3_beta-1.0/src/eegpp2/configs/fftcnn1dnc_config.yml +48 -0
  13. eegpp3_beta-1.0/src/eegpp2/configs/ffttransnc_config.yml +7 -0
  14. eegpp3_beta-1.0/src/eegpp2/configs/stftcnn1dnc_config.yml +24 -0
  15. eegpp3_beta-1.0/src/eegpp2/configs/stfttransnc_config.yml +10 -0
  16. eegpp3_beta-1.0/src/eegpp2/configs/transformer_config.yml +7 -0
  17. eegpp3_beta-1.0/src/eegpp2/configs/wtcnn1dnc_config.yml +49 -0
  18. eegpp3_beta-1.0/src/eegpp2/configs/wtresnet501dnc_config.yml +3 -0
  19. eegpp3_beta-1.0/src/eegpp2/data/__init__.py +39 -0
  20. eegpp3_beta-1.0/src/eegpp2/dataloader.py +177 -0
  21. eegpp3_beta-1.0/src/eegpp2/dataset.py +114 -0
  22. eegpp3_beta-1.0/src/eegpp2/inference.py +311 -0
  23. eegpp3_beta-1.0/src/eegpp2/logger.py +67 -0
  24. eegpp3_beta-1.0/src/eegpp2/models/__init__.py +0 -0
  25. eegpp3_beta-1.0/src/eegpp2/models/baseline/__init__.py +0 -0
  26. eegpp3_beta-1.0/src/eegpp2/models/baseline/cnn1d.py +101 -0
  27. eegpp3_beta-1.0/src/eegpp2/models/baseline/fft.py +23 -0
  28. eegpp3_beta-1.0/src/eegpp2/models/baseline/resnet.py +148 -0
  29. eegpp3_beta-1.0/src/eegpp2/models/baseline/transformer.py +196 -0
  30. eegpp3_beta-1.0/src/eegpp2/models/cnn1d2c.py +123 -0
  31. eegpp3_beta-1.0/src/eegpp2/models/fft2c.py +59 -0
  32. eegpp3_beta-1.0/src/eegpp2/models/fftcnn1dnc.py +158 -0
  33. eegpp3_beta-1.0/src/eegpp2/models/ffttransnc.py +138 -0
  34. eegpp3_beta-1.0/src/eegpp2/models/stft2c.py +13 -0
  35. eegpp3_beta-1.0/src/eegpp2/models/stftcnn1dnc.py +166 -0
  36. eegpp3_beta-1.0/src/eegpp2/models/stfttransnc.py +186 -0
  37. eegpp3_beta-1.0/src/eegpp2/models/wtcnn1dnc.py +136 -0
  38. eegpp3_beta-1.0/src/eegpp2/models/wtresnet1dnc.py +105 -0
  39. eegpp3_beta-1.0/src/eegpp2/out/__init__.py +8 -0
  40. eegpp3_beta-1.0/src/eegpp2/overwrite_template.py +65 -0
  41. eegpp3_beta-1.0/src/eegpp2/params.py +27 -0
  42. eegpp3_beta-1.0/src/eegpp2/trainer.py +415 -0
  43. eegpp3_beta-1.0/src/eegpp2/trainer_backup.py +148 -0
  44. eegpp3_beta-1.0/src/eegpp2/utils/__init__.py +0 -0
  45. eegpp3_beta-1.0/src/eegpp2/utils/common_utils.py +85 -0
  46. eegpp3_beta-1.0/src/eegpp2/utils/config_utils.py +11 -0
  47. eegpp3_beta-1.0/src/eegpp2/utils/data_utils.py +265 -0
  48. eegpp3_beta-1.0/src/eegpp2/utils/model_utils.py +74 -0
  49. eegpp3_beta-1.0/src/eegpp2/visualization.py +209 -0
  50. eegpp3_beta-1.0/src/eegpp3_beta.egg-info/PKG-INFO +69 -0
  51. eegpp3_beta-1.0/src/eegpp3_beta.egg-info/SOURCES.txt +52 -0
  52. eegpp3_beta-1.0/src/eegpp3_beta.egg-info/dependency_links.txt +1 -0
  53. eegpp3_beta-1.0/src/eegpp3_beta.egg-info/requires.txt +13 -0
  54. eegpp3_beta-1.0/src/eegpp3_beta.egg-info/top_level.txt +1 -0
@@ -0,0 +1 @@
1
+ (C) Tuan Vu from DSLab, SoICT, HUST
@@ -0,0 +1,69 @@
1
+ Metadata-Version: 2.4
2
+ Name: eegpp3-beta
3
+ Version: 1.0
4
+ Summary: EEG Phrase Predictor ver 3
5
+ Author-email: "Duc Anh Nguyen, Vu An Tuan" <vuanhtuan1407@gmail.com>
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: dropbox>=12.0.2
13
+ Requires-Dist: joblib>=1.4.2
14
+ Requires-Dist: numpy>=2.1.1
15
+ Requires-Dist: tqdm>=4.66.5
16
+ Requires-Dist: torchinfo>=1.8.0
17
+ Requires-Dist: scipy>=1.14.1
18
+ Requires-Dist: pyyaml>=6.0.2
19
+ Requires-Dist: pywavelets>=1.7.0
20
+ Requires-Dist: ptwt>=0.1.9
21
+ Requires-Dist: pandas>=2.2.2
22
+ Requires-Dist: lightning>=2.4.0
23
+ Requires-Dist: scikit-learn>=1.5.2
24
+ Requires-Dist: seaborn>=0.13.2
25
+ Dynamic: license-file
26
+
27
+ # EEG Phase Predictor ver 2
28
+
29
+ **Note: This is beta version, use for training with default dataset and inference only
30
+
31
+ ## Setup
32
+
33
+ - Requirements: python >= 3.10
34
+ - Installing:
35
+
36
+ ```aiignore
37
+ pip install eegpp3-beta
38
+ ```
39
+
40
+ ## Train with default dataset
41
+
42
+ ```aiignore
43
+ python -m eegpp3 --mode "train" --model_type <mode_type> --lr <learning_rate> --batch_size <batch_size> --n_epochs <num_epochs> --n_splits <num_folds> --resume_checkpoint <resume_from_checkpoint>
44
+ ```
45
+
46
+ ex: `python -m eegpp3 --mode "train" --model_type "stftcnn1dnc" --n_epochs 20 --n_splits 10 --resume_checkpoint False`
47
+
48
+ ## Inference
49
+
50
+ 1. Old version (**Recommend**)
51
+
52
+ ```aiignore
53
+ python -m eegpp3 -p <path_to_infer_config.yml>
54
+ ```
55
+
56
+ ex: `python -m eegpp3 -p "data_config_infer.yml"`
57
+
58
+ 2. New version
59
+
60
+ ```aiignore
61
+ python -m eegpp3 --mode "infer" --data_path <path_to_data_file> --infer_path <path_to_saving_file> --model_type <model_type>
62
+ ```
63
+
64
+ ex:
65
+ `python -m eegpp3 --mode "infer" --data_path "./dump_eeg_1.pkl" --infer_path './inference_result.txt" --model_type "stftcnn1dnc"`
66
+
67
+ ## Model type
68
+
69
+ - stftcnn1dnc: Multi-channels STFT-CNN
@@ -0,0 +1,43 @@
1
+ # EEG Phase Predictor ver 2
2
+
3
+ **Note: This is beta version, use for training with default dataset and inference only
4
+
5
+ ## Setup
6
+
7
+ - Requirements: python >= 3.10
8
+ - Installing:
9
+
10
+ ```aiignore
11
+ pip install eegpp3-beta
12
+ ```
13
+
14
+ ## Train with default dataset
15
+
16
+ ```aiignore
17
+ python -m eegpp3 --mode "train" --model_type <mode_type> --lr <learning_rate> --batch_size <batch_size> --n_epochs <num_epochs> --n_splits <num_folds> --resume_checkpoint <resume_from_checkpoint>
18
+ ```
19
+
20
+ ex: `python -m eegpp3 --mode "train" --model_type "stftcnn1dnc" --n_epochs 20 --n_splits 10 --resume_checkpoint False`
21
+
22
+ ## Inference
23
+
24
+ 1. Old version (**Recommend**)
25
+
26
+ ```aiignore
27
+ python -m eegpp3 -p <path_to_infer_config.yml>
28
+ ```
29
+
30
+ ex: `python -m eegpp3 -p "data_config_infer.yml"`
31
+
32
+ 2. New version
33
+
34
+ ```aiignore
35
+ python -m eegpp3 --mode "infer" --data_path <path_to_data_file> --infer_path <path_to_saving_file> --model_type <model_type>
36
+ ```
37
+
38
+ ex:
39
+ `python -m eegpp3 --mode "infer" --data_path "./dump_eeg_1.pkl" --infer_path './inference_result.txt" --model_type "stftcnn1dnc"`
40
+
41
+ ## Model type
42
+
43
+ - stftcnn1dnc: Multi-channels STFT-CNN
@@ -0,0 +1,38 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+ [project]
5
+ name = "eegpp3-beta"
6
+ version = "1.0"
7
+ authors = [
8
+ { name = "Duc Anh Nguyen, Vu An Tuan", email = "vuanhtuan1407@gmail.com" },
9
+ ]
10
+ description = "EEG Phrase Predictor ver 3"
11
+ readme = "README.md"
12
+ requires-python = ">=3.10"
13
+ classifiers = [
14
+ "Programming Language :: Python :: 3",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Operating System :: OS Independent",
17
+ ]
18
+
19
+ dependencies = [
20
+ 'dropbox>=12.0.2',
21
+ 'joblib>=1.4.2',
22
+ 'numpy>=2.1.1',
23
+ 'tqdm>=4.66.5',
24
+ 'torchinfo>=1.8.0',
25
+ 'scipy>=1.14.1',
26
+ 'pyyaml>=6.0.2',
27
+ 'pywavelets>=1.7.0',
28
+ 'ptwt>=0.1.9',
29
+ 'pandas>=2.2.2',
30
+ 'lightning>=2.4.0',
31
+ 'scikit-learn>=1.5.2',
32
+ 'seaborn>=0.13.2',
33
+ ]
34
+ [tool.setuptools.packages.find]
35
+ where = ["src"]
36
+
37
+ [tool.setuptools.package-data]
38
+ eegpp2 = ["configs/*.yml", "TOKEN.txt"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,3 @@
1
+ fq7ro54b8vcrz85
2
+ i4foei4891bdnah
3
+ yFJ0pz4d4TcAAAAAAAAAAV38Cd7Q1DXwn4L5dvu2K5zohoAWpG8mQqiNkbcKUTfp
@@ -0,0 +1,6 @@
1
+ import os
2
+
3
+ EEGPP_DIR = os.path.dirname(os.path.abspath(__file__))
4
+
5
+ CACHE_DIR = os.path.join(EEGPP_DIR, 'cache')
6
+ os.makedirs(CACHE_DIR, exist_ok=True)
@@ -0,0 +1,98 @@
1
+ import argparse
2
+
3
+ from . import params
4
+ from .inference import infer, infer2
5
+ from .trainer import EEGKFoldTrainer
6
+ from .visualization import visualize_results
7
+
8
+
9
+ def parse_arguments():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--mode", type=str, default='infer', choices=['infer', 'train'])
12
+ parser.add_argument("--model_type", type=str, default="stftcnn1dnc")
13
+ parser.add_argument("--lr", type=float, default=5e-4)
14
+ parser.add_argument("--batch_size", type=int, default=10)
15
+ parser.add_argument("--n_epochs", type=int, default=2)
16
+ parser.add_argument("--n_splits", type=int, default=2)
17
+ parser.add_argument("--n_workers", type=int, default=0)
18
+ parser.add_argument("--resume_checkpoint", type=bool, default=False)
19
+ parser.add_argument("--checkpoint_path", type=str, default=None)
20
+ parser.add_argument('--auto_visualize', type=bool, default=True)
21
+ parser.add_argument("--early_stopping", type=int, default=None)
22
+ parser.add_argument("--export_torchscript", type=bool, default=False)
23
+ parser.add_argument('--data_file', type=str, default='default',
24
+ help='Need to specify data file path when in infer mode')
25
+ parser.add_argument("--infer_path", type=str, default=None)
26
+ parser.add_argument("--remove_tmp", type=bool, default=True)
27
+ parser.add_argument('--accelerator', type=str, default='auto')
28
+ parser.add_argument("-p", "--yaml_config_path", type=str, default='../../data_config_infer2.yml')
29
+
30
+ return parser.parse_args()
31
+
32
+ #
33
+ # def parse_options():
34
+ # parser = argparse.ArgumentParser()
35
+ # parser.add_argument('--data_file', type=str, default='default')
36
+ # parser.add_argument("--infer_path", type=str, default=None)
37
+ # parser.add_argument("--model_type", type=str, default="stftcnn1dnc")
38
+ # parser.add_argument("--batch_size", type=int, default=10)
39
+ # parser.add_argument("--n_workers", type=int, default=0)
40
+ # parser.add_argument("--checkpoint_path", type=str, default=None)
41
+ # return parser.parse_args()
42
+ #
43
+
44
+ def run():
45
+ args = parse_arguments()
46
+ if args.mode == 'train':
47
+ trainer = EEGKFoldTrainer(
48
+ model_type=args.model_type,
49
+ lr=args.lr,
50
+ batch_size=args.batch_size,
51
+ n_splits=args.n_splits,
52
+ n_epochs=args.n_epochs,
53
+ n_workers=args.n_workers,
54
+ accelerator=args.accelerator,
55
+ devices=params.DEVICES,
56
+ early_stopping=None, # Force not apply early stopping because of KFold training process
57
+ export_torchscript=args.export_torchscript,
58
+ resume_checkpoint=args.resume_checkpoint,
59
+ checkpoint_path=args.checkpoint_path,
60
+ )
61
+ trainer.fit()
62
+ trainer.test()
63
+
64
+ if args.auto_visualize:
65
+ visualize_results(args.model_type)
66
+ else:
67
+ # opts = parse_options()
68
+ opts = args # dummy code
69
+ if args.data_file == 'default' and opts.yaml_config_path is None:
70
+ raise ValueError('--data_file must be specified in infer mode')
71
+ else:
72
+ if opts.yaml_config_path is not None:
73
+ params.DATA_CONFIG_PATH = opts.yaml_config_path
74
+ infer2(opts)
75
+
76
+ else:
77
+ infer(
78
+ data_path=opts.data_file,
79
+ infer_path=opts.infer_path,
80
+ model_type=opts.model_type,
81
+ batch_size=opts.batch_size,
82
+ n_workers=opts.n_workers,
83
+ checkpoint_path=opts.checkpoint_path,
84
+ remove_tmp=opts.remove_tmp,
85
+ )
86
+
87
+
88
+ if __name__ == "__main__":
89
+ run()
90
+
91
+ # import torch
92
+ #
93
+ # t = torch.rand(2, 5120)
94
+ # window = torch.hamming_window(2048)
95
+ # rf = torch.fft.rfft(t)
96
+ # rs = torch.stft(t, n_fft=2048, win_length=2048, hop_length=512, normalized=True, return_complex=True, window=window)
97
+ # rt = torch.sqrt(rs.real ** 2 + rs.imag ** 2)
98
+ # print(rt, rt.shape)
@@ -0,0 +1,7 @@
1
+ import os
2
+
3
+ from .. import EEGPP_DIR
4
+
5
+ CONFIG_DIR = os.path.join(EEGPP_DIR, 'configs')
6
+
7
+ tmp = os.path.abspath(os.path.dirname(__file__))
@@ -0,0 +1,20 @@
1
+ turnoff: null
2
+ conv_layers:
3
+ - name: conv1
4
+ out_channels: 4
5
+ kernel_size: 3
6
+ stride: 1
7
+ padding: 0
8
+ pooling_kernel_size: 2
9
+ pooling_stride: 2
10
+ pooling_padding: 0
11
+ dropout: 0.1
12
+ - name: conv2
13
+ out_channels: 3
14
+ kernel_size: 3
15
+ stride: 1
16
+ padding: 0
17
+ pooling_kernel_size: 2
18
+ pooling_stride: 2
19
+ pooling_padding: 0
20
+ dropout: 0.1
@@ -0,0 +1,48 @@
1
+ num_chains: 3
2
+ n_fft: 5120
3
+ conv_layers:
4
+ - name: conv1
5
+ out_channels: 4096
6
+ kernel_size: 11
7
+ stride: 4
8
+ padding: 0
9
+ pooling_kernel_size: 3
10
+ pooling_stride: 2
11
+ pooling_padding: 0
12
+ dropout: 0.1
13
+ - name: conv2
14
+ out_channels: 5120
15
+ kernel_size: 5
16
+ stride: 1
17
+ padding: 0
18
+ pooling_kernel_size: 3
19
+ pooling_stride: 2
20
+ pooling_padding: 0
21
+ dropout: 0.1
22
+ - name: conv3
23
+ out_channels: 6144
24
+ kernel_size: 5
25
+ stride: 1
26
+ padding: 0
27
+ pooling_kernel_size: 3
28
+ pooling_stride: 2
29
+ pooling_padding: 0
30
+ dropout: 0.1
31
+ - name: conv4
32
+ out_channels: 3072
33
+ kernel_size: 5
34
+ stride: 1
35
+ padding: 0
36
+ pooling_kernel_size: 3
37
+ pooling_stride: 2
38
+ pooling_padding: 0
39
+ dropout: 0.1
40
+ - name: conv5
41
+ out_channels: 1024
42
+ kernel_size: 5
43
+ stride: 1
44
+ padding: 0
45
+ pooling_kernel_size: 3
46
+ pooling_stride: 2
47
+ pooling_padding: 0
48
+ dropout: 0.1
@@ -0,0 +1,7 @@
1
+ num_chains: 3
2
+ n_fft: 1024
3
+ d_model: 1024
4
+ dropout: 0.1
5
+ nhead: 8
6
+ dim_feedforward: 2048
7
+ num_layers: 1
@@ -0,0 +1,24 @@
1
+ num_chains: 3
2
+ n_fft: 2048
3
+ win_length: 2048
4
+ hop_length: 512
5
+ normalized: True
6
+ conv_layers:
7
+ - name: conv1
8
+ out_channels: 4096
9
+ kernel_size: 3
10
+ stride: 1
11
+ padding: 0
12
+ pooling_kernel_size: 3
13
+ pooling_stride: 2
14
+ pooling_padding: 0
15
+ dropout: 0.1
16
+ - name: conv2
17
+ out_channels: 1024
18
+ kernel_size: 1
19
+ stride: 1
20
+ padding: 0
21
+ pooling_kernel_size: 2
22
+ pooling_stride: 2
23
+ pooling_padding: 0
24
+ dropout: 0.1
@@ -0,0 +1,10 @@
1
+ num_chains: 3
2
+ n_fft: 1024
3
+ win_length: 1024
4
+ hop_length: 512
5
+ return_complex: True
6
+ normalized: True
7
+ dropout: 0.1
8
+ nhead: 4
9
+ dim_feedforward: 2048
10
+ num_layers: 1
@@ -0,0 +1,7 @@
1
+ d_model: 512
2
+ dropout: 0.1
3
+ nhead_encoder: 4
4
+ nhead_decoder: 4
5
+ num_encoder_layers: 1
6
+ num_decoder_layers: 1
7
+ dim_feedforward: 2048
@@ -0,0 +1,49 @@
1
+ num_chains: 3
2
+ emb_size: 100
3
+ wavelet: "morl"
4
+ conv_layers:
5
+ - name: conv1
6
+ out_channels: 4096
7
+ kernel_size: 11
8
+ stride: 4
9
+ padding: 0
10
+ pooling_kernel_size: 3
11
+ pooling_stride: 2
12
+ pooling_padding: 0
13
+ dropout: 0.1
14
+ - name: conv2
15
+ out_channels: 5120
16
+ kernel_size: 5
17
+ stride: 1
18
+ padding: 0
19
+ pooling_kernel_size: 3
20
+ pooling_stride: 2
21
+ pooling_padding: 0
22
+ dropout: 0.1
23
+ - name: conv3
24
+ out_channels: 6144
25
+ kernel_size: 5
26
+ stride: 1
27
+ padding: 0
28
+ pooling_kernel_size: 3
29
+ pooling_stride: 2
30
+ pooling_padding: 0
31
+ dropout: 0.1
32
+ - name: conv4
33
+ out_channels: 3072
34
+ kernel_size: 5
35
+ stride: 1
36
+ padding: 0
37
+ pooling_kernel_size: 3
38
+ pooling_stride: 2
39
+ pooling_padding: 0
40
+ dropout: 0.1
41
+ - name: conv5
42
+ out_channels: 1024
43
+ kernel_size: 5
44
+ stride: 1
45
+ padding: 0
46
+ pooling_kernel_size: 3
47
+ pooling_stride: 2
48
+ pooling_padding: 0
49
+ dropout: 0.1
@@ -0,0 +1,3 @@
1
+ num_chains: 3
2
+ emb_size: 100
3
+ wavelet: "morl"
@@ -0,0 +1,39 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ DATA_DIR = os.path.abspath(os.path.dirname(__file__))
5
+ RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw')
6
+ DUMP_DATA_DIR = os.path.join(DATA_DIR, 'dump')
7
+
8
+ os.makedirs(DUMP_DATA_DIR, exist_ok=True)
9
+
10
+ SEQ_FILES = [
11
+ str(Path(RAW_DATA_DIR, "raw_K3_EEG3_11h.txt")),
12
+ str(Path(RAW_DATA_DIR, "raw_RS2_EEG1_23 hr.txt")),
13
+ #str(Path(RAW_DATA_DIR, "raw_S1_EEG1_23 hr.txt")),
14
+ #str(Path(RAW_DATA_DIR, "K1_EEG1_SAL.csv")),
15
+ #str(Path(RAW_DATA_DIR, "K1_EEG7_SAL.csv")),
16
+ #str(Path(RAW_DATA_DIR, "K2_EEG4_SAL.csv")),
17
+ #str(Path(RAW_DATA_DIR, "K2_EEG5_SAL.csv")),
18
+ #str(Path(RAW_DATA_DIR, "K4_EEG7_SAL.csv")),
19
+ ]
20
+
21
+ LABEL_FILES = [
22
+ str(Path(RAW_DATA_DIR, "K3_EEG3_11h.txt")),
23
+ str(Path(RAW_DATA_DIR, "RS2_EEG1_23 hr.txt")),
24
+ #str(Path(RAW_DATA_DIR, "S1_EEG1_23 hr.txt")),
25
+ #str(Path(RAW_DATA_DIR, "K1_EEG1_11h.txt")),
26
+ #str(Path(RAW_DATA_DIR, "K1_EEG7_11h.txt")),
27
+ #str(Path(RAW_DATA_DIR, "K2_EEG4_11h.txt")),
28
+ #str(Path(RAW_DATA_DIR, "K2_EEG5_11h.txt")),
29
+ #str(Path(RAW_DATA_DIR, "K4_EEG7_11h.txt")),
30
+ ]
31
+
32
+ DUMP_DATA_FILES = {
33
+ "train": [
34
+ str(Path(DUMP_DATA_DIR, f"dump_eeg_{i+1}.pkl")) for i in range(len(SEQ_FILES))
35
+ ],
36
+ "infer": [
37
+ str(Path(DUMP_DATA_DIR, f"dump_eeg_{i+1}_infer.pkl")) for i in range(len(SEQ_FILES))
38
+ ],
39
+ }
@@ -0,0 +1,177 @@
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from sklearn.model_selection import KFold
6
+ from torch.utils.data import Subset, ConcatDataset, DataLoader, random_split, Sampler, RandomSampler
7
+
8
+ from .data import DUMP_DATA_FILES, DUMP_DATA_DIR
9
+ from . import params
10
+ from .dataset import EEGDataset
11
+ from .utils.data_utils import get_dataset_train
12
+
13
+ LABEL_MARKER = "EpochNo"
14
+
15
+ def fread_header_labels(inp):
16
+ global SEP_CHECKED, SEPERATOR
17
+ fin = open(inp, errors='ignore')
18
+
19
+ headers = []
20
+
21
+ while True:
22
+ line = fin.readline()
23
+ if line == "":
24
+ break
25
+ headers.append(line)
26
+ if line.startswith(LABEL_MARKER):
27
+ break
28
+
29
+ return "".join(headers), fin
30
+
31
+ class EEGKFoldSampler(Sampler):
32
+ def __init__(self, dataset, k_fold):
33
+ super().__init__()
34
+
35
+ def __iter__(self):
36
+ pass
37
+
38
+ def __len__(self):
39
+ pass
40
+
41
+
42
+ class EEGKFoldDataLoader:
43
+ def __init__(
44
+ self,
45
+ dataset_files='default',
46
+ n_splits=5,
47
+ n_workers=0,
48
+ batch_size=4,
49
+ minmax_normalized=True,
50
+ ):
51
+ """
52
+ :param dataset_files: must be list of Path to dump file or "default"
53
+ :param n_splits:
54
+ :param n_workers:
55
+ :param batch_size:
56
+ :param minmax_normalized:
57
+ """
58
+
59
+ self.minmax_normalized = minmax_normalized
60
+
61
+ self.val_dataset = None
62
+ self.train_dataset = None
63
+ self.test_dataset = None
64
+ self.datasets = None
65
+ self.all_splits = None
66
+ self.train_val_datasets = None
67
+ self.current_fold = None
68
+ self.current_epoch = None
69
+
70
+ if isinstance(dataset_files, list):
71
+ self.dataset_files = dataset_files
72
+ else:
73
+ self.prepare_default_data()
74
+ self.dataset_files = DUMP_DATA_FILES['train']
75
+
76
+ self.n_splits = n_splits
77
+ self.n_workers = n_workers
78
+ self.batch_size = batch_size
79
+ self.split_generator = torch.Generator().manual_seed(params.RD_SEED)
80
+ self.dataloader_generator = torch.Generator().manual_seed(0)
81
+ self.setup()
82
+
83
+ @staticmethod
84
+ def prepare_default_data():
85
+ os.makedirs(DUMP_DATA_DIR, exist_ok=True)
86
+ for _, _, files in os.walk(DUMP_DATA_DIR):
87
+ if len(files) != len(DUMP_DATA_FILES['train']):
88
+ get_dataset_train(remote_type='dump')
89
+
90
+ def setup(self):
91
+ print("Loading data...")
92
+ datasets = []
93
+ train_val_dts, test_dts = [], []
94
+ for i, _ in enumerate(self.dataset_files):
95
+ # print(type(DUMP_DATA_FILES), DUMP_DATA_FILES)
96
+ dump_file = DUMP_DATA_FILES['train'][i]
97
+ print("Loading dump file {}".format(dump_file))
98
+ i_dataset = EEGDataset(dump_file, w_out=params.W_OUT, minmax_normalized=self.minmax_normalized)
99
+ datasets.append(i_dataset)
100
+ train_val_dt, test_dt = random_split(i_dataset, [0.9, 0.1], generator=self.split_generator)
101
+ train_val_dts.append(train_val_dt)
102
+ test_dts.append(test_dt)
103
+
104
+ self.datasets = datasets
105
+ self.train_val_datasets = train_val_dts
106
+ self.test_dataset = ConcatDataset(test_dts)
107
+ kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=params.RD_SEED)
108
+ all_splits = []
109
+ for i, _ in enumerate(self.datasets):
110
+ splits = [subset for subset in kf.split(train_val_dts[i])]
111
+ all_splits.append(splits)
112
+ self.all_splits = all_splits
113
+
114
+ def set_fold(self, k):
115
+ if k is None or k < 0 or k >= self.n_splits:
116
+ print("Fold value is invalid. Set to default value: 0")
117
+ self.current_fold = 0
118
+ else:
119
+ self.current_fold = k
120
+
121
+ train_dts, val_dts = [], []
122
+ for i, (dataset, splits) in enumerate(zip(self.datasets, self.all_splits)):
123
+ train_ids, val_ids = splits[k]
124
+ train_subset_ids = np.array(self.train_val_datasets[i].indices)[train_ids]
125
+ val_subset_ids = np.array(self.train_val_datasets[i].indices)[val_ids]
126
+ train_dt = Subset(dataset, train_subset_ids)
127
+ val_dt = Subset(dataset, val_subset_ids)
128
+ train_dts.append(train_dt)
129
+ val_dts.append(val_dt)
130
+
131
+ self.train_dataset = ConcatDataset(train_dts)
132
+ self.val_dataset = ConcatDataset(val_dts)
133
+
134
+ def set_epoch(self, epoch):
135
+ if self.current_fold is None:
136
+ print("Fold attribute is None. Set to default value: 0")
137
+ self.current_fold = 0
138
+ self.current_epoch = epoch
139
+ dataloader_generator_seed = (self.current_epoch + 1) * (self.current_fold + 1)
140
+ self.dataloader_generator = self.dataloader_generator.manual_seed(dataloader_generator_seed)
141
+
142
+ def train_dataloader(self, epoch):
143
+ self.set_epoch(epoch)
144
+ return DataLoader(
145
+ dataset=self.train_dataset,
146
+ batch_size=self.batch_size,
147
+ shuffle=False,
148
+ sampler=RandomSampler(self.train_dataset, generator=self.dataloader_generator),
149
+ num_workers=self.n_workers,
150
+ drop_last=True
151
+ )
152
+
153
+ def val_dataloader(self):
154
+ return DataLoader(
155
+ dataset=self.val_dataset,
156
+ batch_size=self.batch_size,
157
+ shuffle=False,
158
+ num_workers=self.n_workers,
159
+ drop_last=True,
160
+ )
161
+
162
+ def test_dataloader(self):
163
+ return DataLoader(
164
+ dataset=self.test_dataset,
165
+ batch_size=self.batch_size,
166
+ shuffle=False,
167
+ num_workers=self.n_workers,
168
+ drop_last=True,
169
+ )
170
+
171
+
172
+ if __name__ == '__main__':
173
+ dataloader = EEGKFoldDataLoader()
174
+ dataloader.set_fold(0)
175
+ train_dataloader = dataloader.train_dataloader(epoch=0)
176
+ for i, data in enumerate(train_dataloader):
177
+ print(data)