euler-train 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,83 @@
1
+ """runlog — lightweight file-based experiment logging."""
2
+ from __future__ import annotations
3
+
4
+ from .run import Run
5
+ from .architecture import export_architecture
6
+
7
+ __all__ = ["init", "Run", "export_architecture"]
8
+ __version__ = "0.1.0"
9
+
10
+
11
+ def init(
12
+ dir: str | None = None,
13
+ config=None,
14
+ meta: dict | None = None,
15
+ output_formats: dict[str, str] | None = None,
16
+ run_id: str | None = None,
17
+ datasets: dict | None = None,
18
+ run_name: str | None = None,
19
+ evaluations: dict[str, dict] | None = None,
20
+ mode: str | None = None,
21
+ ) -> Run:
22
+ """Create a new run — or resume an existing one — and return the handle.
23
+
24
+ Parameters
25
+ ----------
26
+ dir:
27
+ Project / output directory. Each call creates a unique run
28
+ under ``{dir}/runs/{timestamp_id}/``. When *None* (the
29
+ default), the directory is resolved as:
30
+
31
+ 1. ``$ET_HOME/<project>`` (if ``$ET_HOME`` is set),
32
+ 2. ``~/euler_train/<project>``,
33
+
34
+ where ``<project>`` is the git repository name, or the current
35
+ working directory name when not inside a git repo.
36
+ config:
37
+ Hyperparameters — accepts a *dict*, a path to a JSON / YAML file,
38
+ an ``argparse.Namespace``, or a dataclass instance.
39
+ meta:
40
+ Extra user-defined fields merged into ``meta.json``
41
+ (e.g. ``{"description": "baseline", "tags": ["v2"]}``).
42
+ output_formats:
43
+ Override auto-inferred save formats. Keys can be an output type
44
+ (``"depth"``), a slot / aux name (``"transmission"``), or a
45
+ dotted combination (``"depth.pred"``). Values are ``"png"``,
46
+ ``"npy"``, or ``"npz"``.
47
+ run_id:
48
+ If given, resume an existing run instead of creating a new one.
49
+ The run directory ``{dir}/runs/{run_id}/`` must already exist.
50
+ The existing ``config.json`` is loaded automatically (unless
51
+ *config* is explicitly provided to override it).
52
+ datasets:
53
+ Optional mapping of split name to ``euler_loading.MultiModalDataset``
54
+ instance (e.g. ``{"train": train_ds, "val": val_ds}``). When
55
+ provided, each split is logged into ``meta.json`` under
56
+ ``datasets[split]`` with per-modality records:
57
+ ``path`` and inferred metadata (``used_as``, ``slot``,
58
+ ``modality_type``). Hierarchical modalities also include
59
+ ``hierarchy_scope`` and ``applies_to``. If a dataset implements
60
+ ``describe_for_runlog()``, that contract is used directly.
61
+ Otherwise inference prefers
62
+ ``ds-crawler`` config properties when available, then falls back to
63
+ naming-based heuristics.
64
+ run_name:
65
+ Optional human-readable name for the run. Stored in ``meta.json``.
66
+ evaluations:
67
+ Optional mapping of evaluation key to evaluation entry. Each
68
+ entry may contain ``datasets`` (same dataset objects accepted by
69
+ *datasets*), ``name``, ``status``, ``checkpoint``, and
70
+ ``metadata``. Typically used when resuming a run (via *run_id*)
71
+ for evaluation. See also :meth:`Run.add_evaluation`.
72
+ mode:
73
+ Optional label for the current process context (for example
74
+ ``"train"``, ``"val"``, or ``"eval"``). When provided,
75
+ lifecycle fields and crash details are mirrored into
76
+ ``meta.json`` under ``modes[mode]``.
77
+ """
78
+ return Run(
79
+ dir=dir, config=config, meta=meta,
80
+ output_formats=output_formats, run_id=run_id,
81
+ datasets=datasets, run_name=run_name,
82
+ evaluations=evaluations, mode=mode,
83
+ )
@@ -0,0 +1,125 @@
1
+ """Export a PyTorch model to a lightweight ONNX graph for Netron visualization."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import tempfile
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ log = logging.getLogger("euler_train")
10
+
11
+ _MISSING_DEPS_MSG = (
12
+ "Architecture export requires optional dependencies: onnx, onnxruntime, onnxsim. "
13
+ "Install them with: pip install euler-train[architecture]"
14
+ )
15
+
16
+
17
+ def export_architecture(
18
+ model: Any,
19
+ dummy_input: Any,
20
+ output_path: str | Path = "architecture.onnx",
21
+ ) -> Path:
22
+ """Export a PyTorch model to a simplified, weightless ONNX graph.
23
+
24
+ The resulting file is optimized for visual inspection in Netron:
25
+ redundant nodes are removed, operator fusions are applied, and
26
+ weight tensors are stripped so only the graph topology remains.
27
+
28
+ Parameters
29
+ ----------
30
+ model:
31
+ A PyTorch ``nn.Module``. Temporarily set to eval mode for
32
+ export; the original training/eval state is restored afterward.
33
+ dummy_input:
34
+ Example input tensor(s) matching the model's forward signature.
35
+ output_path:
36
+ Where to write the final ``.onnx`` file.
37
+
38
+ Returns
39
+ -------
40
+ Path
41
+ The written output path.
42
+ """
43
+ import torch
44
+
45
+ try:
46
+ import onnx
47
+ import onnxruntime as ort
48
+ from onnxsim import simplify
49
+ except ImportError:
50
+ raise ImportError(_MISSING_DEPS_MSG)
51
+
52
+ output_path = Path(output_path)
53
+ output_path.parent.mkdir(parents=True, exist_ok=True)
54
+
55
+ was_training = model.training
56
+ model.eval()
57
+
58
+ try:
59
+ return _do_export(model, dummy_input, output_path, onnx, ort, simplify, torch)
60
+ finally:
61
+ if was_training:
62
+ model.train()
63
+
64
+
65
+ def _do_export(model, dummy_input, output_path, onnx, ort, simplify, torch):
66
+ with tempfile.TemporaryDirectory() as tmpdir:
67
+ raw_path = Path(tmpdir) / "raw.onnx"
68
+ ort_path = Path(tmpdir) / "ort.onnx"
69
+
70
+ # Step 1: Export to ONNX (with weights, needed for optimizer passes)
71
+ _export_onnx(model, dummy_input, raw_path, torch=torch)
72
+
73
+ # Step 2: Simplify — removes redundant glue nodes
74
+ log.info("Simplifying ONNX graph with onnxsim...")
75
+ raw_model = onnx.load(str(raw_path))
76
+ simplified_model, check = simplify(raw_model)
77
+ if not check:
78
+ log.warning("onnxsim validation failed, continuing anyway.")
79
+ onnx.save(simplified_model, str(raw_path))
80
+
81
+ # Step 3: ORT optimization — fuses standard blocks (Conv+BN+ReLU, etc.)
82
+ log.info("Applying ONNX Runtime graph optimizations...")
83
+ sess_options = ort.SessionOptions()
84
+ sess_options.graph_optimization_level = (
85
+ ort.GraphOptimizationLevel.ORT_ENABLE_ALL
86
+ )
87
+ sess_options.optimized_model_filepath = str(ort_path)
88
+ ort.InferenceSession(str(raw_path), sess_options)
89
+
90
+ # Step 4: Strip weights for a lightweight file
91
+ log.info("Stripping weights...")
92
+ fused_model = onnx.load(str(ort_path))
93
+ while fused_model.graph.initializer:
94
+ fused_model.graph.initializer.pop()
95
+
96
+ onnx.save(fused_model, str(output_path))
97
+
98
+ log.info("Architecture exported to %s", output_path)
99
+ return output_path
100
+
101
+
102
+ def _export_onnx(model: Any, dummy_input: Any, path: Path, *, torch: Any) -> None:
103
+ """Export using dynamo (PyTorch >= 2.1) or legacy torch.onnx.export."""
104
+ # Try dynamo-based export first (produces a cleaner functional graph)
105
+ if hasattr(torch.onnx, "dynamo_export"):
106
+ try:
107
+ log.info("Exporting ONNX via torch.onnx.dynamo_export (PyTorch 2.x)...")
108
+ export_output = torch.onnx.dynamo_export(model, dummy_input)
109
+ export_output.save(str(path))
110
+ return
111
+ except Exception as exc:
112
+ log.warning(
113
+ "dynamo_export failed (%s), falling back to legacy export.", exc
114
+ )
115
+
116
+ # Legacy export (PyTorch 1.x / 2.0 / dynamo fallback)
117
+ log.info("Exporting ONNX via torch.onnx.export (legacy)...")
118
+ torch.onnx.export(
119
+ model,
120
+ dummy_input,
121
+ str(path),
122
+ export_params=True,
123
+ opset_version=14,
124
+ do_constant_folding=True,
125
+ )
@@ -0,0 +1,137 @@
1
+ """Collect runtime environment metadata for run_environment.json."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import platform
6
+ import subprocess
7
+ import sys
8
+
9
+
10
+ def get_run_environment() -> dict:
11
+ """Return a dict matching the run_environments schema."""
12
+ return {
13
+ "name": _hostname(),
14
+ "python_version": sys.version.split()[0],
15
+ "cuda_version": _cuda_version(),
16
+ "gpu_type": _gpu_type(),
17
+ "gpu_count": _gpu_count(),
18
+ "packages_snapshot": _packages_snapshot(),
19
+ "docker_image": None,
20
+ "docker_digest": None,
21
+ "metadata": None,
22
+ }
23
+
24
+
25
+ def _hostname() -> str | None:
26
+ try:
27
+ return platform.node() or None
28
+ except Exception:
29
+ return None
30
+
31
+
32
+ def _cuda_version() -> str | None:
33
+ # 1. torch.version.cuda
34
+ try:
35
+ import torch
36
+ if torch.version.cuda:
37
+ return torch.version.cuda
38
+ except Exception:
39
+ pass
40
+
41
+ # 2. nvcc --version
42
+ try:
43
+ result = subprocess.run(
44
+ ["nvcc", "--version"],
45
+ capture_output=True, text=True, timeout=10,
46
+ )
47
+ if result.returncode == 0:
48
+ for line in result.stdout.splitlines():
49
+ if "release" in line.lower():
50
+ # e.g. "Cuda compilation tools, release 12.1, V12.1.66"
51
+ parts = line.split("release")[-1].strip().split(",")
52
+ return parts[0].strip()
53
+ except Exception:
54
+ pass
55
+
56
+ # 3. CUDA_VERSION env var
57
+ return os.environ.get("CUDA_VERSION")
58
+
59
+
60
+ def _gpu_type() -> str | None:
61
+ # 1. pynvml
62
+ try:
63
+ import pynvml
64
+ pynvml.nvmlInit()
65
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
66
+ name = pynvml.nvmlDeviceGetName(handle)
67
+ if isinstance(name, bytes):
68
+ name = name.decode()
69
+ return name
70
+ except Exception:
71
+ pass
72
+
73
+ # 2. nvidia-smi
74
+ try:
75
+ result = subprocess.run(
76
+ ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
77
+ capture_output=True, text=True, timeout=10,
78
+ )
79
+ if result.returncode == 0:
80
+ first_line = result.stdout.strip().splitlines()[0].strip()
81
+ return first_line or None
82
+ except Exception:
83
+ pass
84
+
85
+ return None
86
+
87
+
88
+ def _gpu_count() -> int | None:
89
+ # 1. torch.cuda.device_count()
90
+ try:
91
+ import torch
92
+ count = torch.cuda.device_count()
93
+ if count > 0:
94
+ return count
95
+ except Exception:
96
+ pass
97
+
98
+ # 2. pynvml
99
+ try:
100
+ import pynvml
101
+ pynvml.nvmlInit()
102
+ return pynvml.nvmlDeviceGetCount()
103
+ except Exception:
104
+ pass
105
+
106
+ # 3. CUDA_VISIBLE_DEVICES
107
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
108
+ if visible is not None:
109
+ devices = [d.strip() for d in visible.split(",") if d.strip()]
110
+ if devices:
111
+ return len(devices)
112
+
113
+ return None
114
+
115
+
116
+ def _packages_snapshot() -> dict[str, str] | None:
117
+ # Try uv pip freeze, then pip freeze
118
+ for cmd in (["uv", "pip", "freeze"], ["pip", "freeze"]):
119
+ try:
120
+ result = subprocess.run(
121
+ cmd, capture_output=True, text=True, timeout=30,
122
+ )
123
+ if result.returncode == 0 and result.stdout.strip():
124
+ return _parse_freeze(result.stdout)
125
+ except Exception:
126
+ continue
127
+ return None
128
+
129
+
130
+ def _parse_freeze(output: str) -> dict[str, str]:
131
+ packages = {}
132
+ for line in output.strip().splitlines():
133
+ line = line.strip()
134
+ if "==" in line:
135
+ name, _, version = line.partition("==")
136
+ packages[name.strip()] = version.strip()
137
+ return packages
@@ -0,0 +1,43 @@
1
+ """Collect git repository metadata for code_ref.json."""
2
+ from __future__ import annotations
3
+
4
+ import subprocess
5
+
6
+
7
+ def get_code_ref() -> dict:
8
+ """Return a dict matching the code_refs schema."""
9
+ return {
10
+ "repo_url": _git("config", "--get", "remote.origin.url"),
11
+ "branch": _git("rev-parse", "--abbrev-ref", "HEAD"),
12
+ "commit_sha": _git("rev-parse", "HEAD"),
13
+ "is_dirty": _is_dirty(),
14
+ "dirty_diff": _dirty_diff(),
15
+ "commit_message": _git("log", "-1", "--format=%B"),
16
+ "committed_at": _git("log", "-1", "--format=%aI"),
17
+ }
18
+
19
+
20
+ def _git(*args: str) -> str | None:
21
+ try:
22
+ result = subprocess.run(
23
+ ["git", *args],
24
+ capture_output=True,
25
+ text=True,
26
+ timeout=10,
27
+ )
28
+ if result.returncode != 0:
29
+ return None
30
+ return result.stdout.strip() or None
31
+ except Exception:
32
+ return None
33
+
34
+
35
+ def _is_dirty() -> bool:
36
+ porcelain = _git("status", "--porcelain")
37
+ return porcelain is not None
38
+
39
+
40
+ def _dirty_diff() -> str | None:
41
+ if not _is_dirty():
42
+ return None
43
+ return _git("diff", "HEAD")
euler_train/outputs.py ADDED
@@ -0,0 +1,194 @@
1
+ """Save prediction / ground-truth / auxiliary outputs to disk."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Public entry point
12
+ # ---------------------------------------------------------------------------
13
+
14
+ def save_output_tree(
15
+ type_dir: Path,
16
+ slots: dict[str, Any],
17
+ format_overrides: dict[str, str],
18
+ output_type: str,
19
+ ) -> None:
20
+ """Persist all slots (pred, gt, input, aux/…) for one *output_type*.
21
+
22
+ *slots* example::
23
+
24
+ {
25
+ "pred": array,
26
+ "gt": array,
27
+ "aux": {"transmission": array, "attention": array},
28
+ }
29
+ """
30
+ for slot_name, data in slots.items():
31
+ if data is None:
32
+ continue
33
+ if slot_name == "aux" and isinstance(data, dict):
34
+ for aux_name, aux_data in data.items():
35
+ if aux_data is None:
36
+ continue
37
+ _save_slot(
38
+ type_dir / "aux" / aux_name, aux_data,
39
+ output_type, aux_name, format_overrides,
40
+ )
41
+ else:
42
+ _save_slot(
43
+ type_dir / slot_name, data,
44
+ output_type, slot_name, format_overrides,
45
+ )
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Internal helpers
50
+ # ---------------------------------------------------------------------------
51
+
52
+ def _save_slot(
53
+ slot_dir: Path,
54
+ data: Any,
55
+ output_type: str,
56
+ leaf_name: str,
57
+ format_overrides: dict[str, str],
58
+ ) -> None:
59
+ slot_dir.mkdir(parents=True, exist_ok=True)
60
+ items = _unpack(data)
61
+ for idx, item in enumerate(items):
62
+ fmt = _resolve_format(item, output_type, leaf_name, format_overrides)
63
+ _save_item(slot_dir / f"{idx:04d}.{fmt}", item, fmt)
64
+
65
+
66
+ # ---- normalisation -------------------------------------------------------
67
+
68
+ def _unpack(data: Any) -> list:
69
+ """Normalise *data* into a flat list of saveable items."""
70
+ if isinstance(data, (list, tuple)):
71
+ return [_prepare(d) for d in data]
72
+ prepared = _prepare(data)
73
+ # 4-D numpy → treat as batch
74
+ if isinstance(prepared, np.ndarray) and prepared.ndim == 4:
75
+ return [prepared[i] for i in range(prepared.shape[0])]
76
+ return [prepared]
77
+
78
+
79
+ def _prepare(data: Any) -> Any:
80
+ """Convert torch tensors → numpy; pass PIL images through unchanged."""
81
+ # PIL Image — return as-is
82
+ try:
83
+ from PIL import Image as _PIL
84
+ if isinstance(data, _PIL.Image):
85
+ return data
86
+ except ImportError:
87
+ pass
88
+
89
+ # torch Tensor → numpy, channels-first → channels-last
90
+ if hasattr(data, "detach"):
91
+ arr: np.ndarray = data.detach().cpu().numpy()
92
+ # (C, H, W) → (H, W, C) when C looks like a channel dim
93
+ if (
94
+ arr.ndim == 3
95
+ and arr.shape[0] in (1, 3, 4)
96
+ and min(arr.shape[1:]) > 4
97
+ ):
98
+ arr = np.transpose(arr, (1, 2, 0))
99
+ # (B, C, H, W) → (B, H, W, C)
100
+ elif (
101
+ arr.ndim == 4
102
+ and arr.shape[1] in (1, 3, 4)
103
+ and min(arr.shape[2:]) > 4
104
+ ):
105
+ arr = np.transpose(arr, (0, 2, 3, 1))
106
+ return arr
107
+
108
+ return np.asarray(data)
109
+
110
+
111
+ # ---- format inference ----------------------------------------------------
112
+
113
+ def _is_image_like(arr: np.ndarray) -> bool:
114
+ """Heuristic: does this array look like it should be saved as a PNG?"""
115
+ if arr.ndim == 2 and arr.dtype == np.uint8:
116
+ return True # grayscale uint8
117
+ if arr.ndim == 3 and arr.shape[2] in (1, 3, 4):
118
+ return True # HxWx{1,3,4}
119
+ return False
120
+
121
+
122
+ def _resolve_format(
123
+ item: Any,
124
+ output_type: str,
125
+ leaf_name: str,
126
+ overrides: dict[str, str],
127
+ ) -> str:
128
+ """Pick save format: check overrides (most-specific first), then infer."""
129
+ # "rgb.pred" > "rgb" > "pred"
130
+ specific = f"{output_type}.{leaf_name}"
131
+ if specific in overrides:
132
+ return overrides[specific]
133
+ if output_type in overrides:
134
+ return overrides[output_type]
135
+ if leaf_name in overrides:
136
+ return overrides[leaf_name]
137
+
138
+ # PIL Image
139
+ try:
140
+ from PIL import Image as _PIL
141
+ if isinstance(item, _PIL.Image):
142
+ return "png"
143
+ except ImportError:
144
+ pass
145
+
146
+ if isinstance(item, np.ndarray) and _is_image_like(item):
147
+ return "png"
148
+ return "npy"
149
+
150
+
151
+ # ---- writers -------------------------------------------------------------
152
+
153
+ def _save_item(path: Path, item: Any, fmt: str) -> None:
154
+ if fmt == "png":
155
+ _save_png(path, item)
156
+ elif fmt == "npy":
157
+ np.save(str(path), item if isinstance(item, np.ndarray) else np.asarray(item))
158
+ elif fmt == "npz":
159
+ np.savez_compressed(
160
+ str(path),
161
+ data=item if isinstance(item, np.ndarray) else np.asarray(item),
162
+ )
163
+ else:
164
+ raise ValueError(f"Unsupported format: {fmt!r}")
165
+
166
+
167
+ def _save_png(path: Path, item: Any) -> None:
168
+ from PIL import Image
169
+
170
+ # PIL Image — save directly
171
+ if isinstance(item, Image.Image):
172
+ item.save(str(path))
173
+ return
174
+
175
+ arr: np.ndarray = item
176
+
177
+ # float → [0,1] → uint8
178
+ if np.issubdtype(arr.dtype, np.floating):
179
+ arr = np.clip(arr, 0.0, 1.0)
180
+ arr = (arr * 255).astype(np.uint8)
181
+ elif arr.dtype != np.uint8:
182
+ arr = arr.astype(np.uint8)
183
+
184
+ if arr.ndim == 2:
185
+ Image.fromarray(arr, mode="L").save(str(path))
186
+ elif arr.ndim == 3:
187
+ c = arr.shape[2]
188
+ mode = {1: "L", 3: "RGB", 4: "RGBA"}.get(c)
189
+ if mode is None:
190
+ raise ValueError(f"Cannot save array with {c} channels as PNG")
191
+ plane = arr[:, :, 0] if c == 1 else arr
192
+ Image.fromarray(plane, mode=mode).save(str(path))
193
+ else:
194
+ raise ValueError(f"Cannot save {arr.ndim}D array as PNG")