flashseg 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. flashseg/__init__.py +20 -0
  2. flashseg/analytics/__init__.py +6 -0
  3. flashseg/analytics/benchmark.py +65 -0
  4. flashseg/analytics/profiler.py +42 -0
  5. flashseg/cfg/__init__.py +3 -0
  6. flashseg/cfg/config.py +113 -0
  7. flashseg/cli.py +216 -0
  8. flashseg/data/__init__.py +4 -0
  9. flashseg/data/dataset.py +93 -0
  10. flashseg/data/transforms.py +34 -0
  11. flashseg/engine/__init__.py +6 -0
  12. flashseg/engine/exporter.py +68 -0
  13. flashseg/engine/predictor.py +115 -0
  14. flashseg/engine/trainer.py +174 -0
  15. flashseg/engine/validator.py +65 -0
  16. flashseg/losses/__init__.py +3 -0
  17. flashseg/losses/seg_losses.py +75 -0
  18. flashseg/models/__init__.py +3 -0
  19. flashseg/models/backbone/__init__.py +3 -0
  20. flashseg/models/backbone/shufflenetv2.py +111 -0
  21. flashseg/models/build.py +45 -0
  22. flashseg/models/head/__init__.py +3 -0
  23. flashseg/models/head/seg_head.py +48 -0
  24. flashseg/models/neck/__init__.py +3 -0
  25. flashseg/models/neck/fpn.py +45 -0
  26. flashseg/nn/__init__.py +5 -0
  27. flashseg/nn/blocks.py +97 -0
  28. flashseg/solutions/__init__.py +8 -0
  29. flashseg/solutions/area_calculator.py +34 -0
  30. flashseg/solutions/background_remover.py +47 -0
  31. flashseg/solutions/lane_detector.py +45 -0
  32. flashseg/solutions/scene_parser.py +34 -0
  33. flashseg/utils/__init__.py +4 -0
  34. flashseg/utils/metrics.py +53 -0
  35. flashseg/utils/visualization.py +33 -0
  36. flashseg-1.0.0.dist-info/METADATA +307 -0
  37. flashseg-1.0.0.dist-info/RECORD +41 -0
  38. flashseg-1.0.0.dist-info/WHEEL +5 -0
  39. flashseg-1.0.0.dist-info/entry_points.txt +2 -0
  40. flashseg-1.0.0.dist-info/licenses/LICENSE +21 -0
  41. flashseg-1.0.0.dist-info/top_level.txt +1 -0
