nntrf 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntrf/__init__.py +0 -0
- nntrf/loss.py +31 -0
- nntrf/metrics.py +35 -0
- nntrf/models/__init__.py +3 -0
- nntrf/models/composite.py +63 -0
- nntrf/models/linear.py +269 -0
- nntrf/models/nonlinear.py +1358 -0
- nntrf/utils.py +12 -0
- nntrf-1.0.0.dist-info/LICENSE +21 -0
- nntrf-1.0.0.dist-info/METADATA +23 -0
- nntrf-1.0.0.dist-info/RECORD +13 -0
- nntrf-1.0.0.dist-info/WHEEL +5 -0
- nntrf-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1358 @@
|
|
1
|
+
import math
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from scipy.stats import pearsonr
|
5
|
+
from torch.nn.functional import pad, fold
|
6
|
+
from .linear import msec2Idxs, Idxs2msec, CPadOrCrop1D
|
7
|
+
try:
|
8
|
+
import skfda
|
9
|
+
except:
|
10
|
+
skfda = None
|
11
|
+
|
12
|
+
try:
|
13
|
+
from matplotlib import pyplot as plt
|
14
|
+
except:
|
15
|
+
plt = None
|
16
|
+
|
17
|
+
try:
|
18
|
+
from mtrf.model import TRF
|
19
|
+
except:
|
20
|
+
TRF = None
|
21
|
+
|
22
|
+
|
23
|
+
def fit_forward_mtrf(stim, resp, fs, tmin_ms, tmax_ms, regularization, k):
|
24
|
+
trf = TRF(direction=1)
|
25
|
+
trf.train(stim, resp, fs, tmin_ms / 1e3, tmax_ms / 1e3, regularization, k = k)
|
26
|
+
return trf.weights, trf.bias
|
27
|
+
|
28
|
+
def seqLast_pad_zero(seq, value = 0):
|
29
|
+
maxLen = max([i.shape[-1] for i in seq])
|
30
|
+
output = []
|
31
|
+
for i in seq:
|
32
|
+
output.append(pad(i,(0,maxLen - i.shape[-1]), value = value))
|
33
|
+
return torch.stack(output,0)
|
34
|
+
|
35
|
+
class CausalConv(torch.nn.Module):
|
36
|
+
|
37
|
+
def __init__(self,inDim,outDim,nKernel,dilation = 1):
|
38
|
+
super().__init__()
|
39
|
+
self.nKernel = nKernel
|
40
|
+
self.dilation = dilation
|
41
|
+
self.conv = torch.nn.Conv1d(
|
42
|
+
inDim,
|
43
|
+
outDim,
|
44
|
+
nKernel,
|
45
|
+
dilation = dilation
|
46
|
+
)
|
47
|
+
|
48
|
+
def forward(self,x):
|
49
|
+
'''
|
50
|
+
x: (nBatch, nChan, nSeq)
|
51
|
+
'''
|
52
|
+
# padding left
|
53
|
+
x = torch.nn.functional.pad(x,(( self.dilation * (self.nKernel-1) ,0)))
|
54
|
+
|
55
|
+
#(nBatch, nOutChan, nSeq)
|
56
|
+
x = self.conv(x)
|
57
|
+
return x
|
58
|
+
## experiment module end
|
59
|
+
|
60
|
+
class TRFAligner(torch.nn.Module):
|
61
|
+
|
62
|
+
def __init__(self,device):
|
63
|
+
super().__init__()
|
64
|
+
self.device = device
|
65
|
+
|
66
|
+
def forward(self,TRFs,sourceIdx,nRealLen):#,targetTensor):
|
67
|
+
'''
|
68
|
+
in-place operation
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
TRFs : TYPE, (nBatch, outDim, nWin, nSeq)
|
72
|
+
tensors output by DyTimeEncoder.
|
73
|
+
sourceIdx : TYPE, (nBatch, nSeq)
|
74
|
+
index of dyImpulse tensor to be assigned to target tensor
|
75
|
+
nRealLen:
|
76
|
+
the length of the target
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
None.
|
80
|
+
|
81
|
+
'''
|
82
|
+
nBatch, outDim, nWin, nSeq = TRFs.shape
|
83
|
+
# (nBatch, outDim, nWin, nSeq)
|
84
|
+
respUnfold = TRFs
|
85
|
+
maxSrcIdx = torch.max(sourceIdx[:, -1])
|
86
|
+
if maxSrcIdx >= nRealLen:
|
87
|
+
nRealLen = maxSrcIdx + 1
|
88
|
+
# print(outDim,nWin,nRealLen,respUnfold.shape,sourceIdx)
|
89
|
+
self.cache = torch.zeros((nBatch, outDim,nWin,nRealLen),device = self.device)
|
90
|
+
|
91
|
+
idxWin = torch.arange(nWin)
|
92
|
+
idxChan = torch.arange(outDim)
|
93
|
+
idxBatch = torch.arange(nBatch)
|
94
|
+
idxWin = idxWin[:, None]
|
95
|
+
idxChan = idxChan[:,None, None]
|
96
|
+
idxBatch = idxBatch[:, None, None, None]
|
97
|
+
sourceIdx = sourceIdx[:,None, None,:]
|
98
|
+
|
99
|
+
self.cache[idxBatch, idxChan, idxWin, sourceIdx] = respUnfold #(nBatch, outDim,nWin,nRealLen)
|
100
|
+
self.cache = self.cache.view(nBatch,-1,nRealLen) # (nBatch, outDim*nWin, nRealLen)
|
101
|
+
foldOutputSize = (nRealLen + nWin - 1, 1)
|
102
|
+
foldKernelSize = (nWin, 1)
|
103
|
+
#(nBatch,outDim,foldOutputSize,1)
|
104
|
+
output = fold(self.cache,foldOutputSize,foldKernelSize)
|
105
|
+
#(nBatch,outDim,nRealLen)
|
106
|
+
targetTensor = output[:,:,:nRealLen,0]
|
107
|
+
return targetTensor
|
108
|
+
|
109
|
+
class LTITRFGen(torch.nn.Module):
|
110
|
+
def __init__(self,inDim,nWin,outDim,ifAddBiasInForward = True):
|
111
|
+
super().__init__()
|
112
|
+
self.weight = torch.nn.Parameter(torch.ones(outDim,inDim,nWin))
|
113
|
+
self.bias = torch.nn.Parameter(torch.ones(outDim))
|
114
|
+
k = 1 / (inDim * nWin)
|
115
|
+
lower = - np.sqrt(k)
|
116
|
+
upper = np.sqrt(k)
|
117
|
+
torch.nn.init.uniform_(self.weight, a = lower, b = upper)
|
118
|
+
torch.nn.init.uniform_(self.bias, a = lower, b = upper)
|
119
|
+
self.ifAddBiasInForward = ifAddBiasInForward
|
120
|
+
|
121
|
+
@property
|
122
|
+
def outDim(self):
|
123
|
+
return self.weight.shape[0]
|
124
|
+
|
125
|
+
@property
|
126
|
+
def inDim(self):
|
127
|
+
return self.weight.shape[1]
|
128
|
+
|
129
|
+
@property
|
130
|
+
def nWin(self):
|
131
|
+
return self.weight.shape[2]
|
132
|
+
|
133
|
+
def forward(self,x):
|
134
|
+
# x: (nBatch, inDim, nSeq)
|
135
|
+
assert x.ndim == 3
|
136
|
+
kernelsTemp = self.weight[None, ..., None] #(1, outDim, inDim, nWin, 1)
|
137
|
+
xTemp = x[:, None, :, None, :] #(nBatch, 1, inDim, 1, nSeq)
|
138
|
+
TRFs = xTemp * kernelsTemp #(nBatch, outDim, inDim, nWin, nSeq)
|
139
|
+
if self.ifAddBiasInForward:
|
140
|
+
TRFs = TRFs + self.bias[..., None, None, None] #(nBatch, outDim, inDim, nWin, nSeq)
|
141
|
+
TRFs = TRFs.sum(2) #(nBatch, outDim, nWin, nSeq)
|
142
|
+
return TRFs
|
143
|
+
|
144
|
+
def load_mtrf_weights(self, w, b, fs, device):
|
145
|
+
#w: (nInChan, nLag, nOutChan)
|
146
|
+
b = b[0]
|
147
|
+
w = w * 1 / fs
|
148
|
+
b = b * 1/ fs
|
149
|
+
w = torch.FloatTensor(w).to(device)
|
150
|
+
b = torch.FloatTensor(b).to(device)
|
151
|
+
w = w.permute(2, 0, 1) #(nOutChan, nInChan, nLag)
|
152
|
+
with torch.no_grad():
|
153
|
+
self.weight = torch.nn.Parameter(w)
|
154
|
+
self.bias = torch.nn.Parameter(b)
|
155
|
+
return self
|
156
|
+
|
157
|
+
def export_mtrf_weights(self, fs):
|
158
|
+
with torch.no_grad():
|
159
|
+
# (nInChan, nLag, nOutChan)
|
160
|
+
w = self.weight.cpu().detach().permute(1, 2, 0).numpy()
|
161
|
+
b = self.bias.cpu().detach().numpy().reshape(1,-1)
|
162
|
+
w = w * fs
|
163
|
+
b = b * fs
|
164
|
+
return w, b
|
165
|
+
|
166
|
+
def stop_update_weights(self):
|
167
|
+
self.weight.requires_grad_(False)
|
168
|
+
self.weight.grad = None
|
169
|
+
self.bias.requires_grad_(False)
|
170
|
+
self.bias.grad = None
|
171
|
+
|
172
|
+
def enable_update_weights(self):
|
173
|
+
self.requires_grad_(True)
|
174
|
+
self.weight.grad = torch.zeros_like(
|
175
|
+
self.weight
|
176
|
+
)
|
177
|
+
self.bias.grad = torch.zeros_like(
|
178
|
+
self.bias
|
179
|
+
)
|
180
|
+
|
181
|
+
class WordTRFEmbedGenTokenizer():
|
182
|
+
def __init__(self, wordsDict, device):
|
183
|
+
self.wordsDict = wordsDict
|
184
|
+
self.device = device
|
185
|
+
|
186
|
+
def __call__(self, words):
|
187
|
+
batchTokens = []
|
188
|
+
for ws in words:
|
189
|
+
tokens = []
|
190
|
+
for w in ws:
|
191
|
+
tokens.append(self.wordsDict[w])
|
192
|
+
batchTokens.append(
|
193
|
+
torch.tensor(tokens, dtype = torch.long,device = self.device)
|
194
|
+
)
|
195
|
+
return batchTokens
|
196
|
+
|
197
|
+
class WordTRFEmbedGen(torch.nn.Module):
|
198
|
+
|
199
|
+
def __init__(
|
200
|
+
self,
|
201
|
+
outDim,
|
202
|
+
hiddenDim,
|
203
|
+
tmin_ms,
|
204
|
+
tmax_ms,
|
205
|
+
fs,
|
206
|
+
wordsDict,
|
207
|
+
device
|
208
|
+
):
|
209
|
+
super().__init__()
|
210
|
+
self.outDim = outDim
|
211
|
+
self.hiddenDim = hiddenDim
|
212
|
+
self.tmin_ms = tmin_ms
|
213
|
+
self.tmax_ms = tmax_ms
|
214
|
+
self.fs = fs
|
215
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
216
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
217
|
+
nWin = len(self.lagTimes)
|
218
|
+
self.nWin = nWin
|
219
|
+
self.embedding_dim = nWin * hiddenDim
|
220
|
+
|
221
|
+
self.device = device
|
222
|
+
self.wordsDict = wordsDict
|
223
|
+
self.embedding = torch.nn.Embedding(
|
224
|
+
len(wordsDict)+1,
|
225
|
+
self.embedding_dim,
|
226
|
+
padding_idx = 0
|
227
|
+
).to(device)
|
228
|
+
self.proj = torch.nn.Linear(self.hiddenDim, self.outDim, device = device)
|
229
|
+
|
230
|
+
def forward(self, batchTokens):
|
231
|
+
# (nBatch, outDim, nWin, nSeq)
|
232
|
+
batchTokens = seqLast_pad_zero(batchTokens)
|
233
|
+
# (nBatch, nWin * hiddenDim)
|
234
|
+
trfs = self.embedding(batchTokens)
|
235
|
+
# print(trfs.shape)
|
236
|
+
trfs = trfs.reshape(*trfs.shape[:2], self.hiddenDim, self.nWin)
|
237
|
+
# (nBatch, nSeq, nWin, hiddenDim)
|
238
|
+
# print(trfs.shape)
|
239
|
+
trfs = trfs.permute(0, 1, 3, 2)
|
240
|
+
# (nBatch, nSeq, nWin, outDim)
|
241
|
+
# print(torch.cuda.memory_allocated()/1024/1024)
|
242
|
+
trfs = self.proj(trfs)
|
243
|
+
# (nBatch, outDim, nWin, nSeq)
|
244
|
+
trfs = trfs.permute(0, 3, 2, 1)
|
245
|
+
# print(trfs.shape)
|
246
|
+
return trfs
|
247
|
+
|
248
|
+
|
249
|
+
class CustomKernelCNNTRF(torch.nn.Module):
|
250
|
+
|
251
|
+
def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,dilation = 1):
|
252
|
+
super().__init__()
|
253
|
+
self.tmin_ms = tmin_ms
|
254
|
+
self.tmax_ms = tmax_ms
|
255
|
+
self.fs = fs
|
256
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
257
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
258
|
+
self.tmin_idx = self.lagIdxs[0]
|
259
|
+
self.tmax_idx = self.lagIdxs[-1]
|
260
|
+
nLags = len(self.lagTimes)
|
261
|
+
nKernels = (nLags - 1) / dilation + 1
|
262
|
+
assert np.ceil(nKernels) == np.floor(nKernels)
|
263
|
+
self.nWin = int(nKernels)
|
264
|
+
self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
|
265
|
+
self.groups = groups
|
266
|
+
self.dilation = dilation
|
267
|
+
self.inDim = inDim
|
268
|
+
self.outDim = outDim
|
269
|
+
|
270
|
+
k = 1 / (inDim * self.nWin)
|
271
|
+
lower = - np.sqrt(k)
|
272
|
+
upper = np.sqrt(k)
|
273
|
+
self.bias = torch.nn.Parameter(torch.ones(outDim))
|
274
|
+
torch.nn.init.uniform_(self.bias, a = lower, b = upper)
|
275
|
+
|
276
|
+
def setTRFGen(self, trfGen):
|
277
|
+
self.trfGen = trfGen
|
278
|
+
|
279
|
+
def forward(self, x, weight = None):
|
280
|
+
if weight is None:
|
281
|
+
weight = self.trfGen.TRF()
|
282
|
+
assert self.nWin == weight.shape[2]
|
283
|
+
assert self.outDim == weight.shape[0]
|
284
|
+
assert self.inDim == weight.shape[1]
|
285
|
+
x = self.oPadOrCrop(x)
|
286
|
+
y = torch.nn.functional.conv1d(
|
287
|
+
x,
|
288
|
+
weight,
|
289
|
+
bias=self.bias,
|
290
|
+
dilation=self.dilation,
|
291
|
+
groups=self.groups
|
292
|
+
)
|
293
|
+
return y
|
294
|
+
|
295
|
+
class FuncBasisTRF(torch.nn.Module):
|
296
|
+
|
297
|
+
def __init__(self, inDim, outDim, tmin_idx, tmax_idx, timeshiftLimit_idx ,device) -> None:
|
298
|
+
super().__init__()
|
299
|
+
self.timeshiftLimit_idx = torch.tensor(timeshiftLimit_idx)
|
300
|
+
self.time_embedding = self.get_time_embedding(
|
301
|
+
tmin_idx, tmax_idx, device)
|
302
|
+
self.time_embedding_ext = self.get_time_embedding(
|
303
|
+
tmin_idx-timeshiftLimit_idx, tmax_idx+timeshiftLimit_idx, device)
|
304
|
+
nWin = self.time_embedding_ext.shape[-2]
|
305
|
+
TRFs = torch.zeros((outDim, inDim, nWin),device=device)
|
306
|
+
self.register_buffer('TRFs',TRFs)
|
307
|
+
|
308
|
+
# def corrected_time_embedding(self, t):
|
309
|
+
# return t + + self.timeshiftLimit_idx
|
310
|
+
|
311
|
+
def TRF(self):
|
312
|
+
# (outDim, inDim, nWin)
|
313
|
+
#(nBatch, 1, 1, nWin, nSeq)
|
314
|
+
return self.forward(1, 0, 1)[0,...,0]#.detach().cpu().numpy()
|
315
|
+
|
316
|
+
@property
|
317
|
+
def inDim(self):
|
318
|
+
return self.TRFs.shape[1]
|
319
|
+
|
320
|
+
@property
|
321
|
+
def outDim(self):
|
322
|
+
return self.TRFs.shape[0]
|
323
|
+
|
324
|
+
@property
|
325
|
+
def nWin(self):
|
326
|
+
return self.TRFs.shape[2]
|
327
|
+
|
328
|
+
@property
|
329
|
+
def nBasis(self):
|
330
|
+
raise NotImplementedError
|
331
|
+
|
332
|
+
def vis(self):
|
333
|
+
raise NotImplementedError
|
334
|
+
|
335
|
+
def forward(self, a, b, c):
|
336
|
+
raise NotImplementedError
|
337
|
+
|
338
|
+
def fitTRFs(self, TRFs):
|
339
|
+
raise NotImplementedError
|
340
|
+
|
341
|
+
def get_time_embedding(self, tmin_idx, tmax_idx, device = 'cpu'):
|
342
|
+
#(1, 1, 1, nWin, 1)
|
343
|
+
return torch.arange(tmin_idx,tmax_idx+1, device=device)\
|
344
|
+
.view(1,1,1,-1,1)
|
345
|
+
|
346
|
+
@property
|
347
|
+
def timelag_idx(self,):
|
348
|
+
return self.time_embedding.detach().cpu().squeeze()
|
349
|
+
|
350
|
+
@property
|
351
|
+
def timelag_idx_ext(self,):
|
352
|
+
return self.time_embedding_ext.detach().cpu().squeeze()
|
353
|
+
|
354
|
+
|
355
|
+
def build_gaussian_response(x, mu, sigma):
|
356
|
+
# x: (nBatch, 1, 1, nWin, nSeq)
|
357
|
+
# mu: (nBasis)
|
358
|
+
# sigma: (nBasis, outDim, inDim)
|
359
|
+
# output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
|
360
|
+
|
361
|
+
# x: (nBatch, 1, 1, 1, nWin, nSeq)
|
362
|
+
x = x[:, None, ...]
|
363
|
+
# mu: (nBasis, 1, 1, 1, 1)
|
364
|
+
mu = mu[..., None, None, None, None]
|
365
|
+
# sigma: (nBasis, outDim, inDim, 1, 1)
|
366
|
+
sigma = sigma[..., None, None]
|
367
|
+
# output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
|
368
|
+
return torch.exp(-(x-mu)**2 / (2*(sigma)**2))
|
369
|
+
|
370
|
+
def solve_coef(gaussresps, trf):
|
371
|
+
nWin = trf.shape[0]
|
372
|
+
A = np.concatenate([gaussresps, np.ones((1,nWin))], axis = 0).T
|
373
|
+
coefs = np.linalg.lstsq(A, trf, rcond=None)[0]
|
374
|
+
return coefs
|
375
|
+
|
376
|
+
class GaussianBasisTRF(FuncBasisTRF):
|
377
|
+
|
378
|
+
def __init__(
|
379
|
+
self,
|
380
|
+
inDim,
|
381
|
+
outDim,
|
382
|
+
tmin_idx,
|
383
|
+
tmax_idx,
|
384
|
+
nBasis,
|
385
|
+
timeshiftLimit_idx = 0,
|
386
|
+
sigmaMin = 6.4,
|
387
|
+
sigmaMax = 6.4,
|
388
|
+
ifSumInDim = False,
|
389
|
+
device = 'cpu',
|
390
|
+
mu = None,
|
391
|
+
sigma = None,
|
392
|
+
include_constant_term = True
|
393
|
+
):
|
394
|
+
super().__init__(inDim, outDim, tmin_idx, tmax_idx, timeshiftLimit_idx, device)
|
395
|
+
nWin = self.nWin
|
396
|
+
### Fittable Parameters
|
397
|
+
## out projection init
|
398
|
+
coefs = torch.ones((nBasis + 1, outDim, inDim), device = device, dtype = torch.float32)
|
399
|
+
torch.nn.init.kaiming_uniform_(coefs, a=math.sqrt(5))
|
400
|
+
self.coefs = torch.nn.Parameter(coefs)
|
401
|
+
## bias init
|
402
|
+
# k = 1 / (inDim * nWin)
|
403
|
+
# lower = - np.sqrt(k)
|
404
|
+
# upper = np.sqrt(k)
|
405
|
+
# self.bias = torch.nn.Parameter(torch.ones(outDim))
|
406
|
+
# torch.nn.init.uniform_(self.bias, a = lower, b = upper)
|
407
|
+
## sigma init
|
408
|
+
if sigma is not None:
|
409
|
+
assert len(sigma) == nBasis
|
410
|
+
sigma = torch.tensor(sigma)
|
411
|
+
else:
|
412
|
+
sigma = torch.ones(nBasis, outDim, inDim, device = device, dtype = torch.float32) * (sigmaMin + sigmaMax) / 2
|
413
|
+
self.sigma = torch.nn.Parameter(sigma)
|
414
|
+
# torch.nn.init.uniform_(self.sigma, a = lower, b = upper)
|
415
|
+
|
416
|
+
### Fixed Values
|
417
|
+
# timeEmbed = torch.arange(nWin)
|
418
|
+
# self.register_buffer('timeEmbed', timeEmbed)
|
419
|
+
time_embedding_ext = self.time_embedding_ext.squeeze()
|
420
|
+
tmin_idx_ext, tmax_idx_ext = time_embedding_ext[0], time_embedding_ext[-1]
|
421
|
+
if mu is not None:
|
422
|
+
assert len(mu) == nBasis
|
423
|
+
mu = torch.tensor(mu, device = device, dtype = torch.float32)
|
424
|
+
else:
|
425
|
+
mu = torch.linspace(tmin_idx_ext.item(), tmax_idx_ext.item(), nBasis + 2)[1:-1]
|
426
|
+
self.register_buffer('mu', mu)
|
427
|
+
# self.mu = torch.nn.Parameter(mu)
|
428
|
+
|
429
|
+
sigmaMin = torch.tensor(sigmaMin)
|
430
|
+
self.register_buffer('sigmaMin', sigmaMin)
|
431
|
+
sigmaMax = torch.tensor(sigmaMax)
|
432
|
+
self.register_buffer('sigmaMax', sigmaMax)
|
433
|
+
self.ifSumInDim = ifSumInDim
|
434
|
+
self.include_constant_term = include_constant_term
|
435
|
+
self.device = device
|
436
|
+
|
437
|
+
def vec_sum(self, x):
|
438
|
+
return self.vec_gauss_sum(x)
|
439
|
+
|
440
|
+
def vec_gauss_sum(self, x):
|
441
|
+
sigma = self.sigma
|
442
|
+
sigma = torch.maximum(sigma, self.sigmaMin)
|
443
|
+
sigma = torch.minimum(sigma, self.sigmaMax)
|
444
|
+
# print(sigma)
|
445
|
+
# (nBatch, nBasis, outDim, inDim, nWin, nSeq)
|
446
|
+
# print(self.mu, sigma)
|
447
|
+
gaussResps = build_gaussian_response(
|
448
|
+
x,
|
449
|
+
self.mu,
|
450
|
+
sigma
|
451
|
+
)
|
452
|
+
|
453
|
+
# coefs: (nBasis + 1, outDim, inDim, 1, 1)
|
454
|
+
coefs = self.coefs[..., None, None]
|
455
|
+
# print(gaussResps.shape)
|
456
|
+
# nBatch, _, outDim, inDim, nWin, nSeq = gaussResps.shape
|
457
|
+
# (nBatch, nBasis+1, outDim, inDim, nWin, nSeq)
|
458
|
+
# aug_gaussResps = torch.cat([gaussResps, torch.ones(nBatch, 1, outDim, inDim, nWin, nSeq, device = self.device)], dim = -5)
|
459
|
+
# # wGaussResps = coefs[:-1,...] * gaussResps
|
460
|
+
# wGaussResps = coefs * aug_gaussResps
|
461
|
+
# # (nBatch, outDim, inDim, nWin, nSeq)
|
462
|
+
# wGaussResps = wGaussResps.sum((-5))
|
463
|
+
|
464
|
+
# (nBasis, outDim, inDim, 1, 1)
|
465
|
+
coefs_1 = coefs[:-1]
|
466
|
+
# (outDim, inDim, 1, 1)
|
467
|
+
coefs_2 = coefs[-1]
|
468
|
+
# (nBatch, nBasis+1, outDim, inDim, nWin, nSeq)
|
469
|
+
w_gauss_resps_1 = coefs_1 * gaussResps
|
470
|
+
if self.include_constant_term:
|
471
|
+
# print('include constant term')
|
472
|
+
w_gauss_resps = w_gauss_resps_1.sum(-5) + coefs_2
|
473
|
+
else:
|
474
|
+
w_gauss_resps = w_gauss_resps_1.sum(-5)
|
475
|
+
|
476
|
+
return w_gauss_resps
|
477
|
+
|
478
|
+
def forward(self, a, b, c):
|
479
|
+
# print(x.shape)
|
480
|
+
# output (nBatch, outDim, (inDim), nWin, nSeq)
|
481
|
+
# x: x: (nBatch, 1, 1, nWin, nSeq)
|
482
|
+
# currently just support the 'component' mode
|
483
|
+
x = c * (self.time_embedding - b)
|
484
|
+
# x = self.corrected_time_embedding(x)
|
485
|
+
wGaussResps = a * self.vec_gauss_sum(x)
|
486
|
+
# print(coefs[:,0,0,0,0])
|
487
|
+
if self.ifSumInDim:
|
488
|
+
# (nBatch, outDim, nWin, nSeq)
|
489
|
+
wGaussResps = wGaussResps.sum((-3))
|
490
|
+
# wGaussResps = wGaussResps + self.bias[:, None, None]
|
491
|
+
# print(wGaussResps)
|
492
|
+
return wGaussResps
|
493
|
+
|
494
|
+
@property
|
495
|
+
def nBasis(self):
|
496
|
+
return self.sigma.shape[0]
|
497
|
+
|
498
|
+
def fitTRFs(self, TRFs):
|
499
|
+
'''
|
500
|
+
TRFs is the numpy array of mtrf weights
|
501
|
+
Shape: [nInDim, nLags, nOutput]
|
502
|
+
|
503
|
+
self.coefs: (nBasis+1, outDim, inDim)
|
504
|
+
sigma: (nBasis, outDim, inDim)
|
505
|
+
'''
|
506
|
+
# print(TRFs.shape)
|
507
|
+
TRFs = torch.from_numpy(TRFs)
|
508
|
+
TRFs = TRFs.permute(2, 0, 1)
|
509
|
+
self.TRFs[:,:,:] = TRFs.to(self.device)[:,:,:]
|
510
|
+
x = self.time_embedding_ext
|
511
|
+
sigma = self.sigma
|
512
|
+
# print(sigma)
|
513
|
+
# (nBasis, outDim, inDim, nWin)
|
514
|
+
with torch.no_grad():
|
515
|
+
gaussResps = build_gaussian_response(
|
516
|
+
x,
|
517
|
+
self.mu,
|
518
|
+
sigma
|
519
|
+
)[0, ..., 0].cpu().numpy()
|
520
|
+
|
521
|
+
nWin = TRFs.shape[2]
|
522
|
+
assert nWin == self.nWin
|
523
|
+
# (nBasis+1, outDim, inDim)
|
524
|
+
coefs = np.zeros(self.coefs.shape)
|
525
|
+
for i in range(self.outDim):
|
526
|
+
for j in range(self.inDim):
|
527
|
+
t_trf = TRFs[i,j,:]
|
528
|
+
# (nBasis, nWin)
|
529
|
+
t_gauss = gaussResps[:, i, j, :]
|
530
|
+
t_coef = solve_coef(
|
531
|
+
t_gauss,
|
532
|
+
t_trf
|
533
|
+
)
|
534
|
+
# print(coefs[:, i, j].shape, t_coef.shape)
|
535
|
+
coefs[:, i, j] = t_coef
|
536
|
+
|
537
|
+
with torch.no_grad():
|
538
|
+
self.coefs[:,:,:] = torch.from_numpy(coefs)
|
539
|
+
# (nBatch, outDim, inDim, nWin, nSeq)
|
540
|
+
torchTRFs = self.vec_gauss_sum(
|
541
|
+
self.time_embedding_ext,
|
542
|
+
).cpu().numpy()[0, ..., 0]
|
543
|
+
for j in range(self.outDim):
|
544
|
+
for i in range(self.inDim):
|
545
|
+
curFTRF = torchTRFs[j, i,:]
|
546
|
+
TRF = TRFs[j, i, :]
|
547
|
+
# print(pearsonr(curFTRF, TRF))
|
548
|
+
assert np.around(pearsonr(curFTRF, TRF)[0]) >= 0.99
|
549
|
+
|
550
|
+
def vis(self, fs = None):
|
551
|
+
if plt is None:
|
552
|
+
raise ValueError('matplotlib should be installed')
|
553
|
+
with torch.no_grad():
|
554
|
+
FTRFs = self.vec_gauss_sum(
|
555
|
+
self.time_embedding_ext,
|
556
|
+
).cpu()[0, ..., 0]
|
557
|
+
nInChan = self.inDim
|
558
|
+
nOutChan = self.outDim
|
559
|
+
fig, axs = plt.subplots(2)
|
560
|
+
fig.suptitle('top: original TRF, bottom: reconstructed TRF')
|
561
|
+
if fs is None:
|
562
|
+
timelag = self.timelag_idx_ext
|
563
|
+
else:
|
564
|
+
timelag = self.timelag_idx_ext.numpy() / fs
|
565
|
+
for j in range(nOutChan):
|
566
|
+
for i in range(nInChan):
|
567
|
+
TRF = self.TRFs[j,i,:].cpu()
|
568
|
+
FTRF = FTRFs[j,i,:].cpu()
|
569
|
+
# print(pearsonr(FTRF, TRF)[0])
|
570
|
+
axs[0].plot(timelag, TRF)
|
571
|
+
# if j == 0:
|
572
|
+
axs[1].plot(timelag, FTRF)
|
573
|
+
return fig
|
574
|
+
|
575
|
+
|
576
|
+
|
577
|
+
class FourierBasisTRF(FuncBasisTRF):
|
578
|
+
|
579
|
+
def __init__(
|
580
|
+
self,
|
581
|
+
nInChan,
|
582
|
+
nOutChan,
|
583
|
+
tmin_idx,
|
584
|
+
tmax_idx,
|
585
|
+
nBasis,
|
586
|
+
timeshiftLimit_idx,
|
587
|
+
device = 'cpu',
|
588
|
+
if_fit_coefs = False
|
589
|
+
):
|
590
|
+
#TRFs the TRF for some channels
|
591
|
+
super().__init__(nInChan, nOutChan, tmin_idx, tmax_idx, timeshiftLimit_idx, device)
|
592
|
+
# self.nBasis = nBasis
|
593
|
+
# self.nInChan = nInChan
|
594
|
+
# self.nOutChan = nOutChan
|
595
|
+
# self.nWin = nWin
|
596
|
+
coefs = torch.empty((nOutChan, nInChan, nBasis),device=device)
|
597
|
+
if if_fit_coefs:
|
598
|
+
torch.nn.init.kaiming_uniform_(coefs, a=math.sqrt(5))
|
599
|
+
self.coefs = torch.nn.Parameter(coefs)
|
600
|
+
else:
|
601
|
+
self.register_buffer('coefs', coefs)
|
602
|
+
self.T = self.nWin - 1
|
603
|
+
self.device = device
|
604
|
+
maxN = nBasis // 2
|
605
|
+
self.seqN = torch.arange(1,maxN+1,device = self.device)
|
606
|
+
# self.saveMem = False #expr for saving memory usage
|
607
|
+
|
608
|
+
@property
|
609
|
+
def nBasis(self):
|
610
|
+
return self.coefs.shape[2]
|
611
|
+
|
612
|
+
def fitTRFs(self,TRFs):
|
613
|
+
'''
|
614
|
+
TRFs is the numpy array of mtrf weights
|
615
|
+
Shape: [nInDim, nLags, nOutput]
|
616
|
+
'''
|
617
|
+
|
618
|
+
TRFs = torch.from_numpy(TRFs)
|
619
|
+
TRFs = TRFs.permute(2, 0, 1)
|
620
|
+
self.TRFs[:,:,:] = TRFs.to(self.device)[:,:,:]
|
621
|
+
fd_basis_s = []
|
622
|
+
# grid_points = list(range(self.nWin))
|
623
|
+
grid_points = self.time_embedding_ext.squeeze().cpu().numpy()
|
624
|
+
for j in range(self.outDim):
|
625
|
+
for i in range(self.inDim):
|
626
|
+
TRF = TRFs[j, i, :]
|
627
|
+
fd = skfda.FDataGrid(
|
628
|
+
data_matrix=TRF,
|
629
|
+
grid_points=grid_points,
|
630
|
+
)
|
631
|
+
basis = skfda.representation.basis.Fourier(n_basis = self.nBasis)
|
632
|
+
fd_basis = fd.to_basis(basis)
|
633
|
+
coef = fd_basis.coefficients[0]
|
634
|
+
self.coefs[j, i, :] = torch.from_numpy(coef).to(self.device)
|
635
|
+
|
636
|
+
T = fd_basis.basis.period
|
637
|
+
assert T == self.T
|
638
|
+
fd_basis_s.append(fd_basis)
|
639
|
+
|
640
|
+
out = self.vec_fourier_sum(
|
641
|
+
self.nBasis,
|
642
|
+
self.T,
|
643
|
+
self.time_embedding_ext,
|
644
|
+
self.coefs
|
645
|
+
)[0, ..., 0]
|
646
|
+
for j in range(self.outDim):
|
647
|
+
for i in range(self.inDim):
|
648
|
+
fd_basis = fd_basis_s[j*self.inDim + i]
|
649
|
+
temp = fd_basis(grid_points).squeeze() #np.arange(0,self.nWin)
|
650
|
+
curFTRF = out[j, i,:].cpu().numpy()
|
651
|
+
TRF = TRFs[j, i,:]
|
652
|
+
assert np.around(pearsonr(TRF, temp)[0]) >= 0.99
|
653
|
+
try:
|
654
|
+
assert np.allclose(curFTRF,temp,atol = 1e-6)
|
655
|
+
except:
|
656
|
+
print(TRF, curFTRF,temp)
|
657
|
+
raise
|
658
|
+
# print(i,j,pearsonr(TRF, curFTRF))
|
659
|
+
|
660
|
+
|
661
|
+
def phi0(self,T):
|
662
|
+
return 1 / ((2 ** 0.5) * ((T/2) ** 0.5))
|
663
|
+
|
664
|
+
def phi2n_1(self,n,T,t):
|
665
|
+
#n: (maxN)
|
666
|
+
#t: (nBatch, 1, 1, nWin, nSeq, 1)
|
667
|
+
|
668
|
+
#(nBatch, 1, 1, nWin, nSeq, maxN)
|
669
|
+
t_input = 2 * torch.pi * t * n / T
|
670
|
+
#(nBatch, 1, 1, nSeq, maxN, nWin)
|
671
|
+
t_input = t_input.permute(0, 1, 2, 4, 5, 3)
|
672
|
+
signal = torch.sin(t_input) / (T/2)**0.5
|
673
|
+
return signal.permute(0, 1, 2, 5, 3, 4)
|
674
|
+
|
675
|
+
def phi2n(self,n,T,t):
|
676
|
+
#n: (maxN)
|
677
|
+
#t: (nBatch, 1, 1, nWin, nSeq, 1)
|
678
|
+
#(nBatch, 1, 1, nWin, nSeq, maxN)
|
679
|
+
t_input = 2 * torch.pi * t * n / T
|
680
|
+
#(nBatch, 1, 1, nSeq, maxN, nWin)
|
681
|
+
t_input = t_input.permute(0, 1, 2, 4, 5, 3)
|
682
|
+
signal = torch.cos(t_input) / (T/2)**0.5
|
683
|
+
return signal.permute(0, 1, 2, 5, 3, 4)
|
684
|
+
|
685
|
+
def vec_sum(self, x):
|
686
|
+
return self.vec_fourier_sum(self.nBasis, self.T, x, self.coefs)
|
687
|
+
|
688
|
+
def vec_fourier_sum(self,nBasis, T, t,coefs):
|
689
|
+
#coefs: (nOutChan, nInChan, nBasis)
|
690
|
+
#t: (nBatch, 1, 1, nWin, nSeq)
|
691
|
+
# if tChan of t is just 1, which means we share
|
692
|
+
# the same time-axis transformation for all channels
|
693
|
+
#return: (nBatch, outDim, inDim, nWin, nSeq)
|
694
|
+
|
695
|
+
#(nBatch, 1, 1, nWin, nSeq, 1)
|
696
|
+
t = t[..., None]
|
697
|
+
# print(t.shape)
|
698
|
+
const0 = self.phi0(T)
|
699
|
+
maxN = nBasis // 2
|
700
|
+
# (maxN)
|
701
|
+
seqN = self.seqN
|
702
|
+
# (nBatch, 1, 1, nWin, nSeq, maxN)
|
703
|
+
constSin = self.phi2n_1(seqN, T, t)
|
704
|
+
constCos = self.phi2n(seqN, T, t)
|
705
|
+
|
706
|
+
# (nBatch, 1, 1, nWin, nSeq, 2 * maxN)
|
707
|
+
constN = torch.stack(
|
708
|
+
[constSin,constCos],
|
709
|
+
axis = -1
|
710
|
+
).reshape(*constSin.shape[:5], 2*maxN)
|
711
|
+
# print(const0,[i.shape for i in [constN, coefs]])
|
712
|
+
|
713
|
+
nBatch, _, _, nWin, nSeq, nBasis = constN.shape
|
714
|
+
nOutChan, nInChan, nBasis = coefs.shape
|
715
|
+
|
716
|
+
#(nOutChan, nInChan, 1, 1, nBasis)
|
717
|
+
coefs = coefs[:, :, None, None, :]
|
718
|
+
nBasis = nBasis
|
719
|
+
# print(constN.shape, coefs.shape)
|
720
|
+
'''
|
721
|
+
#expr for saving memory usage
|
722
|
+
memAvai,_ = torch.cuda.mem_get_info()
|
723
|
+
nMemReq = nBatch * nSeq * nInChan * nOutChan * nBasis * nWin * 4 # 4 indicates 4 bytes
|
724
|
+
# print(torch.cuda.memory_allocated()/1024/1024)
|
725
|
+
if nMemReq > memAvai * 0.9 or self.saveMem:
|
726
|
+
out = const0 * coefs[...,0] #(nOutChan, nInChan, 1, 1)
|
727
|
+
for nB in range(2 * maxN):
|
728
|
+
out = out + constN[...,nB] * coefs[...,1+nB]
|
729
|
+
else:
|
730
|
+
# (nbatch, nOutChan, nInChan, nWin, nSeq, nBasis)
|
731
|
+
out = const0 * coefs[...,0] + (constN * coefs[...,1:]).sum(-1)
|
732
|
+
'''
|
733
|
+
out = const0 * coefs[...,0] + (constN * coefs[...,1:]).sum(-1)
|
734
|
+
# print(torch.cuda.memory_allocated()/1024/1024)
|
735
|
+
# (nBatch, outDim, inDim, nWin, nSeq)
|
736
|
+
return out
|
737
|
+
|
738
|
+
def forward(self,a, b, c):
|
739
|
+
# a,b,c in most strict case (nBatch, 1, 1, 1, nSeq)
|
740
|
+
# loosly: a,b,c can be (nBatch, nOut, nIn, nWin, nSeq)
|
741
|
+
|
742
|
+
#self.time_embedding #(1, 1, 1, nWin, 1)
|
743
|
+
|
744
|
+
# x: (nBatch, 1, 1, nWin, nSeq)
|
745
|
+
nSeq = self.time_embedding
|
746
|
+
x = c * (nSeq - b)
|
747
|
+
# x = self.corrected_time_embedding(x)
|
748
|
+
#(nBatch, outDim, inDim, nWin, nSeq)
|
749
|
+
# nonLinTRFs = aSeq * self.basisTRF( cSeq * ( nSeq - bSeq) )
|
750
|
+
|
751
|
+
# return:
|
752
|
+
coefs = self.coefs
|
753
|
+
out = a * self.vec_fourier_sum(self.nBasis,self.T,x,coefs)
|
754
|
+
return out
|
755
|
+
|
756
|
+
def vis(self ,fs = None):
|
757
|
+
if fs is None:
|
758
|
+
timelag = self.timelag_idx_ext
|
759
|
+
else:
|
760
|
+
timelag = self.timelag_idx_ext.numpy() / fs
|
761
|
+
if plt is None:
|
762
|
+
raise ValueError('matplotlib should be installed')
|
763
|
+
with torch.no_grad():
|
764
|
+
FTRFs = self.vec_fourier_sum(
|
765
|
+
self.nBasis,
|
766
|
+
self.T,
|
767
|
+
self.time_embedding_ext,
|
768
|
+
self.coefs
|
769
|
+
)[0, ..., 0]
|
770
|
+
nInChan = self.inDim
|
771
|
+
nOutChan = self.outDim
|
772
|
+
fig, axs = plt.subplots(2)
|
773
|
+
fig.suptitle('top: original TRF, bottom: reconstructed TRF')
|
774
|
+
for j in range(nOutChan):
|
775
|
+
for i in range(nInChan):
|
776
|
+
TRF = self.TRFs[j,i,:].cpu()
|
777
|
+
FTRF = FTRFs[j,i,:].cpu()
|
778
|
+
axs[0].plot(timelag, TRF)
|
779
|
+
axs[1].plot(timelag, FTRF)
|
780
|
+
return fig
|
781
|
+
|
782
|
+
basisTRFNameMap = {
|
783
|
+
'gauss': GaussianBasisTRF,
|
784
|
+
'fourier': FourierBasisTRF
|
785
|
+
}
|
786
|
+
|
787
|
+
|
788
|
+
class TRFsGen(torch.nn.Module):
|
789
|
+
|
790
|
+
def forward(self, x, featOnsetIdx):
|
791
|
+
'''
|
792
|
+
input:
|
793
|
+
x: input to be used to derive the
|
794
|
+
transformation parameters TRFs
|
795
|
+
featOnsetIdx: time index of item
|
796
|
+
in the transformed x to be picked as real transformation parameters
|
797
|
+
'''
|
798
|
+
pass
|
799
|
+
|
800
|
+
|
801
|
+
class FuncTRFsGen(torch.nn.Module):
|
802
|
+
'''
|
803
|
+
Implement the functional TRF generator, generate dynamically
|
804
|
+
warped TRF by transform the functional TRF template
|
805
|
+
'''
|
806
|
+
|
807
|
+
def __init__(
|
808
|
+
self,
|
809
|
+
inDim,
|
810
|
+
outDim,
|
811
|
+
tmin_ms,
|
812
|
+
tmax_ms,
|
813
|
+
fs,
|
814
|
+
basisTRFName = 'fourier',
|
815
|
+
limitOfShift_idx = 7,
|
816
|
+
nBasis = 21,
|
817
|
+
mode = '',
|
818
|
+
transformer = None,
|
819
|
+
device = 'cpu',
|
820
|
+
# if_trans_per_outChan = False
|
821
|
+
):
|
822
|
+
super().__init__()
|
823
|
+
assert mode.replace('+-','') in ['','a','b','a,b','a,b,c','a,c']
|
824
|
+
self.fs = fs
|
825
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
826
|
+
self.lagIdxs_ts = torch.Tensor(self.lagIdxs).float().to(device)
|
827
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
828
|
+
nWin = len(self.lagTimes)
|
829
|
+
self.mode = mode
|
830
|
+
self.transformer:torch.nn.Module = transformer
|
831
|
+
self.n_transform_params = len(mode.split(','))
|
832
|
+
self.device = device
|
833
|
+
if isinstance(basisTRFName, str):
|
834
|
+
self.basisTRF:FuncBasisTRF = basisTRFNameMap[basisTRFName](
|
835
|
+
inDim,
|
836
|
+
outDim,
|
837
|
+
self.lagIdxs[0],
|
838
|
+
self.lagIdxs[-1],
|
839
|
+
nBasis,
|
840
|
+
timeshiftLimit_idx = limitOfShift_idx,
|
841
|
+
device=device
|
842
|
+
)
|
843
|
+
elif isinstance(basisTRFName, torch.nn.Module):
|
844
|
+
self.basisTRF:FuncBasisTRF = basisTRFName
|
845
|
+
else:
|
846
|
+
raise ValueError()
|
847
|
+
# self.if_trans_per_outChan = if_trans_per_outChan
|
848
|
+
|
849
|
+
if transformer is None:
|
850
|
+
transInDim, transOutDim, device = self.get_default_transformer_param()
|
851
|
+
self.transformer:torch.nn.Module = CausalConv(transInDim, transOutDim, 2).to(device)
|
852
|
+
|
853
|
+
self.limitOfShift_idx = torch.tensor(limitOfShift_idx)
|
854
|
+
|
855
|
+
@classmethod
|
856
|
+
def parse_trans_params(cls,mode):
|
857
|
+
return mode.split(',')
|
858
|
+
|
859
|
+
@property
|
860
|
+
def tmin_ms(self):
|
861
|
+
return self.lagTimes[0]
|
862
|
+
|
863
|
+
@property
|
864
|
+
def tmax_ms(self):
|
865
|
+
return self.lagTimes[-1]
|
866
|
+
|
867
|
+
@property
|
868
|
+
def inDim(self):
|
869
|
+
return self.basisTRF.inDim
|
870
|
+
|
871
|
+
@property
|
872
|
+
def outDim(self):
|
873
|
+
return self.basisTRF.outDim
|
874
|
+
|
875
|
+
@property
|
876
|
+
def nWin(self):
|
877
|
+
return self.basisTRF.nWin - 2 * self.limitOfShift_idx
|
878
|
+
|
879
|
+
@property
|
880
|
+
def nBasis(self):
|
881
|
+
return self.basisTRF.nBasis
|
882
|
+
|
883
|
+
@property
|
884
|
+
def extendedTimeLagRange(self):
|
885
|
+
minLagIdx = self.lagIdxs[0]
|
886
|
+
maxLagIdx = self.lagIdxs[-1]
|
887
|
+
left = np.arange(minLagIdx - self.limitOfShift_idx, minLagIdx)
|
888
|
+
right = np.arange(maxLagIdx + 1, maxLagIdx + 1 + self.limitOfShift_idx)
|
889
|
+
extLag_idx = np.concatenate([left, self.lagIdxs, right])
|
890
|
+
try:
|
891
|
+
assert len(extLag_idx) == self.nWin + 2 * self.limitOfShift_idx
|
892
|
+
except:
|
893
|
+
print(len(extLag_idx), self.nWin + 2 * self.limitOfShift_idx)
|
894
|
+
timelags = Idxs2msec(extLag_idx, self.fs)
|
895
|
+
return timelags[0], timelags[-1]
|
896
|
+
|
897
|
+
def get_default_transformer_param(self):
|
898
|
+
inDim = self.inDim
|
899
|
+
device = self.device
|
900
|
+
outDim = self.n_transform_params
|
901
|
+
return inDim, outDim, device
|
902
|
+
|
903
|
+
def fitFuncTRF(self, w):
|
904
|
+
w = w * 1 / self.fs
|
905
|
+
with torch.no_grad():
|
906
|
+
self.basisTRF.fitTRFs(w)
|
907
|
+
return self
|
908
|
+
|
909
|
+
def pickParam(self,paramSeqs,idx):
|
910
|
+
#paramSeqs: (nBatch, nMiddleParam, nOut/1, 1, 1, nSeq)
|
911
|
+
return paramSeqs[:, idx, ...]
|
912
|
+
|
913
|
+
|
914
|
+
def getTransformParams(self, x, startIdx = None):
|
915
|
+
paramSeqs = self.transformer(x) #(nBatch, nMiddleParam, nSeq)
|
916
|
+
nBatch, nMiddleParam, nSeq = paramSeqs.shape
|
917
|
+
if startIdx is not None:
|
918
|
+
idxBatch = torch.arange(nBatch)
|
919
|
+
idxMiddleParam = torch.arange(nMiddleParam)
|
920
|
+
idxMiddleParam = idxMiddleParam[:, None]
|
921
|
+
startIdx = startIdx[:, None, :]
|
922
|
+
idxBatch = idxBatch[:, None, None]
|
923
|
+
paramSeqs = paramSeqs[idxBatch, idxMiddleParam, startIdx]
|
924
|
+
|
925
|
+
#(nBatch, n_transform_params, 1, 1, nSeq), this is the most strict case
|
926
|
+
#however, it can also be (nBatch, n_transform_params, nOut, 1, nSeq)
|
927
|
+
# for different transformation for different channel
|
928
|
+
|
929
|
+
paramSeqs = paramSeqs.view(nBatch, self.n_transform_params, -1, 1, 1, nSeq) #[:, :, None, None, :]
|
930
|
+
if paramSeqs.shape[2] != 1:
|
931
|
+
assert paramSeqs.shape[2] == self.outDim
|
932
|
+
midParamList = self.mode.split(',')
|
933
|
+
if midParamList == ['']:
|
934
|
+
midParamList = []
|
935
|
+
nParamMiss = 0
|
936
|
+
if 'a' in midParamList:
|
937
|
+
aIdx = midParamList.index('a')
|
938
|
+
#(nBatch, 1, nOut/1, 1, nSeq)
|
939
|
+
aSeq = self.pickParam(paramSeqs, aIdx)
|
940
|
+
aSeq = torch.abs(aSeq)
|
941
|
+
elif '+-a' in midParamList:
|
942
|
+
aIdx = midParamList.index('+-a')
|
943
|
+
#(nBatch, 1, nOut/1, 1, nSeq)
|
944
|
+
aSeq = self.pickParam(paramSeqs, aIdx)
|
945
|
+
else:
|
946
|
+
nParamMiss += 1
|
947
|
+
#(nBatch, 1, inDim, 1, nSeq)
|
948
|
+
aSeq = x[:, None, :, None, :]
|
949
|
+
if 'b' in midParamList:
|
950
|
+
bIdx = midParamList.index('b')
|
951
|
+
#(nBatch, 1, 1, 1, nSeq)
|
952
|
+
bSeq = self.pickParam(paramSeqs, bIdx)
|
953
|
+
bSeq = torch.maximum(bSeq, - self.limitOfShift_idx)
|
954
|
+
bSeq = torch.minimum(bSeq, self.limitOfShift_idx)
|
955
|
+
else:
|
956
|
+
nParamMiss += 1
|
957
|
+
bSeq = 0
|
958
|
+
|
959
|
+
if 'c' in midParamList:
|
960
|
+
cIdx = midParamList.index('c')
|
961
|
+
#(nBatch, 1, 1, 1, nSeq)
|
962
|
+
cSeq = self.pickParam(paramSeqs, cIdx)
|
963
|
+
#two reasons, cSeq must be larger than 0;
|
964
|
+
#if 1 is the optimum, abs will have two x for the optimum,
|
965
|
+
# which is not stable
|
966
|
+
cSeq = 1 + cSeq
|
967
|
+
cSeq = torch.maximum(cSeq, torch.tensor(0.5))
|
968
|
+
cSeq = torch.minimum(cSeq, torch.tensor(1.28))
|
969
|
+
else:
|
970
|
+
nParamMiss += 1
|
971
|
+
cSeq = 1
|
972
|
+
|
973
|
+
assert (len(midParamList) + nParamMiss) == 3
|
974
|
+
return aSeq, bSeq, cSeq
|
975
|
+
|
976
|
+
def forward(self, x, featOnsetIdx = None):
|
977
|
+
'''
|
978
|
+
x: (nBatch, inDim, nSeq)
|
979
|
+
output: TRFs (nBatch, outDim, nWin, nSeq)
|
980
|
+
'''
|
981
|
+
#(nBatch, nOut/1, 1, 1, nSeq)
|
982
|
+
aSeq, bSeq, cSeq = self.getTransformParams(x, featOnsetIdx)
|
983
|
+
#(1, 1, 1, nWin, 1)
|
984
|
+
# nSeq = self.lagIdxs_ts[None, None, None, :, None] + self.limitOfShift_idx
|
985
|
+
# print(aSeq, bSeq, cSeq)x
|
986
|
+
#(nBatch, outDim, inDim, nWin, nSeq)
|
987
|
+
# nonLinTRFs = aSeq * self.basisTRF( cSeq * ( nSeq - bSeq) )
|
988
|
+
# print(aSeq.shape, bSeq.shape)
|
989
|
+
nonLinTRFs = self.basisTRF(aSeq, bSeq, cSeq)
|
990
|
+
# print(torch.cuda.memory_allocated()/1024/1024)
|
991
|
+
|
992
|
+
#(nBatch, outDim, nWin, nSeq)
|
993
|
+
TRFs = nonLinTRFs.sum(2)
|
994
|
+
# print(torch.cuda.memory_allocated()/1024/1024)
|
995
|
+
return TRFs
|
996
|
+
|
997
|
+
class ASTRF(torch.nn.Module):
|
998
|
+
'''
|
999
|
+
the TRF implemented the convolution sum of temporal response,
|
1000
|
+
(i.e., time-aligning the temporal responses at their
|
1001
|
+
corresponding location, and point-wise sum them).
|
1002
|
+
It requres a module to generate temproal responses to each
|
1003
|
+
individual stimuli, and also require time information to
|
1004
|
+
displace/align the temporal responses at the right
|
1005
|
+
indices/location
|
1006
|
+
|
1007
|
+
limitation: can't do TRF for zscored input, under this condition
|
1008
|
+
location with no stimulus will be non-zero.
|
1009
|
+
|
1010
|
+
|
1011
|
+
the core mechanism of this module following thess steps:
|
1012
|
+
1. generate TRFs using the input param 'x',
|
1013
|
+
2. determine the onset time of these TRFs within the output time series, using input param 'timeinfo',
|
1014
|
+
3. align and sum the generated TRFs at their corresponding time location in the output.
|
1015
|
+
|
1016
|
+
Note:
|
1017
|
+
1. currently, there are two types of x supported,
|
1018
|
+
type 1: discrete type, means there is a single x for each time point in the 'timestamp'
|
1019
|
+
type 2: continuous type, means there is a timeseries of x which has the same length as the output timeseries,
|
1020
|
+
and when generating TRFs, only part of the x will actually contribute to this process.
|
1021
|
+
|
1022
|
+
'''
|
1023
|
+
|
1024
|
+
def __init__(
|
1025
|
+
self,
|
1026
|
+
inDim,
|
1027
|
+
outDim,
|
1028
|
+
tmin_ms,
|
1029
|
+
tmax_ms,
|
1030
|
+
fs,
|
1031
|
+
trfsGen = None,
|
1032
|
+
device = 'cpu',
|
1033
|
+
x_is_timeseries = False
|
1034
|
+
):
|
1035
|
+
'''
|
1036
|
+
inDim: int, the number of columns of input
|
1037
|
+
outDim: int, the number of columns of output of ltiTRFGen and trfsGen
|
1038
|
+
|
1039
|
+
'''
|
1040
|
+
super().__init__()
|
1041
|
+
assert tmin_ms >= 0
|
1042
|
+
self.x_is_timeseries = x_is_timeseries
|
1043
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
1044
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
1045
|
+
nWin = len(self.lagTimes)
|
1046
|
+
self.ltiTRFsGen = LTITRFGen(
|
1047
|
+
inDim,
|
1048
|
+
nWin,
|
1049
|
+
outDim,
|
1050
|
+
ifAddBiasInForward=False
|
1051
|
+
).to(device)
|
1052
|
+
# if callable(trfsGen):
|
1053
|
+
# trfsGen = trfsGen()
|
1054
|
+
self.trfsGen:FuncTRFsGen = trfsGen if trfsGen is None else trfsGen.to(device)
|
1055
|
+
self.fs = fs
|
1056
|
+
|
1057
|
+
self.bias = None
|
1058
|
+
#also train bias for the trfsGen provided by the user
|
1059
|
+
if self.trfsGen is not None:
|
1060
|
+
self.init_nonLinTRFs_bias(inDim, nWin, outDim, device)
|
1061
|
+
|
1062
|
+
self.trfAligner = TRFAligner(device)
|
1063
|
+
self._enableUserTRFGen = False
|
1064
|
+
self.device = device
|
1065
|
+
|
1066
|
+
@property
|
1067
|
+
def inDim(self):
|
1068
|
+
return self.ltiTRFsGen.inDim
|
1069
|
+
|
1070
|
+
@property
|
1071
|
+
def outDim(self):
|
1072
|
+
return self.ltiTRFsGen.outDim
|
1073
|
+
|
1074
|
+
@property
|
1075
|
+
def nWin(self):
|
1076
|
+
return self.ltiTRFsGen.nWin
|
1077
|
+
|
1078
|
+
@property
|
1079
|
+
def tmin_ms(self):
|
1080
|
+
return self.lagTimes[0]
|
1081
|
+
|
1082
|
+
@property
|
1083
|
+
def tmax_ms(self):
|
1084
|
+
return self.lagTimes[-1]
|
1085
|
+
|
1086
|
+
def init_nonLinTRFs_bias(self, inDim, nWin, outDim, device):
|
1087
|
+
self.bias = torch.nn.Parameter(torch.ones(outDim))
|
1088
|
+
fan_in = inDim * nWin
|
1089
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
1090
|
+
torch.nn.init.uniform_(self.bias, -bound, bound)
|
1091
|
+
|
1092
|
+
def set_trfs_gen(self, trfsGen):
|
1093
|
+
self.trfsGen = trfsGen.to(self.device)
|
1094
|
+
self.bias = torch.nn.Parameter(
|
1095
|
+
torch.ones(self.outDim, device = self.device)
|
1096
|
+
)
|
1097
|
+
fan_in = self.inDim * self.nWin
|
1098
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
1099
|
+
torch.nn.init.uniform_(self.bias, -bound, bound)
|
1100
|
+
|
1101
|
+
def get_params_for_train(self):
|
1102
|
+
return [i for i in self.trfsGen.transformer.parameters()] + [self.bias]
|
1103
|
+
# raise NotImplementedError()
|
1104
|
+
|
1105
|
+
|
1106
|
+
def set_linear_weights(self, w, b):
|
1107
|
+
#w: (nInChan, nLag, nOutChan)
|
1108
|
+
self.ltiTRFsGen.load_mtrf_weights(
|
1109
|
+
w,
|
1110
|
+
b,
|
1111
|
+
self.fs,
|
1112
|
+
self.device
|
1113
|
+
)
|
1114
|
+
return self
|
1115
|
+
|
1116
|
+
def get_linear_weights(self):
|
1117
|
+
w, b = self.ltiTRFsGen.export_mtrf_weights(
|
1118
|
+
self.fs
|
1119
|
+
)
|
1120
|
+
return w, b
|
1121
|
+
|
1122
|
+
@property
|
1123
|
+
def if_enable_trfsGen(self):
|
1124
|
+
return self._enableUserTRFGen
|
1125
|
+
|
1126
|
+
@if_enable_trfsGen.setter
|
1127
|
+
def if_enable_trfsGen(self,x):
|
1128
|
+
assert isinstance(x, bool)
|
1129
|
+
print('set ifEnableNonLin',x)
|
1130
|
+
if x == True and self.trfsGen is None:
|
1131
|
+
raise ValueError('trfGen is None, cannot be enabled')
|
1132
|
+
self._enableUserTRFGen = x
|
1133
|
+
|
1134
|
+
def stop_update_linear(self):
|
1135
|
+
self.ltiTRFsGen.stop_update_weights()
|
1136
|
+
|
1137
|
+
def enable_update_linear(self):
|
1138
|
+
self.ltiTRFsGen.enable_update_weights()
|
1139
|
+
|
1140
|
+
def forward(self, x, timeinfo):
|
1141
|
+
'''
|
1142
|
+
input:
|
1143
|
+
x: nBatch * [nChan, nSeq],
|
1144
|
+
timeinfo: nBatch * [2, nSeq]
|
1145
|
+
output: targetTensor
|
1146
|
+
'''
|
1147
|
+
|
1148
|
+
### record the necessary information of each item in the batch
|
1149
|
+
### for x and targetTensor
|
1150
|
+
nSeqs = [] # length of x in the batch
|
1151
|
+
nRealLens = [] # length of real output in the batch
|
1152
|
+
trfOnsetIdxs = [] # the corresponding index of timepoint in the timeinfo in the batch
|
1153
|
+
for ix, xi in enumerate(x):
|
1154
|
+
nLenXi = xi.shape[-1]
|
1155
|
+
if timeinfo[ix] is not None:
|
1156
|
+
# print(timeinfo[ix].shape)
|
1157
|
+
if not self.x_is_timeseries:
|
1158
|
+
assert timeinfo[ix].shape[-1] == xi.shape[-1]
|
1159
|
+
nLen = torch.ceil(
|
1160
|
+
timeinfo[ix][0][-1] * self.fs
|
1161
|
+
).long() + self.nWin
|
1162
|
+
onsetIdx = torch.round(
|
1163
|
+
timeinfo[ix][0,:] * self.fs
|
1164
|
+
).long() + self.lagIdxs[0]
|
1165
|
+
else:
|
1166
|
+
nLen = nLenXi
|
1167
|
+
onsetIdx = torch.tensor(np.arange(nLen)) + self.lagIdxs[0]
|
1168
|
+
nSeqs.append(nLenXi)
|
1169
|
+
nRealLens.append(nLen)
|
1170
|
+
trfOnsetIdxs.append(onsetIdx)
|
1171
|
+
|
1172
|
+
nGlobLen = max(nRealLens)
|
1173
|
+
x = seqLast_pad_zero(x)
|
1174
|
+
trfOnsetIdxs = seqLast_pad_zero(trfOnsetIdxs, value = -1)
|
1175
|
+
|
1176
|
+
# if x is time series
|
1177
|
+
featOnsetIdxs = None
|
1178
|
+
if self.x_is_timeseries:
|
1179
|
+
# featIdxs is the index where we get the feat for generating TRFs
|
1180
|
+
featOnsetIdxs = trfOnsetIdxs.detach().clone()
|
1181
|
+
featOnsetIdxs[featOnsetIdxs != -1] =\
|
1182
|
+
featOnsetIdxs[featOnsetIdxs != -1] - self.lagIdxs[0]
|
1183
|
+
|
1184
|
+
#TRFs shape: (nBatch, outDim, nWin, nSeq)
|
1185
|
+
# print(x.shape)
|
1186
|
+
TRFs = self.get_trfs(x, featOnsetIdxs)
|
1187
|
+
|
1188
|
+
#targetTensor shape: (nBatch,outDim,nRealLen)
|
1189
|
+
targetTensor = self.trfAligner(TRFs,trfOnsetIdxs,nGlobLen)
|
1190
|
+
|
1191
|
+
if self.if_enable_trfsGen:
|
1192
|
+
targetTensor = targetTensor + self.bias.view(-1,1)
|
1193
|
+
else:
|
1194
|
+
ltiTRFBias = self.ltiTRFsGen.bias
|
1195
|
+
targetTensor = targetTensor + ltiTRFBias.view(-1,1)
|
1196
|
+
|
1197
|
+
return targetTensor
|
1198
|
+
|
1199
|
+
def get_trfs(self, x, featOnsetIdxs = None):
|
1200
|
+
if self.if_enable_trfsGen:
|
1201
|
+
return self.trfsGen(x, featOnsetIdxs)
|
1202
|
+
else:
|
1203
|
+
return self.ltiTRFsGen(x)
|
1204
|
+
|
1205
|
+
|
1206
|
+
class ASCNNTRF(ASTRF):
|
1207
|
+
#perform CNNTRF within intervals,
|
1208
|
+
#and change the weights
|
1209
|
+
|
1210
|
+
def __init__(
|
1211
|
+
self,
|
1212
|
+
inDim,
|
1213
|
+
outDim,
|
1214
|
+
tmin_ms,
|
1215
|
+
tmax_ms,
|
1216
|
+
fs,
|
1217
|
+
trfsGen = None,
|
1218
|
+
device = 'cpu'
|
1219
|
+
):
|
1220
|
+
torch.nn.Module.__init__(self)
|
1221
|
+
# assert tmin_ms >= 0
|
1222
|
+
self.inDim = inDim
|
1223
|
+
self.outDim = outDim
|
1224
|
+
self.tmin_ms = tmin_ms
|
1225
|
+
self.tmax_ms = tmax_ms
|
1226
|
+
self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
|
1227
|
+
self.lagTimes = Idxs2msec(self.lagIdxs,fs)
|
1228
|
+
self.tmin_idx = self.lagIdxs[0]
|
1229
|
+
self.tmax_idx = self.lagIdxs[-1]
|
1230
|
+
nWin = len(self.lagTimes)
|
1231
|
+
self._nWin = nWin
|
1232
|
+
self.ltiTRFsGen = LTITRFGen(
|
1233
|
+
inDim,
|
1234
|
+
nWin,
|
1235
|
+
outDim,
|
1236
|
+
ifAddBiasInForward=False
|
1237
|
+
).to(device)
|
1238
|
+
self.trfsGen = trfsGen if trfsGen is None else trfsGen.to(device)
|
1239
|
+
self.fs = fs
|
1240
|
+
|
1241
|
+
self.bias = None
|
1242
|
+
#also train bias for the trfsGen provided by the user
|
1243
|
+
if self.trfsGen is not None:
|
1244
|
+
self.bias = torch.nn.Parameter(torch.ones(outDim, device = device))
|
1245
|
+
fan_in = inDim * nWin
|
1246
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
1247
|
+
torch.nn.init.uniform_(self.bias, -bound, bound)
|
1248
|
+
|
1249
|
+
self._enableUserTRFGen = False
|
1250
|
+
self.device = device
|
1251
|
+
|
1252
|
+
@property
|
1253
|
+
def nWin(self):
|
1254
|
+
return self._nWin
|
1255
|
+
|
1256
|
+
def getTRFs(self, ctx):
|
1257
|
+
# ctx: (nBatch, inDim,nSeq)
|
1258
|
+
# #nTRFs is the number of TRFs needed
|
1259
|
+
|
1260
|
+
#note the difference between the ASTRF and ASCNNTRF at here
|
1261
|
+
#the LTITRF is not multiplied with x
|
1262
|
+
|
1263
|
+
|
1264
|
+
if self.ifEnableUserTRFGen:
|
1265
|
+
#how to decide how much trfs to return????
|
1266
|
+
TRFs = self.trfsGen(ctx)
|
1267
|
+
#TRFs (nBatch, outDim, nWin, nSeq)
|
1268
|
+
TRFs = TRFs[0] #(outDim, nWin, nSeq)
|
1269
|
+
TRFs = TRFs.permute(2, 0, 1)[..., None, :]
|
1270
|
+
else:
|
1271
|
+
#ctx: (nBatch, inDim,nSeq) (nTRFs, inDim,nSeq)
|
1272
|
+
nTRFs = ctx.shape[-1]#len(ctx)
|
1273
|
+
# TRFs nTRFs * (nChanOut, nChanIn, nWin)
|
1274
|
+
TRFs = [self.ltiTRFsGen.weight] * nTRFs
|
1275
|
+
return TRFs
|
1276
|
+
|
1277
|
+
def getTRFSwitchOnsets(self, x):
|
1278
|
+
#X: nBatch * [nChan, nSeq]
|
1279
|
+
return [0, 200, 300, 400, 500, 600, 800, 1000]
|
1280
|
+
|
1281
|
+
|
1282
|
+
def defaultCtx(self, switchOnsets, x):
|
1283
|
+
switchOnsets2 = switchOnsets + [-1]
|
1284
|
+
ctx = []
|
1285
|
+
for i in range(len(switchOnsets)):
|
1286
|
+
ctx.append(x[:, :, switchOnsets2[i]:switchOnsets2[i+1]])
|
1287
|
+
return ctx
|
1288
|
+
|
1289
|
+
def forward(self, x, timeInfo = None, ctx = None):
|
1290
|
+
|
1291
|
+
#currently only support single batch
|
1292
|
+
|
1293
|
+
#X: nBatch * [nChan, nSeq]
|
1294
|
+
#timeinfo: nBatch * [2, nSeq]
|
1295
|
+
if timeInfo is None:
|
1296
|
+
TRFSwitchOnsets = self.getTRFSwitchOnsets(x)
|
1297
|
+
else:
|
1298
|
+
TRFSwitchOnsets = timeInfo
|
1299
|
+
if ctx is None:
|
1300
|
+
ctx = self.defaultCtx(TRFSwitchOnsets, x)
|
1301
|
+
nTRFs = len(TRFSwitchOnsets)
|
1302
|
+
nBatch, nChan, nSeq = x.shape
|
1303
|
+
TRFs = self.getTRFs(ctx)
|
1304
|
+
TRFsFlip = [TRF.flip([-1]) for TRF in TRFs]
|
1305
|
+
# print([torch.equal(TRFsFlip[0], temp) for temp in TRFsFlip[1:]])
|
1306
|
+
# print(TRFsFlip)
|
1307
|
+
TRFSwitchOnsets.append(None)
|
1308
|
+
|
1309
|
+
nPaddedOutput = nSeq \
|
1310
|
+
+ max(-self.tmin_idx, 0) \
|
1311
|
+
+ max( self.tmax_idx, 0)
|
1312
|
+
|
1313
|
+
output = torch.zeros(
|
1314
|
+
nBatch,
|
1315
|
+
self.outDim,
|
1316
|
+
nPaddedOutput,
|
1317
|
+
device = self.device
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
#segment startIdx offset
|
1321
|
+
startOffset = self.tmin_idx
|
1322
|
+
#global startIdx offset
|
1323
|
+
offset = min(0, self.tmin_idx)
|
1324
|
+
|
1325
|
+
realOffset = startOffset - offset
|
1326
|
+
|
1327
|
+
for idx, TRFFlip in enumerate(TRFsFlip):
|
1328
|
+
t_start = TRFSwitchOnsets[idx]
|
1329
|
+
t_end = TRFSwitchOnsets[idx+1]
|
1330
|
+
t_x = x[..., t_start : t_end]
|
1331
|
+
segment = self.trfCNN(t_x,TRFFlip)
|
1332
|
+
# print(realOffset)
|
1333
|
+
t_startReal = t_start + realOffset
|
1334
|
+
t_endReal = t_startReal + segment.shape[-1]
|
1335
|
+
# print(segment.shape, t_startReal, t_endReal, t_x.shape, x.shape)
|
1336
|
+
output[:,:,t_startReal : t_endReal] += segment
|
1337
|
+
|
1338
|
+
#decide how to crop the output based on tmin and tmax
|
1339
|
+
startIdx = max(-self.tmin_idx, 0)
|
1340
|
+
lenOutput = output.shape[-1]
|
1341
|
+
endIdx = lenOutput - max( self.tmax_idx, 0)
|
1342
|
+
|
1343
|
+
return output[..., startIdx: endIdx] + self.ltiTRFsGen.bias.view(-1, 1)
|
1344
|
+
|
1345
|
+
def trfCNN(self, x, TRFFlip):
|
1346
|
+
#X: (nBatch, nChan, nSubSeq)
|
1347
|
+
# TRFsFlip is the TRFs with its kernel dimension flipped
|
1348
|
+
# for Conv1D!
|
1349
|
+
|
1350
|
+
#need to first padding
|
1351
|
+
#timelag doesn't influence how much to pad
|
1352
|
+
# but the offset of the startIdx
|
1353
|
+
nWin = TRFFlip.shape[-1]
|
1354
|
+
x = torch.nn.functional.pad(x, (nWin-1, nWin-1))
|
1355
|
+
#then do the conv
|
1356
|
+
output = torch.nn.functional.conv1d(x, TRFFlip)
|
1357
|
+
return output
|
1358
|
+
|