xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = ['api', 'model', 'util']
|
xttmp/api/__init__.py
ADDED
xttmp/api/evaluate.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# demo_vidstream
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
import time
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from . import instancing_model
|
|
8
|
+
from ..model import *
|
|
9
|
+
from ..util.evaluate_module import (get_ROC_curve_data, compute_AUC,
|
|
10
|
+
get_thres_recall_data, compute_AR,
|
|
11
|
+
get_P_R_curve_data, compute_AP, )
|
|
12
|
+
from ..util.compute_module import matrix_to_sparse_list
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def inference_task(modelName,
|
|
16
|
+
inputpath,
|
|
17
|
+
inputType = 'ImgstreamReader',
|
|
18
|
+
startFrame = 0,
|
|
19
|
+
endFrame = None,
|
|
20
|
+
device = 'cpu',
|
|
21
|
+
**kwargs):
|
|
22
|
+
''' Instantiate the model '''
|
|
23
|
+
objModel = instancing_model(modelName, device=device)
|
|
24
|
+
|
|
25
|
+
''' Dynamically create a video stream reader or other input type '''
|
|
26
|
+
inputModule = globals().get(inputType)
|
|
27
|
+
if inputModule is None:
|
|
28
|
+
raise ValueError(f"Unknown inputType: {inputType}")
|
|
29
|
+
|
|
30
|
+
objIptStream = inputModule(inputpath, startFrame, endFrame)
|
|
31
|
+
|
|
32
|
+
objNMS = MatrixNMS(15)
|
|
33
|
+
|
|
34
|
+
''' Initialize the model '''
|
|
35
|
+
# set the parameter list
|
|
36
|
+
objModel.set_para(**kwargs)
|
|
37
|
+
# init
|
|
38
|
+
objModel.init_config()
|
|
39
|
+
|
|
40
|
+
totalRunningTime = 0
|
|
41
|
+
results = []
|
|
42
|
+
directions = []
|
|
43
|
+
''' Run '''
|
|
44
|
+
while objIptStream.hasFrame:
|
|
45
|
+
# Read the next frame from the video stream
|
|
46
|
+
grayImg, _ = objIptStream.get_next_frame()
|
|
47
|
+
if device != 'cpu':
|
|
48
|
+
grayImg = torch.from_numpy(grayImg).to(device=device).float().unsqueeze(0).unsqueeze(0)
|
|
49
|
+
|
|
50
|
+
# Perform inference using the model
|
|
51
|
+
result, runTime = objModel.process(grayImg)
|
|
52
|
+
totalRunningTime += runTime
|
|
53
|
+
|
|
54
|
+
# postprocessing
|
|
55
|
+
if device != 'cpu':
|
|
56
|
+
torch.cuda.synchronize()
|
|
57
|
+
result = {k: v.squeeze(0).squeeze(0).cpu().numpy() for k, v in result.items()}
|
|
58
|
+
# response
|
|
59
|
+
response = result['response']
|
|
60
|
+
if np.max(response) == 0:
|
|
61
|
+
results.append([])
|
|
62
|
+
continue
|
|
63
|
+
response = objNMS.nms(result['response'])
|
|
64
|
+
maxOpt = np.max(response)
|
|
65
|
+
if maxOpt > 0:
|
|
66
|
+
response /= np.max(response)
|
|
67
|
+
responseListType = matrix_to_sparse_list(response.astype(np.float64))
|
|
68
|
+
else:
|
|
69
|
+
responseListType = []
|
|
70
|
+
results.append(responseListType)
|
|
71
|
+
|
|
72
|
+
# direction
|
|
73
|
+
direction = result['direction']
|
|
74
|
+
if (direction is not None) and len(direction) and len(responseListType):
|
|
75
|
+
directionListType = [[y, x, float(direction[x, y])] for y, x, _ in responseListType]
|
|
76
|
+
else:
|
|
77
|
+
directionListType = []
|
|
78
|
+
directions.append(directionListType)
|
|
79
|
+
|
|
80
|
+
return results, directions, totalRunningTime
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def evaluate_task(modelOpt, groundTruth, aucPara = 40, gTError = 1, startFrame = 0, endFrame = None, plotFigures=True):
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
''' ROC curve Part'''
|
|
87
|
+
# get ROC data
|
|
88
|
+
RPIList, FPPIList, _ = get_ROC_curve_data(modelOpt,
|
|
89
|
+
groundTruth,
|
|
90
|
+
rangeOfFPPI = [0, aucPara],
|
|
91
|
+
gTError = gTError,
|
|
92
|
+
startFrame = startFrame,
|
|
93
|
+
endFrame = endFrame)
|
|
94
|
+
|
|
95
|
+
# calculate AUC
|
|
96
|
+
rocOfAUC = compute_AUC(RPIList, FPPIList, rangeOfFPPI=[0, aucPara])
|
|
97
|
+
|
|
98
|
+
# plot ROC curve
|
|
99
|
+
if plotFigures:
|
|
100
|
+
rocFig, ax1 = plt.subplots()
|
|
101
|
+
ax1.plot(FPPIList, RPIList)
|
|
102
|
+
ax1.set_xlim(0, aucPara)
|
|
103
|
+
ax1.set_ylim(0, 1)
|
|
104
|
+
|
|
105
|
+
ax1.set_xlabel('False Positive Rate (FPPI)')
|
|
106
|
+
ax1.set_ylabel('Recall (RPI)')
|
|
107
|
+
ax1.set_title('ROC Curve')
|
|
108
|
+
|
|
109
|
+
''' mR Part'''
|
|
110
|
+
# get meanRecall data
|
|
111
|
+
RPIList1, thresholdList1 = get_thres_recall_data(modelOpt,
|
|
112
|
+
groundTruth,
|
|
113
|
+
gTError = gTError,
|
|
114
|
+
startFrame = startFrame,
|
|
115
|
+
endFrame = endFrame)
|
|
116
|
+
|
|
117
|
+
# calculate mean Recall
|
|
118
|
+
AR = compute_AR(RPIList1, thresholdList1, rangeOfThreshold=[0.5, 1])
|
|
119
|
+
|
|
120
|
+
# plot mR curve
|
|
121
|
+
if plotFigures:
|
|
122
|
+
mRFig, ax2 = plt.subplots()
|
|
123
|
+
ax2.plot(thresholdList1, RPIList1)
|
|
124
|
+
ax2.set_xlim(0.5, 1)
|
|
125
|
+
ax2.set_ylim(0, 1)
|
|
126
|
+
|
|
127
|
+
ax2.set_xlabel('Threshold')
|
|
128
|
+
ax2.set_ylabel('Recall')
|
|
129
|
+
ax2.set_title('Threshold-Recall Curve')
|
|
130
|
+
|
|
131
|
+
''' P-R curve Part'''
|
|
132
|
+
# get meanRecall data
|
|
133
|
+
rList2, pList2, _ = get_P_R_curve_data(modelOpt,
|
|
134
|
+
groundTruth,
|
|
135
|
+
intervalOfRecall = 0.02,
|
|
136
|
+
gTError = gTError,
|
|
137
|
+
startFrame = startFrame,
|
|
138
|
+
endFrame = endFrame)
|
|
139
|
+
|
|
140
|
+
# calculate mean Recall
|
|
141
|
+
AP = compute_AP(rList2, pList2)
|
|
142
|
+
|
|
143
|
+
# plot mR curve
|
|
144
|
+
if plotFigures:
|
|
145
|
+
PRFig, ax3 = plt.subplots()
|
|
146
|
+
ax3.plot(rList2, pList2)
|
|
147
|
+
ax3.set_xlim(0, 1)
|
|
148
|
+
ax3.set_ylim(0, 1)
|
|
149
|
+
|
|
150
|
+
ax3.set_xlabel('Recall')
|
|
151
|
+
ax3.set_ylabel('Precision')
|
|
152
|
+
ax3.set_title('P-R Curve')
|
|
153
|
+
|
|
154
|
+
if plotFigures:
|
|
155
|
+
figHandle = {'ROC': rocFig,
|
|
156
|
+
'mR': mRFig,
|
|
157
|
+
'PR': PRFig,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
return rocOfAUC, AR, AP, figHandle
|
|
161
|
+
else:
|
|
162
|
+
return rocOfAUC, AR, AP
|
|
163
|
+
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from ..util.iostream import Visualization
|
|
2
|
+
|
|
3
|
+
def get_visualize_handle(className=None,
|
|
4
|
+
showThreshold=None,
|
|
5
|
+
width = 8,
|
|
6
|
+
height = 5,
|
|
7
|
+
dpi = 100):
|
|
8
|
+
"""
|
|
9
|
+
Returns a handle to a visualization object based on the given class name.
|
|
10
|
+
|
|
11
|
+
Parameters:
|
|
12
|
+
className (str): Name of the visualization class. If None, a default visualization object is created.
|
|
13
|
+
showThreshold (bool): Whether to show threshold. Default is None.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Visualization: Handle to the visualization object.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
if className and showThreshold is not None:
|
|
20
|
+
objVisualization = Visualization(className, showThreshold)
|
|
21
|
+
elif className:
|
|
22
|
+
objVisualization = Visualization(className)
|
|
23
|
+
else:
|
|
24
|
+
objVisualization = Visualization()
|
|
25
|
+
|
|
26
|
+
objVisualization.create_fig_handle(width=width, height =height, dpi=dpi)
|
|
27
|
+
|
|
28
|
+
return objVisualization
|
|
29
|
+
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from ..model import * # Import all models
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def instancing_model(model_name, device = 'cpu', model_para=None):
|
|
5
|
+
"""
|
|
6
|
+
Instantiate a model object based on the given model name.
|
|
7
|
+
|
|
8
|
+
Parameters:
|
|
9
|
+
model_name (str): Name of the model to instantiate. If None, a GUI for model selection will be opened.
|
|
10
|
+
model_para: Parameters for model instantiation (optional).
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
BaseModel: The instantiated model object.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Instantiate the model
|
|
17
|
+
_model_name = globals().get(model_name)
|
|
18
|
+
if _model_name:
|
|
19
|
+
model = _model_name()
|
|
20
|
+
else:
|
|
21
|
+
print(f"Class {model_name} not found.")
|
|
22
|
+
|
|
23
|
+
# Process additional parameters if provided
|
|
24
|
+
if model_para is not None:
|
|
25
|
+
# Handle model parameters
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
model.setup()
|
|
29
|
+
model.to(device=device) # Move the model to the specified device
|
|
30
|
+
|
|
31
|
+
return model
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
xttmp/core/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from ..util.create_kernel import create_attention_kernel, create_prediction_kernel
|
|
8
|
+
from ..core.math_operator import compute_temporal_conv_inplace
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AttentionModule(BaseCore):
|
|
12
|
+
"""
|
|
13
|
+
AttentionModule class for attention mechanism.
|
|
14
|
+
|
|
15
|
+
This class implements the attention mechanism module in the ApgSTMD.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
"""
|
|
20
|
+
Constructor method.
|
|
21
|
+
|
|
22
|
+
Initializes the AttentionModule object.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.kernal_size = 17
|
|
26
|
+
self.zeta_list = [2, 2.5, 3, 3.5]
|
|
27
|
+
self.theta_list = torch.tensor([0, torch.pi/4, torch.pi/2, 3*torch.pi/4])
|
|
28
|
+
self.alpha = 1
|
|
29
|
+
self.register_buffer('attention_kernel', torch.empty(0))
|
|
30
|
+
self.setup()
|
|
31
|
+
|
|
32
|
+
def setup(self):
|
|
33
|
+
"""
|
|
34
|
+
Initialization method.
|
|
35
|
+
|
|
36
|
+
Initializes the attention kernel.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
self.r = len(self.zeta_list)
|
|
40
|
+
self.s = len(self.theta_list)
|
|
41
|
+
_attention_kernel = create_attention_kernel(
|
|
42
|
+
self.kernal_size,
|
|
43
|
+
self.zeta_list,
|
|
44
|
+
self.theta_list
|
|
45
|
+
)
|
|
46
|
+
_stacked_kernel = torch.stack([torch.stack(row) for row in _attention_kernel])
|
|
47
|
+
self.attention_kernel.data = _stacked_kernel.reshape(self.r * self.s, 1, self.kernal_size, self.kernal_size)
|
|
48
|
+
|
|
49
|
+
def forward(self, retina_opt, prediction_map):
|
|
50
|
+
"""
|
|
51
|
+
Processing method (Optimized with F.conv2d).
|
|
52
|
+
|
|
53
|
+
Processes the retina_opt and prediction_map to generate the
|
|
54
|
+
attention-optimal output.
|
|
55
|
+
"""
|
|
56
|
+
if prediction_map is None:
|
|
57
|
+
self.Opt = retina_opt
|
|
58
|
+
return self.Opt
|
|
59
|
+
|
|
60
|
+
# 1. 准备输入数据
|
|
61
|
+
map_retina_opt = retina_opt * prediction_map
|
|
62
|
+
|
|
63
|
+
B, C, H, W = map_retina_opt.shape
|
|
64
|
+
|
|
65
|
+
# 为了对每个 Channel 独立应用这 r*s 个卷积核,
|
|
66
|
+
# 我们把 Batch 和 Channel 维度合并,把输入变形为 (B*C, 1, H, W)
|
|
67
|
+
x = map_retina_opt.reshape(B * C, 1, H, W)
|
|
68
|
+
|
|
69
|
+
# 2. 单次并发计算所有的卷积
|
|
70
|
+
# 此时输出形状为 (B*C, r*s, H, W)
|
|
71
|
+
conv_out = F.conv2d(x, self.attention_kernel, padding='same')
|
|
72
|
+
|
|
73
|
+
# 3. 执行 Min 和 Max 聚合操作
|
|
74
|
+
# 将输出重塑为 (B*C, r, s, H, W) 以便按维度进行聚合
|
|
75
|
+
conv_out = conv_out.view(B * C, self.r, self.s, H, W)
|
|
76
|
+
min_out = torch.min(conv_out, dim=2)[0] # shape: (B*C, r, H, W)
|
|
77
|
+
attention_response = torch.max(min_out, dim=1)[0] # shape: (B*C, H, W)
|
|
78
|
+
|
|
79
|
+
# 4. 恢复原始形状并计算最终结果
|
|
80
|
+
attention_response = attention_response.view(B, C, H, W)
|
|
81
|
+
|
|
82
|
+
self.Opt = retina_opt + self.alpha * attention_response
|
|
83
|
+
|
|
84
|
+
return self.Opt
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class PredictionModule(BaseCore):
|
|
88
|
+
"""
|
|
89
|
+
PredictionModule class for ApgSTMD.
|
|
90
|
+
|
|
91
|
+
This class implements the prediction module in the ApgSTMD.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self):
|
|
95
|
+
"""
|
|
96
|
+
Constructor method.
|
|
97
|
+
|
|
98
|
+
Initializes the PredictionModule object.
|
|
99
|
+
"""
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.velocity = None
|
|
102
|
+
self.intDeltaT = 25
|
|
103
|
+
self.sizeFilter = 25
|
|
104
|
+
self.numFilter = 8
|
|
105
|
+
self.zeta = 2
|
|
106
|
+
self.eta = 2.5
|
|
107
|
+
self.kappa = 0.02
|
|
108
|
+
self.mu = 0.75
|
|
109
|
+
self.beta = 1
|
|
110
|
+
self.register_buffer('time_attenuation_kernel', torch.empty(0))
|
|
111
|
+
self.register_buffer('prediction_kernel', torch.empty(0))
|
|
112
|
+
|
|
113
|
+
self.setup()
|
|
114
|
+
|
|
115
|
+
def setup(self):
|
|
116
|
+
"""
|
|
117
|
+
initiate config for prediction module.
|
|
118
|
+
"""
|
|
119
|
+
self.intDeltaT = max(int(self.intDeltaT), 1)
|
|
120
|
+
|
|
121
|
+
if self.velocity is None:
|
|
122
|
+
self.velocity = 25 / 4 / self.intDeltaT
|
|
123
|
+
|
|
124
|
+
_prediction_kernel = create_prediction_kernel(
|
|
125
|
+
self.velocity,
|
|
126
|
+
self.intDeltaT,
|
|
127
|
+
self.sizeFilter,
|
|
128
|
+
self.numFilter,
|
|
129
|
+
self.zeta,
|
|
130
|
+
self.eta
|
|
131
|
+
)
|
|
132
|
+
self.prediction_kernel.data = torch.stack(_prediction_kernel).unsqueeze(1)
|
|
133
|
+
|
|
134
|
+
self.time_attenuation_kernel = torch.exp(self.kappa * torch.arange(-self.intDeltaT, 1))
|
|
135
|
+
|
|
136
|
+
self.reset() # 初始化历史帧缓存
|
|
137
|
+
|
|
138
|
+
def reset(self):
|
|
139
|
+
self.prediction_gain_buffer = deque(maxlen=self.intDeltaT)
|
|
140
|
+
self.prediction_map_buffer = deque(maxlen=self.intDeltaT)
|
|
141
|
+
|
|
142
|
+
def forward(self, lobula_opt):
|
|
143
|
+
"""
|
|
144
|
+
Processing method (Highly Optimized with Vectorization).
|
|
145
|
+
|
|
146
|
+
Processes the input lobula_opt to predict motion and update
|
|
147
|
+
prediction map.
|
|
148
|
+
"""
|
|
149
|
+
num_direction = lobula_opt.shape[1]
|
|
150
|
+
|
|
151
|
+
if len(self.prediction_gain_buffer) > 0:
|
|
152
|
+
# 计算滤波器输入 (广播机制同时处理所有方向通道)
|
|
153
|
+
filter_input = self.mu * lobula_opt + (1 - self.mu) * self.prediction_gain_buffer[0]
|
|
154
|
+
else:
|
|
155
|
+
filter_input = lobula_opt
|
|
156
|
+
|
|
157
|
+
# 分组卷积 (Depthwise Convolution)
|
|
158
|
+
# 一次 F.conv2d 计算出全部 num_direction 个通道的空间卷积,彻底消除 for 循环
|
|
159
|
+
prediction_gain = F.conv2d(filter_input,
|
|
160
|
+
self.prediction_kernel,
|
|
161
|
+
padding='same',
|
|
162
|
+
groups=num_direction)
|
|
163
|
+
|
|
164
|
+
# 更新最新一帧的历史
|
|
165
|
+
self.prediction_gain_buffer.append(prediction_gain)
|
|
166
|
+
|
|
167
|
+
# ==================== 2. Prediction Map =====================
|
|
168
|
+
# 在 num_direction (通道) 维度上求和 -> shape: (1, 1, H, W)
|
|
169
|
+
tobe_prediction_map = torch.sum(prediction_gain, dim=1, keepdim=True)
|
|
170
|
+
|
|
171
|
+
# ==================== 3. Facilitated STMD Output ============
|
|
172
|
+
temporal_conv_out = compute_temporal_conv_inplace(
|
|
173
|
+
self.prediction_gain_buffer,
|
|
174
|
+
self.time_attenuation_kernel
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# 一步计算出所有特征通道的 facilitated_opt -> shape: (1, num_direction, H, W)
|
|
178
|
+
self.Opt = lobula_opt + self.beta * temporal_conv_out
|
|
179
|
+
|
|
180
|
+
# ==================== 4. Memorizer update ===================
|
|
181
|
+
max_tobe_pre_map = torch.max(tobe_prediction_map)
|
|
182
|
+
|
|
183
|
+
self.prediction_map_buffer.append( (tobe_prediction_map > max_tobe_pre_map * 2e-1).squeeze())
|
|
184
|
+
|
|
185
|
+
# prediction_map = self.cell_prediction_map[0]
|
|
186
|
+
return self.Opt, self.prediction_map_buffer[0]
|
|
187
|
+
|
|
188
|
+
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from cv2 import filter2D, BORDER_CONSTANT
|
|
3
|
+
|
|
4
|
+
from .base_core import BaseCore
|
|
5
|
+
from ..util.create_kernel import create_prediction_kernel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PredictionModule(BaseCore):
|
|
9
|
+
"""PredictionModule class for ApgSTMD."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
"""Constructor method."""
|
|
13
|
+
# Initializes the PredictionModule object
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
# Parameters
|
|
17
|
+
self.velocity = 0.1 # Velocity: v_{opt} (Default: 0.25)
|
|
18
|
+
self.intDeltaT = 0 # Delta time
|
|
19
|
+
self.sizeFilter = 20 # Size of filter
|
|
20
|
+
self.numFilter = 8 # Number of filters
|
|
21
|
+
self.zeta = 2 # Zeta parameter
|
|
22
|
+
self.eta = 2.5 # Eta parameter
|
|
23
|
+
self.beta = 1 # Beta parameter
|
|
24
|
+
|
|
25
|
+
# Hidden properties
|
|
26
|
+
self.predictionKernel = None # Prediction kernel
|
|
27
|
+
|
|
28
|
+
def setup(self):
|
|
29
|
+
"""Initialization method."""
|
|
30
|
+
# Initializes the prediction module
|
|
31
|
+
|
|
32
|
+
self.predictionKernel = create_prediction_kernel(
|
|
33
|
+
self.velocity,
|
|
34
|
+
self.intDeltaT,
|
|
35
|
+
self.sizeFilter,
|
|
36
|
+
self.numFilter,
|
|
37
|
+
self.zeta,
|
|
38
|
+
self.eta
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def forward(self, lobulaOpt):
|
|
42
|
+
"""Processing method."""
|
|
43
|
+
# Processes the input lobulaOpt to predict motion and update prediction map
|
|
44
|
+
|
|
45
|
+
numDict = len(lobulaOpt)
|
|
46
|
+
imgH, imgW = lobulaOpt[0].shape
|
|
47
|
+
|
|
48
|
+
predictionGain = []
|
|
49
|
+
for idxD in range(numDict):
|
|
50
|
+
predictionGain.append(
|
|
51
|
+
filter2D(
|
|
52
|
+
lobulaOpt[idxD],
|
|
53
|
+
-1,
|
|
54
|
+
self.predictionKernel[idxD],
|
|
55
|
+
borderType=BORDER_CONSTANT
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Prediction Map
|
|
60
|
+
predictionMap = np.zeros((imgH, imgW))
|
|
61
|
+
for idxD in range(numDict):
|
|
62
|
+
predictionMap += predictionGain[idxD]
|
|
63
|
+
|
|
64
|
+
# Facilitated STMD Output
|
|
65
|
+
facilitatedOpt = []
|
|
66
|
+
for idxD in range(numDict):
|
|
67
|
+
facilitatedOpt.append(
|
|
68
|
+
lobulaOpt[idxD] + self.beta * predictionGain[idxD]
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Memorizer update
|
|
72
|
+
maxPreMap = np.max(predictionMap)
|
|
73
|
+
# Logical Matrix
|
|
74
|
+
predictionMap = (predictionMap > maxPreMap * 2e-1)
|
|
75
|
+
|
|
76
|
+
# Output
|
|
77
|
+
self.Opt = facilitatedOpt
|
|
78
|
+
|
|
79
|
+
return facilitatedOpt, predictionMap
|
xttmp/core/base_core.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseCore(ABC, torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Abstract base class for core processing components.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
"""
|
|
13
|
+
Constructor.
|
|
14
|
+
"""
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
self.output = None
|
|
18
|
+
|
|
19
|
+
def setup(self, *args, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Abstract method for initialization.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
def reset(self):
|
|
26
|
+
"""
|
|
27
|
+
Abstract method for resetting the state.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def forward(self, *args, **kwargs):
|
|
33
|
+
"""
|
|
34
|
+
Abstract method for processing.
|
|
35
|
+
"""
|
|
36
|
+
pass
|