nntrf 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nntrf/models/composite.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- class TwoMixedTRF(torch.nn.Module):
3
+ class MixedTRF(torch.nn.Module):
4
4
 
5
5
  def __init__(
6
6
  self,
nntrf/models/linear.py CHANGED
@@ -214,9 +214,9 @@ class CNNTRF(torch.nn.Module):
214
214
  b = b * 1/self.fs
215
215
  b = b[0]
216
216
  w = np.flip(w,axis = 1).copy()
217
- w = torch.FloatTensor(w).to(device)
217
+ w = torch.from_numpy(w).to(device)
218
218
  w = w.permute(2,0,1)
219
- b = torch.FloatTensor(b).to(device)
219
+ b = torch.from_numpy(b).to(device)
220
220
  with torch.no_grad():
221
221
  self.oCNN.weight = torch.nn.Parameter(w)
222
222
  self.oCNN.bias = torch.nn.Parameter(b)
nntrf/models/nonlinear.py CHANGED
@@ -146,8 +146,8 @@ class LTITRFGen(torch.nn.Module):
146
146
  b = b[0]
147
147
  w = w * 1 / fs
148
148
  b = b * 1/ fs
149
- w = torch.FloatTensor(w).to(device)
150
- b = torch.FloatTensor(b).to(device)
149
+ w = torch.from_numpy(w).to(device)
150
+ b = torch.from_numpy(b).to(device)
151
151
  w = w.permute(2, 0, 1) #(nOutChan, nInChan, nLag)
152
152
  with torch.no_grad():
153
153
  self.weight = torch.nn.Parameter(w)
@@ -359,7 +359,8 @@ def build_gaussian_response(x, mu, sigma):
359
359
  # output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
360
360
 
361
361
  # x: (nBatch, 1, 1, 1, nWin, nSeq)
362
- x = x[:, None, ...]
362
+ if x.ndim == 5:
363
+ x = x[:, None, ...]
363
364
  # mu: (nBasis, 1, 1, 1, 1)
364
365
  mu = mu[..., None, None, None, None]
365
366
  # sigma: (nBasis, outDim, inDim, 1, 1)
@@ -901,7 +902,8 @@ class FuncTRFsGen(torch.nn.Module):
901
902
  return inDim, outDim, device
902
903
 
903
904
  def fitFuncTRF(self, w):
904
- w = w * 1 / self.fs
905
+ # do the self.fs operation here because basisTRF doesn't keep fs info
906
+ w = w * 1 / self.fs
905
907
  with torch.no_grad():
906
908
  self.basisTRF.fitTRFs(w)
907
909
  return self
@@ -960,9 +962,11 @@ class FuncTRFsGen(torch.nn.Module):
960
962
  cIdx = midParamList.index('c')
961
963
  #(nBatch, 1, 1, 1, nSeq)
962
964
  cSeq = self.pickParam(paramSeqs, cIdx)
965
+ # print('c: ',cSeq.squeeze())
963
966
  #two reasons, cSeq must be larger than 0;
964
967
  #if 1 is the optimum, abs will have two x for the optimum,
965
968
  # which is not stable
969
+ # cSeq = torch.tanh(cSeq) #expr
966
970
  cSeq = 1 + cSeq
967
971
  cSeq = torch.maximum(cSeq, torch.tensor(0.5))
968
972
  cSeq = torch.minimum(cSeq, torch.tensor(1.28))
@@ -1030,7 +1034,8 @@ class ASTRF(torch.nn.Module):
1030
1034
  fs,
1031
1035
  trfsGen = None,
1032
1036
  device = 'cpu',
1033
- x_is_timeseries = False
1037
+ x_is_timeseries = False,
1038
+ verbose = True
1034
1039
  ):
