xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from .base_core import BaseCore
|
|
2
|
+
from .math_operator import SpatialInhibition, GammaDelay
|
|
3
|
+
from . import estmd_backbone
|
|
4
|
+
from ..util.create_kernel import *
|
|
5
|
+
from ..util.compute_module import slice_matrix_holding_size
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Medulla(estmd_backbone.Medulla):
|
|
9
|
+
# Medulla layer of the motion detection system
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# Constructor method
|
|
13
|
+
# Initializes the Medulla object
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
self.hPara5Mi1 = None
|
|
17
|
+
self.hPara5Tm1 = None
|
|
18
|
+
self.cellTm1Ipt = None
|
|
19
|
+
|
|
20
|
+
def setup(self):
|
|
21
|
+
# Initialization method
|
|
22
|
+
# This method initializes the Medulla layer components
|
|
23
|
+
super().setup()
|
|
24
|
+
|
|
25
|
+
self.hTm1 = GammaDelay(5, 25)
|
|
26
|
+
self.hPara5Mi1 = GammaDelay(25, 30)
|
|
27
|
+
self.hPara5Tm1 = GammaDelay(25, 30)
|
|
28
|
+
|
|
29
|
+
self.cellTm1Ipt = CircularList()
|
|
30
|
+
|
|
31
|
+
self.hTm1.setup(False)
|
|
32
|
+
self.hPara5Mi1.setup()
|
|
33
|
+
self.hPara5Tm1.setup(False)
|
|
34
|
+
|
|
35
|
+
if not self.cellTm1Ipt.initLen:
|
|
36
|
+
self.cellTm1Ipt.initLen = max(self.hPara5Mi1.lenKernel, self.hPara5Tm1.lenKernel)
|
|
37
|
+
|
|
38
|
+
self.cellTm1Ipt.reset()
|
|
39
|
+
|
|
40
|
+
def forward(self, MedullaIpt):
|
|
41
|
+
# Processing method
|
|
42
|
+
# Applies processing to the input and returns the output
|
|
43
|
+
# Process Tm2 and Tm3 components
|
|
44
|
+
tm2Signal = self.hTm2.forward(MedullaIpt)
|
|
45
|
+
tm3Signal = self.hTm3.forward(MedullaIpt)
|
|
46
|
+
|
|
47
|
+
# Process Tm1 component using output of Tm2
|
|
48
|
+
self.cellTm1Ipt.record_next(tm3Signal)
|
|
49
|
+
tm1Para3Signal = self.hTm1.forward(self.cellTm1Ipt)
|
|
50
|
+
tm1Para5Signal = self.hPara5Tm1.forward(self.cellTm1Ipt)
|
|
51
|
+
|
|
52
|
+
# Process Mi1 component using output of Tm3
|
|
53
|
+
mi1Para5Signal = self.hPara5Mi1.forward(tm3Signal)
|
|
54
|
+
|
|
55
|
+
# Store the output signals in Opt property
|
|
56
|
+
varageout = [tm3Signal, tm1Para3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, self.hPara5Mi1.tau]
|
|
57
|
+
self.Opt = varageout
|
|
58
|
+
return varageout
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Lobula(BaseCore):
|
|
62
|
+
# Lobula layer of the motion detection system
|
|
63
|
+
|
|
64
|
+
def __init__(self):
|
|
65
|
+
# Constructor method
|
|
66
|
+
# Initializes the Lobula object
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
self.hSTMD = None
|
|
70
|
+
self.hLPTC = None
|
|
71
|
+
|
|
72
|
+
def setup(self):
|
|
73
|
+
# Initialization method
|
|
74
|
+
# This method initializes the Lobula layer component
|
|
75
|
+
self.hSTMD = Stmdcell()
|
|
76
|
+
self.hLPTC = Lptcell()
|
|
77
|
+
self.hSTMD.setup()
|
|
78
|
+
self.hLPTC.setup(self.hSTMD.hGammaDelay.lenKernel)
|
|
79
|
+
|
|
80
|
+
def forward(self, varagein):
|
|
81
|
+
# Processing method
|
|
82
|
+
# Performs a correlation operation on the ON and OFF channels
|
|
83
|
+
# and then applies surround inhibition
|
|
84
|
+
|
|
85
|
+
# Extract ON and OFF channel signals from the input
|
|
86
|
+
tm3Signal, tm1Para3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, tau5 = varagein
|
|
87
|
+
|
|
88
|
+
psi, fai = self.hLPTC.forward(tm3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, tau5)
|
|
89
|
+
|
|
90
|
+
lobulaOpt = self.hSTMD.forward(tm3Signal, tm1Para3Signal, psi, fai)
|
|
91
|
+
|
|
92
|
+
# Store the output in Opt property
|
|
93
|
+
self.Opt = lobulaOpt
|
|
94
|
+
|
|
95
|
+
return lobulaOpt, []
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Stmdcell(BaseCore):
|
|
99
|
+
# Lobula layer of the motion detection system
|
|
100
|
+
|
|
101
|
+
def __init__(self):
|
|
102
|
+
# Constructor method
|
|
103
|
+
# Initializes the Lobula object
|
|
104
|
+
super().__init__()
|
|
105
|
+
|
|
106
|
+
self.hSubInhi = None # SpatialInhibition component
|
|
107
|
+
self.alpha = 0.1 # Parameter alpha
|
|
108
|
+
self.gaussKernel = None # Gaussian kernel
|
|
109
|
+
self.hGammaDelay = None
|
|
110
|
+
self.cellDPlusE = None
|
|
111
|
+
self.paraGaussKernel = {'size': 3, 'eta': 1.5}
|
|
112
|
+
|
|
113
|
+
def setup(self):
|
|
114
|
+
# Initialization method
|
|
115
|
+
# This method initializes the Lobula layer component
|
|
116
|
+
self.hSubInhi = SpatialInhibition()
|
|
117
|
+
self.hGammaDelay = GammaDelay(6, 12)
|
|
118
|
+
self.cellDPlusE = CircularList()
|
|
119
|
+
|
|
120
|
+
self.hSubInhi.setup()
|
|
121
|
+
self.hGammaDelay.setup()
|
|
122
|
+
|
|
123
|
+
if not self.cellDPlusE.initLen:
|
|
124
|
+
self.cellDPlusE.initLen = self.hGammaDelay.lenKernel
|
|
125
|
+
self.cellDPlusE.reset()
|
|
126
|
+
|
|
127
|
+
self.gaussKernel = gaussian_filter(
|
|
128
|
+
np.zeros((self.paraGaussKernel['size'], self.paraGaussKernel['size'])),
|
|
129
|
+
self.paraGaussKernel['eta']
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def forward(self, tm3Signal, tm1Signal, faiList, psiList):
|
|
133
|
+
# Processing method
|
|
134
|
+
# Performs temporal convolution, correlation, and surround inhibition
|
|
135
|
+
convnIpt = [None] * self.cellDPlusE.initLen
|
|
136
|
+
|
|
137
|
+
for idxT in range(len(convnIpt)-1, -1, -1):
|
|
138
|
+
pointer = self.cellDPlusE.pointer
|
|
139
|
+
if self.cellDPlusE[pointer] is not None:
|
|
140
|
+
fai = faiList[idxT]
|
|
141
|
+
psi = psiList[idxT]
|
|
142
|
+
convnIpt[idxT] = slice_matrix_holding_size(self.cellDPlusE[pointer], psi, fai)
|
|
143
|
+
pointer = (pointer - 1) % self.cellDPlusE.initLen
|
|
144
|
+
|
|
145
|
+
feedbackSignal = self.hGammaDelay.forward_list(convnIpt)
|
|
146
|
+
|
|
147
|
+
if feedbackSignal is not None:
|
|
148
|
+
feedbackSignal *= self.alpha
|
|
149
|
+
correlationD = np.maximum(tm3Signal - feedbackSignal, 0) * np.maximum(tm1Signal - feedbackSignal, 0)
|
|
150
|
+
else:
|
|
151
|
+
correlationD = np.maximum(tm3Signal, 0) * np.maximum(tm1Signal, 0)
|
|
152
|
+
|
|
153
|
+
correlationE = filter2D(tm3Signal * tm1Signal, -1, self.gaussKernel, borderType=BORDER_CONSTANT)
|
|
154
|
+
|
|
155
|
+
lateralInhiSTMDOpt = self.hSubInhi.forward(correlationD)
|
|
156
|
+
|
|
157
|
+
self.cellDPlusE.record_next(correlationD + correlationE)
|
|
158
|
+
|
|
159
|
+
self.Opt = lateralInhiSTMDOpt
|
|
160
|
+
|
|
161
|
+
return lateralInhiSTMDOpt
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Lptcell(BaseCore):
|
|
165
|
+
# Lptcell Lobula Plate Tangential Cell
|
|
166
|
+
|
|
167
|
+
def __init__(self):
|
|
168
|
+
# Constructor method
|
|
169
|
+
# Initializes the Lobula object
|
|
170
|
+
super().__init__()
|
|
171
|
+
|
|
172
|
+
self.bataList = list(range(2, 19, 2))
|
|
173
|
+
self.thetaList = np.arange(0, 2 * np.pi, np.pi / 4)
|
|
174
|
+
self.velocity = None
|
|
175
|
+
self.tuningCurvef = None
|
|
176
|
+
|
|
177
|
+
def setup(self, lenVelocity):
|
|
178
|
+
self.velocity = np.zeros(lenVelocity)
|
|
179
|
+
|
|
180
|
+
lenBataList = len(self.bataList)
|
|
181
|
+
lenThetaList = len(self.thetaList)
|
|
182
|
+
# generate gauss distribution
|
|
183
|
+
gaussianDistribution = np.exp(-0.5 * ((np.arange(-199, 201) - 1) / (100 / 2)) ** 2)
|
|
184
|
+
# normalization
|
|
185
|
+
gaussianDistribution /= np.max(gaussianDistribution)
|
|
186
|
+
|
|
187
|
+
self.tuningCurvef = np.zeros((lenBataList, lenThetaList * 100 + 200))
|
|
188
|
+
self.tuningCurvef[0, :300] = gaussianDistribution[100:400]
|
|
189
|
+
self.tuningCurvef[-1, -300:] = gaussianDistribution[:300]
|
|
190
|
+
for id in range(1, lenBataList - 1):
|
|
191
|
+
idRange = slice((id+1) * 100 - 200, (id+1) * 100 + 200)
|
|
192
|
+
self.tuningCurvef[id, idRange] = gaussianDistribution
|
|
193
|
+
|
|
194
|
+
def forward(self, tm1Signal, tm2Signal, tm3Signal, mi1Signal, tau5):
|
|
195
|
+
lenBataList = len(self.bataList)
|
|
196
|
+
lenThetaList = len(self.thetaList)
|
|
197
|
+
sumLplcOptR = np.zeros((lenBataList, lenThetaList))
|
|
198
|
+
|
|
199
|
+
for idBata, bata in enumerate(self.bataList):
|
|
200
|
+
for idTheta, theta in enumerate(self.thetaList):
|
|
201
|
+
shiftX = np.round(bata * np.cos(theta + np.pi / 2)).astype(int)
|
|
202
|
+
shiftY = np.round(bata * np.sin(theta + np.pi / 2)).astype(int)
|
|
203
|
+
shiftMi1Signal = slice_matrix_holding_size(mi1Signal, shiftY, shiftX)
|
|
204
|
+
shiftTm1Signal = slice_matrix_holding_size(tm1Signal, shiftY, shiftX)
|
|
205
|
+
|
|
206
|
+
ltlcOpt = tm3Signal * shiftMi1Signal + tm2Signal * shiftTm1Signal
|
|
207
|
+
|
|
208
|
+
sumLplcOpt = np.sum(ltlcOpt)
|
|
209
|
+
|
|
210
|
+
sumLplcOptR[idBata, idTheta] = sumLplcOpt
|
|
211
|
+
|
|
212
|
+
# preferTheta
|
|
213
|
+
firingRate = np.max(sumLplcOptR, axis=1)
|
|
214
|
+
preferTheta = np.argmax(sumLplcOptR, axis=1)
|
|
215
|
+
maxTheta = np.max(preferTheta)
|
|
216
|
+
|
|
217
|
+
# background velocity
|
|
218
|
+
self.velocity = np.roll(self.velocity, -1)
|
|
219
|
+
temp = [np.sum((firingRate[i] - self.tuningCurvef[i,:]) ** 2) for i in range(len(firingRate))]
|
|
220
|
+
self.velocity[-1] = np.argmin(temp)
|
|
221
|
+
|
|
222
|
+
sumV = np.zeros_like(self.velocity)
|
|
223
|
+
for idV in range(len(self.velocity)):
|
|
224
|
+
sumV[idV] = np.sum(self.velocity[idV:])
|
|
225
|
+
|
|
226
|
+
fai = sumV * np.cos(maxTheta)
|
|
227
|
+
psi = sumV * np.sin(maxTheta)
|
|
228
|
+
|
|
229
|
+
self.Opt = [fai, psi]
|
|
230
|
+
|
|
231
|
+
return fai, psi
|
|
232
|
+
|
|
233
|
+
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn import functional as F
|
|
3
|
+
|
|
4
|
+
from .base_core import BaseCore
|
|
5
|
+
from ..util.create_kernel import create_T1_kernels
|
|
6
|
+
from ..util.compute_module import AreaNMS
|
|
7
|
+
from ..util.compute_module import compute_response
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContrastPathway(BaseCore):
|
|
11
|
+
"""ContrastPathway class for ApgSTMD - PyTorch Version."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
"""Constructor method."""
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
self.theta = torch.tensor([0, torch.pi/4, torch.pi/2, 3*torch.pi/4])
|
|
18
|
+
self.alpha2 = 1.5
|
|
19
|
+
self.eta = 3
|
|
20
|
+
self.sizeT1 = 11
|
|
21
|
+
|
|
22
|
+
self.register_buffer('T1_kernel', torch.empty(0))
|
|
23
|
+
|
|
24
|
+
def setup(self):
|
|
25
|
+
"""Initialization method."""
|
|
26
|
+
# 假设 create_T1_kernels 返回的是 4 个 kernel 的 List 或 NumPy array
|
|
27
|
+
_T1_kernel = create_T1_kernels(len(self.theta), self.alpha2, self.eta, self.sizeT1)
|
|
28
|
+
self.T1_kernel.data = _T1_kernel
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
"""
|
|
32
|
+
Processing method.
|
|
33
|
+
retinaOpt: 可以是 (H, W) 或 (1, 1, H, W) 的 PyTorch Tensor
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# out 的形状是 (1, 4, H, W)
|
|
37
|
+
self.Opt = F.conv2d(x, self.T1_kernel, padding='same')
|
|
38
|
+
|
|
39
|
+
return self.Opt
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class MushroomBody(BaseCore):
|
|
43
|
+
# MushroomBody class for STMDPlus - Ultimate Vectorized PyTorch Version
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.nms_size = 5
|
|
48
|
+
|
|
49
|
+
self.DBSCANDist = 5.0
|
|
50
|
+
self.lenDBSCAN = 100
|
|
51
|
+
self.SDThres = 5.0
|
|
52
|
+
|
|
53
|
+
self.torch_nms = None
|
|
54
|
+
|
|
55
|
+
# ================= 终极优化:预分配的全张量状态机 =================
|
|
56
|
+
self.C = None # 通道数 (将在第一帧自动推断)
|
|
57
|
+
self.trackID = None # 张量: [N, D] 坐标
|
|
58
|
+
self.trackInfo = None # 张量: [N, C, lenDBSCAN] (固定大小环形缓冲区)
|
|
59
|
+
self.trackLens = None # 张量: [N] (记录当前轨迹有效长度)
|
|
60
|
+
self.trackPtr = None # 张量: [N] (记录环形缓冲区的写入指针)
|
|
61
|
+
|
|
62
|
+
def setup(self):
|
|
63
|
+
self.torch_nms = AreaNMS(self.nms_size)
|
|
64
|
+
|
|
65
|
+
def forward(self, lobulaOpt, contrast_tensor):
|
|
66
|
+
device = lobulaOpt.device
|
|
67
|
+
|
|
68
|
+
maxLobulaOpt = compute_response(lobulaOpt)
|
|
69
|
+
nmsLobulaOpt = self.torch_nms(maxLobulaOpt)
|
|
70
|
+
|
|
71
|
+
mask_not_nms = (nmsLobulaOpt == 0)
|
|
72
|
+
mushroomBodyOpt = lobulaOpt * mask_not_nms
|
|
73
|
+
|
|
74
|
+
maxNumber = torch.max(nmsLobulaOpt)
|
|
75
|
+
if maxNumber <= 0:
|
|
76
|
+
self.trackID = None
|
|
77
|
+
return mushroomBodyOpt
|
|
78
|
+
|
|
79
|
+
# --- 获取新检测点 ---
|
|
80
|
+
newID = torch.nonzero(nmsLobulaOpt > 0).float()
|
|
81
|
+
if len(newID) == 0:
|
|
82
|
+
self.trackID = None
|
|
83
|
+
return mushroomBodyOpt
|
|
84
|
+
|
|
85
|
+
curr_y, curr_x = newID[:, -2].long(), newID[:, -1].long()
|
|
86
|
+
# all_new_contrasts 形状 [C, M], M是新目标数
|
|
87
|
+
all_new_contrasts = contrast_tensor[:, 0, curr_y, curr_x]
|
|
88
|
+
|
|
89
|
+
if self.C is None:
|
|
90
|
+
self.C = contrast_tensor.shape[0]
|
|
91
|
+
|
|
92
|
+
M = len(newID)
|
|
93
|
+
|
|
94
|
+
# ================= 1. 轨迹匹配 (保持 CPU 高效碰撞处理) =================
|
|
95
|
+
matched_old, matched_new, used_new = [], [], set()
|
|
96
|
+
|
|
97
|
+
if self.trackID is not None and len(self.trackID) > 0:
|
|
98
|
+
DD = torch.cdist(self.trackID[:, -2:], newID[:, -2:])
|
|
99
|
+
D1, min_idx = torch.min(DD, dim=1)
|
|
100
|
+
|
|
101
|
+
# 转移到 CPU 做极速冲突判定
|
|
102
|
+
D1_cpu, min_idx_cpu = D1.cpu().numpy(), min_idx.cpu().numpy()
|
|
103
|
+
for i, d1 in enumerate(D1_cpu):
|
|
104
|
+
if d1 <= self.DBSCANDist:
|
|
105
|
+
j = min_idx_cpu[i]
|
|
106
|
+
if j not in used_new:
|
|
107
|
+
used_new.add(j)
|
|
108
|
+
matched_old.append(i)
|
|
109
|
+
matched_new.append(j)
|
|
110
|
+
|
|
111
|
+
# ================= 2. 环形缓冲区批量状态更新 (0 For循环) =================
|
|
112
|
+
matched_old_ts = torch.tensor(matched_old, dtype=torch.long, device=device)
|
|
113
|
+
matched_new_ts = torch.tensor(matched_new, dtype=torch.long, device=device)
|
|
114
|
+
unmatched_new_ts = torch.tensor([j for j in range(M) if j not in used_new], dtype=torch.long, device=device)
|
|
115
|
+
|
|
116
|
+
if self.trackID is not None and len(matched_old) > 0:
|
|
117
|
+
# --- 提取续航的轨迹状态 ---
|
|
118
|
+
next_trackID = newID[matched_new_ts]
|
|
119
|
+
next_trackInfo = self.trackInfo[matched_old_ts]
|
|
120
|
+
next_trackLens = self.trackLens[matched_old_ts]
|
|
121
|
+
next_trackPtr = self.trackPtr[matched_old_ts]
|
|
122
|
+
|
|
123
|
+
# 批量写入环形缓冲区 (全矩阵操作)
|
|
124
|
+
batch_idx = torch.arange(len(matched_old), device=device)
|
|
125
|
+
next_trackInfo[batch_idx, :, next_trackPtr] = all_new_contrasts[:, matched_new_ts].T
|
|
126
|
+
next_trackLens = torch.clamp(next_trackLens + 1, max=self.lenDBSCAN)
|
|
127
|
+
next_trackPtr = (next_trackPtr + 1) % self.lenDBSCAN
|
|
128
|
+
else:
|
|
129
|
+
# 定义空张量用于拼接
|
|
130
|
+
next_trackID = torch.empty((0, newID.shape[1]), device=device)
|
|
131
|
+
next_trackInfo = torch.empty((0, self.C, self.lenDBSCAN), device=device)
|
|
132
|
+
next_trackLens = torch.empty((0,), dtype=torch.long, device=device)
|
|
133
|
+
next_trackPtr = torch.empty((0,), dtype=torch.long, device=device)
|
|
134
|
+
|
|
135
|
+
if len(unmatched_new_ts) > 0:
|
|
136
|
+
# --- 批量初始化新轨迹 ---
|
|
137
|
+
N_add = len(unmatched_new_ts)
|
|
138
|
+
add_trackID = newID[unmatched_new_ts]
|
|
139
|
+
add_trackInfo = torch.zeros((N_add, self.C, self.lenDBSCAN), device=device)
|
|
140
|
+
add_trackInfo[:, :, 0] = all_new_contrasts[:, unmatched_new_ts].T
|
|
141
|
+
add_trackLens = torch.ones(N_add, dtype=torch.long, device=device)
|
|
142
|
+
add_trackPtr = torch.ones(N_add, dtype=torch.long, device=device)
|
|
143
|
+
|
|
144
|
+
# --- 与续航的轨迹合并 ---
|
|
145
|
+
self.trackID = torch.cat([next_trackID, add_trackID], dim=0)
|
|
146
|
+
self.trackInfo = torch.cat([next_trackInfo, add_trackInfo], dim=0)
|
|
147
|
+
self.trackLens = torch.cat([next_trackLens, add_trackLens], dim=0)
|
|
148
|
+
self.trackPtr = torch.cat([next_trackPtr, add_trackPtr], dim=0)
|
|
149
|
+
else:
|
|
150
|
+
self.trackID, self.trackInfo, self.trackLens, self.trackPtr = next_trackID, next_trackInfo, next_trackLens, next_trackPtr
|
|
151
|
+
|
|
152
|
+
# ================= 3. 并行 STD 计算与批量擦除 =================
|
|
153
|
+
if self.trackID is not None and len(self.trackID) > 0:
|
|
154
|
+
valid_mask = self.trackLens > 1
|
|
155
|
+
if valid_mask.any():
|
|
156
|
+
# 只取出长度 > 1 的轨迹
|
|
157
|
+
chk_info = self.trackInfo[valid_mask] # [K, C, 100]
|
|
158
|
+
chk_lens = self.trackLens[valid_mask] # [K]
|
|
159
|
+
chk_coords = self.trackID[valid_mask] # [K, D]
|
|
160
|
+
|
|
161
|
+
# 创建时间遮罩: 标记环形缓冲区中哪些数据是有效的
|
|
162
|
+
time_idx = torch.arange(self.lenDBSCAN, device=device).view(1, 1, -1) # [1, 1, 100]
|
|
163
|
+
data_mask = time_idx < chk_lens.view(-1, 1, 1) # [K, 1, 100]
|
|
164
|
+
|
|
165
|
+
# 手动并行计算 Masked STD (利用数学公式:Var = sum((x - mean)^2) / (N - 1))
|
|
166
|
+
# 因为方差是无序的,即便环形缓冲区数据没按时间排序,也绝对不影响最终计算结果!
|
|
167
|
+
sum_val = (chk_info * data_mask).sum(dim=-1) # [K, C]
|
|
168
|
+
mean_val = sum_val / chk_lens.unsqueeze(1).float()
|
|
169
|
+
|
|
170
|
+
diff_sq = ((chk_info - mean_val.unsqueeze(-1)) * data_mask) ** 2
|
|
171
|
+
var_val = diff_sq.sum(dim=-1) / (chk_lens.unsqueeze(1).float() - 1.0)
|
|
172
|
+
std_val = var_val.sqrt() # [K, C]
|
|
173
|
+
|
|
174
|
+
# 找出最大的 STD,并判定
|
|
175
|
+
max_std, _ = std_val.max(dim=1) # [K]
|
|
176
|
+
erase_mask = max_std < self.SDThres # [K]
|
|
177
|
+
|
|
178
|
+
if erase_mask.any():
|
|
179
|
+
erase_coords = chk_coords[erase_mask]
|
|
180
|
+
e_y = erase_coords[:, -2].long()
|
|
181
|
+
e_x = erase_coords[:, -1].long()
|
|
182
|
+
# 终极一键批量擦除
|
|
183
|
+
mushroomBodyOpt[..., e_y, e_x] = 0
|
|
184
|
+
|
|
185
|
+
self.Opt = mushroomBodyOpt
|
|
186
|
+
return mushroomBodyOpt
|
|
187
|
+
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.spatial.distance import cdist
|
|
3
|
+
|
|
4
|
+
from . import stmdplus_core
|
|
5
|
+
|
|
6
|
+
class MushroomBody(stmdplus_core.MushroomBody):
|
|
7
|
+
# MushroomBody class for STMDPlus
|
|
8
|
+
|
|
9
|
+
def __init__(self):
|
|
10
|
+
# Constructor method
|
|
11
|
+
# Initializes the MushroomBody object
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def setup(self):
|
|
16
|
+
# Initialization method
|
|
17
|
+
# Initializes the non-maximum suppression
|
|
18
|
+
super().setup()
|
|
19
|
+
|
|
20
|
+
def forward(self, lobulaOpt, contrastOpt):
|
|
21
|
+
# Processing method
|
|
22
|
+
# Processes the input lobulaOpt and contrastOpt to generate mushroomBodyOpt
|
|
23
|
+
|
|
24
|
+
nmsLobulaOpt = self.hNMS.nms(lobulaOpt)
|
|
25
|
+
|
|
26
|
+
mushroomBodyOpt = lobulaOpt * (nmsLobulaOpt > 0)
|
|
27
|
+
|
|
28
|
+
maxNumber = np.max(nmsLobulaOpt)
|
|
29
|
+
|
|
30
|
+
if maxNumber <= 0:
|
|
31
|
+
self.trackID = None
|
|
32
|
+
self.trackInfo = []
|
|
33
|
+
return mushroomBodyOpt
|
|
34
|
+
|
|
35
|
+
idX, idY = np.where(nmsLobulaOpt > 0)
|
|
36
|
+
newID = np.column_stack((idX, idY))
|
|
37
|
+
|
|
38
|
+
shouldTrackID = np.ones(len(self.trackID), dtype=bool) if self.trackID is not None else np.array([], dtype=bool)
|
|
39
|
+
shouldAddNewID = np.ones(len(idX), dtype=bool)
|
|
40
|
+
numContrast = len(contrastOpt)
|
|
41
|
+
|
|
42
|
+
if self.trackID is not None:
|
|
43
|
+
DD = cdist(self.trackID, newID)
|
|
44
|
+
D1 = np.min(DD, axis=1)
|
|
45
|
+
|
|
46
|
+
for idxI, d1 in enumerate(D1):
|
|
47
|
+
if d1 <= self.DBSCANDist:
|
|
48
|
+
idxJ = np.argmin(DD[idxI])
|
|
49
|
+
if shouldAddNewID[idxJ]:
|
|
50
|
+
self.trackID[idxI] = newID[idxJ]
|
|
51
|
+
nowContrast = np.array(
|
|
52
|
+
[[contrastOpt[idCont][newID[idxJ, 0], newID[idxJ, 1]]] for idCont in range(numContrast)])
|
|
53
|
+
self.trackInfo[idxI] = np.hstack((self.trackInfo[idxI], nowContrast))
|
|
54
|
+
shouldTrackID[idxI] = False
|
|
55
|
+
shouldAddNewID[idxJ] = False
|
|
56
|
+
|
|
57
|
+
self.trackID = np.delete(self.trackID, np.where(shouldTrackID), axis=0)
|
|
58
|
+
self.trackInfo = [x for idx, x in enumerate(self.trackInfo) if not shouldTrackID[idx]]
|
|
59
|
+
|
|
60
|
+
oldTractNum = len(self.trackInfo)
|
|
61
|
+
|
|
62
|
+
isxNew = np.where(shouldAddNewID)[0]
|
|
63
|
+
for kk in isxNew:
|
|
64
|
+
if self.trackID is None:
|
|
65
|
+
self.trackID = newID[kk]
|
|
66
|
+
else:
|
|
67
|
+
self.trackID = np.vstack((self.trackID, newID[kk]))
|
|
68
|
+
nowContrast = np.array(
|
|
69
|
+
[[contrastOpt[idCont][newID[kk, 0], newID[kk, 1]]] for idCont in range(numContrast)])
|
|
70
|
+
self.trackInfo.append(nowContrast)
|
|
71
|
+
|
|
72
|
+
for idx in range(oldTractNum):
|
|
73
|
+
if np.max(np.std(self.trackInfo[idx], axis=1)) < self.SDThres:
|
|
74
|
+
idX = self.trackID[idx, 0]
|
|
75
|
+
idY = self.trackID[idx, 1]
|
|
76
|
+
mushroomBodyOpt[idX, idY] = 0
|
|
77
|
+
|
|
78
|
+
if self.trackInfo[idx].shape[1] > self.lenDBSCAN:
|
|
79
|
+
self.trackInfo[idx] = self.trackInfo[idx][:, 1:]
|
|
80
|
+
|
|
81
|
+
self.Opt = mushroomBodyOpt
|
|
82
|
+
return mushroomBodyOpt
|