flashseg 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashseg/__init__.py +20 -0
- flashseg/analytics/__init__.py +6 -0
- flashseg/analytics/benchmark.py +65 -0
- flashseg/analytics/profiler.py +42 -0
- flashseg/cfg/__init__.py +3 -0
- flashseg/cfg/config.py +113 -0
- flashseg/cli.py +216 -0
- flashseg/data/__init__.py +4 -0
- flashseg/data/dataset.py +93 -0
- flashseg/data/transforms.py +34 -0
- flashseg/engine/__init__.py +6 -0
- flashseg/engine/exporter.py +68 -0
- flashseg/engine/predictor.py +115 -0
- flashseg/engine/trainer.py +174 -0
- flashseg/engine/validator.py +65 -0
- flashseg/losses/__init__.py +3 -0
- flashseg/losses/seg_losses.py +75 -0
- flashseg/models/__init__.py +3 -0
- flashseg/models/backbone/__init__.py +3 -0
- flashseg/models/backbone/shufflenetv2.py +111 -0
- flashseg/models/build.py +45 -0
- flashseg/models/head/__init__.py +3 -0
- flashseg/models/head/seg_head.py +48 -0
- flashseg/models/neck/__init__.py +3 -0
- flashseg/models/neck/fpn.py +45 -0
- flashseg/nn/__init__.py +5 -0
- flashseg/nn/blocks.py +97 -0
- flashseg/solutions/__init__.py +8 -0
- flashseg/solutions/area_calculator.py +34 -0
- flashseg/solutions/background_remover.py +47 -0
- flashseg/solutions/lane_detector.py +45 -0
- flashseg/solutions/scene_parser.py +34 -0
- flashseg/utils/__init__.py +4 -0
- flashseg/utils/metrics.py +53 -0
- flashseg/utils/visualization.py +33 -0
- flashseg-1.0.0.dist-info/METADATA +307 -0
- flashseg-1.0.0.dist-info/RECORD +41 -0
- flashseg-1.0.0.dist-info/WHEEL +5 -0
- flashseg-1.0.0.dist-info/entry_points.txt +2 -0
- flashseg-1.0.0.dist-info/licenses/LICENSE +21 -0
- flashseg-1.0.0.dist-info/top_level.txt +1 -0
flashseg/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""FlashSeg - Ultra-lightweight real-time image segmentation."""
|
|
2
|
+
|
|
3
|
+
__version__ = "1.0.0"
|
|
4
|
+
|
|
5
|
+
from flashseg.cfg.config import get_config
|
|
6
|
+
from flashseg.engine.trainer import Trainer
|
|
7
|
+
from flashseg.engine.predictor import Predictor
|
|
8
|
+
from flashseg.engine.exporter import Exporter
|
|
9
|
+
from flashseg.engine.validator import Validator
|
|
10
|
+
from flashseg.models.build import build_model
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"__version__",
|
|
14
|
+
"get_config",
|
|
15
|
+
"build_model",
|
|
16
|
+
"Trainer",
|
|
17
|
+
"Predictor",
|
|
18
|
+
"Exporter",
|
|
19
|
+
"Validator",
|
|
20
|
+
]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Model benchmarking for segmentation."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from flashseg.cfg.config import get_config
|
|
9
|
+
from flashseg.models.build import build_model
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Benchmark:
|
|
15
|
+
"""Benchmark FlashSeg model speed and efficiency."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, model_path: str = None, model_size: str = "m", input_size: int = 512, num_classes: int = 21, device: str = "cuda"):
|
|
18
|
+
self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
|
19
|
+
config = get_config(model_size=model_size, input_size=input_size, num_classes=num_classes)
|
|
20
|
+
self.model = build_model(config).to(self.device)
|
|
21
|
+
|
|
22
|
+
if model_path:
|
|
23
|
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
24
|
+
|
|
25
|
+
self.model.eval()
|
|
26
|
+
self.input_size = input_size
|
|
27
|
+
|
|
28
|
+
def run(self, warmup: int = 10, iterations: int = 100) -> dict:
|
|
29
|
+
"""Run benchmark and return timing results."""
|
|
30
|
+
dummy = torch.randn(1, 3, self.input_size, self.input_size).to(self.device)
|
|
31
|
+
|
|
32
|
+
# Warmup
|
|
33
|
+
with torch.no_grad():
|
|
34
|
+
for _ in range(warmup):
|
|
35
|
+
self.model(dummy)
|
|
36
|
+
|
|
37
|
+
if self.device.type == "cuda":
|
|
38
|
+
torch.cuda.synchronize()
|
|
39
|
+
|
|
40
|
+
# Benchmark
|
|
41
|
+
times = []
|
|
42
|
+
with torch.no_grad():
|
|
43
|
+
for _ in range(iterations):
|
|
44
|
+
start = time.perf_counter()
|
|
45
|
+
self.model(dummy)
|
|
46
|
+
if self.device.type == "cuda":
|
|
47
|
+
torch.cuda.synchronize()
|
|
48
|
+
times.append(time.perf_counter() - start)
|
|
49
|
+
|
|
50
|
+
avg_ms = sum(times) / len(times) * 1000
|
|
51
|
+
fps = 1000.0 / avg_ms
|
|
52
|
+
params = sum(p.numel() for p in self.model.parameters())
|
|
53
|
+
|
|
54
|
+
results = {
|
|
55
|
+
"latency_ms": round(avg_ms, 2),
|
|
56
|
+
"fps": round(fps, 1),
|
|
57
|
+
"params": params,
|
|
58
|
+
"params_m": round(params / 1e6, 2),
|
|
59
|
+
"size_mb": round(params * 4 / 1024 / 1024, 2),
|
|
60
|
+
"device": str(self.device),
|
|
61
|
+
"input_size": self.input_size,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
logger.info(f"Benchmark: {fps:.1f} FPS, {avg_ms:.2f}ms, {params / 1e6:.2f}M params")
|
|
65
|
+
return results
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Layer-by-layer profiling."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from flashseg.cfg.config import get_config
|
|
8
|
+
from flashseg.models.build import build_model
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Profiler:
|
|
14
|
+
"""Profile FlashSeg model layer-by-layer."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, model_path: str = None, model_size: str = "m", input_size: int = 512, num_classes: int = 21):
|
|
17
|
+
config = get_config(model_size=model_size, input_size=input_size, num_classes=num_classes)
|
|
18
|
+
self.model = build_model(config)
|
|
19
|
+
|
|
20
|
+
if model_path:
|
|
21
|
+
self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
|
22
|
+
|
|
23
|
+
self.model.eval()
|
|
24
|
+
self.input_size = input_size
|
|
25
|
+
|
|
26
|
+
def run(self) -> dict:
|
|
27
|
+
"""Profile model and print per-module statistics."""
|
|
28
|
+
dummy = torch.randn(1, 3, self.input_size, self.input_size)
|
|
29
|
+
|
|
30
|
+
results = {}
|
|
31
|
+
for name, module in self.model.named_children():
|
|
32
|
+
params = sum(p.numel() for p in module.parameters())
|
|
33
|
+
results[name] = {
|
|
34
|
+
"params": params,
|
|
35
|
+
"params_m": round(params / 1e6, 3),
|
|
36
|
+
}
|
|
37
|
+
print(f" {name:20s} | {params:>10,} params | {params / 1e6:.3f}M")
|
|
38
|
+
|
|
39
|
+
total = sum(p.numel() for p in self.model.parameters())
|
|
40
|
+
print(f" {'TOTAL':20s} | {total:>10,} params | {total / 1e6:.3f}M")
|
|
41
|
+
results["total"] = {"params": total, "params_m": round(total / 1e6, 3)}
|
|
42
|
+
return results
|
flashseg/cfg/__init__.py
ADDED
flashseg/cfg/config.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Configuration management for FlashSeg."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
MODEL_SIZE_MAP = {
|
|
11
|
+
"n": {"width_mult": 0.25, "depth_mult": 0.33},
|
|
12
|
+
"s": {"width_mult": 0.50, "depth_mult": 0.33},
|
|
13
|
+
"m": {"width_mult": 0.75, "depth_mult": 0.67},
|
|
14
|
+
"l": {"width_mult": 1.00, "depth_mult": 1.00},
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Config:
|
|
20
|
+
"""FlashSeg configuration."""
|
|
21
|
+
|
|
22
|
+
# Model
|
|
23
|
+
model_size: str = "m"
|
|
24
|
+
num_classes: int = 21
|
|
25
|
+
input_size: int = 512
|
|
26
|
+
width_mult: float = 0.75
|
|
27
|
+
depth_mult: float = 0.67
|
|
28
|
+
backbone: str = "shufflenetv2"
|
|
29
|
+
neck: str = "fpn"
|
|
30
|
+
head: str = "seg_head"
|
|
31
|
+
|
|
32
|
+
# Training
|
|
33
|
+
epochs: int = 100
|
|
34
|
+
batch_size: int = 16
|
|
35
|
+
lr: float = 0.01
|
|
36
|
+
momentum: float = 0.9
|
|
37
|
+
weight_decay: float = 5e-4
|
|
38
|
+
warmup_epochs: int = 5
|
|
39
|
+
scheduler: str = "cosine"
|
|
40
|
+
amp: bool = False
|
|
41
|
+
multi_gpu: bool = False
|
|
42
|
+
|
|
43
|
+
# Data
|
|
44
|
+
train_images: str = ""
|
|
45
|
+
train_masks: str = ""
|
|
46
|
+
val_images: str = ""
|
|
47
|
+
val_masks: str = ""
|
|
48
|
+
num_workers: int = 4
|
|
49
|
+
augment: bool = True
|
|
50
|
+
|
|
51
|
+
# LoRA
|
|
52
|
+
use_lora: bool = False
|
|
53
|
+
lora_rank: int = 8
|
|
54
|
+
lora_alpha: int = 16
|
|
55
|
+
lora_variant: str = "standard"
|
|
56
|
+
|
|
57
|
+
# Knowledge Distillation
|
|
58
|
+
use_kd: bool = False
|
|
59
|
+
teacher_checkpoint: str = ""
|
|
60
|
+
teacher_size: str = "l"
|
|
61
|
+
kd_temperature: float = 4.0
|
|
62
|
+
kd_alpha: float = 0.5
|
|
63
|
+
|
|
64
|
+
# Pretrained
|
|
65
|
+
pretrained: bool = True
|
|
66
|
+
|
|
67
|
+
# Paths
|
|
68
|
+
save_dir: str = "workspace"
|
|
69
|
+
device: str = "cuda"
|
|
70
|
+
|
|
71
|
+
# Extra
|
|
72
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_config(
|
|
76
|
+
model_size: str = "m",
|
|
77
|
+
input_size: int = 512,
|
|
78
|
+
num_classes: int = 21,
|
|
79
|
+
**overrides,
|
|
80
|
+
) -> Config:
|
|
81
|
+
"""Create a config with sensible defaults for the given model size."""
|
|
82
|
+
size_params = MODEL_SIZE_MAP.get(model_size, MODEL_SIZE_MAP["m"])
|
|
83
|
+
|
|
84
|
+
config = Config(
|
|
85
|
+
model_size=model_size,
|
|
86
|
+
input_size=input_size,
|
|
87
|
+
num_classes=num_classes,
|
|
88
|
+
width_mult=size_params["width_mult"],
|
|
89
|
+
depth_mult=size_params["depth_mult"],
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
for key, value in overrides.items():
|
|
93
|
+
if hasattr(config, key):
|
|
94
|
+
setattr(config, key, value)
|
|
95
|
+
|
|
96
|
+
return config
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def load_yaml_config(path: str) -> Config:
|
|
100
|
+
"""Load configuration from a YAML file."""
|
|
101
|
+
with open(path, "r") as f:
|
|
102
|
+
data = yaml.safe_load(f)
|
|
103
|
+
|
|
104
|
+
model_size = data.pop("model_size", "m")
|
|
105
|
+
input_size = data.pop("input_size", 512)
|
|
106
|
+
num_classes = data.pop("num_classes", 21)
|
|
107
|
+
|
|
108
|
+
return get_config(
|
|
109
|
+
model_size=model_size,
|
|
110
|
+
input_size=input_size,
|
|
111
|
+
num_classes=num_classes,
|
|
112
|
+
**data,
|
|
113
|
+
)
|
flashseg/cli.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""FlashSeg CLI."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from flashseg import __version__
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def main():
|
|
11
|
+
"""FlashSeg CLI entry point."""
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
prog="flashseg",
|
|
14
|
+
description="FlashSeg — Ultra-lightweight real-time image segmentation",
|
|
15
|
+
)
|
|
16
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
17
|
+
|
|
18
|
+
# Train
|
|
19
|
+
train_parser = subparsers.add_parser("train", help="Train a segmentation model")
|
|
20
|
+
train_parser.add_argument("--model-size", default="m", choices=["n", "s", "m", "l"])
|
|
21
|
+
train_parser.add_argument("--train-images", required=True)
|
|
22
|
+
train_parser.add_argument("--train-masks", required=True)
|
|
23
|
+
train_parser.add_argument("--val-images", required=True)
|
|
24
|
+
train_parser.add_argument("--val-masks", required=True)
|
|
25
|
+
train_parser.add_argument("--num-classes", type=int, default=21)
|
|
26
|
+
train_parser.add_argument("--input-size", type=int, default=512)
|
|
27
|
+
train_parser.add_argument("--epochs", type=int, default=100)
|
|
28
|
+
train_parser.add_argument("--batch-size", type=int, default=16)
|
|
29
|
+
train_parser.add_argument("--lr", type=float, default=0.01)
|
|
30
|
+
train_parser.add_argument("--device", default="cuda")
|
|
31
|
+
train_parser.add_argument("--save-dir", default="workspace")
|
|
32
|
+
train_parser.add_argument("--amp", action="store_true")
|
|
33
|
+
train_parser.add_argument("--lora", action="store_true")
|
|
34
|
+
train_parser.add_argument("--config", type=str, help="YAML config file")
|
|
35
|
+
|
|
36
|
+
# Predict
|
|
37
|
+
pred_parser = subparsers.add_parser("predict", help="Run segmentation inference")
|
|
38
|
+
pred_parser.add_argument("--model", required=True)
|
|
39
|
+
pred_parser.add_argument("--source", required=True)
|
|
40
|
+
pred_parser.add_argument("--model-size", default="m")
|
|
41
|
+
pred_parser.add_argument("--num-classes", type=int, default=21)
|
|
42
|
+
pred_parser.add_argument("--input-size", type=int, default=512)
|
|
43
|
+
pred_parser.add_argument("--device", default="cuda")
|
|
44
|
+
pred_parser.add_argument("--save-dir", default="output")
|
|
45
|
+
|
|
46
|
+
# Validate
|
|
47
|
+
val_parser = subparsers.add_parser("val", help="Validate model")
|
|
48
|
+
val_parser.add_argument("--model", required=True)
|
|
49
|
+
val_parser.add_argument("--val-images", required=True)
|
|
50
|
+
val_parser.add_argument("--val-masks", required=True)
|
|
51
|
+
val_parser.add_argument("--model-size", default="m")
|
|
52
|
+
val_parser.add_argument("--num-classes", type=int, default=21)
|
|
53
|
+
val_parser.add_argument("--input-size", type=int, default=512)
|
|
54
|
+
val_parser.add_argument("--device", default="cuda")
|
|
55
|
+
|
|
56
|
+
# Export
|
|
57
|
+
export_parser = subparsers.add_parser("export", help="Export to ONNX")
|
|
58
|
+
export_parser.add_argument("--model", required=True)
|
|
59
|
+
export_parser.add_argument("--output", default="model.onnx")
|
|
60
|
+
export_parser.add_argument("--model-size", default="m")
|
|
61
|
+
export_parser.add_argument("--num-classes", type=int, default=21)
|
|
62
|
+
export_parser.add_argument("--input-size", type=int, default=512)
|
|
63
|
+
export_parser.add_argument("--simplify", action="store_true")
|
|
64
|
+
|
|
65
|
+
# Utility commands
|
|
66
|
+
subparsers.add_parser("version", help="Print version")
|
|
67
|
+
subparsers.add_parser("check", help="Run health check")
|
|
68
|
+
subparsers.add_parser("settings", help="Show system info")
|
|
69
|
+
|
|
70
|
+
args = parser.parse_args()
|
|
71
|
+
|
|
72
|
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
73
|
+
|
|
74
|
+
if args.command == "version":
|
|
75
|
+
print(f"flashseg {__version__}")
|
|
76
|
+
|
|
77
|
+
elif args.command == "check":
|
|
78
|
+
_run_check()
|
|
79
|
+
|
|
80
|
+
elif args.command == "settings":
|
|
81
|
+
_show_settings()
|
|
82
|
+
|
|
83
|
+
elif args.command == "train":
|
|
84
|
+
from flashseg.engine.trainer import Trainer
|
|
85
|
+
trainer = Trainer(
|
|
86
|
+
model_size=args.model_size,
|
|
87
|
+
train_images=args.train_images,
|
|
88
|
+
train_masks=args.train_masks,
|
|
89
|
+
val_images=args.val_images,
|
|
90
|
+
val_masks=args.val_masks,
|
|
91
|
+
num_classes=args.num_classes,
|
|
92
|
+
input_size=args.input_size,
|
|
93
|
+
epochs=args.epochs,
|
|
94
|
+
batch_size=args.batch_size,
|
|
95
|
+
lr=args.lr,
|
|
96
|
+
device=args.device,
|
|
97
|
+
save_dir=args.save_dir,
|
|
98
|
+
amp=args.amp,
|
|
99
|
+
use_lora=args.lora,
|
|
100
|
+
config_path=args.config,
|
|
101
|
+
)
|
|
102
|
+
trainer.train()
|
|
103
|
+
|
|
104
|
+
elif args.command == "predict":
|
|
105
|
+
from flashseg.engine.predictor import Predictor
|
|
106
|
+
predictor = Predictor(
|
|
107
|
+
model_path=args.model,
|
|
108
|
+
model_size=args.model_size,
|
|
109
|
+
num_classes=args.num_classes,
|
|
110
|
+
input_size=args.input_size,
|
|
111
|
+
device=args.device,
|
|
112
|
+
)
|
|
113
|
+
predictor.predict_directory(args.source, save_dir=args.save_dir)
|
|
114
|
+
|
|
115
|
+
elif args.command == "val":
|
|
116
|
+
from flashseg.engine.validator import Validator
|
|
117
|
+
validator = Validator(
|
|
118
|
+
model_path=args.model,
|
|
119
|
+
val_images=args.val_images,
|
|
120
|
+
val_masks=args.val_masks,
|
|
121
|
+
model_size=args.model_size,
|
|
122
|
+
num_classes=args.num_classes,
|
|
123
|
+
input_size=args.input_size,
|
|
124
|
+
device=args.device,
|
|
125
|
+
)
|
|
126
|
+
validator.validate()
|
|
127
|
+
|
|
128
|
+
elif args.command == "export":
|
|
129
|
+
from flashseg.engine.exporter import Exporter
|
|
130
|
+
exporter = Exporter(
|
|
131
|
+
model_path=args.model,
|
|
132
|
+
model_size=args.model_size,
|
|
133
|
+
num_classes=args.num_classes,
|
|
134
|
+
input_size=args.input_size,
|
|
135
|
+
)
|
|
136
|
+
exporter.export(output=args.output, simplify=args.simplify)
|
|
137
|
+
|
|
138
|
+
else:
|
|
139
|
+
parser.print_help()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _run_check():
|
|
143
|
+
"""Run health check."""
|
|
144
|
+
print("FlashSeg Health Check")
|
|
145
|
+
print("=" * 40)
|
|
146
|
+
checks = []
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
import torch
|
|
150
|
+
checks.append(("PyTorch", f"{torch.__version__}"))
|
|
151
|
+
checks.append(("CUDA available", str(torch.cuda.is_available())))
|
|
152
|
+
if torch.cuda.is_available():
|
|
153
|
+
checks.append(("GPU", torch.cuda.get_device_name(0)))
|
|
154
|
+
except ImportError:
|
|
155
|
+
checks.append(("PyTorch", "NOT INSTALLED"))
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
import cv2
|
|
159
|
+
checks.append(("OpenCV", cv2.__version__))
|
|
160
|
+
except ImportError:
|
|
161
|
+
checks.append(("OpenCV", "NOT INSTALLED"))
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
import flashseg
|
|
165
|
+
checks.append(("FlashSeg", flashseg.__version__))
|
|
166
|
+
except Exception as e:
|
|
167
|
+
checks.append(("FlashSeg", f"ERROR: {e}"))
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
from flashseg.models.build import build_model
|
|
171
|
+
from flashseg.cfg.config import get_config
|
|
172
|
+
config = get_config(model_size="m", input_size=512, num_classes=21)
|
|
173
|
+
model = build_model(config)
|
|
174
|
+
params = sum(p.numel() for p in model.parameters())
|
|
175
|
+
checks.append(("Model build", f"OK ({params:,} params)"))
|
|
176
|
+
except Exception as e:
|
|
177
|
+
checks.append(("Model build", f"FAILED: {e}"))
|
|
178
|
+
|
|
179
|
+
for name, status in checks:
|
|
180
|
+
print(f" {name:20s}: {status}")
|
|
181
|
+
|
|
182
|
+
print("=" * 40)
|
|
183
|
+
print("All checks passed!" if all("NOT INSTALLED" not in s and "FAILED" not in s for _, s in checks) else "Some checks failed.")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _show_settings():
|
|
187
|
+
"""Show system settings."""
|
|
188
|
+
import platform
|
|
189
|
+
print("FlashSeg System Info")
|
|
190
|
+
print("=" * 40)
|
|
191
|
+
print(f" Python: {platform.python_version()}")
|
|
192
|
+
print(f" Platform: {platform.platform()}")
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
import torch
|
|
196
|
+
print(f" PyTorch: {torch.__version__}")
|
|
197
|
+
print(f" CUDA: {torch.version.cuda or 'N/A'}")
|
|
198
|
+
if torch.cuda.is_available():
|
|
199
|
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
|
200
|
+
mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
|
|
201
|
+
print(f" GPU RAM: {mem:.1f} GB")
|
|
202
|
+
else:
|
|
203
|
+
print(" GPU: Not available")
|
|
204
|
+
except ImportError:
|
|
205
|
+
print(" PyTorch: Not installed")
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
import flashseg
|
|
209
|
+
print(f" FlashSeg: {flashseg.__version__}")
|
|
210
|
+
except ImportError:
|
|
211
|
+
pass
|
|
212
|
+
print("=" * 40)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
if __name__ == "__main__":
|
|
216
|
+
main()
|
flashseg/data/dataset.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Segmentation dataset classes."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Callable, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SegmentationDataset(Dataset):
|
|
16
|
+
"""Dataset for semantic segmentation with image-mask pairs."""
|
|
17
|
+
|
|
18
|
+
SUPPORTED_FORMATS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
images_dir: str,
|
|
23
|
+
masks_dir: str,
|
|
24
|
+
input_size: int = 512,
|
|
25
|
+
num_classes: int = 21,
|
|
26
|
+
transform: Optional[Callable] = None,
|
|
27
|
+
augment: bool = False,
|
|
28
|
+
):
|
|
29
|
+
self.images_dir = Path(images_dir)
|
|
30
|
+
self.masks_dir = Path(masks_dir)
|
|
31
|
+
self.input_size = input_size
|
|
32
|
+
self.num_classes = num_classes
|
|
33
|
+
self.transform = transform
|
|
34
|
+
self.augment = augment
|
|
35
|
+
|
|
36
|
+
self.image_files = sorted(
|
|
37
|
+
[f for f in self.images_dir.iterdir() if f.suffix.lower() in self.SUPPORTED_FORMATS]
|
|
38
|
+
)
|
|
39
|
+
self.mask_files = sorted(
|
|
40
|
+
[f for f in self.masks_dir.iterdir() if f.suffix.lower() in self.SUPPORTED_FORMATS]
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
assert len(self.image_files) == len(self.mask_files), (
|
|
44
|
+
f"Mismatch: {len(self.image_files)} images vs {len(self.mask_files)} masks"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
logger.info(f"Loaded {len(self.image_files)} image-mask pairs")
|
|
48
|
+
|
|
49
|
+
def __len__(self) -> int:
|
|
50
|
+
return len(self.image_files)
|
|
51
|
+
|
|
52
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
53
|
+
image = cv2.imread(str(self.image_files[idx]))
|
|
54
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
55
|
+
mask = cv2.imread(str(self.mask_files[idx]), cv2.IMREAD_GRAYSCALE)
|
|
56
|
+
|
|
57
|
+
image = cv2.resize(image, (self.input_size, self.input_size))
|
|
58
|
+
mask = cv2.resize(mask, (self.input_size, self.input_size), interpolation=cv2.INTER_NEAREST)
|
|
59
|
+
|
|
60
|
+
if self.augment:
|
|
61
|
+
image, mask = self._augment(image, mask)
|
|
62
|
+
|
|
63
|
+
if self.transform:
|
|
64
|
+
image = self.transform(image)
|
|
65
|
+
else:
|
|
66
|
+
image = image.astype(np.float32) / 255.0
|
|
67
|
+
image = np.transpose(image, (2, 0, 1))
|
|
68
|
+
image = torch.from_numpy(image)
|
|
69
|
+
|
|
70
|
+
mask = torch.from_numpy(mask.astype(np.int64))
|
|
71
|
+
return image, mask
|
|
72
|
+
|
|
73
|
+
def _augment(self, image: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
74
|
+
"""Apply random augmentations."""
|
|
75
|
+
if np.random.random() > 0.5:
|
|
76
|
+
image = np.fliplr(image).copy()
|
|
77
|
+
mask = np.fliplr(mask).copy()
|
|
78
|
+
|
|
79
|
+
if np.random.random() > 0.5:
|
|
80
|
+
image = np.flipud(image).copy()
|
|
81
|
+
mask = np.flipud(mask).copy()
|
|
82
|
+
|
|
83
|
+
if np.random.random() > 0.5:
|
|
84
|
+
k = np.random.randint(1, 4)
|
|
85
|
+
image = np.rot90(image, k).copy()
|
|
86
|
+
mask = np.rot90(mask, k).copy()
|
|
87
|
+
|
|
88
|
+
# Color jitter (image only)
|
|
89
|
+
if np.random.random() > 0.5:
|
|
90
|
+
factor = np.random.uniform(0.8, 1.2)
|
|
91
|
+
image = np.clip(image * factor, 0, 255).astype(np.uint8)
|
|
92
|
+
|
|
93
|
+
return image, mask
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Data transforms for segmentation."""
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_train_transforms(input_size: int = 512):
|
|
10
|
+
"""Get training transforms."""
|
|
11
|
+
|
|
12
|
+
def transform(image: np.ndarray) -> torch.Tensor:
|
|
13
|
+
image = image.astype(np.float32) / 255.0
|
|
14
|
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
15
|
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
16
|
+
image = (image - mean) / std
|
|
17
|
+
image = np.transpose(image, (2, 0, 1))
|
|
18
|
+
return torch.from_numpy(image.copy())
|
|
19
|
+
|
|
20
|
+
return transform
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_val_transforms(input_size: int = 512):
|
|
24
|
+
"""Get validation transforms."""
|
|
25
|
+
|
|
26
|
+
def transform(image: np.ndarray) -> torch.Tensor:
|
|
27
|
+
image = image.astype(np.float32) / 255.0
|
|
28
|
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
29
|
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
30
|
+
image = (image - mean) / std
|
|
31
|
+
image = np.transpose(image, (2, 0, 1))
|
|
32
|
+
return torch.from_numpy(image.copy())
|
|
33
|
+
|
|
34
|
+
return transform
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""FlashSeg ONNX exporter."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from flashseg.cfg.config import get_config
|
|
9
|
+
from flashseg.models.build import build_model
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Exporter:
|
|
15
|
+
"""Export FlashSeg models to ONNX format."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model_path: str,
|
|
20
|
+
model_size: str = "m",
|
|
21
|
+
num_classes: int = 21,
|
|
22
|
+
input_size: int = 512,
|
|
23
|
+
):
|
|
24
|
+
self.model_path = model_path
|
|
25
|
+
self.model_size = model_size
|
|
26
|
+
self.num_classes = num_classes
|
|
27
|
+
self.input_size = input_size
|
|
28
|
+
|
|
29
|
+
def export(self, output: str = "model.onnx", simplify: bool = True, opset: int = 11) -> str:
|
|
30
|
+
"""Export model to ONNX."""
|
|
31
|
+
return self.export_onnx(output, simplify, opset)
|
|
32
|
+
|
|
33
|
+
def export_onnx(self, output: str = "model.onnx", simplify: bool = True, opset: int = 11) -> str:
|
|
34
|
+
"""Export to ONNX format."""
|
|
35
|
+
config = get_config(model_size=self.model_size, input_size=self.input_size, num_classes=self.num_classes)
|
|
36
|
+
model = build_model(config)
|
|
37
|
+
model.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
|
38
|
+
model.eval()
|
|
39
|
+
|
|
40
|
+
dummy_input = torch.randn(1, 3, self.input_size, self.input_size)
|
|
41
|
+
|
|
42
|
+
torch.onnx.export(
|
|
43
|
+
model,
|
|
44
|
+
dummy_input,
|
|
45
|
+
output,
|
|
46
|
+
opset_version=opset,
|
|
47
|
+
input_names=["images"],
|
|
48
|
+
output_names=["output"],
|
|
49
|
+
dynamic_axes={"images": {0: "batch"}, "output": {0: "batch"}},
|
|
50
|
+
)
|
|
51
|
+
logger.info(f"Exported ONNX model to {output}")
|
|
52
|
+
|
|
53
|
+
if simplify:
|
|
54
|
+
try:
|
|
55
|
+
import onnx
|
|
56
|
+
from onnxsim import simplify as onnx_simplify
|
|
57
|
+
|
|
58
|
+
model_onnx = onnx.load(output)
|
|
59
|
+
model_simple, check = onnx_simplify(model_onnx)
|
|
60
|
+
if check:
|
|
61
|
+
onnx.save(model_simple, output)
|
|
62
|
+
logger.info("ONNX model simplified")
|
|
63
|
+
except ImportError:
|
|
64
|
+
logger.warning("onnxsim not installed, skipping simplification")
|
|
65
|
+
|
|
66
|
+
file_size = Path(output).stat().st_size / 1024 / 1024
|
|
67
|
+
logger.info(f"ONNX model size: {file_size:.2f} MB")
|
|
68
|
+
return output
|