xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from .base_core import BaseCore
|
|
5
|
+
from .math_operator import SpatialInhibition, GammaDelay
|
|
6
|
+
from ..util.create_kernel import create_2d_gaussian_kernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Lobula(BaseCore):
|
|
10
|
+
""" Lobula layer of the motion detection system."""
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
"""Constructor method."""
|
|
14
|
+
# Initializes the Lobula object
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.spatial_inhibition = SpatialInhibition() # SpatialInhibition component
|
|
17
|
+
self.alpha = 1 # Parameter alpha
|
|
18
|
+
self.sigma = 1.5 # Parameters for Gaussian kernel
|
|
19
|
+
|
|
20
|
+
self.gamma_delay = GammaDelay(10, 25) # GammaDelay component
|
|
21
|
+
|
|
22
|
+
self.register_buffer('gaussian_kernel', torch.empty(0)) # Buffer for Gaussian kernel
|
|
23
|
+
|
|
24
|
+
self.setup()
|
|
25
|
+
|
|
26
|
+
def setup(self):
|
|
27
|
+
""" Initialization method."""
|
|
28
|
+
# Initializes the Lobula layer component
|
|
29
|
+
self.spatial_inhibition.setup()
|
|
30
|
+
self.gamma_delay.setup()
|
|
31
|
+
|
|
32
|
+
self.gaussian_kernel.data = create_2d_gaussian_kernel(size=3, sigma=self.sigma)
|
|
33
|
+
|
|
34
|
+
self.reset_buffer()
|
|
35
|
+
|
|
36
|
+
def reset_buffer(self):
|
|
37
|
+
""" Resets the buffer of certain components. """
|
|
38
|
+
self.gamma_delay.reset_buffer()
|
|
39
|
+
|
|
40
|
+
def forward(self, medulla_ON, medulla_OFF):
|
|
41
|
+
""" Processing method. """
|
|
42
|
+
# Performs temporal convolution, correlation, and surround inhibition
|
|
43
|
+
|
|
44
|
+
# Formula (9)
|
|
45
|
+
feedback_output = self.alpha * self.gamma_delay.forward(torch.zeros_like(medulla_ON))
|
|
46
|
+
|
|
47
|
+
# Formula (8)
|
|
48
|
+
ON_with_feedback = torch.clamp(medulla_ON - feedback_output, min=0)
|
|
49
|
+
OFF_with_feedback = torch.clamp(medulla_OFF - feedback_output, min=0)
|
|
50
|
+
correlation_D = ON_with_feedback * OFF_with_feedback
|
|
51
|
+
|
|
52
|
+
# Formula (10)
|
|
53
|
+
correlation_E = F.conv2d(medulla_ON * medulla_OFF, self.gaussian_kernel, padding='same')
|
|
54
|
+
|
|
55
|
+
# Only record (correlationD + correlationE) for next delay in Formula (9)
|
|
56
|
+
self.gamma_delay.buffer[-1] = correlation_D + correlation_E
|
|
57
|
+
|
|
58
|
+
# Formula (14)
|
|
59
|
+
self.output = self.spatial_inhibition(correlation_D)
|
|
60
|
+
|
|
61
|
+
return self.output
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from collections import deque
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from ..util.create_kernel import create_fracdiff_kernel
|
|
8
|
+
from ..core.math_operator import compute_temporal_conv_inplace
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Lamina(BaseCore):
|
|
12
|
+
"""
|
|
13
|
+
Lamina layer in ESTMD.
|
|
14
|
+
Pure PyTorch Implementation supporting both IIR (iteration) and FIR (convolution) modes.
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, alpha=0.8, delta=20, mode='iteration'):
|
|
17
|
+
"""
|
|
18
|
+
Constructor method.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
- alpha: Fractional differential order (0 < alpha <= 1)
|
|
22
|
+
- delta: Length of the historical buffer for convolution
|
|
23
|
+
- mode: 'iteration' (IIR, fast) or 'conv' (FIR, accurate but slower)
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.alpha = alpha
|
|
27
|
+
self.delta = delta
|
|
28
|
+
self.mode = mode
|
|
29
|
+
|
|
30
|
+
self.register_buffer('frac_kernel', torch.empty(0))
|
|
31
|
+
|
|
32
|
+
self.setup()
|
|
33
|
+
|
|
34
|
+
def setup(self):
|
|
35
|
+
|
|
36
|
+
_kernel = create_fracdiff_kernel(self.alpha, self.delta)
|
|
37
|
+
self.frac_kernel.data = _kernel
|
|
38
|
+
|
|
39
|
+
# 2. 计算迭代模式 (IIR) 的系数
|
|
40
|
+
self.para_cur = _kernel[0].item()
|
|
41
|
+
|
|
42
|
+
if self.alpha == 1.0:
|
|
43
|
+
self.para_pre = 0.0
|
|
44
|
+
elif 0.0 < self.alpha < 1.0:
|
|
45
|
+
self.para_pre = math.exp(-self.alpha / (1.0 - self.alpha))
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError("Invalid alpha value. Must be in (0, 1].")
|
|
48
|
+
|
|
49
|
+
self.reset_buffer()
|
|
50
|
+
|
|
51
|
+
def reset_buffer(self):
|
|
52
|
+
# 3. 初始化时序状态 (State)
|
|
53
|
+
self.state_ipt = None
|
|
54
|
+
self.state_opt = None
|
|
55
|
+
self.buffer = deque(maxlen=self.delta)
|
|
56
|
+
|
|
57
|
+
def forward(self, x):
|
|
58
|
+
"""
|
|
59
|
+
Processing method.
|
|
60
|
+
x shape: (B, C, H, W)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# --- 1. 计算一阶差分 (First order difference) ---
|
|
64
|
+
if self.state_ipt is None:
|
|
65
|
+
diff_x = torch.zeros_like(x)
|
|
66
|
+
else:
|
|
67
|
+
diff_x = x - self.state_ipt
|
|
68
|
+
|
|
69
|
+
# 使用 .detach() 截断计算图,防止处理长视频时 GPU 显存爆炸
|
|
70
|
+
self.state_ipt = x.detach()
|
|
71
|
+
|
|
72
|
+
# --- 2. 选择计算模式 ---
|
|
73
|
+
if self.mode == 'iteration':
|
|
74
|
+
self.output = self._compute_by_iteration(diff_x)
|
|
75
|
+
elif self.mode == 'conv':
|
|
76
|
+
self.output = self._compute_by_conv(diff_x)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("Mode must be 'iteration' or 'conv'.")
|
|
79
|
+
|
|
80
|
+
return self.output
|
|
81
|
+
|
|
82
|
+
def _compute_by_iteration(self, diff_x):
|
|
83
|
+
"""IIR (无限脉冲响应) 迭代计算法 - 极速模式"""
|
|
84
|
+
if self.state_opt is None:
|
|
85
|
+
opt = diff_x
|
|
86
|
+
else:
|
|
87
|
+
opt = self.para_cur * diff_x + self.para_pre * self.state_opt
|
|
88
|
+
|
|
89
|
+
# 同样使用 .detach() 截断历史图
|
|
90
|
+
self.state_opt = opt.detach()
|
|
91
|
+
|
|
92
|
+
return opt
|
|
93
|
+
|
|
94
|
+
def _compute_by_conv(self, diff_x):
|
|
95
|
+
"""FIR (有限脉冲响应) 卷积计算法 - 基于历史缓存"""
|
|
96
|
+
self.buffer.append(diff_x)
|
|
97
|
+
|
|
98
|
+
return compute_temporal_conv_inplace(self.buffer, self.frac_kernel)
|
xttmp/core/fstmd_core.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .math_operator import GammaDelay
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class FeedbackPathway(GammaDelay):
|
|
5
|
+
"""FeedbackPathway class for the feedback pathway."""
|
|
6
|
+
|
|
7
|
+
def __init__(self):
|
|
8
|
+
"""Constructor method."""
|
|
9
|
+
# Initializes the FeedbackPathway object
|
|
10
|
+
super().__init__(5, 10)
|
|
11
|
+
|
|
12
|
+
self.feedback_coefficient = 0.22
|
|
13
|
+
|
|
14
|
+
def forward(self, x):
|
|
15
|
+
return self.feedback_coefficient * super().forward(x)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from . import fracstmd_core
|
|
4
|
+
|
|
5
|
+
class Lamina(fracstmd_core.Lamina):
|
|
6
|
+
"""Lamina class for the lamina layer."""
|
|
7
|
+
|
|
8
|
+
def __init__(self):
|
|
9
|
+
"""Constructor method."""
|
|
10
|
+
# Initializes the Lamina object
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.loopLaminaOpt = None
|
|
13
|
+
self.isInLoop = False
|
|
14
|
+
|
|
15
|
+
def forward(self, LaminaIpt):
|
|
16
|
+
"""Processing method."""
|
|
17
|
+
# Processes the LaminaIpt to generate the lamina output
|
|
18
|
+
if self.preLaminaIpt is None:
|
|
19
|
+
diffLaminaIpt = np.zeros_like(LaminaIpt)
|
|
20
|
+
else:
|
|
21
|
+
# First order difference
|
|
22
|
+
diffLaminaIpt = LaminaIpt - self.preLaminaIpt
|
|
23
|
+
|
|
24
|
+
laminaOpt = self.compute_by_iteration(diffLaminaIpt)
|
|
25
|
+
self.Opt = laminaOpt
|
|
26
|
+
|
|
27
|
+
if not self.isInLoop:
|
|
28
|
+
self.preLaminaIpt = LaminaIpt
|
|
29
|
+
|
|
30
|
+
return laminaOpt
|
|
31
|
+
|
|
32
|
+
def compute_by_iteration(self, diffLaminaIpt):
|
|
33
|
+
"""Compute lamina output by iteration."""
|
|
34
|
+
if self.preLaminaOpt is None:
|
|
35
|
+
laminaopt = self.paraCur * diffLaminaIpt
|
|
36
|
+
else:
|
|
37
|
+
if not self.isInLoop:
|
|
38
|
+
self.preLaminaIpt = self.loopLaminaOpt
|
|
39
|
+
laminaopt = self.paraCur * diffLaminaIpt + self.paraPre * self.preLaminaOpt
|
|
40
|
+
|
|
41
|
+
self.loopLaminaOpt = laminaopt
|
|
42
|
+
return laminaopt
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from .math_operator import (compute_temporal_conv_inplace,
|
|
8
|
+
SpatialInhibition, GammaDelay)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Medulla(BaseCore):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.temporal_kernel_len = 15 # a
|
|
16
|
+
self.spatial_kernel_size = 7 # r
|
|
17
|
+
self.register_buffer('theta_list', torch.tensor([(i * torch.pi / 4) for i in range(8)]))
|
|
18
|
+
|
|
19
|
+
self.register_buffer('temporal_ON_kernel', torch.empty(0))
|
|
20
|
+
self.register_buffer('temporal_OFF_kernel', torch.empty(0))
|
|
21
|
+
|
|
22
|
+
self.delay_on = GammaDelay(10, 1)
|
|
23
|
+
self.delay_off = GammaDelay(10, 1)
|
|
24
|
+
|
|
25
|
+
self.medulla_input_buffer = deque(maxlen=self.temporal_kernel_len)
|
|
26
|
+
|
|
27
|
+
self.setup()
|
|
28
|
+
|
|
29
|
+
def setup(self):
|
|
30
|
+
|
|
31
|
+
# Temporal kernels
|
|
32
|
+
k1 = int(self.temporal_kernel_len / 2)
|
|
33
|
+
|
|
34
|
+
self.temporal_ON_kernel = torch.ones((k1, 1))
|
|
35
|
+
self.temporal_OFF_kernel = torch.vstack((torch.zeros((k1, 1)), -torch.ones((k1+1, 1))))
|
|
36
|
+
|
|
37
|
+
self.reset_buffer()
|
|
38
|
+
|
|
39
|
+
def reset_buffer(self):
|
|
40
|
+
# Allocate memory
|
|
41
|
+
self.medulla_input_buffer.clear()
|
|
42
|
+
self.delay_on.reset_buffer()
|
|
43
|
+
self.delay_off.reset_buffer()
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def direction_pooling(x, s=7):
|
|
47
|
+
"""
|
|
48
|
+
高效计算 8 个方向的感受野池化 (均值)
|
|
49
|
+
x: 输入特征图 [B, C, H, W]
|
|
50
|
+
s: 池化窗口大小 (建议为奇数,例如 3, 5, 7)
|
|
51
|
+
"""
|
|
52
|
+
H, W = x.shape[-2:]
|
|
53
|
+
|
|
54
|
+
# 1. 确定最大偏移量:窗口大小减 1
|
|
55
|
+
p = s - 1
|
|
56
|
+
|
|
57
|
+
# 2. 四周补齐 Padding
|
|
58
|
+
# 补齐后尺寸变为 (H + 2p, W + 2p)
|
|
59
|
+
padded_x = F.pad(x, (p, p, p, p))
|
|
60
|
+
|
|
61
|
+
# 3. 仅做一次全局 AvgPool
|
|
62
|
+
# 步长设为 1,输出尺寸会自动变成 (H + p, W + p)
|
|
63
|
+
pooled = F.avg_pool2d(padded_x, kernel_size=(s, s), stride=1, padding=0)
|
|
64
|
+
pooled_on = torch.clamp(pooled, min=0)
|
|
65
|
+
pooled_off = torch.clamp(-pooled, min=0)
|
|
66
|
+
|
|
67
|
+
# --- 4. 见证奇迹的切片时刻 ---
|
|
68
|
+
# 定义不同方向的起始坐标 (基于偏移量 p)
|
|
69
|
+
offset_min = 0 # 偏向上/左
|
|
70
|
+
offset_mid = p // 2 # 居中对齐
|
|
71
|
+
offset_max = p # 偏向下/右
|
|
72
|
+
|
|
73
|
+
spatial_ON_output = torch.cat([
|
|
74
|
+
pooled_on[..., offset_mid : offset_mid+H, offset_min : offset_min+W], # W
|
|
75
|
+
pooled_on[..., offset_max : offset_max+H, offset_min : offset_min+W], # SW
|
|
76
|
+
pooled_on[..., offset_max : offset_max+H, offset_mid : offset_mid+W], # S
|
|
77
|
+
pooled_on[..., offset_max : offset_max+H, offset_max : offset_max+W], # SE
|
|
78
|
+
pooled_on[..., offset_mid : offset_mid+H, offset_max : offset_max+W], # E
|
|
79
|
+
pooled_on[..., offset_min : offset_min+H, offset_max : offset_max+W], # NE
|
|
80
|
+
pooled_on[..., offset_min : offset_min+H, offset_mid : offset_mid+W], # N
|
|
81
|
+
pooled_on[..., offset_min : offset_min+H, offset_min : offset_min+W], # NW
|
|
82
|
+
], dim=1)
|
|
83
|
+
|
|
84
|
+
spatial_OFF_output = torch.cat([
|
|
85
|
+
pooled_off[..., offset_mid : offset_mid+H, offset_max : offset_max+W], # E
|
|
86
|
+
pooled_off[..., offset_min : offset_min+H, offset_max : offset_max+W], # NE
|
|
87
|
+
pooled_off[..., offset_min : offset_min+H, offset_mid : offset_mid+W], # N
|
|
88
|
+
pooled_off[..., offset_min : offset_min+H, offset_min : offset_min+W], # NW
|
|
89
|
+
pooled_off[..., offset_mid : offset_mid+H, offset_min : offset_min+W], # W
|
|
90
|
+
pooled_off[..., offset_max : offset_max+H, offset_min : offset_min+W], # SW
|
|
91
|
+
pooled_off[..., offset_max : offset_max+H, offset_mid : offset_mid+W], # S
|
|
92
|
+
pooled_off[..., offset_max : offset_max+H, offset_max : offset_max+W], # SE
|
|
93
|
+
], dim=1)
|
|
94
|
+
|
|
95
|
+
return spatial_ON_output, spatial_OFF_output
|
|
96
|
+
|
|
97
|
+
def forward(self, medullaIpt):
|
|
98
|
+
|
|
99
|
+
''' Compute temporal part '''
|
|
100
|
+
self.medulla_input_buffer.append(medullaIpt)
|
|
101
|
+
|
|
102
|
+
temporal_ON_output = compute_temporal_conv_inplace(self.medulla_input_buffer, self.temporal_ON_kernel)
|
|
103
|
+
temporal_OFF_output = compute_temporal_conv_inplace(self.medulla_input_buffer, self.temporal_OFF_kernel)
|
|
104
|
+
|
|
105
|
+
# There's no need for half-wave rectification here
|
|
106
|
+
correlated_temporal_output = temporal_ON_output * temporal_OFF_output
|
|
107
|
+
|
|
108
|
+
''' Compute spacial part '''
|
|
109
|
+
spatial_ON_output, spatial_OFF_output = self.direction_pooling(medullaIpt, self.spatial_kernel_size)
|
|
110
|
+
|
|
111
|
+
delayed_spatial_ON_output = self.delay_on.forward(spatial_ON_output)
|
|
112
|
+
delayed_spatial_OFF_output = self.delay_off.forward(spatial_OFF_output)
|
|
113
|
+
|
|
114
|
+
correlated_spatial_output = delayed_spatial_ON_output * delayed_spatial_OFF_output
|
|
115
|
+
|
|
116
|
+
# Store the output in output property
|
|
117
|
+
self.output = (correlated_spatial_output, correlated_temporal_output)
|
|
118
|
+
|
|
119
|
+
return self.output
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Lobula(BaseCore):
|
|
123
|
+
def __init__(self):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.tau = 1 # a parameter to align the spacialOpt and temporalOpt
|
|
126
|
+
self.spatial_inhibition = SpatialInhibition()
|
|
127
|
+
self.spatial_inhibition.B = 1
|
|
128
|
+
|
|
129
|
+
def setup(self):
|
|
130
|
+
self.spatial_inhibition.setup()
|
|
131
|
+
|
|
132
|
+
def forward(self, correlated_spatial_output, correlated_temporal_output):
|
|
133
|
+
correlated_spatiotemporal_output = correlated_spatial_output * correlated_temporal_output
|
|
134
|
+
|
|
135
|
+
# Apply surround inhibition
|
|
136
|
+
self.output = torch.clamp(self.spatial_inhibition(correlated_spatiotemporal_output), min=0)
|
|
137
|
+
|
|
138
|
+
return self.output
|
|
139
|
+
|
|
140
|
+
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Iterable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from .base_core import BaseCore
|
|
8
|
+
from ..util.create_kernel import create_2d_gaussian_kernel, create_gamma_kernel, create_spatial_inhibition_kernel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compute_temporal_conv_inplace(buffer_refs: Iterable[torch.Tensor],
|
|
12
|
+
time_kernel_tensor: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
"""
|
|
14
|
+
Efficiently computes a 1D temporal convolution over a sequence of spatial feature maps
|
|
15
|
+
using in-place accumulation and automatic sequence truncation.
|
|
16
|
+
|
|
17
|
+
This function is designed for streaming video or continuous time-series processing
|
|
18
|
+
where historical frames are stored in a FIFO queue (like collections.deque). It maps
|
|
19
|
+
a 1D temporal kernel across the batch and spatial dimensions of the buffered frames.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
buffer_refs (Iterable[torch.Tensor]): An iterable (e.g., list or deque) containing
|
|
23
|
+
historical feature map tensors.
|
|
24
|
+
- Expected shape of each tensor: [B, C, H, W]
|
|
25
|
+
- Temporal ordering: The rightmost (last) element is assumed to be the
|
|
26
|
+
newest/most recent frame.
|
|
27
|
+
time_kernel_tensor (torch.Tensor): A 1D tensor representing the temporal convolution
|
|
28
|
+
weights.
|
|
29
|
+
- Expected shape: [K], where K is the kernel size.
|
|
30
|
+
- Weight ordering: The leftmost (first) element corresponds to the newest frame.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
torch.Tensor: The result of the temporal convolution, with shape [B, C, H, W].
|
|
34
|
+
Returns None if `buffer_refs` is empty.
|
|
35
|
+
|
|
36
|
+
Notes:
|
|
37
|
+
- Memory Efficiency: Uses a clone for the initial base tensor, followed by in-place
|
|
38
|
+
additions (`add_`) for subsequent historical frames to minimize memory allocation.
|
|
39
|
+
- Automatic Truncation: The `zip` function automatically stops at the shortest
|
|
40
|
+
iterable. If the buffer has fewer frames than the kernel size (e.g., during the
|
|
41
|
+
initial "warm-up" phase of a stream), it safely computes a partial convolution
|
|
42
|
+
without out-of-bounds errors or requiring explicit padding.
|
|
43
|
+
"""
|
|
44
|
+
temporal_conv_out = None
|
|
45
|
+
|
|
46
|
+
# reversed(buffer_refs) iterates from the newest frame to the oldest frame.
|
|
47
|
+
# zip automatically aligns the newest frame with the first weight and truncates safely.
|
|
48
|
+
for t_tensor, weight in zip(reversed(buffer_refs), time_kernel_tensor):
|
|
49
|
+
|
|
50
|
+
weight_val = weight.item()
|
|
51
|
+
|
|
52
|
+
if temporal_conv_out is None:
|
|
53
|
+
# Initialize the base tensor using the newest frame
|
|
54
|
+
temporal_conv_out = t_tensor.clone().mul_(weight_val)
|
|
55
|
+
else:
|
|
56
|
+
# In-place accumulation of historical frames onto the base tensor
|
|
57
|
+
temporal_conv_out.add_(t_tensor, alpha=weight_val)
|
|
58
|
+
|
|
59
|
+
return temporal_conv_out
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GaussianBlur(BaseCore):
|
|
63
|
+
"""
|
|
64
|
+
Gaussian blur filter: Pure PyTorch implementation.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, kernel_size=3, sigma=1.0):
|
|
68
|
+
"""
|
|
69
|
+
Constructor.
|
|
70
|
+
Initializes the GaussianBlur module.
|
|
71
|
+
|
|
72
|
+
Parameters:
|
|
73
|
+
- kernel_size: Size of the filter kernel (int). Should be an odd number.
|
|
74
|
+
- sigma: Standard deviation of the Gaussian distribution (float).
|
|
75
|
+
"""
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.kernel_size = kernel_size
|
|
78
|
+
self.sigma = sigma
|
|
79
|
+
self.register_buffer('blur_kernel', torch.empty(0))
|
|
80
|
+
|
|
81
|
+
self.setup()
|
|
82
|
+
|
|
83
|
+
def setup(self):
|
|
84
|
+
_kernel = create_2d_gaussian_kernel(self.kernel_size, self.sigma)
|
|
85
|
+
self.blur_kernel.data = _kernel.view(1, 1, self.kernel_size, self.kernel_size)
|
|
86
|
+
|
|
87
|
+
def forward(self, x):
|
|
88
|
+
"""
|
|
89
|
+
Processing method.
|
|
90
|
+
Applies the Gaussian filter to the input tensor.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
- x: Input tensor of shape (B, C, H, W)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
- opt: Output after applying the Gaussian filter.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
C = x.shape[1]
|
|
100
|
+
|
|
101
|
+
# 动态将单通道高斯核扩展至与输入特征图通道数一致, expand 不占用额外显存
|
|
102
|
+
weight = self.blur_kernel.expand(C, 1, self.kernel_size, self.kernel_size)
|
|
103
|
+
|
|
104
|
+
# 使用深度可分离卷积(groups=C),每个通道独立进行高斯模糊
|
|
105
|
+
return F.conv2d(x, weight, padding='same', groups=C)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class GammaDelay(BaseCore):
|
|
109
|
+
"""
|
|
110
|
+
GammaDelay Class
|
|
111
|
+
|
|
112
|
+
Implements a gamma filter used in the lamina layer of the ESTMD neural network
|
|
113
|
+
using pure PyTorch and collections.deque for efficient temporal sliding windows.
|
|
114
|
+
"""
|
|
115
|
+
def __init__(self, order=1, tau=1.0, kernel_len=None):
|
|
116
|
+
"""
|
|
117
|
+
Constructor method.
|
|
118
|
+
|
|
119
|
+
Parameters:
|
|
120
|
+
order (int): Order of the gamma filter (n). Default is 1.
|
|
121
|
+
tau (float): Time constant of the filter (\tau).
|
|
122
|
+
kernel_len (int): Length of the filter kernel (T).
|
|
123
|
+
"""
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.order = max(1, int(order))
|
|
126
|
+
self.tau = tau
|
|
127
|
+
# 如果未指定长度,默认使用 3 * tau (覆盖大部分有效权重)
|
|
128
|
+
self.kernel_len = int(3 * tau) if kernel_len is None else kernel_len
|
|
129
|
+
|
|
130
|
+
self.setup()
|
|
131
|
+
|
|
132
|
+
def setup(self):
|
|
133
|
+
# 1. 预计算 Gamma 滤波器的时域权重
|
|
134
|
+
kernel = create_gamma_kernel(self.order, self.tau, self.kernel_len)
|
|
135
|
+
# 注册为 buffer,随模型自动转移设备 (如 .cuda())
|
|
136
|
+
self.register_buffer('gamma_kernel', kernel)
|
|
137
|
+
|
|
138
|
+
# 2. 初始化双端队列作为时序状态缓存区
|
|
139
|
+
self.buffer = deque(maxlen=self.kernel_len)
|
|
140
|
+
|
|
141
|
+
def reset_buffer(self):
|
|
142
|
+
"""
|
|
143
|
+
Resets the internal buffer by clearing all stored frames.
|
|
144
|
+
"""
|
|
145
|
+
self.buffer.clear()
|
|
146
|
+
|
|
147
|
+
def forward(self, x, in_loop=False):
|
|
148
|
+
"""
|
|
149
|
+
Processing method.
|
|
150
|
+
Applies the gamma filter to the input tensor.
|
|
151
|
+
|
|
152
|
+
Parameters:
|
|
153
|
+
- x: Input tensor of shape (B, C, H, W)
|
|
154
|
+
- in_loop (bool): If True, replaces the last frame instead of appending.
|
|
155
|
+
(Equivalent to original isInLoop/cover logic)
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
if in_loop and len(self.buffer) > 0:
|
|
159
|
+
# 替换队尾元素 (最新帧),保持缓存长度不变
|
|
160
|
+
self.buffer[-1] = x
|
|
161
|
+
else:
|
|
162
|
+
self.buffer.append(x)
|
|
163
|
+
|
|
164
|
+
return compute_temporal_conv_inplace(self.buffer, self.gamma_kernel)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class GammaBandPassFilter(BaseCore):
|
|
168
|
+
"""
|
|
169
|
+
GammaBandPassFilter: Temporal Band-pass filter for ESTMD.
|
|
170
|
+
|
|
171
|
+
Optimized pure PyTorch implementation. Uses a single deque buffer and
|
|
172
|
+
mathematically fuses the two Gamma filters into a single convolution kernel
|
|
173
|
+
to halve memory usage and computation time.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self,
|
|
177
|
+
order1=2, tau1=3.0,
|
|
178
|
+
order2=6, tau2=9.0,
|
|
179
|
+
kernel_len=None):
|
|
180
|
+
"""
|
|
181
|
+
Constructor method.
|
|
182
|
+
|
|
183
|
+
Parameters:
|
|
184
|
+
- order1, tau1: Parameters for the excitatory (positive) Gamma filter.
|
|
185
|
+
- order2, tau2: Parameters for the inhibitory (negative) Gamma filter.
|
|
186
|
+
- kernel_len: Temporal length of the filter. If None, auto-calculated.
|
|
187
|
+
"""
|
|
188
|
+
super().__init__()
|
|
189
|
+
|
|
190
|
+
self.order1 = max(1, int(order1))
|
|
191
|
+
self.tau1 = tau1
|
|
192
|
+
self.order2 = max(1, int(order2))
|
|
193
|
+
self.tau2 = tau2
|
|
194
|
+
|
|
195
|
+
# 自动计算所需的历史帧缓存最大长度
|
|
196
|
+
self.kernel_len = kernel_len if kernel_len is not None else max(int(3 * tau1), int(3 * tau2))
|
|
197
|
+
|
|
198
|
+
self.in_loop = False # 默认不覆盖历史帧,直接追加
|
|
199
|
+
|
|
200
|
+
self.setup()
|
|
201
|
+
|
|
202
|
+
def setup(self):
|
|
203
|
+
# 1. 预计算两个 Gamma 滤波器的权重,并补齐到相同的长度 self.T
|
|
204
|
+
k1 = create_gamma_kernel(self.order1, self.tau1, self.kernel_len)
|
|
205
|
+
k2 = create_gamma_kernel(self.order2, self.tau2, self.kernel_len)
|
|
206
|
+
|
|
207
|
+
# 2. 算子融合 (Operator Fusion):W_bandpass = W1 - W2
|
|
208
|
+
# 直接将差值注册为模型的 buffer,前向传播只需计算一次
|
|
209
|
+
bandpass_kernel = k1 - k2
|
|
210
|
+
self.register_buffer('bandpass_kernel', bandpass_kernel)
|
|
211
|
+
|
|
212
|
+
# 3. 初始化单一的高效时序状态缓存区
|
|
213
|
+
self.buffer = deque(maxlen=self.kernel_len)
|
|
214
|
+
|
|
215
|
+
def reset_buffer(self):
|
|
216
|
+
"""
|
|
217
|
+
Resets the internal buffer by clearing all stored frames.
|
|
218
|
+
"""
|
|
219
|
+
self.buffer.clear()
|
|
220
|
+
|
|
221
|
+
def forward(self, x):
|
|
222
|
+
"""
|
|
223
|
+
Processing method.
|
|
224
|
+
|
|
225
|
+
Parameters:
|
|
226
|
+
- x: Input tensor of shape (B, C, H, W) or (C, H, W) or (H, W)
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
- opt_tensor: Processed band-pass output tensor
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
# 1. 记录最新一帧
|
|
233
|
+
if self.in_loop and len(self.buffer) > 0:
|
|
234
|
+
# 替换队尾元素 (最新帧),保持缓存长度不变
|
|
235
|
+
self.buffer[-1] = x
|
|
236
|
+
else:
|
|
237
|
+
self.buffer.append(x)
|
|
238
|
+
|
|
239
|
+
return compute_temporal_conv_inplace(self.buffer, self.bandpass_kernel)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class SpatialInhibition(BaseCore):
|
|
243
|
+
"""
|
|
244
|
+
Gamma_Filter Gamma filter in lamina layer
|
|
245
|
+
Pure PyTorch implementation for Surround Inhibition.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(self,
|
|
249
|
+
kernel_size=15,
|
|
250
|
+
sigma1=1.5,
|
|
251
|
+
sigma2=3.0,
|
|
252
|
+
e=1.0,
|
|
253
|
+
rho=0.0,
|
|
254
|
+
A=1.0,
|
|
255
|
+
B=3.0):
|
|
256
|
+
"""
|
|
257
|
+
Constructor
|
|
258
|
+
Initializes the SurroundInhibition module.
|
|
259
|
+
|
|
260
|
+
Parameters:
|
|
261
|
+
- kernel_size: Size of the filter kernel
|
|
262
|
+
- sigma1: Standard deviation for the first Gaussian (Center)
|
|
263
|
+
- sigma1: Standard deviation for the second Gaussian (Surround)
|
|
264
|
+
- e: Exponent for the weighting of the second Gaussian
|
|
265
|
+
- rho: Radius for circular integration / Center offset
|
|
266
|
+
- A: Amplitude of the positive center
|
|
267
|
+
- B: Amplitude of the negative surround
|
|
268
|
+
"""
|
|
269
|
+
super().__init__()
|
|
270
|
+
self.kernel_size = kernel_size
|
|
271
|
+
self.sigma1 = sigma1
|
|
272
|
+
self.sigma2 = sigma2
|
|
273
|
+
self.e = e
|
|
274
|
+
self.rho = rho
|
|
275
|
+
self.A = A
|
|
276
|
+
self.B = B
|
|
277
|
+
|
|
278
|
+
self.register_buffer('kernel', torch.empty(0))
|
|
279
|
+
|
|
280
|
+
self.setup()
|
|
281
|
+
|
|
282
|
+
def setup(self):
|
|
283
|
+
_spatial_inhibiiton_kernel = create_spatial_inhibition_kernel(self.kernel_size,
|
|
284
|
+
self.sigma1,
|
|
285
|
+
self.sigma2,
|
|
286
|
+
self.e,
|
|
287
|
+
self.rho,
|
|
288
|
+
self.A,
|
|
289
|
+
self.B)
|
|
290
|
+
self.kernel.data = _spatial_inhibiiton_kernel.view(1, 1, self.kernel_size, self.kernel_size)
|
|
291
|
+
|
|
292
|
+
def forward(self, x):
|
|
293
|
+
"""
|
|
294
|
+
Processing method
|
|
295
|
+
Applies the surround inhibition filter to the input tensor.
|
|
296
|
+
|
|
297
|
+
Parameters:
|
|
298
|
+
- x: Input tensor of shape (B, C, H, W)
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
C = x.shape[1]
|
|
302
|
+
|
|
303
|
+
# .expand 不会真的在内存中复制数据,而是通过 stride 机制虚拟映射,极大地节省显存和耗时
|
|
304
|
+
weight = self.kernel.expand(C, 1, self.kernel_size, self.kernel_size)
|
|
305
|
+
|
|
306
|
+
# groups=C 表示进行深度可分离卷积(Depthwise Convolution),每个通道独立滤波
|
|
307
|
+
return F.relu(F.conv2d(x, weight, padding='same', groups=C))
|