euler-train 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euler_train/__init__.py +83 -0
- euler_train/architecture.py +125 -0
- euler_train/environment.py +137 -0
- euler_train/git_info.py +43 -0
- euler_train/outputs.py +194 -0
- euler_train/run.py +1249 -0
- euler_train/serialization.py +86 -0
- euler_train/slurm.py +27 -0
- euler_train-1.3.1.dist-info/METADATA +22 -0
- euler_train-1.3.1.dist-info/RECORD +12 -0
- euler_train-1.3.1.dist-info/WHEEL +5 -0
- euler_train-1.3.1.dist-info/top_level.txt +1 -0
euler_train/__init__.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""runlog — lightweight file-based experiment logging."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from .run import Run
|
|
5
|
+
from .architecture import export_architecture
|
|
6
|
+
|
|
7
|
+
__all__ = ["init", "Run", "export_architecture"]
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def init(
|
|
12
|
+
dir: str | None = None,
|
|
13
|
+
config=None,
|
|
14
|
+
meta: dict | None = None,
|
|
15
|
+
output_formats: dict[str, str] | None = None,
|
|
16
|
+
run_id: str | None = None,
|
|
17
|
+
datasets: dict | None = None,
|
|
18
|
+
run_name: str | None = None,
|
|
19
|
+
evaluations: dict[str, dict] | None = None,
|
|
20
|
+
mode: str | None = None,
|
|
21
|
+
) -> Run:
|
|
22
|
+
"""Create a new run — or resume an existing one — and return the handle.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
dir:
|
|
27
|
+
Project / output directory. Each call creates a unique run
|
|
28
|
+
under ``{dir}/runs/{timestamp_id}/``. When *None* (the
|
|
29
|
+
default), the directory is resolved as:
|
|
30
|
+
|
|
31
|
+
1. ``$ET_HOME/<project>`` (if ``$ET_HOME`` is set),
|
|
32
|
+
2. ``~/euler_train/<project>``,
|
|
33
|
+
|
|
34
|
+
where ``<project>`` is the git repository name, or the current
|
|
35
|
+
working directory name when not inside a git repo.
|
|
36
|
+
config:
|
|
37
|
+
Hyperparameters — accepts a *dict*, a path to a JSON / YAML file,
|
|
38
|
+
an ``argparse.Namespace``, or a dataclass instance.
|
|
39
|
+
meta:
|
|
40
|
+
Extra user-defined fields merged into ``meta.json``
|
|
41
|
+
(e.g. ``{"description": "baseline", "tags": ["v2"]}``).
|
|
42
|
+
output_formats:
|
|
43
|
+
Override auto-inferred save formats. Keys can be an output type
|
|
44
|
+
(``"depth"``), a slot / aux name (``"transmission"``), or a
|
|
45
|
+
dotted combination (``"depth.pred"``). Values are ``"png"``,
|
|
46
|
+
``"npy"``, or ``"npz"``.
|
|
47
|
+
run_id:
|
|
48
|
+
If given, resume an existing run instead of creating a new one.
|
|
49
|
+
The run directory ``{dir}/runs/{run_id}/`` must already exist.
|
|
50
|
+
The existing ``config.json`` is loaded automatically (unless
|
|
51
|
+
*config* is explicitly provided to override it).
|
|
52
|
+
datasets:
|
|
53
|
+
Optional mapping of split name to ``euler_loading.MultiModalDataset``
|
|
54
|
+
instance (e.g. ``{"train": train_ds, "val": val_ds}``). When
|
|
55
|
+
provided, each split is logged into ``meta.json`` under
|
|
56
|
+
``datasets[split]`` with per-modality records:
|
|
57
|
+
``path`` and inferred metadata (``used_as``, ``slot``,
|
|
58
|
+
``modality_type``). Hierarchical modalities also include
|
|
59
|
+
``hierarchy_scope`` and ``applies_to``. If a dataset implements
|
|
60
|
+
``describe_for_runlog()``, that contract is used directly.
|
|
61
|
+
Otherwise inference prefers
|
|
62
|
+
``ds-crawler`` config properties when available, then falls back to
|
|
63
|
+
naming-based heuristics.
|
|
64
|
+
run_name:
|
|
65
|
+
Optional human-readable name for the run. Stored in ``meta.json``.
|
|
66
|
+
evaluations:
|
|
67
|
+
Optional mapping of evaluation key to evaluation entry. Each
|
|
68
|
+
entry may contain ``datasets`` (same dataset objects accepted by
|
|
69
|
+
*datasets*), ``name``, ``status``, ``checkpoint``, and
|
|
70
|
+
``metadata``. Typically used when resuming a run (via *run_id*)
|
|
71
|
+
for evaluation. See also :meth:`Run.add_evaluation`.
|
|
72
|
+
mode:
|
|
73
|
+
Optional label for the current process context (for example
|
|
74
|
+
``"train"``, ``"val"``, or ``"eval"``). When provided,
|
|
75
|
+
lifecycle fields and crash details are mirrored into
|
|
76
|
+
``meta.json`` under ``modes[mode]``.
|
|
77
|
+
"""
|
|
78
|
+
return Run(
|
|
79
|
+
dir=dir, config=config, meta=meta,
|
|
80
|
+
output_formats=output_formats, run_id=run_id,
|
|
81
|
+
datasets=datasets, run_name=run_name,
|
|
82
|
+
evaluations=evaluations, mode=mode,
|
|
83
|
+
)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Export a PyTorch model to a lightweight ONNX graph for Netron visualization."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import tempfile
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger("euler_train")
|
|
10
|
+
|
|
11
|
+
_MISSING_DEPS_MSG = (
|
|
12
|
+
"Architecture export requires optional dependencies: onnx, onnxruntime, onnxsim. "
|
|
13
|
+
"Install them with: pip install euler-train[architecture]"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def export_architecture(
|
|
18
|
+
model: Any,
|
|
19
|
+
dummy_input: Any,
|
|
20
|
+
output_path: str | Path = "architecture.onnx",
|
|
21
|
+
) -> Path:
|
|
22
|
+
"""Export a PyTorch model to a simplified, weightless ONNX graph.
|
|
23
|
+
|
|
24
|
+
The resulting file is optimized for visual inspection in Netron:
|
|
25
|
+
redundant nodes are removed, operator fusions are applied, and
|
|
26
|
+
weight tensors are stripped so only the graph topology remains.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
model:
|
|
31
|
+
A PyTorch ``nn.Module``. Temporarily set to eval mode for
|
|
32
|
+
export; the original training/eval state is restored afterward.
|
|
33
|
+
dummy_input:
|
|
34
|
+
Example input tensor(s) matching the model's forward signature.
|
|
35
|
+
output_path:
|
|
36
|
+
Where to write the final ``.onnx`` file.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
Path
|
|
41
|
+
The written output path.
|
|
42
|
+
"""
|
|
43
|
+
import torch
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import onnx
|
|
47
|
+
import onnxruntime as ort
|
|
48
|
+
from onnxsim import simplify
|
|
49
|
+
except ImportError:
|
|
50
|
+
raise ImportError(_MISSING_DEPS_MSG)
|
|
51
|
+
|
|
52
|
+
output_path = Path(output_path)
|
|
53
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
|
|
55
|
+
was_training = model.training
|
|
56
|
+
model.eval()
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
return _do_export(model, dummy_input, output_path, onnx, ort, simplify, torch)
|
|
60
|
+
finally:
|
|
61
|
+
if was_training:
|
|
62
|
+
model.train()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _do_export(model, dummy_input, output_path, onnx, ort, simplify, torch):
|
|
66
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
67
|
+
raw_path = Path(tmpdir) / "raw.onnx"
|
|
68
|
+
ort_path = Path(tmpdir) / "ort.onnx"
|
|
69
|
+
|
|
70
|
+
# Step 1: Export to ONNX (with weights, needed for optimizer passes)
|
|
71
|
+
_export_onnx(model, dummy_input, raw_path, torch=torch)
|
|
72
|
+
|
|
73
|
+
# Step 2: Simplify — removes redundant glue nodes
|
|
74
|
+
log.info("Simplifying ONNX graph with onnxsim...")
|
|
75
|
+
raw_model = onnx.load(str(raw_path))
|
|
76
|
+
simplified_model, check = simplify(raw_model)
|
|
77
|
+
if not check:
|
|
78
|
+
log.warning("onnxsim validation failed, continuing anyway.")
|
|
79
|
+
onnx.save(simplified_model, str(raw_path))
|
|
80
|
+
|
|
81
|
+
# Step 3: ORT optimization — fuses standard blocks (Conv+BN+ReLU, etc.)
|
|
82
|
+
log.info("Applying ONNX Runtime graph optimizations...")
|
|
83
|
+
sess_options = ort.SessionOptions()
|
|
84
|
+
sess_options.graph_optimization_level = (
|
|
85
|
+
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
86
|
+
)
|
|
87
|
+
sess_options.optimized_model_filepath = str(ort_path)
|
|
88
|
+
ort.InferenceSession(str(raw_path), sess_options)
|
|
89
|
+
|
|
90
|
+
# Step 4: Strip weights for a lightweight file
|
|
91
|
+
log.info("Stripping weights...")
|
|
92
|
+
fused_model = onnx.load(str(ort_path))
|
|
93
|
+
while fused_model.graph.initializer:
|
|
94
|
+
fused_model.graph.initializer.pop()
|
|
95
|
+
|
|
96
|
+
onnx.save(fused_model, str(output_path))
|
|
97
|
+
|
|
98
|
+
log.info("Architecture exported to %s", output_path)
|
|
99
|
+
return output_path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _export_onnx(model: Any, dummy_input: Any, path: Path, *, torch: Any) -> None:
|
|
103
|
+
"""Export using dynamo (PyTorch >= 2.1) or legacy torch.onnx.export."""
|
|
104
|
+
# Try dynamo-based export first (produces a cleaner functional graph)
|
|
105
|
+
if hasattr(torch.onnx, "dynamo_export"):
|
|
106
|
+
try:
|
|
107
|
+
log.info("Exporting ONNX via torch.onnx.dynamo_export (PyTorch 2.x)...")
|
|
108
|
+
export_output = torch.onnx.dynamo_export(model, dummy_input)
|
|
109
|
+
export_output.save(str(path))
|
|
110
|
+
return
|
|
111
|
+
except Exception as exc:
|
|
112
|
+
log.warning(
|
|
113
|
+
"dynamo_export failed (%s), falling back to legacy export.", exc
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Legacy export (PyTorch 1.x / 2.0 / dynamo fallback)
|
|
117
|
+
log.info("Exporting ONNX via torch.onnx.export (legacy)...")
|
|
118
|
+
torch.onnx.export(
|
|
119
|
+
model,
|
|
120
|
+
dummy_input,
|
|
121
|
+
str(path),
|
|
122
|
+
export_params=True,
|
|
123
|
+
opset_version=14,
|
|
124
|
+
do_constant_folding=True,
|
|
125
|
+
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Collect runtime environment metadata for run_environment.json."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import platform
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_run_environment() -> dict:
|
|
11
|
+
"""Return a dict matching the run_environments schema."""
|
|
12
|
+
return {
|
|
13
|
+
"name": _hostname(),
|
|
14
|
+
"python_version": sys.version.split()[0],
|
|
15
|
+
"cuda_version": _cuda_version(),
|
|
16
|
+
"gpu_type": _gpu_type(),
|
|
17
|
+
"gpu_count": _gpu_count(),
|
|
18
|
+
"packages_snapshot": _packages_snapshot(),
|
|
19
|
+
"docker_image": None,
|
|
20
|
+
"docker_digest": None,
|
|
21
|
+
"metadata": None,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _hostname() -> str | None:
|
|
26
|
+
try:
|
|
27
|
+
return platform.node() or None
|
|
28
|
+
except Exception:
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _cuda_version() -> str | None:
|
|
33
|
+
# 1. torch.version.cuda
|
|
34
|
+
try:
|
|
35
|
+
import torch
|
|
36
|
+
if torch.version.cuda:
|
|
37
|
+
return torch.version.cuda
|
|
38
|
+
except Exception:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
# 2. nvcc --version
|
|
42
|
+
try:
|
|
43
|
+
result = subprocess.run(
|
|
44
|
+
["nvcc", "--version"],
|
|
45
|
+
capture_output=True, text=True, timeout=10,
|
|
46
|
+
)
|
|
47
|
+
if result.returncode == 0:
|
|
48
|
+
for line in result.stdout.splitlines():
|
|
49
|
+
if "release" in line.lower():
|
|
50
|
+
# e.g. "Cuda compilation tools, release 12.1, V12.1.66"
|
|
51
|
+
parts = line.split("release")[-1].strip().split(",")
|
|
52
|
+
return parts[0].strip()
|
|
53
|
+
except Exception:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
# 3. CUDA_VERSION env var
|
|
57
|
+
return os.environ.get("CUDA_VERSION")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _gpu_type() -> str | None:
|
|
61
|
+
# 1. pynvml
|
|
62
|
+
try:
|
|
63
|
+
import pynvml
|
|
64
|
+
pynvml.nvmlInit()
|
|
65
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
66
|
+
name = pynvml.nvmlDeviceGetName(handle)
|
|
67
|
+
if isinstance(name, bytes):
|
|
68
|
+
name = name.decode()
|
|
69
|
+
return name
|
|
70
|
+
except Exception:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
# 2. nvidia-smi
|
|
74
|
+
try:
|
|
75
|
+
result = subprocess.run(
|
|
76
|
+
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
|
|
77
|
+
capture_output=True, text=True, timeout=10,
|
|
78
|
+
)
|
|
79
|
+
if result.returncode == 0:
|
|
80
|
+
first_line = result.stdout.strip().splitlines()[0].strip()
|
|
81
|
+
return first_line or None
|
|
82
|
+
except Exception:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _gpu_count() -> int | None:
|
|
89
|
+
# 1. torch.cuda.device_count()
|
|
90
|
+
try:
|
|
91
|
+
import torch
|
|
92
|
+
count = torch.cuda.device_count()
|
|
93
|
+
if count > 0:
|
|
94
|
+
return count
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
# 2. pynvml
|
|
99
|
+
try:
|
|
100
|
+
import pynvml
|
|
101
|
+
pynvml.nvmlInit()
|
|
102
|
+
return pynvml.nvmlDeviceGetCount()
|
|
103
|
+
except Exception:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
# 3. CUDA_VISIBLE_DEVICES
|
|
107
|
+
visible = os.environ.get("CUDA_VISIBLE_DEVICES")
|
|
108
|
+
if visible is not None:
|
|
109
|
+
devices = [d.strip() for d in visible.split(",") if d.strip()]
|
|
110
|
+
if devices:
|
|
111
|
+
return len(devices)
|
|
112
|
+
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _packages_snapshot() -> dict[str, str] | None:
|
|
117
|
+
# Try uv pip freeze, then pip freeze
|
|
118
|
+
for cmd in (["uv", "pip", "freeze"], ["pip", "freeze"]):
|
|
119
|
+
try:
|
|
120
|
+
result = subprocess.run(
|
|
121
|
+
cmd, capture_output=True, text=True, timeout=30,
|
|
122
|
+
)
|
|
123
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
124
|
+
return _parse_freeze(result.stdout)
|
|
125
|
+
except Exception:
|
|
126
|
+
continue
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _parse_freeze(output: str) -> dict[str, str]:
|
|
131
|
+
packages = {}
|
|
132
|
+
for line in output.strip().splitlines():
|
|
133
|
+
line = line.strip()
|
|
134
|
+
if "==" in line:
|
|
135
|
+
name, _, version = line.partition("==")
|
|
136
|
+
packages[name.strip()] = version.strip()
|
|
137
|
+
return packages
|
euler_train/git_info.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Collect git repository metadata for code_ref.json."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import subprocess
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_code_ref() -> dict:
|
|
8
|
+
"""Return a dict matching the code_refs schema."""
|
|
9
|
+
return {
|
|
10
|
+
"repo_url": _git("config", "--get", "remote.origin.url"),
|
|
11
|
+
"branch": _git("rev-parse", "--abbrev-ref", "HEAD"),
|
|
12
|
+
"commit_sha": _git("rev-parse", "HEAD"),
|
|
13
|
+
"is_dirty": _is_dirty(),
|
|
14
|
+
"dirty_diff": _dirty_diff(),
|
|
15
|
+
"commit_message": _git("log", "-1", "--format=%B"),
|
|
16
|
+
"committed_at": _git("log", "-1", "--format=%aI"),
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _git(*args: str) -> str | None:
|
|
21
|
+
try:
|
|
22
|
+
result = subprocess.run(
|
|
23
|
+
["git", *args],
|
|
24
|
+
capture_output=True,
|
|
25
|
+
text=True,
|
|
26
|
+
timeout=10,
|
|
27
|
+
)
|
|
28
|
+
if result.returncode != 0:
|
|
29
|
+
return None
|
|
30
|
+
return result.stdout.strip() or None
|
|
31
|
+
except Exception:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _is_dirty() -> bool:
|
|
36
|
+
porcelain = _git("status", "--porcelain")
|
|
37
|
+
return porcelain is not None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _dirty_diff() -> str | None:
|
|
41
|
+
if not _is_dirty():
|
|
42
|
+
return None
|
|
43
|
+
return _git("diff", "HEAD")
|
euler_train/outputs.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Save prediction / ground-truth / auxiliary outputs to disk."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
# Public entry point
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
|
|
14
|
+
def save_output_tree(
|
|
15
|
+
type_dir: Path,
|
|
16
|
+
slots: dict[str, Any],
|
|
17
|
+
format_overrides: dict[str, str],
|
|
18
|
+
output_type: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Persist all slots (pred, gt, input, aux/…) for one *output_type*.
|
|
21
|
+
|
|
22
|
+
*slots* example::
|
|
23
|
+
|
|
24
|
+
{
|
|
25
|
+
"pred": array,
|
|
26
|
+
"gt": array,
|
|
27
|
+
"aux": {"transmission": array, "attention": array},
|
|
28
|
+
}
|
|
29
|
+
"""
|
|
30
|
+
for slot_name, data in slots.items():
|
|
31
|
+
if data is None:
|
|
32
|
+
continue
|
|
33
|
+
if slot_name == "aux" and isinstance(data, dict):
|
|
34
|
+
for aux_name, aux_data in data.items():
|
|
35
|
+
if aux_data is None:
|
|
36
|
+
continue
|
|
37
|
+
_save_slot(
|
|
38
|
+
type_dir / "aux" / aux_name, aux_data,
|
|
39
|
+
output_type, aux_name, format_overrides,
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
_save_slot(
|
|
43
|
+
type_dir / slot_name, data,
|
|
44
|
+
output_type, slot_name, format_overrides,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
# Internal helpers
|
|
50
|
+
# ---------------------------------------------------------------------------
|
|
51
|
+
|
|
52
|
+
def _save_slot(
|
|
53
|
+
slot_dir: Path,
|
|
54
|
+
data: Any,
|
|
55
|
+
output_type: str,
|
|
56
|
+
leaf_name: str,
|
|
57
|
+
format_overrides: dict[str, str],
|
|
58
|
+
) -> None:
|
|
59
|
+
slot_dir.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
items = _unpack(data)
|
|
61
|
+
for idx, item in enumerate(items):
|
|
62
|
+
fmt = _resolve_format(item, output_type, leaf_name, format_overrides)
|
|
63
|
+
_save_item(slot_dir / f"{idx:04d}.{fmt}", item, fmt)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---- normalisation -------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
def _unpack(data: Any) -> list:
|
|
69
|
+
"""Normalise *data* into a flat list of saveable items."""
|
|
70
|
+
if isinstance(data, (list, tuple)):
|
|
71
|
+
return [_prepare(d) for d in data]
|
|
72
|
+
prepared = _prepare(data)
|
|
73
|
+
# 4-D numpy → treat as batch
|
|
74
|
+
if isinstance(prepared, np.ndarray) and prepared.ndim == 4:
|
|
75
|
+
return [prepared[i] for i in range(prepared.shape[0])]
|
|
76
|
+
return [prepared]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _prepare(data: Any) -> Any:
|
|
80
|
+
"""Convert torch tensors → numpy; pass PIL images through unchanged."""
|
|
81
|
+
# PIL Image — return as-is
|
|
82
|
+
try:
|
|
83
|
+
from PIL import Image as _PIL
|
|
84
|
+
if isinstance(data, _PIL.Image):
|
|
85
|
+
return data
|
|
86
|
+
except ImportError:
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
# torch Tensor → numpy, channels-first → channels-last
|
|
90
|
+
if hasattr(data, "detach"):
|
|
91
|
+
arr: np.ndarray = data.detach().cpu().numpy()
|
|
92
|
+
# (C, H, W) → (H, W, C) when C looks like a channel dim
|
|
93
|
+
if (
|
|
94
|
+
arr.ndim == 3
|
|
95
|
+
and arr.shape[0] in (1, 3, 4)
|
|
96
|
+
and min(arr.shape[1:]) > 4
|
|
97
|
+
):
|
|
98
|
+
arr = np.transpose(arr, (1, 2, 0))
|
|
99
|
+
# (B, C, H, W) → (B, H, W, C)
|
|
100
|
+
elif (
|
|
101
|
+
arr.ndim == 4
|
|
102
|
+
and arr.shape[1] in (1, 3, 4)
|
|
103
|
+
and min(arr.shape[2:]) > 4
|
|
104
|
+
):
|
|
105
|
+
arr = np.transpose(arr, (0, 2, 3, 1))
|
|
106
|
+
return arr
|
|
107
|
+
|
|
108
|
+
return np.asarray(data)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ---- format inference ----------------------------------------------------
|
|
112
|
+
|
|
113
|
+
def _is_image_like(arr: np.ndarray) -> bool:
|
|
114
|
+
"""Heuristic: does this array look like it should be saved as a PNG?"""
|
|
115
|
+
if arr.ndim == 2 and arr.dtype == np.uint8:
|
|
116
|
+
return True # grayscale uint8
|
|
117
|
+
if arr.ndim == 3 and arr.shape[2] in (1, 3, 4):
|
|
118
|
+
return True # HxWx{1,3,4}
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _resolve_format(
|
|
123
|
+
item: Any,
|
|
124
|
+
output_type: str,
|
|
125
|
+
leaf_name: str,
|
|
126
|
+
overrides: dict[str, str],
|
|
127
|
+
) -> str:
|
|
128
|
+
"""Pick save format: check overrides (most-specific first), then infer."""
|
|
129
|
+
# "rgb.pred" > "rgb" > "pred"
|
|
130
|
+
specific = f"{output_type}.{leaf_name}"
|
|
131
|
+
if specific in overrides:
|
|
132
|
+
return overrides[specific]
|
|
133
|
+
if output_type in overrides:
|
|
134
|
+
return overrides[output_type]
|
|
135
|
+
if leaf_name in overrides:
|
|
136
|
+
return overrides[leaf_name]
|
|
137
|
+
|
|
138
|
+
# PIL Image
|
|
139
|
+
try:
|
|
140
|
+
from PIL import Image as _PIL
|
|
141
|
+
if isinstance(item, _PIL.Image):
|
|
142
|
+
return "png"
|
|
143
|
+
except ImportError:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
if isinstance(item, np.ndarray) and _is_image_like(item):
|
|
147
|
+
return "png"
|
|
148
|
+
return "npy"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ---- writers -------------------------------------------------------------
|
|
152
|
+
|
|
153
|
+
def _save_item(path: Path, item: Any, fmt: str) -> None:
|
|
154
|
+
if fmt == "png":
|
|
155
|
+
_save_png(path, item)
|
|
156
|
+
elif fmt == "npy":
|
|
157
|
+
np.save(str(path), item if isinstance(item, np.ndarray) else np.asarray(item))
|
|
158
|
+
elif fmt == "npz":
|
|
159
|
+
np.savez_compressed(
|
|
160
|
+
str(path),
|
|
161
|
+
data=item if isinstance(item, np.ndarray) else np.asarray(item),
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Unsupported format: {fmt!r}")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _save_png(path: Path, item: Any) -> None:
|
|
168
|
+
from PIL import Image
|
|
169
|
+
|
|
170
|
+
# PIL Image — save directly
|
|
171
|
+
if isinstance(item, Image.Image):
|
|
172
|
+
item.save(str(path))
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
arr: np.ndarray = item
|
|
176
|
+
|
|
177
|
+
# float → [0,1] → uint8
|
|
178
|
+
if np.issubdtype(arr.dtype, np.floating):
|
|
179
|
+
arr = np.clip(arr, 0.0, 1.0)
|
|
180
|
+
arr = (arr * 255).astype(np.uint8)
|
|
181
|
+
elif arr.dtype != np.uint8:
|
|
182
|
+
arr = arr.astype(np.uint8)
|
|
183
|
+
|
|
184
|
+
if arr.ndim == 2:
|
|
185
|
+
Image.fromarray(arr, mode="L").save(str(path))
|
|
186
|
+
elif arr.ndim == 3:
|
|
187
|
+
c = arr.shape[2]
|
|
188
|
+
mode = {1: "L", 3: "RGB", 4: "RGBA"}.get(c)
|
|
189
|
+
if mode is None:
|
|
190
|
+
raise ValueError(f"Cannot save array with {c} channels as PNG")
|
|
191
|
+
plane = arr[:, :, 0] if c == 1 else arr
|
|
192
|
+
Image.fromarray(plane, mode=mode).save(str(path))
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(f"Cannot save {arr.ndim}D array as PNG")
|