nntrf 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntrf/models/composite.py +1 -1
- nntrf/models/linear.py +2 -2
- nntrf/models/nonlinear.py +15 -8
- {nntrf-1.0.1.dist-info → nntrf-1.0.4.dist-info}/METADATA +5 -5
- nntrf-1.0.4.dist-info/RECORD +13 -0
- {nntrf-1.0.1.dist-info → nntrf-1.0.4.dist-info}/WHEEL +1 -1
- nntrf-1.0.1.dist-info/RECORD +0 -13
- {nntrf-1.0.1.dist-info → nntrf-1.0.4.dist-info/licenses}/LICENSE +0 -0
- {nntrf-1.0.1.dist-info → nntrf-1.0.4.dist-info}/top_level.txt +0 -0
nntrf/models/composite.py
CHANGED
nntrf/models/linear.py
CHANGED
@@ -214,9 +214,9 @@ class CNNTRF(torch.nn.Module):
|
|
214
214
|
b = b * 1/self.fs
|
215
215
|
b = b[0]
|
216
216
|
w = np.flip(w,axis = 1).copy()
|
217
|
-
w = torch.
|
217
|
+
w = torch.from_numpy(w).to(device)
|
218
218
|
w = w.permute(2,0,1)
|
219
|
-
b = torch.
|
219
|
+
b = torch.from_numpy(b).to(device)
|
220
220
|
with torch.no_grad():
|
221
221
|
self.oCNN.weight = torch.nn.Parameter(w)
|
222
222
|
self.oCNN.bias = torch.nn.Parameter(b)
|
nntrf/models/nonlinear.py
CHANGED
@@ -146,8 +146,8 @@ class LTITRFGen(torch.nn.Module):
|
|
146
146
|
b = b[0]
|
147
147
|
w = w * 1 / fs
|
148
148
|
b = b * 1/ fs
|
149
|
-
w = torch.
|
150
|
-
b = torch.
|
149
|
+
w = torch.from_numpy(w).to(device)
|
150
|
+
b = torch.from_numpy(b).to(device)
|
151
151
|
w = w.permute(2, 0, 1) #(nOutChan, nInChan, nLag)
|
152
152
|
with torch.no_grad():
|
153
153
|
self.weight = torch.nn.Parameter(w)
|
@@ -359,7 +359,8 @@ def build_gaussian_response(x, mu, sigma):
|
|
359
359
|
# output: (nBatch, nBasis, outDim, inDim, nWin, nSeq)
|
360
360
|
|
361
361
|
# x: (nBatch, 1, 1, 1, nWin, nSeq)
|
362
|
-
|
362
|
+
if x.ndim == 5:
|
363
|
+
x = x[:, None, ...]
|
363
364
|
# mu: (nBasis, 1, 1, 1, 1)
|
364
365
|
mu = mu[..., None, None, None, None]
|
365
366
|
# sigma: (nBasis, outDim, inDim, 1, 1)
|
@@ -901,7 +902,8 @@ class FuncTRFsGen(torch.nn.Module):
|
|
901
902
|
return inDim, outDim, device
|
902
903
|
|
903
904
|
def fitFuncTRF(self, w):
|
904
|
-
|
905
|
+
# do the self.fs operation here because basisTRF doesn't keep fs info
|
906
|
+
w = w * 1 / self.fs
|
905
907
|
with torch.no_grad():
|
906
908
|
self.basisTRF.fitTRFs(w)
|
907
909
|
return self
|
@@ -960,9 +962,11 @@ class FuncTRFsGen(torch.nn.Module):
|
|
960
962
|
cIdx = midParamList.index('c')
|
961
963
|
#(nBatch, 1, 1, 1, nSeq)
|
962
964
|
cSeq = self.pickParam(paramSeqs, cIdx)
|
965
|
+
# print('c: ',cSeq.squeeze())
|
963
966
|
#two reasons, cSeq must be larger than 0;
|
964
967
|
#if 1 is the optimum, abs will have two x for the optimum,
|
965
968
|
# which is not stable
|
969
|
+
# cSeq = torch.tanh(cSeq) #expr
|
966
970
|
cSeq = 1 + cSeq
|
967
971
|
cSeq = torch.maximum(cSeq, torch.tensor(0.5))
|
968
972
|
cSeq = torch.minimum(cSeq, torch.tensor(1.28))
|
@@ -1030,7 +1034,8 @@ class ASTRF(torch.nn.Module):
|
|
1030
1034
|
fs,
|
1031
1035
|
trfsGen = None,
|
1032
1036
|
device = 'cpu',
|
1033
|
-
x_is_timeseries = False
|
1037
|
+
x_is_timeseries = False,
|
1038
|
+
verbose = True
|
1034
1039
|
):
|
1035
1040
|
'''
|
1036
1041
|
inDim: int, the number of columns of input
|
@@ -1060,8 +1065,9 @@ class ASTRF(torch.nn.Module):
|
|
1060
1065
|
self.init_nonLinTRFs_bias(inDim, nWin, outDim, device)
|
1061
1066
|
|
1062
1067
|
self.trfAligner = TRFAligner(device)
|
1063
|
-
self._enableUserTRFGen =
|
1068
|
+
self._enableUserTRFGen = True
|
1064
1069
|
self.device = device
|
1070
|
+
self.verbose = verbose
|
1065
1071
|
|
1066
1072
|
@property
|
1067
1073
|
def inDim(self):
|
@@ -1126,7 +1132,8 @@ class ASTRF(torch.nn.Module):
|
|
1126
1132
|
@if_enable_trfsGen.setter
|
1127
1133
|
def if_enable_trfsGen(self,x):
|
1128
1134
|
assert isinstance(x, bool)
|
1129
|
-
|
1135
|
+
if self.verbose:
|
1136
|
+
print('set ifEnableNonLin',x)
|
1130
1137
|
if x == True and self.trfsGen is None:
|
1131
1138
|
raise ValueError('trfGen is None, cannot be enabled')
|
1132
1139
|
self._enableUserTRFGen = x
|
@@ -1155,7 +1162,7 @@ class ASTRF(torch.nn.Module):
|
|
1155
1162
|
if timeinfo[ix] is not None:
|
1156
1163
|
# print(timeinfo[ix].shape)
|
1157
1164
|
if not self.x_is_timeseries:
|
1158
|
-
assert timeinfo[ix].shape[-1] == xi.shape[-1]
|
1165
|
+
assert timeinfo[ix].shape[-1] == xi.shape[-1], f"{timeinfo[ix].shape[-1]} != {xi.shape[-1]}"
|
1159
1166
|
nLen = torch.ceil(
|
1160
1167
|
timeinfo[ix][0][-1] * self.fs
|
1161
1168
|
).long() + self.nWin
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: nntrf
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.4
|
4
4
|
Home-page: https://github.com/powerfulbean/nnTRF
|
5
5
|
Author: Jin Dou
|
6
6
|
Author-email: jindou.bci@gmail.com
|
@@ -10,9 +10,8 @@ Classifier: Operating System :: OS Independent
|
|
10
10
|
Requires-Python: >=3.8
|
11
11
|
Description-Content-Type: text/markdown
|
12
12
|
License-File: LICENSE
|
13
|
-
Requires-Dist: numpy
|
14
|
-
Requires-Dist: torch
|
15
|
-
Requires-Dist: scikit-fda==0.7.1
|
13
|
+
Requires-Dist: numpy
|
14
|
+
Requires-Dist: torch
|
16
15
|
Requires-Dist: mtrf
|
17
16
|
Requires-Dist: scipy
|
18
17
|
Dynamic: author
|
@@ -20,5 +19,6 @@ Dynamic: author-email
|
|
20
19
|
Dynamic: classifier
|
21
20
|
Dynamic: description-content-type
|
22
21
|
Dynamic: home-page
|
22
|
+
Dynamic: license-file
|
23
23
|
Dynamic: requires-dist
|
24
24
|
Dynamic: requires-python
|
@@ -0,0 +1,13 @@
|
|
1
|
+
nntrf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
nntrf/loss.py,sha256=h-ZzLoBd3S0uRENiqyfwTpb75x6Xp6WjRpNnD-_jraE,651
|
3
|
+
nntrf/metrics.py,sha256=TUkg2yVXhNTW96uxHu4nMYKWvt7jbh_HdHVo9ii3Lpw,782
|
4
|
+
nntrf/utils.py,sha256=h6vdrNvArn88cOhsCXz1_24a-HPiX36RfzNfq-jAdn0,224
|
5
|
+
nntrf/models/__init__.py,sha256=9RSC1DAp-d2ZgKndAk6myWn7ogyMj5ODkLPemBrStnE,71
|
6
|
+
nntrf/models/composite.py,sha256=qgni2yZgTUFAiGP96egSKg5m7PNCmawb6suroV7tuvo,2136
|
7
|
+
nntrf/models/linear.py,sha256=MrsVGnIIkuOOMXcBff1GQOiX2l5isF6_8f_2XOrgTrk,8203
|
8
|
+
nntrf/models/nonlinear.py,sha256=UM-xOgxCsgy9_oMv1Zc84Ep6C9bGnYUXELr_oY-wLZw,45853
|
9
|
+
nntrf-1.0.4.dist-info/licenses/LICENSE,sha256=BxgW39X9rivrb3TVuhpIgjctmWn3HsvEsN4dXtTDNgU,1069
|
10
|
+
nntrf-1.0.4.dist-info/METADATA,sha256=RUJ5wmeeYYQiOGW90CQZpXdcFEzqBhFbF8E00rpVcWw,644
|
11
|
+
nntrf-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
12
|
+
nntrf-1.0.4.dist-info/top_level.txt,sha256=amAdTZ4QBomZv41i3oWF5lnkU6187LpyDIqhpJSCw4g,6
|
13
|
+
nntrf-1.0.4.dist-info/RECORD,,
|
nntrf-1.0.1.dist-info/RECORD
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
nntrf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
nntrf/loss.py,sha256=h-ZzLoBd3S0uRENiqyfwTpb75x6Xp6WjRpNnD-_jraE,651
|
3
|
-
nntrf/metrics.py,sha256=TUkg2yVXhNTW96uxHu4nMYKWvt7jbh_HdHVo9ii3Lpw,782
|
4
|
-
nntrf/utils.py,sha256=h6vdrNvArn88cOhsCXz1_24a-HPiX36RfzNfq-jAdn0,224
|
5
|
-
nntrf/models/__init__.py,sha256=9RSC1DAp-d2ZgKndAk6myWn7ogyMj5ODkLPemBrStnE,71
|
6
|
-
nntrf/models/composite.py,sha256=3nQIy0MZmgo9-OCj0mPD3hcUUkHG0lK1h9tyyZtIaiM,2139
|
7
|
-
nntrf/models/linear.py,sha256=DPcQK6ol0zXaDo8p1S5zb-075fG1L0NbfwMlQzm_fyw,8205
|
8
|
-
nntrf/models/nonlinear.py,sha256=E0bQAa78DMd8YgQ7JOSmyUbdfq_UwgXQTJ9UmzN15vg,45536
|
9
|
-
nntrf-1.0.1.dist-info/LICENSE,sha256=BxgW39X9rivrb3TVuhpIgjctmWn3HsvEsN4dXtTDNgU,1069
|
10
|
-
nntrf-1.0.1.dist-info/METADATA,sha256=kY-Zfhrj2GS5wElgNcX-sVw-n6v2HzZLRKo2_u1Zd8M,678
|
11
|
-
nntrf-1.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
12
|
-
nntrf-1.0.1.dist-info/top_level.txt,sha256=amAdTZ4QBomZv41i3oWF5lnkU6187LpyDIqhpJSCw4g,6
|
13
|
-
nntrf-1.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|