nntrf 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1358 @@
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from scipy.stats import pearsonr
5
+ from torch.nn.functional import pad, fold
6
+ from .linear import msec2Idxs, Idxs2msec, CPadOrCrop1D
7
+ try:
8
+ import skfda
9
+ except:
10
+ skfda = None
11
+
12
+ try:
13
+ from matplotlib import pyplot as plt
14
+ except:
15
+ plt = None
16
+
17
+ try:
18
+ from mtrf.model import TRF
19
+ except:
20
+ TRF = None
21
+
22
+
23
+ def fit_forward_mtrf(stim, resp, fs, tmin_ms, tmax_ms, regularization, k):
24
+ trf = TRF(direction=1)
25
+ trf.train(stim, resp, fs, tmin_ms / 1e3, tmax_ms / 1e3, regularization, k = k)
26
+ return trf.weights, trf.bias
27
+
28
+ def seqLast_pad_zero(seq, value = 0):
29
+ maxLen = max([i.shape[-1] for i in seq])
30
+ output = []
31
+ for i in seq:
32
+ output.append(pad(i,(0,maxLen - i.shape[-1]), value = value))
33
+ return torch.stack(output,0)
34
+
35
+ class CausalConv(torch.nn.Module):
36
+
37
+ def __init__(self,inDim,outDim,nKernel,dilation = 1):
38
+ super().__init__()
39
+ self.nKernel = nKernel
40
+ self.dilation = dilation
41
+ self.conv = torch.nn.Conv1d(
42
+ inDim,
43
+ outDim,
44
+ nKernel,
45
+ dilation = dilation
46
+ )
47
+
48
+ def forward(self,x):
49
+ '''
50
+ x: (nBatch, nChan, nSeq)
51
+ '''
52
+ # padding left
53
+ x = torch.nn.functional.pad(x,(( self.dilation * (self.nKernel-1) ,0)))
54
+
55
+ #(nBatch, nOutChan, nSeq)
56
+ x = self.conv(x)
57
+ return x
58
+ ## experiment module end
59
+
60
+ class TRFAligner(torch.nn.Module):
61
+
62
+ def __init__(self,device):
63
+ super().__init__()
64
+ self.device = device
65
+
66
+ def forward(self,TRFs,sourceIdx,nRealLen):#,targetTensor):
67
+ '''
68
+ in-place operation
69
+ Parameters
70
+ ----------
71
+ TRFs : TYPE, (nBatch, outDim, nWin, nSeq)
72
+ tensors output by DyTimeEncoder.
73
+ sourceIdx : TYPE, (nBatch, nSeq)
74
+ index of dyImpulse tensor to be assigned to target tensor
75
+ nRealLen:
76
+ the length of the target
77
+ Returns
78
+ -------
79
+ None.
80
+
81
+ '''
82
+ nBatch, outDim, nWin, nSeq = TRFs.shape
83
+ # (nBatch, outDim, nWin, nSeq)
84
+ respUnfold = TRFs
85
+ maxSrcIdx = torch.max(sourceIdx[:, -1])
86
+ if maxSrcIdx >= nRealLen:
87
+ nRealLen = maxSrcIdx + 1
88
+ # print(outDim,nWin,nRealLen,respUnfold.shape,sourceIdx)
89
+ self.cache = torch.zeros((nBatch, outDim,nWin,nRealLen),device = self.device)
90
+
91
+ idxWin = torch.arange(nWin)
92
+ idxChan = torch.arange(outDim)
93
+ idxBatch = torch.arange(nBatch)
94
+ idxWin = idxWin[:, None]
95
+ idxChan = idxChan[:,None, None]
96
+ idxBatch = idxBatch[:, None, None, None]
97
+ sourceIdx = sourceIdx[:,None, None,:]
98
+
99
+ self.cache[idxBatch, idxChan, idxWin, sourceIdx] = respUnfold #(nBatch, outDim,nWin,nRealLen)
100
+ self.cache = self.cache.view(nBatch,-1,nRealLen) # (nBatch, outDim*nWin, nRealLen)
101
+ foldOutputSize = (nRealLen + nWin - 1, 1)
102
+ foldKernelSize = (nWin, 1)
103
+ #(nBatch,outDim,foldOutputSize,1)
104
+ output = fold(self.cache,foldOutputSize,foldKernelSize)
105
+ #(nBatch,outDim,nRealLen)
106
+ targetTensor = output[:,:,:nRealLen,0]
107
+ return targetTensor
108
+
109
+ class LTITRFGen(torch.nn.Module):
110
+ def __init__(self,inDim,nWin,outDim,ifAddBiasInForward = True):
111
+ super().__init__()
112
+ self.weight = torch.nn.Parameter(torch.ones(outDim,inDim,nWin))
113
+ self.bias = torch.nn.Parameter(torch.ones(outDim))
114
+ k = 1 / (inDim * nWin)
115
+ lower = - np.sqrt(k)
116
+ upper = np.sqrt(k)
117
+ torch.nn.init.uniform_(self.weight, a = lower, b = upper)
118
+ torch.nn.init.uniform_(self.bias, a = lower, b = upper)
119
+ self.ifAddBiasInForward = ifAddBiasInForward
120
+
121
+ @property
122
+ def outDim(self):
123
+ return self.weight.shape[0]
124
+
125
+ @property
126
+ def inDim(self):
127
+ return self.weight.shape[1]
128
+
129
+ @property
130
+ def nWin(self):
131
+ return self.weight.shape[2]
132
+
133
+ def forward(self,x):
134
+ # x: (nBatch, inDim, nSeq)
135
+ assert x.ndim == 3
136
+ kernelsTemp = self.weight[None, ..., None] #(1, outDim, inDim, nWin, 1)
137
+ xTemp = x[:, None, :, None, :] #(nBatch, 1, inDim, 1, nSeq)
138
+ TRFs = xTemp * kernelsTemp #(nBatch, outDim, inDim, nWin, nSeq)
139
+ if self.ifAddBiasInForward:
140
+ TRFs = TRFs + self.bias[..., None, None, None] #(nBatch, outDim, inDim, nWin, nSeq)
141
+ TRFs = TRFs.sum(2) #(nBatch, outDim, nWin, nSeq)
142
+ return TRFs
143
+
144
+ def load_mtrf_weights(self, w, b, fs, device):
145
+ #w: (nInChan, nLag, nOutChan)
146
+ b = b[0]
147
+ w = w * 1 / fs
148
+ b = b * 1/ fs
149
+ w = torch.FloatTensor(w).to(device)
150
+ b = torch.FloatTensor(b).to(device)
151
+ w = w.permute(2, 0, 1) #(nOutChan, nInChan, nLag)
152
+ with torch.no_grad():
153
+ self.weight = torch.nn.Parameter(w)
154
+ self.bias = torch.nn.Parameter(b)
155
+ return self
156
+
157
+ def export_mtrf_weights(self, fs):
158
+ with torch.no_grad():
159
+ # (nInChan, nLag, nOutChan)
160
+ w = self.weight.cpu().detach().permute(1, 2, 0).numpy()
161
+ b = self.bias.cpu().detach().numpy().reshape(1,-1)
162
+ w = w * fs
163
+ b = b * fs
164
+ return w, b
165
+
166
+ def stop_update_weights(self):
167
+ self.weight.requires_grad_(False)
168
+ self.weight.grad = None
169
+ self.bias.requires_grad_(False)
170
+ self.bias.grad = None
171
+
172
+ def enable_update_weights(self):
173
+ self.requires_grad_(True)
174
+ self.weight.grad = torch.zeros_like(
175
+ self.weight
176
+ )
177
+ self.bias.grad = torch.zeros_like(
178
+ self.bias
179
+ )
180
+
181
+ class WordTRFEmbedGenTokenizer():
182
+ def __init__(self, wordsDict, device):
183
+ self.wordsDict = wordsDict
184
+ self.device = device
185
+
186
+ def __call__(self, words):
187
+ batchTokens = []
188
+ for ws in words:
189
+ tokens = []
190
+ for w in ws:
191
+ tokens.append(self.wordsDict[w])
192
+ batchTokens.append(
193
+ torch.tensor(tokens, dtype = torch.long,device = self.device)
194
+ )
195
+ return batchTokens
196
+
197
+ class WordTRFEmbedGen(torch.nn.Module):
198
+
199
+ def __init__(
200
+ self,
201
+ outDim,
202
+ hiddenDim,
203
+ tmin_ms,
204
+ tmax_ms,
205
+ fs,
206
+ wordsDict,
207
+ device
208
+ ):
209
+ super().__init__()
210
+ self.outDim = outDim
211
+ self.hiddenDim = hiddenDim
212
+ self.tmin_ms = tmin_ms
213
+ self.tmax_ms = tmax_ms
214
+ self.fs = fs
215
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
216
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
217
+ nWin = len(self.lagTimes)
218
+ self.nWin = nWin
219
+ self.embedding_dim = nWin * hiddenDim
220
+
221
+ self.device = device
222
+ self.wordsDict = wordsDict
223
+ self.embedding = torch.nn.Embedding(
224
+ len(wordsDict)+1,
225
+ self.embedding_dim,
226
+ padding_idx = 0
227
+ ).to(device)
228
+ self.proj = torch.nn.Linear(self.hiddenDim, self.outDim, device = device)
229
+
230
+ def forward(self, batchTokens):
231
+ # (nBatch, outDim, nWin, nSeq)
232
+ batchTokens = seqLast_pad_zero(batchTokens)
233
+ # (nBatch, nWin * hiddenDim)
234
+ trfs = self.embedding(batchTokens)
235
+ # print(trfs.shape)
236
+ trfs = trfs.reshape(*trfs.shape[:2], self.hiddenDim, self.nWin)
237
+ # (nBatch, nSeq, nWin, hiddenDim)
238
+ # print(trfs.shape)
239
+ trfs = trfs.permute(0, 1, 3, 2)
240
+ # (nBatch, nSeq, nWin, outDim)
241
+ # print(torch.cuda.memory_allocated()/1024/1024)
242
+ trfs = self.proj(trfs)
243
+ # (nBatch, outDim, nWin, nSeq)
244
+ trfs = trfs.permute(0, 3, 2, 1)
245
+ # print(trfs.shape)
246
+ return trfs
247
+
248
+
249
+ class CustomKernelCNNTRF(torch.nn.Module):
250
+
251
+ def __init__(self,inDim,outDim,tmin_ms,tmax_ms,fs,groups = 1,dilation = 1):
252
+ super().__init__()
253
+ self.tmin_ms = tmin_ms
254
+ self.tmax_ms = tmax_ms
255
+ self.fs = fs
256
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
257
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
258
+ self.tmin_idx = self.lagIdxs[0]
259
+ self.tmax_idx = self.lagIdxs[-1]
260
+ nLags = len(self.lagTimes)
261
+ nKernels = (nLags - 1) / dilation + 1
262
+ assert np.ceil(nKernels) == np.floor(nKernels)
263
+ self.nWin = int(nKernels)
264
+ self.oPadOrCrop = CPadOrCrop1D(self.tmin_idx,self.tmax_idx)
265
+ self.groups = groups
266
+ self.dilation = dilation
267
+ self.inDim = inDim
268
+ self.outDim = outDim
269
+
270
+ k = 1 / (inDim * self.nWin)
271
+ lower = - np.sqrt(k)
272
+ upper = np.sqrt(k)
273
+ self.bias = torch.nn.Parameter(torch.ones(outDim))
274
+ torch.nn.init.uniform_(self.bias, a = lower, b = upper)
275
+
276
+ def setTRFGen(self, trfGen):
277
+ self.trfGen = trfGen
278
+
279
+ def forward(self, x, weight = None):
280
+ if weight is None:
281
+ weight = self.trfGen.TRF()
282
+ assert self.nWin == weight.shape[2]
283
+ assert self.outDim == weight.shape[0]
284
+ assert self.inDim == weight.shape[1]
285
+ x = self.oPadOrCrop(x)
286
+ y = torch.nn.functional.conv1d(
287
+ x,
288
+ weight,
289
+ bias=self.bias,
290
+ dilation=self.dilation,
291
+ groups=self.groups
292
+ )
293
+ return y
294
+
295
+ class FuncBasisTRF(torch.nn.Module):
296
+
297
+ def __init__(self, inDim, outDim, tmin_idx, tmax_idx, timeshiftLimit_idx ,device) -> None:
298
+ super().__init__()
299
+ self.timeshiftLimit_idx = torch.tensor(timeshiftLimit_idx)
300
+ self.time_embedding = self.get_time_embedding(
301
+ tmin_idx, tmax_idx, device)
302
+ self.time_embedding_ext = self.get_time_embedding(
303
+ tmin_idx-timeshiftLimit_idx, tmax_idx+timeshiftLimit_idx, device)
304
+ nWin = self.time_embedding_ext.shape[-2]
305
+ TRFs = torch.zeros((outDim, inDim, nWin),device=device)
306
+ self.register_buffer('TRFs',TRFs)
307
+
308
+ # def corrected_time_embedding(self, t):
309
+ # return t + + self.timeshiftLimit_idx
310
+
311
+ def TRF(self):
312
+ # (outDim, inDim, nWin)
313
+ #(nBatch, 1, 1, nWin, nSeq)
314
+ return self.forward(1, 0, 1)[0,...,0]#.detach().cpu().numpy()
315
+
316
+ @property
317
+ def inDim(self):
318
+ return self.TRFs.shape[1]
319
+
320
+ @property
321
+ def outDim(self):
322
+ return self.TRFs.shape[0]
323
+
324
+ @property
325
+ def nWin(self):
326
+ return self.TRFs.shape[2]
327
+
328
+ @property
329
+ def nBasis(self):
330
+ raise NotImplementedError
331
+
332
+ def vis(self):
333
+ raise NotImplementedError
334
+
335
+ def forward(self, a, b, c):
336
+ raise NotImplementedError
337
+
338
+ def fitTRFs(self, TRFs):
339
+ raise NotImplementedError
340
+
341
+ def get_time_embedding(self, tmin_idx, tmax_idx, device = 'cpu'):
342
+ #(1, 1, 1, nWin, 1)
343
+ return torch.arange(tmin_idx,tmax_idx+1, device=device)\
344
+ .view(1,1,1,-1,1)
345
+
346
+ @property
347
+ def timelag_idx(self,):
348
+ return self.time_embedding.detach().cpu().squeeze()
349
+
350
+ @property
351
+ def timelag_idx_ext(self,):
352
+ return self.time_embedding_ext.detach().cpu().squeeze()
353
+
354
+
355
+ def build_gaussian_response(x, mu, sigma):
356
+ # x: (nBatch, 1, 1, nWin, nSeq)
357
+ # mu: (nBasis)
358
+ # sigma: (nBasis, outDim, inDim)
359
+ # output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
360
+
361
+ # x: (nBatch, 1, 1, 1, nWin, nSeq)
362
+ x = x[:, None, ...]
363
+ # mu: (nBasis, 1, 1, 1, 1)
364
+ mu = mu[..., None, None, None, None]
365
+ # sigma: (nBasis, outDim, inDim, 1, 1)
366
+ sigma = sigma[..., None, None]
367
+ # output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
368
+ return torch.exp(-(x-mu)**2 / (2*(sigma)**2))
369
+
370
+ def solve_coef(gaussresps, trf):
371
+ nWin = trf.shape[0]
372
+ A = np.concatenate([gaussresps, np.ones((1,nWin))], axis = 0).T
373
+ coefs = np.linalg.lstsq(A, trf, rcond=None)[0]
374
+ return coefs
375
+
376
+ class GaussianBasisTRF(FuncBasisTRF):
377
+
378
+ def __init__(
379
+ self,
380
+ inDim,
381
+ outDim,
382
+ tmin_idx,
383
+ tmax_idx,
384
+ nBasis,
385
+ timeshiftLimit_idx = 0,
386
+ sigmaMin = 6.4,
387
+ sigmaMax = 6.4,
388
+ ifSumInDim = False,
389
+ device = 'cpu',
390
+ mu = None,
391
+ sigma = None,
392
+ include_constant_term = True
393
+ ):
394
+ super().__init__(inDim, outDim, tmin_idx, tmax_idx, timeshiftLimit_idx, device)
395
+ nWin = self.nWin
396
+ ### Fittable Parameters
397
+ ## out projection init
398
+ coefs = torch.ones((nBasis + 1, outDim, inDim), device = device, dtype = torch.float32)
399
+ torch.nn.init.kaiming_uniform_(coefs, a=math.sqrt(5))
400
+ self.coefs = torch.nn.Parameter(coefs)
401
+ ## bias init
402
+ # k = 1 / (inDim * nWin)
403
+ # lower = - np.sqrt(k)
404
+ # upper = np.sqrt(k)
405
+ # self.bias = torch.nn.Parameter(torch.ones(outDim))
406
+ # torch.nn.init.uniform_(self.bias, a = lower, b = upper)
407
+ ## sigma init
408
+ if sigma is not None:
409
+ assert len(sigma) == nBasis
410
+ sigma = torch.tensor(sigma)
411
+ else:
412
+ sigma = torch.ones(nBasis, outDim, inDim, device = device, dtype = torch.float32) * (sigmaMin + sigmaMax) / 2
413
+ self.sigma = torch.nn.Parameter(sigma)
414
+ # torch.nn.init.uniform_(self.sigma, a = lower, b = upper)
415
+
416
+ ### Fixed Values
417
+ # timeEmbed = torch.arange(nWin)
418
+ # self.register_buffer('timeEmbed', timeEmbed)
419
+ time_embedding_ext = self.time_embedding_ext.squeeze()
420
+ tmin_idx_ext, tmax_idx_ext = time_embedding_ext[0], time_embedding_ext[-1]
421
+ if mu is not None:
422
+ assert len(mu) == nBasis
423
+ mu = torch.tensor(mu, device = device, dtype = torch.float32)
424
+ else:
425
+ mu = torch.linspace(tmin_idx_ext.item(), tmax_idx_ext.item(), nBasis + 2)[1:-1]
426
+ self.register_buffer('mu', mu)
427
+ # self.mu = torch.nn.Parameter(mu)
428
+
429
+ sigmaMin = torch.tensor(sigmaMin)
430
+ self.register_buffer('sigmaMin', sigmaMin)
431
+ sigmaMax = torch.tensor(sigmaMax)
432
+ self.register_buffer('sigmaMax', sigmaMax)
433
+ self.ifSumInDim = ifSumInDim
434
+ self.include_constant_term = include_constant_term
435
+ self.device = device
436
+
437
+ def vec_sum(self, x):
438
+ return self.vec_gauss_sum(x)
439
+
440
+ def vec_gauss_sum(self, x):
441
+ sigma = self.sigma
442
+ sigma = torch.maximum(sigma, self.sigmaMin)
443
+ sigma = torch.minimum(sigma, self.sigmaMax)
444
+ # print(sigma)
445
+ # (nBatch, nBasis, outDim, inDim, nWin, nSeq)
446
+ # print(self.mu, sigma)
447
+ gaussResps = build_gaussian_response(
448
+ x,
449
+ self.mu,
450
+ sigma
451
+ )
452
+
453
+ # coefs: (nBasis + 1, outDim, inDim, 1, 1)
454
+ coefs = self.coefs[..., None, None]
455
+ # print(gaussResps.shape)
456
+ # nBatch, _, outDim, inDim, nWin, nSeq = gaussResps.shape
457
+ # (nBatch, nBasis+1, outDim, inDim, nWin, nSeq)
458
+ # aug_gaussResps = torch.cat([gaussResps, torch.ones(nBatch, 1, outDim, inDim, nWin, nSeq, device = self.device)], dim = -5)
459
+ # # wGaussResps = coefs[:-1,...] * gaussResps
460
+ # wGaussResps = coefs * aug_gaussResps
461
+ # # (nBatch, outDim, inDim, nWin, nSeq)
462
+ # wGaussResps = wGaussResps.sum((-5))
463
+
464
+ # (nBasis, outDim, inDim, 1, 1)
465
+ coefs_1 = coefs[:-1]
466
+ # (outDim, inDim, 1, 1)
467
+ coefs_2 = coefs[-1]
468
+ # (nBatch, nBasis+1, outDim, inDim, nWin, nSeq)
469
+ w_gauss_resps_1 = coefs_1 * gaussResps
470
+ if self.include_constant_term:
471
+ # print('include constant term')
472
+ w_gauss_resps = w_gauss_resps_1.sum(-5) + coefs_2
473
+ else:
474
+ w_gauss_resps = w_gauss_resps_1.sum(-5)
475
+
476
+ return w_gauss_resps
477
+
478
+ def forward(self, a, b, c):
479
+ # print(x.shape)
480
+ # output (nBatch, outDim, (inDim), nWin, nSeq)
481
+ # x: x: (nBatch, 1, 1, nWin, nSeq)
482
+ # currently just support the 'component' mode
483
+ x = c * (self.time_embedding - b)
484
+ # x = self.corrected_time_embedding(x)
485
+ wGaussResps = a * self.vec_gauss_sum(x)
486
+ # print(coefs[:,0,0,0,0])
487
+ if self.ifSumInDim:
488
+ # (nBatch, outDim, nWin, nSeq)
489
+ wGaussResps = wGaussResps.sum((-3))
490
+ # wGaussResps = wGaussResps + self.bias[:, None, None]
491
+ # print(wGaussResps)
492
+ return wGaussResps
493
+
494
+ @property
495
+ def nBasis(self):
496
+ return self.sigma.shape[0]
497
+
498
+ def fitTRFs(self, TRFs):
499
+ '''
500
+ TRFs is the numpy array of mtrf weights
501
+ Shape: [nInDim, nLags, nOutput]
502
+
503
+ self.coefs: (nBasis+1, outDim, inDim)
504
+ sigma: (nBasis, outDim, inDim)
505
+ '''
506
+ # print(TRFs.shape)
507
+ TRFs = torch.from_numpy(TRFs)
508
+ TRFs = TRFs.permute(2, 0, 1)
509
+ self.TRFs[:,:,:] = TRFs.to(self.device)[:,:,:]
510
+ x = self.time_embedding_ext
511
+ sigma = self.sigma
512
+ # print(sigma)
513
+ # (nBasis, outDim, inDim, nWin)
514
+ with torch.no_grad():
515
+ gaussResps = build_gaussian_response(
516
+ x,
517
+ self.mu,
518
+ sigma
519
+ )[0, ..., 0].cpu().numpy()
520
+
521
+ nWin = TRFs.shape[2]
522
+ assert nWin == self.nWin
523
+ # (nBasis+1, outDim, inDim)
524
+ coefs = np.zeros(self.coefs.shape)
525
+ for i in range(self.outDim):
526
+ for j in range(self.inDim):
527
+ t_trf = TRFs[i,j,:]
528
+ # (nBasis, nWin)
529
+ t_gauss = gaussResps[:, i, j, :]
530
+ t_coef = solve_coef(
531
+ t_gauss,
532
+ t_trf
533
+ )
534
+ # print(coefs[:, i, j].shape, t_coef.shape)
535
+ coefs[:, i, j] = t_coef
536
+
537
+ with torch.no_grad():
538
+ self.coefs[:,:,:] = torch.from_numpy(coefs)
539
+ # (nBatch, outDim, inDim, nWin, nSeq)
540
+ torchTRFs = self.vec_gauss_sum(
541
+ self.time_embedding_ext,
542
+ ).cpu().numpy()[0, ..., 0]
543
+ for j in range(self.outDim):
544
+ for i in range(self.inDim):
545
+ curFTRF = torchTRFs[j, i,:]
546
+ TRF = TRFs[j, i, :]
547
+ # print(pearsonr(curFTRF, TRF))
548
+ assert np.around(pearsonr(curFTRF, TRF)[0]) >= 0.99
549
+
550
+ def vis(self, fs = None):
551
+ if plt is None:
552
+ raise ValueError('matplotlib should be installed')
553
+ with torch.no_grad():
554
+ FTRFs = self.vec_gauss_sum(
555
+ self.time_embedding_ext,
556
+ ).cpu()[0, ..., 0]
557
+ nInChan = self.inDim
558
+ nOutChan = self.outDim
559
+ fig, axs = plt.subplots(2)
560
+ fig.suptitle('top: original TRF, bottom: reconstructed TRF')
561
+ if fs is None:
562
+ timelag = self.timelag_idx_ext
563
+ else:
564
+ timelag = self.timelag_idx_ext.numpy() / fs
565
+ for j in range(nOutChan):
566
+ for i in range(nInChan):
567
+ TRF = self.TRFs[j,i,:].cpu()
568
+ FTRF = FTRFs[j,i,:].cpu()
569
+ # print(pearsonr(FTRF, TRF)[0])
570
+ axs[0].plot(timelag, TRF)
571
+ # if j == 0:
572
+ axs[1].plot(timelag, FTRF)
573
+ return fig
574
+
575
+
576
+
577
+ class FourierBasisTRF(FuncBasisTRF):
578
+
579
+ def __init__(
580
+ self,
581
+ nInChan,
582
+ nOutChan,
583
+ tmin_idx,
584
+ tmax_idx,
585
+ nBasis,
586
+ timeshiftLimit_idx,
587
+ device = 'cpu',
588
+ if_fit_coefs = False
589
+ ):
590
+ #TRFs the TRF for some channels
591
+ super().__init__(nInChan, nOutChan, tmin_idx, tmax_idx, timeshiftLimit_idx, device)
592
+ # self.nBasis = nBasis
593
+ # self.nInChan = nInChan
594
+ # self.nOutChan = nOutChan
595
+ # self.nWin = nWin
596
+ coefs = torch.empty((nOutChan, nInChan, nBasis),device=device)
597
+ if if_fit_coefs:
598
+ torch.nn.init.kaiming_uniform_(coefs, a=math.sqrt(5))
599
+ self.coefs = torch.nn.Parameter(coefs)
600
+ else:
601
+ self.register_buffer('coefs', coefs)
602
+ self.T = self.nWin - 1
603
+ self.device = device
604
+ maxN = nBasis // 2
605
+ self.seqN = torch.arange(1,maxN+1,device = self.device)
606
+ # self.saveMem = False #expr for saving memory usage
607
+
608
+ @property
609
+ def nBasis(self):
610
+ return self.coefs.shape[2]
611
+
612
+ def fitTRFs(self,TRFs):
613
+ '''
614
+ TRFs is the numpy array of mtrf weights
615
+ Shape: [nInDim, nLags, nOutput]
616
+ '''
617
+
618
+ TRFs = torch.from_numpy(TRFs)
619
+ TRFs = TRFs.permute(2, 0, 1)
620
+ self.TRFs[:,:,:] = TRFs.to(self.device)[:,:,:]
621
+ fd_basis_s = []
622
+ # grid_points = list(range(self.nWin))
623
+ grid_points = self.time_embedding_ext.squeeze().cpu().numpy()
624
+ for j in range(self.outDim):
625
+ for i in range(self.inDim):
626
+ TRF = TRFs[j, i, :]
627
+ fd = skfda.FDataGrid(
628
+ data_matrix=TRF,
629
+ grid_points=grid_points,
630
+ )
631
+ basis = skfda.representation.basis.Fourier(n_basis = self.nBasis)
632
+ fd_basis = fd.to_basis(basis)
633
+ coef = fd_basis.coefficients[0]
634
+ self.coefs[j, i, :] = torch.from_numpy(coef).to(self.device)
635
+
636
+ T = fd_basis.basis.period
637
+ assert T == self.T
638
+ fd_basis_s.append(fd_basis)
639
+
640
+ out = self.vec_fourier_sum(
641
+ self.nBasis,
642
+ self.T,
643
+ self.time_embedding_ext,
644
+ self.coefs
645
+ )[0, ..., 0]
646
+ for j in range(self.outDim):
647
+ for i in range(self.inDim):
648
+ fd_basis = fd_basis_s[j*self.inDim + i]
649
+ temp = fd_basis(grid_points).squeeze() #np.arange(0,self.nWin)
650
+ curFTRF = out[j, i,:].cpu().numpy()
651
+ TRF = TRFs[j, i,:]
652
+ assert np.around(pearsonr(TRF, temp)[0]) >= 0.99
653
+ try:
654
+ assert np.allclose(curFTRF,temp,atol = 1e-6)
655
+ except:
656
+ print(TRF, curFTRF,temp)
657
+ raise
658
+ # print(i,j,pearsonr(TRF, curFTRF))
659
+
660
+
661
+ def phi0(self,T):
662
+ return 1 / ((2 ** 0.5) * ((T/2) ** 0.5))
663
+
664
+ def phi2n_1(self,n,T,t):
665
+ #n: (maxN)
666
+ #t: (nBatch, 1, 1, nWin, nSeq, 1)
667
+
668
+ #(nBatch, 1, 1, nWin, nSeq, maxN)
669
+ t_input = 2 * torch.pi * t * n / T
670
+ #(nBatch, 1, 1, nSeq, maxN, nWin)
671
+ t_input = t_input.permute(0, 1, 2, 4, 5, 3)
672
+ signal = torch.sin(t_input) / (T/2)**0.5
673
+ return signal.permute(0, 1, 2, 5, 3, 4)
674
+
675
+ def phi2n(self,n,T,t):
676
+ #n: (maxN)
677
+ #t: (nBatch, 1, 1, nWin, nSeq, 1)
678
+ #(nBatch, 1, 1, nWin, nSeq, maxN)
679
+ t_input = 2 * torch.pi * t * n / T
680
+ #(nBatch, 1, 1, nSeq, maxN, nWin)
681
+ t_input = t_input.permute(0, 1, 2, 4, 5, 3)
682
+ signal = torch.cos(t_input) / (T/2)**0.5
683
+ return signal.permute(0, 1, 2, 5, 3, 4)
684
+
685
+ def vec_sum(self, x):
686
+ return self.vec_fourier_sum(self.nBasis, self.T, x, self.coefs)
687
+
688
+ def vec_fourier_sum(self,nBasis, T, t,coefs):
689
+ #coefs: (nOutChan, nInChan, nBasis)
690
+ #t: (nBatch, 1, 1, nWin, nSeq)
691
+ # if tChan of t is just 1, which means we share
692
+ # the same time-axis transformation for all channels
693
+ #return: (nBatch, outDim, inDim, nWin, nSeq)
694
+
695
+ #(nBatch, 1, 1, nWin, nSeq, 1)
696
+ t = t[..., None]
697
+ # print(t.shape)
698
+ const0 = self.phi0(T)
699
+ maxN = nBasis // 2
700
+ # (maxN)
701
+ seqN = self.seqN
702
+ # (nBatch, 1, 1, nWin, nSeq, maxN)
703
+ constSin = self.phi2n_1(seqN, T, t)
704
+ constCos = self.phi2n(seqN, T, t)
705
+
706
+ # (nBatch, 1, 1, nWin, nSeq, 2 * maxN)
707
+ constN = torch.stack(
708
+ [constSin,constCos],
709
+ axis = -1
710
+ ).reshape(*constSin.shape[:5], 2*maxN)
711
+ # print(const0,[i.shape for i in [constN, coefs]])
712
+
713
+ nBatch, _, _, nWin, nSeq, nBasis = constN.shape
714
+ nOutChan, nInChan, nBasis = coefs.shape
715
+
716
+ #(nOutChan, nInChan, 1, 1, nBasis)
717
+ coefs = coefs[:, :, None, None, :]
718
+ nBasis = nBasis
719
+ # print(constN.shape, coefs.shape)
720
+ '''
721
+ #expr for saving memory usage
722
+ memAvai,_ = torch.cuda.mem_get_info()
723
+ nMemReq = nBatch * nSeq * nInChan * nOutChan * nBasis * nWin * 4 # 4 indicates 4 bytes
724
+ # print(torch.cuda.memory_allocated()/1024/1024)
725
+ if nMemReq > memAvai * 0.9 or self.saveMem:
726
+ out = const0 * coefs[...,0] #(nOutChan, nInChan, 1, 1)
727
+ for nB in range(2 * maxN):
728
+ out = out + constN[...,nB] * coefs[...,1+nB]
729
+ else:
730
+ # (nbatch, nOutChan, nInChan, nWin, nSeq, nBasis)
731
+ out = const0 * coefs[...,0] + (constN * coefs[...,1:]).sum(-1)
732
+ '''
733
+ out = const0 * coefs[...,0] + (constN * coefs[...,1:]).sum(-1)
734
+ # print(torch.cuda.memory_allocated()/1024/1024)
735
+ # (nBatch, outDim, inDim, nWin, nSeq)
736
+ return out
737
+
738
+ def forward(self,a, b, c):
739
+ # a,b,c in most strict case (nBatch, 1, 1, 1, nSeq)
740
+ # loosly: a,b,c can be (nBatch, nOut, nIn, nWin, nSeq)
741
+
742
+ #self.time_embedding #(1, 1, 1, nWin, 1)
743
+
744
+ # x: (nBatch, 1, 1, nWin, nSeq)
745
+ nSeq = self.time_embedding
746
+ x = c * (nSeq - b)
747
+ # x = self.corrected_time_embedding(x)
748
+ #(nBatch, outDim, inDim, nWin, nSeq)
749
+ # nonLinTRFs = aSeq * self.basisTRF( cSeq * ( nSeq - bSeq) )
750
+
751
+ # return:
752
+ coefs = self.coefs
753
+ out = a * self.vec_fourier_sum(self.nBasis,self.T,x,coefs)
754
+ return out
755
+
756
+ def vis(self ,fs = None):
757
+ if fs is None:
758
+ timelag = self.timelag_idx_ext
759
+ else:
760
+ timelag = self.timelag_idx_ext.numpy() / fs
761
+ if plt is None:
762
+ raise ValueError('matplotlib should be installed')
763
+ with torch.no_grad():
764
+ FTRFs = self.vec_fourier_sum(
765
+ self.nBasis,
766
+ self.T,
767
+ self.time_embedding_ext,
768
+ self.coefs
769
+ )[0, ..., 0]
770
+ nInChan = self.inDim
771
+ nOutChan = self.outDim
772
+ fig, axs = plt.subplots(2)
773
+ fig.suptitle('top: original TRF, bottom: reconstructed TRF')
774
+ for j in range(nOutChan):
775
+ for i in range(nInChan):
776
+ TRF = self.TRFs[j,i,:].cpu()
777
+ FTRF = FTRFs[j,i,:].cpu()
778
+ axs[0].plot(timelag, TRF)
779
+ axs[1].plot(timelag, FTRF)
780
+ return fig
781
+
782
+ basisTRFNameMap = {
783
+ 'gauss': GaussianBasisTRF,
784
+ 'fourier': FourierBasisTRF
785
+ }
786
+
787
+
788
+ class TRFsGen(torch.nn.Module):
789
+
790
+ def forward(self, x, featOnsetIdx):
791
+ '''
792
+ input:
793
+ x: input to be used to derive the
794
+ transformation parameters TRFs
795
+ featOnsetIdx: time index of item
796
+ in the transformed x to be picked as real transformation parameters
797
+ '''
798
+ pass
799
+
800
+
801
+ class FuncTRFsGen(torch.nn.Module):
802
+ '''
803
+ Implement the functional TRF generator, generate dynamically
804
+ warped TRF by transform the functional TRF template
805
+ '''
806
+
807
+ def __init__(
808
+ self,
809
+ inDim,
810
+ outDim,
811
+ tmin_ms,
812
+ tmax_ms,
813
+ fs,
814
+ basisTRFName = 'fourier',
815
+ limitOfShift_idx = 7,
816
+ nBasis = 21,
817
+ mode = '',
818
+ transformer = None,
819
+ device = 'cpu',
820
+ # if_trans_per_outChan = False
821
+ ):
822
+ super().__init__()
823
+ assert mode.replace('+-','') in ['','a','b','a,b','a,b,c','a,c']
824
+ self.fs = fs
825
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
826
+ self.lagIdxs_ts = torch.Tensor(self.lagIdxs).float().to(device)
827
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
828
+ nWin = len(self.lagTimes)
829
+ self.mode = mode
830
+ self.transformer:torch.nn.Module = transformer
831
+ self.n_transform_params = len(mode.split(','))
832
+ self.device = device
833
+ if isinstance(basisTRFName, str):
834
+ self.basisTRF:FuncBasisTRF = basisTRFNameMap[basisTRFName](
835
+ inDim,
836
+ outDim,
837
+ self.lagIdxs[0],
838
+ self.lagIdxs[-1],
839
+ nBasis,
840
+ timeshiftLimit_idx = limitOfShift_idx,
841
+ device=device
842
+ )
843
+ elif isinstance(basisTRFName, torch.nn.Module):
844
+ self.basisTRF:FuncBasisTRF = basisTRFName
845
+ else:
846
+ raise ValueError()
847
+ # self.if_trans_per_outChan = if_trans_per_outChan
848
+
849
+ if transformer is None:
850
+ transInDim, transOutDim, device = self.get_default_transformer_param()
851
+ self.transformer:torch.nn.Module = CausalConv(transInDim, transOutDim, 2).to(device)
852
+
853
+ self.limitOfShift_idx = torch.tensor(limitOfShift_idx)
854
+
855
+ @classmethod
856
+ def parse_trans_params(cls,mode):
857
+ return mode.split(',')
858
+
859
+ @property
860
+ def tmin_ms(self):
861
+ return self.lagTimes[0]
862
+
863
+ @property
864
+ def tmax_ms(self):
865
+ return self.lagTimes[-1]
866
+
867
+ @property
868
+ def inDim(self):
869
+ return self.basisTRF.inDim
870
+
871
+ @property
872
+ def outDim(self):
873
+ return self.basisTRF.outDim
874
+
875
+ @property
876
+ def nWin(self):
877
+ return self.basisTRF.nWin - 2 * self.limitOfShift_idx
878
+
879
+ @property
880
+ def nBasis(self):
881
+ return self.basisTRF.nBasis
882
+
883
+ @property
884
+ def extendedTimeLagRange(self):
885
+ minLagIdx = self.lagIdxs[0]
886
+ maxLagIdx = self.lagIdxs[-1]
887
+ left = np.arange(minLagIdx - self.limitOfShift_idx, minLagIdx)
888
+ right = np.arange(maxLagIdx + 1, maxLagIdx + 1 + self.limitOfShift_idx)
889
+ extLag_idx = np.concatenate([left, self.lagIdxs, right])
890
+ try:
891
+ assert len(extLag_idx) == self.nWin + 2 * self.limitOfShift_idx
892
+ except:
893
+ print(len(extLag_idx), self.nWin + 2 * self.limitOfShift_idx)
894
+ timelags = Idxs2msec(extLag_idx, self.fs)
895
+ return timelags[0], timelags[-1]
896
+
897
+ def get_default_transformer_param(self):
898
+ inDim = self.inDim
899
+ device = self.device
900
+ outDim = self.n_transform_params
901
+ return inDim, outDim, device
902
+
903
+ def fitFuncTRF(self, w):
904
+ w = w * 1 / self.fs
905
+ with torch.no_grad():
906
+ self.basisTRF.fitTRFs(w)
907
+ return self
908
+
909
+ def pickParam(self,paramSeqs,idx):
910
+ #paramSeqs: (nBatch, nMiddleParam, nOut/1, 1, 1, nSeq)
911
+ return paramSeqs[:, idx, ...]
912
+
913
+
914
+ def getTransformParams(self, x, startIdx = None):
915
+ paramSeqs = self.transformer(x) #(nBatch, nMiddleParam, nSeq)
916
+ nBatch, nMiddleParam, nSeq = paramSeqs.shape
917
+ if startIdx is not None:
918
+ idxBatch = torch.arange(nBatch)
919
+ idxMiddleParam = torch.arange(nMiddleParam)
920
+ idxMiddleParam = idxMiddleParam[:, None]
921
+ startIdx = startIdx[:, None, :]
922
+ idxBatch = idxBatch[:, None, None]
923
+ paramSeqs = paramSeqs[idxBatch, idxMiddleParam, startIdx]
924
+
925
+ #(nBatch, n_transform_params, 1, 1, nSeq), this is the most strict case
926
+ #however, it can also be (nBatch, n_transform_params, nOut, 1, nSeq)
927
+ # for different transformation for different channel
928
+
929
+ paramSeqs = paramSeqs.view(nBatch, self.n_transform_params, -1, 1, 1, nSeq) #[:, :, None, None, :]
930
+ if paramSeqs.shape[2] != 1:
931
+ assert paramSeqs.shape[2] == self.outDim
932
+ midParamList = self.mode.split(',')
933
+ if midParamList == ['']:
934
+ midParamList = []
935
+ nParamMiss = 0
936
+ if 'a' in midParamList:
937
+ aIdx = midParamList.index('a')
938
+ #(nBatch, 1, nOut/1, 1, nSeq)
939
+ aSeq = self.pickParam(paramSeqs, aIdx)
940
+ aSeq = torch.abs(aSeq)
941
+ elif '+-a' in midParamList:
942
+ aIdx = midParamList.index('+-a')
943
+ #(nBatch, 1, nOut/1, 1, nSeq)
944
+ aSeq = self.pickParam(paramSeqs, aIdx)
945
+ else:
946
+ nParamMiss += 1
947
+ #(nBatch, 1, inDim, 1, nSeq)
948
+ aSeq = x[:, None, :, None, :]
949
+ if 'b' in midParamList:
950
+ bIdx = midParamList.index('b')
951
+ #(nBatch, 1, 1, 1, nSeq)
952
+ bSeq = self.pickParam(paramSeqs, bIdx)
953
+ bSeq = torch.maximum(bSeq, - self.limitOfShift_idx)
954
+ bSeq = torch.minimum(bSeq, self.limitOfShift_idx)
955
+ else:
956
+ nParamMiss += 1
957
+ bSeq = 0
958
+
959
+ if 'c' in midParamList:
960
+ cIdx = midParamList.index('c')
961
+ #(nBatch, 1, 1, 1, nSeq)
962
+ cSeq = self.pickParam(paramSeqs, cIdx)
963
+ #two reasons, cSeq must be larger than 0;
964
+ #if 1 is the optimum, abs will have two x for the optimum,
965
+ # which is not stable
966
+ cSeq = 1 + cSeq
967
+ cSeq = torch.maximum(cSeq, torch.tensor(0.5))
968
+ cSeq = torch.minimum(cSeq, torch.tensor(1.28))
969
+ else:
970
+ nParamMiss += 1
971
+ cSeq = 1
972
+
973
+ assert (len(midParamList) + nParamMiss) == 3
974
+ return aSeq, bSeq, cSeq
975
+
976
+ def forward(self, x, featOnsetIdx = None):
977
+ '''
978
+ x: (nBatch, inDim, nSeq)
979
+ output: TRFs (nBatch, outDim, nWin, nSeq)
980
+ '''
981
+ #(nBatch, nOut/1, 1, 1, nSeq)
982
+ aSeq, bSeq, cSeq = self.getTransformParams(x, featOnsetIdx)
983
+ #(1, 1, 1, nWin, 1)
984
+ # nSeq = self.lagIdxs_ts[None, None, None, :, None] + self.limitOfShift_idx
985
+ # print(aSeq, bSeq, cSeq)x
986
+ #(nBatch, outDim, inDim, nWin, nSeq)
987
+ # nonLinTRFs = aSeq * self.basisTRF( cSeq * ( nSeq - bSeq) )
988
+ # print(aSeq.shape, bSeq.shape)
989
+ nonLinTRFs = self.basisTRF(aSeq, bSeq, cSeq)
990
+ # print(torch.cuda.memory_allocated()/1024/1024)
991
+
992
+ #(nBatch, outDim, nWin, nSeq)
993
+ TRFs = nonLinTRFs.sum(2)
994
+ # print(torch.cuda.memory_allocated()/1024/1024)
995
+ return TRFs
996
+
997
+ class ASTRF(torch.nn.Module):
998
+ '''
999
+ the TRF implemented the convolution sum of temporal response,
1000
+ (i.e., time-aligning the temporal responses at their
1001
+ corresponding location, and point-wise sum them).
1002
+ It requres a module to generate temproal responses to each
1003
+ individual stimuli, and also require time information to
1004
+ displace/align the temporal responses at the right
1005
+ indices/location
1006
+
1007
+ limitation: can't do TRF for zscored input, under this condition
1008
+ location with no stimulus will be non-zero.
1009
+
1010
+
1011
+ the core mechanism of this module following thess steps:
1012
+ 1. generate TRFs using the input param 'x',
1013
+ 2. determine the onset time of these TRFs within the output time series, using input param 'timeinfo',
1014
+ 3. align and sum the generated TRFs at their corresponding time location in the output.
1015
+
1016
+ Note:
1017
+ 1. currently, there are two types of x supported,
1018
+ type 1: discrete type, means there is a single x for each time point in the 'timestamp'
1019
+ type 2: continuous type, means there is a timeseries of x which has the same length as the output timeseries,
1020
+ and when generating TRFs, only part of the x will actually contribute to this process.
1021
+
1022
+ '''
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ inDim,
1027
+ outDim,
1028
+ tmin_ms,
1029
+ tmax_ms,
1030
+ fs,
1031
+ trfsGen = None,
1032
+ device = 'cpu',
1033
+ x_is_timeseries = False
1034
+ ):
1035
+ '''
1036
+ inDim: int, the number of columns of input
1037
+ outDim: int, the number of columns of output of ltiTRFGen and trfsGen
1038
+
1039
+ '''
1040
+ super().__init__()
1041
+ assert tmin_ms >= 0
1042
+ self.x_is_timeseries = x_is_timeseries
1043
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
1044
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
1045
+ nWin = len(self.lagTimes)
1046
+ self.ltiTRFsGen = LTITRFGen(
1047
+ inDim,
1048
+ nWin,
1049
+ outDim,
1050
+ ifAddBiasInForward=False
1051
+ ).to(device)
1052
+ # if callable(trfsGen):
1053
+ # trfsGen = trfsGen()
1054
+ self.trfsGen:FuncTRFsGen = trfsGen if trfsGen is None else trfsGen.to(device)
1055
+ self.fs = fs
1056
+
1057
+ self.bias = None
1058
+ #also train bias for the trfsGen provided by the user
1059
+ if self.trfsGen is not None:
1060
+ self.init_nonLinTRFs_bias(inDim, nWin, outDim, device)
1061
+
1062
+ self.trfAligner = TRFAligner(device)
1063
+ self._enableUserTRFGen = False
1064
+ self.device = device
1065
+
1066
+ @property
1067
+ def inDim(self):
1068
+ return self.ltiTRFsGen.inDim
1069
+
1070
+ @property
1071
+ def outDim(self):
1072
+ return self.ltiTRFsGen.outDim
1073
+
1074
+ @property
1075
+ def nWin(self):
1076
+ return self.ltiTRFsGen.nWin
1077
+
1078
+ @property
1079
+ def tmin_ms(self):
1080
+ return self.lagTimes[0]
1081
+
1082
+ @property
1083
+ def tmax_ms(self):
1084
+ return self.lagTimes[-1]
1085
+
1086
+ def init_nonLinTRFs_bias(self, inDim, nWin, outDim, device):
1087
+ self.bias = torch.nn.Parameter(torch.ones(outDim))
1088
+ fan_in = inDim * nWin
1089
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1090
+ torch.nn.init.uniform_(self.bias, -bound, bound)
1091
+
1092
+ def set_trfs_gen(self, trfsGen):
1093
+ self.trfsGen = trfsGen.to(self.device)
1094
+ self.bias = torch.nn.Parameter(
1095
+ torch.ones(self.outDim, device = self.device)
1096
+ )
1097
+ fan_in = self.inDim * self.nWin
1098
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1099
+ torch.nn.init.uniform_(self.bias, -bound, bound)
1100
+
1101
+ def get_params_for_train(self):
1102
+ return [i for i in self.trfsGen.transformer.parameters()] + [self.bias]
1103
+ # raise NotImplementedError()
1104
+
1105
+
1106
+ def set_linear_weights(self, w, b):
1107
+ #w: (nInChan, nLag, nOutChan)
1108
+ self.ltiTRFsGen.load_mtrf_weights(
1109
+ w,
1110
+ b,
1111
+ self.fs,
1112
+ self.device
1113
+ )
1114
+ return self
1115
+
1116
+ def get_linear_weights(self):
1117
+ w, b = self.ltiTRFsGen.export_mtrf_weights(
1118
+ self.fs
1119
+ )
1120
+ return w, b
1121
+
1122
+ @property
1123
+ def if_enable_trfsGen(self):
1124
+ return self._enableUserTRFGen
1125
+
1126
+ @if_enable_trfsGen.setter
1127
+ def if_enable_trfsGen(self,x):
1128
+ assert isinstance(x, bool)
1129
+ print('set ifEnableNonLin',x)
1130
+ if x == True and self.trfsGen is None:
1131
+ raise ValueError('trfGen is None, cannot be enabled')
1132
+ self._enableUserTRFGen = x
1133
+
1134
+ def stop_update_linear(self):
1135
+ self.ltiTRFsGen.stop_update_weights()
1136
+
1137
+ def enable_update_linear(self):
1138
+ self.ltiTRFsGen.enable_update_weights()
1139
+
1140
+ def forward(self, x, timeinfo):
1141
+ '''
1142
+ input:
1143
+ x: nBatch * [nChan, nSeq],
1144
+ timeinfo: nBatch * [2, nSeq]
1145
+ output: targetTensor
1146
+ '''
1147
+
1148
+ ### record the necessary information of each item in the batch
1149
+ ### for x and targetTensor
1150
+ nSeqs = [] # length of x in the batch
1151
+ nRealLens = [] # length of real output in the batch
1152
+ trfOnsetIdxs = [] # the corresponding index of timepoint in the timeinfo in the batch
1153
+ for ix, xi in enumerate(x):
1154
+ nLenXi = xi.shape[-1]
1155
+ if timeinfo[ix] is not None:
1156
+ # print(timeinfo[ix].shape)
1157
+ if not self.x_is_timeseries:
1158
+ assert timeinfo[ix].shape[-1] == xi.shape[-1]
1159
+ nLen = torch.ceil(
1160
+ timeinfo[ix][0][-1] * self.fs
1161
+ ).long() + self.nWin
1162
+ onsetIdx = torch.round(
1163
+ timeinfo[ix][0,:] * self.fs
1164
+ ).long() + self.lagIdxs[0]
1165
+ else:
1166
+ nLen = nLenXi
1167
+ onsetIdx = torch.tensor(np.arange(nLen)) + self.lagIdxs[0]
1168
+ nSeqs.append(nLenXi)
1169
+ nRealLens.append(nLen)
1170
+ trfOnsetIdxs.append(onsetIdx)
1171
+
1172
+ nGlobLen = max(nRealLens)
1173
+ x = seqLast_pad_zero(x)
1174
+ trfOnsetIdxs = seqLast_pad_zero(trfOnsetIdxs, value = -1)
1175
+
1176
+ # if x is time series
1177
+ featOnsetIdxs = None
1178
+ if self.x_is_timeseries:
1179
+ # featIdxs is the index where we get the feat for generating TRFs
1180
+ featOnsetIdxs = trfOnsetIdxs.detach().clone()
1181
+ featOnsetIdxs[featOnsetIdxs != -1] =\
1182
+ featOnsetIdxs[featOnsetIdxs != -1] - self.lagIdxs[0]
1183
+
1184
+ #TRFs shape: (nBatch, outDim, nWin, nSeq)
1185
+ # print(x.shape)
1186
+ TRFs = self.get_trfs(x, featOnsetIdxs)
1187
+
1188
+ #targetTensor shape: (nBatch,outDim,nRealLen)
1189
+ targetTensor = self.trfAligner(TRFs,trfOnsetIdxs,nGlobLen)
1190
+
1191
+ if self.if_enable_trfsGen:
1192
+ targetTensor = targetTensor + self.bias.view(-1,1)
1193
+ else:
1194
+ ltiTRFBias = self.ltiTRFsGen.bias
1195
+ targetTensor = targetTensor + ltiTRFBias.view(-1,1)
1196
+
1197
+ return targetTensor
1198
+
1199
+ def get_trfs(self, x, featOnsetIdxs = None):
1200
+ if self.if_enable_trfsGen:
1201
+ return self.trfsGen(x, featOnsetIdxs)
1202
+ else:
1203
+ return self.ltiTRFsGen(x)
1204
+
1205
+
1206
+ class ASCNNTRF(ASTRF):
1207
+ #perform CNNTRF within intervals,
1208
+ #and change the weights
1209
+
1210
+ def __init__(
1211
+ self,
1212
+ inDim,
1213
+ outDim,
1214
+ tmin_ms,
1215
+ tmax_ms,
1216
+ fs,
1217
+ trfsGen = None,
1218
+ device = 'cpu'
1219
+ ):
1220
+ torch.nn.Module.__init__(self)
1221
+ # assert tmin_ms >= 0
1222
+ self.inDim = inDim
1223
+ self.outDim = outDim
1224
+ self.tmin_ms = tmin_ms
1225
+ self.tmax_ms = tmax_ms
1226
+ self.lagIdxs = msec2Idxs([tmin_ms,tmax_ms],fs)
1227
+ self.lagTimes = Idxs2msec(self.lagIdxs,fs)
1228
+ self.tmin_idx = self.lagIdxs[0]
1229
+ self.tmax_idx = self.lagIdxs[-1]
1230
+ nWin = len(self.lagTimes)
1231
+ self._nWin = nWin
1232
+ self.ltiTRFsGen = LTITRFGen(
1233
+ inDim,
1234
+ nWin,
1235
+ outDim,
1236
+ ifAddBiasInForward=False
1237
+ ).to(device)
1238
+ self.trfsGen = trfsGen if trfsGen is None else trfsGen.to(device)
1239
+ self.fs = fs
1240
+
1241
+ self.bias = None
1242
+ #also train bias for the trfsGen provided by the user
1243
+ if self.trfsGen is not None:
1244
+ self.bias = torch.nn.Parameter(torch.ones(outDim, device = device))
1245
+ fan_in = inDim * nWin
1246
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1247
+ torch.nn.init.uniform_(self.bias, -bound, bound)
1248
+
1249
+ self._enableUserTRFGen = False
1250
+ self.device = device
1251
+
1252
+ @property
1253
+ def nWin(self):
1254
+ return self._nWin
1255
+
1256
+ def getTRFs(self, ctx):
1257
+ # ctx: (nBatch, inDim,nSeq)
1258
+ # #nTRFs is the number of TRFs needed
1259
+
1260
+ #note the difference between the ASTRF and ASCNNTRF at here
1261
+ #the LTITRF is not multiplied with x
1262
+
1263
+
1264
+ if self.ifEnableUserTRFGen:
1265
+ #how to decide how much trfs to return????
1266
+ TRFs = self.trfsGen(ctx)
1267
+ #TRFs (nBatch, outDim, nWin, nSeq)
1268
+ TRFs = TRFs[0] #(outDim, nWin, nSeq)
1269
+ TRFs = TRFs.permute(2, 0, 1)[..., None, :]
1270
+ else:
1271
+ #ctx: (nBatch, inDim,nSeq) (nTRFs, inDim,nSeq)
1272
+ nTRFs = ctx.shape[-1]#len(ctx)
1273
+ # TRFs nTRFs * (nChanOut, nChanIn, nWin)
1274
+ TRFs = [self.ltiTRFsGen.weight] * nTRFs
1275
+ return TRFs
1276
+
1277
+ def getTRFSwitchOnsets(self, x):
1278
+ #X: nBatch * [nChan, nSeq]
1279
+ return [0, 200, 300, 400, 500, 600, 800, 1000]
1280
+
1281
+
1282
+ def defaultCtx(self, switchOnsets, x):
1283
+ switchOnsets2 = switchOnsets + [-1]
1284
+ ctx = []
1285
+ for i in range(len(switchOnsets)):
1286
+ ctx.append(x[:, :, switchOnsets2[i]:switchOnsets2[i+1]])
1287
+ return ctx
1288
+
1289
+ def forward(self, x, timeInfo = None, ctx = None):
1290
+
1291
+ #currently only support single batch
1292
+
1293
+ #X: nBatch * [nChan, nSeq]
1294
+ #timeinfo: nBatch * [2, nSeq]
1295
+ if timeInfo is None:
1296
+ TRFSwitchOnsets = self.getTRFSwitchOnsets(x)
1297
+ else:
1298
+ TRFSwitchOnsets = timeInfo
1299
+ if ctx is None:
1300
+ ctx = self.defaultCtx(TRFSwitchOnsets, x)
1301
+ nTRFs = len(TRFSwitchOnsets)
1302
+ nBatch, nChan, nSeq = x.shape
1303
+ TRFs = self.getTRFs(ctx)
1304
+ TRFsFlip = [TRF.flip([-1]) for TRF in TRFs]
1305
+ # print([torch.equal(TRFsFlip[0], temp) for temp in TRFsFlip[1:]])
1306
+ # print(TRFsFlip)
1307
+ TRFSwitchOnsets.append(None)
1308
+
1309
+ nPaddedOutput = nSeq \
1310
+ + max(-self.tmin_idx, 0) \
1311
+ + max( self.tmax_idx, 0)
1312
+
1313
+ output = torch.zeros(
1314
+ nBatch,
1315
+ self.outDim,
1316
+ nPaddedOutput,
1317
+ device = self.device
1318
+ )
1319
+
1320
+ #segment startIdx offset
1321
+ startOffset = self.tmin_idx
1322
+ #global startIdx offset
1323
+ offset = min(0, self.tmin_idx)
1324
+
1325
+ realOffset = startOffset - offset
1326
+
1327
+ for idx, TRFFlip in enumerate(TRFsFlip):
1328
+ t_start = TRFSwitchOnsets[idx]
1329
+ t_end = TRFSwitchOnsets[idx+1]
1330
+ t_x = x[..., t_start : t_end]
1331
+ segment = self.trfCNN(t_x,TRFFlip)
1332
+ # print(realOffset)
1333
+ t_startReal = t_start + realOffset
1334
+ t_endReal = t_startReal + segment.shape[-1]
1335
+ # print(segment.shape, t_startReal, t_endReal, t_x.shape, x.shape)
1336
+ output[:,:,t_startReal : t_endReal] += segment
1337
+
1338
+ #decide how to crop the output based on tmin and tmax
1339
+ startIdx = max(-self.tmin_idx, 0)
1340
+ lenOutput = output.shape[-1]
1341
+ endIdx = lenOutput - max( self.tmax_idx, 0)
1342
+
1343
+ return output[..., startIdx: endIdx] + self.ltiTRFsGen.bias.view(-1, 1)
1344
+
1345
+ def trfCNN(self, x, TRFFlip):
1346
+ #X: (nBatch, nChan, nSubSeq)
1347
+ # TRFsFlip is the TRFs with its kernel dimension flipped
1348
+ # for Conv1D!
1349
+
1350
+ #need to first padding
1351
+ #timelag doesn't influence how much to pad
1352
+ # but the offset of the startIdx
1353
+ nWin = TRFFlip.shape[-1]
1354
+ x = torch.nn.functional.pad(x, (nWin-1, nWin-1))
1355
+ #then do the conv
1356
+ output = torch.nn.functional.conv1d(x, TRFFlip)
1357
+ return output
1358
+