xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/model/vstmd.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import scipy.ndimage
|
|
5
|
+
|
|
6
|
+
from ..core import estmd_core, vstmd_core
|
|
7
|
+
from ..util.compute_module import AreaNMS
|
|
8
|
+
from .backbone import BaseModel
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class vSTMD(BaseModel):
|
|
15
|
+
""" STMDNet: A Lightweight Directional Framework for Motion Pattern Recognition of Tiny Targets.
|
|
16
|
+
|
|
17
|
+
Ref:
|
|
18
|
+
* Xu M, Luan H, Hao Z D, et al. STMDNet: A Lightweight Directional Framework for Motion Pattern Recognition of Tiny Targets[J]. arXiv preprint arXiv:2501.13054, 2025.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
22
|
+
__paraMappingList = {
|
|
23
|
+
# retina
|
|
24
|
+
'sigma1': 'retina.sigma',
|
|
25
|
+
# lamina
|
|
26
|
+
'alpha' : 'lamina.alpha',
|
|
27
|
+
'delta' : 'lamina.delta',
|
|
28
|
+
# medulla
|
|
29
|
+
'g_leak' : ('medulla.on_pathway.g_leak', 'medulla.off_pathway.g_leak'),
|
|
30
|
+
'v_rest' : ('medulla.on_pathway.v_rest', 'medulla.off_pathway.v_rest'),
|
|
31
|
+
'vEx' : ('medulla.on_pathway.v_exci', 'medulla.off_pathway.v_exci'),
|
|
32
|
+
# lobula
|
|
33
|
+
'A' : 'lobula.spatial_inhibition.A',
|
|
34
|
+
'B' : 'lobula.spatial_inhibition.B',
|
|
35
|
+
'e' : 'lobula.spatial_inhibition.e',
|
|
36
|
+
'rho' : 'lobula.spatial_inhibition.rho',
|
|
37
|
+
'sigma2': 'lobula.spatial_inhibition.sigma1',
|
|
38
|
+
'sigma3': 'lobula.spatial_inhibition.sigma2',
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
""" Constructor function """
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
# Initialize components
|
|
46
|
+
self.retina = estmd_core.Retina()
|
|
47
|
+
self.lamina = vstmd_core.Lamina()
|
|
48
|
+
self.medulla = vstmd_core.Medulla()
|
|
49
|
+
self.lobula = vstmd_core.Lobula()
|
|
50
|
+
|
|
51
|
+
self.lamina.alpha = 0.25
|
|
52
|
+
self.medulla.on_pathway.g_leak = 0.35
|
|
53
|
+
self.medulla.off_pathway.g_leak = 0.35
|
|
54
|
+
|
|
55
|
+
def forward(self, modelIpt):
|
|
56
|
+
""" Define the structure of the model """
|
|
57
|
+
retina_output = self.retina.forward(modelIpt)
|
|
58
|
+
lamina_ON, lamina_OFF = self.lamina.forward(retina_output)
|
|
59
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
60
|
+
self.model_output['response'], self.model_output['direction'] \
|
|
61
|
+
= self.lobula.forward(medulla_ON, medulla_OFF, lamina_ON, lamina_OFF)
|
|
62
|
+
|
|
63
|
+
return self.model_output
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class vSTMD_F(vSTMD):
|
|
67
|
+
""" vSTMD_F: vSTMD with Feedback Mechanism.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
71
|
+
__paraMappingList = deepcopy(vSTMD._vSTMD__paraMappingList)
|
|
72
|
+
__paraMappingList.update(
|
|
73
|
+
{'beta' : 'lobula.beta',
|
|
74
|
+
'sigma_4' : 'lobula.sigma', })
|
|
75
|
+
|
|
76
|
+
def __init__(self):
|
|
77
|
+
"""
|
|
78
|
+
STMDNetF Constructor method
|
|
79
|
+
Initializes an instance of the FeedbackSTMD class.
|
|
80
|
+
"""
|
|
81
|
+
# Call superclass constructor
|
|
82
|
+
super().__init__()
|
|
83
|
+
|
|
84
|
+
# Lobula with feedback mechanism
|
|
85
|
+
self.lobula = vstmd_core.Lobula_with_Feedback()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class vSTMD_L(vSTMD):
|
|
89
|
+
""" vSTMD_L: vSTMD Location Module.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
93
|
+
__paraMappingList = deepcopy(vSTMD._vSTMD__paraMappingList)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def forward(self, x):
|
|
97
|
+
""" DaC: Dymanics and Correlate """
|
|
98
|
+
# denoise by Gaussian filter
|
|
99
|
+
retina_output = self.retina.forward(x)
|
|
100
|
+
# temporal difference and signal separation
|
|
101
|
+
lamina_ON, lamina_OFF = self.lamina.forward(retina_output)
|
|
102
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
103
|
+
|
|
104
|
+
# location only
|
|
105
|
+
self.correlation_output = medulla_ON * medulla_OFF
|
|
106
|
+
self.model_output['response'] = self.lobula.spatial_inhibition.forward(self.correlation_output)
|
|
107
|
+
|
|
108
|
+
return self.model_output
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class vSTMD_F_L(vSTMD_F):
|
|
112
|
+
""" vSTMD_L: vSTMD Location Module.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
116
|
+
__paraMappingList = deepcopy(vSTMD_F._vSTMD_F__paraMappingList)
|
|
117
|
+
|
|
118
|
+
def forward(self, x):
|
|
119
|
+
""" DaC: Dymanics and Correlate """
|
|
120
|
+
# denoise by Gaussian filter
|
|
121
|
+
retina_output = self.retina.forward(x)
|
|
122
|
+
# temporal difference and signal separation
|
|
123
|
+
lamina_ON, lamina_OFF = self.lamina.forward(retina_output)
|
|
124
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
125
|
+
|
|
126
|
+
# location only
|
|
127
|
+
self.correlation_output = medulla_ON * medulla_OFF
|
|
128
|
+
self.model_output['response'] = self.lobula.spatial_inhibition.forward(self.correlation_output)
|
|
129
|
+
|
|
130
|
+
return self.model_output
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class vSTMD_M(vSTMD_L):
|
|
134
|
+
__paraMappingList = deepcopy(vSTMD_L._vSTMD_L__paraMappingList)
|
|
135
|
+
|
|
136
|
+
def __init__(self):
|
|
137
|
+
""" Constructor function """
|
|
138
|
+
super().__init__()
|
|
139
|
+
self.direction_computer = vstmd_core.FastEuclideanTracker()
|
|
140
|
+
self.torch_nms = AreaNMS(radio=8)
|
|
141
|
+
|
|
142
|
+
def get_direction_by_matching(self, model_response):
|
|
143
|
+
device = model_response.device
|
|
144
|
+
|
|
145
|
+
direction_output = torch.full_like(model_response, float('nan'), device=device)
|
|
146
|
+
|
|
147
|
+
responses = torch.argwhere(model_response > 0)
|
|
148
|
+
if len(responses) == 0:
|
|
149
|
+
return direction_output
|
|
150
|
+
|
|
151
|
+
tracks = self.direction_computer.update(responses)
|
|
152
|
+
|
|
153
|
+
if len(tracks) > 0:
|
|
154
|
+
tracks_t = torch.as_tensor(tracks, device=device)
|
|
155
|
+
dim_0 = tracks_t[:, 0].long()
|
|
156
|
+
dim_1 = tracks_t[:, 1].long()
|
|
157
|
+
directions = tracks_t[:, 2].float()
|
|
158
|
+
|
|
159
|
+
# 3. 向量化赋值 (高级索引)
|
|
160
|
+
direction_output[0, 0, dim_0, dim_1] = directions
|
|
161
|
+
|
|
162
|
+
return direction_output
|
|
163
|
+
|
|
164
|
+
def forward(self, modelIpt):
|
|
165
|
+
""" forward Method: Defines the structure of the vSTMD_F model. """
|
|
166
|
+
super().forward(modelIpt)
|
|
167
|
+
|
|
168
|
+
response = self.torch_nms(self.model_output['response'])
|
|
169
|
+
self.model_output['direction'] = self.get_direction_by_matching(response)
|
|
170
|
+
|
|
171
|
+
return self.model_output
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class vSTMD_F_M(vSTMD_F_L):
|
|
175
|
+
|
|
176
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
177
|
+
__paraMappingList = deepcopy(vSTMD_L._vSTMD_L__paraMappingList)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def __init__(self):
|
|
181
|
+
""" Constructor function """
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.direction_computer = vstmd_core.FastEuclideanTracker()
|
|
184
|
+
|
|
185
|
+
self.torch_nms = AreaNMS(radio=8)
|
|
186
|
+
|
|
187
|
+
def get_direction_by_matching(self, model_response):
|
|
188
|
+
device = model_response.device
|
|
189
|
+
|
|
190
|
+
direction_output = torch.full_like(model_response, float('nan'), device=device)
|
|
191
|
+
|
|
192
|
+
responses = torch.argwhere(model_response > 0)
|
|
193
|
+
if len(responses) == 0:
|
|
194
|
+
return direction_output
|
|
195
|
+
|
|
196
|
+
tracks = self.direction_computer.update(responses)
|
|
197
|
+
|
|
198
|
+
if len(tracks) > 0:
|
|
199
|
+
tracks_t = torch.as_tensor(tracks, device=device)
|
|
200
|
+
dim_0 = tracks_t[:, 0].long()
|
|
201
|
+
dim_1 = tracks_t[:, 1].long()
|
|
202
|
+
directions = tracks_t[:, 2].float()
|
|
203
|
+
|
|
204
|
+
# 3. 向量化赋值 (高级索引)
|
|
205
|
+
direction_output[0, 0, dim_0, dim_1] = directions
|
|
206
|
+
|
|
207
|
+
return direction_output
|
|
208
|
+
|
|
209
|
+
def forward(self, modelIpt):
|
|
210
|
+
""" forward Method: Defines the structure of the vSTMD_F model. """
|
|
211
|
+
super().forward(modelIpt)
|
|
212
|
+
|
|
213
|
+
response = self.torch_nms(self.model_output['response'])
|
|
214
|
+
self.model_output['direction'] = self.get_direction_by_matching(response)
|
|
215
|
+
|
|
216
|
+
return self.model_output
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# ablation
|
|
223
|
+
class vSTMD_without_GF(vSTMD):
|
|
224
|
+
|
|
225
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
226
|
+
__paraMappingList = {
|
|
227
|
+
# retina
|
|
228
|
+
'sigma1': 'retina.sigma',
|
|
229
|
+
# lamina
|
|
230
|
+
'alpha' : 'lamina.alpha',
|
|
231
|
+
'delta' : 'lamina.delta',
|
|
232
|
+
# medulla
|
|
233
|
+
'g_leak' : ('medulla.on_pathway.g_leak', 'medulla.off_pathway.g_leak'),
|
|
234
|
+
'v_rest' : ('medulla.on_pathway.v_rest', 'medulla.off_pathway.v_rest'),
|
|
235
|
+
'vEx' : ('medulla.on_pathway.v_exci', 'medulla.off_pathway.v_exci'),
|
|
236
|
+
# lobula
|
|
237
|
+
'A' : 'lobula.spatial_inhibition.A',
|
|
238
|
+
'B' : 'lobula.spatial_inhibition.B',
|
|
239
|
+
'e' : 'lobula.spatial_inhibition.e',
|
|
240
|
+
'rho' : 'lobula.spatial_inhibition.rho',
|
|
241
|
+
'sigma2': 'lobula.spatial_inhibition.sigma1',
|
|
242
|
+
'sigma3': 'lobula.spatial_inhibition.sigma2',
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
def forward(self, x):
|
|
246
|
+
""" Define the structure of the model """
|
|
247
|
+
lamina_ON, lamina_OFF = self.lamina.forward(x)
|
|
248
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
249
|
+
self.model_output['response'], self.model_output['direction'] \
|
|
250
|
+
= self.lobula.forward(medulla_ON, medulla_OFF, lamina_ON, lamina_OFF)
|
|
251
|
+
|
|
252
|
+
return self.model_output
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class vSTMD_without_cIDP(vSTMD):
|
|
256
|
+
|
|
257
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
258
|
+
__paraMappingList = {
|
|
259
|
+
# retina
|
|
260
|
+
'sigma1': 'retina.sigma',
|
|
261
|
+
# lamina
|
|
262
|
+
'alpha' : 'lamina.alpha',
|
|
263
|
+
'delta' : 'lamina.delta',
|
|
264
|
+
# medulla
|
|
265
|
+
'g_leak' : ('medulla.on_pathway.g_leak', 'medulla.off_pathway.g_leak'),
|
|
266
|
+
'v_rest' : ('medulla.on_pathway.v_rest', 'medulla.off_pathway.v_rest'),
|
|
267
|
+
'vEx' : ('medulla.on_pathway.v_exci', 'medulla.off_pathway.v_exci'),
|
|
268
|
+
# lobula
|
|
269
|
+
'A' : 'lobula.spatial_inhibition.A',
|
|
270
|
+
'B' : 'lobula.spatial_inhibition.B',
|
|
271
|
+
'e' : 'lobula.spatial_inhibition.e',
|
|
272
|
+
'rho' : 'lobula.spatial_inhibition.rho',
|
|
273
|
+
'sigma2': 'lobula.spatial_inhibition.sigma1',
|
|
274
|
+
'sigma3': 'lobula.spatial_inhibition.sigma2',
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
def __init__(self):
|
|
278
|
+
from ..core.math_operator import GammaDelay
|
|
279
|
+
""" Constructor function """
|
|
280
|
+
super().__init__()
|
|
281
|
+
|
|
282
|
+
# Initialize components
|
|
283
|
+
self.gamma_delay = GammaDelay(12, 25)
|
|
284
|
+
|
|
285
|
+
def forward(self, x):
|
|
286
|
+
""" Define the structure of the model """
|
|
287
|
+
retina_output = self.retina.forward(x)
|
|
288
|
+
lamina_ON, lamina_OFF = self.lamina.forward(retina_output)
|
|
289
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
290
|
+
delayed_medulla_OFF = self.gamma_delay.forward(medulla_OFF)
|
|
291
|
+
self.model_output['response'], self.model_output['direction'] \
|
|
292
|
+
= self.lobula.forward(medulla_ON, delayed_medulla_OFF, lamina_ON, lamina_OFF)
|
|
293
|
+
|
|
294
|
+
return self.model_output
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class vSTMD_without_CDGC(BaseModel):
|
|
298
|
+
|
|
299
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
300
|
+
__paraMappingList = {}
|
|
301
|
+
|
|
302
|
+
def __init__(self):
|
|
303
|
+
from ..model.backbone import DSTMD
|
|
304
|
+
|
|
305
|
+
super().__init__()
|
|
306
|
+
self.location_part = vSTMD()
|
|
307
|
+
self.direction_part = DSTMD()
|
|
308
|
+
|
|
309
|
+
def forward(self, modelIpt):
|
|
310
|
+
""" Define the structure of the model """
|
|
311
|
+
self.location_part.forward(modelIpt)
|
|
312
|
+
self.model_output['response'] = self.location_part.model_output['response']
|
|
313
|
+
|
|
314
|
+
self.direction_part.forward(modelIpt)
|
|
315
|
+
_direction = self.direction_part.model_output['direction']
|
|
316
|
+
self.model_output['direction'] = self.match_direction(self.model_output['response'], _direction)
|
|
317
|
+
|
|
318
|
+
return self.model_output
|
|
319
|
+
|
|
320
|
+
@staticmethod
|
|
321
|
+
def match_direction(response, direction):
|
|
322
|
+
"""
|
|
323
|
+
Fill NaN direction values for positive response locations with the nearest non-NaN direction value.
|
|
324
|
+
Uses matrix operations for acceleration.
|
|
325
|
+
"""
|
|
326
|
+
mask_nan = torch.isnan(direction)
|
|
327
|
+
mask_pos = response > 0
|
|
328
|
+
|
|
329
|
+
# Only fill where response > 0 and direction is nan
|
|
330
|
+
fill_mask = mask_nan & mask_pos
|
|
331
|
+
|
|
332
|
+
# Create a mask of valid direction locations
|
|
333
|
+
valid_mask = ~mask_nan
|
|
334
|
+
|
|
335
|
+
# Replace NaNs with 0 for distance calculation
|
|
336
|
+
direction_filled = torch.where(valid_mask, direction, 0)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# 1. 提取布尔掩码,放回 CPU 并转为 NumPy 数组
|
|
340
|
+
# (~valid_mask) 等价于 valid_mask == 0,即 NaN 的地方为 True (需要被填充的区域)
|
|
341
|
+
target_mask_np = (~valid_mask).cpu().numpy()
|
|
342
|
+
|
|
343
|
+
# 2. 调用 SciPy 计算距离变换和索引
|
|
344
|
+
distance_np, indices_np = scipy.ndimage.distance_transform_edt(
|
|
345
|
+
target_mask_np, return_indices=True
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# 3. 将 NumPy 格式的索引转回 PyTorch,并放到原来的设备 (GPU) 上
|
|
349
|
+
indices = torch.from_numpy(indices_np).to(device=direction.device, dtype=torch.long)
|
|
350
|
+
|
|
351
|
+
# Get nearest valid direction for each pixel
|
|
352
|
+
nearest_direction = direction_filled[tuple(indices)]
|
|
353
|
+
|
|
354
|
+
# Prepare output
|
|
355
|
+
directionOpt = direction.clone()
|
|
356
|
+
|
|
357
|
+
# Fill only where fill_mask is True
|
|
358
|
+
directionOpt[fill_mask] = nearest_direction[fill_mask]
|
|
359
|
+
|
|
360
|
+
return directionOpt
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class vSTMD_F_without_GF(vSTMD_F):
|
|
364
|
+
|
|
365
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
366
|
+
__paraMappingList = {
|
|
367
|
+
# retina
|
|
368
|
+
'sigma1': 'retina.sigma',
|
|
369
|
+
# lamina
|
|
370
|
+
'alpha' : 'lamina.alpha',
|
|
371
|
+
'delta' : 'lamina.delta',
|
|
372
|
+
# medulla
|
|
373
|
+
'g_leak' : ('medulla.on_pathway.g_leak', 'medulla.off_pathway.g_leak'),
|
|
374
|
+
'v_rest' : ('medulla.on_pathway.v_rest', 'medulla.off_pathway.v_rest'),
|
|
375
|
+
'vEx' : ('medulla.on_pathway.v_exci', 'medulla.off_pathway.v_exci'),
|
|
376
|
+
# lobula
|
|
377
|
+
'A' : 'lobula.spatial_inhibition.A',
|
|
378
|
+
'B' : 'lobula.spatial_inhibition.B',
|
|
379
|
+
'e' : 'lobula.spatial_inhibition.e',
|
|
380
|
+
'rho' : 'lobula.spatial_inhibition.rho',
|
|
381
|
+
'sigma2': 'lobula.spatial_inhibition.sigma1',
|
|
382
|
+
'sigma3': 'lobula.spatial_inhibition.sigma2',
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
def forward(self, x):
|
|
386
|
+
""" Define the structure of the model """
|
|
387
|
+
lamina_ON, lamina_OFF = self.lamina.forward(x)
|
|
388
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
389
|
+
self.model_output['response'], self.model_output['direction'] \
|
|
390
|
+
= self.lobula.forward(medulla_ON, medulla_OFF, lamina_ON, lamina_OFF)
|
|
391
|
+
|
|
392
|
+
return self.model_output
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class vSTMD_F_without_cIDP(vSTMD_F):
|
|
396
|
+
|
|
397
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
398
|
+
__paraMappingList = {
|
|
399
|
+
# retina
|
|
400
|
+
'sigma1': 'retina.sigma',
|
|
401
|
+
# lamina
|
|
402
|
+
'alpha' : 'lamina.alpha',
|
|
403
|
+
'delta' : 'lamina.delta',
|
|
404
|
+
# medulla
|
|
405
|
+
'g_leak' : ('medulla.on_pathway.g_leak', 'medulla.off_pathway.g_leak'),
|
|
406
|
+
'v_rest' : ('medulla.on_pathway.v_rest', 'medulla.off_pathway.v_rest'),
|
|
407
|
+
'vEx' : ('medulla.on_pathway.v_exci', 'medulla.off_pathway.v_exci'),
|
|
408
|
+
# lobula
|
|
409
|
+
'A' : 'lobula.spatial_inhibition.A',
|
|
410
|
+
'B' : 'lobula.spatial_inhibition.B',
|
|
411
|
+
'e' : 'lobula.spatial_inhibition.e',
|
|
412
|
+
'rho' : 'lobula.spatial_inhibition.rho',
|
|
413
|
+
'sigma2': 'lobula.spatial_inhibition.sigma1',
|
|
414
|
+
'sigma3': 'lobula.spatial_inhibition.sigma2',
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
def __init__(self):
|
|
418
|
+
from ..core.math_operator import GammaDelay
|
|
419
|
+
""" Constructor function """
|
|
420
|
+
super().__init__()
|
|
421
|
+
|
|
422
|
+
self.gamma_delay = GammaDelay(12, 25)
|
|
423
|
+
|
|
424
|
+
def forward(self, x):
|
|
425
|
+
""" Define the structure of the model """
|
|
426
|
+
retina_output = self.retina.forward(x)
|
|
427
|
+
lamina_ON, lamina_OFF = self.lamina.forward(retina_output)
|
|
428
|
+
medulla_ON, medulla_OFF = self.medulla.forward(lamina_ON, lamina_OFF)
|
|
429
|
+
delayed_medulla_OFF = self.gamma_delay.forward(medulla_OFF)
|
|
430
|
+
self.model_output['response'], self.model_output['direction'] \
|
|
431
|
+
= self.lobula.forward(medulla_ON, delayed_medulla_OFF, lamina_ON, lamina_OFF)
|
|
432
|
+
|
|
433
|
+
return self.model_output
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class vSTMD_F_without_CDGC(BaseModel):
|
|
437
|
+
|
|
438
|
+
# Bind model parameters and their corresponding parameter pointers.
|
|
439
|
+
__paraMappingList = {}
|
|
440
|
+
|
|
441
|
+
def __init__(self):
|
|
442
|
+
from ..model.backbone import DSTMD
|
|
443
|
+
|
|
444
|
+
super().__init__()
|
|
445
|
+
self.location_part = vSTMD_F()
|
|
446
|
+
self.direction_part = DSTMD()
|
|
447
|
+
|
|
448
|
+
def forward(self, modelIpt):
|
|
449
|
+
""" Define the structure of the model """
|
|
450
|
+
self.location_part.forward(modelIpt)
|
|
451
|
+
self.model_output['response'] = self.location_part.model_output['response']
|
|
452
|
+
|
|
453
|
+
self.direction_part.forward(modelIpt)
|
|
454
|
+
_direction = self.direction_part.model_output['direction']
|
|
455
|
+
self.model_output['direction'] = vSTMD_without_CDGC.match_direction(self.model_output['response'], _direction)
|
|
456
|
+
|
|
457
|
+
return self.model_output
|
xttmp/util/__init__.py
ADDED
|
File without changes
|