nntrf 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntrf/__init__.py +0 -0
- nntrf/loss.py +31 -0
- nntrf/metrics.py +35 -0
- nntrf/models/__init__.py +3 -0
- nntrf/models/composite.py +63 -0
- nntrf/models/linear.py +269 -0
- nntrf/models/nonlinear.py +1358 -0
- nntrf/utils.py +12 -0
- nntrf-1.0.0.dist-info/LICENSE +21 -0
- nntrf-1.0.0.dist-info/METADATA +23 -0
- nntrf-1.0.0.dist-info/RECORD +13 -0
- nntrf-1.0.0.dist-info/WHEEL +5 -0
- nntrf-1.0.0.dist-info/top_level.txt +1 -0
nntrf/__init__.py
ADDED
File without changes
|
nntrf/loss.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Thu Dec 10 14:56:39 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
import torch
|
8
|
+
def pearsonCorrLoss(x,y):
|
9
|
+
|
10
|
+
mx = torch.mean(x,1)
|
11
|
+
my = torch.mean(y,1)
|
12
|
+
x = x[:,:,0]
|
13
|
+
y = y[:,:,0]
|
14
|
+
xm, ym = x-mx, y-my
|
15
|
+
r_num = torch.sum(xm * ym,1)
|
16
|
+
r_den = torch.sqrt(torch.sum(xm*xm,1) * torch.sum(ym*ym,1))
|
17
|
+
# print(torch.sum(r_num_sub)/r_den)
|
18
|
+
if(torch.mean(r_num) == 0 and torch.mean(r_den) ==0):
|
19
|
+
raise ValueError("gradient gone\n")
|
20
|
+
r = None
|
21
|
+
else:
|
22
|
+
r = r_num / r_den
|
23
|
+
# return 1 - r**2
|
24
|
+
|
25
|
+
# print(r)
|
26
|
+
r = torch.mean(r)
|
27
|
+
# loss = torch.exp(r)
|
28
|
+
loss = 1 - r
|
29
|
+
|
30
|
+
return loss
|
31
|
+
|
nntrf/metrics.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Thu Dec 10 14:55:06 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
import numpy as np
|
8
|
+
# from scipy import stats as spStats
|
9
|
+
|
10
|
+
def Pearsonr(x,y):
|
11
|
+
nObs = len(x)
|
12
|
+
sumX = np.sum(x,0)
|
13
|
+
sumY = np.sum(y,0)
|
14
|
+
sdXY = np.sqrt((np.sum(x**2,0) - (sumX**2/nObs)) * (np.sum(y ** 2, 0) - (sumY ** 2)/nObs))
|
15
|
+
|
16
|
+
r = (np.sum(x*y,0) - (sumX * sumY)/nObs) / sdXY
|
17
|
+
return r
|
18
|
+
|
19
|
+
#def Pearsonr(x,y):
|
20
|
+
## x = np.squeeze(x)
|
21
|
+
## y = np.squeeze(y)
|
22
|
+
## print(x.shape)
|
23
|
+
## print(y.shape)
|
24
|
+
#
|
25
|
+
# out = spStats.pearsonr(x,y)
|
26
|
+
## print('correlation',out)
|
27
|
+
# return out
|
28
|
+
|
29
|
+
def BatchPearsonr(pred,y):
|
30
|
+
result = list()
|
31
|
+
for i in range(len(pred)):
|
32
|
+
out1 = Pearsonr(pred[i],y[i])
|
33
|
+
# print(out1)
|
34
|
+
result.append(out1)
|
35
|
+
return np.mean(result,0)
|
nntrf/models/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
class TwoMixedTRF(torch.nn.Module):
|
4
|
+
|
5
|
+
def __init__(
|
6
|
+
self,
|
7
|
+
device,
|
8
|
+
trfs, #list of trf models
|
9
|
+
feats_keys, #list of feat key for each trf in the trfs
|
10
|
+
):
|
11
|
+
super().__init__()
|
12
|
+
self.trfs:torch.nn.Module = torch.nn.ModuleList([trf for trf in trfs])
|
13
|
+
self.feats_keys = feats_keys
|
14
|
+
self.device = device
|
15
|
+
|
16
|
+
def forward(self,feat_dict:dict, y):
|
17
|
+
|
18
|
+
pred_list = []
|
19
|
+
for trf_index, iTRF in enumerate(self.trfs):
|
20
|
+
feats_key = self.feats_keys[trf_index]
|
21
|
+
feats = []
|
22
|
+
n_dict_feat = 0
|
23
|
+
for feat_key in feats_key:
|
24
|
+
# print(feat_dict.keys())
|
25
|
+
feat = feat_dict[feat_key]
|
26
|
+
if isinstance(feat, dict):
|
27
|
+
feats.append(feat)
|
28
|
+
n_dict_feat += 1
|
29
|
+
else:
|
30
|
+
assert isinstance(feat, torch.Tensor)
|
31
|
+
feats.append(feat)
|
32
|
+
if n_dict_feat > 0:
|
33
|
+
assert len(feats) == n_dict_feat
|
34
|
+
# concatente
|
35
|
+
if len(feats) == 1:
|
36
|
+
feats = feats[0]
|
37
|
+
else:
|
38
|
+
# raise NotImplementedError
|
39
|
+
timeinfo_0 = feats[0]['timeinfo']
|
40
|
+
xs = []
|
41
|
+
for feat in feats:
|
42
|
+
xs.append(feat['x'])
|
43
|
+
torch.equal(timeinfo_0, feat['timeinfo'])
|
44
|
+
xs = torch.cat(xs, dim = -2)
|
45
|
+
feats = {
|
46
|
+
'x':xs,
|
47
|
+
'timeinfo':timeinfo_0
|
48
|
+
}
|
49
|
+
pred_list.append(iTRF(**feats))
|
50
|
+
else:
|
51
|
+
minLen = min([f.shape[-1] for f in feats])
|
52
|
+
feats = torch.cat([f[...,:minLen] for f in feats], axis = -2)
|
53
|
+
pred_list.append(iTRF(feats))
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
minLen = min([p.shape[-1] for p in pred_list] + [y.shape[-1]])
|
58
|
+
|
59
|
+
pred_list = [p[...,:minLen] for p in pred_list]
|
60
|
+
cropedY = y[:,:,:minLen]
|
61
|
+
pred = sum(pred_list)
|
62
|
+
# stop
|
63
|
+
return pred,cropedY
|
nntrf/models/linear.py
ADDED
@@ -0,0 +1,269 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""
|
3
|
+
Created on Wed Dec 9 10:11:27 2020
|
4
|
+
|
5
|
+
@author: Jin Dou
|
6
|
+
"""
|
7
|
+
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import numpy as np
|
11
|
+
from ..metrics import Pearsonr, BatchPearsonr
|
12
|
+
from ..utils import TensorsToNumpy
|
13
|
+
try:
|
14
|
+
from matplotlib import pyplot as plt
|
15
|
+
except:
|
16
|
+
plt = None
|
17
|
+
|
18
|
+
def msec2Idxs(msecRange,fs):
|
19
|
+
'''
|
20
|
+
convert a millisecond range to a list of sample indexes
|
21
|
+
|
22
|
+
the left and right ranges will both be included
|
23
|
+
'''
|
24
|
+
assert len(msecRange) == 2
|
25
|
+
|
26
|
+
tmin = msecRange[0]/1e3
|
27
|
+
tmax = msecRange[1]/1e3
|
28
|
+
return list(range(int(np.floor(tmin*fs)),int(np.ceil(tmax*fs)) + 1))
|
29
|
+
|
30
|
+
def Idxs2msec(lags,fs):
|
31
|
+
'''
|
32
|
+
convert a list of sample indexes to a millisecond range
|
33
|
+
|
34
|
+
the left and right ranges will both be included
|
35
|
+
'''
|
36
|
+
temp = np.array(lags)
|
37
|
+
return list(temp/fs * 1e3)
|
38
|
+
|
39
|
+
class LRTRF(torch.nn.Module):
|
40
|
+
'''
|
41
|
+
the TRF implemented with a linear layer and time lag of input
|
42
|
+
'''
|
43
|
+
# the shape of the input for the forward should be the (nBatch,nTimeSteps,nChannels)
|
44
|
+
|
45
|
+
def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,bias = True):
|
46
|
+
super().__init__()
|
47
|
+
self.tmin_ms = tmin_ms
|
48
|
+
self.tmax_ms = tmax_ms
|
49
|
+
self.fs = fs
|
50
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
51
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
52
|
+
self.realInDim = len(self.lagIdxs) * inDim
|
53
|
+
self.oDense = torch.nn.Linear(self.realInDim,outDim,bias = bias)
|
54
|
+
self.inDim = inDim
|
55
|
+
self.outDim = outDim
|
56
|
+
|
57
|
+
def timeLagging(self,tensor):
|
58
|
+
x = tensor
|
59
|
+
nBatch = x.shape[0]
|
60
|
+
batchList = []
|
61
|
+
for batchId in range(nBatch):
|
62
|
+
batch = x[batchId:batchId+1]
|
63
|
+
lagDataList = []
|
64
|
+
for idx,lag in enumerate(self.lagIdxs):
|
65
|
+
# we assume the last second dimension indicates time steps
|
66
|
+
if lag < 0:
|
67
|
+
temp = torch.nn.functional.pad(batch,((0,0,0,-lag)))
|
68
|
+
# lagDataList.append(temp[:,-lag:,:])
|
69
|
+
lagDataList.append((temp.T)[:,-lag:].T)
|
70
|
+
elif lag > 0:
|
71
|
+
temp = torch.nn.functional.pad(batch,((0,0,lag,0)))
|
72
|
+
# lagDataList.append(temp[:,0:-lag,:])
|
73
|
+
lagDataList.append((temp.T)[:,0:-lag].T)
|
74
|
+
else:
|
75
|
+
lagDataList.append(batch)
|
76
|
+
batchList.append(torch.cat(lagDataList,-1))
|
77
|
+
x3 = torch.cat(batchList,0)
|
78
|
+
return x3
|
79
|
+
|
80
|
+
def forward(self,x):
|
81
|
+
x = self.timeLagging(x)
|
82
|
+
return self.oDense(x)
|
83
|
+
|
84
|
+
@property
|
85
|
+
def weights(self):
|
86
|
+
return self.state_dict()['oDense.weight'].cpu().detach().numpy()
|
87
|
+
|
88
|
+
@property
|
89
|
+
def w(self):
|
90
|
+
'''
|
91
|
+
funtion reproduce the definition of w in mTRF-toolbox
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
None.
|
96
|
+
|
97
|
+
'''
|
98
|
+
w = self.oDense.weight.T.cpu().detach()
|
99
|
+
w = w.view(len(self.lagIdxs),self.inDim,self.outDim)
|
100
|
+
w = w.permute(1,0,2)
|
101
|
+
w = w.numpy()
|
102
|
+
return w
|
103
|
+
|
104
|
+
def loadFromMTRFpy(self,w,b,device):
|
105
|
+
#w: (nInChan, nLag, nOutChan)
|
106
|
+
# print(w.shape)
|
107
|
+
w = w * 1/ self.fs
|
108
|
+
b = b * 1/self.fs
|
109
|
+
b = b[0]
|
110
|
+
w = torch.FloatTensor(w).to(device)
|
111
|
+
w = w.permute(1,0,2)
|
112
|
+
w = w.reshape(-1,w.shape[-1]).T
|
113
|
+
b = torch.FloatTensor(b).to(device)
|
114
|
+
with torch.no_grad():
|
115
|
+
self.oDense.weight = torch.nn.Parameter(w)
|
116
|
+
self.oDense.bias = torch.nn.Parameter(b)
|
117
|
+
return self
|
118
|
+
|
119
|
+
class CPadOrCrop1D(torch.nn.Module):
|
120
|
+
def __init__(self,tmin_idx,tmax_idx):
|
121
|
+
super().__init__()
|
122
|
+
self.tmin_idx = tmin_idx
|
123
|
+
self.tmax_idx = tmax_idx
|
124
|
+
|
125
|
+
def forward(self,x):
|
126
|
+
# padding buttom
|
127
|
+
if (self.tmin_idx <= 0):
|
128
|
+
x = torch.nn.functional.pad(x,((0,-self.tmin_idx)))
|
129
|
+
else:
|
130
|
+
x = x[:,:,:-self.tmin_idx]
|
131
|
+
if (self.tmax_idx < 0):
|
132
|
+
x = x[:,:,-self.tmax_idx:]
|
133
|
+
else:
|
134
|
+
x = torch.nn.functional.pad(x,((self.tmax_idx,0)))
|
135
|
+
return x
|
136
|
+
|
137
|
+
|
138
|
+
class CNNTRF(torch.nn.Module):
|
139
|
+
'''
|
140
|
+
the TRF implemented with a convolutional layer,
|
141
|
+
and zero padding of input
|
142
|
+
'''
|
143
|
+
|
144
|
+
# the shape of the input for the forward should be the (nBatch,nChannels,nTimeSteps,)
|
145
|
+
# Be care of the calculation of correlation, when using this model,
|
146
|
+
# because the nnTRF.Metrics.Pearsonr treat the input data as the shape of
|
147
|
+
# (nTimeSteps, nChannels)
|
148
|
+
def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,enableBN = False, dilation = 1):
|
149
|
+
super().__init__()
|
150
|
+
self.tmin_ms = tmin_ms
|
151
|
+
self.tmax_ms = tmax_ms
|
152
|
+
self.fs = fs
|
153
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
154
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
155
|
+
self.tmin_idx = self.lagIdxs[0]
|
156
|
+
self.tmax_idx = self.lagIdxs[-1]
|
157
|
+
nLags = len(self.lagTimes)
|
158
|
+
nKernels = (nLags - 1) / dilation + 1
|
159
|
+
assert np.ceil(nKernels) == np.floor(nKernels)
|
160
|
+
nKernels = int(nKernels)
|
161
|
+
self.oCNN = torch.nn.Conv1d(
|
162
|
+
inDim,
|
163
|
+
outDim,
|
164
|
+
nKernels,
|
165
|
+
groups = groups,
|
166
|
+
dilation = dilation
|
167
|
+
)
|
168
|
+
self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
|
169
|
+
self.groups = groups
|
170
|
+
self.enableBN = enableBN
|
171
|
+
self.oBN = torch.nn.BatchNorm1d(inDim,affine=False,track_running_stats=False)
|
172
|
+
self.dilation = dilation
|
173
|
+
#if both lagMin and lagMax > 0, more complex operation
|
174
|
+
|
175
|
+
def forward(self,x):
|
176
|
+
if self.enableBN:
|
177
|
+
x = self.oBN(x)
|
178
|
+
x = self.oPadOrCrop(x)
|
179
|
+
x = self.oCNN(x)
|
180
|
+
return x
|
181
|
+
|
182
|
+
@property
|
183
|
+
def weights(self):
|
184
|
+
'''
|
185
|
+
Returns
|
186
|
+
-------
|
187
|
+
Formatted weights, with timeLag dimension flipped to conform to mTRF
|
188
|
+
[outChannels, inChannels, timeLags]
|
189
|
+
'''
|
190
|
+
return np.flip(self.state_dict()['oCNN.weight'].cpu().detach().numpy(),axis = -1)
|
191
|
+
|
192
|
+
@property
|
193
|
+
def w(self):
|
194
|
+
'''
|
195
|
+
funtion reproduce the definition of w in mTRF-toolbox
|
196
|
+
|
197
|
+
Returns
|
198
|
+
-------
|
199
|
+
None.
|
200
|
+
|
201
|
+
'''
|
202
|
+
tensor = self.state_dict()['oCNN.weight']
|
203
|
+
tensor = tensor.permute(1,2,0)
|
204
|
+
return np.flip(tensor.cpu().detach().numpy(),axis = 1)
|
205
|
+
|
206
|
+
|
207
|
+
@property
|
208
|
+
def b(self):
|
209
|
+
return self.oCNN.bias.squeeze().detach().cpu().numpy()
|
210
|
+
|
211
|
+
def loadFromMTRFpy(self,w,b,device):
|
212
|
+
#w: (nInChan, nLag, nOutChan)
|
213
|
+
w = w * 1/self.fs
|
214
|
+
b = b * 1/self.fs
|
215
|
+
b = b[0]
|
216
|
+
w = np.flip(w,axis = 1).copy()
|
217
|
+
w = torch.FloatTensor(w).to(device)
|
218
|
+
w = w.permute(2,0,1)
|
219
|
+
b = torch.FloatTensor(b).to(device)
|
220
|
+
with torch.no_grad():
|
221
|
+
self.oCNN.weight = torch.nn.Parameter(w)
|
222
|
+
self.oCNN.bias = torch.nn.Parameter(b)
|
223
|
+
return self
|
224
|
+
|
225
|
+
@property
|
226
|
+
def t(self):
|
227
|
+
return self.lagTimes
|
228
|
+
|
229
|
+
def BatchPearsonr(self,pred,y):
|
230
|
+
tensors = TensorsToNumpy(pred.transpose(-1,-2),y.transpose(-1,-2))
|
231
|
+
return BatchPearsonr(*tensors)
|
232
|
+
|
233
|
+
@property
|
234
|
+
def readableWeights(self):
|
235
|
+
'''
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
Readable formatted weights
|
239
|
+
[timeLags, inChannels, outChannels]
|
240
|
+
'''
|
241
|
+
return self.weights.T
|
242
|
+
|
243
|
+
def W(self,inIdx=None,outIdx=None,tIdx=None):
|
244
|
+
'''
|
245
|
+
Returns
|
246
|
+
-------
|
247
|
+
readable formatted weights of selected inChannel/outChannel/lagTime
|
248
|
+
'''
|
249
|
+
|
250
|
+
inIdx = slice(inIdx) if inIdx is None else inIdx
|
251
|
+
outIdx = slice(outIdx) if outIdx is None else outIdx
|
252
|
+
tIdx = slice(tIdx) if tIdx is None else tIdx
|
253
|
+
return self.readableWeights[tIdx,inIdx,outIdx]
|
254
|
+
|
255
|
+
def load(self,path):
|
256
|
+
self.load_state_dict(torch.load(path,map_location='cpu')['state_dict'])
|
257
|
+
self.eval()
|
258
|
+
|
259
|
+
|
260
|
+
def plotWeights(self,outChan = None, inChan = None):
|
261
|
+
fig,ax = plt.subplots()
|
262
|
+
if outChan is None:
|
263
|
+
outChan = slice(outChan)
|
264
|
+
if inChan is None:
|
265
|
+
inChan = slice(inChan)
|
266
|
+
ax.plot(self.t[::self.dilation], self.weights[outChan, inChan].T)
|
267
|
+
return fig, ax
|
268
|
+
|
269
|
+
|