srforge 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
srforge-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.4
2
+ Name: srforge
3
+ Version: 0.1.0
4
+ Summary: Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files.
5
+ Author-email: Tomasz Tarasiewicz <tarasiewicztomasz@gmail.com>
6
+ Project-URL: Homepage, https://gitlab.com/tarasiewicztomasz/sr-forge
7
+ Project-URL: Source, https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch>=2.7
11
+ Requires-Dist: torchvision>=0.22
12
+ Requires-Dist: torchaudio>=2.7
13
+ Requires-Dist: torch-geometric>=2.6
14
+ Requires-Dist: pyg-lib>=0.4
15
+ Requires-Dist: torch-scatter>=2.1
16
+ Requires-Dist: torch-sparse>=0.6
17
+ Requires-Dist: torch-cluster>=1.6
18
+ Requires-Dist: torch-spline-conv>=1.2
19
+ Requires-Dist: torchdatasets>=0.2
20
+ Requires-Dist: torchmetrics>=1.7
21
+ Requires-Dist: hydra-core>=1.3
22
+ Requires-Dist: einops>=0.8
23
+ Requires-Dist: wandb>=0.21
24
+ Requires-Dist: matplotlib>=3.10
25
+ Requires-Dist: scikit-image>=0.25
26
+ Requires-Dist: opencv-python>=4.12
27
+ Requires-Dist: pandas>=2.3
28
+ Requires-Dist: colorlog>=6.9
29
+ Provides-Extra: cuda128
30
+ Requires-Dist: torch==2.7.1+cu128; extra == "cuda128"
31
+ Requires-Dist: torchvision==0.22.1+cu128; extra == "cuda128"
32
+ Requires-Dist: torchaudio==2.7.1+cu128; extra == "cuda128"
33
+ Requires-Dist: torch-geometric==2.6.1; extra == "cuda128"
34
+ Requires-Dist: pyg-lib==0.4.0+pt27cu128; extra == "cuda128"
35
+ Requires-Dist: torch-scatter==2.1.2+pt27cu128; extra == "cuda128"
36
+ Requires-Dist: torch-sparse==0.6.18+pt27cu128; extra == "cuda128"
37
+ Requires-Dist: torch-cluster==1.6.3+pt27cu128; extra == "cuda128"
38
+ Requires-Dist: torch-spline-conv==1.2.2+pt27cu128; extra == "cuda128"
39
+
40
+ # SR FORGE
41
+ **Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
42
+
43
+ ---
44
+
45
+ SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
46
+
47
+ ## Key Features
48
+
49
+ - **Structured Workflow**
50
+ SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
51
+
52
+ - **Task-driven restoration**
53
+ Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
54
+
55
+ - **Config-Driven Experiments**
56
+ Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
57
+
58
+ - **Flexible Model Plug-In**
59
+ Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
60
+
61
+ - **Unified Metrics**
62
+ Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
63
+
64
+ - **Visualization Tools**
65
+ Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
66
+
67
+ ## Installation
68
+
69
+ 1. **Clone the Repository**
70
+ ```bash
71
+ git clone https://github.com/your-username/sr-forge.git
@@ -0,0 +1,32 @@
1
+ # SR FORGE
2
+ **Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
3
+
4
+ ---
5
+
6
+ SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
7
+
8
+ ## Key Features
9
+
10
+ - **Structured Workflow**
11
+ SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
12
+
13
+ - **Task-driven restoration**
14
+ Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
15
+
16
+ - **Config-Driven Experiments**
17
+ Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
18
+
19
+ - **Flexible Model Plug-In**
20
+ Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
21
+
22
+ - **Unified Metrics**
23
+ Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
24
+
25
+ - **Visualization Tools**
26
+ Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
27
+
28
+ ## Installation
29
+
30
+ 1. **Clone the Repository**
31
+ ```bash
32
+ git clone https://github.com/your-username/sr-forge.git
@@ -0,0 +1,77 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"] # ↑ keep these up‑to‑date
3
+ build-backend = "setuptools.build_meta"
4
+
5
+
6
+
7
+ [project.urls]
8
+ Homepage = "https://gitlab.com/tarasiewicztomasz/sr-forge"
9
+ Source = "https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main"
10
+
11
+
12
+ [project] # PEP 621 metadata
13
+ name = "srforge" # pip install srforge
14
+ version = "0.1.0"
15
+ description = "Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files."
16
+ readme = "README.md"
17
+ requires-python = ">=3.10"
18
+
19
+ authors = [
20
+ { name = "Tomasz Tarasiewicz", email = "tarasiewicztomasz@gmail.com" }
21
+ ]
22
+
23
+ dependencies = [
24
+ "torch>=2.7", # gets the CPU wheel by default
25
+ "torchvision>=0.22",
26
+ "torchaudio>=2.7",
27
+
28
+ "torch-geometric>=2.6",
29
+ "pyg-lib>=0.4",
30
+ "torch-scatter>=2.1",
31
+ "torch-sparse>=0.6",
32
+ "torch-cluster>=1.6",
33
+ "torch-spline-conv>=1.2",
34
+
35
+ "torchdatasets>=0.2",
36
+ "torchmetrics>=1.7",
37
+
38
+ "hydra-core>=1.3",
39
+ "einops>=0.8",
40
+ "wandb>=0.21",
41
+
42
+ "matplotlib>=3.10",
43
+ "scikit-image>=0.25",
44
+ "opencv-python>=4.12",
45
+ "pandas>=2.3",
46
+
47
+ "colorlog>=6.9"
48
+ ]
49
+
50
+ # ────────────────────────────── CUDA 12.8 wheels ───────────────────────────────
51
+ [project.optional-dependencies]
52
+ cuda128 = [
53
+ "torch==2.7.1+cu128",
54
+ "torchvision==0.22.1+cu128",
55
+ "torchaudio==2.7.1+cu128",
56
+
57
+ "torch-geometric==2.6.1",
58
+ "pyg-lib==0.4.0+pt27cu128",
59
+ "torch-scatter==2.1.2+pt27cu128",
60
+ "torch-sparse==0.6.18+pt27cu128",
61
+ "torch-cluster==1.6.3+pt27cu128",
62
+ "torch-spline-conv==1.2.2+pt27cu128"
63
+ ]
64
+
65
+ # 1) tell setuptools “only package the 'srforge' directory”
66
+ [tool.setuptools]
67
+ packages = ["srforge", "scripts"] # easiest & explicit
68
+
69
+ # alternatively you could use the automatic finder:
70
+ # [tool.setuptools.packages.find]
71
+ # include = ["srforge*"] # but explicit is clearer
72
+
73
+
74
+ # 2) expose console‑script entry‑points *optional*
75
+ [project.scripts]
76
+ train = "scripts.train:main"
77
+ test = "scripts.test:main"
File without changes
@@ -0,0 +1,104 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ from training.observers import ImageProcessor
6
+
7
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # IMPORT CORE AFTER THIS LINE !!!
8
+ import logging
9
+ from utils.logging import configure_logger
10
+ configure_logger(logging.INFO)
11
+ import hydra
12
+ from hydra.core.hydra_config import HydraConfig
13
+ from tqdm import tqdm
14
+
15
+ import wandb
16
+ import omegaconf
17
+ from core import GlobalSettings
18
+ import core.config.utils
19
+ from core.config.legacy import ConfigParser
20
+ from core.data.loader import DataLoaderFactory
21
+ from core.training import runners, observers
22
+
23
+ import torch
24
+ import torchdatasets as td
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ @hydra.main(config_path="configs", config_name="test-cfg", version_base=None)
29
+ def main(cfg) -> None:
30
+ core.config.utils.clear_defaults(cfg)
31
+ # Set global settings
32
+ GlobalSettings().config = cfg
33
+ GlobalSettings().output_directory = HydraConfig.get().runtime.output_dir
34
+
35
+ configure_logger(cfg.system.debug_level)# second time because hydra overrides current global logger configuration
36
+ out_dir = HydraConfig.get().runtime.output_dir
37
+ wandb.init(project=cfg.wandb.project, entity=cfg.wandb.entity, name=cfg.wandb.run_name,
38
+ id=cfg.wandb.run_id, group=cfg.wandb.group_name, notes=cfg.wandb.notes,
39
+ dir=out_dir, tags=cfg.wandb.tags, job_type='test', resume='never',
40
+ config=omegaconf.OmegaConf.to_container(cfg, resolve=False, throw_on_missing=True),
41
+ mode=cfg.wandb.mode, settings=wandb.Settings(x_disable_stats=True))
42
+
43
+ print('Configuring benchmark...')
44
+ print(f'Run ID: {wandb.run.id}')
45
+ print(f'Run Name: {wandb.run.name}')
46
+ print(f'Run path: {wandb.run.path}')
47
+
48
+ if isinstance(cfg.system.device, (list, tuple, omegaconf.ListConfig)):
49
+ raise NotImplementedError("Multi-GPU inference is not supported yet.")
50
+ else:
51
+ device = torch.device(cfg.system.device)
52
+ print(f"Using device: {device}")
53
+
54
+
55
+ parser = ConfigParser(cfg, wandb.run)
56
+ model = parser(cfg.model)
57
+ model.to(device)
58
+ metrics = parser(cfg.test_metrics)
59
+ test_dataset = parser(cfg.dataset)
60
+ postprocessor = parser(cfg.postprocessing)
61
+
62
+ if cfg.system.cache_dir is not None:
63
+ cache_path = Path(cfg.system.cache_dir) / 'cache' / f"{device.type}_{device.index}"
64
+ print(cache_path, cache_path.absolute())
65
+ os.makedirs(cache_path, exist_ok=True)
66
+ if os.path.isdir(cache_path) and cfg.system.recache:
67
+ shutil.rmtree(cache_path)
68
+ test_dataset = test_dataset.cache(td.cachers.Pickle(cache_path / 'test'))
69
+
70
+ test_loader = DataLoaderFactory(test_dataset, batch_size=1, shuffle=False,
71
+ num_workers=0, device=cfg.system.device).get_loader()
72
+
73
+ out_dir = Path(cfg.system.output_dir)
74
+ out_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+ benchmark_runner = runners.BenchmarkRunner(metrics, device, postprocessor=postprocessor)
77
+ benchmark_runner.add_observer(observers.ProgressLogger("Testing"))
78
+ benchmark_runner.add_observer(observers.BatchImageSaver(
79
+ image_processor=ImageProcessor(65535.0, border=0, ),
80
+ output_path=out_dir,
81
+ dtype='uint16',
82
+ ))
83
+ benchmark_runner.add_observer(observers.BatchImageLogger(list(range(0, 50)),
84
+ observers.ImageProcessor(
85
+ 65535.0,
86
+ border=3,
87
+ equalize_mode='clahe'),
88
+ # ref_key='hr',
89
+ # img_key='sr',
90
+ ))
91
+ benchmark_runner.add_observer(observers.TableLogger(
92
+ 'test_benchmark',
93
+ 'sr',
94
+ processor=observers.ImageProcessor(65535.0, 'clahe')
95
+ ))
96
+
97
+ benchmark_runner.run_epoch(model=model, data_loader=test_loader, epoch=0)
98
+ wandb.finish(0)
99
+
100
+
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
@@ -0,0 +1,140 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ os.environ["C10D_TCP_STORE_USE_LIBUV"] = "0"
6
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # IMPORT CORE AFTER THIS LINE !!!
7
+ import logging
8
+ from srforge.utils.logging import configure_logger
9
+ configure_logger(logging.INFO)
10
+ import hydra
11
+ from hydra.core.hydra_config import HydraConfig
12
+
13
+ import wandb
14
+ import omegaconf
15
+ from srforge import GlobalSettings
16
+ import srforge.config.utils
17
+ from srforge.config.legacy import ConfigParser
18
+ from srforge.config import TrainConfig
19
+ from srforge.dataset import Dataset
20
+ from srforge.data.loader import DataLoaderFactory
21
+ from srforge.data import Entry, GraphEntry
22
+
23
+ from srforge.loss import Loss
24
+ from srforge.training import runners, trainers, observers
25
+
26
+
27
+ import torch
28
+ import torch_geometric as tg
29
+ import torchdatasets as td
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ @hydra.main(config_path="configs", config_name="train-cfg", version_base=None)
34
+ def main(cfg) -> None:
35
+ cfg: TrainConfig
36
+ srforge.config.utils.clear_defaults(cfg)
37
+
38
+ # Set global settings
39
+ GlobalSettings().config = cfg
40
+ GlobalSettings().output_directory = HydraConfig.get().runtime.output_dir
41
+
42
+ configure_logger(cfg.system.debug_level)# second time because hydra overrides current global logger configuration
43
+ out_dir = HydraConfig.get().runtime.output_dir
44
+ wandb.init(project=cfg.wandb.project, entity=cfg.wandb.entity, name=cfg.wandb.run_name,
45
+ id=cfg.wandb.run_id, group=cfg.wandb.group_name, notes=cfg.wandb.notes,
46
+ dir=out_dir, tags=cfg.wandb.tags, job_type='training', resume='allow',
47
+ config=omegaconf.OmegaConf.to_container(cfg, resolve=False, throw_on_missing=True),
48
+ mode=cfg.wandb.mode)
49
+ print('Configuring training...')
50
+ print(f'Run ID: {wandb.run.id}')
51
+ print(f'Run Name: {wandb.run.name}')
52
+ print(f'Run path: {wandb.run.path}')
53
+ if wandb.run.resumed:
54
+ print(f'Training resumed: {wandb.run.resumed}\n')
55
+
56
+ multigpu_training = False
57
+ if isinstance(cfg.system.device, (list, tuple, omegaconf.ListConfig)):
58
+ device = torch.device(cfg.system.device[0])
59
+ if len(cfg.system.device) > 1:
60
+ multigpu_training = True
61
+ print(f"Using devices: {[torch.device(x) for x in cfg.system.device]}")
62
+ else:
63
+ device = torch.device(cfg.system.device)
64
+ print(f"Using device: {device}")
65
+
66
+
67
+ parser = ConfigParser(cfg, wandb.run)
68
+ initial_epoch, best_losses = parser.get_training_data()
69
+ model = parser.get_model()
70
+ model.to(device)
71
+ optimizer = parser.get_optimizer()
72
+ lr_scheduler = parser.get_lr_scheduler()
73
+ loss: Loss = parser(cfg.loss)
74
+ metrics = parser(cfg.validation_metrics)
75
+ train_dataset: Dataset = parser(cfg.dataset.training)
76
+ valid_dataset = parser(cfg.dataset.validation)
77
+ train_postprocessor = parser(cfg.postprocessing.training)
78
+ valid_postprocessor = parser(cfg.postprocessing.validation)
79
+
80
+ # train_dataset = train_dataset.take(12)
81
+ # valid_dataset = valid_dataset.take(12)
82
+
83
+ if cfg.system.cache_dir is not None:
84
+ cache_path = Path(cfg.system.cache_dir) / 'cache' / f"{device.type}_{device.index}"
85
+ print(cache_path, cache_path.absolute())
86
+ os.makedirs(cache_path, exist_ok=True)
87
+ if os.path.isdir(cache_path) and cfg.system.recache:
88
+ shutil.rmtree(cache_path)
89
+ train_dataset = train_dataset.cache(td.cachers.Pickle(cache_path / 'train'))
90
+ valid_dataset = valid_dataset.cache(td.cachers.Pickle(cache_path / 'val'))
91
+
92
+ train_loader = DataLoaderFactory(train_dataset, batch_size=cfg.training.batch_size, shuffle=True,
93
+ num_workers=cfg.training.num_workers, device=cfg.system.device,
94
+ pin_memory_device=str(device), pin_memory=True).get_loader()
95
+ val_loader = DataLoaderFactory(valid_dataset, batch_size=1, shuffle=False,
96
+ num_workers=0, device=cfg.system.device,
97
+ pin_memory_device=str(device), pin_memory=True).get_loader()
98
+
99
+
100
+
101
+ if multigpu_training:
102
+ first_element = train_dataset[0]
103
+ if isinstance(first_element, Entry):
104
+ model = torch.nn.DataParallel(model, device_ids=[torch.device(x).index for x in cfg.system.device])
105
+ elif isinstance(first_element, GraphEntry):
106
+ model = tg.nn.DataParallel(model, device_ids=[torch.device(x).index for x in cfg.system.device])
107
+
108
+ else:
109
+ device = torch.device(f'cuda:{cfg.system.device}')
110
+
111
+
112
+ wandb.watch(model if not isinstance(model, (tg.nn.DataParallel, torch.nn.DataParallel)) else model.module,
113
+ log='all', log_graph=True, log_freq=len(train_loader), criterion=loss)
114
+ print('Training started...')
115
+ train_runner = runners.TrainingEpochRunner(optimizer, loss,
116
+ postprocessor=train_postprocessor,
117
+ mixed_precision=cfg.system.mixed_precision,
118
+ device=device,
119
+ gradient_accumulation_steps=cfg.training.gradient_accumulation_steps)
120
+ train_runner.add_observer(observers.ProgressLogger("T"))
121
+ val_runner = runners.ValidationEpochRunner(loss_fn=metrics,
122
+ postprocessor=valid_postprocessor,
123
+ mixed_precision=cfg.system.mixed_precision,
124
+ device=device)
125
+ val_runner.add_observer(observers.ProgressLogger("V"))
126
+ val_runner.add_observer(observers.LogitsLogger(batch_id=0, ref_key='target_shifts', img_key='dynamic_filters',
127
+ log_dir='dynamic_filters'))
128
+ trainer = trainers.PyTorchTrainer(model, train_runner, val_runner, scheduler=lr_scheduler,
129
+ initial_epoch=initial_epoch,)
130
+ # trainer.add_observer(observers.LossScheduler(loss_fn))
131
+ trainer.add_observer(observers.LossLogger())
132
+ trainer.add_observer(observers.PyTorchModelSaver(best_loss=best_losses['total'] if best_losses is not None else None))
133
+ trainer.train(cfg.training.epochs, train_loader, val_loader)
134
+ wandb.finish(0)
135
+
136
+
137
+
138
+
139
+ if __name__ == "__main__":
140
+ main()
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,23 @@
1
+ from srforge.registry import ClassRegistry
2
+ import logging
3
+ logger = logging.getLogger(__name__)
4
+
5
+
6
+ class Singleton(type):
7
+ _instances = {}
8
+ def __call__(cls, *args, **kwargs):
9
+ if cls not in cls._instances:
10
+ cls._instances[cls] = super().__call__(*args, **kwargs)
11
+ return cls._instances[cls]
12
+
13
+ class GlobalSettings(metaclass=Singleton):
14
+ def __init__(self):
15
+ self.debug_mode = False
16
+ self.output_directory = 'output'
17
+ self.config = None
18
+
19
+ def __setattr__(self, key, value):
20
+ # logger.info(f"GlobalSettings: setting {key} to {value}")
21
+ super().__setattr__(key, value)
22
+
23
+ ClassRegistry().discover(__name__)
@@ -0,0 +1,150 @@
1
+ import logging
2
+ logger = logging.getLogger(__name__)
3
+ import threading
4
+ from typing import Type, Union
5
+ import pkgutil
6
+ import importlib
7
+ from pathlib import Path
8
+ import json
9
+ import inspect
10
+ import ast
11
+
12
+
13
+
14
+ class ClassRegistry:
15
+ _instance = None
16
+ _lock = threading.Lock()
17
+ CLASS_REGISTRY_FILENAME = "class_registry.json"
18
+
19
+ # ────────────────────────────────────────────────────────── singleton ──
20
+ def __new__(cls, *args, **kwargs):
21
+ with cls._lock:
22
+ if cls._instance is None:
23
+ cls._instance = super().__new__(cls)
24
+ cls._instance.classes = {} # {name: (module, qualname)}
25
+ cls._instance._discovered = set() # module names already walked
26
+ return cls._instance
27
+
28
+ # ──────────────────────────────────────────────── public API ──────────
29
+ def discover(self, root: Union[str, Path]):
30
+ """
31
+ Walk every .py file inside *root* and auto‑register any class decorated
32
+ with @register_class. *root* may be:
33
+
34
+ * a dotted import name – "srforge" or "srforge.transform"
35
+ * a Path object – Path("src/srforge")
36
+ """
37
+ root_path, root_pkg = self._to_path_and_package(root)
38
+ if root_pkg in self._discovered:
39
+ return # already done
40
+
41
+ self._walk_source_tree(root_path, root_pkg)
42
+ self._discovered.add(root_pkg)
43
+ self._save_metadata(Path(__file__).parent.parent / "metadata")
44
+
45
+ def register(self, name: str, cls: Type):
46
+ """Register a class by the *public* name it should be looked up with."""
47
+ if any((m == cls.__module__ and n == cls.__name__)
48
+ for m, n in self.classes.values()):
49
+ return # identical class already known
50
+ self.classes[name] = (cls.__module__, cls.__name__)
51
+
52
+ # ──────────────────────────────────────────── internals ───────────────
53
+ @staticmethod
54
+ def _to_path_and_package(root: Union[str, Path]):
55
+ """Return (Path_to_folder, canonical_package_name)."""
56
+ if isinstance(root, Path):
57
+ # Convert src/srforge/... → srforge...
58
+ # The package name is the last folder that *contains* __init__.py
59
+ for parent in reversed(root.resolve().parents):
60
+ if (parent / "__init__.py").exists():
61
+ pkg = ".".join(root.resolve().relative_to(parent).parts)
62
+ full_pkg = f"{parent.name}.{pkg}" if pkg else parent.name
63
+ return root.resolve(), full_pkg
64
+ raise RuntimeError(f"Cannot derive package name from {root}")
65
+
66
+ # dotted import name: find its spec
67
+ spec = importlib.util.find_spec(root)
68
+ if spec is None or spec.origin is None:
69
+ raise ImportError(f"Cannot import '{root}'")
70
+ if spec.submodule_search_locations: # it’s a package
71
+ folder = Path(next(iter(spec.submodule_search_locations))).resolve()
72
+ else: # it’s a single module file
73
+ folder = Path(spec.origin).resolve().parent
74
+ return folder, spec.name # spec.name is the canonical pkg
75
+
76
+ # --------------------------------------------------------------------
77
+ def _walk_source_tree(self, folder: Path, package_root: str):
78
+ """
79
+ Parse every .py file underneath *folder* and register classes that have
80
+ a @register_class decorator.
81
+ """
82
+ for py_file in folder.rglob("*.py"):
83
+ try:
84
+ tree = ast.parse(py_file.read_text("utf‑8"), filename=str(py_file))
85
+ except SyntaxError:
86
+ continue
87
+
88
+ rel_mod = ".".join(py_file.relative_to(folder).with_suffix("").parts)
89
+ module_name = f"{package_root}.{rel_mod}" if rel_mod else package_root
90
+
91
+ for node in tree.body:
92
+ if not isinstance(node, ast.ClassDef):
93
+ continue
94
+ for deco in node.decorator_list:
95
+ # @register_class → ast.Name
96
+ # @register_class("foo") → ast.Call(func=ast.Name)
97
+ if (isinstance(deco, ast.Name) and deco.id == "register_class") \
98
+ or (isinstance(deco, ast.Call) and
99
+ isinstance(deco.func, ast.Name) and deco.func.id == "register_class"):
100
+ public = (node.name if isinstance(deco, ast.Name)
101
+ else (deco.args[0].value if deco.args else node.name))
102
+ self.classes.setdefault(public, (module_name, node.name))
103
+ break # stop scanning decorators
104
+
105
+ # --------------------------------------------------------------------
106
+ def _save_metadata(self, path: Path):
107
+ path.mkdir(parents=True, exist_ok=True)
108
+ result = dict()
109
+
110
+ classes = list(self.classes.keys())
111
+ classes = sorted(classes, key=lambda x: x.lower())
112
+ result['by_class'] = {k: self.classes[k][0] for k in classes}
113
+
114
+ modules = sorted([self.classes[k][0] for k in self.classes])
115
+ x = {module: [] for module in modules}
116
+ for k, v in self.classes.items():
117
+ x[v[0]].append(k)
118
+ keys = sorted(x.keys())
119
+ result['by_module'] = {k: x[k] for k in keys},
120
+
121
+ with open(path / self.CLASS_REGISTRY_FILENAME, 'w') as f:
122
+ json.dump(result, f, indent=4)
123
+
124
+ def __contains__(self, name):
125
+ return self.has(name)
126
+
127
+ def get(self, name):
128
+ return self.classes.get(name)
129
+
130
+ def has(self, name):
131
+ return name in self.classes
132
+
133
+ def __repr__(self):
134
+ # pretty print
135
+ return f'ClassRegistry({self.classes})'
136
+
137
+ def register_class(cls_or_name: Union[type, str] = None):
138
+ """
139
+ Decorator to register a class with its module path so that when loading from the yaml config file through _target
140
+ keyword, the class can be instantiated without the need of specifying the full module path.
141
+ """
142
+ def wrapper(cls):
143
+ # Get the full module path
144
+ class_name = cls.__name__ if isinstance(cls_or_name, type) or cls_or_name is None else cls_or_name
145
+ ClassRegistry().register(class_name, cls)
146
+ return cls
147
+
148
+ if isinstance(cls_or_name, type):
149
+ return wrapper(cls_or_name)
150
+ return wrapper
@@ -0,0 +1,16 @@
1
+ import collections
2
+
3
+ AfterEpochDataToSave = collections.namedtuple("DataToSave",
4
+ ["optimizer",
5
+ "lr_scheduler",
6
+ "model",
7
+ "train_loss",
8
+ "val_loss",
9
+ "epoch",
10
+ "scaler"])
11
+
12
+ LogPoint = collections.namedtuple("LogPoint", ["name", "value", "epoch"])
13
+
14
+ EpochValidationResults = collections.namedtuple("EpochValidationResults",
15
+ ["epoch", "batch", "output", "input",
16
+ "batch_metrics", "criterion", "epoch_metrics"])
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.4
2
+ Name: srforge
3
+ Version: 0.1.0
4
+ Summary: Super‑resolution research framework for PyTorch with a focus on simplicity and flexibility using config files.
5
+ Author-email: Tomasz Tarasiewicz <tarasiewicztomasz@gmail.com>
6
+ Project-URL: Homepage, https://gitlab.com/tarasiewicztomasz/sr-forge
7
+ Project-URL: Source, https://gitlab.com/tarasiewicztomasz/sr-forge/-/tree/main
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch>=2.7
11
+ Requires-Dist: torchvision>=0.22
12
+ Requires-Dist: torchaudio>=2.7
13
+ Requires-Dist: torch-geometric>=2.6
14
+ Requires-Dist: pyg-lib>=0.4
15
+ Requires-Dist: torch-scatter>=2.1
16
+ Requires-Dist: torch-sparse>=0.6
17
+ Requires-Dist: torch-cluster>=1.6
18
+ Requires-Dist: torch-spline-conv>=1.2
19
+ Requires-Dist: torchdatasets>=0.2
20
+ Requires-Dist: torchmetrics>=1.7
21
+ Requires-Dist: hydra-core>=1.3
22
+ Requires-Dist: einops>=0.8
23
+ Requires-Dist: wandb>=0.21
24
+ Requires-Dist: matplotlib>=3.10
25
+ Requires-Dist: scikit-image>=0.25
26
+ Requires-Dist: opencv-python>=4.12
27
+ Requires-Dist: pandas>=2.3
28
+ Requires-Dist: colorlog>=6.9
29
+ Provides-Extra: cuda128
30
+ Requires-Dist: torch==2.7.1+cu128; extra == "cuda128"
31
+ Requires-Dist: torchvision==0.22.1+cu128; extra == "cuda128"
32
+ Requires-Dist: torchaudio==2.7.1+cu128; extra == "cuda128"
33
+ Requires-Dist: torch-geometric==2.6.1; extra == "cuda128"
34
+ Requires-Dist: pyg-lib==0.4.0+pt27cu128; extra == "cuda128"
35
+ Requires-Dist: torch-scatter==2.1.2+pt27cu128; extra == "cuda128"
36
+ Requires-Dist: torch-sparse==0.6.18+pt27cu128; extra == "cuda128"
37
+ Requires-Dist: torch-cluster==1.6.3+pt27cu128; extra == "cuda128"
38
+ Requires-Dist: torch-spline-conv==1.2.2+pt27cu128; extra == "cuda128"
39
+
40
+ # SR FORGE
41
+ **Super-Resolution Framework for Oriented Restoration and Guided Enhancement**
42
+
43
+ ---
44
+
45
+ SR FORGE (**S**uper-**R**esolution **F**ramework for **O**riented **R**estoration & **G**uided **E**nhancement) is a unified, modular, and task-driven framework for training and evaluating deep learning models in the field of super-resolution.
46
+
47
+ ## Key Features
48
+
49
+ - **Structured Workflow**
50
+ SR FORGE provides an **organized** approach to super resolution. Every step—from data loading to final evaluation—follows a clear, modular structure.
51
+
52
+ - **Task-driven restoration**
53
+ Built-in utilities to help fine-tune models for specific tasks or objectives (e.g., OCR, remote sensing, medical imaging, etc.).
54
+
55
+ - **Config-Driven Experiments**
56
+ Simple YAML/JSON configuration files let you customize your pipeline without modifying code directly.
57
+
58
+ - **Flexible Model Plug-In**
59
+ Easily integrate popular SISR (EDSR, RCAN, ESRGAN, etc.) and MISR (RAMS, HighRes-net, PIUNET, TR-MISR, MagNAt) or your own custom architecture.
60
+
61
+ - **Unified Metrics**
62
+ Evaluate your models with a suite of standard metrics (PSNR, SSIM, LPIPS) and straightforward logging.
63
+
64
+ - **Visualization Tools**
65
+ Quickly visualize results (side-by-side comparisons, zoom-ins, or overlays) for interpretability and debugging.
66
+
67
+ ## Installation
68
+
69
+ 1. **Clone the Repository**
70
+ ```bash
71
+ git clone https://github.com/your-username/sr-forge.git
@@ -0,0 +1,14 @@
1
+ README.md
2
+ pyproject.toml
3
+ scripts/__init__.py
4
+ scripts/test.py
5
+ scripts/train.py
6
+ srforge/__init__.py
7
+ srforge/registry.py
8
+ srforge/structs.py
9
+ srforge.egg-info/PKG-INFO
10
+ srforge.egg-info/SOURCES.txt
11
+ srforge.egg-info/dependency_links.txt
12
+ srforge.egg-info/entry_points.txt
13
+ srforge.egg-info/requires.txt
14
+ srforge.egg-info/top_level.txt
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ test = scripts.test:main
3
+ train = scripts.train:main
@@ -0,0 +1,30 @@
1
+ torch>=2.7
2
+ torchvision>=0.22
3
+ torchaudio>=2.7
4
+ torch-geometric>=2.6
5
+ pyg-lib>=0.4
6
+ torch-scatter>=2.1
7
+ torch-sparse>=0.6
8
+ torch-cluster>=1.6
9
+ torch-spline-conv>=1.2
10
+ torchdatasets>=0.2
11
+ torchmetrics>=1.7
12
+ hydra-core>=1.3
13
+ einops>=0.8
14
+ wandb>=0.21
15
+ matplotlib>=3.10
16
+ scikit-image>=0.25
17
+ opencv-python>=4.12
18
+ pandas>=2.3
19
+ colorlog>=6.9
20
+
21
+ [cuda128]
22
+ torch==2.7.1+cu128
23
+ torchvision==0.22.1+cu128
24
+ torchaudio==2.7.1+cu128
25
+ torch-geometric==2.6.1
26
+ pyg-lib==0.4.0+pt27cu128
27
+ torch-scatter==2.1.2+pt27cu128
28
+ torch-sparse==0.6.18+pt27cu128
29
+ torch-cluster==1.6.3+pt27cu128
30
+ torch-spline-conv==1.2.2+pt27cu128
@@ -0,0 +1,2 @@
1
+ scripts
2
+ srforge