xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .base_core import BaseCore
7
+ from .estmd_core import Tm1, Mi1
8
+ from .math_operator import GammaDelay, SpatialInhibition
9
+ from ..util.create_kernel import create_direction_inhi_kernel
10
+
11
+
12
+ class Medulla(BaseCore):
13
+ """Medulla class for motion detection."""
14
+
15
+ def __init__(self):
16
+ """Constructor method."""
17
+ # Initializes the Medulla object
18
+ super().__init__()
19
+
20
+ # Initialize components
21
+ self.tm3 = Tm3()
22
+ self.mi1_para4 = Mi1(3, 15)
23
+
24
+ self.tm2 = Tm2()
25
+ self.tm1_para5 = Tm1(5, 25)
26
+ self.tm1_para6 = Tm1(8, 40)
27
+
28
+ def setup(self):
29
+ """Initialization method."""
30
+ # Initializes the delay components
31
+
32
+ self.mi1_para4.setup()
33
+ self.tm1_para5.setup()
34
+ self.tm1_para6.setup()
35
+
36
+ def forward(self, x):
37
+ """Processing method."""
38
+ # Processes input signals and produces output
39
+
40
+ # Process Tm3 and Tm2 signals
41
+ tm3_output = self.tm3.forward(x) # L ON
42
+ tm2_output = self.tm2.forward(x) # L OFF
43
+
44
+ # Process signals with delays
45
+ mi1_para4_output = self.mi1_para4.forward(tm3_output)
46
+ tm1_para5_output = self.tm1_para5.forward(tm2_output)
47
+ tm1_para6_output = self.tm1_para6.forward(tm2_output)
48
+
49
+ # Output signals
50
+ self.output = (tm3_output, mi1_para4_output, tm1_para5_output, tm1_para6_output)
51
+ return self.output
52
+
53
+
54
+ class Tm2(BaseCore):
55
+ """Tm2 class for motion detection."""
56
+
57
+ def forward(self, iptMatrix):
58
+ """Processing method."""
59
+ # Processes the input matrix by performing a maximum operation with zero for negative values
60
+ self.output = torch.clamp(-iptMatrix, min=0)
61
+
62
+ return self.output
63
+
64
+
65
+ class Tm3(BaseCore):
66
+ """Tm3 class for motion detection."""
67
+
68
+ def forward(self, iptMatrix):
69
+ """Processing method."""
70
+ # Processes the input matrix by performing a maximum operation with zero for negative values
71
+
72
+ self.output = torch.clamp(iptMatrix, min=0)
73
+
74
+ return self.output
75
+
76
+
77
+ class Lobula(BaseCore):
78
+ """Lobula class for motion detection."""
79
+
80
+ def __init__(self):
81
+ """Constructor method."""
82
+ # Initializes the Lobula object
83
+ super().__init__()
84
+ self.alpha1 = 3 # Alpha parameter
85
+
86
+ self.register_buffer('theta_list', torch.tensor([(i * torch.pi / 4) for i in range(8)]))
87
+ self.hLateralInhi = SpatialInhibition() # Lateral inhibition component
88
+ self.hDirectionInhi = DirectionInhibition() # Directional inhibition component
89
+
90
+ def setup(self):
91
+ """Initialization method."""
92
+ # Initializes the lateral and directional inhibition components
93
+ self.hLateralInhi.setup()
94
+ self.hDirectionInhi.setup()
95
+
96
+ def forward(self, tm3, mi1_p4, tm1_p5, tm1_p6):
97
+ # tm3, mi1_p4, tm1_p5, tm1_p6 形状均为 [1, 1, H, W]
98
+ _, _, imgH, imgW = tm3.shape
99
+ device = tm3.device
100
+
101
+ num_thetas = len(self.theta_list)
102
+ a1 = self.alpha1
103
+
104
+ # 2. 计算偏移索引
105
+ shifts_x = torch.round(a1 * torch.cos(self.theta_list)).long() # 形状: [num_thetas]
106
+ shifts_y = torch.round(a1 * torch.sin(self.theta_list)).long() # 形状: [num_thetas]
107
+
108
+ # 3. 提取中心 ROI
109
+ y_s, y_e = a1, imgH - a1
110
+ x_s, x_e = a1, imgW - a1
111
+
112
+ # 提取不变部分的 ROI,保持 4D 形状: [1, 1, h_roi, w_roi]
113
+ tm3_roi = tm3[:, :, y_s:y_e, x_s:x_e]
114
+ tm1_p5_roi = tm1_p5[:, :, y_s:y_e, x_s:x_e]
115
+
116
+ # 4. 生成偏移索引网格 (关键点)
117
+ grid_y, grid_x = torch.meshgrid(
118
+ torch.arange(y_s, y_e, device=device),
119
+ torch.arange(x_s, x_e, device=device),
120
+ indexing='ij'
121
+ )
122
+
123
+ # 计算所有方向的索引: [num_thetas, h_roi, w_roi]
124
+ src_idx_x = grid_x.unsqueeze(0).unsqueeze(0) - shifts_x .view(1, -1, 1, 1)
125
+ src_idx_y = grid_y.unsqueeze(0).unsqueeze(0) + shifts_y.view(1, -1, 1, 1)
126
+
127
+ # 5. 高级索引提取偏移信号
128
+ # mi1_p4[0, 0] 是 [H, W],通过 src_idx 提取后变成 [num_thetas, h_roi, w_roi]
129
+ # 我们将其扩展回 4D: [1, num_thetas, h_roi, w_roi]
130
+ mi1_p4_shifted = mi1_p4[0, 0, src_idx_y, src_idx_x]
131
+ tm1_p6_shifted = tm1_p6[0, 0, src_idx_y, src_idx_x]
132
+
133
+ # 6. 计算相关输出 (利用广播)
134
+ # tm3_roi: [1, 1, h_roi, w_roi]
135
+ # mi1_p4_shifted: [1, num_thetas, h_roi, w_roi]
136
+ # 结果 corre_roi: [1, num_thetas, h_roi, w_roi]
137
+ corre_roi = tm3_roi * (tm1_p5_roi + mi1_p4_shifted) * tm1_p6_shifted
138
+
139
+ # 7. 填回全零张量
140
+ correOutput = torch.zeros((1, num_thetas, imgH, imgW), device=device)
141
+ correOutput[:, :, y_s:y_e, x_s:x_e] = corre_roi
142
+
143
+ lateralInhioutput = self.hLateralInhi.forward(correOutput)
144
+ self.output = self.hDirectionInhi.forward(lateralInhioutput)
145
+
146
+ return self.output
147
+
148
+
149
+ class DirectionInhibition(BaseCore):
150
+ """Directional inhibition in DSTMD."""
151
+
152
+ def __init__(self):
153
+ """Constructor method."""
154
+ # Initializes the DirectionInhi object
155
+ super().__init__()
156
+ self.direction = 8 # Number of directions
157
+ self.sigma1 = 1.5 # Sigma for the first Gaussian kernel
158
+ self.sigma2 = 3.0 # Sigma for the second Gaussian kernel
159
+
160
+ self.register_buffer("diretional_inhi_kernel", torch.empty(0)) # Placeholder for the directional inhibition kernel
161
+
162
+ self.setup() # Initialize the kernel
163
+
164
+ def setup(self):
165
+ """Initialization method."""
166
+ # Initializes the directional inhibition kernel
167
+
168
+ _diretional_inhi_kernel = create_direction_inhi_kernel(
169
+ self.direction, self.sigma1, self.sigma2
170
+ )
171
+ # Shape: [1, 1, kernel_size]
172
+ self.diretional_inhi_kernel.data = _diretional_inhi_kernel
173
+
174
+ def forward(self, x):
175
+ """
176
+ Input x shape: [B, C, H, W], where C = self.direction
177
+ Output shape: [B, C, H, W]
178
+ """
179
+ # Performs directional inhibition on the input
180
+
181
+ # 1. 准备数据维度
182
+ # 输入是 [B, C, H, W],卷积需要在 C 维度上滑,所以要把 H, W 暂时视为 Batch
183
+ b, c, h, w = x.shape
184
+
185
+ # 转换形状: [B, C, H, W] -> [B, H, W, C] -> [B*H*W, C]
186
+ # 这样对于 conv1d 来说,BatchSize = B*H*W, 通道数 = 1, 序列长度 = C
187
+ x = x.permute(0, 2, 3, 1).reshape(b * h * w, 1, c)
188
+
189
+ # 3. 执行循环卷积 (Circular Convolution)
190
+ # padding 设为 kernel_size // 2,且模式设为 'circular'
191
+ pad_size = self.diretional_inhi_kernel.shape[-1]
192
+ center_idx = pad_size // 2
193
+
194
+ pad_left = center_idx
195
+ pad_right = pad_size - center_idx - 1
196
+
197
+ # F.pad 在 1D 信号上的填充格式是 (left, right)
198
+ x_padded = F.pad(x, (pad_left, pad_right), mode='circular')
199
+
200
+ # F.conv1d 会在 C 维度(方向轴)上滑动
201
+ # 结果形状依然是 [B*H*W, 1, C]
202
+ result = F.conv1d(x_padded, self.diretional_inhi_kernel)
203
+
204
+ # 4. 激活与恢复形状
205
+ opt = F.relu(result)
206
+
207
+ # 恢复回 [B, C, H, W]
208
+ return opt.reshape(b, h, w, c).permute(0, 3, 1, 2)
209
+
210
+
211
+
212
+
213
+
@@ -0,0 +1,110 @@
1
+ import torch
2
+
3
+ from .base_core import BaseCore
4
+ from .math_operator import GammaBandPassFilter, SpatialInhibition
5
+ from . import estmd_core
6
+
7
+
8
+ class Lamina(GammaBandPassFilter):
9
+ pass
10
+
11
+
12
+ class Medulla(BaseCore):
13
+ """Medulla layer of the motion detection system."""
14
+
15
+ def __init__(self):
16
+ """Constructor method."""
17
+ # Initializes the Medulla object
18
+ super().__init__()
19
+ # Initialize components
20
+ self.tm1 = estmd_core.Tm1()
21
+ self.mi1 = estmd_core.Mi1()
22
+ self.tm2 = Tm2()
23
+ self.tm3 = Tm3()
24
+
25
+ def setup(self):
26
+ """Initialization method."""
27
+ # This method initializes the Medulla layer components
28
+ self.tm1.setup()
29
+ self.mi1.setup()
30
+ self.tm2.setup()
31
+ self.tm3.setup()
32
+
33
+ def forward(self, MedullaIpt):
34
+ """Processing method."""
35
+ # Applies processing to the input and returns the output
36
+
37
+ # Process Tm2 and Tm3 components
38
+ tm2_output = self.tm2.forward(MedullaIpt)
39
+ tm3_output = self.tm3.forward(MedullaIpt)
40
+
41
+ # Process Tm1 component using output of Tm2
42
+ tm1_output = self.tm1.forward(tm2_output)
43
+
44
+ # Store the output signals in output property
45
+ self.output = [tm3_output, tm1_output]
46
+
47
+ return self.output
48
+
49
+
50
+ class Lobula(BaseCore):
51
+ """Lobula layer of the motion detection system."""
52
+
53
+ def __init__(self):
54
+ """Constructor method."""
55
+ # Initializes the Lobula object
56
+ super().__init__()
57
+ # Initialize the SpatialInhibition component
58
+ self.spatial_inhibition = SpatialInhibition()
59
+ # Parameters related to the recombination of ON and OFF channels
60
+ self.a = 0
61
+ self.b = 0
62
+ self.c = 1
63
+
64
+ def setup(self):
65
+ """Initialization method."""
66
+ # This method initializes the Lobula layer component
67
+ self.spatial_inhibition.setup()
68
+
69
+ def forward(self, on_output, off_output):
70
+ """Processing method."""
71
+ # Performs a correlation operation on the ON and OFF channels
72
+ # and then applies surround inhibition
73
+
74
+ # Perform the correlation operation
75
+ correlationOutput = (
76
+ self.a * on_output +
77
+ self.b * off_output +
78
+ self.c * on_output * off_output
79
+ )
80
+
81
+ # Apply surround inhibition
82
+ lobulaoutput = self.spatial_inhibition.forward(correlationOutput)
83
+
84
+ # Store the output in output property
85
+ self.output = lobulaoutput
86
+ return lobulaoutput, correlationOutput
87
+
88
+
89
+ class Tm2(BaseCore):
90
+ """Tm2 without spatial inhibition."""
91
+ def forward(self, x):
92
+ """Processing method."""
93
+ # Applies surround inhibition to the input to generate the output
94
+ self.output = torch.clamp(-x, min=0) # Apply surround inhibition
95
+ return self.output
96
+
97
+
98
+ class Tm3(BaseCore):
99
+ """Tm3 without spatial inhibition."""
100
+ def forward(self, tm3outputIpt):
101
+ """Processing method."""
102
+ # Applies a surround inhibition to the input to generate the output
103
+ self.output = torch.clamp(tm3outputIpt, min=0) # Apply surround inhibition
104
+ return self.output
105
+
106
+
107
+
108
+
109
+
110
+
@@ -0,0 +1,356 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .base_core import BaseCore
7
+ from .math_operator import (compute_temporal_conv_inplace,
8
+ GaussianBlur, SpatialInhibition,
9
+ GammaDelay, GammaBandPassFilter)
10
+ from ..util.create_kernel import create_2d_gaussian_kernel
11
+
12
+
13
+ class Retina(GaussianBlur):
14
+ pass
15
+
16
+
17
+ class Lamina(BaseCore):
18
+ """
19
+ LAMINA Lamina layer
20
+ This class implements the Lamina layer of the ESTMD
21
+
22
+ Author: Mingshuo Xu
23
+ Date: 2024-04-29
24
+ """
25
+
26
+ def __init__(self):
27
+ """
28
+ Constructor
29
+ Initializes the Lamina object and creates GammaBankPassFilter
30
+ and LaminaLateralInhibition objects
31
+ """
32
+ super().__init__()
33
+ self.gamma_BPF = GammaBandPassFilter()
34
+ self.spatial_inhibition = LaminaLateralInhibition()
35
+
36
+ def setup(self):
37
+ self.gamma_BPF.setup()
38
+ self.spatial_inhibition.setup()
39
+
40
+ def reset_buffer(self):
41
+ self.gamma_BPF.reset_buffer()
42
+
43
+ def forward(self, laminaIpt):
44
+ """
45
+ Processing method
46
+ Applies GammaBankPassFilter and LaminaLateralInhibition to the input matrix
47
+
48
+ Parameters:
49
+ - laminaIpt: Input matrix
50
+
51
+ Returns:
52
+ - laminaOpt: Processed output matrix
53
+ """
54
+ signalWithBPF = self.gamma_BPF.forward(laminaIpt)
55
+ self.output = self.spatial_inhibition.forward(signalWithBPF)
56
+
57
+ return self.output
58
+
59
+
60
+ class Medulla(BaseCore):
61
+ """
62
+ Medulla Layer of DSTMD
63
+ This class implements the Medulla layer of the ESTMD.
64
+ """
65
+
66
+ def __init__(self):
67
+ """
68
+ Constructor method
69
+ Initializes the Medulla object
70
+ """
71
+ super().__init__()
72
+ self.tm1 = Tm1(order=12, tau=25) # Initialize Tm1 object
73
+ self.tm2 = Tm2() # Initialize Tm2 object
74
+ self.tm3 = Tm3() # Initialize Tm3 object
75
+ self.mi1 = Mi1(order=12, tau=25) # Initialize Tm3 object
76
+
77
+ def setup(self):
78
+ """
79
+ Initialization method
80
+ Initializes the Tm1, Tm2, and Tm3 objects
81
+ """
82
+ self.tm1.setup()
83
+ self.tm2.setup()
84
+ self.tm3.setup()
85
+ self.mi1.setup()
86
+
87
+ def reset_buffer(self):
88
+ """
89
+ Buffer reset method
90
+ Resets the buffers of Tm1, Tm2, and Tm3 objects
91
+ """
92
+ self.tm1.reset_buffer()
93
+ self.mi1.reset_buffer()
94
+
95
+ def forward(self, x):
96
+ """
97
+ Processing method
98
+ Processes the input MedullaIpt through Tm1, Tm2, and Tm3 layers
99
+
100
+ Parameters:
101
+ - MedullaIpt: Input matrix
102
+
103
+ Returns:
104
+ - tm3Signal: Output of Tm3 layer
105
+ - tm1Signal: Output of Tm1 layer
106
+ """
107
+ tm2_output = self.tm2.forward(x) # Process input through Tm2
108
+ tm3_output = self.tm3.forward(x) # Process input through Tm3
109
+
110
+ tm1_output = self.tm1.forward(tm2_output) # Process Tm2 output through Tm1
111
+
112
+ self.output = (tm3_output, tm1_output) # Update output property with output
113
+ return tm3_output, tm1_output
114
+
115
+
116
+ class Lobula(BaseCore):
117
+ """
118
+ Lobula Layer of DSTMD
119
+ This class implements the Lobula layer of the ESTMD.
120
+ """
121
+
122
+ def __init__(self):
123
+ """
124
+ Constructor method
125
+ Initializes the Lobula object
126
+ """
127
+ super().__init__()
128
+ self.a = 0 # Parameter a
129
+ self.b = 0 # Parameter b
130
+ self.c = 1 # Parameter c
131
+
132
+ def setup(self):
133
+ """
134
+ Initialization method
135
+ """
136
+ pass
137
+
138
+ def forward(self, onSignal, offSignal):
139
+ """
140
+ Processing method
141
+ Processes the input ON and OFF signals
142
+
143
+ Parameters:
144
+ - varagein: Tuple containing ON and OFF signals
145
+
146
+ Returns:
147
+ - lobulaOpt: Output of the Lobula layer
148
+ """
149
+
150
+ # Compute Lobula output using the provided formula
151
+ self.output = self.a*onSignal + self.b*offSignal + self.c*onSignal*offSignal
152
+
153
+ return self.output
154
+
155
+
156
+ class Mi1(GammaDelay):
157
+ pass
158
+
159
+
160
+ class Tm1(GammaDelay):
161
+ pass
162
+
163
+
164
+ class Tm2(BaseCore):
165
+ """
166
+ Tm2 Medulla Layer Neurons in ESTMD
167
+ """
168
+
169
+ def __init__(self, device='cpu'):
170
+ """
171
+ Constructor method
172
+ Initializes the Tm2 object
173
+ """
174
+ super().__init__()
175
+ self.spatial_inhibition = SpatialInhibition() # Initialize SurroundInhibition object
176
+
177
+ def setup(self):
178
+ """
179
+ Initialization method
180
+ Initializes the SurroundInhibition object
181
+ """
182
+ self.spatial_inhibition.setup()
183
+
184
+ def forward(self, x):
185
+ """
186
+ Processing method
187
+ Applies the Surround Inhibition mechanism to the input matrix iptMatrix
188
+
189
+ Parameters:
190
+ - iptMatrix: Input matrix
191
+
192
+ Returns:
193
+ - tm2Opt: Output of the Tm2 layer
194
+ """
195
+ # Extract the OFF signal from iptMatrix
196
+ L_OFF = torch.clamp(-x, min=0)
197
+
198
+ # Process the OFF signal using SurroundInhibition
199
+ self.output = self.spatial_inhibition.forward(L_OFF)
200
+
201
+ return self.output
202
+
203
+
204
+ class Tm3(BaseCore):
205
+ """ Tm3 """
206
+
207
+ def __init__(self, device='cpu'):
208
+ """ Constructor method
209
+
210
+ Initializes the Tm3 object
211
+ """
212
+ super().__init__()
213
+ self.spatial_inhibition = SpatialInhibition() # Initialize SurroundInhibition object
214
+
215
+ def setup(self):
216
+ """ Initialization method
217
+
218
+ Initializes the SurroundInhibition object
219
+ """
220
+ self.spatial_inhibition.setup()
221
+
222
+ def forward(self, iptMatrix):
223
+ """ Processing method
224
+
225
+ Description:
226
+ Applies Surround Inhibition to the On-signal matrix iptMatrix
227
+
228
+ Parameters:
229
+ - iptMatrix: Input matrix
230
+
231
+ Returns:
232
+ - tm3Opt: Output of the Tm3 layer
233
+ """
234
+
235
+ L_ON = torch.clamp(iptMatrix, min=0)
236
+
237
+ self.output = self.spatial_inhibition.forward(L_ON) # Processes the On-signal using SurroundInhibition
238
+
239
+ return self.output
240
+
241
+
242
+ class LaminaLateralInhibition(BaseCore):
243
+ """ LAMINALATERALINHIBITION Lateral inhibition in the Lamina layer
244
+
245
+ This class implements the lateral inhibition mechanism in the Lamina layer
246
+ of the ESTMD using pure PyTorch operations.
247
+
248
+ References:
249
+ * S. D. Wiederman, P. A. Shoemarker, D. C. O'Carroll, A model
250
+ for the detection of moving targets in visual clutter inspired by
251
+ insect physiology, PLoS ONE 3 (7) (2008) e2784.
252
+ * Wang H, Peng J, Yue S. A directionally selective small target
253
+ motion detecting visual neural network in cluttered backgrounds[J].
254
+ IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
255
+ """
256
+
257
+ def __init__(self,
258
+ sizeW1=[11, 11, 7],
259
+ lambda1=3.0,
260
+ lambda2=9.0,
261
+ sigma1=1.5,
262
+ sigma2=None):
263
+ """
264
+ Constructor
265
+ Initializes the LaminaLateralInhibition module
266
+ """
267
+ super().__init__()
268
+ self.sizeW1 = sizeW1
269
+ self.lambda1 = lambda1
270
+ self.lambda2 = lambda2
271
+ self.sigma1 = sigma1
272
+ self.sigma2 = 2.0 * sigma1 if sigma2 is None else sigma2
273
+ self.T = sizeW1[2] # Temporal length
274
+
275
+ self.register_buffer('spatial_pos_kernel', torch.empty(0))
276
+ self.register_buffer('spatial_neg_kernel', torch.empty(0))
277
+ self.register_buffer('temporal_pos_kernel', torch.empty(0))
278
+ self.register_buffer('temporal_neg_kernel', torch.empty(0))
279
+
280
+ self.setup() # Initialize kernels and buffers
281
+
282
+ def _create_spatial_kernels(self):
283
+ """初始化 DoG (Difference of Gaussian) 空间感受野权重"""
284
+ g_sigma2 = create_2d_gaussian_kernel(self.sizeW1[:2], self.sigma1)
285
+ g_sigma3 = create_2d_gaussian_kernel(self.sizeW1[:2], self.sigma2)
286
+ diff_of_gaussian = g_sigma2 - g_sigma3
287
+
288
+ # W_{S}^{P} 和 W_{S}^{N}
289
+ pos_kernel = torch.clamp(diff_of_gaussian, min=0)
290
+ neg_kernel = torch.clamp(diff_of_gaussian, max=0)
291
+
292
+ # 调整形状为 (out_channels=1, in_channels=1, H, W) 以匹配 F.conv2d 的需求
293
+ pos_kernel = pos_kernel.view(1, 1, *self.sizeW1[:2])
294
+ neg_kernel = neg_kernel.view(1, 1, *self.sizeW1[:2])
295
+
296
+ return pos_kernel, neg_kernel
297
+
298
+ def _create_temporal_kernels(self):
299
+ """初始化时间衰减权重"""
300
+ t = torch.arange(self.T, dtype=torch.float32)
301
+
302
+ # W_{T}^{P} 和 W_{T}^{N}
303
+ w_t_pos = torch.exp(-t / self.lambda1) / self.lambda1
304
+ w_t_neg = torch.exp(-t / self.lambda2) / self.lambda2
305
+
306
+ # 调整形状为 (T, 1, 1, 1) 方便后续与 (T, Batch, C, H, W) 张量进行广播乘法
307
+ return w_t_pos.view(-1, 1, 1, 1), w_t_neg.view(-1, 1, 1, 1)
308
+
309
+ def setup(self):
310
+ # 1. 预计算空间卷积核 (Spatial Kernels)
311
+ spatial_pos, spatial_neg = self._create_spatial_kernels()
312
+ # 使用 register_buffer,这样模型调用 .cuda() 或 .to(device) 时,核也会自动转移
313
+ self.spatial_pos_kernel.data = spatial_pos
314
+ self.spatial_neg_kernel.data = spatial_neg
315
+
316
+ # 2. 预计算时间卷积核 (Temporal Kernels)
317
+ temporal_pos, temporal_neg = self._create_temporal_kernels()
318
+ self.temporal_pos_kernel.data = temporal_pos
319
+ self.temporal_neg_kernel.data = temporal_neg
320
+
321
+ # 3. 初始化时序状态缓存区 (采用双端队列 deque 实现高效的滑动窗口)
322
+ self.pos_buffer = deque(maxlen=self.T)
323
+ self.neg_buffer = deque(maxlen=self.T)
324
+
325
+ def reset_buffer(self):
326
+ """重置时序状态缓存区"""
327
+ self.pos_buffer.clear()
328
+ self.neg_buffer.clear()
329
+
330
+ def forward(self, x):
331
+ """
332
+ 输入:
333
+ x: 形状为 (B, C, H, W) 或 (H, W) 的张量
334
+ """
335
+
336
+ # === 1. 空间侧抑制 (Spatial Lateral Inhibition) ===
337
+ on_conv = F.conv2d(x, self.spatial_pos_kernel, padding='same')
338
+ off_conv = F.conv2d(x, self.spatial_neg_kernel, padding='same')
339
+
340
+ # 记录当前帧结果到时序缓存区 (新的帧在队列右侧)
341
+ self.pos_buffer.append(on_conv)
342
+ self.neg_buffer.append(off_conv)
343
+
344
+ # === 2. 时序卷积 (Temporal Convolution) ===
345
+
346
+ pos_out = compute_temporal_conv_inplace(self.pos_buffer, self.temporal_pos_kernel)
347
+ neg_out = compute_temporal_conv_inplace(self.neg_buffer, self.temporal_neg_kernel)
348
+
349
+ return pos_out + neg_out
350
+
351
+
352
+
353
+
354
+
355
+
356
+