xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,514 @@
1
+ from abc import ABC, abstractmethod
2
+ import warnings
3
+ import logging
4
+ import time
5
+
6
+ import torch
7
+
8
+ from ..core import estmd_core, estmd_backbone, fracstmd_core, dstmd_core
9
+ from ..util.compute_module import compute_response, compute_direction
10
+
11
+
12
+ class BaseModel(ABC, torch.nn.Module):
13
+ """ Base class for Small Target Motion Detector models. """
14
+
15
+ # Bind model parameters and their corresponding parameter pointers.
16
+ __paraMappingList = {
17
+ # here is just an example
18
+ 'sigma1': 'retina.gaussian_blur.sigma',
19
+ 'sigma2': 'lobula.gaussian_blur.sigma',
20
+ }
21
+
22
+ def __init_subclass__(cls, **kwargs):
23
+ super().__init_subclass__(**kwargs)
24
+ # 检查子类的字典中是否定义了同名方法
25
+ if 'reset' in cls.__dict__:
26
+ raise TypeError(f"禁止重写: 子类 {cls.__name__} 不能覆盖 'reset_buffer' 方法。")
27
+ if 'setup' in cls.__dict__:
28
+ raise TypeError(f"禁止重写: 子类 {cls.__name__} 不能覆盖 'setup' 方法。")
29
+ if 'print_para' in cls.__dict__:
30
+ raise TypeError(f"禁止重写: 子类 {cls.__name__} 不能覆盖 'print_para' 方法。")
31
+ if 'set_para' in cls.__dict__:
32
+ raise TypeError(f"禁止重写: 子类 {cls.__name__} 不能覆盖 'set_para' 方法。")
33
+
34
+ def __init__(self):
35
+ """ Constructor method.
36
+ """
37
+ super().__init__()
38
+
39
+ self.retina = None # Handle for the retina layer
40
+ self.lamina = None # Handle for the lamina layer
41
+ self.medulla = None # Handle for the medulla layer
42
+ self.lobula = None # Handle for the lobula layer
43
+
44
+ self.input_fps = None
45
+
46
+ self.register_buffer('_dummy_device', torch.empty(0)) # Buffer for the correlation output, used for direction computation
47
+
48
+ # Model output structure
49
+ self.model_output = {'response': None, 'direction': None}
50
+
51
+ def setup(self):
52
+ """
53
+ 递归地遍历模型中的所有子模块,
54
+ 如果该子模块有 setup 方法,就调用它。
55
+ """
56
+ for module in self.children():
57
+ if hasattr(module, 'setup'):
58
+ module.setup()
59
+
60
+ def reset_buffer(self):
61
+ """
62
+ 递归地遍历模型中的所有子模块,
63
+ 如果该子模块有 reset_buffer 方法,就调用它。
64
+ """
65
+ for module in self.children():
66
+ if hasattr(module, 'reset_buffer'):
67
+ module.reset_buffer()
68
+
69
+ @abstractmethod
70
+ def forward(self, img_tensor: torch.Tensor):
71
+ """
72
+ Abstract method for forwarding input through the model.
73
+
74
+ Parameters:
75
+ modelIpt: torch.tensor([B, 1, H, W]) Input for model forwarding.
76
+ Returns:
77
+ model_output: Model output structure.
78
+ """
79
+ pass
80
+
81
+ def process(self, img_tensor: torch.Tensor):
82
+ """ (Old API) Process method for the model.
83
+
84
+ This method serves as a wrapper around the forward method
85
+
86
+ Parameters:
87
+ img_tensor: torch.Tensor Input tensor for processing.
88
+ Returns:
89
+ model_output: The output from the forward method, potentially after additional processing.
90
+ time_cost: The time taken to process the input, useful for performance evaluation.
91
+ """
92
+ device = self._dummy_device.device
93
+ if device == torch.device('cuda'):
94
+ torch.cuda.synchronize() # Ensure all CUDA operations are complete before starting the timer
95
+ start_time = time.perf_counter()
96
+ model_output = self.forward(img_tensor)
97
+ if device == torch.device('cuda'):
98
+ torch.cuda.synchronize() # Ensure all CUDA operations are complete before starting the timer
99
+ end_time = time.perf_counter()
100
+
101
+ return model_output, end_time - start_time
102
+
103
+ def print_para(self) -> None:
104
+ logger = logging.getLogger(__name__)
105
+
106
+ para_list = eval(f'self._{self.__class__.__name__}__paraMappingList')
107
+
108
+ if not para_list:
109
+ logger.info(f'The parameters of <{self.__class__.__name__}> is empty.')
110
+ return
111
+
112
+ msg = f'The parameters of <{self.__class__.__name__}> are:\n'
113
+ for name, paths in para_list.items():
114
+ msg += f' {name:6}'
115
+ if isinstance(paths, tuple):
116
+ for i, path in enumerate(paths):
117
+ val = self._get_nested_attr(path)
118
+ if i == 0:
119
+ msg += f' --> {path} = {val}\n'
120
+ elif i == len(paths) - 1:
121
+ msg += f'{" "*len(name):6} \\---> {path} = {val}\n'
122
+ else:
123
+ msg += f'{" "*len(name):6} |---> {path} = {val}\n'
124
+ else:
125
+ val = self._get_nested_attr(paths)
126
+ msg += f' --> {paths} = {val}\n'
127
+
128
+ logger.info(msg)
129
+
130
+ def set_para(self, **kwargs):
131
+ """
132
+ Sets parameters for the class instance based on provided keyword arguments.
133
+
134
+ This method updates instance attributes using keyword arguments passed to it.
135
+ The attributes to be updated are determined by a private attribute that
136
+ maps parameter names to their respective instance attribute names or tuples of attribute names.
137
+
138
+ Parameters:
139
+ - **kwargs: Keyword arguments where each key-value pair represents a parameter name and its new value.
140
+
141
+ Behavior:
142
+ - The method iterates over each key-value pair in `kwargs`.
143
+ - It retrieves the dictionary of parameter mappings for the current class instance by accessing a private attribute.
144
+ - If the parameter name (`key`) exists in the dictionary:
145
+ - If the corresponding value is a tuple, it assigns the new value to each attribute in the tuple using `setattr`.
146
+ - If the corresponding value is not a tuple, it assigns the new value to the single attribute specified using `setattr`.
147
+ - If the parameter name does not exist in the dictionary, a warning is issued.
148
+
149
+ Raises:
150
+ - None directly, but issues a warning if the parameter does not exist.
151
+ """
152
+ para_list = getattr(self, f'self._{self.__class__.__name__}__paraMappingList', {})
153
+
154
+ for key, value in kwargs.items():
155
+ if key in para_list.keys():
156
+ paths = para_list[key]
157
+ if isinstance(paths, tuple):
158
+ for mapped_key in paths:
159
+ self._set_nested_attr(mapped_key, value)
160
+ else:
161
+ self._set_nested_attr(paths, value)
162
+ else:
163
+ warnings.warn(f"Private variable '{key}' does not exist.", UserWarning)
164
+
165
+ def _get_nested_attr(self, attr_str):
166
+ """ 安全地获取嵌套属性,如 'retina.gaussian_blur.sigma' """
167
+ obj = self
168
+ for attr in attr_str.split('.'):
169
+ obj = getattr(obj, attr)
170
+ return obj
171
+
172
+ def _set_nested_attr(self, attr_str, value):
173
+ """ 安全地设置嵌套属性,如将 'retina.gaussian_blur.sigma' 设为 value """
174
+ obj = self
175
+ attrs = attr_str.split('.')
176
+ for attr in attrs[:-1]:
177
+ obj = getattr(obj, attr)
178
+ setattr(obj, attrs[-1], value)
179
+
180
+
181
+ class ESTMD(BaseModel):
182
+ """ ESTMD: Elementary small target motion detector
183
+
184
+ Ref:
185
+ * Wiederman S D, Shoemaker P A, O'Carroll D C. A model for the detection of moving targets in visual clutter inspired by insect physiology[J]. PloS one, 2008, 3(7): e2784.
186
+ * Wang H, Peng J, Yue S. A directionally selective small target motion detecting visual neural network in cluttered backgrounds[J]. IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
187
+
188
+ Remark:
189
+ The implementation and parameters in this code follow Ref [2].
190
+
191
+ Parameters:
192
+ Retina:
193
+ - sigma1: Standard deviation for Gaussian blur in the retina, representing visual preforwarding. (Eq. 1)
194
+ Lamina:
195
+ - n1, tau1: Order and time constant for the first gamma bandpass filter delay in the lamina. (Eq. 4)
196
+ - n2, tau2: Order and time constant for the second gamma bandpass filter delay in the lamina. (Eq. 4)
197
+ - sigma2, sigma3: Standard deviations for lateral inhibition in the lamina. (Eq. 8-9)
198
+ - lambda1, lambda2: Parameters controlling lateral inhibition intensity. (Eq. 10-11)
199
+ Medulla:
200
+ - A, B: Parameters for second-inhibition mechanisms in the medulla. (Eq. 20)
201
+ - sigma4, sigma5: Standard deviations for second-inhibition spatial spread in medulla. (Eq. 21)
202
+ - e, rho: Non-linear interaction parameters in second-inhibition. (Eq. 21)
203
+ - n3, tau3: Order and time constant for gamma delay in neurons Tm1 and Mi1 in the medulla. (Eq. 24)
204
+ """
205
+
206
+ # Bind model parameters and their corresponding parameter pointers.
207
+ __paraMappingList = {
208
+ # retina
209
+ 'sigma1' : 'retina.sigma', # Eq. (1)
210
+ # lamina
211
+ 'n1' : 'lamina.gamma_BPF.order1', # Eq. (4)
212
+ 'tau1' : 'lamina.gamma_BPF.tau1',
213
+ 'n2' : 'lamina.gamma_BPF.order2',
214
+ 'tau2' : 'lamina.gamma_BPF.tau2',
215
+ 'sigma2' : 'lamina.spatial_inhibition.sigma1', # Eq. (8)(9)
216
+ 'sigma3' : 'lamina.spatial_inhibition.sigma2',
217
+ 'lambda1' : 'lamina.spatial_inhibition.lambda1', # Eq. (10)(11)
218
+ 'lambda2' : 'lamina.spatial_inhibition.lambda2',
219
+ # medulla
220
+ 'A' : ('medulla.tm2.spatial_inhibition.A', 'medulla.tm3.spatial_inhibition.A'), # Eq. (20)
221
+ 'B' : ('medulla.tm2.spatial_inhibition.B', 'medulla.tm3.spatial_inhibition.B'),
222
+ 'sigma4' : ('medulla.tm2.spatial_inhibition.sigma1', 'medulla.tm3.spatial_inhibition.sigma1'), # Eq. (21)
223
+ 'sigma5' : ('medulla.tm2.spatial_inhibition.sigma2', 'medulla.tm3.spatial_inhibition.sigma2'),
224
+ 'e' : ('medulla.tm2.spatial_inhibition.e', 'medulla.tm3.spatial_inhibition.e'),
225
+ 'rho' : ('medulla.tm2.spatial_inhibition.rho', 'medulla.tm3.spatial_inhibition.rho'),
226
+ 'n3' : ('medulla.tm1.order', 'medulla.mi1.order'), # Eq. (24)
227
+ 'tau3' : ('medulla.tm1.tau', 'medulla.mi1.tau')
228
+ }
229
+
230
+ def __init__(self):
231
+ # Call the superclass constructor
232
+ super().__init__()
233
+ # Initialize components
234
+ self.retina = estmd_core.Retina()
235
+ self.lamina = estmd_core.Lamina()
236
+ self.medulla = estmd_core.Medulla()
237
+ self.lobula = estmd_core.Lobula()
238
+
239
+ def forward(self, x):
240
+ # Define the structure of the ESTMD model
241
+ # forward input matrix through model components
242
+ retina_output = self.retina.forward(x)
243
+ lamina_output = self.lamina.forward(retina_output)
244
+ medulla_ON, medulla_OFF = self.medulla.forward(lamina_output)
245
+ lobula_output = self.lobula.forward(medulla_ON, medulla_OFF)
246
+ # direction not set in the ESTMD model
247
+ self.model_output['response'] = lobula_output
248
+
249
+ return self.model_output
250
+
251
+
252
+ class ESTMDBackbone(BaseModel):
253
+ """ ESTMDBackbone: A backbone based on ESTMD
254
+
255
+ Ref:
256
+ * Wiederman S D, Shoemaker P A, O'Carroll D C. A model for the detection of moving targets in visual clutter inspired by insect physiology[J]. PloS one, 2008, 3(7): e2784.
257
+ * Wang H, Peng J, Yue S. A directionally selective small target motion detecting visual neural network in cluttered backgrounds[J]. IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
258
+ """
259
+
260
+ # Bind model parameters and their corresponding parameter pointers.
261
+ __paraMappingList = {
262
+ 'sigma1' : 'retina.sigma',
263
+ 'n1' : 'lamina.order1',
264
+ 'tau1' : 'lamina.tau1',
265
+ 'n2' : 'lamina.order2',
266
+ 'tau2' : 'lamina.tau2',
267
+ 'A' : 'lobula.spatial_inhibition.A',
268
+ 'B' : 'lobula.spatial_inhibition.B',
269
+ 'e' : 'lobula.spatial_inhibition.e',
270
+ 'rho' : 'lobula.spatial_inhibition.rho',
271
+ 'sigma4' : 'lobula.spatial_inhibition.sigma1',
272
+ 'sigma5' : 'lobula.spatial_inhibition.sigma2',
273
+ 'order3' : ('medulla.tm1.order', 'medulla.mi1.order'),
274
+ 'tau3' : ('medulla.tm1.tau', 'medulla.mi1.tau'),
275
+ }
276
+
277
+ def __init__(self):
278
+ """ ESTMDBackbone Constructor method
279
+
280
+ Initializes an instance of the ESTMDBackbone class.
281
+ """
282
+ # Call superclass constructor
283
+ super().__init__()
284
+
285
+ # Initialize components
286
+ self.retina = estmd_core.Retina()
287
+ self.lamina = estmd_backbone.Lamina()
288
+ self.medulla = estmd_backbone.Medulla()
289
+ self.lobula = estmd_backbone.Lobula()
290
+
291
+ def forward(self, img_tensor):
292
+ """ forward Method
293
+
294
+ Defines the structure of the ESTMDBackbone model.
295
+ """
296
+ # forward input matrix through model components
297
+ retina_output = self.retina.forward(img_tensor)
298
+ lamina_output = self.lamina.forward(retina_output)
299
+ medulla_ON, medulla_OFF = self.medulla.forward(lamina_output)
300
+ self.lobula_output, _ = self.lobula.forward(medulla_ON, medulla_OFF)
301
+
302
+ # Set model response
303
+ self.model_output['response'] = self.lobula_output
304
+
305
+ return self.model_output
306
+
307
+
308
+ class FracSTMD(ESTMDBackbone):
309
+ """ FracSTMD: Fractional-order Small Target Motion Detector
310
+
311
+ Ref:
312
+ * Xu M, Wang H, Chen H, et al. A fractional-order visual neural model for small target motion detection[J]. Neurocomputing, 2023, 550: 126459.
313
+
314
+ Description:
315
+ The FracSTMD model leverages a fractional-order approach to enhance the precision of small target motion detection for low-sampling-frequency.
316
+ It captures instantaneous luminance change and integrates it with memory information, where the instantaneous information dominates the integrated signal. Due to the rapid response of instantaneous information and the supplement of memory information, the proposed model locates the small moving targets accurately and robustly in low-sampling-frequencies.
317
+
318
+ Parameters:
319
+ Retina:
320
+ - sigma1: Standard deviation of the Gaussian blur applied in the retina layer to reduce noise and high-frequency artifacts, enhancing visual clarity for subsequent forwarding. (Eq. 2)
321
+
322
+ Lamina:
323
+ - alpha: Order of Fractional-differnece operator in the lamina. (Eq. 5)
324
+ - delta: Time constant in fractional order operators
325
+
326
+ Medulla:
327
+ - n1: Order of the gamma delay function in the Tm1 pathway, which contributes to the temporal filtering in the medulla. (Eq. 10)
328
+ - tau1: Time constant for the gamma delay in Tm1, determining the speed at which temporal integration occurs, optimizing the medulla’s responsiveness to small target motion.
329
+
330
+ Lobula:
331
+ - A, B: Amplitude parameter for lateral inhibition in the lobula. (Eq. 14)
332
+ - e, rho, sigma2, sigma3: Parameter controlling the strength of inhibition for lateral interactions. (Eq. 15)
333
+ """
334
+
335
+ # Bind model parameters and their corresponding parameter pointers.
336
+ __paraMappingList = {
337
+ # retina
338
+ 'sigma1' : 'retina.sigma', # Eq. (2)
339
+ # lamina
340
+ 'alpha' : 'lamina.alpha', # Eq. (5)
341
+ 'delta' : 'lamina.delta',
342
+ # medulla
343
+ 'n1' : 'medulla.tm1.order', # Eq. (10)
344
+ 'tau1' : 'medulla.tm1.tau',
345
+ # lobula
346
+ 'A' : 'lobula.spatial_inhibition.A', # Eq. (14)
347
+ 'B' : 'lobula.spatial_inhibition.B',
348
+ 'e' : 'lobula.spatial_inhibition.e', # Eq. (15)
349
+ 'rho' : 'lobula.spatial_inhibition.rho',
350
+ 'sigma2' : 'lobula.spatial_inhibition.sigma1',
351
+ 'sigma3' : 'lobula.spatial_inhibition.sigma2',
352
+ }
353
+
354
+ def __init__(self):
355
+ """
356
+ FracSTMD Constructor method
357
+ Initializes an instance of the FracSTMD class.
358
+ """
359
+ # Call superclass constructor
360
+ super().__init__()
361
+
362
+ # Customize Lamina and Lobula components
363
+ self.lamina = fracstmd_core.Lamina()
364
+ self.medulla.tm1.order = 100
365
+ self.lobula.spatial_inhibition.e = 1.8
366
+
367
+
368
+ class DSTMD(BaseModel):
369
+ """
370
+ DSTMD: Directional-Small Target Motion Detector
371
+
372
+ Ref:
373
+ * Wang H, Peng J, Yue S. A directionally selective small target motion detecting visual neural network in cluttered backgrounds[J]. IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
374
+
375
+ Parameters:
376
+ Retina:
377
+ - sigma1: Standard deviation for the Gaussian blur in the retina layer, serving as a pre-filter for noise reduction. (Eq. 1)
378
+
379
+ Lamina:
380
+ - n1, tau1: Order and time constant of the first gamma bandpass filter in the lamina. (Eq. 4)
381
+ - n2, tau2: Order and time constant of the second gamma bandpass filter. (Eq. 4)
382
+ - sigma2, sigma3: Standard deviations for lateral inhibition in the lamina, helping to suppress non-target background motion. (Eq. 8-9)
383
+ - lambda1, lambda2: Parameters controlling lateral inhibition strength. (Eq. 10-11)
384
+
385
+ Medulla:
386
+ - n4, tau4: Order and time constant of gamma delay in the Mi1 neuron pathway, enhancing motion sensitivity. (Eq. 25)
387
+ - n5, tau5: Order and time constant of gamma delay in the Tm1 pathway. (Eq. 25)
388
+ - n6, tau6: Order and time constant of gamma delay in another Tm1 pathway variant, allowing selective tuning for target size and direction. (Eq. 25)
389
+
390
+ Lobula:
391
+ - alpha1: Parameter controlling signal intensity in the lobula, modulating directional selectivity. (Eq. 26)
392
+ - A, B: Lateral inhibition parameters. (Eq. 20)
393
+ - e, rho: Nonlinear parameters for lateral inhibition, affecting signal strength and inhibition shape. (Eq. 21)
394
+ - sigma4, sigma5: Standard deviations for lateral inhibition spatial spread, setting the inhibition area. (Eq. 21)
395
+ - sigma6, sigma7: Parameters for directionally selective inhibition. (Eq. 29)
396
+ """
397
+
398
+ # Bind model parameters and their corresponding parameter pointers.
399
+ __paraMappingList = {
400
+ # retina
401
+ 'sigma1' : 'retina.sigma', # Eq. (1)
402
+ # lamina
403
+ 'n1' : 'lamina.gamma_BPF.order1', # Eq. (4)
404
+ 'tau1' : 'lamina.gamma_BPF.tau1',
405
+ 'n2' : 'lamina.gamma_BPF.order2',
406
+ 'tau2' : 'lamina.gamma_BPF.tau2',
407
+ 'sigma2' : 'lamina.spatial_inhibition.sigma1', # Eq. (8)(9)
408
+ 'sigma3' : 'lamina.spatial_inhibition.sigma2',
409
+ 'lambda1' : 'lamina.spatial_inhibition.lambda1', # Eq. (10)(11)
410
+ 'lambda2' : 'lamina.spatial_inhibition.lambda2',
411
+ # medulla
412
+ 'n4' : 'medulla.mi1_para4.order', # Eq. (25)
413
+ 'tau4' : 'medulla.mi1_para4.tau',
414
+ 'n5' : 'medulla.tm1_para5.order',
415
+ 'tau5' : 'medulla.tm1_para5.tau',
416
+ 'n6' : 'medulla.tm1_para6.order',
417
+ 'tau6' : 'medulla.tm1_para6.tau',
418
+ # lobula
419
+ 'alpha1' : 'lobula.alpha1', # Eq. (26)
420
+ 'A' : 'lobula.hLateralInhi.A', # Eq. (20)
421
+ 'B' : 'lobula.hLateralInhi.B',
422
+ 'e' : 'lobula.hLateralInhi.e', # Eq. (21)
423
+ 'rho' : 'lobula.hLateralInhi.rho',
424
+ 'sigma4' : 'lobula.hLateralInhi.sigma1',
425
+ 'sigma5' : 'lobula.hLateralInhi.sigma2',
426
+ 'sigma6' : 'lobula.hDirectionInhi.sigma1', # Eq. (29)
427
+ 'sigma7' : 'lobula.hDirectionInhi.sigma2',
428
+ }
429
+
430
+ def __init__(self):
431
+ """ DSTMD Constructor method
432
+
433
+ Initializes an instance of the DSTMD class.
434
+ """
435
+ # Call superclass constructor
436
+ super().__init__()
437
+
438
+ # Initialize components
439
+ self.retina = estmd_core.Retina()
440
+ self.lamina = estmd_core.Lamina()
441
+ self.medulla = dstmd_core.Medulla()
442
+ self.lobula = dstmd_core.Lobula()
443
+
444
+
445
+ def forward(self, x):
446
+ """ forward Method
447
+
448
+ Defines the structure of the DSTMD model.
449
+ """
450
+ # forward input matrix through model components
451
+ retina_output = self.retina.forward(x)
452
+ lamina_output = self.lamina.forward(retina_output)
453
+ medulla_tm3_output, medulla_mi1_p4_output, medulla_tm1_p5_output, medulla_tm1_p6_output = \
454
+ self.medulla.forward(lamina_output)
455
+ lobula_output = self.lobula.forward(medulla_tm3_output, medulla_mi1_p4_output,
456
+ medulla_tm1_p5_output, medulla_tm1_p6_output)
457
+
458
+ # Compute response and direction
459
+ self.model_output['response'] = compute_response(lobula_output)
460
+ self.model_output['direction'] = compute_direction(lobula_output)
461
+
462
+ return self.model_output
463
+
464
+
465
+ class DSTMDBackbone(DSTMD):
466
+ """ DSTMDBackbone: A directional backbone based on DSTMD
467
+
468
+ Ref:
469
+ * Wang H, Peng J, Yue S. A directionally selective small target motion detecting visual neural network in cluttered backgrounds[J]. IEEE transactions on cybernetics, 2018, 50(4): 1541-1555.
470
+ """
471
+
472
+ # Bind model parameters and their corresponding parameter pointers.
473
+ __paraMappingList = {
474
+ # retina
475
+ 'sigma1' : 'retina.sigma',
476
+ # lamina
477
+ 'n1' : 'lamina.order1',
478
+ 'tau1' : 'lamina.tau1',
479
+ 'n2' : 'lamina.order2',
480
+ 'tau2' : 'lamina.tau2',
481
+ # medulla
482
+ 'n4' : 'medulla.mi1_para4.order',
483
+ 'tau4' : 'medulla.mi1_para4.tau',
484
+ 'n5' : 'medulla.tm1_para5.order',
485
+ 'tau5' : 'medulla.tm1_para5.tau',
486
+ 'n6' : 'medulla.tm1_para6.order',
487
+ 'tau6' : 'medulla.tm1_para6.tau',
488
+ # lobula
489
+ 'alpha1' : 'lobula.alpha1',
490
+ 'A' : 'lobula.hLateralInhi.A',
491
+ 'B' : 'lobula.hLateralInhi.B',
492
+ 'e' : 'lobula.hLateralInhi.e',
493
+ 'rho' : 'lobula.hLateralInhi.rho',
494
+ 'sigma4' : 'lobula.hLateralInhi.sigma1',
495
+ 'sigma5' : 'lobula.hLateralInhi.sigma2',
496
+ }
497
+
498
+ def __init__(self):
499
+ """ DSTMDBackbone Constructor method
500
+
501
+ Initializes an instance of the DSTMDBackbone class.
502
+ """
503
+ # Call superclass constructor
504
+ super().__init__()
505
+
506
+ # Initialize components
507
+ self.lamina = estmd_backbone.Lamina()
508
+
509
+
510
+
511
+
512
+
513
+
514
+