nntrf 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nntrf-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2020 powerfulbean
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
nntrf-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.2
2
+ Name: nntrf
3
+ Version: 1.0.0
4
+ Home-page: https://github.com/powerfulbean/nnTRF
5
+ Author: Jin Dou
6
+ Author-email: jindou.bci@gmail.com
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.8
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: numpy>=1.20.1
14
+ Requires-Dist: torch<2.0.0,>=1.12.1
15
+ Requires-Dist: scikit-fda==0.7.1
16
+ Requires-Dist: mtrf
17
+ Dynamic: author
18
+ Dynamic: author-email
19
+ Dynamic: classifier
20
+ Dynamic: description-content-type
21
+ Dynamic: home-page
22
+ Dynamic: requires-dist
23
+ Dynamic: requires-python
nntrf-1.0.0/README.md ADDED
@@ -0,0 +1,8 @@
1
+ # nnTRF - neural network Temporal Response Function
2
+
3
+ This package is an artificial neural network implementation for temporal responses function modelling of brain signal. It implement the linear time-invariant TRF which can be solved with ridge regression ([mTRF-Toolbox](https://github.com/mickcrosse/mTRF-Toolbox), [mTRFpy](https://github.com/powerfulbean/mTRFpy)) or boosting ([Eelbrain
4
+ ](https://github.com/christianbrodbeck/Eelbrain)), the [dynamic TRF](https://doi.org/10.1101/2024.08.26.609779) framework and mode!
5
+
6
+ ## Installation
7
+
8
+ coming soon ....
File without changes
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Dec 10 14:56:39 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+ import torch
8
+ def pearsonCorrLoss(x,y):
9
+
10
+ mx = torch.mean(x,1)
11
+ my = torch.mean(y,1)
12
+ x = x[:,:,0]
13
+ y = y[:,:,0]
14
+ xm, ym = x-mx, y-my
15
+ r_num = torch.sum(xm * ym,1)
16
+ r_den = torch.sqrt(torch.sum(xm*xm,1) * torch.sum(ym*ym,1))
17
+ # print(torch.sum(r_num_sub)/r_den)
18
+ if(torch.mean(r_num) == 0 and torch.mean(r_den) ==0):
19
+ raise ValueError("gradient gone\n")
20
+ r = None
21
+ else:
22
+ r = r_num / r_den
23
+ # return 1 - r**2
24
+
25
+ # print(r)
26
+ r = torch.mean(r)
27
+ # loss = torch.exp(r)
28
+ loss = 1 - r
29
+
30
+ return loss
31
+
@@ -0,0 +1,35 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Dec 10 14:55:06 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+ import numpy as np
8
+ # from scipy import stats as spStats
9
+
10
+ def Pearsonr(x,y):
11
+ nObs = len(x)
12
+ sumX = np.sum(x,0)
13
+ sumY = np.sum(y,0)
14
+ sdXY = np.sqrt((np.sum(x**2,0) - (sumX**2/nObs)) * (np.sum(y ** 2, 0) - (sumY ** 2)/nObs))
15
+
16
+ r = (np.sum(x*y,0) - (sumX * sumY)/nObs) / sdXY
17
+ return r
18
+
19
+ #def Pearsonr(x,y):
20
+ ## x = np.squeeze(x)
21
+ ## y = np.squeeze(y)
22
+ ## print(x.shape)
23
+ ## print(y.shape)
24
+ #
25
+ # out = spStats.pearsonr(x,y)
26
+ ## print('correlation',out)
27
+ # return out
28
+
29
+ def BatchPearsonr(pred,y):
30
+ result = list()
31
+ for i in range(len(pred)):
32
+ out1 = Pearsonr(pred[i],y[i])
33
+ # print(out1)
34
+ result.append(out1)
35
+ return np.mean(result,0)
@@ -0,0 +1,3 @@
1
+ from .linear import *
2
+ from .nonlinear import *
3
+ from .composite import *
@@ -0,0 +1,63 @@
1
+ import torch
2
+
3
+ class TwoMixedTRF(torch.nn.Module):
4
+
5
+ def __init__(
6
+ self,
7
+ device,
8
+ trfs, #list of trf models
9
+ feats_keys, #list of feat key for each trf in the trfs
10
+ ):
11
+ super().__init__()
12
+ self.trfs:torch.nn.Module = torch.nn.ModuleList([trf for trf in trfs])
13
+ self.feats_keys = feats_keys
14
+ self.device = device
15
+
16
+ def forward(self,feat_dict:dict, y):
17
+
18
+ pred_list = []
19
+ for trf_index, iTRF in enumerate(self.trfs):
20
+ feats_key = self.feats_keys[trf_index]
21
+ feats = []
22
+ n_dict_feat = 0
23
+ for feat_key in feats_key:
24
+ # print(feat_dict.keys())
25
+ feat = feat_dict[feat_key]
26
+ if isinstance(feat, dict):
27
+ feats.append(feat)
28
+ n_dict_feat += 1
29
+ else:
30
+ assert isinstance(feat, torch.Tensor)
31
+ feats.append(feat)
32
+ if n_dict_feat > 0:
33
+ assert len(feats) == n_dict_feat
34
+ # concatente
35
+ if len(feats) == 1:
36
+ feats = feats[0]
37
+ else:
38
+ # raise NotImplementedError
39
+ timeinfo_0 = feats[0]['timeinfo']
40
+ xs = []
41
+ for feat in feats:
42
+ xs.append(feat['x'])
43
+ torch.equal(timeinfo_0, feat['timeinfo'])
44
+ xs = torch.cat(xs, dim = -2)
45
+ feats = {
46
+ 'x':xs,
47
+ 'timeinfo':timeinfo_0
48
+ }
49
+ pred_list.append(iTRF(**feats))
50
+ else:
51
+ minLen = min([f.shape[-1] for f in feats])
52
+ feats = torch.cat([f[...,:minLen] for f in feats], axis = -2)
53
+ pred_list.append(iTRF(feats))
54
+
55
+
56
+
57
+ minLen = min([p.shape[-1] for p in pred_list] + [y.shape[-1]])
58
+
59
+ pred_list = [p[...,:minLen] for p in pred_list]
60
+ cropedY = y[:,:,:minLen]
61
+ pred = sum(pred_list)
62
+ # stop
63
+ return pred,cropedY
@@ -0,0 +1,269 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Dec 9 10:11:27 2020
4
+
5
+ @author: Jin Dou
6
+ """
7
+
8
+
9
+ import torch
10
+ import numpy as np
11
+ from ..metrics import Pearsonr, BatchPearsonr
12
+ from ..utils import TensorsToNumpy
13
+ try:
14
+ from matplotlib import pyplot as plt
15
+ except:
16
+ plt = None
17
+
18
+ def msec2Idxs(msecRange,fs):
19
+ '''
20
+ convert a millisecond range to a list of sample indexes
21
+
22
+ the left and right ranges will both be included
23
+ '''
24
+ assert len(msecRange) == 2
25
+
26
+ tmin = msecRange[0]/1e3
27
+ tmax = msecRange[1]/1e3
28
+ return list(range(int(np.floor(tmin*fs)),int(np.ceil(tmax*fs)) + 1))
29
+
30
+ def Idxs2msec(lags,fs):
31
+ '''
32
+ convert a list of sample indexes to a millisecond range
33
+
34
+ the left and right ranges will both be included
35
+ '''
36
+ temp = np.array(lags)
37
+ return list(temp/fs * 1e3)
38
+
39
+ class LRTRF(torch.nn.Module):
40
+ '''
41
+ the TRF implemented with a linear layer and time lag of input
42
+ '''
43
+ # the shape of the input for the forward should be the (nBatch,nTimeSteps,nChannels)
44
+
45
+ def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,bias = True):
46
+ super().__init__()
47
+ self.tmin_ms = tmin_ms
48
+ self.tmax_ms = tmax_ms
49
+ self.fs = fs
50
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
51
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
52
+ self.realInDim = len(self.lagIdxs) * inDim
53
+ self.oDense = torch.nn.Linear(self.realInDim,outDim,bias = bias)
54
+ self.inDim = inDim
55
+ self.outDim = outDim
56
+
57
+ def timeLagging(self,tensor):
58
+ x = tensor
59
+ nBatch = x.shape[0]
60
+ batchList = []
61
+ for batchId in range(nBatch):
62
+ batch = x[batchId:batchId+1]
63
+ lagDataList = []
64
+ for idx,lag in enumerate(self.lagIdxs):
65
+ # we assume the last second dimension indicates time steps
66
+ if lag < 0:
67
+ temp = torch.nn.functional.pad(batch,((0,0,0,-lag)))
68
+ # lagDataList.append(temp[:,-lag:,:])
69
+ lagDataList.append((temp.T)[:,-lag:].T)
70
+ elif lag > 0:
71
+ temp = torch.nn.functional.pad(batch,((0,0,lag,0)))
72
+ # lagDataList.append(temp[:,0:-lag,:])
73
+ lagDataList.append((temp.T)[:,0:-lag].T)
74
+ else:
75
+ lagDataList.append(batch)
76
+ batchList.append(torch.cat(lagDataList,-1))
77
+ x3 = torch.cat(batchList,0)
78
+ return x3
79
+
80
+ def forward(self,x):
81
+ x = self.timeLagging(x)
82
+ return self.oDense(x)
83
+
84
+ @property
85
+ def weights(self):
86
+ return self.state_dict()['oDense.weight'].cpu().detach().numpy()
87
+
88
+ @property
89
+ def w(self):
90
+ '''
91
+ funtion reproduce the definition of w in mTRF-toolbox
92
+
93
+ Returns
94
+ -------
95
+ None.
96
+
97
+ '''
98
+ w = self.oDense.weight.T.cpu().detach()
99
+ w = w.view(len(self.lagIdxs),self.inDim,self.outDim)
100
+ w = w.permute(1,0,2)
101
+ w = w.numpy()
102
+ return w
103
+
104
+ def loadFromMTRFpy(self,w,b,device):
105
+ #w: (nInChan, nLag, nOutChan)
106
+ # print(w.shape)
107
+ w = w * 1/ self.fs
108
+ b = b * 1/self.fs
109
+ b = b[0]
110
+ w = torch.FloatTensor(w).to(device)
111
+ w = w.permute(1,0,2)
112
+ w = w.reshape(-1,w.shape[-1]).T
113
+ b = torch.FloatTensor(b).to(device)
114
+ with torch.no_grad():
115
+ self.oDense.weight = torch.nn.Parameter(w)
116
+ self.oDense.bias = torch.nn.Parameter(b)
117
+ return self
118
+
119
+ class CPadOrCrop1D(torch.nn.Module):
120
+ def __init__(self,tmin_idx,tmax_idx):
121
+ super().__init__()
122
+ self.tmin_idx = tmin_idx
123
+ self.tmax_idx = tmax_idx
124
+
125
+ def forward(self,x):
126
+ # padding buttom
127
+ if (self.tmin_idx <= 0):
128
+ x = torch.nn.functional.pad(x,((0,-self.tmin_idx)))
129
+ else:
130
+ x = x[:,:,:-self.tmin_idx]
131
+ if (self.tmax_idx < 0):
132
+ x = x[:,:,-self.tmax_idx:]
133
+ else:
134
+ x = torch.nn.functional.pad(x,((self.tmax_idx,0)))
135
+ return x
136
+
137
+
138
+ class CNNTRF(torch.nn.Module):
139
+ '''
140
+ the TRF implemented with a convolutional layer,
141
+ and zero padding of input
142
+ '''
143
+
144
+ # the shape of the input for the forward should be the (nBatch,nChannels,nTimeSteps,)
145
+ # Be care of the calculation of correlation, when using this model,
146
+ # because the nnTRF.Metrics.Pearsonr treat the input data as the shape of
147
+ # (nTimeSteps, nChannels)
148
+ def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,enableBN = False, dilation = 1):
149
+ super().__init__()
150
+ self.tmin_ms = tmin_ms
151
+ self.tmax_ms = tmax_ms
152
+ self.fs = fs
153
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
154
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
155
+ self.tmin_idx = self.lagIdxs[0]
156
+ self.tmax_idx = self.lagIdxs[-1]
157
+ nLags = len(self.lagTimes)
158
+ nKernels = (nLags - 1) / dilation + 1
159
+ assert np.ceil(nKernels) == np.floor(nKernels)
160
+ nKernels = int(nKernels)
161
+ self.oCNN = torch.nn.Conv1d(
162
+ inDim,
163
+ outDim,
164
+ nKernels,
165
+ groups = groups,
166
+ dilation = dilation
167
+ )
168
+ self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
169
+ self.groups = groups
170
+ self.enableBN = enableBN
171
+ self.oBN = torch.nn.BatchNorm1d(inDim,affine=False,track_running_stats=False)
172
+ self.dilation = dilation
173
+ #if both lagMin and lagMax > 0, more complex operation
174
+
175
+ def forward(self,x):
176
+ if self.enableBN:
177
+ x = self.oBN(x)
178
+ x = self.oPadOrCrop(x)
179
+ x = self.oCNN(x)
180
+ return x
181
+
182
+ @property
183
+ def weights(self):
184
+ '''
185
+ Returns
186
+ -------
187
+ Formatted weights, with timeLag dimension flipped to conform to mTRF
188
+ [outChannels, inChannels, timeLags]
189
+ '''
190
+ return np.flip(self.state_dict()['oCNN.weight'].cpu().detach().numpy(),axis = -1)
191
+
192
+ @property
193
+ def w(self):
194
+ '''
195
+ funtion reproduce the definition of w in mTRF-toolbox
196
+
197
+ Returns
198
+ -------
199
+ None.
200
+
201
+ '''
202
+ tensor = self.state_dict()['oCNN.weight']
203
+ tensor = tensor.permute(1,2,0)
204
+ return np.flip(tensor.cpu().detach().numpy(),axis = 1)
205
+
206
+
207
+ @property
208
+ def b(self):
209
+ return self.oCNN.bias.squeeze().detach().cpu().numpy()
210
+
211
+ def loadFromMTRFpy(self,w,b,device):
212
+ #w: (nInChan, nLag, nOutChan)
213
+ w = w * 1/self.fs
214
+ b = b * 1/self.fs
215
+ b = b[0]
216
+ w = np.flip(w,axis = 1).copy()
217
+ w = torch.FloatTensor(w).to(device)
218
+ w = w.permute(2,0,1)
219
+ b = torch.FloatTensor(b).to(device)
220
+ with torch.no_grad():
221
+ self.oCNN.weight = torch.nn.Parameter(w)
222
+ self.oCNN.bias = torch.nn.Parameter(b)
223
+ return self
224
+
225
+ @property
226
+ def t(self):
227
+ return self.lagTimes
228
+
229
+ def BatchPearsonr(self,pred,y):
230
+ tensors = TensorsToNumpy(pred.transpose(-1,-2),y.transpose(-1,-2))
231
+ return BatchPearsonr(*tensors)
232
+
233
+ @property
234
+ def readableWeights(self):
235
+ '''
236
+ Returns
237
+ -------
238
+ Readable formatted weights
239
+ [timeLags, inChannels, outChannels]
240
+ '''
241
+ return self.weights.T
242
+
243
+ def W(self,inIdx=None,outIdx=None,tIdx=None):
244
+ '''
245
+ Returns
246
+ -------
247
+ readable formatted weights of selected inChannel/outChannel/lagTime
248
+ '''
249
+
250
+ inIdx = slice(inIdx) if inIdx is None else inIdx
251
+ outIdx = slice(outIdx) if outIdx is None else outIdx
252
+ tIdx = slice(tIdx) if tIdx is None else tIdx
253
+ return self.readableWeights[tIdx,inIdx,outIdx]
254
+
255
+ def load(self,path):
256
+ self.load_state_dict(torch.load(path,map_location='cpu')['state_dict'])
257
+ self.eval()
258
+
259
+
260
+ def plotWeights(self,outChan = None, inChan = None):
261
+ fig,ax = plt.subplots()
262
+ if outChan is None:
263
+ outChan = slice(outChan)
264
+ if inChan is None:
265
+ inChan = slice(inChan)
266
+ ax.plot(self.t[::self.dilation], self.weights[outChan, inChan].T)
267
+ return fig, ax
268
+
269
+