xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,233 @@
1
+ from .base_core import BaseCore
2
+ from .math_operator import SpatialInhibition, GammaDelay
3
+ from . import estmd_backbone
4
+ from ..util.create_kernel import *
5
+ from ..util.compute_module import slice_matrix_holding_size
6
+
7
+
8
+ class Medulla(estmd_backbone.Medulla):
9
+ # Medulla layer of the motion detection system
10
+
11
+ def __init__(self):
12
+ # Constructor method
13
+ # Initializes the Medulla object
14
+ super().__init__()
15
+
16
+ self.hPara5Mi1 = None
17
+ self.hPara5Tm1 = None
18
+ self.cellTm1Ipt = None
19
+
20
+ def setup(self):
21
+ # Initialization method
22
+ # This method initializes the Medulla layer components
23
+ super().setup()
24
+
25
+ self.hTm1 = GammaDelay(5, 25)
26
+ self.hPara5Mi1 = GammaDelay(25, 30)
27
+ self.hPara5Tm1 = GammaDelay(25, 30)
28
+
29
+ self.cellTm1Ipt = CircularList()
30
+
31
+ self.hTm1.setup(False)
32
+ self.hPara5Mi1.setup()
33
+ self.hPara5Tm1.setup(False)
34
+
35
+ if not self.cellTm1Ipt.initLen:
36
+ self.cellTm1Ipt.initLen = max(self.hPara5Mi1.lenKernel, self.hPara5Tm1.lenKernel)
37
+
38
+ self.cellTm1Ipt.reset()
39
+
40
+ def forward(self, MedullaIpt):
41
+ # Processing method
42
+ # Applies processing to the input and returns the output
43
+ # Process Tm2 and Tm3 components
44
+ tm2Signal = self.hTm2.forward(MedullaIpt)
45
+ tm3Signal = self.hTm3.forward(MedullaIpt)
46
+
47
+ # Process Tm1 component using output of Tm2
48
+ self.cellTm1Ipt.record_next(tm3Signal)
49
+ tm1Para3Signal = self.hTm1.forward(self.cellTm1Ipt)
50
+ tm1Para5Signal = self.hPara5Tm1.forward(self.cellTm1Ipt)
51
+
52
+ # Process Mi1 component using output of Tm3
53
+ mi1Para5Signal = self.hPara5Mi1.forward(tm3Signal)
54
+
55
+ # Store the output signals in Opt property
56
+ varageout = [tm3Signal, tm1Para3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, self.hPara5Mi1.tau]
57
+ self.Opt = varageout
58
+ return varageout
59
+
60
+
61
+ class Lobula(BaseCore):
62
+ # Lobula layer of the motion detection system
63
+
64
+ def __init__(self):
65
+ # Constructor method
66
+ # Initializes the Lobula object
67
+ super().__init__()
68
+
69
+ self.hSTMD = None
70
+ self.hLPTC = None
71
+
72
+ def setup(self):
73
+ # Initialization method
74
+ # This method initializes the Lobula layer component
75
+ self.hSTMD = Stmdcell()
76
+ self.hLPTC = Lptcell()
77
+ self.hSTMD.setup()
78
+ self.hLPTC.setup(self.hSTMD.hGammaDelay.lenKernel)
79
+
80
+ def forward(self, varagein):
81
+ # Processing method
82
+ # Performs a correlation operation on the ON and OFF channels
83
+ # and then applies surround inhibition
84
+
85
+ # Extract ON and OFF channel signals from the input
86
+ tm3Signal, tm1Para3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, tau5 = varagein
87
+
88
+ psi, fai = self.hLPTC.forward(tm3Signal, mi1Para5Signal, tm2Signal, tm1Para5Signal, tau5)
89
+
90
+ lobulaOpt = self.hSTMD.forward(tm3Signal, tm1Para3Signal, psi, fai)
91
+
92
+ # Store the output in Opt property
93
+ self.Opt = lobulaOpt
94
+
95
+ return lobulaOpt, []
96
+
97
+
98
+ class Stmdcell(BaseCore):
99
+ # Lobula layer of the motion detection system
100
+
101
+ def __init__(self):
102
+ # Constructor method
103
+ # Initializes the Lobula object
104
+ super().__init__()
105
+
106
+ self.hSubInhi = None # SpatialInhibition component
107
+ self.alpha = 0.1 # Parameter alpha
108
+ self.gaussKernel = None # Gaussian kernel
109
+ self.hGammaDelay = None
110
+ self.cellDPlusE = None
111
+ self.paraGaussKernel = {'size': 3, 'eta': 1.5}
112
+
113
+ def setup(self):
114
+ # Initialization method
115
+ # This method initializes the Lobula layer component
116
+ self.hSubInhi = SpatialInhibition()
117
+ self.hGammaDelay = GammaDelay(6, 12)
118
+ self.cellDPlusE = CircularList()
119
+
120
+ self.hSubInhi.setup()
121
+ self.hGammaDelay.setup()
122
+
123
+ if not self.cellDPlusE.initLen:
124
+ self.cellDPlusE.initLen = self.hGammaDelay.lenKernel
125
+ self.cellDPlusE.reset()
126
+
127
+ self.gaussKernel = gaussian_filter(
128
+ np.zeros((self.paraGaussKernel['size'], self.paraGaussKernel['size'])),
129
+ self.paraGaussKernel['eta']
130
+ )
131
+
132
+ def forward(self, tm3Signal, tm1Signal, faiList, psiList):
133
+ # Processing method
134
+ # Performs temporal convolution, correlation, and surround inhibition
135
+ convnIpt = [None] * self.cellDPlusE.initLen
136
+
137
+ for idxT in range(len(convnIpt)-1, -1, -1):
138
+ pointer = self.cellDPlusE.pointer
139
+ if self.cellDPlusE[pointer] is not None:
140
+ fai = faiList[idxT]
141
+ psi = psiList[idxT]
142
+ convnIpt[idxT] = slice_matrix_holding_size(self.cellDPlusE[pointer], psi, fai)
143
+ pointer = (pointer - 1) % self.cellDPlusE.initLen
144
+
145
+ feedbackSignal = self.hGammaDelay.forward_list(convnIpt)
146
+
147
+ if feedbackSignal is not None:
148
+ feedbackSignal *= self.alpha
149
+ correlationD = np.maximum(tm3Signal - feedbackSignal, 0) * np.maximum(tm1Signal - feedbackSignal, 0)
150
+ else:
151
+ correlationD = np.maximum(tm3Signal, 0) * np.maximum(tm1Signal, 0)
152
+
153
+ correlationE = filter2D(tm3Signal * tm1Signal, -1, self.gaussKernel, borderType=BORDER_CONSTANT)
154
+
155
+ lateralInhiSTMDOpt = self.hSubInhi.forward(correlationD)
156
+
157
+ self.cellDPlusE.record_next(correlationD + correlationE)
158
+
159
+ self.Opt = lateralInhiSTMDOpt
160
+
161
+ return lateralInhiSTMDOpt
162
+
163
+
164
+ class Lptcell(BaseCore):
165
+ # Lptcell Lobula Plate Tangential Cell
166
+
167
+ def __init__(self):
168
+ # Constructor method
169
+ # Initializes the Lobula object
170
+ super().__init__()
171
+
172
+ self.bataList = list(range(2, 19, 2))
173
+ self.thetaList = np.arange(0, 2 * np.pi, np.pi / 4)
174
+ self.velocity = None
175
+ self.tuningCurvef = None
176
+
177
+ def setup(self, lenVelocity):
178
+ self.velocity = np.zeros(lenVelocity)
179
+
180
+ lenBataList = len(self.bataList)
181
+ lenThetaList = len(self.thetaList)
182
+ # generate gauss distribution
183
+ gaussianDistribution = np.exp(-0.5 * ((np.arange(-199, 201) - 1) / (100 / 2)) ** 2)
184
+ # normalization
185
+ gaussianDistribution /= np.max(gaussianDistribution)
186
+
187
+ self.tuningCurvef = np.zeros((lenBataList, lenThetaList * 100 + 200))
188
+ self.tuningCurvef[0, :300] = gaussianDistribution[100:400]
189
+ self.tuningCurvef[-1, -300:] = gaussianDistribution[:300]
190
+ for id in range(1, lenBataList - 1):
191
+ idRange = slice((id+1) * 100 - 200, (id+1) * 100 + 200)
192
+ self.tuningCurvef[id, idRange] = gaussianDistribution
193
+
194
+ def forward(self, tm1Signal, tm2Signal, tm3Signal, mi1Signal, tau5):
195
+ lenBataList = len(self.bataList)
196
+ lenThetaList = len(self.thetaList)
197
+ sumLplcOptR = np.zeros((lenBataList, lenThetaList))
198
+
199
+ for idBata, bata in enumerate(self.bataList):
200
+ for idTheta, theta in enumerate(self.thetaList):
201
+ shiftX = np.round(bata * np.cos(theta + np.pi / 2)).astype(int)
202
+ shiftY = np.round(bata * np.sin(theta + np.pi / 2)).astype(int)
203
+ shiftMi1Signal = slice_matrix_holding_size(mi1Signal, shiftY, shiftX)
204
+ shiftTm1Signal = slice_matrix_holding_size(tm1Signal, shiftY, shiftX)
205
+
206
+ ltlcOpt = tm3Signal * shiftMi1Signal + tm2Signal * shiftTm1Signal
207
+
208
+ sumLplcOpt = np.sum(ltlcOpt)
209
+
210
+ sumLplcOptR[idBata, idTheta] = sumLplcOpt
211
+
212
+ # preferTheta
213
+ firingRate = np.max(sumLplcOptR, axis=1)
214
+ preferTheta = np.argmax(sumLplcOptR, axis=1)
215
+ maxTheta = np.max(preferTheta)
216
+
217
+ # background velocity
218
+ self.velocity = np.roll(self.velocity, -1)
219
+ temp = [np.sum((firingRate[i] - self.tuningCurvef[i,:]) ** 2) for i in range(len(firingRate))]
220
+ self.velocity[-1] = np.argmin(temp)
221
+
222
+ sumV = np.zeros_like(self.velocity)
223
+ for idV in range(len(self.velocity)):
224
+ sumV[idV] = np.sum(self.velocity[idV:])
225
+
226
+ fai = sumV * np.cos(maxTheta)
227
+ psi = sumV * np.sin(maxTheta)
228
+
229
+ self.Opt = [fai, psi]
230
+
231
+ return fai, psi
232
+
233
+
@@ -0,0 +1,187 @@
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ from .base_core import BaseCore
5
+ from ..util.create_kernel import create_T1_kernels
6
+ from ..util.compute_module import AreaNMS
7
+ from ..util.compute_module import compute_response
8
+
9
+
10
+ class ContrastPathway(BaseCore):
11
+ """ContrastPathway class for ApgSTMD - PyTorch Version."""
12
+
13
+ def __init__(self):
14
+ """Constructor method."""
15
+ super().__init__()
16
+
17
+ self.theta = torch.tensor([0, torch.pi/4, torch.pi/2, 3*torch.pi/4])
18
+ self.alpha2 = 1.5
19
+ self.eta = 3
20
+ self.sizeT1 = 11
21
+
22
+ self.register_buffer('T1_kernel', torch.empty(0))
23
+
24
+ def setup(self):
25
+ """Initialization method."""
26
+ # 假设 create_T1_kernels 返回的是 4 个 kernel 的 List 或 NumPy array
27
+ _T1_kernel = create_T1_kernels(len(self.theta), self.alpha2, self.eta, self.sizeT1)
28
+ self.T1_kernel.data = _T1_kernel
29
+
30
+ def forward(self, x):
31
+ """
32
+ Processing method.
33
+ retinaOpt: 可以是 (H, W) 或 (1, 1, H, W) 的 PyTorch Tensor
34
+ """
35
+
36
+ # out 的形状是 (1, 4, H, W)
37
+ self.Opt = F.conv2d(x, self.T1_kernel, padding='same')
38
+
39
+ return self.Opt
40
+
41
+
42
+ class MushroomBody(BaseCore):
43
+ # MushroomBody class for STMDPlus - Ultimate Vectorized PyTorch Version
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.nms_size = 5
48
+
49
+ self.DBSCANDist = 5.0
50
+ self.lenDBSCAN = 100
51
+ self.SDThres = 5.0
52
+
53
+ self.torch_nms = None
54
+
55
+ # ================= 终极优化:预分配的全张量状态机 =================
56
+ self.C = None # 通道数 (将在第一帧自动推断)
57
+ self.trackID = None # 张量: [N, D] 坐标
58
+ self.trackInfo = None # 张量: [N, C, lenDBSCAN] (固定大小环形缓冲区)
59
+ self.trackLens = None # 张量: [N] (记录当前轨迹有效长度)
60
+ self.trackPtr = None # 张量: [N] (记录环形缓冲区的写入指针)
61
+
62
+ def setup(self):
63
+ self.torch_nms = AreaNMS(self.nms_size)
64
+
65
+ def forward(self, lobulaOpt, contrast_tensor):
66
+ device = lobulaOpt.device
67
+
68
+ maxLobulaOpt = compute_response(lobulaOpt)
69
+ nmsLobulaOpt = self.torch_nms(maxLobulaOpt)
70
+
71
+ mask_not_nms = (nmsLobulaOpt == 0)
72
+ mushroomBodyOpt = lobulaOpt * mask_not_nms
73
+
74
+ maxNumber = torch.max(nmsLobulaOpt)
75
+ if maxNumber <= 0:
76
+ self.trackID = None
77
+ return mushroomBodyOpt
78
+
79
+ # --- 获取新检测点 ---
80
+ newID = torch.nonzero(nmsLobulaOpt > 0).float()
81
+ if len(newID) == 0:
82
+ self.trackID = None
83
+ return mushroomBodyOpt
84
+
85
+ curr_y, curr_x = newID[:, -2].long(), newID[:, -1].long()
86
+ # all_new_contrasts 形状 [C, M], M是新目标数
87
+ all_new_contrasts = contrast_tensor[:, 0, curr_y, curr_x]
88
+
89
+ if self.C is None:
90
+ self.C = contrast_tensor.shape[0]
91
+
92
+ M = len(newID)
93
+
94
+ # ================= 1. 轨迹匹配 (保持 CPU 高效碰撞处理) =================
95
+ matched_old, matched_new, used_new = [], [], set()
96
+
97
+ if self.trackID is not None and len(self.trackID) > 0:
98
+ DD = torch.cdist(self.trackID[:, -2:], newID[:, -2:])
99
+ D1, min_idx = torch.min(DD, dim=1)
100
+
101
+ # 转移到 CPU 做极速冲突判定
102
+ D1_cpu, min_idx_cpu = D1.cpu().numpy(), min_idx.cpu().numpy()
103
+ for i, d1 in enumerate(D1_cpu):
104
+ if d1 <= self.DBSCANDist:
105
+ j = min_idx_cpu[i]
106
+ if j not in used_new:
107
+ used_new.add(j)
108
+ matched_old.append(i)
109
+ matched_new.append(j)
110
+
111
+ # ================= 2. 环形缓冲区批量状态更新 (0 For循环) =================
112
+ matched_old_ts = torch.tensor(matched_old, dtype=torch.long, device=device)
113
+ matched_new_ts = torch.tensor(matched_new, dtype=torch.long, device=device)
114
+ unmatched_new_ts = torch.tensor([j for j in range(M) if j not in used_new], dtype=torch.long, device=device)
115
+
116
+ if self.trackID is not None and len(matched_old) > 0:
117
+ # --- 提取续航的轨迹状态 ---
118
+ next_trackID = newID[matched_new_ts]
119
+ next_trackInfo = self.trackInfo[matched_old_ts]
120
+ next_trackLens = self.trackLens[matched_old_ts]
121
+ next_trackPtr = self.trackPtr[matched_old_ts]
122
+
123
+ # 批量写入环形缓冲区 (全矩阵操作)
124
+ batch_idx = torch.arange(len(matched_old), device=device)
125
+ next_trackInfo[batch_idx, :, next_trackPtr] = all_new_contrasts[:, matched_new_ts].T
126
+ next_trackLens = torch.clamp(next_trackLens + 1, max=self.lenDBSCAN)
127
+ next_trackPtr = (next_trackPtr + 1) % self.lenDBSCAN
128
+ else:
129
+ # 定义空张量用于拼接
130
+ next_trackID = torch.empty((0, newID.shape[1]), device=device)
131
+ next_trackInfo = torch.empty((0, self.C, self.lenDBSCAN), device=device)
132
+ next_trackLens = torch.empty((0,), dtype=torch.long, device=device)
133
+ next_trackPtr = torch.empty((0,), dtype=torch.long, device=device)
134
+
135
+ if len(unmatched_new_ts) > 0:
136
+ # --- 批量初始化新轨迹 ---
137
+ N_add = len(unmatched_new_ts)
138
+ add_trackID = newID[unmatched_new_ts]
139
+ add_trackInfo = torch.zeros((N_add, self.C, self.lenDBSCAN), device=device)
140
+ add_trackInfo[:, :, 0] = all_new_contrasts[:, unmatched_new_ts].T
141
+ add_trackLens = torch.ones(N_add, dtype=torch.long, device=device)
142
+ add_trackPtr = torch.ones(N_add, dtype=torch.long, device=device)
143
+
144
+ # --- 与续航的轨迹合并 ---
145
+ self.trackID = torch.cat([next_trackID, add_trackID], dim=0)
146
+ self.trackInfo = torch.cat([next_trackInfo, add_trackInfo], dim=0)
147
+ self.trackLens = torch.cat([next_trackLens, add_trackLens], dim=0)
148
+ self.trackPtr = torch.cat([next_trackPtr, add_trackPtr], dim=0)
149
+ else:
150
+ self.trackID, self.trackInfo, self.trackLens, self.trackPtr = next_trackID, next_trackInfo, next_trackLens, next_trackPtr
151
+
152
+ # ================= 3. 并行 STD 计算与批量擦除 =================
153
+ if self.trackID is not None and len(self.trackID) > 0:
154
+ valid_mask = self.trackLens > 1
155
+ if valid_mask.any():
156
+ # 只取出长度 > 1 的轨迹
157
+ chk_info = self.trackInfo[valid_mask] # [K, C, 100]
158
+ chk_lens = self.trackLens[valid_mask] # [K]
159
+ chk_coords = self.trackID[valid_mask] # [K, D]
160
+
161
+ # 创建时间遮罩: 标记环形缓冲区中哪些数据是有效的
162
+ time_idx = torch.arange(self.lenDBSCAN, device=device).view(1, 1, -1) # [1, 1, 100]
163
+ data_mask = time_idx < chk_lens.view(-1, 1, 1) # [K, 1, 100]
164
+
165
+ # 手动并行计算 Masked STD (利用数学公式:Var = sum((x - mean)^2) / (N - 1))
166
+ # 因为方差是无序的,即便环形缓冲区数据没按时间排序,也绝对不影响最终计算结果!
167
+ sum_val = (chk_info * data_mask).sum(dim=-1) # [K, C]
168
+ mean_val = sum_val / chk_lens.unsqueeze(1).float()
169
+
170
+ diff_sq = ((chk_info - mean_val.unsqueeze(-1)) * data_mask) ** 2
171
+ var_val = diff_sq.sum(dim=-1) / (chk_lens.unsqueeze(1).float() - 1.0)
172
+ std_val = var_val.sqrt() # [K, C]
173
+
174
+ # 找出最大的 STD,并判定
175
+ max_std, _ = std_val.max(dim=1) # [K]
176
+ erase_mask = max_std < self.SDThres # [K]
177
+
178
+ if erase_mask.any():
179
+ erase_coords = chk_coords[erase_mask]
180
+ e_y = erase_coords[:, -2].long()
181
+ e_x = erase_coords[:, -1].long()
182
+ # 终极一键批量擦除
183
+ mushroomBodyOpt[..., e_y, e_x] = 0
184
+
185
+ self.Opt = mushroomBodyOpt
186
+ return mushroomBodyOpt
187
+
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+ from scipy.spatial.distance import cdist
3
+
4
+ from . import stmdplus_core
5
+
6
+ class MushroomBody(stmdplus_core.MushroomBody):
7
+ # MushroomBody class for STMDPlus
8
+
9
+ def __init__(self):
10
+ # Constructor method
11
+ # Initializes the MushroomBody object
12
+ super().__init__()
13
+
14
+
15
+ def setup(self):
16
+ # Initialization method
17
+ # Initializes the non-maximum suppression
18
+ super().setup()
19
+
20
+ def forward(self, lobulaOpt, contrastOpt):
21
+ # Processing method
22
+ # Processes the input lobulaOpt and contrastOpt to generate mushroomBodyOpt
23
+
24
+ nmsLobulaOpt = self.hNMS.nms(lobulaOpt)
25
+
26
+ mushroomBodyOpt = lobulaOpt * (nmsLobulaOpt > 0)
27
+
28
+ maxNumber = np.max(nmsLobulaOpt)
29
+
30
+ if maxNumber <= 0:
31
+ self.trackID = None
32
+ self.trackInfo = []
33
+ return mushroomBodyOpt
34
+
35
+ idX, idY = np.where(nmsLobulaOpt > 0)
36
+ newID = np.column_stack((idX, idY))
37
+
38
+ shouldTrackID = np.ones(len(self.trackID), dtype=bool) if self.trackID is not None else np.array([], dtype=bool)
39
+ shouldAddNewID = np.ones(len(idX), dtype=bool)
40
+ numContrast = len(contrastOpt)
41
+
42
+ if self.trackID is not None:
43
+ DD = cdist(self.trackID, newID)
44
+ D1 = np.min(DD, axis=1)
45
+
46
+ for idxI, d1 in enumerate(D1):
47
+ if d1 <= self.DBSCANDist:
48
+ idxJ = np.argmin(DD[idxI])
49
+ if shouldAddNewID[idxJ]:
50
+ self.trackID[idxI] = newID[idxJ]
51
+ nowContrast = np.array(
52
+ [[contrastOpt[idCont][newID[idxJ, 0], newID[idxJ, 1]]] for idCont in range(numContrast)])
53
+ self.trackInfo[idxI] = np.hstack((self.trackInfo[idxI], nowContrast))
54
+ shouldTrackID[idxI] = False
55
+ shouldAddNewID[idxJ] = False
56
+
57
+ self.trackID = np.delete(self.trackID, np.where(shouldTrackID), axis=0)
58
+ self.trackInfo = [x for idx, x in enumerate(self.trackInfo) if not shouldTrackID[idx]]
59
+
60
+ oldTractNum = len(self.trackInfo)
61
+
62
+ isxNew = np.where(shouldAddNewID)[0]
63
+ for kk in isxNew:
64
+ if self.trackID is None:
65
+ self.trackID = newID[kk]
66
+ else:
67
+ self.trackID = np.vstack((self.trackID, newID[kk]))
68
+ nowContrast = np.array(
69
+ [[contrastOpt[idCont][newID[kk, 0], newID[kk, 1]]] for idCont in range(numContrast)])
70
+ self.trackInfo.append(nowContrast)
71
+
72
+ for idx in range(oldTractNum):
73
+ if np.max(np.std(self.trackInfo[idx], axis=1)) < self.SDThres:
74
+ idX = self.trackID[idx, 0]
75
+ idY = self.trackID[idx, 1]
76
+ mushroomBodyOpt[idX, idY] = 0
77
+
78
+ if self.trackInfo[idx].shape[1] > self.lenDBSCAN:
79
+ self.trackInfo[idx] = self.trackInfo[idx][:, 1:]
80
+
81
+ self.Opt = mushroomBodyOpt
82
+ return mushroomBodyOpt