pytour 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tour/dataclass/io.py ADDED
@@ -0,0 +1,225 @@
1
+ import os
2
+ import mne
3
+ import h5py
4
+ import json
5
+ import numpy as np
6
+ from collections import OrderedDict
7
+ from typing import Union, List
8
+ """
9
+ mne montage data class related
10
+ """
11
+
12
+ def _validate_stimuli_dict(stimuli_dict:dict):
13
+ for k in stimuli_dict:
14
+ stim:dict = stimuli_dict[k]
15
+ if not isinstance(stim, dict):
16
+ raise ValueError(f'value for stim {k} should be a dict')
17
+ for feat_k, feat_v in stim.items():
18
+ if isinstance(feat_v, dict):
19
+ assert all([s in feat_v for s in ['x', 'timeinfo', 'tag']])
20
+ else:
21
+ pass
22
+ # pattern = r"_fs\d+$"
23
+ # assert re.search(pattern, feat_k)
24
+ return stimuli_dict
25
+
26
+ def mne_montage_to_h5py_group(montage:mne.channels.DigMontage, f:h5py.File):
27
+ montage_grp = f.require_group('montage')
28
+ pos_dict = montage.get_positions()
29
+ for k,v in pos_dict.items():
30
+ # print(k)
31
+ if k == 'ch_pos':
32
+ chs, ch_coords = list(zip(
33
+ *[
34
+ (ch, ch_coord)
35
+ for ch, ch_coord in v.items()
36
+ ]))
37
+ ch_coords = np.stack(ch_coords)
38
+ chs_json_str = json.dumps(chs)
39
+ # print(chs_json_str)
40
+ t_ds = montage_grp.create_dataset(k, data = ch_coords)
41
+ t_ds.attrs['chs_json_str'] = chs_json_str
42
+ elif k == 'coord_frame':
43
+ montage_grp.attrs['coord_frame'] = v
44
+ else:
45
+ if v is None:
46
+ v = np.array([])
47
+ montage_grp.create_dataset(k, data = v)
48
+ return f
49
+
50
+ def mne_montage_from_h5py_group(f:h5py.File):
51
+ pos_dict = {}
52
+ montage_grp = f['montage']
53
+ for k, v in montage_grp.items():
54
+ if k == 'ch_pos':
55
+ t_dict = OrderedDict()
56
+ t_ds = montage_grp[k]
57
+ ch_coords = t_ds[:]
58
+ chs = json.loads(t_ds.attrs['chs_json_str'])
59
+ for i_ch, ch in enumerate(chs):
60
+ t_dict[ch] = ch_coords[i_ch]
61
+ pos_dict[k] = t_dict
62
+ else:
63
+ # print(v.shape)
64
+ if v.shape == (0,):
65
+ pos_dict[k] = None
66
+ else:
67
+ pos_dict[k] = v[:]
68
+ pos_dict['coord_frame'] = montage_grp.attrs['coord_frame']
69
+ montage = mne.channels.make_dig_montage(**pos_dict)
70
+ return montage
71
+
72
+
73
+ """
74
+ DataRecord class related
75
+ """
76
+ def data_record_to_h5py_group(
77
+ key: str,
78
+ data: np.ndarray,
79
+ stim_id: Union[str, int],
80
+ meta_info:dict,
81
+ srate: int,
82
+ f:h5py.File
83
+ ):
84
+ root_grp = f.require_group(f'records/{key}')
85
+ root_grp.create_dataset('data', data = data)
86
+ root_grp.attrs['stim_id'] = stim_id
87
+ root_grp.attrs['srate'] = srate
88
+
89
+ meta_info_grp = root_grp.require_group('meta_info')
90
+ for k,v in meta_info.items():
91
+ if isinstance(v, np.ndarray):
92
+ meta_info_grp.create_dataset(k, data=v)
93
+ else:
94
+ meta_info_grp.attrs[k] = v
95
+
96
+ return f
97
+
98
+ def data_record_from_h5py_group(
99
+ f:h5py.File
100
+ ):
101
+ data = f['data'][:]
102
+ stim_id = f.attrs['stim_id']
103
+ srate = int(f.attrs['srate'])
104
+
105
+ meta_info_grp = f['meta_info']
106
+ meta_info = {}
107
+ for k,v in meta_info_grp.attrs.items():
108
+ meta_info[k] = v
109
+
110
+ for k,v in meta_info.items():
111
+ meta_info[k] = v
112
+
113
+ return dict(
114
+ data = data, stim_id = stim_id, meta_info = meta_info, srate = srate
115
+ )
116
+
117
+ """
118
+ Stim Dict related
119
+ """
120
+
121
+ def check_list_of_string(data:List[str]):
122
+ assert isinstance(data, list)
123
+ assert all([isinstance(i, str) for i in data])
124
+ return data
125
+
126
+ def stim_dict_to_hdf5(
127
+ filename:str,
128
+ stim_dict: dict,
129
+ attrs:dict = None,
130
+ ):
131
+ with h5py.File(filename, 'a') as hdf5f:
132
+ _validate_stimuli_dict(stim_dict)
133
+ for stim_id in stim_dict:
134
+ grp = hdf5f.require_group(stim_id)
135
+ for feat_name in stim_dict[stim_id]:
136
+ assert feat_name not in grp
137
+ data = stim_dict[stim_id][feat_name]
138
+ if isinstance(data, np.ndarray):
139
+ dataset = grp.create_dataset(feat_name, data = data)
140
+ elif isinstance(data, dict):
141
+ dataset = discrete_stim_to_hdf5(
142
+ feat_name=feat_name,
143
+ feat_dict=data,
144
+ hdf5f=grp
145
+ )
146
+ else:
147
+ raise TypeError
148
+
149
+ if attrs is not None:
150
+ dataset.attrs.update(attrs[stim_id][feat_name])
151
+
152
+ def stim_dict_from_hdf5(
153
+ filename:str,
154
+ ) -> dict:
155
+ stim_dict = {}
156
+ with h5py.File(filename, 'r') as hdf5f:
157
+ for stim_id, stim_grp in hdf5f.items():
158
+ stim_dict[stim_id] = {}
159
+ for k, v in stim_grp.items():
160
+ if isinstance(v, h5py.Dataset):
161
+ stim_dict[stim_id][k] = v[:]
162
+ elif isinstance(v, h5py.Group):
163
+ stim_dict[stim_id][k] = discrete_stim_from_hdf5(v)
164
+ else:
165
+ raise TypeError
166
+ _validate_stimuli_dict(stim_dict)
167
+ return stim_dict
168
+
169
+ def discrete_stim_to_hdf5(
170
+ feat_name:str,
171
+ feat_dict:dict,
172
+ hdf5f:h5py.Group
173
+ ) -> h5py.Group:
174
+ """
175
+ {
176
+ 'x': None,
177
+ 'tag':None,
178
+ 'timeinfo':None
179
+ }
180
+ """
181
+ grp = hdf5f.require_group(feat_name)
182
+ for k,v in feat_dict.items():
183
+ if isinstance(v, np.ndarray):
184
+ grp.create_dataset(k, data = v)
185
+ elif check_list_of_string(v):
186
+ string_list_to_hdf5(k, v, grp)
187
+ else:
188
+ raise TypeError
189
+ return grp
190
+
191
+ def discrete_stim_from_hdf5(
192
+ hdf5f:h5py.Group,
193
+ ):
194
+ stim_dict = {}
195
+ for k,v in hdf5f.items():
196
+ if isinstance(v, h5py.Dataset):
197
+ v:h5py.Dataset
198
+ if v.dtype == h5py.string_dtype(encoding='utf-8'):
199
+ stim_dict[k] = string_list_from_hdf5(
200
+ v
201
+ )
202
+ elif v.dtype:
203
+ stim_dict[k] = v[:]
204
+ else:
205
+ raise ValueError
206
+ return stim_dict
207
+
208
+
209
+ def string_list_to_hdf5(
210
+ dataset_name:str,
211
+ strings: List[str],
212
+ f:h5py.Dataset
213
+ ):
214
+ '''
215
+ from chatGPT
216
+ '''
217
+ dt = h5py.string_dtype(encoding='utf-8')
218
+ # Create dataset
219
+ f.require_dataset(dataset_name, (len(strings),), dtype=dt, data = strings)
220
+ return f
221
+
222
+ def string_list_from_hdf5(
223
+ dt:h5py.Dataset
224
+ ):
225
+ return dt.asstr()[:].tolist()
tour/dataclass/stim.py ADDED
@@ -0,0 +1,33 @@
1
+ from typing import Dict
2
+
3
+ from ..backend import is_tensor, np, torch, Array
4
+
5
+
6
+ def to_impulses(x:Array, timeinfo:Array, f:float, padding_s:float = 0):
7
+ '''
8
+ # align the vectors into impulses with specific sampling rate
9
+ '''
10
+ if is_tensor(x):
11
+ assert is_tensor(timeinfo)
12
+ else:
13
+ assert not is_tensor(timeinfo)
14
+ startTimes = timeinfo[0]
15
+ endTimes = timeinfo[1]
16
+ secLen = endTimes[-1] + padding_s
17
+ nDim = x.shape[0]
18
+ if is_tensor(x):
19
+ nLen = torch.ceil(secLen * f).long()
20
+ out = torch.zeros((nDim, nLen), dtype=x.dtype)
21
+ timeIndices = torch.round(startTimes * f).long()
22
+ else:
23
+ nLen = np.ceil( secLen * f).astype(int)
24
+ out = np.zeros((nDim, nLen), dtype=x.dtype)
25
+ timeIndices = np.round(startTimes * f).astype(int)
26
+ out[:,timeIndices] = x
27
+ return out
28
+
29
+ def dictTensor_to(x:Dict[str, Array], device):
30
+ output = {
31
+ k:v.to(device) if is_tensor(v) else v for k,v in x.items()
32
+ }
33
+ return output
tour/package_manage.py ADDED
@@ -0,0 +1,13 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Feb 6 08:33:37 2025
4
+
5
+ @author: ShiningStone
6
+ """
7
+
8
+
9
+ def check_import(package, name):
10
+ if package is None:
11
+ raise ImportError(f'{name} is failed to be imported')
12
+ return package
13
+
tour/torch_trainer.py ADDED
@@ -0,0 +1,339 @@
1
+ import os
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import numpy as np
6
+ from itertools import chain
7
+ from typing import Callable, List, Union, Protocol
8
+
9
+ def func_reduce_mean(values):
10
+ # print(torch.cat(values).shape)
11
+ if values[0].ndim == 0:
12
+ return torch.mean(torch.stack(values), dim = 0)
13
+ else:
14
+ return torch.mean(torch.cat(values), dim = 0)
15
+
16
+ def get_logger(
17
+ file_dir,
18
+ console_level=logging.INFO,
19
+ file_level=logging.DEBUG,
20
+ file_name="logfile.log",
21
+ if_print = True,
22
+ ):
23
+ #adopt from chat-gpt
24
+
25
+ file_path = f"{file_dir}/{file_name}"
26
+ logger = logging.getLogger('tray/trainer')
27
+ logger.setLevel(logging.DEBUG) # master level: allow all through to handlers
28
+ logger.handlers.clear() # prevent duplicate handlers on re-run
29
+
30
+ formatter = logging.Formatter(
31
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ datefmt='%Y-%m-%d %H:%M:%S'
33
+ )
34
+
35
+ if if_print:
36
+ # Console handler
37
+ console_handler = logging.StreamHandler(sys.stdout)
38
+ console_handler.setLevel(console_level)
39
+ console_handler.setFormatter(formatter)
40
+ logger.addHandler(console_handler)
41
+
42
+ # File handler
43
+ file_handler = logging.FileHandler(file_path)
44
+ file_handler.setLevel(file_level)
45
+ file_handler.setFormatter(formatter)
46
+ logger.addHandler(file_handler)
47
+
48
+ return logger
49
+
50
+ class DependentModule(Protocol):
51
+
52
+ def load_state(self, state:dict) -> dict:...
53
+ def get_state(self) -> dict:...
54
+
55
+ class BatchAccumulator:
56
+
57
+ def __init__(self,):
58
+ self._data:list = []
59
+
60
+ def append(self, output):
61
+ self._data.append(output)
62
+
63
+ @property
64
+ def data(self):
65
+ # concatenate along the batch dimension
66
+ return torch.cat(self._data)
67
+
68
+ class MetricsRecord:
69
+
70
+ def __init__(self,):
71
+ self._data = {}
72
+
73
+ def append(self, metricDict:dict, tag:str = ''):
74
+ data = self._data
75
+ for k in metricDict:
76
+ if tag != '':
77
+ real_k = tag + '/' + k
78
+ else:
79
+ real_k = k
80
+ if real_k not in data:
81
+ data[real_k] = []
82
+ data[real_k].append(metricDict[k].cpu())
83
+
84
+ def __iter__(self):
85
+ return iter(self._data.keys())
86
+
87
+ def __getitem__(self, key):
88
+ return self._data[key]
89
+
90
+ def items(self):
91
+ for k, v in self._data.items():
92
+ yield k,v
93
+
94
+ def ndarrays_to_tensors(*datas:List[np.ndarray]):
95
+ # the resulted tensor will share the same memory as the array
96
+ return [
97
+ [
98
+ torch.from_numpy(d) if not np.isscalar(d) else torch.tensor(d, dtype=torch.get_default_dtype())
99
+ for d in data
100
+ ]
101
+ for data in datas
102
+ ]
103
+
104
+ class StimRespDataset(torch.utils.data.Dataset):
105
+
106
+ def __init__(self,
107
+ stims:Union[List[np.ndarray], List[torch.Tensor]],
108
+ resps:Union[List[np.ndarray], List[torch.Tensor]],
109
+ device = 'cpu'
110
+ ):
111
+ if isinstance(stims[0], np.ndarray):
112
+ stims, resps = ndarrays_to_tensors(stims, resps)
113
+ self.stims = stims
114
+ self.resps = resps
115
+ self.device = device
116
+ assert len(stims) == len(resps)
117
+
118
+ def __getitem__(self, index:int):
119
+ return self.stims[index].to(self.device), self.resps[index].to(self.device)
120
+
121
+ def __len__(self):
122
+ return len(self.stims)
123
+
124
+ class Context:
125
+
126
+ def __init__(
127
+ self,
128
+ model:torch.nn.Module,
129
+ optimizer:torch.optim.Optimizer,
130
+ func_metrics: Callable,
131
+ checkpoint_folder: str,
132
+ checkpoint_file = "checkpoint.pt",
133
+ custom_config = {},
134
+ if_print_metric = True,
135
+ ):
136
+ self.model = model
137
+ self.optimizer = optimizer
138
+ self.state_current_epoch = -1
139
+ self.func_metrics = func_metrics
140
+ self.checkpoint_folder = checkpoint_folder
141
+ self.checkpoint_file = checkpoint_file
142
+ self.metrics_log = MetricsRecord()
143
+ self.custom_config = custom_config
144
+ self.logger = get_logger(checkpoint_folder, if_print=if_print_metric)
145
+
146
+ self.dependents:List[DependentModule] = []
147
+
148
+ def add_dependent(self, module:DependentModule):
149
+ self.dependents.append(module)
150
+
151
+ def new_epochs(self):
152
+ self.state_current_epoch += 1
153
+
154
+ def checkpoint_exists(self):
155
+ return os.path.exists(self.checkpoint_path)
156
+
157
+ def save_checkpoint(self):
158
+ checkpoint = {}
159
+ for module in self.dependents:
160
+ checkpoint[module.__class__.__name__] = module.get_state()
161
+ checkpoint['context'] = self.get_state()
162
+ torch.save(checkpoint, self.checkpoint_path)
163
+
164
+ def load_checkpoint(self):
165
+ checkpoint = torch.load(self.checkpoint_path)
166
+ self.load_state(checkpoint['context'])
167
+ for module in self.dependents:
168
+ module.load_state(checkpoint[module.__class__.__name__])
169
+
170
+ @property
171
+ def checkpoint_path(self):
172
+ return f'{self.checkpoint_folder}/{self.checkpoint_file}'
173
+
174
+ def log_metrics(
175
+ self,
176
+ metrics,
177
+ tag = ''
178
+ ):
179
+ scalar_metrics = {k:v.item() for k,v in metrics.items() if v.numel() == 1}
180
+ metrics_log = ''
181
+ for k,v in scalar_metrics.items():
182
+ metrics_log += f'{k}:{v} '
183
+ self.logger.info(f"epochs:{self.state_current_epoch} - {tag} - {metrics_log}")
184
+ self.metrics_log.append(metrics,tag)
185
+
186
+ def new_metrics_record(self):
187
+ return MetricsRecord()
188
+
189
+ def evaluate_dataloader(
190
+ self,
191
+ tag:str,
192
+ dtldr:torch.utils.data.DataLoader,
193
+ forward_function: Callable,
194
+ f_reduce_metrics_records = func_reduce_mean,
195
+ save_in_context = False,
196
+ batch_hook:List[Callable] = [],
197
+ output_hook:List[Callable] = []
198
+ ):
199
+ new_log = MetricsRecord()
200
+ is_model_training = self.model.training
201
+ with torch.no_grad():
202
+ for batch in dtldr:
203
+ for f_batch in batch_hook:
204
+ f_batch(batch)
205
+ self.model.eval()
206
+ output = forward_function(self.model, batch)
207
+ for f_output in output_hook:
208
+ f_output(output)
209
+ metrics_dict = self.func_metrics(
210
+ batch,
211
+ output
212
+ )
213
+ new_log.append(
214
+ metrics_dict,
215
+ )
216
+ # print([i.shape for i in new_log['loss']])
217
+ # print(torch.cat(new_log['loss']).shape)
218
+ if is_model_training:
219
+ self.model.train()
220
+ else:
221
+ self.model.eval()
222
+
223
+ reduced_record = {k: f_reduce_metrics_records(v) for k, v in new_log.items()}
224
+ if save_in_context:
225
+ self.log_metrics(reduced_record, tag)
226
+
227
+ output_record = {}
228
+ for k,v in reduced_record.items():
229
+ if tag != '':
230
+ real_k = tag + '/' + k
231
+ else:
232
+ real_k = k
233
+ output_record[real_k] = v.cpu()
234
+ scalar_metrics = {k:v.item() for k,v in output_record.items() if v.numel() == 1}
235
+ return output_record, scalar_metrics
236
+
237
+ def get_state(self):
238
+ state = {
239
+ 'model_state_dict': self.model.state_dict(),
240
+ 'optim_state_dict': self.optimizer.state_dict(),
241
+ 'state_current_epoch': self.state_current_epoch,
242
+ 'custom_config':self.custom_config
243
+ }
244
+ return state
245
+
246
+ def load_state(self, state):
247
+ self.model.load_state_dict(state['model_state_dict'])
248
+ self.optimizer.load_state_dict(state['optim_state_dict'])
249
+ self.state_current_epoch = state['state_current_epoch']
250
+ self.custom_config = state['custom_config']
251
+
252
+ class SaveBest:
253
+ def __init__(
254
+ self,
255
+ ctx:Context,
256
+ state_metric_name,
257
+ op = lambda old, new: new > old,
258
+ tol = None,
259
+ ifLog = True,
260
+ file_name = "save_best.pt"
261
+ ):
262
+ self.ctx = ctx
263
+ ctx.add_dependent(self)
264
+ self.state_cnt = 0
265
+ self.state_best_cnt = -1
266
+ self.state_best_metric = None
267
+ self.state_metric_name = state_metric_name
268
+
269
+ self.op = op
270
+ self.tol = tol
271
+ self.saved_checkpoint= None
272
+ self.ifLog = ifLog
273
+ self.file_name = file_name
274
+
275
+ @property
276
+ def target_path(self):
277
+ return f'{self.ctx.checkpoint_folder}/{self.file_name}'
278
+
279
+ def get_state(self):
280
+ output = {}
281
+ for k,v in self.__dict__.items():
282
+ if k.startswith('state_'):
283
+ output[k] = v
284
+ return output
285
+
286
+ def load_state(self, state):
287
+ for k,v in self.__dict__.items():
288
+ if k.startswith('state_'):
289
+ self.__dict__[k] = state[k]
290
+
291
+ def step(self,):
292
+ t_metric = self.ctx.metrics_log[self.state_metric_name][-1]
293
+ assert t_metric.ndim == 0 or (t_metric.ndim == 1 and t_metric.shape[0] == 1), t_metric.shape
294
+ t_metric = t_metric.item()
295
+ t_cnt = self.state_cnt
296
+ ifUpdate = False
297
+ ifStop = False
298
+ if self.state_best_metric is None:
299
+ ifUpdate = True
300
+ else:
301
+ ifUpdate = self.op(self.state_best_metric, t_metric)
302
+ if ifUpdate:
303
+ self.state_best_metric = t_metric
304
+ self.state_best_cnt = t_cnt
305
+ checkpoint = {}
306
+ checkpoint.update(self.ctx.get_state())
307
+ checkpoint.update(self.get_state())
308
+
309
+ if self.ifLog:
310
+ msg = f'save_best --- cnt: {self.state_best_cnt}, {self.state_metric_name}: {self.state_best_metric}'
311
+ self.ctx.logger.info(msg)
312
+ torch.save(checkpoint, self.target_path)
313
+ self.saved_checkpoint = checkpoint
314
+
315
+ if self.tol is not None:
316
+ if self.state_cnt - self.state_best_cnt > self.tol:
317
+ ifStop = True
318
+ msg = f'early_stop --- epoch: {self.state_best_cnt}, metric: {self.state_best_metric}'
319
+ self.ctx.logger.info(msg)
320
+ self.state_cnt += 1
321
+ return ifUpdate, ifStop
322
+
323
+
324
+ def pearsonr(y, y_pred):
325
+ """
326
+ Compute Pearson's correlation coefficient between predicted
327
+ and observed data
328
+
329
+ y: (..., n_samples, n_chans)
330
+ y_pred: (..., n_samples, n_chans)
331
+ """
332
+ r = torch.mean(
333
+ (y - y.mean(-2, keepdims = True)) * (y_pred - y_pred.mean(-2, keepdims = True)),
334
+ -2
335
+ ) / (
336
+ y.std(-2) * y_pred.std(-2)
337
+ )
338
+ return r
339
+