xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .base_core import BaseCore
5
+ from .math_operator import SpatialInhibition, GammaDelay
6
+ from ..util.create_kernel import create_2d_gaussian_kernel
7
+
8
+
9
+ class Lobula(BaseCore):
10
+ """ Lobula layer of the motion detection system."""
11
+
12
+ def __init__(self):
13
+ """Constructor method."""
14
+ # Initializes the Lobula object
15
+ super().__init__()
16
+ self.spatial_inhibition = SpatialInhibition() # SpatialInhibition component
17
+ self.alpha = 1 # Parameter alpha
18
+ self.sigma = 1.5 # Parameters for Gaussian kernel
19
+
20
+ self.gamma_delay = GammaDelay(10, 25) # GammaDelay component
21
+
22
+ self.register_buffer('gaussian_kernel', torch.empty(0)) # Buffer for Gaussian kernel
23
+
24
+ self.setup()
25
+
26
+ def setup(self):
27
+ """ Initialization method."""
28
+ # Initializes the Lobula layer component
29
+ self.spatial_inhibition.setup()
30
+ self.gamma_delay.setup()
31
+
32
+ self.gaussian_kernel.data = create_2d_gaussian_kernel(size=3, sigma=self.sigma)
33
+
34
+ self.reset_buffer()
35
+
36
+ def reset_buffer(self):
37
+ """ Resets the buffer of certain components. """
38
+ self.gamma_delay.reset_buffer()
39
+
40
+ def forward(self, medulla_ON, medulla_OFF):
41
+ """ Processing method. """
42
+ # Performs temporal convolution, correlation, and surround inhibition
43
+
44
+ # Formula (9)
45
+ feedback_output = self.alpha * self.gamma_delay.forward(torch.zeros_like(medulla_ON))
46
+
47
+ # Formula (8)
48
+ ON_with_feedback = torch.clamp(medulla_ON - feedback_output, min=0)
49
+ OFF_with_feedback = torch.clamp(medulla_OFF - feedback_output, min=0)
50
+ correlation_D = ON_with_feedback * OFF_with_feedback
51
+
52
+ # Formula (10)
53
+ correlation_E = F.conv2d(medulla_ON * medulla_OFF, self.gaussian_kernel, padding='same')
54
+
55
+ # Only record (correlationD + correlationE) for next delay in Formula (9)
56
+ self.gamma_delay.buffer[-1] = correlation_D + correlation_E
57
+
58
+ # Formula (14)
59
+ self.output = self.spatial_inhibition(correlation_D)
60
+
61
+ return self.output
@@ -0,0 +1,98 @@
1
+ import math
2
+
3
+ import torch
4
+ from collections import deque
5
+
6
+ from .base_core import BaseCore
7
+ from ..util.create_kernel import create_fracdiff_kernel
8
+ from ..core.math_operator import compute_temporal_conv_inplace
9
+
10
+
11
+ class Lamina(BaseCore):
12
+ """
13
+ Lamina layer in ESTMD.
14
+ Pure PyTorch Implementation supporting both IIR (iteration) and FIR (convolution) modes.
15
+ """
16
+ def __init__(self, alpha=0.8, delta=20, mode='iteration'):
17
+ """
18
+ Constructor method.
19
+
20
+ Parameters:
21
+ - alpha: Fractional differential order (0 < alpha <= 1)
22
+ - delta: Length of the historical buffer for convolution
23
+ - mode: 'iteration' (IIR, fast) or 'conv' (FIR, accurate but slower)
24
+ """
25
+ super().__init__()
26
+ self.alpha = alpha
27
+ self.delta = delta
28
+ self.mode = mode
29
+
30
+ self.register_buffer('frac_kernel', torch.empty(0))
31
+
32
+ self.setup()
33
+
34
+ def setup(self):
35
+
36
+ _kernel = create_fracdiff_kernel(self.alpha, self.delta)
37
+ self.frac_kernel.data = _kernel
38
+
39
+ # 2. 计算迭代模式 (IIR) 的系数
40
+ self.para_cur = _kernel[0].item()
41
+
42
+ if self.alpha == 1.0:
43
+ self.para_pre = 0.0
44
+ elif 0.0 < self.alpha < 1.0:
45
+ self.para_pre = math.exp(-self.alpha / (1.0 - self.alpha))
46
+ else:
47
+ raise ValueError("Invalid alpha value. Must be in (0, 1].")
48
+
49
+ self.reset_buffer()
50
+
51
+ def reset_buffer(self):
52
+ # 3. 初始化时序状态 (State)
53
+ self.state_ipt = None
54
+ self.state_opt = None
55
+ self.buffer = deque(maxlen=self.delta)
56
+
57
+ def forward(self, x):
58
+ """
59
+ Processing method.
60
+ x shape: (B, C, H, W)
61
+ """
62
+
63
+ # --- 1. 计算一阶差分 (First order difference) ---
64
+ if self.state_ipt is None:
65
+ diff_x = torch.zeros_like(x)
66
+ else:
67
+ diff_x = x - self.state_ipt
68
+
69
+ # 使用 .detach() 截断计算图,防止处理长视频时 GPU 显存爆炸
70
+ self.state_ipt = x.detach()
71
+
72
+ # --- 2. 选择计算模式 ---
73
+ if self.mode == 'iteration':
74
+ self.output = self._compute_by_iteration(diff_x)
75
+ elif self.mode == 'conv':
76
+ self.output = self._compute_by_conv(diff_x)
77
+ else:
78
+ raise ValueError("Mode must be 'iteration' or 'conv'.")
79
+
80
+ return self.output
81
+
82
+ def _compute_by_iteration(self, diff_x):
83
+ """IIR (无限脉冲响应) 迭代计算法 - 极速模式"""
84
+ if self.state_opt is None:
85
+ opt = diff_x
86
+ else:
87
+ opt = self.para_cur * diff_x + self.para_pre * self.state_opt
88
+
89
+ # 同样使用 .detach() 截断历史图
90
+ self.state_opt = opt.detach()
91
+
92
+ return opt
93
+
94
+ def _compute_by_conv(self, diff_x):
95
+ """FIR (有限脉冲响应) 卷积计算法 - 基于历史缓存"""
96
+ self.buffer.append(diff_x)
97
+
98
+ return compute_temporal_conv_inplace(self.buffer, self.frac_kernel)
@@ -0,0 +1,15 @@
1
+ from .math_operator import GammaDelay
2
+
3
+
4
+ class FeedbackPathway(GammaDelay):
5
+ """FeedbackPathway class for the feedback pathway."""
6
+
7
+ def __init__(self):
8
+ """Constructor method."""
9
+ # Initializes the FeedbackPathway object
10
+ super().__init__(5, 10)
11
+
12
+ self.feedback_coefficient = 0.22
13
+
14
+ def forward(self, x):
15
+ return self.feedback_coefficient * super().forward(x)
@@ -0,0 +1,42 @@
1
+ import numpy as np
2
+
3
+ from . import fracstmd_core
4
+
5
+ class Lamina(fracstmd_core.Lamina):
6
+ """Lamina class for the lamina layer."""
7
+
8
+ def __init__(self):
9
+ """Constructor method."""
10
+ # Initializes the Lamina object
11
+ super().__init__()
12
+ self.loopLaminaOpt = None
13
+ self.isInLoop = False
14
+
15
+ def forward(self, LaminaIpt):
16
+ """Processing method."""
17
+ # Processes the LaminaIpt to generate the lamina output
18
+ if self.preLaminaIpt is None:
19
+ diffLaminaIpt = np.zeros_like(LaminaIpt)
20
+ else:
21
+ # First order difference
22
+ diffLaminaIpt = LaminaIpt - self.preLaminaIpt
23
+
24
+ laminaOpt = self.compute_by_iteration(diffLaminaIpt)
25
+ self.Opt = laminaOpt
26
+
27
+ if not self.isInLoop:
28
+ self.preLaminaIpt = LaminaIpt
29
+
30
+ return laminaOpt
31
+
32
+ def compute_by_iteration(self, diffLaminaIpt):
33
+ """Compute lamina output by iteration."""
34
+ if self.preLaminaOpt is None:
35
+ laminaopt = self.paraCur * diffLaminaIpt
36
+ else:
37
+ if not self.isInLoop:
38
+ self.preLaminaIpt = self.loopLaminaOpt
39
+ laminaopt = self.paraCur * diffLaminaIpt + self.paraPre * self.preLaminaOpt
40
+
41
+ self.loopLaminaOpt = laminaopt
42
+ return laminaopt
@@ -0,0 +1,140 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ from .base_core import BaseCore
7
+ from .math_operator import (compute_temporal_conv_inplace,
8
+ SpatialInhibition, GammaDelay)
9
+
10
+
11
+
12
+ class Medulla(BaseCore):
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.temporal_kernel_len = 15 # a
16
+ self.spatial_kernel_size = 7 # r
17
+ self.register_buffer('theta_list', torch.tensor([(i * torch.pi / 4) for i in range(8)]))
18
+
19
+ self.register_buffer('temporal_ON_kernel', torch.empty(0))
20
+ self.register_buffer('temporal_OFF_kernel', torch.empty(0))
21
+
22
+ self.delay_on = GammaDelay(10, 1)
23
+ self.delay_off = GammaDelay(10, 1)
24
+
25
+ self.medulla_input_buffer = deque(maxlen=self.temporal_kernel_len)
26
+
27
+ self.setup()
28
+
29
+ def setup(self):
30
+
31
+ # Temporal kernels
32
+ k1 = int(self.temporal_kernel_len / 2)
33
+
34
+ self.temporal_ON_kernel = torch.ones((k1, 1))
35
+ self.temporal_OFF_kernel = torch.vstack((torch.zeros((k1, 1)), -torch.ones((k1+1, 1))))
36
+
37
+ self.reset_buffer()
38
+
39
+ def reset_buffer(self):
40
+ # Allocate memory
41
+ self.medulla_input_buffer.clear()
42
+ self.delay_on.reset_buffer()
43
+ self.delay_off.reset_buffer()
44
+
45
+ @staticmethod
46
+ def direction_pooling(x, s=7):
47
+ """
48
+ 高效计算 8 个方向的感受野池化 (均值)
49
+ x: 输入特征图 [B, C, H, W]
50
+ s: 池化窗口大小 (建议为奇数,例如 3, 5, 7)
51
+ """
52
+ H, W = x.shape[-2:]
53
+
54
+ # 1. 确定最大偏移量:窗口大小减 1
55
+ p = s - 1
56
+
57
+ # 2. 四周补齐 Padding
58
+ # 补齐后尺寸变为 (H + 2p, W + 2p)
59
+ padded_x = F.pad(x, (p, p, p, p))
60
+
61
+ # 3. 仅做一次全局 AvgPool
62
+ # 步长设为 1,输出尺寸会自动变成 (H + p, W + p)
63
+ pooled = F.avg_pool2d(padded_x, kernel_size=(s, s), stride=1, padding=0)
64
+ pooled_on = torch.clamp(pooled, min=0)
65
+ pooled_off = torch.clamp(-pooled, min=0)
66
+
67
+ # --- 4. 见证奇迹的切片时刻 ---
68
+ # 定义不同方向的起始坐标 (基于偏移量 p)
69
+ offset_min = 0 # 偏向上/左
70
+ offset_mid = p // 2 # 居中对齐
71
+ offset_max = p # 偏向下/右
72
+
73
+ spatial_ON_output = torch.cat([
74
+ pooled_on[..., offset_mid : offset_mid+H, offset_min : offset_min+W], # W
75
+ pooled_on[..., offset_max : offset_max+H, offset_min : offset_min+W], # SW
76
+ pooled_on[..., offset_max : offset_max+H, offset_mid : offset_mid+W], # S
77
+ pooled_on[..., offset_max : offset_max+H, offset_max : offset_max+W], # SE
78
+ pooled_on[..., offset_mid : offset_mid+H, offset_max : offset_max+W], # E
79
+ pooled_on[..., offset_min : offset_min+H, offset_max : offset_max+W], # NE
80
+ pooled_on[..., offset_min : offset_min+H, offset_mid : offset_mid+W], # N
81
+ pooled_on[..., offset_min : offset_min+H, offset_min : offset_min+W], # NW
82
+ ], dim=1)
83
+
84
+ spatial_OFF_output = torch.cat([
85
+ pooled_off[..., offset_mid : offset_mid+H, offset_max : offset_max+W], # E
86
+ pooled_off[..., offset_min : offset_min+H, offset_max : offset_max+W], # NE
87
+ pooled_off[..., offset_min : offset_min+H, offset_mid : offset_mid+W], # N
88
+ pooled_off[..., offset_min : offset_min+H, offset_min : offset_min+W], # NW
89
+ pooled_off[..., offset_mid : offset_mid+H, offset_min : offset_min+W], # W
90
+ pooled_off[..., offset_max : offset_max+H, offset_min : offset_min+W], # SW
91
+ pooled_off[..., offset_max : offset_max+H, offset_mid : offset_mid+W], # S
92
+ pooled_off[..., offset_max : offset_max+H, offset_max : offset_max+W], # SE
93
+ ], dim=1)
94
+
95
+ return spatial_ON_output, spatial_OFF_output
96
+
97
+ def forward(self, medullaIpt):
98
+
99
+ ''' Compute temporal part '''
100
+ self.medulla_input_buffer.append(medullaIpt)
101
+
102
+ temporal_ON_output = compute_temporal_conv_inplace(self.medulla_input_buffer, self.temporal_ON_kernel)
103
+ temporal_OFF_output = compute_temporal_conv_inplace(self.medulla_input_buffer, self.temporal_OFF_kernel)
104
+
105
+ # There's no need for half-wave rectification here
106
+ correlated_temporal_output = temporal_ON_output * temporal_OFF_output
107
+
108
+ ''' Compute spacial part '''
109
+ spatial_ON_output, spatial_OFF_output = self.direction_pooling(medullaIpt, self.spatial_kernel_size)
110
+
111
+ delayed_spatial_ON_output = self.delay_on.forward(spatial_ON_output)
112
+ delayed_spatial_OFF_output = self.delay_off.forward(spatial_OFF_output)
113
+
114
+ correlated_spatial_output = delayed_spatial_ON_output * delayed_spatial_OFF_output
115
+
116
+ # Store the output in output property
117
+ self.output = (correlated_spatial_output, correlated_temporal_output)
118
+
119
+ return self.output
120
+
121
+
122
+ class Lobula(BaseCore):
123
+ def __init__(self):
124
+ super().__init__()
125
+ self.tau = 1 # a parameter to align the spacialOpt and temporalOpt
126
+ self.spatial_inhibition = SpatialInhibition()
127
+ self.spatial_inhibition.B = 1
128
+
129
+ def setup(self):
130
+ self.spatial_inhibition.setup()
131
+
132
+ def forward(self, correlated_spatial_output, correlated_temporal_output):
133
+ correlated_spatiotemporal_output = correlated_spatial_output * correlated_temporal_output
134
+
135
+ # Apply surround inhibition
136
+ self.output = torch.clamp(self.spatial_inhibition(correlated_spatiotemporal_output), min=0)
137
+
138
+ return self.output
139
+
140
+
@@ -0,0 +1,307 @@
1
+ from collections import deque
2
+ from typing import Iterable
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from .base_core import BaseCore
8
+ from ..util.create_kernel import create_2d_gaussian_kernel, create_gamma_kernel, create_spatial_inhibition_kernel
9
+
10
+
11
+ def compute_temporal_conv_inplace(buffer_refs: Iterable[torch.Tensor],
12
+ time_kernel_tensor: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Efficiently computes a 1D temporal convolution over a sequence of spatial feature maps
15
+ using in-place accumulation and automatic sequence truncation.
16
+
17
+ This function is designed for streaming video or continuous time-series processing
18
+ where historical frames are stored in a FIFO queue (like collections.deque). It maps
19
+ a 1D temporal kernel across the batch and spatial dimensions of the buffered frames.
20
+
21
+ Args:
22
+ buffer_refs (Iterable[torch.Tensor]): An iterable (e.g., list or deque) containing
23
+ historical feature map tensors.
24
+ - Expected shape of each tensor: [B, C, H, W]
25
+ - Temporal ordering: The rightmost (last) element is assumed to be the
26
+ newest/most recent frame.
27
+ time_kernel_tensor (torch.Tensor): A 1D tensor representing the temporal convolution
28
+ weights.
29
+ - Expected shape: [K], where K is the kernel size.
30
+ - Weight ordering: The leftmost (first) element corresponds to the newest frame.
31
+
32
+ Returns:
33
+ torch.Tensor: The result of the temporal convolution, with shape [B, C, H, W].
34
+ Returns None if `buffer_refs` is empty.
35
+
36
+ Notes:
37
+ - Memory Efficiency: Uses a clone for the initial base tensor, followed by in-place
38
+ additions (`add_`) for subsequent historical frames to minimize memory allocation.
39
+ - Automatic Truncation: The `zip` function automatically stops at the shortest
40
+ iterable. If the buffer has fewer frames than the kernel size (e.g., during the
41
+ initial "warm-up" phase of a stream), it safely computes a partial convolution
42
+ without out-of-bounds errors or requiring explicit padding.
43
+ """
44
+ temporal_conv_out = None
45
+
46
+ # reversed(buffer_refs) iterates from the newest frame to the oldest frame.
47
+ # zip automatically aligns the newest frame with the first weight and truncates safely.
48
+ for t_tensor, weight in zip(reversed(buffer_refs), time_kernel_tensor):
49
+
50
+ weight_val = weight.item()
51
+
52
+ if temporal_conv_out is None:
53
+ # Initialize the base tensor using the newest frame
54
+ temporal_conv_out = t_tensor.clone().mul_(weight_val)
55
+ else:
56
+ # In-place accumulation of historical frames onto the base tensor
57
+ temporal_conv_out.add_(t_tensor, alpha=weight_val)
58
+
59
+ return temporal_conv_out
60
+
61
+
62
+ class GaussianBlur(BaseCore):
63
+ """
64
+ Gaussian blur filter: Pure PyTorch implementation.
65
+ """
66
+
67
+ def __init__(self, kernel_size=3, sigma=1.0):
68
+ """
69
+ Constructor.
70
+ Initializes the GaussianBlur module.
71
+
72
+ Parameters:
73
+ - kernel_size: Size of the filter kernel (int). Should be an odd number.
74
+ - sigma: Standard deviation of the Gaussian distribution (float).
75
+ """
76
+ super().__init__()
77
+ self.kernel_size = kernel_size
78
+ self.sigma = sigma
79
+ self.register_buffer('blur_kernel', torch.empty(0))
80
+
81
+ self.setup()
82
+
83
+ def setup(self):
84
+ _kernel = create_2d_gaussian_kernel(self.kernel_size, self.sigma)
85
+ self.blur_kernel.data = _kernel.view(1, 1, self.kernel_size, self.kernel_size)
86
+
87
+ def forward(self, x):
88
+ """
89
+ Processing method.
90
+ Applies the Gaussian filter to the input tensor.
91
+
92
+ Parameters:
93
+ - x: Input tensor of shape (B, C, H, W)
94
+
95
+ Returns:
96
+ - opt: Output after applying the Gaussian filter.
97
+ """
98
+
99
+ C = x.shape[1]
100
+
101
+ # 动态将单通道高斯核扩展至与输入特征图通道数一致, expand 不占用额外显存
102
+ weight = self.blur_kernel.expand(C, 1, self.kernel_size, self.kernel_size)
103
+
104
+ # 使用深度可分离卷积(groups=C),每个通道独立进行高斯模糊
105
+ return F.conv2d(x, weight, padding='same', groups=C)
106
+
107
+
108
+ class GammaDelay(BaseCore):
109
+ """
110
+ GammaDelay Class
111
+
112
+ Implements a gamma filter used in the lamina layer of the ESTMD neural network
113
+ using pure PyTorch and collections.deque for efficient temporal sliding windows.
114
+ """
115
+ def __init__(self, order=1, tau=1.0, kernel_len=None):
116
+ """
117
+ Constructor method.
118
+
119
+ Parameters:
120
+ order (int): Order of the gamma filter (n). Default is 1.
121
+ tau (float): Time constant of the filter (\tau).
122
+ kernel_len (int): Length of the filter kernel (T).
123
+ """
124
+ super().__init__()
125
+ self.order = max(1, int(order))
126
+ self.tau = tau
127
+ # 如果未指定长度,默认使用 3 * tau (覆盖大部分有效权重)
128
+ self.kernel_len = int(3 * tau) if kernel_len is None else kernel_len
129
+
130
+ self.setup()
131
+
132
+ def setup(self):
133
+ # 1. 预计算 Gamma 滤波器的时域权重
134
+ kernel = create_gamma_kernel(self.order, self.tau, self.kernel_len)
135
+ # 注册为 buffer,随模型自动转移设备 (如 .cuda())
136
+ self.register_buffer('gamma_kernel', kernel)
137
+
138
+ # 2. 初始化双端队列作为时序状态缓存区
139
+ self.buffer = deque(maxlen=self.kernel_len)
140
+
141
+ def reset_buffer(self):
142
+ """
143
+ Resets the internal buffer by clearing all stored frames.
144
+ """
145
+ self.buffer.clear()
146
+
147
+ def forward(self, x, in_loop=False):
148
+ """
149
+ Processing method.
150
+ Applies the gamma filter to the input tensor.
151
+
152
+ Parameters:
153
+ - x: Input tensor of shape (B, C, H, W)
154
+ - in_loop (bool): If True, replaces the last frame instead of appending.
155
+ (Equivalent to original isInLoop/cover logic)
156
+ """
157
+
158
+ if in_loop and len(self.buffer) > 0:
159
+ # 替换队尾元素 (最新帧),保持缓存长度不变
160
+ self.buffer[-1] = x
161
+ else:
162
+ self.buffer.append(x)
163
+
164
+ return compute_temporal_conv_inplace(self.buffer, self.gamma_kernel)
165
+
166
+
167
+ class GammaBandPassFilter(BaseCore):
168
+ """
169
+ GammaBandPassFilter: Temporal Band-pass filter for ESTMD.
170
+
171
+ Optimized pure PyTorch implementation. Uses a single deque buffer and
172
+ mathematically fuses the two Gamma filters into a single convolution kernel
173
+ to halve memory usage and computation time.
174
+ """
175
+
176
+ def __init__(self,
177
+ order1=2, tau1=3.0,
178
+ order2=6, tau2=9.0,
179
+ kernel_len=None):
180
+ """
181
+ Constructor method.
182
+
183
+ Parameters:
184
+ - order1, tau1: Parameters for the excitatory (positive) Gamma filter.
185
+ - order2, tau2: Parameters for the inhibitory (negative) Gamma filter.
186
+ - kernel_len: Temporal length of the filter. If None, auto-calculated.
187
+ """
188
+ super().__init__()
189
+
190
+ self.order1 = max(1, int(order1))
191
+ self.tau1 = tau1
192
+ self.order2 = max(1, int(order2))
193
+ self.tau2 = tau2
194
+
195
+ # 自动计算所需的历史帧缓存最大长度
196
+ self.kernel_len = kernel_len if kernel_len is not None else max(int(3 * tau1), int(3 * tau2))
197
+
198
+ self.in_loop = False # 默认不覆盖历史帧,直接追加
199
+
200
+ self.setup()
201
+
202
+ def setup(self):
203
+ # 1. 预计算两个 Gamma 滤波器的权重,并补齐到相同的长度 self.T
204
+ k1 = create_gamma_kernel(self.order1, self.tau1, self.kernel_len)
205
+ k2 = create_gamma_kernel(self.order2, self.tau2, self.kernel_len)
206
+
207
+ # 2. 算子融合 (Operator Fusion):W_bandpass = W1 - W2
208
+ # 直接将差值注册为模型的 buffer,前向传播只需计算一次
209
+ bandpass_kernel = k1 - k2
210
+ self.register_buffer('bandpass_kernel', bandpass_kernel)
211
+
212
+ # 3. 初始化单一的高效时序状态缓存区
213
+ self.buffer = deque(maxlen=self.kernel_len)
214
+
215
+ def reset_buffer(self):
216
+ """
217
+ Resets the internal buffer by clearing all stored frames.
218
+ """
219
+ self.buffer.clear()
220
+
221
+ def forward(self, x):
222
+ """
223
+ Processing method.
224
+
225
+ Parameters:
226
+ - x: Input tensor of shape (B, C, H, W) or (C, H, W) or (H, W)
227
+
228
+ Returns:
229
+ - opt_tensor: Processed band-pass output tensor
230
+ """
231
+
232
+ # 1. 记录最新一帧
233
+ if self.in_loop and len(self.buffer) > 0:
234
+ # 替换队尾元素 (最新帧),保持缓存长度不变
235
+ self.buffer[-1] = x
236
+ else:
237
+ self.buffer.append(x)
238
+
239
+ return compute_temporal_conv_inplace(self.buffer, self.bandpass_kernel)
240
+
241
+
242
+ class SpatialInhibition(BaseCore):
243
+ """
244
+ Gamma_Filter Gamma filter in lamina layer
245
+ Pure PyTorch implementation for Surround Inhibition.
246
+ """
247
+
248
+ def __init__(self,
249
+ kernel_size=15,
250
+ sigma1=1.5,
251
+ sigma2=3.0,
252
+ e=1.0,
253
+ rho=0.0,
254
+ A=1.0,
255
+ B=3.0):
256
+ """
257
+ Constructor
258
+ Initializes the SurroundInhibition module.
259
+
260
+ Parameters:
261
+ - kernel_size: Size of the filter kernel
262
+ - sigma1: Standard deviation for the first Gaussian (Center)
263
+ - sigma1: Standard deviation for the second Gaussian (Surround)
264
+ - e: Exponent for the weighting of the second Gaussian
265
+ - rho: Radius for circular integration / Center offset
266
+ - A: Amplitude of the positive center
267
+ - B: Amplitude of the negative surround
268
+ """
269
+ super().__init__()
270
+ self.kernel_size = kernel_size
271
+ self.sigma1 = sigma1
272
+ self.sigma2 = sigma2
273
+ self.e = e
274
+ self.rho = rho
275
+ self.A = A
276
+ self.B = B
277
+
278
+ self.register_buffer('kernel', torch.empty(0))
279
+
280
+ self.setup()
281
+
282
+ def setup(self):
283
+ _spatial_inhibiiton_kernel = create_spatial_inhibition_kernel(self.kernel_size,
284
+ self.sigma1,
285
+ self.sigma2,
286
+ self.e,
287
+ self.rho,
288
+ self.A,
289
+ self.B)
290
+ self.kernel.data = _spatial_inhibiiton_kernel.view(1, 1, self.kernel_size, self.kernel_size)
291
+
292
+ def forward(self, x):
293
+ """
294
+ Processing method
295
+ Applies the surround inhibition filter to the input tensor.
296
+
297
+ Parameters:
298
+ - x: Input tensor of shape (B, C, H, W)
299
+ """
300
+
301
+ C = x.shape[1]
302
+
303
+ # .expand 不会真的在内存中复制数据,而是通过 stride 机制虚拟映射,极大地节省显存和耗时
304
+ weight = self.kernel.expand(C, 1, self.kernel_size, self.kernel_size)
305
+
306
+ # groups=C 表示进行深度可分离卷积(Depthwise Convolution),每个通道独立滤波
307
+ return F.relu(F.conv2d(x, weight, padding='same', groups=C))