1035
1040
  '''
1036
1041
  inDim: int, the number of columns of input
@@ -1060,8 +1065,9 @@ class ASTRF(torch.nn.Module):
1060
1065
  self.init_nonLinTRFs_bias(inDim, nWin, outDim, device)
1061
1066
 
1062
1067
  self.trfAligner = TRFAligner(device)
1063
- self._enableUserTRFGen = False
1068
+ self._enableUserTRFGen = True
1064
1069
  self.device = device
1070
+ self.verbose = verbose
1065
1071
 
1066
1072
  @property
1067
1073
  def inDim(self):
@@ -1126,7 +1132,8 @@ class ASTRF(torch.nn.Module):
1126
1132
  @if_enable_trfsGen.setter
1127
1133
  def if_enable_trfsGen(self,x):
1128
1134
  assert isinstance(x, bool)
1129
- print('set ifEnableNonLin',x)
1135
+ if self.verbose:
1136
+ print('set ifEnableNonLin',x)
1130
1137
  if x == True and self.trfsGen is None:
1131
1138
  raise ValueError('trfGen is None, cannot be enabled')
1132
1139
  self._enableUserTRFGen = x
@@ -1155,7 +1162,7 @@ class ASTRF(torch.nn.Module):
1155
1162
  if timeinfo[ix] is not None:
1156
1163
  # print(timeinfo[ix].shape)
1157
1164
  if not self.x_is_timeseries:
1158
- assert timeinfo[ix].shape[-1] == xi.shape[-1]
1165
+ assert timeinfo[ix].shape[-1] == xi.shape[-1], f"{timeinfo[ix].shape[-1]} != {xi.shape[-1]}"
1159
1166
  nLen = torch.ceil(
1160
1167
  timeinfo[ix][0][-1] * self.fs
1161
1168
  ).long() + self.nWin
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: nntrf
3
- Version: 1.0.1
3
+ Version: 1.0.4
4
4
  Home-page: https://github.com/powerfulbean/nnTRF
5
5
  Author: Jin Dou
6
6
  Author-email: jindou.bci@gmail.com
@@ -10,9 +10,8 @@ Classifier: Operating System :: OS Independent
10
10
  Requires-Python: >=3.8
11
11
  Description-Content-Type: text/markdown
12
12
  License-File: LICENSE
13
- Requires-Dist: numpy>=1.20.1
14
- Requires-Dist: torch<2.0.0,>=1.12.1
15
- Requires-Dist: scikit-fda==0.7.1
13
+ Requires-Dist: numpy
14
+ Requires-Dist: torch
16
15
  Requires-Dist: mtrf
17
16
  Requires-Dist: scipy
18
17
  Dynamic: author
@@ -20,5 +19,6 @@ Dynamic: author-email
20
19
  Dynamic: classifier
21
20
  Dynamic: description-content-type
22
21
  Dynamic: home-page
22
+ Dynamic: license-file
23
23
  Dynamic: requires-dist
24
24
  Dynamic: requires-python
@@ -0,0 +1,13 @@
1
+ nntrf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ nntrf/loss.py,sha256=h-ZzLoBd3S0uRENiqyfwTpb75x6Xp6WjRpNnD-_jraE,651
3
+ nntrf/metrics.py,sha256=TUkg2yVXhNTW96uxHu4nMYKWvt7jbh_HdHVo9ii3Lpw,782
4
+ nntrf/utils.py,sha256=h6vdrNvArn88cOhsCXz1_24a-HPiX36RfzNfq-jAdn0,224
5
+ nntrf/models/__init__.py,sha256=9RSC1DAp-d2ZgKndAk6myWn7ogyMj5ODkLPemBrStnE,71
6
+ nntrf/models/composite.py,sha256=qgni2yZgTUFAiGP96egSKg5m7PNCmawb6suroV7tuvo,2136
7
+ nntrf/models/linear.py,sha256=MrsVGnIIkuOOMXcBff1GQOiX2l5isF6_8f_2XOrgTrk,8203
8
+ nntrf/models/nonlinear.py,sha256=UM-xOgxCsgy9_oMv1Zc84Ep6C9bGnYUXELr_oY-wLZw,45853
9
+ nntrf-1.0.4.dist-info/licenses/LICENSE,sha256=BxgW39X9rivrb3TVuhpIgjctmWn3HsvEsN4dXtTDNgU,1069
10
+ nntrf-1.0.4.dist-info/METADATA,sha256=RUJ5wmeeYYQiOGW90CQZpXdcFEzqBhFbF8E00rpVcWw,644
11
+ nntrf-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ nntrf-1.0.4.dist-info/top_level.txt,sha256=amAdTZ4QBomZv41i3oWF5lnkU6187LpyDIqhpJSCw4g,6
13
+ nntrf-1.0.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,13 +0,0 @@
1
- nntrf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- nntrf/loss.py,sha256=h-ZzLoBd3S0uRENiqyfwTpb75x6Xp6WjRpNnD-_jraE,651
3
- nntrf/metrics.py,sha256=TUkg2yVXhNTW96uxHu4nMYKWvt7jbh_HdHVo9ii3Lpw,782
4
- nntrf/utils.py,sha256=h6vdrNvArn88cOhsCXz1_24a-HPiX36RfzNfq-jAdn0,224
5
- nntrf/models/__init__.py,sha256=9RSC1DAp-d2ZgKndAk6myWn7ogyMj5ODkLPemBrStnE,71
6
- nntrf/models/composite.py,sha256=3nQIy0MZmgo9-OCj0mPD3hcUUkHG0lK1h9tyyZtIaiM,2139
7
- nntrf/models/linear.py,sha256=DPcQK6ol0zXaDo8p1S5zb-075fG1L0NbfwMlQzm_fyw,8205
8
- nntrf/models/nonlinear.py,sha256=E0bQAa78DMd8YgQ7JOSmyUbdfq_UwgXQTJ9UmzN15vg,45536
9
- nntrf-1.0.1.dist-info/LICENSE,sha256=BxgW39X9rivrb3TVuhpIgjctmWn3HsvEsN4dXtTDNgU,1069
10
- nntrf-1.0.1.dist-info/METADATA,sha256=kY-Zfhrj2GS5wElgNcX-sVw-n6v2HzZLRKo2_u1Zd8M,678
11
- nntrf-1.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
12
- nntrf-1.0.1.dist-info/top_level.txt,sha256=amAdTZ4QBomZv41i3oWF5lnkU6187LpyDIqhpJSCw4g,6
13
- nntrf-1.0.1.dist-info/RECORD,,