xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/core/dstmd_core.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from .estmd_core import Tm1, Mi1
|
|
8
|
+
from .math_operator import GammaDelay, SpatialInhibition
|
|
9
|
+
from ..util.create_kernel import create_direction_inhi_kernel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Medulla(BaseCore):
|
|
13
|
+
"""Medulla class for motion detection."""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
"""Constructor method."""
|
|
17
|
+
# Initializes the Medulla object
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
# Initialize components
|
|
21
|
+
self.tm3 = Tm3()
|
|
22
|
+
self.mi1_para4 = Mi1(3, 15)
|
|
23
|
+
|
|
24
|
+
self.tm2 = Tm2()
|
|
25
|
+
self.tm1_para5 = Tm1(5, 25)
|
|
26
|
+
self.tm1_para6 = Tm1(8, 40)
|
|
27
|
+
|
|
28
|
+
def setup(self):
|
|
29
|
+
"""Initialization method."""
|
|
30
|
+
# Initializes the delay components
|
|
31
|
+
|
|
32
|
+
self.mi1_para4.setup()
|
|
33
|
+
self.tm1_para5.setup()
|
|
34
|
+
self.tm1_para6.setup()
|
|
35
|
+
|
|
36
|
+
def forward(self, x):
|
|
37
|
+
"""Processing method."""
|
|
38
|
+
# Processes input signals and produces output
|
|
39
|
+
|
|
40
|
+
# Process Tm3 and Tm2 signals
|
|
41
|
+
tm3_output = self.tm3.forward(x) # L ON
|
|
42
|
+
tm2_output = self.tm2.forward(x) # L OFF
|
|
43
|
+
|
|
44
|
+
# Process signals with delays
|
|
45
|
+
mi1_para4_output = self.mi1_para4.forward(tm3_output)
|
|
46
|
+
tm1_para5_output = self.tm1_para5.forward(tm2_output)
|
|
47
|
+
tm1_para6_output = self.tm1_para6.forward(tm2_output)
|
|
48
|
+
|
|
49
|
+
# Output signals
|
|
50
|
+
self.output = (tm3_output, mi1_para4_output, tm1_para5_output, tm1_para6_output)
|
|
51
|
+
return self.output
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Tm2(BaseCore):
|
|
55
|
+
"""Tm2 class for motion detection."""
|
|
56
|
+
|
|
57
|
+
def forward(self, iptMatrix):
|
|
58
|
+
"""Processing method."""
|
|
59
|
+
# Processes the input matrix by performing a maximum operation with zero for negative values
|
|
60
|
+
self.output = torch.clamp(-iptMatrix, min=0)
|
|
61
|
+
|
|
62
|
+
return self.output
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Tm3(BaseCore):
|
|
66
|
+
"""Tm3 class for motion detection."""
|
|
67
|
+
|
|
68
|
+
def forward(self, iptMatrix):
|
|
69
|
+
"""Processing method."""
|
|
70
|
+
# Processes the input matrix by performing a maximum operation with zero for negative values
|
|
71
|
+
|
|
72
|
+
self.output = torch.clamp(iptMatrix, min=0)
|
|
73
|
+
|
|
74
|
+
return self.output
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Lobula(BaseCore):
|
|
78
|
+
"""Lobula class for motion detection."""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
"""Constructor method."""
|
|
82
|
+
# Initializes the Lobula object
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.alpha1 = 3 # Alpha parameter
|
|
85
|
+
|
|
86
|
+
self.register_buffer('theta_list', torch.tensor([(i * torch.pi / 4) for i in range(8)]))
|
|
87
|
+
self.hLateralInhi = SpatialInhibition() # Lateral inhibition component
|
|
88
|
+
self.hDirectionInhi = DirectionInhibition() # Directional inhibition component
|
|
89
|
+
|
|
90
|
+
def setup(self):
|
|
91
|
+
"""Initialization method."""
|
|
92
|
+
# Initializes the lateral and directional inhibition components
|
|
93
|
+
self.hLateralInhi.setup()
|
|
94
|
+
self.hDirectionInhi.setup()
|
|
95
|
+
|
|
96
|
+
def forward(self, tm3, mi1_p4, tm1_p5, tm1_p6):
|
|
97
|
+
# tm3, mi1_p4, tm1_p5, tm1_p6 形状均为 [1, 1, H, W]
|
|
98
|
+
_, _, imgH, imgW = tm3.shape
|
|
99
|
+
device = tm3.device
|
|
100
|
+
|
|
101
|
+
num_thetas = len(self.theta_list)
|
|
102
|
+
a1 = self.alpha1
|
|
103
|
+
|
|
104
|
+
# 2. 计算偏移索引
|
|
105
|
+
shifts_x = torch.round(a1 * torch.cos(self.theta_list)).long() # 形状: [num_thetas]
|
|
106
|
+
shifts_y = torch.round(a1 * torch.sin(self.theta_list)).long() # 形状: [num_thetas]
|
|
107
|
+
|
|
108
|
+
# 3. 提取中心 ROI
|
|
109
|
+
y_s, y_e = a1, imgH - a1
|
|
110
|
+
x_s, x_e = a1, imgW - a1
|
|
111
|
+
|
|
112
|
+
# 提取不变部分的 ROI,保持 4D 形状: [1, 1, h_roi, w_roi]
|
|
113
|
+
tm3_roi = tm3[:, :, y_s:y_e, x_s:x_e]
|
|
114
|
+
tm1_p5_roi = tm1_p5[:, :, y_s:y_e, x_s:x_e]
|
|
115
|
+
|
|
116
|
+
# 4. 生成偏移索引网格 (关键点)
|
|
117
|
+
grid_y, grid_x = torch.meshgrid(
|
|
118
|
+
torch.arange(y_s, y_e, device=device),
|
|
119
|
+
torch.arange(x_s, x_e, device=device),
|
|
120
|
+
indexing='ij'
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# 计算所有方向的索引: [num_thetas, h_roi, w_roi]
|
|
124
|
+
src_idx_x = grid_x.unsqueeze(0).unsqueeze(0) - shifts_x .view(1, -1, 1, 1)
|
|
125
|
+
src_idx_y = grid_y.unsqueeze(0).unsqueeze(0) + shifts_y.view(1, -1, 1, 1)
|
|
126
|
+
|
|
127
|
+
# 5. 高级索引提取偏移信号
|
|
128
|
+
# mi1_p4[0, 0] 是 [H, W],通过 src_idx 提取后变成 [num_thetas, h_roi, w_roi]
|
|
129
|
+
# 我们将其扩展回 4D: [1, num_thetas, h_roi, w_roi]
|
|
130
|
+
mi1_p4_shifted = mi1_p4[0, 0, src_idx_y, src_idx_x]
|
|
131
|
+
tm1_p6_shifted = tm1_p6[0, 0, src_idx_y, src_idx_x]
|
|
132
|
+
|
|
133
|
+
# 6. 计算相关输出 (利用广播)
|
|
134
|
+
# tm3_roi: [1, 1, h_roi, w_roi]
|
|
135
|
+
# mi1_p4_shifted: [1, num_thetas, h_roi, w_roi]
|
|
136
|
+
# 结果 corre_roi: [1, num_thetas, h_roi, w_roi]
|
|
137
|
+
corre_roi = tm3_roi * (tm1_p5_roi + mi1_p4_shifted) * tm1_p6_shifted
|
|
138
|
+
|
|
139
|
+
# 7. 填回全零张量
|
|
140
|
+
correOutput = torch.zeros((1, num_thetas, imgH, imgW), device=device)
|
|
141
|
+
correOutput[:, :, y_s:y_e, x_s:x_e] = corre_roi
|
|
142
|
+
|
|
143
|
+
lateralInhioutput = self.hLateralInhi.forward(correOutput)
|
|
144
|
+
self.output = self.hDirectionInhi.forward(lateralInhioutput)
|
|
145
|
+
|
|
146
|
+
return self.output
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class DirectionInhibition(BaseCore):
|
|
150
|
+
"""Directional inhibition in DSTMD."""
|
|
151
|
+
|
|
152
|
+
def __init__(self):
|
|
153
|
+
"""Constructor method."""
|
|
154
|
+
# Initializes the DirectionInhi object
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.direction = 8 # Number of directions
|
|
157
|
+
self.sigma1 = 1.5 # Sigma for the first Gaussian kernel
|
|
158
|
+
self.sigma2 = 3.0 # Sigma for the second Gaussian kernel
|
|
159
|
+
|
|
160
|
+
self.register_buffer("diretional_inhi_kernel", torch.empty(0)) # Placeholder for the directional inhibition kernel
|
|
161
|
+
|
|
162
|
+
self.setup() # Initialize the kernel
|
|
163
|
+
|
|
164
|
+
def setup(self):
|
|
165
|
+
"""Initialization method."""
|
|
166
|
+
# Initializes the directional inhibition kernel
|
|
167
|
+
|
|
168
|
+
_diretional_inhi_kernel = create_direction_inhi_kernel(
|
|
169
|
+
self.direction, self.sigma1, self.sigma2
|
|
170
|
+
)
|
|
171
|
+
# Shape: [1, 1, kernel_size]
|
|
172
|
+
self.diretional_inhi_kernel.data = _diretional_inhi_kernel
|
|
173
|
+
|
|
174
|
+
def forward(self, x):
|
|
175
|
+
"""
|
|
176
|
+
Input x shape: [B, C, H, W], where C = self.direction
|
|
177
|
+
Output shape: [B, C, H, W]
|
|
178
|
+
"""
|
|
179
|
+
# Performs directional inhibition on the input
|
|
180
|
+
|
|
181
|
+
# 1. 准备数据维度
|
|
182
|
+
# 输入是 [B, C, H, W],卷积需要在 C 维度上滑,所以要把 H, W 暂时视为 Batch
|
|
183
|
+
b, c, h, w = x.shape
|
|
184
|
+
|
|
185
|
+
# 转换形状: [B, C, H, W] -> [B, H, W, C] -> [B*H*W, C]
|
|
186
|
+
# 这样对于 conv1d 来说,BatchSize = B*H*W, 通道数 = 1, 序列长度 = C
|
|
187
|
+
x = x.permute(0, 2, 3, 1).reshape(b * h * w, 1, c)
|
|
188
|
+
|
|
189
|
+
# 3. 执行循环卷积 (Circular Convolution)
|
|
190
|
+
# padding 设为 kernel_size // 2,且模式设为 'circular'
|
|
191
|
+
pad_size = self.diretional_inhi_kernel.shape[-1]
|
|
192
|
+
center_idx = pad_size // 2
|
|
193
|
+
|
|
194
|
+
pad_left = center_idx
|
|
195
|
+
pad_right = pad_size - center_idx - 1
|
|
196
|
+
|
|
197
|
+
# F.pad 在 1D 信号上的填充格式是 (left, right)
|
|
198
|
+
x_padded = F.pad(x, (pad_left, pad_right), mode='circular')
|
|
199
|
+
|
|
200
|
+
# F.conv1d 会在 C 维度(方向轴)上滑动
|
|
201
|
+
# 结果形状依然是 [B*H*W, 1, C]
|
|
202
|
+
result = F.conv1d(x_padded, self.diretional_inhi_kernel)
|
|
203
|
+
|
|
204
|
+
# 4. 激活与恢复形状
|
|
205
|
+
opt = F.relu(result)
|
|
206
|
+
|
|
207
|
+
# 恢复回 [B, C, H, W]
|
|
208
|
+
return opt.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .base_core import BaseCore
|
|
4
|
+
from .math_operator import GammaBandPassFilter, SpatialInhibition
|
|
5
|
+
from . import estmd_core
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Lamina(GammaBandPassFilter):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Medulla(BaseCore):
|
|
13
|
+
"""Medulla layer of the motion detection system."""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
"""Constructor method."""
|
|
17
|
+
# Initializes the Medulla object
|
|
18
|
+
super().__init__()
|
|
19
|
+
# Initialize components
|
|
20
|
+
self.tm1 = estmd_core.Tm1()
|
|
21
|
+
self.mi1 = estmd_core.Mi1()
|
|
22
|
+
self.tm2 = Tm2()
|
|
23
|
+
self.tm3 = Tm3()
|
|
24
|
+
|
|
25
|
+
def setup(self):
|
|
26
|
+
"""Initialization method."""
|
|
27
|
+
# This method initializes the Medulla layer components
|
|
28
|
+
self.tm1.setup()
|
|
29
|
+
self.mi1.setup()
|
|
30
|
+
self.tm2.setup()
|
|
31
|
+
self.tm3.setup()
|
|
32
|
+
|
|
33
|
+
def forward(self, MedullaIpt):
|
|
34
|
+
"""Processing method."""
|
|
35
|
+
# Applies processing to the input and returns the output
|
|
36
|
+
|
|
37
|
+
# Process Tm2 and Tm3 components
|
|
38
|
+
tm2_output = self.tm2.forward(MedullaIpt)
|
|
39
|
+
tm3_output = self.tm3.forward(MedullaIpt)
|
|
40
|
+
|
|
41
|
+
# Process Tm1 component using output of Tm2
|
|
42
|
+
tm1_output = self.tm1.forward(tm2_output)
|
|
43
|
+
|
|
44
|
+
# Store the output signals in output property
|
|
45
|
+
self.output = [tm3_output, tm1_output]
|
|
46
|
+
|
|
47
|
+
return self.output
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Lobula(BaseCore):
|
|
51
|
+
"""Lobula layer of the motion detection system."""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
"""Constructor method."""
|
|
55
|
+
# Initializes the Lobula object
|
|
56
|
+
super().__init__()
|
|
57
|
+
# Initialize the SpatialInhibition component
|
|
58
|
+
self.spatial_inhibition = SpatialInhibition()
|
|
59
|
+
# Parameters related to the recombination of ON and OFF channels
|
|
60
|
+
self.a = 0
|
|
61
|
+
self.b = 0
|
|
62
|
+
self.c = 1
|
|
63
|
+
|
|
64
|
+
def setup(self):
|
|
65
|
+
"""Initialization method."""
|
|
66
|
+
# This method initializes the Lobula layer component
|
|
67
|
+
self.spatial_inhibition.setup()
|
|
68
|
+
|
|
69
|
+
def forward(self, on_output, off_output):
|
|
70
|
+
"""Processing method."""
|
|
71
|
+
# Performs a correlation operation on the ON and OFF channels
|
|
72
|
+
# and then applies surround inhibition
|
|
73
|
+
|
|
74
|
+
# Perform the correlation operation
|
|
75
|
+
correlationOutput = (
|
|
76
|
+
self.a * on_output +
|
|
77
|
+
self.b * off_output +
|
|
78
|
+
self.c * on_output * off_output
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Apply surround inhibition
|
|
82
|
+
lobulaoutput = self.spatial_inhibition.forward(correlationOutput)
|
|
83
|
+
|
|
84
|
+
# Store the output in output property
|
|
85
|
+
self.output = lobulaoutput
|
|
86
|
+
return lobulaoutput, correlationOutput
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Tm2(BaseCore):
|
|
90
|
+
"""Tm2 without spatial inhibition."""
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
"""Processing method."""
|
|
93
|
+
# Applies surround inhibition to the input to generate the output
|
|
94
|
+
self.output = torch.clamp(-x, min=0) # Apply surround inhibition
|
|
95
|
+
return self.output
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Tm3(BaseCore):
|
|
99
|
+
"""Tm3 without spatial inhibition."""
|
|
100
|
+
def forward(self, tm3outputIpt):
|
|
101
|
+
"""Processing method."""
|
|
102
|
+
# Applies a surround inhibition to the input to generate the output
|
|
103
|
+
self.output = torch.clamp(tm3outputIpt, min=0) # Apply surround inhibition
|
|
104
|
+
return self.output
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
xttmp/core/estmd_core.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from .math_operator import (compute_temporal_conv_inplace,
|
|
8
|
+
GaussianBlur, SpatialInhibition,
|
|
9
|
+
GammaDelay, GammaBandPassFilter)
|
|
10
|
+
from ..util.create_kernel import create_2d_gaussian_kernel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Retina(GaussianBlur):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Lamina(BaseCore):
|
|
18
|
+
"""
|
|
19
|
+
LAMINA Lamina layer
|
|
20
|
+
This class implements the Lamina layer of the ESTMD
|
|
21
|
+
|
|
22
|
+
Author: Mingshuo Xu
|
|
23
|
+
Date: 2024-04-29
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
"""
|
|
28
|
+
Constructor
|
|
29
|
+
Initializes the Lamina object and creates GammaBankPassFilter
|
|
30
|
+
and LaminaLateralInhibition objects
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.gamma_BPF = GammaBandPassFilter()
|
|
34
|
+
self.spatial_inhibition = LaminaLateralInhibition()
|
|
35
|
+
|
|
36
|
+
def setup(self):
|
|
37
|
+
self.gamma_BPF.setup()
|
|
38
|
+
self.spatial_inhibition.setup()
|
|
39
|
+
|
|
40
|
+
def reset_buffer(self):
|
|
41
|
+
self.gamma_BPF.reset_buffer()
|
|
42
|
+
|
|
43
|
+
def forward(self, laminaIpt):
|
|
44
|
+
"""
|
|
45
|
+
Processing method
|
|
46
|
+
Applies GammaBankPassFilter and LaminaLateralInhibition to the input matrix
|
|
47
|
+
|
|
48
|
+
Parameters:
|
|
49
|
+
- laminaIpt: Input matrix
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
- laminaOpt: Processed output matrix
|
|
53
|
+
"""
|
|
54
|
+
signalWithBPF = self.gamma_BPF.forward(laminaIpt)
|
|
55
|
+
self.output = self.spatial_inhibition.forward(signalWithBPF)
|
|
56
|
+
|
|
57
|
+
return self.output
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Medulla(BaseCore):
|
|
61
|
+
"""
|
|
62
|
+
Medulla Layer of DSTMD
|
|
63
|
+
This class implements the Medulla layer of the ESTMD.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self):
|
|
67
|
+
"""
|
|
68
|
+
Constructor method
|
|
69
|
+
Initializes the Medulla object
|
|
70
|
+
"""
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.tm1 = Tm1(order=12, tau=25) # Initialize Tm1 object
|
|
73
|
+
self.tm2 = Tm2() # Initialize Tm2 object
|
|
74
|
+
self.tm3 = Tm3() # Initialize Tm3 object
|
|
75
|
+
self.mi1 = Mi1(order=12, tau=25) # Initialize Tm3 object
|
|
76
|
+
|
|
77
|
+
def setup(self):
|
|
78
|
+
"""
|
|
79
|
+
Initialization method
|
|
80
|
+
Initializes the Tm1, Tm2, and Tm3 objects
|
|
81
|
+
"""
|
|
82
|
+
self.tm1.setup()
|
|
83
|
+
self.tm2.setup()
|
|
84
|
+
self.tm3.setup()
|
|
85
|
+
self.mi1.setup()
|
|
86
|
+
|
|
87
|
+
def reset_buffer(self):
|
|
88
|
+
"""
|
|
89
|
+
Buffer reset method
|
|
90
|
+
Resets the buffers of Tm1, Tm2, and Tm3 objects
|
|
91
|
+
"""
|
|
92
|
+
self.tm1.reset_buffer()
|
|
93
|
+
self.mi1.reset_buffer()
|
|
94
|
+
|
|
95
|
+
def forward(self, x):
|
|
96
|
+
"""
|
|
97
|
+
Processing method
|
|
98
|
+
Processes the input MedullaIpt through Tm1, Tm2, and Tm3 layers
|
|
99
|
+
|
|
100
|
+
Parameters:
|
|
101
|
+
- MedullaIpt: Input matrix
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
- tm3Signal: Output of Tm3 layer
|
|
105
|
+
- tm1Signal: Output of Tm1 layer
|
|
106
|
+
"""
|
|
107
|
+
tm2_output = self.tm2.forward(x) # Process input through Tm2
|
|
108
|
+
tm3_output = self.tm3.forward(x) # Process input through Tm3
|
|
109
|
+
|
|
110
|
+
tm1_output = self.tm1.forward(tm2_output) # Process Tm2 output through Tm1
|
|
111
|
+
|
|
112
|
+
self.output = (tm3_output, tm1_output) # Update output property with output
|
|
113
|
+
return tm3_output, tm1_output
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Lobula(BaseCore):
|
|
117
|
+
"""
|
|
118
|
+
Lobula Layer of DSTMD
|
|
119
|
+
This class implements the Lobula layer of the ESTMD.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(self):
|
|
123
|
+
"""
|
|
124
|
+
Constructor method
|
|
125
|
+
Initializes the Lobula object
|
|
126
|
+
"""
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.a = 0 # Parameter a
|
|
129
|
+
self.b = 0 # Parameter b
|
|
130
|
+
self.c = 1 # Parameter c
|
|
131
|
+
|
|
132
|
+
def setup(self):
|
|
133
|
+
"""
|
|
134
|
+
Initialization method
|
|
135
|
+
"""
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
def forward(self, onSignal, offSignal):
|
|
139
|
+
"""
|
|
140
|
+
Processing method
|
|
141
|
+
Processes the input ON and OFF signals
|
|
142
|
+
|
|
143
|
+
Parameters:
|
|
144
|
+
- varagein: Tuple containing ON and OFF signals
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
- lobulaOpt: Output of the Lobula layer
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
# Compute Lobula output using the provided formula
|
|
151
|
+
self.output = self.a*onSignal + self.b*offSignal + self.c*onSignal*offSignal
|
|
152
|
+
|
|
153
|
+
return self.output
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class Mi1(GammaDelay):
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class Tm1(GammaDelay):
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Tm2(BaseCore):
|
|
165
|
+
"""
|
|
166
|
+
Tm2 Medulla Layer Neurons in ESTMD
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, device='cpu'):
|
|
170
|
+
"""
|
|
171
|
+
Constructor method
|
|
172
|
+
Initializes the Tm2 object
|
|
173
|
+
"""
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.spatial_inhibition = SpatialInhibition() # Initialize SurroundInhibition object
|
|
176
|
+
|
|
177
|
+
def setup(self):
|
|
178
|
+
"""
|
|
179
|
+
Initialization method
|
|
180
|
+
Initializes the SurroundInhibition object
|
|
181
|
+
"""
|
|
182
|
+
self.spatial_inhibition.setup()
|
|
183
|
+
|
|
184
|
+
def forward(self, x):
|
|
185
|
+
"""
|
|
186
|
+
Processing method
|
|
187
|
+
Applies the Surround Inhibition mechanism to the input matrix iptMatrix
|
|
188
|
+
|
|
189
|
+
Parameters:
|
|
190
|
+
- iptMatrix: Input matrix
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
- tm2Opt: Output of the Tm2 layer
|
|
194
|
+
"""
|
|
195
|
+
# Extract the OFF signal from iptMatrix
|
|
196
|
+
L_OFF = torch.clamp(-x, min=0)
|
|
197
|
+
|
|
198
|
+
# Process the OFF signal using SurroundInhibition
|
|
199
|
+
self.output = self.spatial_inhibition.forward(L_OFF)
|
|
200
|
+
|
|
201
|
+
return self.output
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class Tm3(BaseCore):
|
|
205
|
+
""" Tm3 """
|
|
206
|
+
|
|
207
|
+
def __init__(self, device='cpu'):
|
|
208
|
+
""" Constructor method
|
|
209
|
+
|
|
210
|
+
Initializes the Tm3 object
|
|
211
|
+
"""
|
|
212
|
+
super().__init__()
|
|
213
|
+
self.spatial_inhibition = SpatialInhibition() # Initialize SurroundInhibition object
|
|
214
|
+
|
|
215
|
+
def setup(self):
|
|
216
|
+
""" Initialization method
|
|
217
|
+
|
|
218
|
+
Initializes the SurroundInhibition object
|
|
219
|
+
"""
|
|
220
|
+
self.spatial_inhibition.setup()
|
|
221
|
+
|
|
222
|
+
def forward(self, iptMatrix):
|
|
223
|
+
""" Processing method
|
|
224
|
+
|
|
225
|
+
Description:
|
|
226
|
+
Applies Surround Inhibition to the On-signal matrix iptMatrix
|
|
227
|
+
|
|
228
|
+
Parameters:
|
|
229
|
+
- iptMatrix: Input matrix
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
- tm3Opt: Output of the Tm3 layer
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
L_ON = torch.clamp(iptMatrix, min=0)
|
|
236
|
+
|
|
237
|
+
self.output = self.spatial_inhibition.forward(L_ON) # Processes the On-signal using SurroundInhibition
|
|
238
|
+
|
|
239
|
+
return self.output
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class LaminaLateralInhibition(BaseCore):
|
|
243
|
+
""" LAMINALATERALINHIBITION Lateral inhibition in the Lamina layer
|
|
244
|
+
|
|
245
|
+
This class implements the lateral inhibition mechanism in the Lamina layer
|
|
246
|
+
of the ESTMD using pure PyTorch operations.
|
|
247
|
+
|
|
248
|
+
References:
|
|
249
|
+
* S. D. Wiederman, P. A. Shoemarker, D. C. O'Carroll, A model
|
|
250
|
+
for the detection of moving targets in visual clutter inspired by
|
|
251
|
+
insect physiology, PLoS ONE 3 (7) (2008) e2784.
|
|
252
|
+
* Wang H, Peng J, Yue S. A directionally selective small target
|
|
253
|
+
motion detecting visual neural network in cluttered backgrounds[J].
|
|
254
|
+
IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(self,
|
|
258
|
+
sizeW1=[11, 11, 7],
|
|
259
|
+
lambda1=3.0,
|
|
260
|
+
lambda2=9.0,
|
|
261
|
+
sigma1=1.5,
|
|
262
|
+
sigma2=None):
|
|
263
|
+
"""
|
|
264
|
+
Constructor
|
|
265
|
+
Initializes the LaminaLateralInhibition module
|
|
266
|
+
"""
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.sizeW1 = sizeW1
|
|
269
|
+
self.lambda1 = lambda1
|
|
270
|
+
self.lambda2 = lambda2
|
|
271
|
+
self.sigma1 = sigma1
|
|
272
|
+
self.sigma2 = 2.0 * sigma1 if sigma2 is None else sigma2
|
|
273
|
+
self.T = sizeW1[2] # Temporal length
|
|
274
|
+
|
|
275
|
+
self.register_buffer('spatial_pos_kernel', torch.empty(0))
|
|
276
|
+
self.register_buffer('spatial_neg_kernel', torch.empty(0))
|
|
277
|
+
self.register_buffer('temporal_pos_kernel', torch.empty(0))
|
|
278
|
+
self.register_buffer('temporal_neg_kernel', torch.empty(0))
|
|
279
|
+
|
|
280
|
+
self.setup() # Initialize kernels and buffers
|
|
281
|
+
|
|
282
|
+
def _create_spatial_kernels(self):
|
|
283
|
+
"""初始化 DoG (Difference of Gaussian) 空间感受野权重"""
|
|
284
|
+
g_sigma2 = create_2d_gaussian_kernel(self.sizeW1[:2], self.sigma1)
|
|
285
|
+
g_sigma3 = create_2d_gaussian_kernel(self.sizeW1[:2], self.sigma2)
|
|
286
|
+
diff_of_gaussian = g_sigma2 - g_sigma3
|
|
287
|
+
|
|
288
|
+
# W_{S}^{P} 和 W_{S}^{N}
|
|
289
|
+
pos_kernel = torch.clamp(diff_of_gaussian, min=0)
|
|
290
|
+
neg_kernel = torch.clamp(diff_of_gaussian, max=0)
|
|
291
|
+
|
|
292
|
+
# 调整形状为 (out_channels=1, in_channels=1, H, W) 以匹配 F.conv2d 的需求
|
|
293
|
+
pos_kernel = pos_kernel.view(1, 1, *self.sizeW1[:2])
|
|
294
|
+
neg_kernel = neg_kernel.view(1, 1, *self.sizeW1[:2])
|
|
295
|
+
|
|
296
|
+
return pos_kernel, neg_kernel
|
|
297
|
+
|
|
298
|
+
def _create_temporal_kernels(self):
|
|
299
|
+
"""初始化时间衰减权重"""
|
|
300
|
+
t = torch.arange(self.T, dtype=torch.float32)
|
|
301
|
+
|
|
302
|
+
# W_{T}^{P} 和 W_{T}^{N}
|
|
303
|
+
w_t_pos = torch.exp(-t / self.lambda1) / self.lambda1
|
|
304
|
+
w_t_neg = torch.exp(-t / self.lambda2) / self.lambda2
|
|
305
|
+
|
|
306
|
+
# 调整形状为 (T, 1, 1, 1) 方便后续与 (T, Batch, C, H, W) 张量进行广播乘法
|
|
307
|
+
return w_t_pos.view(-1, 1, 1, 1), w_t_neg.view(-1, 1, 1, 1)
|
|
308
|
+
|
|
309
|
+
def setup(self):
|
|
310
|
+
# 1. 预计算空间卷积核 (Spatial Kernels)
|
|
311
|
+
spatial_pos, spatial_neg = self._create_spatial_kernels()
|
|
312
|
+
# 使用 register_buffer,这样模型调用 .cuda() 或 .to(device) 时,核也会自动转移
|
|
313
|
+
self.spatial_pos_kernel.data = spatial_pos
|
|
314
|
+
self.spatial_neg_kernel.data = spatial_neg
|
|
315
|
+
|
|
316
|
+
# 2. 预计算时间卷积核 (Temporal Kernels)
|
|
317
|
+
temporal_pos, temporal_neg = self._create_temporal_kernels()
|
|
318
|
+
self.temporal_pos_kernel.data = temporal_pos
|
|
319
|
+
self.temporal_neg_kernel.data = temporal_neg
|
|
320
|
+
|
|
321
|
+
# 3. 初始化时序状态缓存区 (采用双端队列 deque 实现高效的滑动窗口)
|
|
322
|
+
self.pos_buffer = deque(maxlen=self.T)
|
|
323
|
+
self.neg_buffer = deque(maxlen=self.T)
|
|
324
|
+
|
|
325
|
+
def reset_buffer(self):
|
|
326
|
+
"""重置时序状态缓存区"""
|
|
327
|
+
self.pos_buffer.clear()
|
|
328
|
+
self.neg_buffer.clear()
|
|
329
|
+
|
|
330
|
+
def forward(self, x):
|
|
331
|
+
"""
|
|
332
|
+
输入:
|
|
333
|
+
x: 形状为 (B, C, H, W) 或 (H, W) 的张量
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
# === 1. 空间侧抑制 (Spatial Lateral Inhibition) ===
|
|
337
|
+
on_conv = F.conv2d(x, self.spatial_pos_kernel, padding='same')
|
|
338
|
+
off_conv = F.conv2d(x, self.spatial_neg_kernel, padding='same')
|
|
339
|
+
|
|
340
|
+
# 记录当前帧结果到时序缓存区 (新的帧在队列右侧)
|
|
341
|
+
self.pos_buffer.append(on_conv)
|
|
342
|
+
self.neg_buffer.append(off_conv)
|
|
343
|
+
|
|
344
|
+
# === 2. 时序卷积 (Temporal Convolution) ===
|
|
345
|
+
|
|
346
|
+
pos_out = compute_temporal_conv_inplace(self.pos_buffer, self.temporal_pos_kernel)
|
|
347
|
+
neg_out = compute_temporal_conv_inplace(self.neg_buffer, self.temporal_neg_kernel)
|
|
348
|
+
|
|
349
|
+
return pos_out + neg_out
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
|