pytour 3.0.0.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Jin Dou
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.4
2
+ Name: pytour
3
+ Version: 3.0.0.dev4
4
+ Home-page: https://github.com/powerfulbean/pytour
5
+ Author: Powerfulbean
6
+ Author-email: powerfulbean@gmail.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: mne
13
+ Requires-Dist: numpy
14
+ Requires-Dist: scipy
15
+ Requires-Dist: matplotlib
16
+ Dynamic: author
17
+ Dynamic: author-email
18
+ Dynamic: classifier
19
+ Dynamic: description
20
+ Dynamic: description-content-type
21
+ Dynamic: home-page
22
+ Dynamic: license-file
23
+ Dynamic: requires-dist
24
+
25
+ # tour
26
+ A framework for boosting the implementation of stimulus-response research code in the field of cognitive science and neuroscience
@@ -0,0 +1,2 @@
1
+ # tour
2
+ A framework for boosting the implementation of stimulus-response research code in the field of cognitive science and neuroscience
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.4
2
+ Name: pytour
3
+ Version: 3.0.0.dev4
4
+ Home-page: https://github.com/powerfulbean/pytour
5
+ Author: Powerfulbean
6
+ Author-email: powerfulbean@gmail.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: mne
13
+ Requires-Dist: numpy
14
+ Requires-Dist: scipy
15
+ Requires-Dist: matplotlib
16
+ Dynamic: author
17
+ Dynamic: author-email
18
+ Dynamic: classifier
19
+ Dynamic: description
20
+ Dynamic: description-content-type
21
+ Dynamic: home-page
22
+ Dynamic: license-file
23
+ Dynamic: requires-dist
24
+
25
+ # tour
26
+ A framework for boosting the implementation of stimulus-response research code in the field of cognitive science and neuroscience
@@ -0,0 +1,16 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ pytour.egg-info/PKG-INFO
5
+ pytour.egg-info/SOURCES.txt
6
+ pytour.egg-info/dependency_links.txt
7
+ pytour.egg-info/requires.txt
8
+ pytour.egg-info/top_level.txt
9
+ tests/test_dataset.py
10
+ tour/__init__.py
11
+ tour/artifacts_removal.py
12
+ tour/package_manage.py
13
+ tour/vis.py
14
+ tour/dataclass/__init__.py
15
+ tour/dataclass/dataset.py
16
+ tour/dataclass/io.py
@@ -0,0 +1,4 @@
1
+ mne
2
+ numpy
3
+ scipy
4
+ matplotlib
@@ -0,0 +1 @@
1
+ tour
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,40 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Aug 31 00:52:16 2019
4
+
5
+ @author: Jin Dou
6
+ """
7
+
8
+ import setuptools
9
+ import re
10
+
11
+ with open("./README.md", "r") as fh:
12
+ long_description = fh.read()
13
+
14
+ with open("tour/__init__.py") as file:
15
+ for line in file.readlines():
16
+ m = re.match("__version__ *= *['\"](.*)['\"]", line)
17
+ if m:
18
+ version = m.group(1)
19
+
20
+ setuptools.setup(
21
+ name="pytour",
22
+ version=version,
23
+ author="Powerfulbean",
24
+ author_email="powerfulbean@gmail.com",
25
+ long_description=long_description,
26
+ long_description_content_type="text/markdown",
27
+ url="https://github.com/powerfulbean/pytour",
28
+ packages=setuptools.find_packages(),
29
+ classifiers=[
30
+ "Programming Language :: Python :: 3",
31
+ "License :: OSI Approved :: MIT License",
32
+ "Operating System :: OS Independent",
33
+ ],
34
+ install_requires=[
35
+ "mne",
36
+ "numpy",
37
+ "scipy",
38
+ "matplotlib",
39
+ ],
40
+ )
@@ -0,0 +1,36 @@
1
+ import os
2
+ import mne
3
+ import h5py
4
+ import json
5
+ import numpy as np
6
+ from collections import OrderedDict
7
+ from StellarInfra import siIO
8
+ from tray.dataclass.io import (
9
+ mne_montage_to_h5py_group,
10
+ mne_montage_from_h5py_group
11
+ )
12
+
13
+ current_folder = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ def test_save_mne_montage():
16
+ output_fd = os.environ['box_root']
17
+ montage = mne.channels.make_standard_montage('biosemi128')
18
+ fig = montage.plot(show = False)
19
+ fig.savefig(f"{current_folder}/target_montage.png")
20
+ pos_dict = montage.get_positions()
21
+ with h5py.File(f"{output_fd}/Collab-Project/CompiledDataset/biosemi128_montage.h5", "w") as f:
22
+ mne_montage_to_h5py_group(pos_dict, f)
23
+
24
+ def test_load_montage_in_mne():
25
+ output_fd = os.environ['box_root']
26
+ with h5py.File(f"{output_fd}/Collab-Project/CompiledDataset/biosemi128_montage.h5", "r") as f:
27
+ montage = mne_montage_from_h5py_group(f)
28
+ fig = montage.plot(show = False)
29
+ fig.savefig(f"{current_folder}/loaded_montage.png")
30
+
31
+
32
+ data_path = f"{os.environ['box_root']}/Collab-Project/CompiledDataset/ns.pkl"
33
+ dataset = siIO.loadObject(data_path)
34
+ print(dataset)
35
+ # test_save_mne_montage()
36
+ # test_load_montage_in_mne()
@@ -0,0 +1 @@
1
+ __version__ = "3.0.0.dev4"
@@ -0,0 +1,122 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Jul 8 14:41:55 2021
4
+
5
+ @author: Jin Dou
6
+ """
7
+ import mne
8
+ import numpy as np
9
+ from scipy.stats import zscore
10
+
11
+ def mneWrap_lalorlab_detect_EEG_badChannels(mneraw:mne.io.RawArray, montage = None, nNearest = 10):
12
+ oRaw = mneraw.copy()
13
+ data = oRaw.get_data()
14
+ if montage is None:
15
+ badChansIdx = lalorlab_detect_EEG_badChannels(data,False)
16
+ else:
17
+ badChansIdx = lalorlab_detect_EEG_badChannels_covVarNear(data,montage, nNearest = nNearest)
18
+ oRaw.info['bads'] = [oRaw.info['ch_names'][i] for i in badChansIdx]
19
+ print(f'bad channels: {",".join(oRaw.info["bads"])}')
20
+ return oRaw
21
+
22
+
23
+ def lalorlab_detect_EEG_badChannels(eegarray,verbose = True):
24
+ '''
25
+ we assume the first dimension is channel dimension
26
+
27
+ Parameters
28
+ ----------
29
+ eegarray : TYPE
30
+ DESCRIPTION.
31
+
32
+ Returns
33
+ -------
34
+ None.
35
+
36
+ '''
37
+ eegarray = np.array(eegarray)
38
+ assert len(eegarray.shape) == 2
39
+ stdChans = list()
40
+ badChansIdx = list()
41
+ for chan in eegarray:
42
+ stdChans.append(np.std(chan))
43
+
44
+ for idx,chan in enumerate(eegarray):
45
+ if np.std(chan) > 2.5 * np.mean(stdChans):
46
+ badChansIdx.append(idx)
47
+
48
+ stdChans.clear()
49
+
50
+ for idx,chan in enumerate(eegarray):
51
+ if idx not in badChansIdx:
52
+ stdChans.append(np.std(chan))
53
+
54
+ for idx,chan in enumerate(eegarray):
55
+ if np.std(chan) < np.mean(stdChans) / 2.5:
56
+ badChansIdx.append(idx)
57
+
58
+ if verbose:
59
+ print(badChansIdx)
60
+
61
+ return badChansIdx
62
+
63
+ def lalorlab_detect_EEG_badChannels_covVarNear(data, montage, th1 = 2, th2 = 2, nNearest = 6):
64
+ # data: (nChan, nSamples)
65
+ data = np.array(data)
66
+ assert data.ndim == 2
67
+
68
+ ### prepare the nearest channels
69
+ if nNearest > 0:
70
+ chanloc = montage.get_positions()['ch_pos']
71
+ chnames = []
72
+ poses = []
73
+ for n,pos in chanloc.items():
74
+ chnames.append(n)
75
+ poses.append(pos)
76
+
77
+ assert data.shape[0] == len(chnames)
78
+ chanDistMat = np.zeros((len(chanloc), len(chanloc)))
79
+
80
+ fDist = lambda pos1,pos2: np.sqrt(np.sum((pos1 - pos2)**2))
81
+
82
+ for i in range(len(chnames)):
83
+ for j in range(len(chnames)):
84
+ chanDistMat[i,j] = fDist(poses[i], poses[j])
85
+
86
+ nearChanIdx = []
87
+ for i in range(len(chnames)):
88
+ nearChanIdx.append(np.argsort(chanDistMat[i])[1:nNearest+1])
89
+ else:
90
+ nearChanIdx = [None] * data.shape[1]
91
+ ### find the bad channels
92
+ dataz = zscore(data, axis = 1)
93
+ XTX = np.matmul(dataz ,dataz.T)
94
+ stdXTX = np.std(XTX, axis = 1)
95
+ stdEEG = np.std(data, axis = 1)
96
+
97
+ badChans = []
98
+ if nNearest <=0 :
99
+ badChans.append(np.where(stdXTX < np.mean(stdXTX) / th1))
100
+ badChans.append(np.where(stdEEG > np.mean(stdEEG) * th2))
101
+ else:
102
+ for chanIdx in range(data.shape[0]):
103
+ # print(stdXTX[chanIdx],
104
+ # stdEEG[chanIdx],
105
+ # np.mean(stdXTX[nearChanIdx[chanIdx]]) / th1,
106
+ # np.mean(stdEEG[nearChanIdx[chanIdx]]) * th2)
107
+ if stdXTX[chanIdx] < np.mean(stdXTX[nearChanIdx[chanIdx]]) / th1:
108
+ badChans.append(chanIdx)
109
+ if stdEEG[chanIdx] > np.mean(stdEEG[nearChanIdx[chanIdx]]) * th2:
110
+ badChans.append(chanIdx)
111
+
112
+ return list(set(badChans))
113
+
114
+ def plotChanWithNamesAtIdx(montage, idxs):
115
+ chnames = montage.ch_names
116
+ montage.plot(show_names = [chnames[idx] for idx in idxs])
117
+
118
+
119
+ # def
120
+
121
+
122
+
File without changes
@@ -0,0 +1,404 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Jan 16 12:44:20 2025
4
+
5
+ @author: jdou3
6
+ """
7
+ import re
8
+ import io
9
+ import copy
10
+ import h5py
11
+ import json
12
+ import itertools
13
+ # import h5py
14
+ from typing import List
15
+ import numpy as np
16
+
17
+ from .io import (
18
+ data_record_from_h5py_group, data_record_to_h5py_group
19
+ )
20
+
21
+ #Note: please don't change the order, it matters for some functions using it
22
+ META_INFO_FORCED_FIELD = ['dataset_name', 'subj_id', 'trial_id']
23
+
24
+
25
+ def flatten_list_of_lists(list_of_lists:List[List]):
26
+ return list(itertools.chain.from_iterable(list_of_lists))
27
+
28
+ def k_folds(n_trials, n_folds):
29
+ id_trials = np.arange(n_trials)
30
+ splits = np.array_split(id_trials, n_folds)
31
+ for split_idx in range(len(splits)):
32
+ # print('cv fold', split_idx)
33
+ idx_val = splits[split_idx]
34
+ idx_train = np.concatenate(splits[:split_idx] + splits[split_idx + 1 :])
35
+ yield idx_train, idx_val
36
+
37
+ def _validate_stimuli_dict(stimuli_dict:dict):
38
+ for k in stimuli_dict:
39
+ stim:dict = stimuli_dict[k]
40
+ if not isinstance(stim, dict):
41
+ raise ValueError(f'value for stim {k} should be a dict')
42
+ for feat_k, feat_v in stim.items():
43
+ if isinstance(feat_v, dict):
44
+ assert all([s in feat_v for s in ['x', 'timeinfo', 'tag']])
45
+ else:
46
+ pass
47
+ # pattern = r"_fs\d+$"
48
+ # assert re.search(pattern, feat_k)
49
+ return stimuli_dict
50
+
51
+ def _validate_meta_info(info:dict):
52
+ assert all([k in info for k in META_INFO_FORCED_FIELD])
53
+ for k, v in info.items():
54
+ if isinstance(v, np.integer):
55
+ info[k] = int(v)
56
+ for k,v in info.items():
57
+ assert isinstance(v, ((str, int, float, np.ndarray))), f"{k},{type(v)}"
58
+ return info
59
+
60
+ def align_data(*arrs):
61
+ arrs = list(arrs)
62
+ #assume arr have shape [nChannel, nSamples]
63
+ minLen = min([arr.shape[1] for arr in arrs])
64
+ for i, arr in enumerate(arrs):
65
+ arrs[i] = arr[:, :minLen]
66
+ return arrs
67
+
68
+ def i_split_kfold(tarList,cur_fold,n_folds, add_dev = True):
69
+ ''' curFold starts from zero '''
70
+ kfList = [i for i in k_folds(len(tarList), n_folds)]
71
+ curTrainDevIdx = kfList[cur_fold][0]
72
+ curTestIdx = kfList[cur_fold][1]
73
+ curDevIdx = kfList[(cur_fold + 1) % n_folds][1]
74
+ if add_dev:
75
+ curTrainIdx = [i for i in curTrainDevIdx if i not in curDevIdx]
76
+ curTrain = [tarList[i] for i in curTrainIdx]
77
+ curDev = [tarList[i] for i in curDevIdx]
78
+ curTest = [tarList[i] for i in curTestIdx]
79
+ return curTrain, curDev, curTest
80
+ else:
81
+ curTrainDev = [tarList[i] for i in curTrainDevIdx]
82
+ curTest = [tarList[i] for i in curTestIdx]
83
+ return curTrainDev, [], curTest
84
+
85
+ def k_fold(dataset:'Dataset', cur_fold, n_folds, split_by = 'trial_id', add_dev = True, if_shuffle = False, seed = 42):
86
+ info_sets = sorted(
87
+ list(set(
88
+ [i.meta_info[split_by] for i in dataset.records]
89
+ )))
90
+ if if_shuffle:
91
+ rng = np.random.default_rng(seed)
92
+ inf_idxs = np.arange(len(info_sets))
93
+ rng.shuffle(inf_idxs)
94
+ info_sets = [info_sets[idx_] for idx_ in inf_idxs]
95
+ info_train_list, info_dev_list, info_test_list = i_split_kfold(
96
+ info_sets, cur_fold, n_folds, add_dev)
97
+ print(info_train_list, info_dev_list, info_test_list)
98
+ output = {}
99
+ output['train'] = dataset.subset_by_info({split_by:info_train_list})
100
+ if len(info_dev_list) > 0:
101
+ output['dev'] = dataset.subset_by_info({split_by:info_dev_list})
102
+ output['test'] = dataset.subset_by_info({split_by:info_test_list})
103
+ return output
104
+
105
+
106
+ def dump_dict_contains_nparray(state_dict):
107
+ output = {}
108
+ for key, value in state_dict.items():
109
+ if isinstance(value, np.ndarray):
110
+ buffer = io.BytesIO()
111
+ np.save(buffer, value)
112
+ t_value = buffer.getvalue()
113
+ elif isinstance(value, dict):
114
+ # print(key)
115
+ t_value = dump_dict_contains_nparray(value)
116
+ else:
117
+ t_value = value
118
+ output[key] = t_value
119
+ return output
120
+
121
+ def load_dict_contains_nparray(state_dict):
122
+ new_state = {}
123
+ for k,v in state_dict.items():
124
+ if isinstance(v, bytes):
125
+ buffer = io.BytesIO(v)
126
+ new_state[k] = np.load(buffer)
127
+ # new_state[k] = np.frombuffer(v)
128
+ elif isinstance(v, dict):
129
+ # print(k)
130
+ new_state[k] = load_dict_contains_nparray(v)
131
+ else:
132
+ new_state[k] = v
133
+ return new_state
134
+
135
+ class DataRecord:
136
+
137
+ def __init__(self, data, stim_id, meta_info:dict, srate:int):
138
+ self.srate = srate
139
+ self.data = data
140
+ self.stim_id = stim_id
141
+ self.meta_info = _validate_meta_info(meta_info)
142
+
143
+ def dump_to_dict(self):
144
+ return dump_dict_contains_nparray(self.__dict__)
145
+
146
+ def dump(self):
147
+ record_key = "-".join(
148
+ [str(self.meta_info[k]) for k in META_INFO_FORCED_FIELD]
149
+ )
150
+ return dict(
151
+ key = record_key,
152
+ data = self.data,
153
+ stim_id = self.stim_id,
154
+ meta_info = self.meta_info,
155
+ srate = self.srate,
156
+ )
157
+
158
+ @classmethod
159
+ def load(cls, new_state:dict):
160
+ obj = cls(**new_state)
161
+ return obj
162
+
163
+ @classmethod
164
+ def load_from_dict(cls, state:dict):
165
+ new_state = load_dict_contains_nparray(state)
166
+ obj = cls(**new_state)
167
+ # for key in state:
168
+ # obj.__dict__[key] = state[key]
169
+ return obj
170
+
171
+ def copy(self):
172
+ new = DataRecord(
173
+ self.data.copy(),
174
+ self.stim_id,
175
+ copy.deepcopy(self.meta_info),
176
+ self.srate
177
+ )
178
+ return new
179
+
180
+ class Dataset:
181
+
182
+ # data and stim have the shape (nChannels, nSamples)
183
+ # stim_id_cond: used when stimuli contains multiple conditions
184
+
185
+ def __init__(self, name:str, srate:int):
186
+ self.name = name
187
+ self.srate = srate
188
+ self.stim_feat_filter:list = []
189
+ self.resp_chan_filter:list = []
190
+ self.stim_id_cond:str|None = None
191
+ self.meta_info_filter:dict = {}
192
+ self._stimuli_dict:dict = {}
193
+ self._records:List[DataRecord] = []
194
+ self._preprocess_config = {}
195
+
196
+ def copy(self):
197
+ new_dataset = Dataset(
198
+ self.name,
199
+ self.srate
200
+ )
201
+ new_dataset.stim_feat_filter = copy.deepcopy(self.stim_feat_filter)
202
+ new_dataset.resp_chan_filter = copy.deepcopy(self.resp_chan_filter)
203
+ new_dataset.stim_id_cond = copy.deepcopy(self.stim_id_cond)
204
+ new_dataset.meta_info_filter = copy.deepcopy(self.meta_info_filter)
205
+ new_dataset._stimuli_dict = self._stimuli_dict
206
+ new_dataset._records = [r_.copy() for r_ in self._records]
207
+ return new_dataset
208
+
209
+ @property
210
+ def stimuli_dict(self):
211
+ return self._stimuli_dict
212
+
213
+ @stimuli_dict.setter
214
+ def stimuli_dict(self, x):
215
+ self._stimuli_dict = _validate_stimuli_dict(x)
216
+
217
+ @property
218
+ def records(self) -> List[DataRecord]:
219
+ if len(self.meta_info_filter) == 0:
220
+ return self._records
221
+ else:
222
+ return self._filter_records_by_info(self._records, self.meta_info_filter)
223
+
224
+ def _filter_records_by_info(self, records, meta_info_filter:dict):
225
+ output = list()
226
+ for record in records:
227
+ if all(
228
+ [
229
+ record.meta_info[k] == v if np.isscalar(v)
230
+ else record.meta_info[k] in v
231
+ for k,v in meta_info_filter.items()
232
+ ]
233
+ ):
234
+ output.append(record)
235
+ return output
236
+
237
+ def append(self, record:DataRecord):
238
+ assert record.srate == self.srate
239
+ self._records.append(record)
240
+
241
+ def _filter_stim_feat(self, stim_feat):
242
+ new_stim_feat = {}
243
+ if len(self.stim_feat_filter) == 0:
244
+ stim_feat_filter = stim_feat.keys()
245
+ else:
246
+ stim_feat_filter = self.stim_feat_filter
247
+ for i in stim_feat_filter:
248
+ new_stim_feat[i] = stim_feat[i]
249
+ return new_stim_feat
250
+
251
+ def _filter_resp_chan(self, resp):
252
+ if len(self.resp_chan_filter) > 0:
253
+ idxArr = np.array(self.resp_chan_filter)
254
+ output = resp[idxArr,:]
255
+ else:
256
+ output = resp
257
+ return output
258
+
259
+ def __getitem__(self, idx):
260
+ record:DataRecord = self.records[idx]
261
+ return self._unpack_record(record)
262
+
263
+ def _unpack_record(self, record:DataRecord):
264
+ stim_id, data = record.stim_id, record.data
265
+ if isinstance(stim_id, dict):
266
+ assert self.stim_id_cond is not None
267
+ stim_id = stim_id[self.stim_id_cond]
268
+ stim_feat = self._filter_stim_feat(self.stimuli_dict[stim_id])
269
+ data = self._filter_resp_chan(data)
270
+ return stim_feat, data, record.meta_info
271
+
272
+ def __len__(self):
273
+ return len(self.records)
274
+
275
+ def __iter__(self):
276
+ self.n = 0
277
+ return self
278
+
279
+ def __next__(self):
280
+ if self.n < len(self.records):
281
+ self.n += 1
282
+ return self.__getitem__(self.n-1)#self.records[self.n-1]
283
+ else:
284
+ raise StopIteration
285
+
286
+ def to_pairs(self, ifT = True):
287
+ allSubj = set([i.meta_info['subj_id'] for i in self.records])
288
+ filterKey = lambda x: x.meta_info['subj_id']
289
+ sortKey = lambda x : (
290
+ x.meta_info['dataset_name'],
291
+ x.meta_info['subj_id'],
292
+ x.meta_info['trial_id'],
293
+ )
294
+
295
+ records = sorted(self.records, key = sortKey)
296
+
297
+ transpose = lambda *arrs: [arr.T for arr in arrs]
298
+ def catstimarr(stim:dict):
299
+ keys = stim.keys()
300
+ # print(keys)
301
+ assert all([stim[k].shape[0] < stim[k].shape[1] for k in keys if isinstance(stim[k], np.ndarray)])
302
+ stim = [stim[k] for k in keys if isinstance(stim[k], np.ndarray)]
303
+ stim = align_data(*stim)
304
+ stim = np.concatenate(stim, axis = 0)
305
+ # print(stim.shape)
306
+ return stim
307
+
308
+ stims_subj = []
309
+ resps_subj = []
310
+ infoss = []
311
+ ks = []
312
+ for k, grp in itertools.groupby(records, filterKey):
313
+ stims, resps, infos = list(zip(*[self._unpack_record(g) for g in grp]))
314
+ # print(infos)
315
+ stims = list(map(catstimarr, stims))
316
+ stims, resps = list(zip(*map(align_data, stims, resps)))
317
+ if ifT:
318
+ stims, resps = list(zip(*map(transpose, stims, resps)))
319
+ stims_subj.append(stims)
320
+ resps_subj.append(resps)
321
+ infoss.append(infos)
322
+ ks.append(k)
323
+
324
+ return stims_subj, resps_subj, ks, infoss
325
+
326
+ def to_pairs_iter(self,sortKey = None):
327
+ allSubj = set([i.meta_info['subj_id'] for i in self.records])
328
+
329
+ filterKey = lambda x: x.meta_info['subj_id']
330
+ if sortKey is None:
331
+ sortKey = lambda x : (
332
+ x.meta_info['dataset_name'],
333
+ x.meta_info['subj_id'],
334
+ x.meta_info['trial_id'],
335
+ )
336
+
337
+ records = sorted(self.records, key = sortKey)
338
+
339
+ for k, grp in itertools.groupby(records, filterKey):
340
+ stims, resps, infos = list(zip(*[self._unpack_record(g) for g in grp]))
341
+ yield stims, resps, infos, k
342
+
343
+
344
+
345
+ def k_fold(self, cur_fold, n_folds, split_by, add_dev = True, if_shuffle = False):
346
+ return k_fold(self, cur_fold, n_folds, split_by, add_dev=add_dev, if_shuffle = if_shuffle)
347
+
348
+ def subset_by_info(self,meta_info_filter):
349
+ records = self._filter_records_by_info(
350
+ self._records, meta_info_filter)
351
+ state_dict = self.dump()
352
+ state_dict['_records'] = [l.dump() for l in records]
353
+ return self.__class__.load(state_dict)
354
+
355
+ def dump(self, file_path):
356
+ with h5py.File(file_path, "w") as f:
357
+ f.attrs["name"] = self.name
358
+ f.attrs["srate"] = self.srate
359
+ preprocess_config = json.dumps(self._preprocess_config)
360
+ f.attrs["preprocess_config_str"] = preprocess_config
361
+ for record in self._records:
362
+ data_record_to_h5py_group(
363
+ f = f,
364
+ **record.dump(),
365
+ )
366
+
367
+ @classmethod
368
+ def load(cls, file_path):
369
+ new_dataset = None
370
+ with h5py.File(file_path, "r") as f:
371
+ new_dataset = cls(
372
+ name = str(f.attrs['name']),
373
+ srate = int(f.attrs['srate']),
374
+ )
375
+ for k, grp in f['records'].items():
376
+ record_dict = data_record_from_h5py_group(grp)
377
+ new_record = DataRecord(**record_dict)
378
+ new_dataset.append(new_record)
379
+ return new_dataset
380
+
381
+ def dump_to_dict(self):
382
+ output = {}
383
+ output['_records'] = [l.dump() for l in self._records]
384
+ for k,v in self.__dict__.items():
385
+ if k != '_records':
386
+ if isinstance(v, dict):
387
+ output[k] = dump_dict_contains_nparray(v)
388
+ else:
389
+ output[k] = v
390
+ return output
391
+
392
+ @classmethod
393
+ def load_from_dict(cls, state):
394
+ output = cls(name = state['name'], srate = state['srate'])
395
+ for k,v in state.items():
396
+ if k == '_records':
397
+ output.__dict__['_records'] = [DataRecord.load(l) for l in state[k]]
398
+ else:
399
+ if isinstance(v, dict):
400
+ output.__dict__[k] = load_dict_contains_nparray(v)
401
+ else:
402
+ output.__dict__[k] = state[k]
403
+ return output
404
+
@@ -0,0 +1,101 @@
1
+ import os
2
+ import mne
3
+ import h5py
4
+ import json
5
+ import numpy as np
6
+ from collections import OrderedDict
7
+ from typing import Union
8
+ """
9
+ mne montage data class related
10
+ """
11
+
12
+ def mne_montage_to_h5py_group(montage:mne.channels.DigMontage, f:h5py.File):
13
+ montage_grp = f.require_group('montage')
14
+ pos_dict = montage.get_positions()
15
+ for k,v in pos_dict.items():
16
+ # print(k)
17
+ if k == 'ch_pos':
18
+ chs, ch_coords = list(zip(
19
+ *[
20
+ (ch, ch_coord)
21
+ for ch, ch_coord in v.items()
22
+ ]))
23
+ ch_coords = np.stack(ch_coords)
24
+ chs_json_str = json.dumps(chs)
25
+ # print(chs_json_str)
26
+ t_ds = montage_grp.create_dataset(k, data = ch_coords)
27
+ t_ds.attrs['chs_json_str'] = chs_json_str
28
+ elif k == 'coord_frame':
29
+ montage_grp.attrs['coord_frame'] = v
30
+ else:
31
+ if v is None:
32
+ v = np.array([])
33
+ montage_grp.create_dataset(k, data = v)
34
+ return f
35
+
36
+ def mne_montage_from_h5py_group(f:h5py.File):
37
+ pos_dict = {}
38
+ montage_grp = f['montage']
39
+ for k, v in montage_grp.items():
40
+ if k == 'ch_pos':
41
+ t_dict = OrderedDict()
42
+ t_ds = montage_grp[k]
43
+ ch_coords = t_ds[:]
44
+ chs = json.loads(t_ds.attrs['chs_json_str'])
45
+ for i_ch, ch in enumerate(chs):
46
+ t_dict[ch] = ch_coords[i_ch]
47
+ pos_dict[k] = t_dict
48
+ else:
49
+ # print(v.shape)
50
+ if v.shape == (0,):
51
+ pos_dict[k] = None
52
+ else:
53
+ pos_dict[k] = v[:]
54
+ pos_dict['coord_frame'] = montage_grp.attrs['coord_frame']
55
+ montage = mne.channels.make_dig_montage(**pos_dict)
56
+ return montage
57
+
58
+
59
+ """
60
+ tray DataRecord class related
61
+ """
62
+ def data_record_to_h5py_group(
63
+ key: str,
64
+ data: np.ndarray,
65
+ stim_id: Union[str, int],
66
+ meta_info:dict,
67
+ srate: int,
68
+ f:h5py.File
69
+ ):
70
+ root_grp = f.require_group(f'records/{key}')
71
+ root_grp.create_dataset('data', data = data)
72
+ root_grp.attrs['stim_id'] = stim_id
73
+ root_grp.attrs['srate'] = srate
74
+
75
+ meta_info_grp = root_grp.require_group('meta_info')
76
+ for k,v in meta_info.items():
77
+ if isinstance(v, np.ndarray):
78
+ meta_info_grp.create_dataset(k, data=v)
79
+ else:
80
+ meta_info_grp.attrs[k] = v
81
+
82
+ return f
83
+
84
+ def data_record_from_h5py_group(
85
+ f:h5py.File
86
+ ):
87
+ data = f['data'][:]
88
+ stim_id = f.attrs['stim_id']
89
+ srate = int(f.attrs['srate'])
90
+
91
+ meta_info_grp = f['meta_info']
92
+ meta_info = {}
93
+ for k,v in meta_info_grp.attrs.items():
94
+ meta_info[k] = v
95
+
96
+ for k,v in meta_info.items():
97
+ meta_info[k] = v
98
+
99
+ return dict(
100
+ data = data, stim_id = stim_id, meta_info = meta_info, srate = srate
101
+ )
@@ -0,0 +1,13 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Feb 6 08:33:37 2025
4
+
5
+ @author: ShiningStone
6
+ """
7
+
8
+
9
+ def check_import(package, name):
10
+ if package is None:
11
+ raise ImportError(f'{name} is failed to be imported')
12
+ return package
13
+
@@ -0,0 +1,187 @@
1
+ import mne
2
+ import numpy as np
3
+ from matplotlib import pyplot as plt
4
+ from matplotlib import gridspec
5
+
6
+
7
+ def plot_data(
8
+ data,
9
+ fs = 64,
10
+ times = None,
11
+ title = '',
12
+ chan_idx = None,
13
+ time_intvl = None,
14
+ units = 'a.u.',
15
+ montage = None,
16
+ mode = 'joint',
17
+ tmin = 0,
18
+
19
+ **kwargs
20
+ ):
21
+ data = np.array(data)
22
+ ifAx = False
23
+ if montage is None:
24
+ montage = mne.channels.make_standard_montage('biosemi128')
25
+ chnames_map = dict(
26
+ C17 = 'Fpz',
27
+ # C21 = 'Fz',
28
+ # A1 = 'Cz',
29
+ D23 = 'T7',
30
+ B26 = 'T8',
31
+ # A19 = 'Pz',
32
+ A23 = 'Oz'
33
+ )
34
+
35
+ for k,v in chnames_map.items():
36
+ montage.ch_names[montage.ch_names.index(k)] = v
37
+
38
+ chNames = montage.ch_names
39
+
40
+ # print(chNames)
41
+ try:
42
+ info = mne.create_info(chNames, fs,ch_types = 'eeg', montage = montage)
43
+ except:
44
+ info = mne.create_info(chNames, fs,ch_types = 'eeg')
45
+ info.set_montage(montage = montage)
46
+
47
+ kwargs['sensors'] = False if 'sensors' not in kwargs else kwargs['sensors']
48
+ kwargs['res'] = 256 if 'res' not in kwargs else kwargs['res']
49
+ kwargs['outlines'] ='head' if 'outlines' not in kwargs else kwargs['outlines']
50
+ show_names = kwargs.get('show_names', False)
51
+ if 'show_names' in kwargs:
52
+ del kwargs['show_names']
53
+ if show_names:
54
+ names = montage.ch_names
55
+ else:
56
+ names = None
57
+ ts_args = kwargs.get('ts_args', None)
58
+
59
+ if time_intvl is not None:
60
+ time_intvl = np.array(time_intvl)
61
+ if time_intvl.ndim == 1:
62
+ time_intvl = time_intvl[None,...]
63
+ # print(time_intvl)
64
+ #calculate the intersted time point and average time window
65
+ average_window = time_intvl[:,1] - time_intvl[:,0]
66
+ timepoint = time_intvl.mean(1)
67
+ else:
68
+ average_window = None
69
+
70
+ chanMask = None
71
+ if chan_idx is not None:
72
+ chanMask = np.zeros(data.shape,dtype = bool)
73
+ for i in chan_idx:
74
+ chanMask[i] = True
75
+
76
+ kwargs['cmap'] = plt.get_cmap("bwr") if 'cmap' not in kwargs else kwargs['cmap']
77
+ kwargs['show'] = False if 'show' not in kwargs else kwargs['show']
78
+
79
+ maskParam = dict(
80
+ marker='o',
81
+ markerfacecolor='w',
82
+ markeredgecolor='k',
83
+ linewidth=0,
84
+ markersize=8
85
+ )
86
+
87
+ maskParam2_default = dict(
88
+ marker='o',
89
+ markerfacecolor='w',
90
+ markeredgecolor='k',
91
+ linewidth=0,
92
+ markersize=4
93
+ )
94
+ maskParam2 = kwargs.get('maskParam', maskParam2_default)
95
+ if 'maskParam' in kwargs:
96
+ del kwargs['maskParam']
97
+
98
+ if data.ndim == 2:
99
+
100
+ default_ts_args={
101
+ "units": units,
102
+ "scalings": dict(eeg=1),
103
+ "highlight": time_intvl,
104
+ }
105
+
106
+ if ts_args is not None:
107
+ default_ts_args.update(ts_args)
108
+
109
+
110
+ mneW = mne.EvokedArray(data,info, tmin = tmin)
111
+ if montage is not None:
112
+ mneW.set_montage(montage)
113
+
114
+ if times is None:
115
+ if time_intvl is None:
116
+ if mode == 'joint':
117
+ times = 'peaks'
118
+ else:
119
+ times = 'auto'
120
+ else:
121
+ times = timepoint
122
+
123
+ if mode == 'joint':
124
+ print(default_ts_args)
125
+ fig = mneW.plot_joint(
126
+ times = times,
127
+ topomap_args=dict(
128
+ scalings = 1,
129
+ mask = chanMask,
130
+ mask_params= maskParam,
131
+ average = average_window
132
+ ),
133
+ ts_args = default_ts_args,
134
+ show = kwargs.get('show', True),
135
+ title = title
136
+ )
137
+ else:
138
+ fig = mneW.plot_topomap(
139
+ times = times,
140
+ time_unit='s',
141
+ scalings = 1,
142
+ title = title,
143
+ units = units,
144
+ cbar_fmt='%3.3f',
145
+ mask = chanMask,
146
+ mask_params= maskParam2,
147
+ colorbar=False,
148
+ # names = None,
149
+ **kwargs
150
+ )
151
+
152
+
153
+ elif data.ndim == 1:
154
+ if 'ax' in kwargs:
155
+ ax1 = kwargs['ax']
156
+ del kwargs['ax']
157
+ ifAx = True
158
+ else:
159
+ fig = plt.figure(tight_layout=True)
160
+ gridspec_kw={'width_ratios': [19, 1]}
161
+ gs = gridspec.GridSpec(4, 2,**gridspec_kw)
162
+ ax1 = fig.add_subplot(gs[:, 0])
163
+ ax2 = fig.add_subplot(gs[1:3, 1])
164
+ # print('contours')
165
+ im,cm = mne.viz.plot_topomap(
166
+ data.squeeze(),
167
+ info,
168
+ axes = ax1,
169
+ mask = chanMask,
170
+ names = names,
171
+ mask_params= maskParam2,
172
+ sphere = 'eeglab',
173
+ contours = 2,
174
+ **kwargs
175
+ )
176
+ # cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
177
+ if not ifAx:
178
+ clb = fig.colorbar(im, cax=ax2)
179
+ clb.ax.set_title(units,fontsize=10) # title on top of colorbar
180
+ fig.suptitle(title)
181
+ else:
182
+ raise NotImplementedError
183
+
184
+ if not ifAx:
185
+ return fig
186
+ else:
187
+ return im