flashdet 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. flashdet/__init__.py +23 -0
  2. flashdet/analytics/__init__.py +19 -0
  3. flashdet/analytics/benchmark.py +156 -0
  4. flashdet/analytics/plots.py +226 -0
  5. flashdet/analytics/profiler.py +199 -0
  6. flashdet/cfg/__init__.py +5 -0
  7. flashdet/cfg/config.py +236 -0
  8. flashdet/cli.py +296 -0
  9. flashdet/data/__init__.py +27 -0
  10. flashdet/data/dataloader.py +127 -0
  11. flashdet/data/dataset.py +210 -0
  12. flashdet/data/prepare.py +872 -0
  13. flashdet/data/transforms.py +345 -0
  14. flashdet/engine/__init__.py +6 -0
  15. flashdet/engine/callbacks.py +152 -0
  16. flashdet/engine/exporter.py +141 -0
  17. flashdet/engine/predictor.py +319 -0
  18. flashdet/engine/trainer.py +585 -0
  19. flashdet/engine/validator.py +176 -0
  20. flashdet/losses/__init__.py +21 -0
  21. flashdet/losses/chunked_loss.py +116 -0
  22. flashdet/losses/focal_loss.py +162 -0
  23. flashdet/losses/iou_loss.py +117 -0
  24. flashdet/losses/kd_loss.py +234 -0
  25. flashdet/models/__init__.py +50 -0
  26. flashdet/models/assignment/__init__.py +3 -0
  27. flashdet/models/assignment/dsl_assigner.py +187 -0
  28. flashdet/models/backbone/__init__.py +3 -0
  29. flashdet/models/backbone/shufflenet.py +168 -0
  30. flashdet/models/detector.py +437 -0
  31. flashdet/models/head/__init__.py +9 -0
  32. flashdet/models/head/aux_head.py +117 -0
  33. flashdet/models/head/nanodet_head.py +536 -0
  34. flashdet/models/lora.py +869 -0
  35. flashdet/models/neck/__init__.py +11 -0
  36. flashdet/models/neck/conv_module.py +78 -0
  37. flashdet/models/neck/ghost_pan.py +304 -0
  38. flashdet/nn/__init__.py +4 -0
  39. flashdet/registry.py +96 -0
  40. flashdet/solutions/__init__.py +19 -0
  41. flashdet/solutions/analytics_dashboard.py +262 -0
  42. flashdet/solutions/distance_calculator.py +214 -0
  43. flashdet/solutions/heatmap.py +146 -0
  44. flashdet/solutions/live_inference.py +270 -0
  45. flashdet/solutions/object_counter.py +152 -0
  46. flashdet/solutions/parking_manager.py +241 -0
  47. flashdet/solutions/queue_manager.py +178 -0
  48. flashdet/solutions/region_counter.py +156 -0
  49. flashdet/solutions/security_alarm.py +221 -0
  50. flashdet/solutions/speed_estimator.py +144 -0
  51. flashdet/solutions/workout_monitor.py +203 -0
  52. flashdet/trackers/__init__.py +7 -0
  53. flashdet/trackers/bot_sort.py +420 -0
  54. flashdet/trackers/byte_tracker.py +312 -0
  55. flashdet/trackers/sort_tracker.py +270 -0
  56. flashdet/utils/__init__.py +22 -0
  57. flashdet/utils/box_utils.py +203 -0
  58. flashdet/utils/checkpoint.py +276 -0
  59. flashdet/utils/logger.py +106 -0
  60. flashdet/utils/metrics.py +229 -0
  61. flashdet/utils/torchtune_optim.py +300 -0
  62. flashdet/utils/visualization.py +253 -0
  63. flashdet-1.0.0.dist-info/METADATA +473 -0
  64. flashdet-1.0.0.dist-info/RECORD +68 -0
  65. flashdet-1.0.0.dist-info/WHEEL +5 -0
  66. flashdet-1.0.0.dist-info/entry_points.txt +2 -0
  67. flashdet-1.0.0.dist-info/licenses/LICENSE +21 -0
  68. flashdet-1.0.0.dist-info/top_level.txt +1 -0
