euler-preprocess 1.8.0__tar.gz → 2.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/PKG-INFO +1 -1
  2. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/cli.py +17 -5
  3. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/dataset.py +17 -25
  4. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/output.py +229 -3
  5. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/transform.py +11 -1
  6. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/models.py +53 -7
  7. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/transform.py +228 -22
  8. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/PKG-INFO +1 -1
  9. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/SOURCES.txt +1 -0
  10. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/pyproject.toml +1 -1
  11. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_dcp_heuristic_airlight.py +2 -2
  12. euler_preprocess-2.0.0/tests/test_fog_aux_outputs.py +395 -0
  13. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/README.md +0 -0
  14. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/__init__.py +0 -0
  15. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/__init__.py +0 -0
  16. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/device.py +0 -0
  17. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/intrinsics.py +0 -0
  18. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/io.py +0 -0
  19. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/logging.py +0 -0
  20. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/noise.py +0 -0
  21. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/normalize.py +0 -0
  22. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/common/sampling.py +0 -0
  23. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/__init__.py +0 -0
  24. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/airlight_from_sky.py +0 -0
  25. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/dcp_airlight.py +0 -0
  26. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/dcp_airlight_torch.py +0 -0
  27. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/dcp_heuristic_airlight.py +0 -0
  28. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/dcp_heuristic_airlight_torch.py +0 -0
  29. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/foggify.py +0 -0
  30. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/foggify_logging.py +0 -0
  31. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/fog/logging.py +0 -0
  32. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/radial/__init__.py +0 -0
  33. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/radial/transform.py +0 -0
  34. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/sky_depth/__init__.py +0 -0
  35. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess/sky_depth/transform.py +0 -0
  36. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/dependency_links.txt +0 -0
  37. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/entry_points.txt +0 -0
  38. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/requires.txt +0 -0
  39. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/euler_preprocess.egg-info/top_level.txt +0 -0
  40. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/setup.cfg +0 -0
  41. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_airlight_fallback.py +0 -0
  42. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_foggify_integration.py +0 -0
  43. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_radial.py +0 -0
  44. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_sky_depth.py +0 -0
  45. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_source_backed_output.py +0 -0
  46. {euler_preprocess-1.8.0 → euler_preprocess-2.0.0}/tests/test_zip_output.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: euler-preprocess
3
- Version: 1.8.0
3
+ Version: 2.0.0
4
4
  Summary: Physics-based preprocessing (fog, etc.) for RGB+depth datasets
5
5
  Requires-Python: >=3.9
6
6
  Description-Content-Type: text/markdown
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
 
13
13
  from euler_preprocess.common.dataset import build_dataset
14
14
  from euler_preprocess.common.logging import get_logger, log_dataset_info
15
- from euler_preprocess.common.output import prepare_output_backend
15
+ from euler_preprocess.common.output import prepare_output_backends
16
16
 
17
17
 
18
18
  # ---------------------------------------------------------------------------
@@ -55,7 +55,8 @@ def _run_transform(args: argparse.Namespace, transform_class: type) -> int:
55
55
  required_modalities = transform_class.REQUIRED_MODALITIES
56
56
  required_hierarchical = transform_class.REQUIRED_HIERARCHICAL_MODALITIES or None
57
57
  dataset = build_dataset(config, required_modalities, required_hierarchical)
58
- output_backend = prepare_output_backend(config, dataset, transform_class)
58
+ output_backends = prepare_output_backends(config, dataset, transform_class)
59
+ primary_backend = next(iter(output_backends.values()))
59
60
  dataset_name = config.get("dataset", "dataset")
60
61
 