flashseg/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """FlashSeg - Ultra-lightweight real-time image segmentation."""
2
+
3
+ __version__ = "1.0.0"
4
+
5
+ from flashseg.cfg.config import get_config
6
+ from flashseg.engine.trainer import Trainer
7
+ from flashseg.engine.predictor import Predictor
8
+ from flashseg.engine.exporter import Exporter
9
+ from flashseg.engine.validator import Validator
10
+ from flashseg.models.build import build_model
11
+
12
+ __all__ = [
13
+ "__version__",
14
+ "get_config",
15
+ "build_model",
16
+ "Trainer",
17
+ "Predictor",
18
+ "Exporter",
19
+ "Validator",
20
+ ]
@@ -0,0 +1,6 @@
1
+ """Analytics and benchmarking tools."""
2
+
3
+ from flashseg.analytics.benchmark import Benchmark
4
+ from flashseg.analytics.profiler import Profiler
5
+
6
+ __all__ = ["Benchmark", "Profiler"]
@@ -0,0 +1,65 @@
1
+ """Model benchmarking for segmentation."""
2
+
3
+ import time
4
+ import logging
5
+
6
+ import torch
7
+
8
+ from flashseg.cfg.config import get_config
9
+ from flashseg.models.build import build_model
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Benchmark:
15
+ """Benchmark FlashSeg model speed and efficiency."""
16
+
17
+ def __init__(self, model_path: str = None, model_size: str = "m", input_size: int = 512, num_classes: int = 21, device: str = "cuda"):
18
+ self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
19
+ config = get_config(model_size=model_size, input_size=input_size, num_classes=num_classes)
20
+ self.model = build_model(config).to(self.device)
21
+
22
+ if model_path:
23
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
24
+
25
+ self.model.eval()
26
+ self.input_size = input_size
27
+
28
+ def run(self, warmup: int = 10, iterations: int = 100) -> dict:
29
+ """Run benchmark and return timing results."""
30
+ dummy = torch.randn(1, 3, self.input_size, self.input_size).to(self.device)
31
+
32
+ # Warmup
33
+ with torch.no_grad():
34
+ for _ in range(warmup):
35
+ self.model(dummy)
36
+
37
+ if self.device.type == "cuda":
38
+ torch.cuda.synchronize()
39
+
40
+ # Benchmark
41
+ times = []
42
+ with torch.no_grad():
43
+ for _ in range(iterations):
44
+ start = time.perf_counter()
45
+ self.model(dummy)
46
+ if self.device.type == "cuda":
47
+ torch.cuda.synchronize()
48
+ times.append(time.perf_counter() - start)
49
+
50
+ avg_ms = sum(times) / len(times) * 1000
51
+ fps = 1000.0 / avg_ms
52
+ params = sum(p.numel() for p in self.model.parameters())
53
+
54
+ results = {
55
+ "latency_ms": round(avg_ms, 2),
56
+ "fps": round(fps, 1),
57
+ "params": params,
58
+ "params_m": round(params / 1e6, 2),
59
+ "size_mb": round(params * 4 / 1024 / 1024, 2),
60
+ "device": str(self.device),
61
+ "input_size": self.input_size,
62
+ }
63
+
64
+ logger.info(f"Benchmark: {fps:.1f} FPS, {avg_ms:.2f}ms, {params / 1e6:.2f}M params")
65
+ return results
@@ -0,0 +1,42 @@
1
+ """Layer-by-layer profiling."""
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ from flashseg.cfg.config import get_config
8
+ from flashseg.models.build import build_model
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Profiler:
14
+ """Profile FlashSeg model layer-by-layer."""
15
+
16
+ def __init__(self, model_path: str = None, model_size: str = "m", input_size: int = 512, num_classes: int = 21):
17
+ config = get_config(model_size=model_size, input_size=input_size, num_classes=num_classes)
18
+ self.model = build_model(config)
19
+
20
+ if model_path:
21
+ self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
22
+
23
+ self.model.eval()
24
+ self.input_size = input_size
25
+
26
+ def run(self) -> dict:
27
+ """Profile model and print per-module statistics."""
28
+ dummy = torch.randn(1, 3, self.input_size, self.input_size)
29
+
30
+ results = {}
31
+ for name, module in self.model.named_children():
32
+ params = sum(p.numel() for p in module.parameters())
33
+ results[name] = {
34
+ "params": params,
35
+ "params_m": round(params / 1e6, 3),
36
+ }
37
+ print(f" {name:20s} | {params:>10,} params | {params / 1e6:.3f}M")
38
+
39
+ total = sum(p.numel() for p in self.model.parameters())
40
+ print(f" {'TOTAL':20s} | {total:>10,} params | {total / 1e6:.3f}M")
41
+ results["total"] = {"params": total, "params_m": round(total / 1e6, 3)}
42
+ return results
@@ -0,0 +1,3 @@
1
+ from flashseg.cfg.config import get_config, load_yaml_config
2
+
3
+ __all__ = ["get_config", "load_yaml_config"]
flashseg/cfg/config.py ADDED
@@ -0,0 +1,113 @@
1
+ """Configuration management for FlashSeg."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import yaml
8
+
9
+
10
+ MODEL_SIZE_MAP = {
11
+ "n": {"width_mult": 0.25, "depth_mult": 0.33},
12
+ "s": {"width_mult": 0.50, "depth_mult": 0.33},
13
+ "m": {"width_mult": 0.75, "depth_mult": 0.67},
14
+ "l": {"width_mult": 1.00, "depth_mult": 1.00},
15
+ }
16
+
17
+
18
+ @dataclass
19
+ class Config:
20
+ """FlashSeg configuration."""
21
+
22
+ # Model
23
+ model_size: str = "m"
24
+ num_classes: int = 21
25
+ input_size: int = 512
26
+ width_mult: float = 0.75
27
+ depth_mult: float = 0.67
28
+ backbone: str = "shufflenetv2"
29
+ neck: str = "fpn"
30
+ head: str = "seg_head"
31
+
32
+ # Training
33
+ epochs: int = 100
34
+ batch_size: int = 16
35
+ lr: float = 0.01
36
+ momentum: float = 0.9
37
+ weight_decay: float = 5e-4
38
+ warmup_epochs: int = 5
39
+ scheduler: str = "cosine"
40
+ amp: bool = False
41
+ multi_gpu: bool = False
42
+
43
+ # Data
44
+ train_images: str = ""
45
+ train_masks: str = ""
46
+ val_images: str = ""
47
+ val_masks: str = ""
48
+ num_workers: int = 4
49
+ augment: bool = True
50
+
51
+ # LoRA
52
+ use_lora: bool = False
53
+ lora_rank: int = 8
54
+ lora_alpha: int = 16
55
+ lora_variant: str = "standard"
56
+
57
+ # Knowledge Distillation
58
+ use_kd: bool = False
59
+ teacher_checkpoint: str = ""
60
+ teacher_size: str = "l"
61
+ kd_temperature: float = 4.0
62
+ kd_alpha: float = 0.5
63
+
64
+ # Pretrained
65
+ pretrained: bool = True
66
+
67
+ # Paths
68
+ save_dir: str = "workspace"
69
+ device: str = "cuda"
70
+
71
+ # Extra
72
+ extra: Dict[str, Any] = field(default_factory=dict)
73
+
74
+
75
+ def get_config(
76
+ model_size: str = "m",
77
+ input_size: int = 512,
78
+ num_classes: int = 21,
79
+ **overrides,
80
+ ) -> Config:
81
+ """Create a config with sensible defaults for the given model size."""
82
+ size_params = MODEL_SIZE_MAP.get(model_size, MODEL_SIZE_MAP["m"])
83
+
84
+ config = Config(
85
+ model_size=model_size,
86
+ input_size=input_size,
87
+ num_classes=num_classes,
88
+ width_mult=size_params["width_mult"],
89
+ depth_mult=size_params["depth_mult"],
90
+ )
91
+
92
+ for key, value in overrides.items():
93
+ if hasattr(config, key):
94
+ setattr(config, key, value)
95
+
96
+ return config
97
+
98
+
99
+ def load_yaml_config(path: str) -> Config:
100
+ """Load configuration from a YAML file."""
101
+ with open(path, "r") as f:
102
+ data = yaml.safe_load(f)
103
+
104
+ model_size = data.pop("model_size", "m")
105
+ input_size = data.pop("input_size", 512)
106
+ num_classes = data.pop("num_classes", 21)
107
+
108
+ return get_config(
109
+ model_size=model_size,
110
+ input_size=input_size,
111
+ num_classes=num_classes,
112
+ **data,
113
+ )
flashseg/cli.py ADDED
@@ -0,0 +1,216 @@
1
+ """FlashSeg CLI."""
2
+
3
+ import argparse
4
+ import logging
5
+ import sys
6
+
7
+ from flashseg import __version__
8
+
9
+
10
+ def main():
11
+ """FlashSeg CLI entry point."""
12
+ parser = argparse.ArgumentParser(
13
+ prog="flashseg",
14
+ description="FlashSeg — Ultra-lightweight real-time image segmentation",
15
+ )
16
+ subparsers = parser.add_subparsers(dest="command")
17
+
18
+ # Train
19
+ train_parser = subparsers.add_parser("train", help="Train a segmentation model")
20
+ train_parser.add_argument("--model-size", default="m", choices=["n", "s", "m", "l"])
21
+ train_parser.add_argument("--train-images", required=True)
22
+ train_parser.add_argument("--train-masks", required=True)
23
+ train_parser.add_argument("--val-images", required=True)
24
+ train_parser.add_argument("--val-masks", required=True)
25
+ train_parser.add_argument("--num-classes", type=int, default=21)
26
+ train_parser.add_argument("--input-size", type=int, default=512)
27
+ train_parser.add_argument("--epochs", type=int, default=100)
28
+ train_parser.add_argument("--batch-size", type=int, default=16)
29
+ train_parser.add_argument("--lr", type=float, default=0.01)
30
+ train_parser.add_argument("--device", default="cuda")
31
+ train_parser.add_argument("--save-dir", default="workspace")
32
+ train_parser.add_argument("--amp", action="store_true")
33
+ train_parser.add_argument("--lora", action="store_true")
34
+ train_parser.add_argument("--config", type=str, help="YAML config file")
35
+
36
+ # Predict
37
+ pred_parser = subparsers.add_parser("predict", help="Run segmentation inference")
38
+ pred_parser.add_argument("--model", required=True)
39
+ pred_parser.add_argument("--source", required=True)
40
+ pred_parser.add_argument("--model-size", default="m")
41
+ pred_parser.add_argument("--num-classes", type=int, default=21)
42
+ pred_parser.add_argument("--input-size", type=int, default=512)
43
+ pred_parser.add_argument("--device", default="cuda")
44
+ pred_parser.add_argument("--save-dir", default="output")
45
+
46
+ # Validate
47
+ val_parser = subparsers.add_parser("val", help="Validate model")
48
+ val_parser.add_argument("--model", required=True)
49
+ val_parser.add_argument("--val-images", required=True)
50
+ val_parser.add_argument("--val-masks", required=True)
51
+ val_parser.add_argument("--model-size", default="m")
52
+ val_parser.add_argument("--num-classes", type=int, default=21)
53
+ val_parser.add_argument("--input-size", type=int, default=512)
54
+ val_parser.add_argument("--device", default="cuda")
55
+
56
+ # Export
57
+ export_parser = subparsers.add_parser("export", help="Export to ONNX")
58
+ export_parser.add_argument("--model", required=True)
59
+ export_parser.add_argument("--output", default="model.onnx")
60
+ export_parser.add_argument("--model-size", default="m")
61
+ export_parser.add_argument("--num-classes", type=int, default=21)
62
+ export_parser.add_argument("--input-size", type=int, default=512)
63
+ export_parser.add_argument("--simplify", action="store_true")
64
+
65
+ # Utility commands
66
+ subparsers.add_parser("version", help="Print version")
67
+ subparsers.add_parser("check", help="Run health check")
68
+ subparsers.add_parser("settings", help="Show system info")
69
+
70
+ args = parser.parse_args()
71
+
72
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
73
+
74
+ if args.command == "version":
75
+ print(f"flashseg {__version__}")
76
+
77
+ elif args.command == "check":
78
+ _run_check()
79
+
80
+ elif args.command == "settings":
81
+ _show_settings()
82
+
83
+ elif args.command == "train":
84
+ from flashseg.engine.trainer import Trainer
85
+ trainer = Trainer(
86
+ model_size=args.model_size,
87
+ train_images=args.train_images,
88
+ train_masks=args.train_masks,
89
+ val_images=args.val_images,
90
+ val_masks=args.val_masks,
91
+ num_classes=args.num_classes,
92
+ input_size=args.input_size,
93
+ epochs=args.epochs,
94
+ batch_size=args.batch_size,
95
+ lr=args.lr,
96
+ device=args.device,
97
+ save_dir=args.save_dir,
98
+ amp=args.amp,
99
+ use_lora=args.lora,
100
+ config_path=args.config,
101
+ )
102
+ trainer.train()
103
+
104
+ elif args.command == "predict":
105
+ from flashseg.engine.predictor import Predictor
106
+ predictor = Predictor(
107
+ model_path=args.model,
108
+ model_size=args.model_size,
109
+ num_classes=args.num_classes,
110
+ input_size=args.input_size,
111
+ device=args.device,
112
+ )
113
+ predictor.predict_directory(args.source, save_dir=args.save_dir)
114
+
115
+ elif args.command == "val":
116
+ from flashseg.engine.validator import Validator
117
+ validator = Validator(
118
+ model_path=args.model,
119
+ val_images=args.val_images,
120
+ val_masks=args.val_masks,
121
+ model_size=args.model_size,
122
+ num_classes=args.num_classes,
123
+ input_size=args.input_size,
124
+ device=args.device,
125
+ )
126
+ validator.validate()
127
+
128
+ elif args.command == "export":
129
+ from flashseg.engine.exporter import Exporter
130
+ exporter = Exporter(
131
+ model_path=args.model,
132
+ model_size=args.model_size,
133
+ num_classes=args.num_classes,
134
+ input_size=args.input_size,
135
+ )
136
+ exporter.export(output=args.output, simplify=args.simplify)
137
+
138
+ else:
139
+ parser.print_help()
140
+
141
+
142
+ def _run_check():
143
+ """Run health check."""
144
+ print("FlashSeg Health Check")
145
+ print("=" * 40)
146
+ checks = []
147
+
148
+ try:
149
+ import torch
150
+ checks.append(("PyTorch", f"{torch.__version__}"))
151
+ checks.append(("CUDA available", str(torch.cuda.is_available())))
152
+ if torch.cuda.is_available():
153
+ checks.append(("GPU", torch.cuda.get_device_name(0)))
154
+ except ImportError:
155
+ checks.append(("PyTorch", "NOT INSTALLED"))
156
+
157
+ try:
158
+ import cv2
159
+ checks.append(("OpenCV", cv2.__version__))
160
+ except ImportError:
161
+ checks.append(("OpenCV", "NOT INSTALLED"))
162
+
163
+ try:
164
+ import flashseg
165
+ checks.append(("FlashSeg", flashseg.__version__))
166
+ except Exception as e:
167
+ checks.append(("FlashSeg", f"ERROR: {e}"))
168
+
169
+ try:
170
+ from flashseg.models.build import build_model
171
+ from flashseg.cfg.config import get_config
172
+ config = get_config(model_size="m", input_size=512, num_classes=21)
173
+ model = build_model(config)
174
+ params = sum(p.numel() for p in model.parameters())
175
+ checks.append(("Model build", f"OK ({params:,} params)"))
176
+ except Exception as e:
177
+ checks.append(("Model build", f"FAILED: {e}"))
178
+
179
+ for name, status in checks:
180
+ print(f" {name:20s}: {status}")
181
+
182
+ print("=" * 40)
183
+ print("All checks passed!" if all("NOT INSTALLED" not in s and "FAILED" not in s for _, s in checks) else "Some checks failed.")
184
+
185
+
186
+ def _show_settings():
187
+ """Show system settings."""
188
+ import platform
189
+ print("FlashSeg System Info")
190
+ print("=" * 40)
191
+ print(f" Python: {platform.python_version()}")
192
+ print(f" Platform: {platform.platform()}")
193
+
194
+ try:
195
+ import torch
196
+ print(f" PyTorch: {torch.__version__}")
197
+ print(f" CUDA: {torch.version.cuda or 'N/A'}")
198
+ if torch.cuda.is_available():
199
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
200
+ mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
201
+ print(f" GPU RAM: {mem:.1f} GB")
202
+ else:
203
+ print(" GPU: Not available")
204
+ except ImportError:
205
+ print(" PyTorch: Not installed")
206
+
207
+ try:
208
+ import flashseg
209
+ print(f" FlashSeg: {flashseg.__version__}")
210
+ except ImportError:
211
+ pass
212
+ print("=" * 40)
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
@@ -0,0 +1,4 @@
1
+ from flashseg.data.dataset import SegmentationDataset
2
+ from flashseg.data.transforms import get_train_transforms, get_val_transforms
3
+
4
+ __all__ = ["SegmentationDataset", "get_train_transforms", "get_val_transforms"]
@@ -0,0 +1,93 @@
1
+ """Segmentation dataset classes."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Callable, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class SegmentationDataset(Dataset):
16
+ """Dataset for semantic segmentation with image-mask pairs."""
17
+
18
+ SUPPORTED_FORMATS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
19
+
20
+ def __init__(
21
+ self,
22
+ images_dir: str,
23
+ masks_dir: str,
24
+ input_size: int = 512,
25
+ num_classes: int = 21,
26
+ transform: Optional[Callable] = None,
27
+ augment: bool = False,
28
+ ):
29
+ self.images_dir = Path(images_dir)
30
+ self.masks_dir = Path(masks_dir)
31
+ self.input_size = input_size
32
+ self.num_classes = num_classes
33
+ self.transform = transform
34
+ self.augment = augment
35
+
36
+ self.image_files = sorted(
37
+ [f for f in self.images_dir.iterdir() if f.suffix.lower() in self.SUPPORTED_FORMATS]
38
+ )
39
+ self.mask_files = sorted(
40
+ [f for f in self.masks_dir.iterdir() if f.suffix.lower() in self.SUPPORTED_FORMATS]
41
+ )
42
+
43
+ assert len(self.image_files) == len(self.mask_files), (
44
+ f"Mismatch: {len(self.image_files)} images vs {len(self.mask_files)} masks"
45
+ )
46
+
47
+ logger.info(f"Loaded {len(self.image_files)} image-mask pairs")
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.image_files)
51
+
52
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ image = cv2.imread(str(self.image_files[idx]))
54
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
+ mask = cv2.imread(str(self.mask_files[idx]), cv2.IMREAD_GRAYSCALE)
56
+
57
+ image = cv2.resize(image, (self.input_size, self.input_size))
58
+ mask = cv2.resize(mask, (self.input_size, self.input_size), interpolation=cv2.INTER_NEAREST)
59
+
60
+ if self.augment:
61
+ image, mask = self._augment(image, mask)
62
+
63
+ if self.transform:
64
+ image = self.transform(image)
65
+ else:
66
+ image = image.astype(np.float32) / 255.0
67
+ image = np.transpose(image, (2, 0, 1))
68
+ image = torch.from_numpy(image)
69
+
70
+ mask = torch.from_numpy(mask.astype(np.int64))
71
+ return image, mask
72
+
73
+ def _augment(self, image: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
74
+ """Apply random augmentations."""
75
+ if np.random.random() > 0.5:
76
+ image = np.fliplr(image).copy()
77
+ mask = np.fliplr(mask).copy()
78
+
79
+ if np.random.random() > 0.5:
80
+ image = np.flipud(image).copy()
81
+ mask = np.flipud(mask).copy()
82
+
83
+ if np.random.random() > 0.5:
84
+ k = np.random.randint(1, 4)
85
+ image = np.rot90(image, k).copy()
86
+ mask = np.rot90(mask, k).copy()
87
+
88
+ # Color jitter (image only)
89
+ if np.random.random() > 0.5:
90
+ factor = np.random.uniform(0.8, 1.2)
91
+ image = np.clip(image * factor, 0, 255).astype(np.uint8)
92
+
93
+ return image, mask
@@ -0,0 +1,34 @@
1
+ """Data transforms for segmentation."""
2
+
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def get_train_transforms(input_size: int = 512):
10
+ """Get training transforms."""
11
+
12
+ def transform(image: np.ndarray) -> torch.Tensor:
13
+ image = image.astype(np.float32) / 255.0
14
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
15
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
16
+ image = (image - mean) / std
17
+ image = np.transpose(image, (2, 0, 1))
18
+ return torch.from_numpy(image.copy())
19
+
20
+ return transform
21
+
22
+
23
+ def get_val_transforms(input_size: int = 512):
24
+ """Get validation transforms."""
25
+
26
+ def transform(image: np.ndarray) -> torch.Tensor:
27
+ image = image.astype(np.float32) / 255.0
28
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
29
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
30
+ image = (image - mean) / std
31
+ image = np.transpose(image, (2, 0, 1))
32
+ return torch.from_numpy(image.copy())
33
+
34
+ return transform
@@ -0,0 +1,6 @@
1
+ from flashseg.engine.trainer import Trainer
2
+ from flashseg.engine.predictor import Predictor
3
+ from flashseg.engine.exporter import Exporter
4
+ from flashseg.engine.validator import Validator
5
+
6
+ __all__ = ["Trainer", "Predictor", "Exporter", "Validator"]
@@ -0,0 +1,68 @@
1
+ """FlashSeg ONNX exporter."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import torch
7
+
8
+ from flashseg.cfg.config import get_config
9
+ from flashseg.models.build import build_model
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Exporter:
15
+ """Export FlashSeg models to ONNX format."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_path: str,
20
+ model_size: str = "m",
21
+ num_classes: int = 21,
22
+ input_size: int = 512,
23
+ ):
24
+ self.model_path = model_path
25
+ self.model_size = model_size
26
+ self.num_classes = num_classes
27
+ self.input_size = input_size
28
+
29
+ def export(self, output: str = "model.onnx", simplify: bool = True, opset: int = 11) -> str:
30
+ """Export model to ONNX."""
31
+ return self.export_onnx(output, simplify, opset)
32
+
33
+ def export_onnx(self, output: str = "model.onnx", simplify: bool = True, opset: int = 11) -> str:
34
+ """Export to ONNX format."""
35
+ config = get_config(model_size=self.model_size, input_size=self.input_size, num_classes=self.num_classes)
36
+ model = build_model(config)
37
+ model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
38
+ model.eval()
39
+
40
+ dummy_input = torch.randn(1, 3, self.input_size, self.input_size)
41
+
42
+ torch.onnx.export(
43
+ model,
44
+ dummy_input,
45
+ output,
46
+ opset_version=opset,
47
+ input_names=["images"],
48
+ output_names=["output"],
49
+ dynamic_axes={"images": {0: "batch"}, "output": {0: "batch"}},
50
+ )
51
+ logger.info(f"Exported ONNX model to {output}")
52
+
53
+ if simplify:
54
+ try:
55
+ import onnx
56
+ from onnxsim import simplify as onnx_simplify
57
+
58
+ model_onnx = onnx.load(output)
59
+ model_simple, check = onnx_simplify(model_onnx)
60
+ if check:
61
+ onnx.save(model_simple, output)
62
+ logger.info("ONNX model simplified")
63
+ except ImportError:
64
+ logger.warning("onnxsim not installed, skipping simplification")
65
+
66
+ file_size = Path(output).stat().st_size / 1024 / 1024
67
+ logger.info(f"ONNX model size: {file_size:.2f} MB")
68
+ return output