nntrf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nntrf/__init__.py ADDED
File without changes
nntrf/loss.py ADDED
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Dec 10 14:56:39 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+ import torch
8
+ def pearsonCorrLoss(x,y):
9
+
10
+ mx = torch.mean(x,1)
11
+ my = torch.mean(y,1)
12
+ x = x[:,:,0]
13
+ y = y[:,:,0]
14
+ xm, ym = x-mx, y-my
15
+ r_num = torch.sum(xm * ym,1)
16
+ r_den = torch.sqrt(torch.sum(xm*xm,1) * torch.sum(ym*ym,1))
17
+ # print(torch.sum(r_num_sub)/r_den)
18
+ if(torch.mean(r_num) == 0 and torch.mean(r_den) ==0):
19
+ raise ValueError("gradient gone\n")
20
+ r = None
21
+ else:
22
+ r = r_num / r_den
23
+ # return 1 - r**2
24
+
25
+ # print(r)
26
+ r = torch.mean(r)
27
+ # loss = torch.exp(r)
28
+ loss = 1 - r
29
+
30
+ return loss
31
+
nntrf/metrics.py ADDED
@@ -0,0 +1,35 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Dec 10 14:55:06 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+ import numpy as np
8
+ # from scipy import stats as spStats
9
+
10
+ def Pearsonr(x,y):
11
+ nObs = len(x)
12
+ sumX = np.sum(x,0)
13
+ sumY = np.sum(y,0)
14
+ sdXY = np.sqrt((np.sum(x**2,0) - (sumX**2/nObs)) * (np.sum(y ** 2, 0) - (sumY ** 2)/nObs))
15
+
16
+ r = (np.sum(x*y,0) - (sumX * sumY)/nObs) / sdXY
17
+ return r
18
+
19
+ #def Pearsonr(x,y):
20
+ ## x = np.squeeze(x)
21
+ ## y = np.squeeze(y)
22
+ ## print(x.shape)
23
+ ## print(y.shape)
24
+ #
25
+ # out = spStats.pearsonr(x,y)
26
+ ## print('correlation',out)
27
+ # return out
28
+
29
+ def BatchPearsonr(pred,y):
30
+ result = list()
31
+ for i in range(len(pred)):
32
+ out1 = Pearsonr(pred[i],y[i])
33
+ # print(out1)
34
+ result.append(out1)
35
+ return np.mean(result,0)
@@ -0,0 +1,3 @@
1
+ from .linear import *
2
+ from .nonlinear import *
3
+ from .composite import *
@@ -0,0 +1,63 @@
1
+ import torch
2
+
3
+ class TwoMixedTRF(torch.nn.Module):
4
+
5
+ def __init__(
6
+ self,
7
+ device,
8
+ trfs, #list of trf models
9
+ feats_keys, #list of feat key for each trf in the trfs
10
+ ):
11
+ super().__init__()
12
+ self.trfs:torch.nn.Module = torch.nn.ModuleList([trf for trf in trfs])
13
+ self.feats_keys = feats_keys
14
+ self.device = device
15
+
16
+ def forward(self,feat_dict:dict, y):
17
+
18
+ pred_list = []
19
+ for trf_index, iTRF in enumerate(self.trfs):
20
+ feats_key = self.feats_keys[trf_index]
21
+ feats = []
22
+ n_dict_feat = 0
23
+ for feat_key in feats_key:
24
+ # print(feat_dict.keys())
25
+ feat = feat_dict[feat_key]
26
+ if isinstance(feat, dict):
27
+ feats.append(feat)
28
+ n_dict_feat += 1
29
+ else:
30
+ assert isinstance(feat, torch.Tensor)
31
+ feats.append(feat)
32
+ if n_dict_feat > 0:
33
+ assert len(feats) == n_dict_feat
34
+ # concatente
35
+ if len(feats) == 1:
36
+ feats = feats[0]
37
+ else:
38
+ # raise NotImplementedError
39
+ timeinfo_0 = feats[0]['timeinfo']
40
+ xs = []
41
+ for feat in feats:
42
+ xs.append(feat['x'])
43
+ torch.equal(timeinfo_0, feat['timeinfo'])
44
+ xs = torch.cat(xs, dim = -2)
45
+ feats = {
46
+ 'x':xs,
47
+ 'timeinfo':timeinfo_0
48
+ }
49
+ pred_list.append(iTRF(**feats))
50
+ else:
51
+ minLen = min([f.shape[-1] for f in feats])
52
+ feats = torch.cat([f[...,:minLen] for f in feats], axis = -2)
53
+ pred_list.append(iTRF(feats))
54
+
55
+
56
+
57
+ minLen = min([p.shape[-1] for p in pred_list] + [y.shape[-1]])
58
+
59
+ pred_list = [p[...,:minLen] for p in pred_list]
60
+ cropedY = y[:,:,:minLen]
61
+ pred = sum(pred_list)
62
+ # stop
63
+ return pred,cropedY
nntrf/models/linear.py ADDED
@@ -0,0 +1,269 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Dec 9 10:11:27 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+
8
+
9
+ import torch
10
+ import numpy as np
11
+ from ..metrics import Pearsonr, BatchPearsonr
12
+ from ..utils import TensorsToNumpy
13
+ try:
14
+ from matplotlib import pyplot as plt
15
+ except:
16
+ plt = None
17
+
18
+ def msec2Idxs(msecRange,fs):
19
+ '''
20
+ convert a millisecond range to a list of sample indexes
21
+
22
+ the left and right ranges will both be included
23
+ '''
24
+ assert len(msecRange) == 2
25
+
26
+ tmin = msecRange[0]/1e3
27
+ tmax = msecRange[1]/1e3
28
+ return list(range(int(np.floor(tmin*fs)),int(np.ceil(tmax*fs)) + 1))
29
+
30
+ def Idxs2msec(lags,fs):
31
+ '''
32
+ convert a list of sample indexes to a millisecond range
33
+
34
+ the left and right ranges will both be included
35
+ '''
36
+ temp = np.array(lags)
37
+ return list(temp/fs * 1e3)
38
+
39
+ class LRTRF(torch.nn.Module):
40
+ '''
41
+ the TRF implemented with a linear layer and time lag of input
42
+ '''
43
+ # the shape of the input for the forward should be the (nBatch,nTimeSteps,nChannels)
44
+
45
+ def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,bias = True):
46
+ super().__init__()
47
+ self.tmin_ms = tmin_ms
48
+ self.tmax_ms = tmax_ms
49
+ self.fs = fs
50
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
51
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
52
+ self.realInDim = len(self.lagIdxs) * inDim
53
+ self.oDense = torch.nn.Linear(self.realInDim,outDim,bias = bias)
54
+ self.inDim = inDim
55
+ self.outDim = outDim
56
+
57
+ def timeLagging(self,tensor):
58
+ x = tensor
59
+ nBatch = x.shape[0]
60
+ batchList = []
61
+ for batchId in range(nBatch):
62
+ batch = x[batchId:batchId+1]
63
+ lagDataList = []
64
+ for idx,lag in enumerate(self.lagIdxs):
65
+ # we assume the last second dimension indicates time steps
66
+ if lag < 0:
67
+ temp = torch.nn.functional.pad(batch,((0,0,0,-lag)))
68
+ # lagDataList.append(temp[:,-lag:,:])
69
+ lagDataList.append((temp.T)[:,-lag:].T)
70
+ elif lag > 0:
71
+ temp = torch.nn.functional.pad(batch,((0,0,lag,0)))
72
+ # lagDataList.append(temp[:,0:-lag,:])
73
+ lagDataList.append((temp.T)[:,0:-lag].T)
74
+ else:
75
+ lagDataList.append(batch)
76
+ batchList.append(torch.cat(lagDataList,-1))
77
+ x3 = torch.cat(batchList,0)
78
+ return x3
79
+
80
+ def forward(self,x):
81
+ x = self.timeLagging(x)
82
+ return self.oDense(x)
83
+
84
+ @property
85
+ def weights(self):
86
+ return self.state_dict()['oDense.weight'].cpu().detach().numpy()
87
+
88
+ @property
89
+ def w(self):
90
+ '''
91
+ funtion reproduce the definition of w in mTRF-toolbox
92
+
93
+ Returns
94
+ -------
95
+ None.
96
+
97
+ '''
98
+ w = self.oDense.weight.T.cpu().detach()
99
+ w = w.view(len(self.lagIdxs),self.inDim,self.outDim)
100
+ w = w.permute(1,0,2)
101
+ w = w.numpy()
102
+ return w
103
+
104
+ def loadFromMTRFpy(self,w,b,device):
105
+ #w: (nInChan, nLag, nOutChan)
106
+ # print(w.shape)
107
+ w = w * 1/ self.fs
108
+ b = b * 1/self.fs
109
+ b = b[0]
110
+ w = torch.FloatTensor(w).to(device)
111
+ w = w.permute(1,0,2)
112
+ w = w.reshape(-1,w.shape[-1]).T
113
+ b = torch.FloatTensor(b).to(device)
114
+ with torch.no_grad():
115
+ self.oDense.weight = torch.nn.Parameter(w)
116
+ self.oDense.bias = torch.nn.Parameter(b)
117
+ return self
118
+
119
+ class CPadOrCrop1D(torch.nn.Module):
120
+ def __init__(self,tmin_idx,tmax_idx):
121
+ super().__init__()
122
+ self.tmin_idx = tmin_idx
123
+ self.tmax_idx = tmax_idx
124
+
125
+ def forward(self,x):
126
+ # padding buttom
127
+ if (self.tmin_idx <= 0):
128
+ x = torch.nn.functional.pad(x,((0,-self.tmin_idx)))
129
+ else:
130
+ x = x[:,:,:-self.tmin_idx]
131
+ if (self.tmax_idx < 0):
132
+ x = x[:,:,-self.tmax_idx:]
133
+ else:
134
+ x = torch.nn.functional.pad(x,((self.tmax_idx,0)))
135
+ return x
136
+
137
+
138
+ class CNNTRF(torch.nn.Module):
139
+ '''
140
+ the TRF implemented with a convolutional layer,
141
+ and zero padding of input
142
+ '''
143
+
144
+ # the shape of the input for the forward should be the (nBatch,nChannels,nTimeSteps,)
145
+ # Be care of the calculation of correlation, when using this model,
146
+ # because the nnTRF.Metrics.Pearsonr treat the input data as the shape of
147
+ # (nTimeSteps, nChannels)
148
+ def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,enableBN = False, dilation = 1):
149
+ super().__init__()
150
+ self.tmin_ms = tmin_ms
151
+ self.tmax_ms = tmax_ms
152
+ self.fs = fs
153
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
154
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
155
+ self.tmin_idx = self.lagIdxs[0]
156
+ self.tmax_idx = self.lagIdxs[-1]
157
+ nLags = len(self.lagTimes)
158
+ nKernels = (nLags - 1) / dilation + 1
159
+ assert np.ceil(nKernels) == np.floor(nKernels)
160
+ nKernels = int(nKernels)
161
+ self.oCNN = torch.nn.Conv1d(
162
+ inDim,
163
+ outDim,
164
+ nKernels,
165
+ groups = groups,
166
+ dilation = dilation
167
+ )
168
+ self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
169
+ self.groups = groups
170
+ self.enableBN = enableBN
171
+ self.oBN = torch.nn.BatchNorm1d(inDim,affine=False,track_running_stats=False)
172
+ self.dilation = dilation
173
+ #if both lagMin and lagMax > 0, more complex operation
174
+
175
+ def forward(self,x):
176
+ if self.enableBN:
177
+ x = self.oBN(x)
178
+ x = self.oPadOrCrop(x)
179
+ x = self.oCNN(x)
180
+ return x
181
+
182
+ @property
183
+ def weights(self):
184
+ '''
185
+ Returns
186
+ -------
187
+ Formatted weights, with timeLag dimension flipped to conform to mTRF
188
+ [outChannels, inChannels, timeLags]
189
+ '''
190
+ return np.flip(self.state_dict()['oCNN.weight'].cpu().detach().numpy(),axis = -1)
191
+
192
+ @property
193
+ def w(self):
194
+ '''
195
+ funtion reproduce the definition of w in mTRF-toolbox
196
+
197
+ Returns
198
+ -------
199
+ None.
200
+
201
+ '''
202
+ tensor = self.state_dict()['oCNN.weight']
203
+ tensor = tensor.permute(1,2,0)
204
+ return np.flip(tensor.cpu().detach().numpy(),axis = 1)
205
+
206
+
207
+ @property
208
+ def b(self):
209
+ return self.oCNN.bias.squeeze().detach().cpu().numpy()
210
+
211
+ def loadFromMTRFpy(self,w,b,device):
212
+ #w: (nInChan, nLag, nOutChan)
213
+ w = w * 1/self.fs
214
+ b = b * 1/self.fs
215
+ b = b[0]
216
+ w = np.flip(w,axis = 1).copy()
217
+ w = torch.FloatTensor(w).to(device)
218
+ w = w.permute(2,0,1)
219
+ b = torch.FloatTensor(b).to(device)
220
+ with torch.no_grad():
221
+ self.oCNN.weight = torch.nn.Parameter(w)
222
+ self.oCNN.bias = torch.nn.Parameter(b)
223
+ return self
224
+
225
+ @property
226
+ def t(self):
227
+ return self.lagTimes
228
+
229
+ def BatchPearsonr(self,pred,y):
230
+ tensors = TensorsToNumpy(pred.transpose(-1,-2),y.transpose(-1,-2))
231
+ return BatchPearsonr(*tensors)
232
+
233
+ @property
234
+ def readableWeights(self):
235
+ '''
236
+ Returns
237
+ -------
238
+ Readable formatted weights
239
+ [timeLags, inChannels, outChannels]
240
+ '''
241
+ return self.weights.T
242
+
243
+ def W(self,inIdx=None,outIdx=None,tIdx=None):
244
+ '''
245
+ Returns
246
+ -------
247
+ readable formatted weights of selected inChannel/outChannel/lagTime
248
+ '''
249
+
250
+ inIdx = slice(inIdx) if inIdx is None else inIdx
251
+ outIdx = slice(outIdx) if outIdx is None else outIdx
252
+ tIdx = slice(tIdx) if tIdx is None else tIdx
253
+ return self.readableWeights[tIdx,inIdx,outIdx]
254
+
255
+ def load(self,path):
256
+ self.load_state_dict(torch.load(path,map_location='cpu')['state_dict'])
257
+ self.eval()
258
+
259
+
260
+ def plotWeights(self,outChan = None, inChan = None):
261
+ fig,ax = plt.subplots()
262
+ if outChan is None:
263
+ outChan = slice(outChan)
264
+ if inChan is None:
265
+ inChan = slice(inChan)
266
+ ax.plot(self.t[::self.dilation], self.weights[outChan, inChan].T)
267
+ return fig, ax
268
+
269
+