pytour 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytour-3.0.0.dist-info/METADATA +27 -0
- pytour-3.0.0.dist-info/RECORD +15 -0
- pytour-3.0.0.dist-info/WHEEL +5 -0
- pytour-3.0.0.dist-info/licenses/LICENSE +21 -0
- pytour-3.0.0.dist-info/top_level.txt +1 -0
- tour/__init__.py +1 -0
- tour/artifacts_removal.py +122 -0
- tour/backend.py +34 -0
- tour/dataclass/__init__.py +0 -0
- tour/dataclass/dataset.py +465 -0
- tour/dataclass/io.py +225 -0
- tour/dataclass/stim.py +33 -0
- tour/package_manage.py +13 -0
- tour/torch_trainer.py +339 -0
- tour/vis.py +201 -0
tour/dataclass/io.py
ADDED
@@ -0,0 +1,225 @@
|
|
1
|
+
import os
|
2
|
+
import mne
|
3
|
+
import h5py
|
4
|
+
import json
|
5
|
+
import numpy as np
|
6
|
+
from collections import OrderedDict
|
7
|
+
from typing import Union, List
|
8
|
+
"""
|
9
|
+
mne montage data class related
|
10
|
+
"""
|
11
|
+
|
12
|
+
def _validate_stimuli_dict(stimuli_dict:dict):
|
13
|
+
for k in stimuli_dict:
|
14
|
+
stim:dict = stimuli_dict[k]
|
15
|
+
if not isinstance(stim, dict):
|
16
|
+
raise ValueError(f'value for stim {k} should be a dict')
|
17
|
+
for feat_k, feat_v in stim.items():
|
18
|
+
if isinstance(feat_v, dict):
|
19
|
+
assert all([s in feat_v for s in ['x', 'timeinfo', 'tag']])
|
20
|
+
else:
|
21
|
+
pass
|
22
|
+
# pattern = r"_fs\d+$"
|
23
|
+
# assert re.search(pattern, feat_k)
|
24
|
+
return stimuli_dict
|
25
|
+
|
26
|
+
def mne_montage_to_h5py_group(montage:mne.channels.DigMontage, f:h5py.File):
|
27
|
+
montage_grp = f.require_group('montage')
|
28
|
+
pos_dict = montage.get_positions()
|
29
|
+
for k,v in pos_dict.items():
|
30
|
+
# print(k)
|
31
|
+
if k == 'ch_pos':
|
32
|
+
chs, ch_coords = list(zip(
|
33
|
+
*[
|
34
|
+
(ch, ch_coord)
|
35
|
+
for ch, ch_coord in v.items()
|
36
|
+
]))
|
37
|
+
ch_coords = np.stack(ch_coords)
|
38
|
+
chs_json_str = json.dumps(chs)
|
39
|
+
# print(chs_json_str)
|
40
|
+
t_ds = montage_grp.create_dataset(k, data = ch_coords)
|
41
|
+
t_ds.attrs['chs_json_str'] = chs_json_str
|
42
|
+
elif k == 'coord_frame':
|
43
|
+
montage_grp.attrs['coord_frame'] = v
|
44
|
+
else:
|
45
|
+
if v is None:
|
46
|
+
v = np.array([])
|
47
|
+
montage_grp.create_dataset(k, data = v)
|
48
|
+
return f
|
49
|
+
|
50
|
+
def mne_montage_from_h5py_group(f:h5py.File):
|
51
|
+
pos_dict = {}
|
52
|
+
montage_grp = f['montage']
|
53
|
+
for k, v in montage_grp.items():
|
54
|
+
if k == 'ch_pos':
|
55
|
+
t_dict = OrderedDict()
|
56
|
+
t_ds = montage_grp[k]
|
57
|
+
ch_coords = t_ds[:]
|
58
|
+
chs = json.loads(t_ds.attrs['chs_json_str'])
|
59
|
+
for i_ch, ch in enumerate(chs):
|
60
|
+
t_dict[ch] = ch_coords[i_ch]
|
61
|
+
pos_dict[k] = t_dict
|
62
|
+
else:
|
63
|
+
# print(v.shape)
|
64
|
+
if v.shape == (0,):
|
65
|
+
pos_dict[k] = None
|
66
|
+
else:
|
67
|
+
pos_dict[k] = v[:]
|
68
|
+
pos_dict['coord_frame'] = montage_grp.attrs['coord_frame']
|
69
|
+
montage = mne.channels.make_dig_montage(**pos_dict)
|
70
|
+
return montage
|
71
|
+
|
72
|
+
|
73
|
+
"""
|
74
|
+
DataRecord class related
|
75
|
+
"""
|
76
|
+
def data_record_to_h5py_group(
|
77
|
+
key: str,
|
78
|
+
data: np.ndarray,
|
79
|
+
stim_id: Union[str, int],
|
80
|
+
meta_info:dict,
|
81
|
+
srate: int,
|
82
|
+
f:h5py.File
|
83
|
+
):
|
84
|
+
root_grp = f.require_group(f'records/{key}')
|
85
|
+
root_grp.create_dataset('data', data = data)
|
86
|
+
root_grp.attrs['stim_id'] = stim_id
|
87
|
+
root_grp.attrs['srate'] = srate
|
88
|
+
|
89
|
+
meta_info_grp = root_grp.require_group('meta_info')
|
90
|
+
for k,v in meta_info.items():
|
91
|
+
if isinstance(v, np.ndarray):
|
92
|
+
meta_info_grp.create_dataset(k, data=v)
|
93
|
+
else:
|
94
|
+
meta_info_grp.attrs[k] = v
|
95
|
+
|
96
|
+
return f
|
97
|
+
|
98
|
+
def data_record_from_h5py_group(
|
99
|
+
f:h5py.File
|
100
|
+
):
|
101
|
+
data = f['data'][:]
|
102
|
+
stim_id = f.attrs['stim_id']
|
103
|
+
srate = int(f.attrs['srate'])
|
104
|
+
|
105
|
+
meta_info_grp = f['meta_info']
|
106
|
+
meta_info = {}
|
107
|
+
for k,v in meta_info_grp.attrs.items():
|
108
|
+
meta_info[k] = v
|
109
|
+
|
110
|
+
for k,v in meta_info.items():
|
111
|
+
meta_info[k] = v
|
112
|
+
|
113
|
+
return dict(
|
114
|
+
data = data, stim_id = stim_id, meta_info = meta_info, srate = srate
|
115
|
+
)
|
116
|
+
|
117
|
+
"""
|
118
|
+
Stim Dict related
|
119
|
+
"""
|
120
|
+
|
121
|
+
def check_list_of_string(data:List[str]):
|
122
|
+
assert isinstance(data, list)
|
123
|
+
assert all([isinstance(i, str) for i in data])
|
124
|
+
return data
|
125
|
+
|
126
|
+
def stim_dict_to_hdf5(
|
127
|
+
filename:str,
|
128
|
+
stim_dict: dict,
|
129
|
+
attrs:dict = None,
|
130
|
+
):
|
131
|
+
with h5py.File(filename, 'a') as hdf5f:
|
132
|
+
_validate_stimuli_dict(stim_dict)
|
133
|
+
for stim_id in stim_dict:
|
134
|
+
grp = hdf5f.require_group(stim_id)
|
135
|
+
for feat_name in stim_dict[stim_id]:
|
136
|
+
assert feat_name not in grp
|
137
|
+
data = stim_dict[stim_id][feat_name]
|
138
|
+
if isinstance(data, np.ndarray):
|
139
|
+
dataset = grp.create_dataset(feat_name, data = data)
|
140
|
+
elif isinstance(data, dict):
|
141
|
+
dataset = discrete_stim_to_hdf5(
|
142
|
+
feat_name=feat_name,
|
143
|
+
feat_dict=data,
|
144
|
+
hdf5f=grp
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
raise TypeError
|
148
|
+
|
149
|
+
if attrs is not None:
|
150
|
+
dataset.attrs.update(attrs[stim_id][feat_name])
|
151
|
+
|
152
|
+
def stim_dict_from_hdf5(
|
153
|
+
filename:str,
|
154
|
+
) -> dict:
|
155
|
+
stim_dict = {}
|
156
|
+
with h5py.File(filename, 'r') as hdf5f:
|
157
|
+
for stim_id, stim_grp in hdf5f.items():
|
158
|
+
stim_dict[stim_id] = {}
|
159
|
+
for k, v in stim_grp.items():
|
160
|
+
if isinstance(v, h5py.Dataset):
|
161
|
+
stim_dict[stim_id][k] = v[:]
|
162
|
+
elif isinstance(v, h5py.Group):
|
163
|
+
stim_dict[stim_id][k] = discrete_stim_from_hdf5(v)
|
164
|
+
else:
|
165
|
+
raise TypeError
|
166
|
+
_validate_stimuli_dict(stim_dict)
|
167
|
+
return stim_dict
|
168
|
+
|
169
|
+
def discrete_stim_to_hdf5(
|
170
|
+
feat_name:str,
|
171
|
+
feat_dict:dict,
|
172
|
+
hdf5f:h5py.Group
|
173
|
+
) -> h5py.Group:
|
174
|
+
"""
|
175
|
+
{
|
176
|
+
'x': None,
|
177
|
+
'tag':None,
|
178
|
+
'timeinfo':None
|
179
|
+
}
|
180
|
+
"""
|
181
|
+
grp = hdf5f.require_group(feat_name)
|
182
|
+
for k,v in feat_dict.items():
|
183
|
+
if isinstance(v, np.ndarray):
|
184
|
+
grp.create_dataset(k, data = v)
|
185
|
+
elif check_list_of_string(v):
|
186
|
+
string_list_to_hdf5(k, v, grp)
|
187
|
+
else:
|
188
|
+
raise TypeError
|
189
|
+
return grp
|
190
|
+
|
191
|
+
def discrete_stim_from_hdf5(
|
192
|
+
hdf5f:h5py.Group,
|
193
|
+
):
|
194
|
+
stim_dict = {}
|
195
|
+
for k,v in hdf5f.items():
|
196
|
+
if isinstance(v, h5py.Dataset):
|
197
|
+
v:h5py.Dataset
|
198
|
+
if v.dtype == h5py.string_dtype(encoding='utf-8'):
|
199
|
+
stim_dict[k] = string_list_from_hdf5(
|
200
|
+
v
|
201
|
+
)
|
202
|
+
elif v.dtype:
|
203
|
+
stim_dict[k] = v[:]
|
204
|
+
else:
|
205
|
+
raise ValueError
|
206
|
+
return stim_dict
|
207
|
+
|
208
|
+
|
209
|
+
def string_list_to_hdf5(
|
210
|
+
dataset_name:str,
|
211
|
+
strings: List[str],
|
212
|
+
f:h5py.Dataset
|
213
|
+
):
|
214
|
+
'''
|
215
|
+
from chatGPT
|
216
|
+
'''
|
217
|
+
dt = h5py.string_dtype(encoding='utf-8')
|
218
|
+
# Create dataset
|
219
|
+
f.require_dataset(dataset_name, (len(strings),), dtype=dt, data = strings)
|
220
|
+
return f
|
221
|
+
|
222
|
+
def string_list_from_hdf5(
|
223
|
+
dt:h5py.Dataset
|
224
|
+
):
|
225
|
+
return dt.asstr()[:].tolist()
|
tour/dataclass/stim.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Dict
|
2
|
+
|
3
|
+
from ..backend import is_tensor, np, torch, Array
|
4
|
+
|
5
|
+
|
6
|
+
def to_impulses(x:Array, timeinfo:Array, f:float, padding_s:float = 0):
|
7
|
+
'''
|
8
|
+
# align the vectors into impulses with specific sampling rate
|
9
|
+
'''
|
10
|
+
if is_tensor(x):
|
11
|
+
assert is_tensor(timeinfo)
|
12
|
+
else:
|
13
|
+
assert not is_tensor(timeinfo)
|
14
|
+
startTimes = timeinfo[0]
|
15
|
+
endTimes = timeinfo[1]
|
16
|
+
secLen = endTimes[-1] + padding_s
|
17
|
+
nDim = x.shape[0]
|
18
|
+
if is_tensor(x):
|
19
|
+
nLen = torch.ceil(secLen * f).long()
|
20
|
+
out = torch.zeros((nDim, nLen), dtype=x.dtype)
|
21
|
+
timeIndices = torch.round(startTimes * f).long()
|
22
|
+
else:
|
23
|
+
nLen = np.ceil( secLen * f).astype(int)
|
24
|
+
out = np.zeros((nDim, nLen), dtype=x.dtype)
|
25
|
+
timeIndices = np.round(startTimes * f).astype(int)
|
26
|
+
out[:,timeIndices] = x
|
27
|
+
return out
|
28
|
+
|
29
|
+
def dictTensor_to(x:Dict[str, Array], device):
|
30
|
+
output = {
|
31
|
+
k:v.to(device) if is_tensor(v) else v for k,v in x.items()
|
32
|
+
}
|
33
|
+
return output
|
tour/package_manage.py
ADDED
tour/torch_trainer.py
ADDED
@@ -0,0 +1,339 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
import torch
|
4
|
+
import logging
|
5
|
+
import numpy as np
|
6
|
+
from itertools import chain
|
7
|
+
from typing import Callable, List, Union, Protocol
|
8
|
+
|
9
|
+
def func_reduce_mean(values):
|
10
|
+
# print(torch.cat(values).shape)
|
11
|
+
if values[0].ndim == 0:
|
12
|
+
return torch.mean(torch.stack(values), dim = 0)
|
13
|
+
else:
|
14
|
+
return torch.mean(torch.cat(values), dim = 0)
|
15
|
+
|
16
|
+
def get_logger(
|
17
|
+
file_dir,
|
18
|
+
console_level=logging.INFO,
|
19
|
+
file_level=logging.DEBUG,
|
20
|
+
file_name="logfile.log",
|
21
|
+
if_print = True,
|
22
|
+
):
|
23
|
+
#adopt from chat-gpt
|
24
|
+
|
25
|
+
file_path = f"{file_dir}/{file_name}"
|
26
|
+
logger = logging.getLogger('tray/trainer')
|
27
|
+
logger.setLevel(logging.DEBUG) # master level: allow all through to handlers
|
28
|
+
logger.handlers.clear() # prevent duplicate handlers on re-run
|
29
|
+
|
30
|
+
formatter = logging.Formatter(
|
31
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
32
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
33
|
+
)
|
34
|
+
|
35
|
+
if if_print:
|
36
|
+
# Console handler
|
37
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
38
|
+
console_handler.setLevel(console_level)
|
39
|
+
console_handler.setFormatter(formatter)
|
40
|
+
logger.addHandler(console_handler)
|
41
|
+
|
42
|
+
# File handler
|
43
|
+
file_handler = logging.FileHandler(file_path)
|
44
|
+
file_handler.setLevel(file_level)
|
45
|
+
file_handler.setFormatter(formatter)
|
46
|
+
logger.addHandler(file_handler)
|
47
|
+
|
48
|
+
return logger
|
49
|
+
|
50
|
+
class DependentModule(Protocol):
|
51
|
+
|
52
|
+
def load_state(self, state:dict) -> dict:...
|
53
|
+
def get_state(self) -> dict:...
|
54
|
+
|
55
|
+
class BatchAccumulator:
|
56
|
+
|
57
|
+
def __init__(self,):
|
58
|
+
self._data:list = []
|
59
|
+
|
60
|
+
def append(self, output):
|
61
|
+
self._data.append(output)
|
62
|
+
|
63
|
+
@property
|
64
|
+
def data(self):
|
65
|
+
# concatenate along the batch dimension
|
66
|
+
return torch.cat(self._data)
|
67
|
+
|
68
|
+
class MetricsRecord:
|
69
|
+
|
70
|
+
def __init__(self,):
|
71
|
+
self._data = {}
|
72
|
+
|
73
|
+
def append(self, metricDict:dict, tag:str = ''):
|
74
|
+
data = self._data
|
75
|
+
for k in metricDict:
|
76
|
+
if tag != '':
|
77
|
+
real_k = tag + '/' + k
|
78
|
+
else:
|
79
|
+
real_k = k
|
80
|
+
if real_k not in data:
|
81
|
+
data[real_k] = []
|
82
|
+
data[real_k].append(metricDict[k].cpu())
|
83
|
+
|
84
|
+
def __iter__(self):
|
85
|
+
return iter(self._data.keys())
|
86
|
+
|
87
|
+
def __getitem__(self, key):
|
88
|
+
return self._data[key]
|
89
|
+
|
90
|
+
def items(self):
|
91
|
+
for k, v in self._data.items():
|
92
|
+
yield k,v
|
93
|
+
|
94
|
+
def ndarrays_to_tensors(*datas:List[np.ndarray]):
|
95
|
+
# the resulted tensor will share the same memory as the array
|
96
|
+
return [
|
97
|
+
[
|
98
|
+
torch.from_numpy(d) if not np.isscalar(d) else torch.tensor(d, dtype=torch.get_default_dtype())
|
99
|
+
for d in data
|
100
|
+
]
|
101
|
+
for data in datas
|
102
|
+
]
|
103
|
+
|
104
|
+
class StimRespDataset(torch.utils.data.Dataset):
|
105
|
+
|
106
|
+
def __init__(self,
|
107
|
+
stims:Union[List[np.ndarray], List[torch.Tensor]],
|
108
|
+
resps:Union[List[np.ndarray], List[torch.Tensor]],
|
109
|
+
device = 'cpu'
|
110
|
+
):
|
111
|
+
if isinstance(stims[0], np.ndarray):
|
112
|
+
stims, resps = ndarrays_to_tensors(stims, resps)
|
113
|
+
self.stims = stims
|
114
|
+
self.resps = resps
|
115
|
+
self.device = device
|
116
|
+
assert len(stims) == len(resps)
|
117
|
+
|
118
|
+
def __getitem__(self, index:int):
|
119
|
+
return self.stims[index].to(self.device), self.resps[index].to(self.device)
|
120
|
+
|
121
|
+
def __len__(self):
|
122
|
+
return len(self.stims)
|
123
|
+
|
124
|
+
class Context:
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
model:torch.nn.Module,
|
129
|
+
optimizer:torch.optim.Optimizer,
|
130
|
+
func_metrics: Callable,
|
131
|
+
checkpoint_folder: str,
|
132
|
+
checkpoint_file = "checkpoint.pt",
|
133
|
+
custom_config = {},
|
134
|
+
if_print_metric = True,
|
135
|
+
):
|
136
|
+
self.model = model
|
137
|
+
self.optimizer = optimizer
|
138
|
+
self.state_current_epoch = -1
|
139
|
+
self.func_metrics = func_metrics
|
140
|
+
self.checkpoint_folder = checkpoint_folder
|
141
|
+
self.checkpoint_file = checkpoint_file
|
142
|
+
self.metrics_log = MetricsRecord()
|
143
|
+
self.custom_config = custom_config
|
144
|
+
self.logger = get_logger(checkpoint_folder, if_print=if_print_metric)
|
145
|
+
|
146
|
+
self.dependents:List[DependentModule] = []
|
147
|
+
|
148
|
+
def add_dependent(self, module:DependentModule):
|
149
|
+
self.dependents.append(module)
|
150
|
+
|
151
|
+
def new_epochs(self):
|
152
|
+
self.state_current_epoch += 1
|
153
|
+
|
154
|
+
def checkpoint_exists(self):
|
155
|
+
return os.path.exists(self.checkpoint_path)
|
156
|
+
|
157
|
+
def save_checkpoint(self):
|
158
|
+
checkpoint = {}
|
159
|
+
for module in self.dependents:
|
160
|
+
checkpoint[module.__class__.__name__] = module.get_state()
|
161
|
+
checkpoint['context'] = self.get_state()
|
162
|
+
torch.save(checkpoint, self.checkpoint_path)
|
163
|
+
|
164
|
+
def load_checkpoint(self):
|
165
|
+
checkpoint = torch.load(self.checkpoint_path)
|
166
|
+
self.load_state(checkpoint['context'])
|
167
|
+
for module in self.dependents:
|
168
|
+
module.load_state(checkpoint[module.__class__.__name__])
|
169
|
+
|
170
|
+
@property
|
171
|
+
def checkpoint_path(self):
|
172
|
+
return f'{self.checkpoint_folder}/{self.checkpoint_file}'
|
173
|
+
|
174
|
+
def log_metrics(
|
175
|
+
self,
|
176
|
+
metrics,
|
177
|
+
tag = ''
|
178
|
+
):
|
179
|
+
scalar_metrics = {k:v.item() for k,v in metrics.items() if v.numel() == 1}
|
180
|
+
metrics_log = ''
|
181
|
+
for k,v in scalar_metrics.items():
|
182
|
+
metrics_log += f'{k}:{v} '
|
183
|
+
self.logger.info(f"epochs:{self.state_current_epoch} - {tag} - {metrics_log}")
|
184
|
+
self.metrics_log.append(metrics,tag)
|
185
|
+
|
186
|
+
def new_metrics_record(self):
|
187
|
+
return MetricsRecord()
|
188
|
+
|
189
|
+
def evaluate_dataloader(
|
190
|
+
self,
|
191
|
+
tag:str,
|
192
|
+
dtldr:torch.utils.data.DataLoader,
|
193
|
+
forward_function: Callable,
|
194
|
+
f_reduce_metrics_records = func_reduce_mean,
|
195
|
+
save_in_context = False,
|
196
|
+
batch_hook:List[Callable] = [],
|
197
|
+
output_hook:List[Callable] = []
|
198
|
+
):
|
199
|
+
new_log = MetricsRecord()
|
200
|
+
is_model_training = self.model.training
|
201
|
+
with torch.no_grad():
|
202
|
+
for batch in dtldr:
|
203
|
+
for f_batch in batch_hook:
|
204
|
+
f_batch(batch)
|
205
|
+
self.model.eval()
|
206
|
+
output = forward_function(self.model, batch)
|
207
|
+
for f_output in output_hook:
|
208
|
+
f_output(output)
|
209
|
+
metrics_dict = self.func_metrics(
|
210
|
+
batch,
|
211
|
+
output
|
212
|
+
)
|
213
|
+
new_log.append(
|
214
|
+
metrics_dict,
|
215
|
+
)
|
216
|
+
# print([i.shape for i in new_log['loss']])
|
217
|
+
# print(torch.cat(new_log['loss']).shape)
|
218
|
+
if is_model_training:
|
219
|
+
self.model.train()
|
220
|
+
else:
|
221
|
+
self.model.eval()
|
222
|
+
|
223
|
+
reduced_record = {k: f_reduce_metrics_records(v) for k, v in new_log.items()}
|
224
|
+
if save_in_context:
|
225
|
+
self.log_metrics(reduced_record, tag)
|
226
|
+
|
227
|
+
output_record = {}
|
228
|
+
for k,v in reduced_record.items():
|
229
|
+
if tag != '':
|
230
|
+
real_k = tag + '/' + k
|
231
|
+
else:
|
232
|
+
real_k = k
|
233
|
+
output_record[real_k] = v.cpu()
|
234
|
+
scalar_metrics = {k:v.item() for k,v in output_record.items() if v.numel() == 1}
|
235
|
+
return output_record, scalar_metrics
|
236
|
+
|
237
|
+
def get_state(self):
|
238
|
+
state = {
|
239
|
+
'model_state_dict': self.model.state_dict(),
|
240
|
+
'optim_state_dict': self.optimizer.state_dict(),
|
241
|
+
'state_current_epoch': self.state_current_epoch,
|
242
|
+
'custom_config':self.custom_config
|
243
|
+
}
|
244
|
+
return state
|
245
|
+
|
246
|
+
def load_state(self, state):
|
247
|
+
self.model.load_state_dict(state['model_state_dict'])
|
248
|
+
self.optimizer.load_state_dict(state['optim_state_dict'])
|
249
|
+
self.state_current_epoch = state['state_current_epoch']
|
250
|
+
self.custom_config = state['custom_config']
|
251
|
+
|
252
|
+
class SaveBest:
|
253
|
+
def __init__(
|
254
|
+
self,
|
255
|
+
ctx:Context,
|
256
|
+
state_metric_name,
|
257
|
+
op = lambda old, new: new > old,
|
258
|
+
tol = None,
|
259
|
+
ifLog = True,
|
260
|
+
file_name = "save_best.pt"
|
261
|
+
):
|
262
|
+
self.ctx = ctx
|
263
|
+
ctx.add_dependent(self)
|
264
|
+
self.state_cnt = 0
|
265
|
+
self.state_best_cnt = -1
|
266
|
+
self.state_best_metric = None
|
267
|
+
self.state_metric_name = state_metric_name
|
268
|
+
|
269
|
+
self.op = op
|
270
|
+
self.tol = tol
|
271
|
+
self.saved_checkpoint= None
|
272
|
+
self.ifLog = ifLog
|
273
|
+
self.file_name = file_name
|
274
|
+
|
275
|
+
@property
|
276
|
+
def target_path(self):
|
277
|
+
return f'{self.ctx.checkpoint_folder}/{self.file_name}'
|
278
|
+
|
279
|
+
def get_state(self):
|
280
|
+
output = {}
|
281
|
+
for k,v in self.__dict__.items():
|
282
|
+
if k.startswith('state_'):
|
283
|
+
output[k] = v
|
284
|
+
return output
|
285
|
+
|
286
|
+
def load_state(self, state):
|
287
|
+
for k,v in self.__dict__.items():
|
288
|
+
if k.startswith('state_'):
|
289
|
+
self.__dict__[k] = state[k]
|
290
|
+
|
291
|
+
def step(self,):
|
292
|
+
t_metric = self.ctx.metrics_log[self.state_metric_name][-1]
|
293
|
+
assert t_metric.ndim == 0 or (t_metric.ndim == 1 and t_metric.shape[0] == 1), t_metric.shape
|
294
|
+
t_metric = t_metric.item()
|
295
|
+
t_cnt = self.state_cnt
|
296
|
+
ifUpdate = False
|
297
|
+
ifStop = False
|
298
|
+
if self.state_best_metric is None:
|
299
|
+
ifUpdate = True
|
300
|
+
else:
|
301
|
+
ifUpdate = self.op(self.state_best_metric, t_metric)
|
302
|
+
if ifUpdate:
|
303
|
+
self.state_best_metric = t_metric
|
304
|
+
self.state_best_cnt = t_cnt
|
305
|
+
checkpoint = {}
|
306
|
+
checkpoint.update(self.ctx.get_state())
|
307
|
+
checkpoint.update(self.get_state())
|
308
|
+
|
309
|
+
if self.ifLog:
|
310
|
+
msg = f'save_best --- cnt: {self.state_best_cnt}, {self.state_metric_name}: {self.state_best_metric}'
|
311
|
+
self.ctx.logger.info(msg)
|
312
|
+
torch.save(checkpoint, self.target_path)
|
313
|
+
self.saved_checkpoint = checkpoint
|
314
|
+
|
315
|
+
if self.tol is not None:
|
316
|
+
if self.state_cnt - self.state_best_cnt > self.tol:
|
317
|
+
ifStop = True
|
318
|
+
msg = f'early_stop --- epoch: {self.state_best_cnt}, metric: {self.state_best_metric}'
|
319
|
+
self.ctx.logger.info(msg)
|
320
|
+
self.state_cnt += 1
|
321
|
+
return ifUpdate, ifStop
|
322
|
+
|
323
|
+
|
324
|
+
def pearsonr(y, y_pred):
|
325
|
+
"""
|
326
|
+
Compute Pearson's correlation coefficient between predicted
|
327
|
+
and observed data
|
328
|
+
|
329
|
+
y: (..., n_samples, n_chans)
|
330
|
+
y_pred: (..., n_samples, n_chans)
|
331
|
+
"""
|
332
|
+
r = torch.mean(
|
333
|
+
(y - y.mean(-2, keepdims = True)) * (y_pred - y_pred.mean(-2, keepdims = True)),
|
334
|
+
-2
|
335
|
+
) / (
|
336
|
+
y.std(-2) * y_pred.std(-2)
|
337
|
+
)
|
338
|
+
return r
|
339
|
+
|