flashdet/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ """FlashDet — Ultra-lightweight real-time object detection."""
2
+
3
+ __version__ = "1.0.0"
4
+
5
+ from flashdet.models.detector import FlashDet
6
+ from flashdet.models.lora import apply_lora, apply_qlora, merge_lora_weights
7
+ from flashdet.engine.trainer import Trainer
8
+ from flashdet.engine.validator import Validator
9
+ from flashdet.engine.predictor import Predictor
10
+ from flashdet.engine.exporter import Exporter
11
+ from flashdet.cfg import get_config
12
+ from flashdet.trackers import ByteTracker
13
+ from flashdet.solutions import ObjectCounter, SpeedEstimator, Heatmap, RegionCounter
14
+ from flashdet.analytics import Benchmark
15
+
16
+ __all__ = [
17
+ "FlashDet", "Trainer", "Validator", "Predictor", "Exporter",
18
+ "apply_lora", "apply_qlora", "merge_lora_weights", "get_config",
19
+ "ByteTracker",
20
+ "ObjectCounter", "SpeedEstimator", "Heatmap", "RegionCounter",
21
+ "Benchmark",
22
+ "__version__",
23
+ ]
@@ -0,0 +1,19 @@
1
+ """Analytics — benchmarking, profiling and visualisation tools for FlashDet."""
2
+
3
+ from flashdet.analytics.benchmark import Benchmark
4
+ from flashdet.analytics.profiler import Profiler
5
+ from flashdet.analytics.plots import (
6
+ plot_training_curves,
7
+ plot_pr_curve,
8
+ plot_confusion_matrix,
9
+ plot_map_curve,
10
+ )
11
+
12
+ __all__ = [
13
+ "Benchmark",
14
+ "Profiler",
15
+ "plot_training_curves",
16
+ "plot_pr_curve",
17
+ "plot_confusion_matrix",
18
+ "plot_map_curve",
19
+ ]
@@ -0,0 +1,156 @@
1
+ """Benchmark — measure FlashDet model speed, size and parameter count."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+
11
+
12
+ class Benchmark:
13
+ """Benchmark FlashDet model speed and resource usage.
14
+
15
+ Parameters
16
+ ----------
17
+ model_path : str | Path
18
+ Path to a saved FlashDet checkpoint (``.pth`` / ``.onnx``).
19
+ device : str
20
+ ``"cuda"`` or ``"cpu"``.
21
+ input_size : int | tuple[int, int]
22
+ Network input resolution. A single int is treated as
23
+ ``(input_size, input_size)``.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model_path: Union[str, Path],
29
+ device: str = "cuda",
30
+ input_size: Union[int, tuple] = 320,
31
+ ):
32
+ self.model_path = Path(model_path)
33
+ self.device = device
34
+ self.input_size = (input_size, input_size) if isinstance(input_size, int) else tuple(input_size)
35
+
36
+ self._model: Optional[Any] = None
37
+ self._is_onnx: bool = self.model_path.suffix.lower() == ".onnx"
38
+
39
+ # ------------------------------------------------------------------
40
+ # Public API
41
+ # ------------------------------------------------------------------
42
+
43
+ def run(self, warmup: int = 10, iterations: int = 100) -> Dict[str, float]:
44
+ """Run a speed benchmark.
45
+
46
+ Returns
47
+ -------
48
+ dict
49
+ ``{"fps": …, "latency_ms": …, "params": …, "model_size_mb": …}``
50
+ """
51
+ model = self._load_model()
52
+ dummy = self._make_dummy_input()
53
+
54
+ if self._is_onnx:
55
+ return self._bench_onnx(model, dummy, warmup, iterations)
56
+ return self._bench_pytorch(model, dummy, warmup, iterations)
57
+
58
+ def compare(self, model_paths: List[Union[str, Path]]) -> List[Dict[str, Any]]:
59
+ """Compare multiple models side by side.
60
+
61
+ Parameters
62
+ ----------
63
+ model_paths : list[str | Path]
64
+ Paths to additional models to compare against the primary model.
65
+
66
+ Returns
67
+ -------
68
+ list[dict]
69
+ One result dict per model (primary model first).
70
+ """
71
+ all_paths = [self.model_path] + [Path(p) for p in model_paths]
72
+ results: List[Dict[str, Any]] = []
73
+ for p in all_paths:
74
+ bm = Benchmark(p, device=self.device, input_size=self.input_size)
75
+ res = bm.run()
76
+ res["model"] = str(p.name)
77
+ results.append(res)
78
+ return results
79
+
80
+ # ------------------------------------------------------------------
81
+ # Internal helpers
82
+ # ------------------------------------------------------------------
83
+
84
+ def _load_model(self):
85
+ if self._model is not None:
86
+ return self._model
87
+
88
+ if self._is_onnx:
89
+ import onnxruntime as ort
90
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"]
91
+ self._model = ort.InferenceSession(str(self.model_path), providers=providers)
92
+ else:
93
+ import torch
94
+ self._model = torch.load(str(self.model_path), map_location=self.device)
95
+ if hasattr(self._model, "eval"):
96
+ self._model.eval()
97
+ return self._model
98
+
99
+ def _make_dummy_input(self) -> np.ndarray:
100
+ return np.random.rand(1, 3, *self.input_size).astype(np.float32)
101
+
102
+ def _bench_pytorch(self, model, dummy: np.ndarray, warmup: int, iterations: int) -> Dict[str, float]:
103
+ import torch
104
+
105
+ tensor = torch.from_numpy(dummy).to(self.device)
106
+
107
+ # Warmup
108
+ with torch.no_grad():
109
+ for _ in range(warmup):
110
+ model(tensor)
111
+ if self.device == "cuda":
112
+ torch.cuda.synchronize()
113
+
114
+ # Timed runs
115
+ start = time.perf_counter()
116
+ with torch.no_grad():
117
+ for _ in range(iterations):
118
+ model(tensor)
119
+ if self.device == "cuda":
120
+ torch.cuda.synchronize()
121
+ elapsed = time.perf_counter() - start
122
+
123
+ latency_ms = (elapsed / iterations) * 1000
124
+ fps = iterations / elapsed
125
+
126
+ params = sum(p.numel() for p in model.parameters()) if hasattr(model, "parameters") else 0
127
+ size_mb = self.model_path.stat().st_size / (1024 * 1024)
128
+
129
+ return {
130
+ "fps": round(fps, 2),
131
+ "latency_ms": round(latency_ms, 3),
132
+ "params": params,
133
+ "model_size_mb": round(size_mb, 2),
134
+ }
135
+
136
+ def _bench_onnx(self, session, dummy: np.ndarray, warmup: int, iterations: int) -> Dict[str, float]:
137
+ input_name = session.get_inputs()[0].name
138
+
139
+ for _ in range(warmup):
140
+ session.run(None, {input_name: dummy})
141
+
142
+ start = time.perf_counter()
143
+ for _ in range(iterations):
144
+ session.run(None, {input_name: dummy})
145
+ elapsed = time.perf_counter() - start
146
+
147
+ latency_ms = (elapsed / iterations) * 1000
148
+ fps = iterations / elapsed
149
+ size_mb = self.model_path.stat().st_size / (1024 * 1024)
150
+
151
+ return {
152
+ "fps": round(fps, 2),
153
+ "latency_ms": round(latency_ms, 3),
154
+ "params": 0, # ONNX doesn't expose param count easily
155
+ "model_size_mb": round(size_mb, 2),
156
+ }
@@ -0,0 +1,226 @@
1
+ """Plotting utilities — training curves, PR curves, mAP and confusion matrices."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
7
+
8
+ import numpy as np
9
+
10
+ if TYPE_CHECKING:
11
+ from matplotlib.figure import Figure
12
+
13
+
14
+ def _get_plt():
15
+ """Lazy-import matplotlib to avoid hard dependency at module level."""
16
+ import matplotlib
17
+ matplotlib.use("Agg")
18
+ import matplotlib.pyplot as plt
19
+ return plt
20
+
21
+
22
+ # ------------------------------------------------------------------
23
+ # Training curves
24
+ # ------------------------------------------------------------------
25
+
26
+ def plot_training_curves(
27
+ log: Dict[str, List[float]],
28
+ keys: Optional[Sequence[str]] = None,
29
+ save_path: Optional[Union[str, Path]] = None,
30
+ title: str = "Training Curves",
31
+ ) -> "Figure":
32
+ """Plot one or more scalar metrics from a training log dict.
33
+
34
+ Parameters
35
+ ----------
36
+ log : dict[str, list[float]]
37
+ ``{"loss": [...], "lr": [...], "mAP": [...], ...}``
38
+ keys : sequence of str | None
39
+ Which keys to plot. *None* plots everything.
40
+ save_path : str | Path | None
41
+ If given, save figure to this path.
42
+ title : str
43
+ Plot title.
44
+
45
+ Returns
46
+ -------
47
+ matplotlib.figure.Figure
48
+ """
49
+ plt = _get_plt()
50
+ keys = keys or list(log.keys())
51
+ n = len(keys)
52
+ fig, axes = plt.subplots(1, n, figsize=(5 * n, 4), squeeze=False)
53
+ axes = axes.flatten()
54
+
55
+ for ax, key in zip(axes, keys):
56
+ values = log[key]
57
+ ax.plot(values, linewidth=1.5)
58
+ ax.set_title(key)
59
+ ax.set_xlabel("Epoch")
60
+ ax.grid(True, alpha=0.3)
61
+
62
+ fig.suptitle(title, fontsize=14)
63
+ fig.tight_layout()
64
+
65
+ if save_path is not None:
66
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
67
+
68
+ return fig
69
+
70
+
71
+ # ------------------------------------------------------------------
72
+ # Precision-Recall curve
73
+ # ------------------------------------------------------------------
74
+
75
+ def plot_pr_curve(
76
+ precisions: np.ndarray,
77
+ recalls: np.ndarray,
78
+ ap: Optional[float] = None,
79
+ class_name: str = "all",
80
+ save_path: Optional[Union[str, Path]] = None,
81
+ ) -> "Figure":
82
+ """Plot a Precision-Recall curve.
83
+
84
+ Parameters
85
+ ----------
86
+ precisions, recalls : np.ndarray
87
+ 1-D arrays of matched length.
88
+ ap : float | None
89
+ Average Precision value (shown in legend when provided).
90
+ class_name : str
91
+ Label for the curve.
92
+ save_path : str | Path | None
93
+ Optional file path.
94
+
95
+ Returns
96
+ -------
97
+ matplotlib.figure.Figure
98
+ """
99
+ plt = _get_plt()
100
+ fig, ax = plt.subplots(figsize=(6, 5))
101
+ label = f"{class_name}"
102
+ if ap is not None:
103
+ label += f" (AP={ap:.3f})"
104
+ ax.plot(recalls, precisions, linewidth=1.5, label=label)
105
+ ax.set_xlabel("Recall")
106
+ ax.set_ylabel("Precision")
107
+ ax.set_title("Precision-Recall Curve")
108
+ ax.set_xlim(0, 1)
109
+ ax.set_ylim(0, 1.05)
110
+ ax.legend(loc="lower left")
111
+ ax.grid(True, alpha=0.3)
112
+ fig.tight_layout()
113
+
114
+ if save_path is not None:
115
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
116
+
117
+ return fig
118
+
119
+
120
+ # ------------------------------------------------------------------
121
+ # mAP over IoU thresholds
122
+ # ------------------------------------------------------------------
123
+
124
+ def plot_map_curve(
125
+ iou_thresholds: np.ndarray,
126
+ map_values: np.ndarray,
127
+ save_path: Optional[Union[str, Path]] = None,
128
+ ) -> "Figure":
129
+ """Plot mAP at different IoU thresholds.
130
+
131
+ Parameters
132
+ ----------
133
+ iou_thresholds : np.ndarray
134
+ 1-D array of IoU thresholds (e.g. 0.50, 0.55, …, 0.95).
135
+ map_values : np.ndarray
136
+ Corresponding mAP values.
137
+ save_path : str | Path | None
138
+ Optional save path.
139
+
140
+ Returns
141
+ -------
142
+ matplotlib.figure.Figure
143
+ """
144
+ plt = _get_plt()
145
+ fig, ax = plt.subplots(figsize=(6, 4))
146
+ ax.bar(iou_thresholds, map_values, width=0.03, color="steelblue", edgecolor="white")
147
+ ax.set_xlabel("IoU Threshold")
148
+ ax.set_ylabel("mAP")
149
+ ax.set_title("mAP @ IoU Thresholds")
150
+ ax.set_ylim(0, 1.0)
151
+ ax.grid(True, axis="y", alpha=0.3)
152
+ fig.tight_layout()
153
+
154
+ if save_path is not None:
155
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
156
+
157
+ return fig
158
+
159
+
160
+ # ------------------------------------------------------------------
161
+ # Confusion matrix
162
+ # ------------------------------------------------------------------
163
+
164
+ def plot_confusion_matrix(
165
+ matrix: np.ndarray,
166
+ class_names: Optional[List[str]] = None,
167
+ normalize: bool = True,
168
+ save_path: Optional[Union[str, Path]] = None,
169
+ title: str = "Confusion Matrix",
170
+ ) -> "Figure":
171
+ """Plot a confusion matrix as a heatmap.
172
+
173
+ Parameters
174
+ ----------
175
+ matrix : np.ndarray
176
+ Square confusion matrix of shape ``(n_classes, n_classes)``.
177
+ class_names : list[str] | None
178
+ Tick labels. Auto-generated indices when *None*.
179
+ normalize : bool
180
+ Row-normalise the matrix before plotting.
181
+ save_path : str | Path | None
182
+ Optional save path.
183
+ title : str
184
+ Plot title.
185
+
186
+ Returns
187
+ -------
188
+ matplotlib.figure.Figure
189
+ """
190
+ plt = _get_plt()
191
+ n = matrix.shape[0]
192
+ if class_names is None:
193
+ class_names = [str(i) for i in range(n)]
194
+
195
+ if normalize:
196
+ row_sums = matrix.sum(axis=1, keepdims=True)
197
+ row_sums = np.where(row_sums == 0, 1, row_sums)
198
+ matrix = matrix.astype(np.float64) / row_sums
199
+
200
+ fig, ax = plt.subplots(figsize=(max(6, n * 0.6), max(5, n * 0.5)))
201
+ im = ax.imshow(matrix, interpolation="nearest", cmap="Blues")
202
+ fig.colorbar(im, ax=ax)
203
+
204
+ ax.set_xticks(range(n))
205
+ ax.set_yticks(range(n))
206
+ ax.set_xticklabels(class_names, rotation=45, ha="right", fontsize=8)
207
+ ax.set_yticklabels(class_names, fontsize=8)
208
+ ax.set_xlabel("Predicted")
209
+ ax.set_ylabel("True")
210
+ ax.set_title(title)
211
+
212
+ thresh = matrix.max() / 2
213
+ for i in range(n):
214
+ for j in range(n):
215
+ val = matrix[i, j]
216
+ text = f"{val:.2f}" if normalize else f"{int(val)}"
217
+ ax.text(
218
+ j, i, text, ha="center", va="center",
219
+ color="white" if val > thresh else "black", fontsize=7,
220
+ )
221
+
222
+ fig.tight_layout()
223
+ if save_path is not None:
224
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
225
+
226
+ return fig
@@ -0,0 +1,199 @@
1
+ """Profiler — layer-wise latency and memory analysis for FlashDet models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from collections import OrderedDict
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+
12
+
13
+ class Profiler:
14
+ """Profile a FlashDet PyTorch model layer-by-layer.
15
+
16
+ Hooks into every ``nn.Module`` child to record per-layer forward time
17
+ and (optionally) memory consumption.
18
+
19
+ Parameters
20
+ ----------
21
+ model_path : str | Path
22
+ Path to a FlashDet ``.pth`` checkpoint.
23
+ device : str
24
+ ``"cuda"`` or ``"cpu"``.
25
+ input_size : int | tuple[int, int]
26
+ Network input resolution.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model_path: Union[str, Path],
32
+ device: str = "cuda",
33
+ input_size: Union[int, tuple] = 320,
34
+ ):
35
+ self.model_path = Path(model_path)
36
+ self.device = device
37
+ self.input_size = (input_size, input_size) if isinstance(input_size, int) else tuple(input_size)
38
+
39
+ self._model: Optional[Any] = None
40
+
41
+ # ------------------------------------------------------------------
42
+ # Public API
43
+ # ------------------------------------------------------------------
44
+
45
+ def run(self, warmup: int = 5, iterations: int = 20) -> List[Dict[str, Any]]:
46
+ """Profile the model and return per-layer statistics.
47
+
48
+ Returns
49
+ -------
50
+ list[dict]
51
+ Each dict has keys:
52
+ ``{"name", "type", "time_ms", "time_pct", "params", "output_shape"}``.
53
+ """
54
+ import torch
55
+
56
+ model = self._load_model()
57
+ dummy = torch.randn(1, 3, *self.input_size, device=self.device)
58
+
59
+ timings: OrderedDict[str, List[float]] = OrderedDict()
60
+ shapes: Dict[str, tuple] = {}
61
+ hooks = []
62
+
63
+ def _make_hook(name: str):
64
+ def _hook(module, inp, out):
65
+ if self.device == "cuda":
66
+ torch.cuda.synchronize()
67
+ timings[name].append(time.perf_counter())
68
+ if isinstance(out, torch.Tensor):
69
+ shapes[name] = tuple(out.shape)
70
+ elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
71
+ shapes[name] = tuple(out[0].shape)
72
+ return _hook
73
+
74
+ for name, module in model.named_modules():
75
+ if len(list(module.children())) > 0:
76
+ continue # only leaf modules
77
+ timings[name] = []
78
+ hooks.append(module.register_forward_hook(_make_hook(name)))
79
+
80
+ # Warmup
81
+ with torch.no_grad():
82
+ for _ in range(warmup):
83
+ model(dummy)
84
+
85
+ # Timed runs — record wall-clock per forward hook invocation
86
+ layer_times: OrderedDict[str, List[float]] = OrderedDict()
87
+ for name in timings:
88
+ layer_times[name] = []
89
+
90
+ for _ in range(iterations):
91
+ for name in timings:
92
+ timings[name].clear()
93
+
94
+ if self.device == "cuda":
95
+ torch.cuda.synchronize()
96
+ with torch.no_grad():
97
+ model(dummy)
98
+
99
+ sorted_names = sorted(timings.keys(), key=lambda n: timings[n][0] if timings[n] else float("inf"))
100
+ for i, name in enumerate(sorted_names):
101
+ if not timings[name]:
102
+ continue
103
+ t_end = timings[name][0]
104
+ if i == 0:
105
+ t_start = t_end
106
+ layer_times[name].append(0.0)
107
+ else:
108
+ prev_name = sorted_names[i - 1]
109
+ t_start = timings[prev_name][0] if timings[prev_name] else t_end
110
+ layer_times[name].append((t_end - t_start) * 1000)
111
+
112
+ for h in hooks:
113
+ h.remove()
114
+
115
+ # Aggregate
116
+ results: List[Dict[str, Any]] = []
117
+ total_ms = 0.0
118
+ for name in layer_times:
119
+ vals = layer_times[name]
120
+ mean_ms = float(np.mean(vals)) if vals else 0.0
121
+ total_ms += mean_ms
122
+
123
+ for name in layer_times:
124
+ vals = layer_times[name]
125
+ mean_ms = float(np.mean(vals)) if vals else 0.0
126
+ module = dict(model.named_modules()).get(name)
127
+ n_params = sum(p.numel() for p in module.parameters()) if module is not None else 0
128
+ mod_type = type(module).__name__ if module is not None else "Unknown"
129
+ results.append({
130
+ "name": name,
131
+ "type": mod_type,
132
+ "time_ms": round(mean_ms, 4),
133
+ "time_pct": round(mean_ms / total_ms * 100, 2) if total_ms > 0 else 0.0,
134
+ "params": n_params,
135
+ "output_shape": shapes.get(name),
136
+ })
137
+
138
+ return results
139
+
140
+ def summary(self, warmup: int = 5, iterations: int = 20) -> str:
141
+ """Return a human-readable profiling summary table."""
142
+ rows = self.run(warmup=warmup, iterations=iterations)
143
+
144
+ lines = [
145
+ f"{'Layer':<50} {'Type':<18} {'Time(ms)':>10} {'%':>7} {'Params':>10}",
146
+ "-" * 100,
147
+ ]
148
+ for r in rows:
149
+ lines.append(
150
+ f"{r['name']:<50} {r['type']:<18} {r['time_ms']:>10.4f} "
151
+ f"{r['time_pct']:>6.2f}% {r['params']:>10,}"
152
+ )
153
+ total_ms = sum(r["time_ms"] for r in rows)
154
+ total_params = sum(r["params"] for r in rows)
155
+ lines.append("-" * 100)
156
+ lines.append(
157
+ f"{'TOTAL':<50} {'':<18} {total_ms:>10.4f} {'100.00%':>7} {total_params:>10,}"
158
+ )
159
+ return "\n".join(lines)
160
+
161
+ def memory_report(self) -> Dict[str, float]:
162
+ """Return GPU memory usage summary (CUDA only).
163
+
164
+ Returns
165
+ -------
166
+ dict
167
+ ``{"allocated_mb", "reserved_mb", "peak_mb"}``
168
+ """
169
+ import torch
170
+
171
+ if not torch.cuda.is_available() or self.device != "cuda":
172
+ return {"allocated_mb": 0.0, "reserved_mb": 0.0, "peak_mb": 0.0}
173
+
174
+ model = self._load_model()
175
+ dummy = torch.randn(1, 3, *self.input_size, device=self.device)
176
+ torch.cuda.reset_peak_memory_stats()
177
+
178
+ with torch.no_grad():
179
+ model(dummy)
180
+ torch.cuda.synchronize()
181
+
182
+ return {
183
+ "allocated_mb": round(torch.cuda.memory_allocated() / 1024 ** 2, 2),
184
+ "reserved_mb": round(torch.cuda.memory_reserved() / 1024 ** 2, 2),
185
+ "peak_mb": round(torch.cuda.max_memory_allocated() / 1024 ** 2, 2),
186
+ }
187
+
188
+ # ------------------------------------------------------------------
189
+ # Internal
190
+ # ------------------------------------------------------------------
191
+
192
+ def _load_model(self):
193
+ if self._model is not None:
194
+ return self._model
195
+ import torch
196
+ self._model = torch.load(str(self.model_path), map_location=self.device)
197
+ if hasattr(self._model, "eval"):
198
+ self._model.eval()
199
+ return self._model
@@ -0,0 +1,5 @@
1
+ from flashdet.cfg.config import get_config, ModelConfig, DataConfig, TrainConfig
2
+ from flashdet.cfg.config import Config as FlashDetConfig
3
+ from flashdet.cfg.config import load_yaml_config
4
+
5
+ __all__ = ["get_config", "load_yaml_config", "ModelConfig", "DataConfig", "TrainConfig", "FlashDetConfig"]