61
62
  raw_modalities = {
@@ -69,14 +70,25 @@ def _run_transform(args: argparse.Namespace, transform_class: type) -> int:
69
70
  else:
70
71
  modality_info[name] = entry
71
72
  log_dataset_info(logger, dataset_name, len(dataset), modality_info, use_gpu)
72
- logger.info("Output path: %s", output_backend.root)
73
+ for slot, backend in output_backends.items():
74
+ logger.info("Output path [%s]: %s", slot, backend.root)
73
75
 
74
76
  transform_kwargs: dict = {
75
77
  "config_path": str(transform_config_path),
76
- "out_path": str(output_backend.root),
77
- "output_backend": output_backend,
78
+ "out_path": str(primary_backend.root),
78
79
  }
79
80
  init_params = inspect.signature(transform_class.__init__).parameters
81
+ if "output_backends" in init_params:
82
+ transform_kwargs["output_backends"] = output_backends
83
+ else:
84
+ transform_kwargs["output_backend"] = primary_backend
85
+ if len(output_backends) > 1:
86
+ extra = [s for s in output_backends if s != next(iter(output_backends))]
87
+ logger.warning(
88
+ "%s does not accept output_backends; ignoring auxiliary slots: %s",
89
+ transform_class.__name__,
90
+ extra,
91
+ )
80
92
  if "strict" in init_params:
81
93
  transform_kwargs["strict"] = bool(getattr(args, "strict", False))
82
94
  elif getattr(args, "strict", False):
@@ -1,11 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
 
4
- def _parse_modality_entry(entry: str | dict) -> dict:
5
- """Normalise a modality config entry to ``{path, split}``."""
4
+ def _make_modality(entry: str | dict):
5
+ from euler_loading import Modality
6
+
6
7
  if isinstance(entry, str):
7
- return {"path": entry}
8
- return entry
8
+ return Modality(entry)
9
+ return Modality(entry["path"], split=entry.get("split"))
9
10
 
10
11
 
11
12
  def build_dataset(
@@ -15,19 +16,15 @@ def build_dataset(
15
16
  ):
16
17
  """Build a ``MultiModalDataset`` from a config dict.
17
18
 
18
- Args:
19
- config: Top-level dataset config containing ``modalities`` and
20
- optionally ``hierarchical_modalities`` mappings. Each modality
21
- value may be a plain path string or a dict with ``path`` and
22
- an optional ``split`` key.
23
- required_modalities: Set of modality names that must be present.
24
- required_hierarchical: Optional set of hierarchical modality names
25
- that must be present.
26
-
27
- Returns:
28
- A ``MultiModalDataset`` instance.
19
+ Each modality entry is either a plain path string or a dict with
20
+ ``path`` and optional ``split``. Loader resolution (which function to
21
+ call, which module to use) is handled by euler-loading via the
22
+ ds-crawler index at each path point the config at a path whose
23
+ index declares the function you want (e.g. a ``sky_mask`` index for
24
+ boolean sky masks vs. a ``class_segmentation`` index for raw class
25
+ id maps).
29
26
  """
30
- from euler_loading import Modality, MultiModalDataset
27
+ from euler_loading import MultiModalDataset
31
28
 
32
29
  raw_modalities = config.get("modalities", {})
33
30
  raw_hierarchical = config.get("hierarchical_modalities", {})
@@ -48,15 +45,10 @@ def build_dataset(
48
45
  f"contain at least: {', '.join(sorted(required_hierarchical))}"
49
46
  )
50
47
 
51
- modalities = {}
52
- for name, entry in raw_modalities.items():
53
- parsed = _parse_modality_entry(entry)
54
- modalities[name] = Modality(parsed["path"], split=parsed.get("split"))
55
-
56
- hierarchical_modalities = {}
57
- for name, entry in raw_hierarchical.items():
58
- parsed = _parse_modality_entry(entry)
59
- hierarchical_modalities[name] = Modality(parsed["path"], split=parsed.get("split"))
48
+ modalities = {name: _make_modality(entry) for name, entry in raw_modalities.items()}
49
+ hierarchical_modalities = {
50
+ name: _make_modality(entry) for name, entry in raw_hierarchical.items()
51
+ }
60
52
 
61
53
  return MultiModalDataset(
62
54
  modalities=modalities,
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  import tempfile
5
+ from collections.abc import Callable
5
6
  from dataclasses import dataclass, field
6
7
  from pathlib import Path
7
8
  from typing import Any
@@ -19,6 +20,39 @@ _PIPELINE_OUTPUT_STORAGE_KINDS = {"directory", "zip", "file"}
19
20
  _IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
20
21
 
21
22
 
23
+ @dataclass(frozen=True)
24
+ class OutputSlotSpec:
25
+ """Auxiliary-slot spec for transforms producing more than one modality.
26
+
27
+ Auxiliary slots reuse the source modality's hierarchy/indexing so the
28
+ written files line up with the input dataset, but supply their own writer
29
+ and ds-crawler metadata so the resulting on-disk dataset advertises the
30
+ correct modality type and loader.
31
+
32
+ Attributes:
33
+ source_modality: Name of the input modality whose hierarchy and
34
+ per-sample basenames are mirrored when writing auxiliary outputs.
35
+ writer: Writer callable invoked as ``writer(target, value, meta)``.
36
+ ``target`` is either a filesystem path (``str``/``PathLike``) or a
37
+ binary stream (when the writer is marked stream-supported and the
38
+ output is a zip).
39
+ index_overlay: Mapping merged on top of the source modality's
40
+ ``index_output`` to produce the ds-crawler head metadata for this
41
+ slot. Use this to override ``name``/``type``/``euler_train``/
42
+ ``euler_loading``/``meta`` while inheriting indexing/hierarchy.
43
+ output_extension: When set (e.g. ``".npy"``), source basenames are
44
+ rewritten with this extension before writing.
45
+ meta: Optional ``modality_meta`` passed to the writer. Defaults to
46
+ the ``meta`` field from the merged ``index_overlay`` when set there.
47
+ """
48
+
49
+ source_modality: str
50
+ writer: Callable[..., None]
51
+ index_overlay: dict[str, Any]
52
+ output_extension: str | None = None
53
+ meta: dict[str, Any] | None = None
54
+
55
+
22
56
  @dataclass(frozen=True)
23
57
  class PipelineOutputTargetConfig:
24
58
  """Runtime-resolved pipeline output target for a single transform output."""
@@ -218,6 +252,7 @@ class SourceBackedOutputBackend:
218
252
  dataset_writer: DatasetWriter | ZipDatasetWriter,
219
253
  modality_writer: Any,
220
254
  modality_meta: dict[str, Any] | None,
255
+ output_extension: str | None = None,
221
256
  pipeline_manifest_path: Path | None = None,
222
257
  pipeline_manifest_targets: list[PipelineOutputTargetConfig] | None = None,
223
258
  ) -> None:
@@ -226,6 +261,7 @@ class SourceBackedOutputBackend:
226
261
  self.dataset_writer = dataset_writer
227
262
  self.modality_writer = modality_writer
228
263
  self.modality_meta = modality_meta
264
+ self.output_extension = output_extension
229
265
  self.pipeline_manifest_path = pipeline_manifest_path
230
266
  self.pipeline_manifest_targets = pipeline_manifest_targets or []
231
267
 
@@ -254,8 +290,13 @@ class SourceBackedOutputBackend:
254
290
  "requires sample['meta'][source_modality]['path']."
255
291
  )
256
292
 
257
- basename = Path(str(source_meta["path"])).name
258
- relative_path = str(source_meta["path"])
293
+ source_path = Path(str(source_meta["path"]))
294
+ if self.output_extension is not None:
295
+ basename = source_path.stem + self.output_extension
296
+ relative_path = str(source_path.with_suffix(self.output_extension))
297
+ else:
298
+ basename = source_path.name
299
+ relative_path = str(source_path)
259
300
  source_meta_copy = dict(source_meta)
260
301
 
261
302
  if isinstance(self.dataset_writer, ZipDatasetWriter):
@@ -380,7 +421,11 @@ def prepare_output_backend(
380
421
  dataset: MultiModalDataset,
381
422
  transform_class: type,
382
423
  ) -> SourceBackedOutputBackend:
383
- """Create a source-backed output backend for a transform run."""
424
+ """Create a source-backed output backend for a transform run.
425
+
426
+ Used for the *primary* output slot. Transforms with auxiliary outputs
427
+ should use :func:`prepare_output_backends` (plural).
428
+ """
384
429
 
385
430
  source_modality = getattr(transform_class, "SOURCE_MODALITY", None)
386
431
  if not isinstance(source_modality, str) or not source_modality:
@@ -436,3 +481,184 @@ def prepare_output_backend(
436
481
  pipeline_manifest_path=pipeline_manifest_path,
437
482
  pipeline_manifest_targets=[pipeline_target] if pipeline_target else [],
438
483
  )
484
+
485
+
486
+ def _build_auxiliary_backend(
487
+ *,
488
+ spec: OutputSlotSpec,
489
+ pipeline_target: PipelineOutputTargetConfig,
490
+ dataset: MultiModalDataset,
491
+ ) -> SourceBackedOutputBackend:
492
+ """Create a backend for an auxiliary slot using its OutputSlotSpec.
493
+
494
+ The auxiliary backend does not own the pipeline manifest — that is
495
+ aggregated and written by the primary backend.
496
+ """
497
+
498
+ if pipeline_target.storage == "file":
499
+ raise ValueError(
500
+ f"Pipeline output target '{pipeline_target.slot}' uses "
501
+ "unsupported storage='file'"
502
+ )
503
+
504
+ root = Path(pipeline_target.path)
505
+ zip_mode = pipeline_target.storage == "zip"
506
+
507
+ base_index = dataset.get_modality_index(spec.source_modality)
508
+ index_output = _build_auxiliary_index(base_index, spec)
509
+
510
+ dataset_writer = create_dataset_writer_from_index(
511
+ index_output=index_output,
512
+ root=root,
513
+ zip=zip_mode,
514
+ )
515
+
516
+ modality_meta = spec.meta
517
+ if modality_meta is None:
518
+ head = index_output.get("head") or {}
519
+ modality_meta = (head.get("modality") or {}).get("meta")
520
+ if modality_meta is None:
521
+ modality_meta = index_output.get("meta")
522
+
523
+ return SourceBackedOutputBackend(
524
+ source_modality=spec.source_modality,
525
+ root=root,
526
+ dataset_writer=dataset_writer,
527
+ modality_writer=spec.writer,
528
+ modality_meta=modality_meta,
529
+ output_extension=spec.output_extension,
530
+ pipeline_manifest_path=None,
531
+ pipeline_manifest_targets=[],
532
+ )
533
+
534
+
535
+ def _build_auxiliary_index(
536
+ base_index: dict[str, Any], spec: OutputSlotSpec
537
+ ) -> dict[str, Any]:
538
+ """Apply ``spec.index_overlay`` to a copy of ``base_index``.
539
+
540
+ The overlay's recognised keys map to fields used by ds-crawler's writer
541
+ construction. ``name`` / ``type`` rewrite the dataset id+name and
542
+ modality key on both the contract head and the legacy top-level fields;
543
+ ``meta`` overrides the modality's meta dict; ``euler_train`` /
544
+ ``euler_loading`` replace those addon entries. Any other overlay keys
545
+ are passed through as top-level fields for the legacy writer path.
546
+ """
547
+
548
+ overlay = dict(spec.index_overlay)
549
+ index_output: dict[str, Any] = {**base_index}
550
+
551
+ head = base_index.get("head")
552
+ if isinstance(head, dict):
553
+ new_head = json.loads(json.dumps(head)) # deep copy via JSON
554
+ new_head.setdefault("dataset", {})
555
+ new_head.setdefault("modality", {})
556
+ new_head.setdefault("addons", {})
557
+
558
+ if "name" in overlay:
559
+ name = overlay["name"]
560
+ new_head["dataset"]["id"] = name
561
+ new_head["dataset"]["name"] = name
562
+ if "type" in overlay:
563
+ new_head["modality"]["key"] = overlay["type"]
564
+ if "meta" in overlay:
565
+ new_head["modality"]["meta"] = dict(overlay["meta"])
566
+ if "euler_train" in overlay:
567
+ new_head["addons"]["euler_train"] = dict(overlay["euler_train"])
568
+ if "euler_loading" in overlay:
569
+ new_head["addons"]["euler_loading"] = dict(overlay["euler_loading"])
570
+
571
+ index_output["head"] = new_head
572
+
573
+ # Legacy top-level fields used by the non-contract writer construction
574
+ # path. Preserved alongside the head for compatibility.
575
+ for key, value in overlay.items():
576
+ if isinstance(value, dict):
577
+ index_output[key] = dict(value)
578
+ else:
579
+ index_output[key] = value
580
+
581
+ return index_output
582
+
583
+
584
+ def _resolve_primary_slot(transform_class: type) -> str:
585
+ """Return the *primary* slot name declared by *transform_class*."""
586
+
587
+ output_slots = getattr(transform_class, "OUTPUT_SLOTS", None)
588
+ if output_slots:
589
+ return output_slots[0]
590
+
591
+ output_slot = getattr(transform_class, "OUTPUT_SLOT", None)
592
+ if isinstance(output_slot, str) and output_slot:
593
+ return output_slot
594
+
595
+ source_modality = getattr(transform_class, "SOURCE_MODALITY", None)
596
+ if isinstance(source_modality, str) and source_modality:
597
+ return source_modality
598
+
599
+ raise ValueError(
600
+ f"{transform_class.__name__} declares no output slot "
601
+ "(set OUTPUT_SLOT, OUTPUT_SLOTS, or SOURCE_MODALITY)"
602
+ )
603
+
604
+
605
+ def prepare_output_backends(
606
+ config: dict[str, Any],
607
+ dataset: MultiModalDataset,
608
+ transform_class: type,
609
+ ) -> dict[str, SourceBackedOutputBackend]:
610
+ """Create per-slot output backends for *transform_class*.
611
+
612
+ Returns ``{slot_name: backend}``. The primary slot (the first entry of
613
+ ``OUTPUT_SLOTS``, falling back to ``OUTPUT_SLOT`` / ``SOURCE_MODALITY``) is
614
+ always present. Auxiliary slots declared in
615
+ :attr:`Transform.OUTPUT_SLOT_SPECS` are included only when the dataset
616
+ config's ``pipeline.output_targets`` contains a matching entry; otherwise
617
+ the slot is silently omitted (auxiliary outputs are opt-in).
618
+
619
+ The returned dict's iteration order matches the declared
620
+ ``OUTPUT_SLOTS`` order.
621
+ """
622
+
623
+ primary_slot = _resolve_primary_slot(transform_class)
624
+ primary_backend = prepare_output_backend(config, dataset, transform_class)
625
+ backends: dict[str, SourceBackedOutputBackend] = {primary_slot: primary_backend}
626
+
627
+ slot_specs = getattr(transform_class, "OUTPUT_SLOT_SPECS", None) or {}
628
+ pipeline = parse_pipeline_config(config)
629
+
630
+ declared_slots = getattr(transform_class, "OUTPUT_SLOTS", ()) or ()
631
+ aux_slots = [s for s in declared_slots if s != primary_slot]
632
+ if pipeline is not None and slot_specs:
633
+ for slot in aux_slots:
634
+ spec = slot_specs.get(slot)
635
+ if spec is None:
636
+ continue
637
+ target = pipeline.get_output_target(slot)
638
+ if target is None:
639
+ continue
640
+ backends[slot] = _build_auxiliary_backend(
641
+ spec=spec,
642
+ pipeline_target=target,
643
+ dataset=dataset,
644
+ )
645
+
646
+ # Aggregate every slot we actually wrote into the manifest the primary
647
+ # backend will emit, so a single manifest documents the full set.
648
+ if (
649
+ pipeline is not None
650
+ and pipeline.outputs_manifest_path
651
+ and len(backends) > 1
652
+ ):
653
+ manifest_targets: list[PipelineOutputTargetConfig] = list(
654
+ primary_backend.pipeline_manifest_targets
655
+ )
656
+ for slot in aux_slots:
657
+ if slot not in backends:
658
+ continue
659
+ target = pipeline.get_output_target(slot)
660
+ if target is not None:
661
+ manifest_targets.append(target)
662
+ primary_backend.pipeline_manifest_targets = manifest_targets
663
+
664
+ return backends
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
5
  from pathlib import Path
6
- from typing import ClassVar
6
+ from typing import Any, ClassVar
7
7
 
8
8
 
9
9
  class Transform(ABC):
@@ -11,12 +11,22 @@ class Transform(ABC):
11
11
 
12
12
  Subclasses declare the modalities they need via class variables and
13
13
  implement :meth:`run` to process samples.
14
+
15
+ Output slots:
16
+ Most transforms produce a single output (the *primary* slot, declared
17
+ via :attr:`OUTPUT_SLOT`). Transforms that produce additional auxiliary
18
+ outputs (e.g. fog β / L_s maps) declare them in :attr:`OUTPUT_SLOTS`
19
+ together with per-slot specs in :attr:`OUTPUT_SLOT_SPECS`. Auxiliary
20
+ slots are opt-in: they are only written when the dataset config's
21
+ ``pipeline.output_targets`` includes a matching entry.
14
22
  """
15
23
 
16
24
  REQUIRED_MODALITIES: ClassVar[set[str]] = set()
17
25
  REQUIRED_HIERARCHICAL_MODALITIES: ClassVar[set[str]] = set()
18
26
  SOURCE_MODALITY: ClassVar[str | None] = None
19
27
  OUTPUT_SLOT: ClassVar[str | None] = None
28
+ OUTPUT_SLOTS: ClassVar[tuple[str, ...]] = ()
29
+ OUTPUT_SLOT_SPECS: ClassVar[dict[str, Any]] = {}
20
30
  OUTPUT_INDEX_META_OVERRIDES: ClassVar[dict[str, object]] = {}
21
31
 
22
32
  @abstractmethod
@@ -243,6 +243,33 @@ def uses_estimated_airlight(al_spec) -> bool:
243
243
  return al_spec is None or al_spec in AIRLIGHT_METHODS
244
244
 
245
245
 
246
+ def broadcast_k_field(k_field: Any, height: int, width: int) -> np.ndarray:
247
+ """Return ``k_field`` as a ``(H, W)`` float32 map (broadcasting if scalar)."""
248
+ arr = np.asarray(k_field, dtype=np.float32)
249
+ if arr.ndim == 0:
250
+ return np.broadcast_to(arr, (height, width)).astype(np.float32, copy=True)
251
+ if arr.shape == (height, width):
252
+ return arr.astype(np.float32, copy=False)
253
+ raise ValueError(
254
+ f"k_field must be scalar or shape ({height}, {width}); got {arr.shape}"
255
+ )
256
+
257
+
258
+ def broadcast_ls_field(ls_field: Any, height: int, width: int) -> np.ndarray:
259
+ """Return ``ls_field`` as a ``(H, W, 3)`` float32 map (broadcasting if needed)."""
260
+ arr = np.asarray(ls_field, dtype=np.float32)
261
+ if arr.shape == (3,):
262
+ return np.broadcast_to(arr, (height, width, 3)).astype(np.float32, copy=True)
263
+ if arr.shape == (1, 1, 3):
264
+ return np.broadcast_to(arr, (height, width, 3)).astype(np.float32, copy=True)
265
+ if arr.shape == (height, width, 3):
266
+ return arr.astype(np.float32, copy=False)
267
+ raise ValueError(
268
+ f"ls_field must have shape (3,), (1, 1, 3), or "
269
+ f"({height}, {width}, 3); got {arr.shape}"
270
+ )
271
+
272
+
246
273
  def apply_model(
247
274
  rgb: np.ndarray,
248
275
  depth_m: np.ndarray,
@@ -251,7 +278,18 @@ def apply_model(
251
278
  rng: np.random.Generator,
252
279
  contrast_threshold_default: float,
253
280
  estimated_airlight: np.ndarray,
254
- ) -> tuple[np.ndarray, float, np.ndarray]:
281
+ ) -> tuple[np.ndarray, float, np.ndarray, np.ndarray, np.ndarray]:
282
+ """Apply a fog model to ``rgb``.
283
+
284
+ Returns:
285
+ Tuple ``(foggy, k_mean, ls_base, k_map, ls_map)``:
286
+
287
+ * ``foggy``: ``(H, W, 3)`` foggy RGB image.
288
+ * ``k_mean``: scalar mean scattering coefficient (for filenames/logs).
289
+ * ``ls_base``: ``(3,)`` base atmospheric light (for filenames/logs).
290
+ * ``k_map``: ``(H, W)`` β-field actually used (broadcast for uniform).
291
+ * ``ls_map``: ``(H, W, 3)`` L_s-field actually used (broadcast for uniform).
292
+ """
255
293
  if model_name not in DEFAULT_MODEL_CONFIGS:
256
294
  raise ValueError(f"Unsupported fog model: {model_name}")
257
295
  visibility = float(sample_value(model_cfg.get("visibility_m"), rng))
@@ -269,14 +307,19 @@ def apply_model(
269
307
  sampled_al = sample_value(al_spec, rng)
270
308
  ls_base = normalize_atmospheric_light(np.asarray(sampled_al))
271
309
 
310
+ height, width = depth_m.shape
311
+
272
312
  if model_name == "uniform":
273
313
  ls_field = ls_base.reshape(1, 1, 3)
274
- return apply_fog(rgb, depth_m, k_mean, ls_field), k_mean, ls_base
314
+ foggy = apply_fog(rgb, depth_m, k_mean, ls_field)
315
+ k_map = broadcast_k_field(k_mean, height, width)
316
+ ls_map = broadcast_ls_field(ls_base, height, width)
317
+ return foggy, k_mean, ls_base, k_map, ls_map
275
318
 
276
319
  if model_name in ("heterogeneous_k", "heterogeneous_k_ls"):
277
320
  k_cfg = model_cfg.get("k_hetero", {})
278
- k_scales = resolve_scales(k_cfg, depth_m.shape[0], depth_m.shape[1], rng)
279
- k_noise = perlin_fbm(depth_m.shape[0], depth_m.shape[1], k_scales, rng)
321
+ k_scales = resolve_scales(k_cfg, height, width, rng)
322
+ k_noise = perlin_fbm(height, width, k_scales, rng)
280
323
  min_factor = float(sample_value(k_cfg.get("min_factor", 1.0), rng))
281
324
  max_factor = float(sample_value(k_cfg.get("max_factor", 1.0), rng))
282
325
  k_field = modulate_with_noise(
@@ -291,8 +334,8 @@ def apply_model(
291
334
 
292
335
  if model_name in ("heterogeneous_ls", "heterogeneous_k_ls"):
293
336
  ls_cfg = model_cfg.get("ls_hetero", {})
294
- ls_scales = resolve_scales(ls_cfg, depth_m.shape[0], depth_m.shape[1], rng)
295
- ls_noise = perlin_fbm(depth_m.shape[0], depth_m.shape[1], ls_scales, rng)
337
+ ls_scales = resolve_scales(ls_cfg, height, width, rng)
338
+ ls_noise = perlin_fbm(height, width, ls_scales, rng)
296
339
  min_factor = float(sample_value(ls_cfg.get("min_factor", 1.0), rng))
297
340
  max_factor = float(sample_value(ls_cfg.get("max_factor", 1.0), rng))
298
341
  ls_field = modulate_with_noise(
@@ -306,4 +349,7 @@ def apply_model(
306
349
  else:
307
350
  ls_field = ls_base.reshape(1, 1, 3)
308
351
 
309
- return apply_fog(rgb, depth_m, k_field, ls_field), k_mean, ls_base
352
+ foggy = apply_fog(rgb, depth_m, k_field, ls_field)
353
+ k_map = broadcast_k_field(k_field, height, width)
354
+ ls_map = broadcast_ls_field(ls_field, height, width)
355
+ return foggy, k_mean, ls_base, k_map, ls_map