srforge 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- srforge-0.1.0/PKG-INFO +71 -0
- srforge-0.1.0/README.md +32 -0
- srforge-0.1.0/pyproject.toml +77 -0
- srforge-0.1.0/scripts/__init__.py +0 -0
- srforge-0.1.0/scripts/test.py +104 -0
- srforge-0.1.0/scripts/train.py +140 -0
- srforge-0.1.0/setup.cfg +4 -0
- srforge-0.1.0/srforge/__init__.py +23 -0
- srforge-0.1.0/srforge/registry.py +150 -0
- srforge-0.1.0/srforge/structs.py +16 -0
- srforge-0.1.0/srforge.egg-info/PKG-INFO +71 -0
- srforge-0.1.0/srforge.egg-info/SOURCES.txt +14 -0
- srforge-0.1.0/srforge.egg-info/dependency_links.txt +1 -0
- srforge-0.1.0/srforge.egg-info/entry_points.txt +3 -0
- srforge-0.1.0/srforge.egg-info/requires.txt +30 -0
- srforge-0.1.0/srforge.egg-info/top_level.txt +2 -0
srforge-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: srforge
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files.
|
|
5
|
+
Author-email: Tomasz Tarasiewicz <tarasiewicztomasz@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://gitlab.com/tarasiewicztomasz/sr-forge
|
|
7
|
+
Project-URL: Source, https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch>=2.7
|
|
11
|
+
Requires-Dist: torchvision>=0.22
|
|
12
|
+
Requires-Dist: torchaudio>=2.7
|
|
13
|
+
Requires-Dist: torch-geometric>=2.6
|
|
14
|
+
Requires-Dist: pyg-lib>=0.4
|
|
15
|
+
Requires-Dist: torch-scatter>=2.1
|
|
16
|
+
Requires-Dist: torch-sparse>=0.6
|
|
17
|
+
Requires-Dist: torch-cluster>=1.6
|
|
18
|
+
Requires-Dist: torch-spline-conv>=1.2
|
|
19
|
+
Requires-Dist: torchdatasets>=0.2
|
|
20
|
+
Requires-Dist: torchmetrics>=1.7
|
|
21
|
+
Requires-Dist: hydra-core>=1.3
|
|
22
|
+
Requires-Dist: einops>=0.8
|
|
23
|
+
Requires-Dist: wandb>=0.21
|
|
24
|
+
Requires-Dist: matplotlib>=3.10
|
|
25
|
+
Requires-Dist: scikit-image>=0.25
|
|
26
|
+
Requires-Dist: opencv-python>=4.12
|
|
27
|
+
Requires-Dist: pandas>=2.3
|
|
28
|
+
Requires-Dist: colorlog>=6.9
|
|
29
|
+
Provides-Extra: cuda128
|
|
30
|
+
Requires-Dist: torch==2.7.1+cu128; extra == "cuda128"
|
|
31
|
+
Requires-Dist: torchvision==0.22.1+cu128; extra == "cuda128"
|
|
32
|
+
Requires-Dist: torchaudio==2.7.1+cu128; extra == "cuda128"
|
|
33
|
+
Requires-Dist: torch-geometric==2.6.1; extra == "cuda128"
|
|
34
|
+
Requires-Dist: pyg-lib==0.4.0+pt27cu128; extra == "cuda128"
|
|
35
|
+
Requires-Dist: torch-scatter==2.1.2+pt27cu128; extra == "cuda128"
|
|
36
|
+
Requires-Dist: torch-sparse==0.6.18+pt27cu128; extra == "cuda128"
|
|
37
|
+
Requires-Dist: torch-cluster==1.6.3+pt27cu128; extra == "cuda128"
|
|
38
|
+
Requires-Dist: torch-spline-conv==1.2.2+pt27cu128; extra == "cuda128"
|
|
39
|
+
|
|
40
|
+
# SR FORGE
|
|
41
|
+
**Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
|
|
42
|
+
|
|
43
|
+
---
|
|
44
|
+
|
|
45
|
+
SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
|
|
46
|
+
|
|
47
|
+
## Key Features
|
|
48
|
+
|
|
49
|
+
- **Structured Workflow**
|
|
50
|
+
SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
|
|
51
|
+
|
|
52
|
+
- **Task-driven restoration**
|
|
53
|
+
Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
|
|
54
|
+
|
|
55
|
+
- **Config-Driven Experiments**
|
|
56
|
+
Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
|
|
57
|
+
|
|
58
|
+
- **Flexible Model Plug-In**
|
|
59
|
+
Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
|
|
60
|
+
|
|
61
|
+
- **Unified Metrics**
|
|
62
|
+
Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
|
|
63
|
+
|
|
64
|
+
- **Visualization Tools**
|
|
65
|
+
Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
|
|
66
|
+
|
|
67
|
+
## Installation
|
|
68
|
+
|
|
69
|
+
1. **Clone the Repository**
|
|
70
|
+
```bash
|
|
71
|
+
git clone https://github.com/your-username/sr-forge.git
|
srforge-0.1.0/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# SR FORGE
|
|
2
|
+
**Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
|
|
3
|
+
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
|
|
7
|
+
|
|
8
|
+
## Key Features
|
|
9
|
+
|
|
10
|
+
- **Structured Workflow**
|
|
11
|
+
SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
|
|
12
|
+
|
|
13
|
+
- **Task-driven restoration**
|
|
14
|
+
Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
|
|
15
|
+
|
|
16
|
+
- **Config-Driven Experiments**
|
|
17
|
+
Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
|
|
18
|
+
|
|
19
|
+
- **Flexible Model Plug-In**
|
|
20
|
+
Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
|
|
21
|
+
|
|
22
|
+
- **Unified Metrics**
|
|
23
|
+
Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
|
|
24
|
+
|
|
25
|
+
- **Visualization Tools**
|
|
26
|
+
Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
|
|
27
|
+
|
|
28
|
+
## Installation
|
|
29
|
+
|
|
30
|
+
1. **Clone the Repository**
|
|
31
|
+
```bash
|
|
32
|
+
git clone https://github.com/your-username/sr-forge.git
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"] # ↑ keep these up‑to‑date
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
[project.urls]
|
|
8
|
+
Homepage = "https://gitlab.com/tarasiewicztomasz/sr-forge"
|
|
9
|
+
Source = "https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
[project] # PEP 621 metadata
|
|
13
|
+
name = "srforge" # pip install srforge
|
|
14
|
+
version = "0.1.0"
|
|
15
|
+
description = "Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files."
|
|
16
|
+
readme = "README.md"
|
|
17
|
+
requires-python = ">=3.10"
|
|
18
|
+
|
|
19
|
+
authors = [
|
|
20
|
+
{ name = "Tomasz Tarasiewicz", email = "tarasiewicztomasz@gmail.com" }
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
dependencies = [
|
|
24
|
+
"torch>=2.7", # gets the CPU wheel by default
|
|
25
|
+
"torchvision>=0.22",
|
|
26
|
+
"torchaudio>=2.7",
|
|
27
|
+
|
|
28
|
+
"torch-geometric>=2.6",
|
|
29
|
+
"pyg-lib>=0.4",
|
|
30
|
+
"torch-scatter>=2.1",
|
|
31
|
+
"torch-sparse>=0.6",
|
|
32
|
+
"torch-cluster>=1.6",
|
|
33
|
+
"torch-spline-conv>=1.2",
|
|
34
|
+
|
|
35
|
+
"torchdatasets>=0.2",
|
|
36
|
+
"torchmetrics>=1.7",
|
|
37
|
+
|
|
38
|
+
"hydra-core>=1.3",
|
|
39
|
+
"einops>=0.8",
|
|
40
|
+
"wandb>=0.21",
|
|
41
|
+
|
|
42
|
+
"matplotlib>=3.10",
|
|
43
|
+
"scikit-image>=0.25",
|
|
44
|
+
"opencv-python>=4.12",
|
|
45
|
+
"pandas>=2.3",
|
|
46
|
+
|
|
47
|
+
"colorlog>=6.9"
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# ────────────────────────────── CUDA 12.8 wheels ───────────────────────────────
|
|
51
|
+
[project.optional-dependencies]
|
|
52
|
+
cuda128 = [
|
|
53
|
+
"torch==2.7.1+cu128",
|
|
54
|
+
"torchvision==0.22.1+cu128",
|
|
55
|
+
"torchaudio==2.7.1+cu128",
|
|
56
|
+
|
|
57
|
+
"torch-geometric==2.6.1",
|
|
58
|
+
"pyg-lib==0.4.0+pt27cu128",
|
|
59
|
+
"torch-scatter==2.1.2+pt27cu128",
|
|
60
|
+
"torch-sparse==0.6.18+pt27cu128",
|
|
61
|
+
"torch-cluster==1.6.3+pt27cu128",
|
|
62
|
+
"torch-spline-conv==1.2.2+pt27cu128"
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
# 1) tell setuptools “only package the 'srforge' directory”
|
|
66
|
+
[tool.setuptools]
|
|
67
|
+
packages = ["srforge", "scripts"] # easiest & explicit
|
|
68
|
+
|
|
69
|
+
# alternatively you could use the automatic finder:
|
|
70
|
+
# [tool.setuptools.packages.find]
|
|
71
|
+
# include = ["srforge*"] # but explicit is clearer
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# 2) expose console‑script entry‑points *optional*
|
|
75
|
+
[project.scripts]
|
|
76
|
+
train = "scripts.train:main"
|
|
77
|
+
test = "scripts.test:main"
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from training.observers import ImageProcessor
|
|
6
|
+
|
|
7
|
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # IMPORT CORE AFTER THIS LINE !!!
|
|
8
|
+
import logging
|
|
9
|
+
from utils.logging import configure_logger
|
|
10
|
+
configure_logger(logging.INFO)
|
|
11
|
+
import hydra
|
|
12
|
+
from hydra.core.hydra_config import HydraConfig
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
import wandb
|
|
16
|
+
import omegaconf
|
|
17
|
+
from core import GlobalSettings
|
|
18
|
+
import core.config.utils
|
|
19
|
+
from core.config.legacy import ConfigParser
|
|
20
|
+
from core.data.loader import DataLoaderFactory
|
|
21
|
+
from core.training import runners, observers
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torchdatasets as td
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
@hydra.main(config_path="configs", config_name="test-cfg", version_base=None)
|
|
29
|
+
def main(cfg) -> None:
|
|
30
|
+
core.config.utils.clear_defaults(cfg)
|
|
31
|
+
# Set global settings
|
|
32
|
+
GlobalSettings().config = cfg
|
|
33
|
+
GlobalSettings().output_directory = HydraConfig.get().runtime.output_dir
|
|
34
|
+
|
|
35
|
+
configure_logger(cfg.system.debug_level)# second time because hydra overrides current global logger configuration
|
|
36
|
+
out_dir = HydraConfig.get().runtime.output_dir
|
|
37
|
+
wandb.init(project=cfg.wandb.project, entity=cfg.wandb.entity, name=cfg.wandb.run_name,
|
|
38
|
+
id=cfg.wandb.run_id, group=cfg.wandb.group_name, notes=cfg.wandb.notes,
|
|
39
|
+
dir=out_dir, tags=cfg.wandb.tags, job_type='test', resume='never',
|
|
40
|
+
config=omegaconf.OmegaConf.to_container(cfg, resolve=False, throw_on_missing=True),
|
|
41
|
+
mode=cfg.wandb.mode, settings=wandb.Settings(x_disable_stats=True))
|
|
42
|
+
|
|
43
|
+
print('Configuring benchmark...')
|
|
44
|
+
print(f'Run ID: {wandb.run.id}')
|
|
45
|
+
print(f'Run Name: {wandb.run.name}')
|
|
46
|
+
print(f'Run path: {wandb.run.path}')
|
|
47
|
+
|
|
48
|
+
if isinstance(cfg.system.device, (list, tuple, omegaconf.ListConfig)):
|
|
49
|
+
raise NotImplementedError("Multi-GPU inference is not supported yet.")
|
|
50
|
+
else:
|
|
51
|
+
device = torch.device(cfg.system.device)
|
|
52
|
+
print(f"Using device: {device}")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
parser = ConfigParser(cfg, wandb.run)
|
|
56
|
+
model = parser(cfg.model)
|
|
57
|
+
model.to(device)
|
|
58
|
+
metrics = parser(cfg.test_metrics)
|
|
59
|
+
test_dataset = parser(cfg.dataset)
|
|
60
|
+
postprocessor = parser(cfg.postprocessing)
|
|
61
|
+
|
|
62
|
+
if cfg.system.cache_dir is not None:
|
|
63
|
+
cache_path = Path(cfg.system.cache_dir) / 'cache' / f"{device.type}_{device.index}"
|
|
64
|
+
print(cache_path, cache_path.absolute())
|
|
65
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
66
|
+
if os.path.isdir(cache_path) and cfg.system.recache:
|
|
67
|
+
shutil.rmtree(cache_path)
|
|
68
|
+
test_dataset = test_dataset.cache(td.cachers.Pickle(cache_path / 'test'))
|
|
69
|
+
|
|
70
|
+
test_loader = DataLoaderFactory(test_dataset, batch_size=1, shuffle=False,
|
|
71
|
+
num_workers=0, device=cfg.system.device).get_loader()
|
|
72
|
+
|
|
73
|
+
out_dir = Path(cfg.system.output_dir)
|
|
74
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
|
|
76
|
+
benchmark_runner = runners.BenchmarkRunner(metrics, device, postprocessor=postprocessor)
|
|
77
|
+
benchmark_runner.add_observer(observers.ProgressLogger("Testing"))
|
|
78
|
+
benchmark_runner.add_observer(observers.BatchImageSaver(
|
|
79
|
+
image_processor=ImageProcessor(65535.0, border=0, ),
|
|
80
|
+
output_path=out_dir,
|
|
81
|
+
dtype='uint16',
|
|
82
|
+
))
|
|
83
|
+
benchmark_runner.add_observer(observers.BatchImageLogger(list(range(0, 50)),
|
|
84
|
+
observers.ImageProcessor(
|
|
85
|
+
65535.0,
|
|
86
|
+
border=3,
|
|
87
|
+
equalize_mode='clahe'),
|
|
88
|
+
# ref_key='hr',
|
|
89
|
+
# img_key='sr',
|
|
90
|
+
))
|
|
91
|
+
benchmark_runner.add_observer(observers.TableLogger(
|
|
92
|
+
'test_benchmark',
|
|
93
|
+
'sr',
|
|
94
|
+
processor=observers.ImageProcessor(65535.0, 'clahe')
|
|
95
|
+
))
|
|
96
|
+
|
|
97
|
+
benchmark_runner.run_epoch(model=model, data_loader=test_loader, epoch=0)
|
|
98
|
+
wandb.finish(0)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
if __name__ == "__main__":
|
|
104
|
+
main()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
os.environ["C10D_TCP_STORE_USE_LIBUV"] = "0"
|
|
6
|
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # IMPORT CORE AFTER THIS LINE !!!
|
|
7
|
+
import logging
|
|
8
|
+
from srforge.utils.logging import configure_logger
|
|
9
|
+
configure_logger(logging.INFO)
|
|
10
|
+
import hydra
|
|
11
|
+
from hydra.core.hydra_config import HydraConfig
|
|
12
|
+
|
|
13
|
+
import wandb
|
|
14
|
+
import omegaconf
|
|
15
|
+
from srforge import GlobalSettings
|
|
16
|
+
import srforge.config.utils
|
|
17
|
+
from srforge.config.legacy import ConfigParser
|
|
18
|
+
from srforge.config import TrainConfig
|
|
19
|
+
from srforge.dataset import Dataset
|
|
20
|
+
from srforge.data.loader import DataLoaderFactory
|
|
21
|
+
from srforge.data import Entry, GraphEntry
|
|
22
|
+
|
|
23
|
+
from srforge.loss import Loss
|
|
24
|
+
from srforge.training import runners, trainers, observers
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import torch_geometric as tg
|
|
29
|
+
import torchdatasets as td
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
@hydra.main(config_path="configs", config_name="train-cfg", version_base=None)
|
|
34
|
+
def main(cfg) -> None:
|
|
35
|
+
cfg: TrainConfig
|
|
36
|
+
srforge.config.utils.clear_defaults(cfg)
|
|
37
|
+
|
|
38
|
+
# Set global settings
|
|
39
|
+
GlobalSettings().config = cfg
|
|
40
|
+
GlobalSettings().output_directory = HydraConfig.get().runtime.output_dir
|
|
41
|
+
|
|
42
|
+
configure_logger(cfg.system.debug_level)# second time because hydra overrides current global logger configuration
|
|
43
|
+
out_dir = HydraConfig.get().runtime.output_dir
|
|
44
|
+
wandb.init(project=cfg.wandb.project, entity=cfg.wandb.entity, name=cfg.wandb.run_name,
|
|
45
|
+
id=cfg.wandb.run_id, group=cfg.wandb.group_name, notes=cfg.wandb.notes,
|
|
46
|
+
dir=out_dir, tags=cfg.wandb.tags, job_type='training', resume='allow',
|
|
47
|
+
config=omegaconf.OmegaConf.to_container(cfg, resolve=False, throw_on_missing=True),
|
|
48
|
+
mode=cfg.wandb.mode)
|
|
49
|
+
print('Configuring training...')
|
|
50
|
+
print(f'Run ID: {wandb.run.id}')
|
|
51
|
+
print(f'Run Name: {wandb.run.name}')
|
|
52
|
+
print(f'Run path: {wandb.run.path}')
|
|
53
|
+
if wandb.run.resumed:
|
|
54
|
+
print(f'Training resumed: {wandb.run.resumed}\n')
|
|
55
|
+
|
|
56
|
+
multigpu_training = False
|
|
57
|
+
if isinstance(cfg.system.device, (list, tuple, omegaconf.ListConfig)):
|
|
58
|
+
device = torch.device(cfg.system.device[0])
|
|
59
|
+
if len(cfg.system.device) > 1:
|
|
60
|
+
multigpu_training = True
|
|
61
|
+
print(f"Using devices: {[torch.device(x) for x in cfg.system.device]}")
|
|
62
|
+
else:
|
|
63
|
+
device = torch.device(cfg.system.device)
|
|
64
|
+
print(f"Using device: {device}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
parser = ConfigParser(cfg, wandb.run)
|
|
68
|
+
initial_epoch, best_losses = parser.get_training_data()
|
|
69
|
+
model = parser.get_model()
|
|
70
|
+
model.to(device)
|
|
71
|
+
optimizer = parser.get_optimizer()
|
|
72
|
+
lr_scheduler = parser.get_lr_scheduler()
|
|
73
|
+
loss: Loss = parser(cfg.loss)
|
|
74
|
+
metrics = parser(cfg.validation_metrics)
|
|
75
|
+
train_dataset: Dataset = parser(cfg.dataset.training)
|
|
76
|
+
valid_dataset = parser(cfg.dataset.validation)
|
|
77
|
+
train_postprocessor = parser(cfg.postprocessing.training)
|
|
78
|
+
valid_postprocessor = parser(cfg.postprocessing.validation)
|
|
79
|
+
|
|
80
|
+
# train_dataset = train_dataset.take(12)
|
|
81
|
+
# valid_dataset = valid_dataset.take(12)
|
|
82
|
+
|
|
83
|
+
if cfg.system.cache_dir is not None:
|
|
84
|
+
cache_path = Path(cfg.system.cache_dir) / 'cache' / f"{device.type}_{device.index}"
|
|
85
|
+
print(cache_path, cache_path.absolute())
|
|
86
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
87
|
+
if os.path.isdir(cache_path) and cfg.system.recache:
|
|
88
|
+
shutil.rmtree(cache_path)
|
|
89
|
+
train_dataset = train_dataset.cache(td.cachers.Pickle(cache_path / 'train'))
|
|
90
|
+
valid_dataset = valid_dataset.cache(td.cachers.Pickle(cache_path / 'val'))
|
|
91
|
+
|
|
92
|
+
train_loader = DataLoaderFactory(train_dataset, batch_size=cfg.training.batch_size, shuffle=True,
|
|
93
|
+
num_workers=cfg.training.num_workers, device=cfg.system.device,
|
|
94
|
+
pin_memory_device=str(device), pin_memory=True).get_loader()
|
|
95
|
+
val_loader = DataLoaderFactory(valid_dataset, batch_size=1, shuffle=False,
|
|
96
|
+
num_workers=0, device=cfg.system.device,
|
|
97
|
+
pin_memory_device=str(device), pin_memory=True).get_loader()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
if multigpu_training:
|
|
102
|
+
first_element = train_dataset[0]
|
|
103
|
+
if isinstance(first_element, Entry):
|
|
104
|
+
model = torch.nn.DataParallel(model, device_ids=[torch.device(x).index for x in cfg.system.device])
|
|
105
|
+
elif isinstance(first_element, GraphEntry):
|
|
106
|
+
model = tg.nn.DataParallel(model, device_ids=[torch.device(x).index for x in cfg.system.device])
|
|
107
|
+
|
|
108
|
+
else:
|
|
109
|
+
device = torch.device(f'cuda:{cfg.system.device}')
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
wandb.watch(model if not isinstance(model, (tg.nn.DataParallel, torch.nn.DataParallel)) else model.module,
|
|
113
|
+
log='all', log_graph=True, log_freq=len(train_loader), criterion=loss)
|
|
114
|
+
print('Training started...')
|
|
115
|
+
train_runner = runners.TrainingEpochRunner(optimizer, loss,
|
|
116
|
+
postprocessor=train_postprocessor,
|
|
117
|
+
mixed_precision=cfg.system.mixed_precision,
|
|
118
|
+
device=device,
|
|
119
|
+
gradient_accumulation_steps=cfg.training.gradient_accumulation_steps)
|
|
120
|
+
train_runner.add_observer(observers.ProgressLogger("T"))
|
|
121
|
+
val_runner = runners.ValidationEpochRunner(loss_fn=metrics,
|
|
122
|
+
postprocessor=valid_postprocessor,
|
|
123
|
+
mixed_precision=cfg.system.mixed_precision,
|
|
124
|
+
device=device)
|
|
125
|
+
val_runner.add_observer(observers.ProgressLogger("V"))
|
|
126
|
+
val_runner.add_observer(observers.LogitsLogger(batch_id=0, ref_key='target_shifts', img_key='dynamic_filters',
|
|
127
|
+
log_dir='dynamic_filters'))
|
|
128
|
+
trainer = trainers.PyTorchTrainer(model, train_runner, val_runner, scheduler=lr_scheduler,
|
|
129
|
+
initial_epoch=initial_epoch,)
|
|
130
|
+
# trainer.add_observer(observers.LossScheduler(loss_fn))
|
|
131
|
+
trainer.add_observer(observers.LossLogger())
|
|
132
|
+
trainer.add_observer(observers.PyTorchModelSaver(best_loss=best_losses['total'] if best_losses is not None else None))
|
|
133
|
+
trainer.train(cfg.training.epochs, train_loader, val_loader)
|
|
134
|
+
wandb.finish(0)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
140
|
+
main()
|
srforge-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from srforge.registry import ClassRegistry
|
|
2
|
+
import logging
|
|
3
|
+
logger = logging.getLogger(__name__)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Singleton(type):
|
|
7
|
+
_instances = {}
|
|
8
|
+
def __call__(cls, *args, **kwargs):
|
|
9
|
+
if cls not in cls._instances:
|
|
10
|
+
cls._instances[cls] = super().__call__(*args, **kwargs)
|
|
11
|
+
return cls._instances[cls]
|
|
12
|
+
|
|
13
|
+
class GlobalSettings(metaclass=Singleton):
|
|
14
|
+
def __init__(self):
|
|
15
|
+
self.debug_mode = False
|
|
16
|
+
self.output_directory = 'output'
|
|
17
|
+
self.config = None
|
|
18
|
+
|
|
19
|
+
def __setattr__(self, key, value):
|
|
20
|
+
# logger.info(f"GlobalSettings: setting {key} to {value}")
|
|
21
|
+
super().__setattr__(key, value)
|
|
22
|
+
|
|
23
|
+
ClassRegistry().discover(__name__)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
logger = logging.getLogger(__name__)
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Type, Union
|
|
5
|
+
import pkgutil
|
|
6
|
+
import importlib
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import json
|
|
9
|
+
import inspect
|
|
10
|
+
import ast
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ClassRegistry:
|
|
15
|
+
_instance = None
|
|
16
|
+
_lock = threading.Lock()
|
|
17
|
+
CLASS_REGISTRY_FILENAME = "class_registry.json"
|
|
18
|
+
|
|
19
|
+
# ────────────────────────────────────────────────────────── singleton ──
|
|
20
|
+
def __new__(cls, *args, **kwargs):
|
|
21
|
+
with cls._lock:
|
|
22
|
+
if cls._instance is None:
|
|
23
|
+
cls._instance = super().__new__(cls)
|
|
24
|
+
cls._instance.classes = {} # {name: (module, qualname)}
|
|
25
|
+
cls._instance._discovered = set() # module names already walked
|
|
26
|
+
return cls._instance
|
|
27
|
+
|
|
28
|
+
# ──────────────────────────────────────────────── public API ──────────
|
|
29
|
+
def discover(self, root: Union[str, Path]):
|
|
30
|
+
"""
|
|
31
|
+
Walk every .py file inside *root* and auto‑register any class decorated
|
|
32
|
+
with @register_class. *root* may be:
|
|
33
|
+
|
|
34
|
+
* a dotted import name – "srforge" or "srforge.transform"
|
|
35
|
+
* a Path object – Path("src/srforge")
|
|
36
|
+
"""
|
|
37
|
+
root_path, root_pkg = self._to_path_and_package(root)
|
|
38
|
+
if root_pkg in self._discovered:
|
|
39
|
+
return # already done
|
|
40
|
+
|
|
41
|
+
self._walk_source_tree(root_path, root_pkg)
|
|
42
|
+
self._discovered.add(root_pkg)
|
|
43
|
+
self._save_metadata(Path(__file__).parent.parent / "metadata")
|
|
44
|
+
|
|
45
|
+
def register(self, name: str, cls: Type):
|
|
46
|
+
"""Register a class by the *public* name it should be looked up with."""
|
|
47
|
+
if any((m == cls.__module__ and n == cls.__name__)
|
|
48
|
+
for m, n in self.classes.values()):
|
|
49
|
+
return # identical class already known
|
|
50
|
+
self.classes[name] = (cls.__module__, cls.__name__)
|
|
51
|
+
|
|
52
|
+
# ──────────────────────────────────────────── internals ───────────────
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _to_path_and_package(root: Union[str, Path]):
|
|
55
|
+
"""Return (Path_to_folder, canonical_package_name)."""
|
|
56
|
+
if isinstance(root, Path):
|
|
57
|
+
# Convert src/srforge/... → srforge...
|
|
58
|
+
# The package name is the last folder that *contains* __init__.py
|
|
59
|
+
for parent in reversed(root.resolve().parents):
|
|
60
|
+
if (parent / "__init__.py").exists():
|
|
61
|
+
pkg = ".".join(root.resolve().relative_to(parent).parts)
|
|
62
|
+
full_pkg = f"{parent.name}.{pkg}" if pkg else parent.name
|
|
63
|
+
return root.resolve(), full_pkg
|
|
64
|
+
raise RuntimeError(f"Cannot derive package name from {root}")
|
|
65
|
+
|
|
66
|
+
# dotted import name: find its spec
|
|
67
|
+
spec = importlib.util.find_spec(root)
|
|
68
|
+
if spec is None or spec.origin is None:
|
|
69
|
+
raise ImportError(f"Cannot import '{root}'")
|
|
70
|
+
if spec.submodule_search_locations: # it’s a package
|
|
71
|
+
folder = Path(next(iter(spec.submodule_search_locations))).resolve()
|
|
72
|
+
else: # it’s a single module file
|
|
73
|
+
folder = Path(spec.origin).resolve().parent
|
|
74
|
+
return folder, spec.name # spec.name is the canonical pkg
|
|
75
|
+
|
|
76
|
+
# --------------------------------------------------------------------
|
|
77
|
+
def _walk_source_tree(self, folder: Path, package_root: str):
|
|
78
|
+
"""
|
|
79
|
+
Parse every .py file underneath *folder* and register classes that have
|
|
80
|
+
a @register_class decorator.
|
|
81
|
+
"""
|
|
82
|
+
for py_file in folder.rglob("*.py"):
|
|
83
|
+
try:
|
|
84
|
+
tree = ast.parse(py_file.read_text("utf‑8"), filename=str(py_file))
|
|
85
|
+
except SyntaxError:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
rel_mod = ".".join(py_file.relative_to(folder).with_suffix("").parts)
|
|
89
|
+
module_name = f"{package_root}.{rel_mod}" if rel_mod else package_root
|
|
90
|
+
|
|
91
|
+
for node in tree.body:
|
|
92
|
+
if not isinstance(node, ast.ClassDef):
|
|
93
|
+
continue
|
|
94
|
+
for deco in node.decorator_list:
|
|
95
|
+
# @register_class → ast.Name
|
|
96
|
+
# @register_class("foo") → ast.Call(func=ast.Name)
|
|
97
|
+
if (isinstance(deco, ast.Name) and deco.id == "register_class") \
|
|
98
|
+
or (isinstance(deco, ast.Call) and
|
|
99
|
+
isinstance(deco.func, ast.Name) and deco.func.id == "register_class"):
|
|
100
|
+
public = (node.name if isinstance(deco, ast.Name)
|
|
101
|
+
else (deco.args[0].value if deco.args else node.name))
|
|
102
|
+
self.classes.setdefault(public, (module_name, node.name))
|
|
103
|
+
break # stop scanning decorators
|
|
104
|
+
|
|
105
|
+
# --------------------------------------------------------------------
|
|
106
|
+
def _save_metadata(self, path: Path):
|
|
107
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
108
|
+
result = dict()
|
|
109
|
+
|
|
110
|
+
classes = list(self.classes.keys())
|
|
111
|
+
classes = sorted(classes, key=lambda x: x.lower())
|
|
112
|
+
result['by_class'] = {k: self.classes[k][0] for k in classes}
|
|
113
|
+
|
|
114
|
+
modules = sorted([self.classes[k][0] for k in self.classes])
|
|
115
|
+
x = {module: [] for module in modules}
|
|
116
|
+
for k, v in self.classes.items():
|
|
117
|
+
x[v[0]].append(k)
|
|
118
|
+
keys = sorted(x.keys())
|
|
119
|
+
result['by_module'] = {k: x[k] for k in keys},
|
|
120
|
+
|
|
121
|
+
with open(path / self.CLASS_REGISTRY_FILENAME, 'w') as f:
|
|
122
|
+
json.dump(result, f, indent=4)
|
|
123
|
+
|
|
124
|
+
def __contains__(self, name):
|
|
125
|
+
return self.has(name)
|
|
126
|
+
|
|
127
|
+
def get(self, name):
|
|
128
|
+
return self.classes.get(name)
|
|
129
|
+
|
|
130
|
+
def has(self, name):
|
|
131
|
+
return name in self.classes
|
|
132
|
+
|
|
133
|
+
def __repr__(self):
|
|
134
|
+
# pretty print
|
|
135
|
+
return f'ClassRegistry({self.classes})'
|
|
136
|
+
|
|
137
|
+
def register_class(cls_or_name: Union[type, str] = None):
|
|
138
|
+
"""
|
|
139
|
+
Decorator to register a class with its module path so that when loading from the yaml config file through _target
|
|
140
|
+
keyword, the class can be instantiated without the need of specifying the full module path.
|
|
141
|
+
"""
|
|
142
|
+
def wrapper(cls):
|
|
143
|
+
# Get the full module path
|
|
144
|
+
class_name = cls.__name__ if isinstance(cls_or_name, type) or cls_or_name is None else cls_or_name
|
|
145
|
+
ClassRegistry().register(class_name, cls)
|
|
146
|
+
return cls
|
|
147
|
+
|
|
148
|
+
if isinstance(cls_or_name, type):
|
|
149
|
+
return wrapper(cls_or_name)
|
|
150
|
+
return wrapper
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
|
|
3
|
+
AfterEpochDataToSave = collections.namedtuple("DataToSave",
|
|
4
|
+
["optimizer",
|
|
5
|
+
"lr_scheduler",
|
|
6
|
+
"model",
|
|
7
|
+
"train_loss",
|
|
8
|
+
"val_loss",
|
|
9
|
+
"epoch",
|
|
10
|
+
"scaler"])
|
|
11
|
+
|
|
12
|
+
LogPoint = collections.namedtuple("LogPoint", ["name", "value", "epoch"])
|
|
13
|
+
|
|
14
|
+
EpochValidationResults = collections.namedtuple("EpochValidationResults",
|
|
15
|
+
["epoch", "batch", "output", "input",
|
|
16
|
+
"batch_metrics", "criterion", "epoch_metrics"])
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: srforge
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files.
|
|
5
|
+
Author-email: Tomasz Tarasiewicz <tarasiewicztomasz@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://gitlab.com/tarasiewicztomasz/sr-forge
|
|
7
|
+
Project-URL: Source, https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch>=2.7
|
|
11
|
+
Requires-Dist: torchvision>=0.22
|
|
12
|
+
Requires-Dist: torchaudio>=2.7
|
|
13
|
+
Requires-Dist: torch-geometric>=2.6
|
|
14
|
+
Requires-Dist: pyg-lib>=0.4
|
|
15
|
+
Requires-Dist: torch-scatter>=2.1
|
|
16
|
+
Requires-Dist: torch-sparse>=0.6
|
|
17
|
+
Requires-Dist: torch-cluster>=1.6
|
|
18
|
+
Requires-Dist: torch-spline-conv>=1.2
|
|
19
|
+
Requires-Dist: torchdatasets>=0.2
|
|
20
|
+
Requires-Dist: torchmetrics>=1.7
|
|
21
|
+
Requires-Dist: hydra-core>=1.3
|
|
22
|
+
Requires-Dist: einops>=0.8
|
|
23
|
+
Requires-Dist: wandb>=0.21
|
|
24
|
+
Requires-Dist: matplotlib>=3.10
|
|
25
|
+
Requires-Dist: scikit-image>=0.25
|
|
26
|
+
Requires-Dist: opencv-python>=4.12
|
|
27
|
+
Requires-Dist: pandas>=2.3
|
|
28
|
+
Requires-Dist: colorlog>=6.9
|
|
29
|
+
Provides-Extra: cuda128
|
|
30
|
+
Requires-Dist: torch==2.7.1+cu128; extra == "cuda128"
|
|
31
|
+
Requires-Dist: torchvision==0.22.1+cu128; extra == "cuda128"
|
|
32
|
+
Requires-Dist: torchaudio==2.7.1+cu128; extra == "cuda128"
|
|
33
|
+
Requires-Dist: torch-geometric==2.6.1; extra == "cuda128"
|
|
34
|
+
Requires-Dist: pyg-lib==0.4.0+pt27cu128; extra == "cuda128"
|
|
35
|
+
Requires-Dist: torch-scatter==2.1.2+pt27cu128; extra == "cuda128"
|
|
36
|
+
Requires-Dist: torch-sparse==0.6.18+pt27cu128; extra == "cuda128"
|
|
37
|
+
Requires-Dist: torch-cluster==1.6.3+pt27cu128; extra == "cuda128"
|
|
38
|
+
Requires-Dist: torch-spline-conv==1.2.2+pt27cu128; extra == "cuda128"
|
|
39
|
+
|
|
40
|
+
# SR FORGE
|
|
41
|
+
**Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
|
|
42
|
+
|
|
43
|
+
---
|
|
44
|
+
|
|
45
|
+
SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
|
|
46
|
+
|
|
47
|
+
## Key Features
|
|
48
|
+
|
|
49
|
+
- **Structured Workflow**
|
|
50
|
+
SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
|
|
51
|
+
|
|
52
|
+
- **Task-driven restoration**
|
|
53
|
+
Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
|
|
54
|
+
|
|
55
|
+
- **Config-Driven Experiments**
|
|
56
|
+
Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
|
|
57
|
+
|
|
58
|
+
- **Flexible Model Plug-In**
|
|
59
|
+
Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
|
|
60
|
+
|
|
61
|
+
- **Unified Metrics**
|
|
62
|
+
Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
|
|
63
|
+
|
|
64
|
+
- **Visualization Tools**
|
|
65
|
+
Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
|
|
66
|
+
|
|
67
|
+
## Installation
|
|
68
|
+
|
|
69
|
+
1. **Clone the Repository**
|
|
70
|
+
```bash
|
|
71
|
+
git clone https://github.com/your-username/sr-forge.git
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
scripts/__init__.py
|
|
4
|
+
scripts/test.py
|
|
5
|
+
scripts/train.py
|
|
6
|
+
srforge/__init__.py
|
|
7
|
+
srforge/registry.py
|
|
8
|
+
srforge/structs.py
|
|
9
|
+
srforge.egg-info/PKG-INFO
|
|
10
|
+
srforge.egg-info/SOURCES.txt
|
|
11
|
+
srforge.egg-info/dependency_links.txt
|
|
12
|
+
srforge.egg-info/entry_points.txt
|
|
13
|
+
srforge.egg-info/requires.txt
|
|
14
|
+
srforge.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
torch>=2.7
|
|
2
|
+
torchvision>=0.22
|
|
3
|
+
torchaudio>=2.7
|
|
4
|
+
torch-geometric>=2.6
|
|
5
|
+
pyg-lib>=0.4
|
|
6
|
+
torch-scatter>=2.1
|
|
7
|
+
torch-sparse>=0.6
|
|
8
|
+
torch-cluster>=1.6
|
|
9
|
+
torch-spline-conv>=1.2
|
|
10
|
+
torchdatasets>=0.2
|
|
11
|
+
torchmetrics>=1.7
|
|
12
|
+
hydra-core>=1.3
|
|
13
|
+
einops>=0.8
|
|
14
|
+
wandb>=0.21
|
|
15
|
+
matplotlib>=3.10
|
|
16
|
+
scikit-image>=0.25
|
|
17
|
+
opencv-python>=4.12
|
|
18
|
+
pandas>=2.3
|
|
19
|
+
colorlog>=6.9
|
|
20
|
+
|
|
21
|
+
[cuda128]
|
|
22
|
+
torch==2.7.1+cu128
|
|
23
|
+
torchvision==0.22.1+cu128
|
|
24
|
+
torchaudio==2.7.1+cu128
|
|
25
|
+
torch-geometric==2.6.1
|
|
26
|
+
pyg-lib==0.4.0+pt27cu128
|
|
27
|
+
torch-scatter==2.1.2+pt27cu128
|
|
28
|
+
torch-sparse==0.6.18+pt27cu128
|
|
29
|
+
torch-cluster==1.6.3+pt27cu128
|
|
30
|
+
torch-spline-conv==1.2.2+pt27cu128
|