nntrf 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntrf-1.0.0/LICENSE +21 -0
- nntrf-1.0.0/PKG-INFO +23 -0
- nntrf-1.0.0/README.md +8 -0
- nntrf-1.0.0/nntrf/__init__.py +0 -0
- nntrf-1.0.0/nntrf/loss.py +31 -0
- nntrf-1.0.0/nntrf/metrics.py +35 -0
- nntrf-1.0.0/nntrf/models/__init__.py +3 -0
- nntrf-1.0.0/nntrf/models/composite.py +63 -0
- nntrf-1.0.0/nntrf/models/linear.py +269 -0
- nntrf-1.0.0/nntrf/models/nonlinear.py +1358 -0
- nntrf-1.0.0/nntrf/utils.py +12 -0
- nntrf-1.0.0/nntrf.egg-info/PKG-INFO +23 -0
- nntrf-1.0.0/nntrf.egg-info/SOURCES.txt +17 -0
- nntrf-1.0.0/nntrf.egg-info/dependency_links.txt +1 -0
- nntrf-1.0.0/nntrf.egg-info/requires.txt +4 -0
- nntrf-1.0.0/nntrf.egg-info/top_level.txt +1 -0
- nntrf-1.0.0/setup.cfg +4 -0
- nntrf-1.0.0/setup.py +23 -0
- nntrf-1.0.0/tests/testNonLinear.py +247 -0
nntrf-1.0.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2020 powerfulbean
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
nntrf-1.0.0/PKG-INFO
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
Metadata-Version: 2.2
|
2
|
+
Name: nntrf
|
3
|
+
Version: 1.0.0
|
4
|
+
Home-page: https://github.com/powerfulbean/nnTRF
|
5
|
+
Author: Jin Dou
|
6
|
+
Author-email: jindou.bci@gmail.com
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
9
|
+
Classifier: Operating System :: OS Independent
|
10
|
+
Requires-Python: >=3.8
|
11
|
+
Description-Content-Type: text/markdown
|
12
|
+
License-File: LICENSE
|
13
|
+
Requires-Dist: numpy>=1.20.1
|
14
|
+
Requires-Dist: torch<2.0.0,>=1.12.1
|
15
|
+
Requires-Dist: scikit-fda==0.7.1
|
16
|
+
Requires-Dist: mtrf
|
17
|
+
Dynamic: author
|
18
|
+
Dynamic: author-email
|
19
|
+
Dynamic: classifier
|
20
|
+
Dynamic: description-content-type
|
21
|
+
Dynamic: home-page
|
22
|
+
Dynamic: requires-dist
|
23
|
+
Dynamic: requires-python
|
nntrf-1.0.0/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
1
|
+
# nnTRF - neural network Temporal Response Function
|
2
|
+
|
3
|
+
This package is an artificial neural network implementation for temporal responses function modelling of brain signal. It implement the linear time-invariant TRF which can be solved with ridge regression ([mTRF-Toolbox](https://github.com/mickcrosse/mTRF-Toolbox), [mTRFpy](https://github.com/powerfulbean/mTRFpy)) or boosting ([Eelbrain
|
4
|
+
](https://github.com/christianbrodbeck/Eelbrain)), the [dynamic TRF](https://doi.org/10.1101/2024.08.26.609779) framework and mode!
|
5
|
+
|
6
|
+
## Installation
|
7
|
+
|
8
|
+
coming soon ....
|
File without changes
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Thu Dec 10 14:56:39 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
import torch
|
8
|
+
def pearsonCorrLoss(x,y):
|
9
|
+
|
10
|
+
mx = torch.mean(x,1)
|
11
|
+
my = torch.mean(y,1)
|
12
|
+
x = x[:,:,0]
|
13
|
+
y = y[:,:,0]
|
14
|
+
xm, ym = x-mx, y-my
|
15
|
+
r_num = torch.sum(xm * ym,1)
|
16
|
+
r_den = torch.sqrt(torch.sum(xm*xm,1) * torch.sum(ym*ym,1))
|
17
|
+
# print(torch.sum(r_num_sub)/r_den)
|
18
|
+
if(torch.mean(r_num) == 0 and torch.mean(r_den) ==0):
|
19
|
+
raise ValueError("gradient gone\n")
|
20
|
+
r = None
|
21
|
+
else:
|
22
|
+
r = r_num / r_den
|
23
|
+
# return 1 - r**2
|
24
|
+
|
25
|
+
# print(r)
|
26
|
+
r = torch.mean(r)
|
27
|
+
# loss = torch.exp(r)
|
28
|
+
loss = 1 - r
|
29
|
+
|
30
|
+
return loss
|
31
|
+
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Thu Dec 10 14:55:06 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
import numpy as np
|
8
|
+
# from scipy import stats as spStats
|
9
|
+
|
10
|
+
def Pearsonr(x,y):
|
11
|
+
nObs = len(x)
|
12
|
+
sumX = np.sum(x,0)
|
13
|
+
sumY = np.sum(y,0)
|
14
|
+
sdXY = np.sqrt((np.sum(x**2,0) - (sumX**2/nObs)) * (np.sum(y ** 2, 0) - (sumY ** 2)/nObs))
|
15
|
+
|
16
|
+
r = (np.sum(x*y,0) - (sumX * sumY)/nObs) / sdXY
|
17
|
+
return r
|
18
|
+
|
19
|
+
#def Pearsonr(x,y):
|
20
|
+
## x = np.squeeze(x)
|
21
|
+
## y = np.squeeze(y)
|
22
|
+
## print(x.shape)
|
23
|
+
## print(y.shape)
|
24
|
+
#
|
25
|
+
# out = spStats.pearsonr(x,y)
|
26
|
+
## print('correlation',out)
|
27
|
+
# return out
|
28
|
+
|
29
|
+
def BatchPearsonr(pred,y):
|
30
|
+
result = list()
|
31
|
+
for i in range(len(pred)):
|
32
|
+
out1 = Pearsonr(pred[i],y[i])
|
33
|
+
# print(out1)
|
34
|
+
result.append(out1)
|
35
|
+
return np.mean(result,0)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
class TwoMixedTRF(torch.nn.Module):
|
4
|
+
|
5
|
+
def __init__(
|
6
|
+
self,
|
7
|
+
device,
|
8
|
+
trfs, #list of trf models
|
9
|
+
feats_keys, #list of feat key for each trf in the trfs
|
10
|
+
):
|
11
|
+
super().__init__()
|
12
|
+
self.trfs:torch.nn.Module = torch.nn.ModuleList([trf for trf in trfs])
|
13
|
+
self.feats_keys = feats_keys
|
14
|
+
self.device = device
|
15
|
+
|
16
|
+
def forward(self,feat_dict:dict, y):
|
17
|
+
|
18
|
+
pred_list = []
|
19
|
+
for trf_index, iTRF in enumerate(self.trfs):
|
20
|
+
feats_key = self.feats_keys[trf_index]
|
21
|
+
feats = []
|
22
|
+
n_dict_feat = 0
|
23
|
+
for feat_key in feats_key:
|
24
|
+
# print(feat_dict.keys())
|
25
|
+
feat = feat_dict[feat_key]
|
26
|
+
if isinstance(feat, dict):
|
27
|
+
feats.append(feat)
|
28
|
+
n_dict_feat += 1
|
29
|
+
else:
|
30
|
+
assert isinstance(feat, torch.Tensor)
|
31
|
+
feats.append(feat)
|
32
|
+
if n_dict_feat > 0:
|
33
|
+
assert len(feats) == n_dict_feat
|
34
|
+
# concatente
|
35
|
+
if len(feats) == 1:
|
36
|
+
feats = feats[0]
|
37
|
+
else:
|
38
|
+
# raise NotImplementedError
|
39
|
+
timeinfo_0 = feats[0]['timeinfo']
|
40
|
+
xs = []
|
41
|
+
for feat in feats:
|
42
|
+
xs.append(feat['x'])
|
43
|
+
torch.equal(timeinfo_0, feat['timeinfo'])
|
44
|
+
xs = torch.cat(xs, dim = -2)
|
45
|
+
feats = {
|
46
|
+
'x':xs,
|
47
|
+
'timeinfo':timeinfo_0
|
48
|
+
}
|
49
|
+
pred_list.append(iTRF(**feats))
|
50
|
+
else:
|
51
|
+
minLen = min([f.shape[-1] for f in feats])
|
52
|
+
feats = torch.cat([f[...,:minLen] for f in feats], axis = -2)
|
53
|
+
pred_list.append(iTRF(feats))
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
minLen = min([p.shape[-1] for p in pred_list] + [y.shape[-1]])
|
58
|
+
|
59
|
+
pred_list = [p[...,:minLen] for p in pred_list]
|
60
|
+
cropedY = y[:,:,:minLen]
|
61
|
+
pred = sum(pred_list)
|
62
|
+
# stop
|
63
|
+
return pred,cropedY
|
@@ -0,0 +1,269 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Wed Dec 9 10:11:27 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import numpy as np
|
11
|
+
from ..metrics import Pearsonr, BatchPearsonr
|
12
|
+
from ..utils import TensorsToNumpy
|
13
|
+
try:
|
14
|
+
from matplotlib import pyplot as plt
|
15
|
+
except:
|
16
|
+
plt = None
|
17
|
+
|
18
|
+
def msec2Idxs(msecRange,fs):
|
19
|
+
'''
|
20
|
+
convert a millisecond range to a list of sample indexes
|
21
|
+
|
22
|
+
the left and right ranges will both be included
|
23
|
+
'''
|
24
|
+
assert len(msecRange) == 2
|
25
|
+
|
26
|
+
tmin = msecRange[0]/1e3
|
27
|
+
tmax = msecRange[1]/1e3
|
28
|
+
return list(range(int(np.floor(tmin*fs)),int(np.ceil(tmax*fs)) + 1))
|
29
|
+
|
30
|
+
def Idxs2msec(lags,fs):
|
31
|
+
'''
|
32
|
+
convert a list of sample indexes to a millisecond range
|
33
|
+
|
34
|
+
the left and right ranges will both be included
|
35
|
+
'''
|
36
|
+
temp = np.array(lags)
|
37
|
+
return list(temp/fs * 1e3)
|
38
|
+
|
39
|
+
class LRTRF(torch.nn.Module):
|
40
|
+
'''
|
41
|
+
the TRF implemented with a linear layer and time lag of input
|
42
|
+
'''
|
43
|
+
# the shape of the input for the forward should be the (nBatch,nTimeSteps,nChannels)
|
44
|
+
|
45
|
+
def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,bias = True):
|
46
|
+
super().__init__()
|
47
|
+
self.tmin_ms = tmin_ms
|
48
|
+
self.tmax_ms = tmax_ms
|
49
|
+
self.fs = fs
|
50
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
51
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
52
|
+
self.realInDim = len(self.lagIdxs) * inDim
|
53
|
+
self.oDense = torch.nn.Linear(self.realInDim,outDim,bias = bias)
|
54
|
+
self.inDim = inDim
|
55
|
+
self.outDim = outDim
|
56
|
+
|
57
|
+
def timeLagging(self,tensor):
|
58
|
+
x = tensor
|
59
|
+
nBatch = x.shape[0]
|
60
|
+
batchList = []
|
61
|
+
for batchId in range(nBatch):
|
62
|
+
batch = x[batchId:batchId+1]
|
63
|
+
lagDataList = []
|
64
|
+
for idx,lag in enumerate(self.lagIdxs):
|
65
|
+
# we assume the last second dimension indicates time steps
|
66
|
+
if lag < 0:
|
67
|
+
temp = torch.nn.functional.pad(batch,((0,0,0,-lag)))
|
68
|
+
# lagDataList.append(temp[:,-lag:,:])
|
69
|
+
lagDataList.append((temp.T)[:,-lag:].T)
|
70
|
+
elif lag > 0:
|
71
|
+
temp = torch.nn.functional.pad(batch,((0,0,lag,0)))
|
72
|
+
# lagDataList.append(temp[:,0:-lag,:])
|
73
|
+
lagDataList.append((temp.T)[:,0:-lag].T)
|
74
|
+
else:
|
75
|
+
lagDataList.append(batch)
|
76
|
+
batchList.append(torch.cat(lagDataList,-1))
|
77
|
+
x3 = torch.cat(batchList,0)
|
78
|
+
return x3
|
79
|
+
|
80
|
+
def forward(self,x):
|
81
|
+
x = self.timeLagging(x)
|
82
|
+
return self.oDense(x)
|
83
|
+
|
84
|
+
@property
|
85
|
+
def weights(self):
|
86
|
+
return self.state_dict()['oDense.weight'].cpu().detach().numpy()
|
87
|
+
|
88
|
+
@property
|
89
|
+
def w(self):
|
90
|
+
'''
|
91
|
+
funtion reproduce the definition of w in mTRF-toolbox
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
None.
|
96
|
+
|
97
|
+
'''
|
98
|
+
w = self.oDense.weight.T.cpu().detach()
|
99
|
+
w = w.view(len(self.lagIdxs),self.inDim,self.outDim)
|
100
|
+
w = w.permute(1,0,2)
|
101
|
+
w = w.numpy()
|
102
|
+
return w
|
103
|
+
|
104
|
+
def loadFromMTRFpy(self,w,b,device):
|
105
|
+
#w: (nInChan, nLag, nOutChan)
|
106
|
+
# print(w.shape)
|
107
|
+
w = w * 1/ self.fs
|
108
|
+
b = b * 1/self.fs
|
109
|
+
b = b[0]
|
110
|
+
w = torch.FloatTensor(w).to(device)
|
111
|
+
w = w.permute(1,0,2)
|
112
|
+
w = w.reshape(-1,w.shape[-1]).T
|
113
|
+
b = torch.FloatTensor(b).to(device)
|
114
|
+
with torch.no_grad():
|
115
|
+
self.oDense.weight = torch.nn.Parameter(w)
|
116
|
+
self.oDense.bias = torch.nn.Parameter(b)
|
117
|
+
return self
|
118
|
+
|
119
|
+
class CPadOrCrop1D(torch.nn.Module):
|
120
|
+
def __init__(self,tmin_idx,tmax_idx):
|
121
|
+
super().__init__()
|
122
|
+
self.tmin_idx = tmin_idx
|
123
|
+
self.tmax_idx = tmax_idx
|
124
|
+
|
125
|
+
def forward(self,x):
|
126
|
+
# padding buttom
|
127
|
+
if (self.tmin_idx <= 0):
|
128
|
+
x = torch.nn.functional.pad(x,((0,-self.tmin_idx)))
|
129
|
+
else:
|
130
|
+
x = x[:,:,:-self.tmin_idx]
|
131
|
+
if (self.tmax_idx < 0):
|
132
|
+
x = x[:,:,-self.tmax_idx:]
|
133
|
+
else:
|
134
|
+
x = torch.nn.functional.pad(x,((self.tmax_idx,0)))
|
135
|
+
return x
|
136
|
+
|
137
|
+
|
138
|
+
class CNNTRF(torch.nn.Module):
|
139
|
+
'''
|
140
|
+
the TRF implemented with a convolutional layer,
|
141
|
+
and zero padding of input
|
142
|
+
'''
|
143
|
+
|
144
|
+
# the shape of the input for the forward should be the (nBatch,nChannels,nTimeSteps,)
|
145
|
+
# Be care of the calculation of correlation, when using this model,
|
146
|
+
# because the nnTRF.Metrics.Pearsonr treat the input data as the shape of
|
147
|
+
# (nTimeSteps, nChannels)
|
148
|
+
def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,enableBN = False, dilation = 1):
|
149
|
+
super().__init__()
|
150
|
+
self.tmin_ms = tmin_ms
|
151
|
+
self.tmax_ms = tmax_ms
|
152
|
+
self.fs = fs
|
153
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
154
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
155
|
+
self.tmin_idx = self.lagIdxs[0]
|
156
|
+
self.tmax_idx = self.lagIdxs[-1]
|
157
|
+
nLags = len(self.lagTimes)
|
158
|
+
nKernels = (nLags - 1) / dilation + 1
|
159
|
+
assert np.ceil(nKernels) == np.floor(nKernels)
|
160
|
+
nKernels = int(nKernels)
|
161
|
+
self.oCNN = torch.nn.Conv1d(
|
162
|
+
inDim,
|
163
|
+
outDim,
|
164
|
+
nKernels,
|
165
|
+
groups = groups,
|
166
|
+
dilation = dilation
|
167
|
+
)
|
168
|
+
self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
|
169
|
+
self.groups = groups
|
170
|
+
self.enableBN = enableBN
|
171
|
+
self.oBN = torch.nn.BatchNorm1d(inDim,affine=False,track_running_stats=False)
|
172
|
+
self.dilation = dilation
|
173
|
+
#if both lagMin and lagMax > 0, more complex operation
|
174
|
+
|
175
|
+
def forward(self,x):
|
176
|
+
if self.enableBN:
|
177
|
+
x = self.oBN(x)
|
178
|
+
x = self.oPadOrCrop(x)
|
179
|
+
x = self.oCNN(x)
|
180
|
+
return x
|
181
|
+
|
182
|
+
@property
|
183
|
+
def weights(self):
|
184
|
+
'''
|
185
|
+
Returns
|
186
|
+
-------
|
187
|
+
Formatted weights, with timeLag dimension flipped to conform to mTRF
|
188
|
+
[outChannels, inChannels, timeLags]
|
189
|
+
'''
|
190
|
+
return np.flip(self.state_dict()['oCNN.weight'].cpu().detach().numpy(),axis = -1)
|
191
|
+
|
192
|
+
@property
|
193
|
+
def w(self):
|
194
|
+
'''
|
195
|
+
funtion reproduce the definition of w in mTRF-toolbox
|
196
|
+
|
197
|
+
Returns
|
198
|
+
-------
|
199
|
+
None.
|
200
|
+
|
201
|
+
'''
|
202
|
+
tensor = self.state_dict()['oCNN.weight']
|
203
|
+
tensor = tensor.permute(1,2,0)
|
204
|
+
return np.flip(tensor.cpu().detach().numpy(),axis = 1)
|
205
|
+
|
206
|
+
|
207
|
+
@property
|
208
|
+
def b(self):
|
209
|
+
return self.oCNN.bias.squeeze().detach().cpu().numpy()
|
210
|
+
|
211
|
+
def loadFromMTRFpy(self,w,b,device):
|
212
|
+
#w: (nInChan, nLag, nOutChan)
|
213
|
+
w = w * 1/self.fs
|
214
|
+
b = b * 1/self.fs
|
215
|
+
b = b[0]
|
216
|
+
w = np.flip(w,axis = 1).copy()
|
217
|
+
w = torch.FloatTensor(w).to(device)
|
218
|
+
w = w.permute(2,0,1)
|
219
|
+
b = torch.FloatTensor(b).to(device)
|
220
|
+
with torch.no_grad():
|
221
|
+
self.oCNN.weight = torch.nn.Parameter(w)
|
222
|
+
self.oCNN.bias = torch.nn.Parameter(b)
|
223
|
+
return self
|
224
|
+
|
225
|
+
@property
|
226
|
+
def t(self):
|
227
|
+
return self.lagTimes
|
228
|
+
|
229
|
+
def BatchPearsonr(self,pred,y):
|
230
|
+
tensors = TensorsToNumpy(pred.transpose(-1,-2),y.transpose(-1,-2))
|
231
|
+
return BatchPearsonr(*tensors)
|
232
|
+
|
233
|
+
@property
|
234
|
+
def readableWeights(self):
|
235
|
+
'''
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
Readable formatted weights
|
239
|
+
[timeLags, inChannels, outChannels]
|
240
|
+
'''
|
241
|
+
return self.weights.T
|
242
|
+
|
243
|
+
def W(self,inIdx=None,outIdx=None,tIdx=None):
|
244
|
+
'''
|
245
|
+
Returns
|
246
|
+
-------
|
247
|
+
readable formatted weights of selected inChannel/outChannel/lagTime
|
248
|
+
'''
|
249
|
+
|
250
|
+
inIdx = slice(inIdx) if inIdx is None else inIdx
|
251
|
+
outIdx = slice(outIdx) if outIdx is None else outIdx
|
252
|
+
tIdx = slice(tIdx) if tIdx is None else tIdx
|
253
|
+
return self.readableWeights[tIdx,inIdx,outIdx]
|
254
|
+
|
255
|
+
def load(self,path):
|
256
|
+
self.load_state_dict(torch.load(path,map_location='cpu')['state_dict'])
|
257
|
+
self.eval()
|
258
|
+
|
259
|
+
|
260
|
+
def plotWeights(self,outChan = None, inChan = None):
|
261
|
+
fig,ax = plt.subplots()
|
262
|
+
if outChan is None:
|
263
|
+
outChan = slice(outChan)
|
264
|
+
if inChan is None:
|
265
|
+
inChan = slice(inChan)
|
266
|
+
ax.plot(self.t[::self.dilation], self.weights[outChan, inChan].T)
|
267
|
+
return fig, ax
|
268
|
+
|
269
|
+
|