torch-nvidia-of-sdk 5.0.0__py3-none-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- of/__init__.py +3 -0
- of/datasets.py +229 -0
- of/io.py +80 -0
- of/methods/__init__.py +3 -0
- of/methods/nvidia_sdk.py +69 -0
- of/methods/nvof_torch.cpython-313-x86_64-linux-gnu.so +0 -0
- of/metrics.py +174 -0
- of/py.typed +0 -0
- of/visualization.py +378 -0
- torch_nvidia_of_sdk-5.0.0.dist-info/METADATA +218 -0
- torch_nvidia_of_sdk-5.0.0.dist-info/RECORD +15 -0
- torch_nvidia_of_sdk-5.0.0.dist-info/WHEEL +5 -0
- torch_nvidia_of_sdk-5.0.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- torch_nvidia_of_sdk.libs/libcuda-df918354.so.580.105.08 +0 -0
- torch_nvidia_of_sdk.libs/libcudart-381c0faa.so.12.9.37 +0 -0
of/__init__.py
ADDED
of/datasets.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset loaders for optical flow benchmarks.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Dict, Optional, Union
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import imageio.v3 as imageio
|
|
10
|
+
|
|
11
|
+
from .io import read_flo
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class FlowSample:
|
|
16
|
+
"""A single optical flow sample with metadata."""
|
|
17
|
+
|
|
18
|
+
current_frame: Path
|
|
19
|
+
reference_frame: Path
|
|
20
|
+
fw_flow: Optional[Path] = None
|
|
21
|
+
bw_flow: Optional[Path] = None
|
|
22
|
+
fw_valid_mask: Optional[Path] = None
|
|
23
|
+
bw_valid_mask: Optional[Path] = None
|
|
24
|
+
name: str = ""
|
|
25
|
+
metadata: Dict = None
|
|
26
|
+
|
|
27
|
+
def __post_init__(self):
|
|
28
|
+
if self.metadata is None:
|
|
29
|
+
self.metadata = {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpticalFlowDataset:
|
|
33
|
+
"""Base class for optical flow datasets."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, root_dir: Union[str, Path]):
|
|
36
|
+
self.root_dir = Path(root_dir)
|
|
37
|
+
self.samples: List[FlowSample] = []
|
|
38
|
+
self.setup()
|
|
39
|
+
|
|
40
|
+
def setup(self):
|
|
41
|
+
raise NotImplementedError("Subclasses should implement this method.")
|
|
42
|
+
|
|
43
|
+
def __len__(self) -> int:
|
|
44
|
+
return len(self.samples)
|
|
45
|
+
|
|
46
|
+
def __getitem__(self, idx: int) -> FlowSample:
|
|
47
|
+
raise NotImplementedError("Subclasses should implement this method.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class CSEMDataset(OpticalFlowDataset):
|
|
51
|
+
"""
|
|
52
|
+
project_root/
|
|
53
|
+
├── images/
|
|
54
|
+
│ ├── frame_0001.png
|
|
55
|
+
│ ├── frame_0002.png
|
|
56
|
+
| ├── ...
|
|
57
|
+
│ └── frame_N.png
|
|
58
|
+
├── flow/
|
|
59
|
+
│ ├── forward/ # Flow from t to t+1
|
|
60
|
+
│ │ ├── flow_0001.flo
|
|
61
|
+
│ │ ├── flow_0002.flo
|
|
62
|
+
│ │ ├── ...
|
|
63
|
+
│ │ └── flow_<N-1>.flo
|
|
64
|
+
│ └── backward/ # Flow from t to t-1 (for occlusion checks)
|
|
65
|
+
│ ├── flow_0002.flo
|
|
66
|
+
│ ├── flow_0003.flo
|
|
67
|
+
│ ├── ...
|
|
68
|
+
│ └── flow_<N>.flo
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def setup(self):
|
|
72
|
+
frame_paths = sorted((self.root_dir / "images").glob("*.png"))
|
|
73
|
+
for i in range(len(frame_paths) - 1):
|
|
74
|
+
current_frame = frame_paths[i]
|
|
75
|
+
reference_frame = frame_paths[i + 1]
|
|
76
|
+
name = current_frame.stem.split("_")[-1]
|
|
77
|
+
frame_number = int(name)
|
|
78
|
+
fw_flow_path = self.root_dir / "flow" / "forward" / f"flow_{name}.flo"
|
|
79
|
+
bw_flow_path = self.root_dir / "flow" / "backward" / f"flow_{name}.flo"
|
|
80
|
+
sample = FlowSample(
|
|
81
|
+
current_frame=current_frame,
|
|
82
|
+
reference_frame=reference_frame,
|
|
83
|
+
fw_flow=fw_flow_path if fw_flow_path.exists() else None,
|
|
84
|
+
bw_flow=bw_flow_path if bw_flow_path.exists() else None,
|
|
85
|
+
name=name,
|
|
86
|
+
metadata={"frame_number": frame_number},
|
|
87
|
+
)
|
|
88
|
+
self.samples.append(sample)
|
|
89
|
+
|
|
90
|
+
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
|
|
91
|
+
sample = self.samples[idx]
|
|
92
|
+
data = {
|
|
93
|
+
"current_frame": imageio.imread(sample.current_frame),
|
|
94
|
+
"reference_frame": imageio.imread(sample.reference_frame),
|
|
95
|
+
"name": sample.name,
|
|
96
|
+
"metadata": sample.metadata,
|
|
97
|
+
}
|
|
98
|
+
if sample.fw_flow:
|
|
99
|
+
data["fw_flow"] = read_flo(sample.fw_flow)
|
|
100
|
+
if sample.bw_flow:
|
|
101
|
+
data["bw_flow"] = read_flo(sample.bw_flow)
|
|
102
|
+
return data
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class MPISintelDataset(OpticalFlowDataset):
|
|
106
|
+
"""
|
|
107
|
+
MPI-Sintel optical flow dataset.
|
|
108
|
+
|
|
109
|
+
The dataset has two passes: 'clean' and 'final'.
|
|
110
|
+
Directory structure:
|
|
111
|
+
root/split/pass_name/sequence/frame_xxxx.png
|
|
112
|
+
root/split/flow/sequence/frame_xxxx.flo
|
|
113
|
+
root/split/invalid/sequence/frame_xxxx.png
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
root_dir: Union[str, Path],
|
|
119
|
+
split: str = "training",
|
|
120
|
+
pass_name: str = "clean",
|
|
121
|
+
):
|
|
122
|
+
"""
|
|
123
|
+
Initialize MPI-Sintel dataset.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
root_dir: Path to MPI-Sintel root directory
|
|
127
|
+
split: Either 'training' or 'test'
|
|
128
|
+
pass_name: Either 'clean' or 'final'
|
|
129
|
+
"""
|
|
130
|
+
# Set attributes before super().__init__ because it calls setup()
|
|
131
|
+
self.split = split
|
|
132
|
+
self.pass_name = pass_name
|
|
133
|
+
|
|
134
|
+
super().__init__(root_dir)
|
|
135
|
+
|
|
136
|
+
def setup(self):
|
|
137
|
+
"""Build list of samples from directory structure."""
|
|
138
|
+
images_dir = self.root_dir / self.split / self.pass_name
|
|
139
|
+
|
|
140
|
+
if not images_dir.exists():
|
|
141
|
+
raise ValueError(f"Images directory not found: {images_dir}")
|
|
142
|
+
|
|
143
|
+
flow_dir = self.root_dir / self.split / "flow"
|
|
144
|
+
# Sintel provides 'invalid' masks (occlusions + out of bounds)
|
|
145
|
+
# We map this to fw_valid_mask (loading logic should handle the inversion if needed)
|
|
146
|
+
invalid_dir = self.root_dir / self.split / "invalid"
|
|
147
|
+
|
|
148
|
+
# Iterate through sequences
|
|
149
|
+
for seq_dir in sorted(images_dir.iterdir()):
|
|
150
|
+
if not seq_dir.is_dir():
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
seq_name = seq_dir.name
|
|
154
|
+
# Get all image pairs in the sequence
|
|
155
|
+
image_files = sorted(seq_dir.glob("*.png"))
|
|
156
|
+
|
|
157
|
+
for i in range(len(image_files) - 1):
|
|
158
|
+
img1 = image_files[i]
|
|
159
|
+
img2 = image_files[i + 1]
|
|
160
|
+
|
|
161
|
+
fw_flow_path = None
|
|
162
|
+
fw_valid_mask_path = None
|
|
163
|
+
|
|
164
|
+
# Training split has ground truth flow and masks
|
|
165
|
+
if self.split == "training":
|
|
166
|
+
# Flow filename matches image filename
|
|
167
|
+
cand_flow = flow_dir / seq_name / img1.name.replace(".png", ".flo")
|
|
168
|
+
if cand_flow.exists():
|
|
169
|
+
fw_flow_path = cand_flow
|
|
170
|
+
|
|
171
|
+
cand_mask = invalid_dir / seq_name / img1.name
|
|
172
|
+
if cand_mask.exists():
|
|
173
|
+
fw_valid_mask_path = cand_mask
|
|
174
|
+
|
|
175
|
+
sample = FlowSample(
|
|
176
|
+
current_frame=img1,
|
|
177
|
+
reference_frame=img2,
|
|
178
|
+
fw_flow=fw_flow_path,
|
|
179
|
+
bw_flow=None, # Sintel standard structure doesn't provide backward flow
|
|
180
|
+
fw_valid_mask=fw_valid_mask_path,
|
|
181
|
+
bw_valid_mask=None,
|
|
182
|
+
name=f"{seq_name}_{img1.stem}",
|
|
183
|
+
metadata={
|
|
184
|
+
"sequence": seq_name,
|
|
185
|
+
"pass": self.pass_name,
|
|
186
|
+
"frame_index": i,
|
|
187
|
+
},
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
self.samples.append(sample)
|
|
191
|
+
|
|
192
|
+
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
|
|
193
|
+
sample = self.samples[idx]
|
|
194
|
+
data = {
|
|
195
|
+
"current_frame": imageio.imread(sample.current_frame),
|
|
196
|
+
"reference_frame": imageio.imread(sample.reference_frame),
|
|
197
|
+
"name": sample.name,
|
|
198
|
+
"metadata": sample.metadata,
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
if sample.fw_flow:
|
|
202
|
+
data["fw_flow"] = read_flo(sample.fw_flow)
|
|
203
|
+
|
|
204
|
+
if sample.fw_valid_mask:
|
|
205
|
+
# Sintel masks are PNGs where logic 1 usually means invalid.
|
|
206
|
+
# Reading as is; downstream transforms can invert/threshold.
|
|
207
|
+
data["fw_valid_mask"] = imageio.imread(sample.fw_valid_mask)
|
|
208
|
+
|
|
209
|
+
return data
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def get_dataset(name: str, root_dir: Union[str, Path], **kwargs) -> OpticalFlowDataset:
|
|
213
|
+
"""
|
|
214
|
+
Factory function to get a dataset by name.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
name: Dataset name ('mpi-sintel')
|
|
218
|
+
root_dir: Root directory of the dataset
|
|
219
|
+
**kwargs: Additional arguments for the dataset
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
OpticalFlowDataset instance
|
|
223
|
+
"""
|
|
224
|
+
name = name.lower()
|
|
225
|
+
|
|
226
|
+
if name == "mpi-sintel":
|
|
227
|
+
return MPISintelDataset(root_dir, **kwargs)
|
|
228
|
+
else:
|
|
229
|
+
raise ValueError(f"Unknown dataset: {name}. Only 'mpi-sintel' is supported.")
|
of/io.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Input/Output utilities for optical flow data.
|
|
3
|
+
Supports reading and writing optical flow in various formats.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import struct
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from numpy.typing import NDArray
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
FlowType = Union[NDArray, Tensor]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def read_flo(filepath: Union[str, Path]) -> NDArray: # noqa
|
|
19
|
+
"""
|
|
20
|
+
Read optical flow from .flo file (Middlebury format).
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
filepath: Path to the .flo file
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Optical flow as numpy array of shape (H, W, 2) with dtype float32
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If file format is invalid
|
|
30
|
+
"""
|
|
31
|
+
filepath = Path(filepath)
|
|
32
|
+
|
|
33
|
+
with open(filepath, "rb") as f:
|
|
34
|
+
# Check magic number
|
|
35
|
+
magic = struct.unpack("f", f.read(4))[0]
|
|
36
|
+
if magic != 202021.25:
|
|
37
|
+
raise ValueError(f"Invalid .flo file format. Magic number: {magic}")
|
|
38
|
+
|
|
39
|
+
# Read dimensions
|
|
40
|
+
width = struct.unpack("i", f.read(4))[0]
|
|
41
|
+
height = struct.unpack("i", f.read(4))[0]
|
|
42
|
+
|
|
43
|
+
# Read flow data
|
|
44
|
+
data = np.fromfile(f, np.float32)
|
|
45
|
+
|
|
46
|
+
# Reshape to (height, width, 2)
|
|
47
|
+
flow = data.reshape((height, width, 2))
|
|
48
|
+
return flow
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def write_flo(filepath: Union[str, Path], flow: FlowType) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Write optical flow to .flo file (Middlebury format).
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
filepath: Path to save the .flo file
|
|
57
|
+
flow: Optical flow as numpy array of shape (H, W, 2)
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If flow array shape is invalid
|
|
61
|
+
"""
|
|
62
|
+
filepath = Path(filepath)
|
|
63
|
+
if isinstance(flow, torch.Tensor):
|
|
64
|
+
flow = flow.detach().cpu().numpy()
|
|
65
|
+
|
|
66
|
+
if flow.ndim != 3 or flow.shape[2] != 2:
|
|
67
|
+
raise ValueError(f"Flow must have shape (H, W, 2), got {flow.shape}")
|
|
68
|
+
|
|
69
|
+
height, width, _ = flow.shape
|
|
70
|
+
|
|
71
|
+
with open(filepath, "wb") as f:
|
|
72
|
+
# Write magic number
|
|
73
|
+
f.write(struct.pack("f", 202021.25))
|
|
74
|
+
|
|
75
|
+
# Write dimensions
|
|
76
|
+
f.write(struct.pack("i", width))
|
|
77
|
+
f.write(struct.pack("i", height))
|
|
78
|
+
|
|
79
|
+
# Write flow data
|
|
80
|
+
flow.astype(np.float32).tofile(f)
|
of/methods/__init__.py
ADDED
of/methods/nvidia_sdk.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from .nvof_torch import TorchNVOpticalFlow as _C_TorchNVOpticalFlow
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TorchNVOpticalFlow(_C_TorchNVOpticalFlow):
|
|
8
|
+
@classmethod
|
|
9
|
+
def from_tensor(
|
|
10
|
+
cls,
|
|
11
|
+
input: torch.Tensor,
|
|
12
|
+
preset: Literal["slow", "medium", "fast"] = "fast",
|
|
13
|
+
grid_size: Literal[1, 2, 4] = 1,
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Factory method to create engine from a tensor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
input: Tensor of shape (H, W, C) to infer dimensions from
|
|
20
|
+
preset: Performance preset ("slow", "medium", "fast")
|
|
21
|
+
grid_size: Grid size for optical flow (1, 2, or 4)
|
|
22
|
+
bidirectional: Whether to compute bidirectional flow # TODO
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
TorchNVOpticalFlow instance
|
|
26
|
+
"""
|
|
27
|
+
h = input.size(0)
|
|
28
|
+
w = input.size(1)
|
|
29
|
+
gpu = input.get_device()
|
|
30
|
+
return cls(w, h, gpu, preset, grid_size, False)
|
|
31
|
+
|
|
32
|
+
@torch.no_grad()
|
|
33
|
+
def compute_flow(
|
|
34
|
+
self,
|
|
35
|
+
input: torch.Tensor,
|
|
36
|
+
reference: torch.Tensor,
|
|
37
|
+
upsample: bool = True,
|
|
38
|
+
) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
Compute optical flow between input and reference frames.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input: Input tensor of shape (H, W, 3) or (H, W, 4), uint8, CUDA
|
|
44
|
+
reference: Reference tensor of shape (H, W, 3) or (H, W, 4), uint8, CUDA
|
|
45
|
+
upsample: Whether to upsample the output to full resolution
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Flow tensor of shape (H, W, 2), int16
|
|
49
|
+
"""
|
|
50
|
+
# NVOpticalFlow requires ABGR uint8
|
|
51
|
+
if isinstance(input, np.ndarray):
|
|
52
|
+
input = torch.from_numpy(input)
|
|
53
|
+
if isinstance(reference, np.ndarray):
|
|
54
|
+
reference = torch.from_numpy(reference)
|
|
55
|
+
|
|
56
|
+
if input.shape[-1] == 3:
|
|
57
|
+
alpha_input = input.sum(dim=-1, keepdim=True).div(3).clamp(0, 255).byte()
|
|
58
|
+
input = torch.cat([alpha_input, input], dim=-1).to(f"cuda:{self.gpu_id()}")
|
|
59
|
+
|
|
60
|
+
if reference.shape[-1] == 3:
|
|
61
|
+
alpha_reference = (
|
|
62
|
+
reference.sum(dim=-1, keepdim=True).div(3).clamp(0, 255).byte()
|
|
63
|
+
)
|
|
64
|
+
reference = torch.cat([alpha_reference, reference], dim=-1).to(
|
|
65
|
+
f"cuda:{self.gpu_id()}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Checks performed in C++
|
|
69
|
+
return super().compute_flow(input, reference, upsample=upsample).float() / 32.0
|
|
Binary file
|
of/metrics.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Error metrics for optical flow evaluation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import Optional, Dict, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def endpoint_error(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
10
|
+
valid_mask: Optional[np.ndarray] = None) -> np.ndarray:
|
|
11
|
+
"""
|
|
12
|
+
Compute End-Point Error (EPE) between predicted and ground truth flow.
|
|
13
|
+
|
|
14
|
+
EPE = sqrt((u_pred - u_gt)^2 + (v_pred - v_gt)^2)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
18
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
19
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
EPE map of shape (H, W)
|
|
23
|
+
"""
|
|
24
|
+
diff = flow_pred - flow_gt
|
|
25
|
+
epe = np.sqrt(np.sum(diff**2, axis=2))
|
|
26
|
+
|
|
27
|
+
if valid_mask is not None:
|
|
28
|
+
epe = epe * valid_mask
|
|
29
|
+
|
|
30
|
+
return epe
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def average_endpoint_error(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
34
|
+
valid_mask: Optional[np.ndarray] = None) -> float:
|
|
35
|
+
"""
|
|
36
|
+
Compute Average End-Point Error (AEE).
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
40
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
41
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Average EPE as a scalar float
|
|
45
|
+
"""
|
|
46
|
+
epe = endpoint_error(flow_pred, flow_gt, valid_mask)
|
|
47
|
+
|
|
48
|
+
if valid_mask is not None:
|
|
49
|
+
return epe[valid_mask].mean()
|
|
50
|
+
else:
|
|
51
|
+
return epe.mean()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def angular_error(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
55
|
+
valid_mask: Optional[np.ndarray] = None) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Compute Angular Error (AE) between predicted and ground truth flow.
|
|
58
|
+
|
|
59
|
+
The angular error is computed as the angle between the 3D vectors
|
|
60
|
+
(u, v, 1) for predicted and ground truth flows.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
64
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
65
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Angular error map in degrees of shape (H, W)
|
|
69
|
+
"""
|
|
70
|
+
# Convert to 3D vectors
|
|
71
|
+
u_pred, v_pred = flow_pred[:, :, 0], flow_pred[:, :, 1]
|
|
72
|
+
u_gt, v_gt = flow_gt[:, :, 0], flow_gt[:, :, 1]
|
|
73
|
+
|
|
74
|
+
# Compute dot product of normalized vectors
|
|
75
|
+
num = 1 + u_pred * u_gt + v_pred * v_gt
|
|
76
|
+
denom = np.sqrt(1 + u_pred**2 + v_pred**2) * np.sqrt(1 + u_gt**2 + v_gt**2)
|
|
77
|
+
|
|
78
|
+
# Avoid division by zero
|
|
79
|
+
denom = np.maximum(denom, 1e-10)
|
|
80
|
+
|
|
81
|
+
# Compute angle
|
|
82
|
+
cos_angle = np.clip(num / denom, -1.0, 1.0)
|
|
83
|
+
ae = np.arccos(cos_angle) * 180.0 / np.pi
|
|
84
|
+
|
|
85
|
+
if valid_mask is not None:
|
|
86
|
+
ae = ae * valid_mask
|
|
87
|
+
|
|
88
|
+
return ae
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def average_angular_error(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
92
|
+
valid_mask: Optional[np.ndarray] = None) -> float:
|
|
93
|
+
"""
|
|
94
|
+
Compute Average Angular Error (AAE).
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
98
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
99
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Average angular error in degrees as a scalar float
|
|
103
|
+
"""
|
|
104
|
+
ae = angular_error(flow_pred, flow_gt, valid_mask)
|
|
105
|
+
|
|
106
|
+
if valid_mask is not None:
|
|
107
|
+
return ae[valid_mask].mean()
|
|
108
|
+
else:
|
|
109
|
+
return ae.mean()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def fl_all(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
113
|
+
valid_mask: Optional[np.ndarray] = None,
|
|
114
|
+
threshold: float = 3.0) -> float:
|
|
115
|
+
"""
|
|
116
|
+
Compute FL-all metric (percentage of outliers).
|
|
117
|
+
|
|
118
|
+
An outlier is defined as EPE > threshold OR angular error > threshold degrees.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
122
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
123
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
124
|
+
threshold: Threshold for outlier detection (default: 3.0 pixels)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Percentage of outliers (0-100)
|
|
128
|
+
"""
|
|
129
|
+
epe = endpoint_error(flow_pred, flow_gt, valid_mask)
|
|
130
|
+
ae = angular_error(flow_pred, flow_gt, valid_mask)
|
|
131
|
+
|
|
132
|
+
outliers = (epe > threshold) | (ae > threshold)
|
|
133
|
+
|
|
134
|
+
if valid_mask is not None:
|
|
135
|
+
return 100.0 * outliers[valid_mask].sum() / valid_mask.sum()
|
|
136
|
+
else:
|
|
137
|
+
return 100.0 * outliers.sum() / outliers.size
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def compute_all_metrics(flow_pred: np.ndarray, flow_gt: np.ndarray,
|
|
141
|
+
valid_mask: Optional[np.ndarray] = None,
|
|
142
|
+
thresholds: Tuple[float, ...] = (1.0, 3.0, 5.0)) -> Dict[str, float]:
|
|
143
|
+
"""
|
|
144
|
+
Compute all standard optical flow metrics.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
flow_pred: Predicted flow of shape (H, W, 2)
|
|
148
|
+
flow_gt: Ground truth flow of shape (H, W, 2)
|
|
149
|
+
valid_mask: Optional boolean mask of valid pixels (H, W)
|
|
150
|
+
thresholds: Thresholds for outlier metrics
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dictionary containing all metrics
|
|
154
|
+
"""
|
|
155
|
+
metrics = {}
|
|
156
|
+
|
|
157
|
+
# End-point error
|
|
158
|
+
metrics['EPE'] = average_endpoint_error(flow_pred, flow_gt, valid_mask)
|
|
159
|
+
|
|
160
|
+
# Angular error
|
|
161
|
+
metrics['AE'] = average_angular_error(flow_pred, flow_gt, valid_mask)
|
|
162
|
+
|
|
163
|
+
# Outlier percentages at different thresholds
|
|
164
|
+
for threshold in thresholds:
|
|
165
|
+
metrics[f'FL-{threshold}px'] = fl_all(flow_pred, flow_gt, valid_mask, threshold)
|
|
166
|
+
|
|
167
|
+
# Root mean square error
|
|
168
|
+
epe = endpoint_error(flow_pred, flow_gt, valid_mask)
|
|
169
|
+
if valid_mask is not None:
|
|
170
|
+
metrics['RMSE'] = np.sqrt((epe[valid_mask]**2).mean())
|
|
171
|
+
else:
|
|
172
|
+
metrics['RMSE'] = np.sqrt((epe**2).mean())
|
|
173
|
+
|
|
174
|
+
return metrics
|
of/py.typed
ADDED
|
File without changes
|
of/visualization.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization utilities for optical flow.
|
|
3
|
+
|
|
4
|
+
INTERPRETATION GUIDE:
|
|
5
|
+
---------------------
|
|
6
|
+
To correctly read these visualizations, you must distinguish between the "Color Wheel"
|
|
7
|
+
logic and "Masking" logic.
|
|
8
|
+
|
|
9
|
+
1. PIXEL COLORS (Valid Flow):
|
|
10
|
+
- WHITE: Zero Motion (Stationary).
|
|
11
|
+
- COLOR: Motion Direction (Hue) and Speed (Saturation).
|
|
12
|
+
(Red=Right, Green=Up, Cyan=Left, Blue=Down).
|
|
13
|
+
- BRIGHT/VIVID: Fast motion.
|
|
14
|
+
- PALE/PASTEL: Slow motion.
|
|
15
|
+
|
|
16
|
+
2. BLACK PIXELS (Invalid/Occluded):
|
|
17
|
+
- BLACK: Invalid data, NaN, Occluded, or Unknown.
|
|
18
|
+
|
|
19
|
+
*Note on KITTI*: In KITTI ground truth, the sky and upper image are invalid.
|
|
20
|
+
They appear BLACK. The stationary road appears WHITE. If you see a visualization
|
|
21
|
+
that is Black where the road should be, the data is likely NaN/Invalid, not stationary.
|
|
22
|
+
|
|
23
|
+
3. CONVENTIONS (Normalization):
|
|
24
|
+
- 'middlebury': Adapts to the image's max speed. Good for seeing details in
|
|
25
|
+
slow scenes.
|
|
26
|
+
- 'kitti': Uses fixed scaling (usually saturates at 3px or similar).
|
|
27
|
+
Good for comparing speed between different clips (fast scenes
|
|
28
|
+
look vivid, slow scenes look white/pale).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
|
32
|
+
|
|
33
|
+
import matplotlib.pyplot as plt
|
|
34
|
+
import numpy as np
|
|
35
|
+
from matplotlib.axes import Axes
|
|
36
|
+
from matplotlib.figure import Figure
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_color_wheel() -> np.ndarray:
|
|
40
|
+
"""
|
|
41
|
+
Generates the standard optical flow color wheel (Middlebury/KITTI standard).
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
np.ndarray: Color wheel palette of shape (55, 3)
|
|
45
|
+
"""
|
|
46
|
+
RY = 15
|
|
47
|
+
YG = 6
|
|
48
|
+
GC = 4
|
|
49
|
+
CB = 11
|
|
50
|
+
BM = 13
|
|
51
|
+
MR = 6
|
|
52
|
+
|
|
53
|
+
ncols = RY + YG + GC + CB + BM + MR
|
|
54
|
+
colorwheel = np.zeros((ncols, 3))
|
|
55
|
+
|
|
56
|
+
col = 0
|
|
57
|
+
# RY
|
|
58
|
+
colorwheel[0:RY, 0] = 255
|
|
59
|
+
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY)
|
|
60
|
+
col += RY
|
|
61
|
+
|
|
62
|
+
# YG
|
|
63
|
+
colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG)
|
|
64
|
+
colorwheel[col : col + YG, 1] = 255
|
|
65
|
+
col += YG
|
|
66
|
+
|
|
67
|
+
# GC
|
|
68
|
+
colorwheel[col : col + GC, 1] = 255
|
|
69
|
+
colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC, 1) / GC)
|
|
70
|
+
col += GC
|
|
71
|
+
|
|
72
|
+
# CB
|
|
73
|
+
colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB)
|
|
74
|
+
colorwheel[col : col + CB, 2] = 255
|
|
75
|
+
col += CB
|
|
76
|
+
|
|
77
|
+
# BM
|
|
78
|
+
colorwheel[col : col + BM, 2] = 255
|
|
79
|
+
colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM, 1) / BM)
|
|
80
|
+
col += BM
|
|
81
|
+
|
|
82
|
+
# MR
|
|
83
|
+
colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR)
|
|
84
|
+
colorwheel[col : col + MR, 0] = 255
|
|
85
|
+
|
|
86
|
+
return colorwheel / 255.0
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def flow_to_color(
|
|
90
|
+
flow: np.ndarray, max_flow: Optional[float] = None, convention: str = "middlebury"
|
|
91
|
+
) -> np.ndarray:
|
|
92
|
+
"""
|
|
93
|
+
Convert optical flow to an RGB image using the standard color wheel.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
flow: Optical flow array of shape (H, W, 2).
|
|
97
|
+
|
|
98
|
+
max_flow: Normalization factor.
|
|
99
|
+
- If None: Defaults to the max magnitude in the flow array.
|
|
100
|
+
- If float: Any flow larger than this is clamped/desaturated.
|
|
101
|
+
|
|
102
|
+
convention: Controls the default normalization behavior if max_flow is None.
|
|
103
|
+
- 'middlebury': Always normalizes to the current image max.
|
|
104
|
+
- 'kitti': Defaults to a fixed scale (e.g. 3.0) if max_flow is None.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
np.ndarray: RGB image of shape (H, W, 3).
|
|
108
|
+
Valid flow is colored (0 motion = White).
|
|
109
|
+
Invalid/NaN flow is Black.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
assert flow.ndim == 3 and flow.shape[2] == 2
|
|
113
|
+
|
|
114
|
+
u = flow[:, :, 0]
|
|
115
|
+
v = flow[:, :, 1]
|
|
116
|
+
|
|
117
|
+
# 1. Detect Invalid Flow (NaN, Inf, or > 1e9)
|
|
118
|
+
# Middlebury uses 1e9 as a magic number for "unknown"
|
|
119
|
+
idx_unknown = (np.abs(u) > 1e9) | (np.abs(v) > 1e9) | np.isnan(u) | np.isnan(v)
|
|
120
|
+
|
|
121
|
+
# Temporarily clean data for calculation (avoid warnings)
|
|
122
|
+
u = np.where(idx_unknown, 0, u)
|
|
123
|
+
v = np.where(idx_unknown, 0, v)
|
|
124
|
+
|
|
125
|
+
mag = np.sqrt(u**2 + v**2)
|
|
126
|
+
angle = np.arctan2(-v, -u) / np.pi
|
|
127
|
+
|
|
128
|
+
fk = (angle + 1) / 2 * (55 - 1)
|
|
129
|
+
k0 = np.floor(fk).astype(int)
|
|
130
|
+
k1 = k0 + 1
|
|
131
|
+
k1[k1 == 55] = 0
|
|
132
|
+
f = fk - k0
|
|
133
|
+
|
|
134
|
+
colorwheel = _get_color_wheel()
|
|
135
|
+
col0 = colorwheel[k0]
|
|
136
|
+
col1 = colorwheel[k1]
|
|
137
|
+
|
|
138
|
+
col = (1 - f)[:, :, None] * col0 + f[:, :, None] * col1
|
|
139
|
+
|
|
140
|
+
if max_flow is None:
|
|
141
|
+
if convention.lower() == "kitti":
|
|
142
|
+
max_flow = 3.0
|
|
143
|
+
else:
|
|
144
|
+
max_flow = np.max(mag)
|
|
145
|
+
if max_flow == 0:
|
|
146
|
+
max_flow = 1.0
|
|
147
|
+
|
|
148
|
+
col *= mag[:, :, None] / max_flow
|
|
149
|
+
|
|
150
|
+
# Handle saturation
|
|
151
|
+
idx_sat = mag > max_flow
|
|
152
|
+
if np.any(idx_sat):
|
|
153
|
+
col[idx_sat] = col[idx_sat] * (max_flow / mag[idx_sat])[:, None]
|
|
154
|
+
col[idx_sat] = col[idx_sat] * 0.75 + 0.25
|
|
155
|
+
|
|
156
|
+
col = np.clip(col, 0, 1)
|
|
157
|
+
|
|
158
|
+
# 2. Apply Black for Unknown/Invalid pixels
|
|
159
|
+
if np.any(idx_unknown):
|
|
160
|
+
col[idx_unknown] = np.array([0, 0, 0])
|
|
161
|
+
|
|
162
|
+
return (col * 255).astype(np.uint8)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def visualize_flow(
|
|
166
|
+
flow: np.ndarray,
|
|
167
|
+
ax: Optional[Axes] = None,
|
|
168
|
+
title: str = "Optical Flow",
|
|
169
|
+
max_flow: Optional[float] = None,
|
|
170
|
+
convention: str = "middlebury",
|
|
171
|
+
figsize: Tuple[int, int] = (8, 6),
|
|
172
|
+
) -> Tuple[Figure, Axes]:
|
|
173
|
+
"""
|
|
174
|
+
Visualize dense optical flow.
|
|
175
|
+
"""
|
|
176
|
+
if ax is None:
|
|
177
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
178
|
+
else:
|
|
179
|
+
fig = ax.get_figure()
|
|
180
|
+
|
|
181
|
+
rgb = flow_to_color(flow, max_flow=max_flow, convention=convention)
|
|
182
|
+
|
|
183
|
+
ax.imshow(rgb)
|
|
184
|
+
ax.set_title(title)
|
|
185
|
+
ax.axis("off")
|
|
186
|
+
|
|
187
|
+
return fig, ax
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def visualize_flow_arrows(
|
|
191
|
+
flow: np.ndarray,
|
|
192
|
+
ax: Optional[Axes] = None,
|
|
193
|
+
image: Optional[np.ndarray] = None,
|
|
194
|
+
step: int = 16,
|
|
195
|
+
scale: float = 1.0,
|
|
196
|
+
color: str = "r",
|
|
197
|
+
title: str = "Flow Vectors",
|
|
198
|
+
figsize: Tuple[int, int] = (8, 6),
|
|
199
|
+
) -> Tuple[Figure, Axes]:
|
|
200
|
+
"""
|
|
201
|
+
Visualize flow as sparse arrows (quiver plot).
|
|
202
|
+
"""
|
|
203
|
+
if ax is None:
|
|
204
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
205
|
+
else:
|
|
206
|
+
fig = ax.get_figure()
|
|
207
|
+
|
|
208
|
+
h, w = flow.shape[:2]
|
|
209
|
+
|
|
210
|
+
if image is not None:
|
|
211
|
+
ax.imshow(image, cmap="gray" if image.ndim == 2 else None)
|
|
212
|
+
else:
|
|
213
|
+
ax.imshow(np.zeros((h, w, 3), dtype=np.uint8))
|
|
214
|
+
|
|
215
|
+
y, x = np.mgrid[step // 2 : h : step, step // 2 : w : step]
|
|
216
|
+
u = flow[y, x, 0]
|
|
217
|
+
v = flow[y, x, 1]
|
|
218
|
+
|
|
219
|
+
ax.quiver(
|
|
220
|
+
x,
|
|
221
|
+
y,
|
|
222
|
+
u,
|
|
223
|
+
-v,
|
|
224
|
+
color=color,
|
|
225
|
+
angles="xy",
|
|
226
|
+
scale_units="xy",
|
|
227
|
+
scale=1 / scale,
|
|
228
|
+
width=0.003,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
ax.set_title(title)
|
|
232
|
+
ax.set_xlim(0, w)
|
|
233
|
+
ax.set_ylim(h, 0)
|
|
234
|
+
ax.axis("off")
|
|
235
|
+
|
|
236
|
+
return fig, ax
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def compare_flows(
|
|
240
|
+
flow_pred: np.ndarray,
|
|
241
|
+
flow_gt: np.ndarray,
|
|
242
|
+
axes: Optional[Union[np.ndarray, List[Axes]]] = None,
|
|
243
|
+
max_flow: Optional[float] = None,
|
|
244
|
+
convention: str = "middlebury",
|
|
245
|
+
titles: Tuple[str, str, str] = ("Prediction", "Ground Truth", "EPE Error"),
|
|
246
|
+
figsize: Tuple[int, int] = (16, 5),
|
|
247
|
+
) -> Tuple[Figure, np.ndarray]:
|
|
248
|
+
"""
|
|
249
|
+
Compare prediction vs ground truth and visualize error.
|
|
250
|
+
"""
|
|
251
|
+
if axes is None:
|
|
252
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
253
|
+
else:
|
|
254
|
+
fig = axes[0].get_figure()
|
|
255
|
+
if len(axes) != 3:
|
|
256
|
+
raise ValueError("compare_flows requires 3 axes")
|
|
257
|
+
|
|
258
|
+
# 1. Visualize Prediction
|
|
259
|
+
visualize_flow(
|
|
260
|
+
flow_pred, ax=axes[0], title=titles[0], max_flow=max_flow, convention=convention
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# 2. Visualize GT
|
|
264
|
+
visualize_flow(
|
|
265
|
+
flow_gt, ax=axes[1], title=titles[1], max_flow=max_flow, convention=convention
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# 3. Visualize Endpoint Error (EPE)
|
|
269
|
+
epe = np.linalg.norm(flow_pred - flow_gt, axis=2)
|
|
270
|
+
|
|
271
|
+
# We use a valid mask to compute mean error only on valid pixels
|
|
272
|
+
# Otherwise NaN errors will break the visualization
|
|
273
|
+
mask = (np.abs(flow_gt[:, :, 0]) < 1e9) & (np.abs(flow_gt[:, :, 1]) < 1e9)
|
|
274
|
+
if mask.sum() > 0:
|
|
275
|
+
mean_epe = epe[mask].mean()
|
|
276
|
+
else:
|
|
277
|
+
mean_epe = 0.0
|
|
278
|
+
|
|
279
|
+
im_err = axes[2].imshow(epe, cmap="magma")
|
|
280
|
+
axes[2].set_title(f"{titles[2]} (Mean: {mean_epe:.2f})")
|
|
281
|
+
axes[2].axis("off")
|
|
282
|
+
|
|
283
|
+
plt.colorbar(im_err, ax=axes[2], fraction=0.046, pad=0.04)
|
|
284
|
+
|
|
285
|
+
plt.tight_layout()
|
|
286
|
+
return fig, axes
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def concatenate_flows(
|
|
290
|
+
arrays: List[np.ndarray],
|
|
291
|
+
positions: List[Tuple[int, int]],
|
|
292
|
+
kwargs_list: Optional[List[Dict[str, Any]]] = None,
|
|
293
|
+
) -> np.ndarray:
|
|
294
|
+
"""
|
|
295
|
+
Concatenates strictly typed images (RGB) and flow fields (Float) into a grid.
|
|
296
|
+
|
|
297
|
+
The function handles color space conversions internally so you can use standard
|
|
298
|
+
RGB colors for text and input images, while leveraging OpenCV for text rendering.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
arrays: List of numpy arrays.
|
|
302
|
+
- 3 Channels: MUST be RGB uint8 (Standard Image).
|
|
303
|
+
- 2 Channels: MUST be Float Optical Flow.
|
|
304
|
+
positions: List of (col, row) tuples determining grid placement.
|
|
305
|
+
kwargs_list: Optional list of dictionaries (one per array) for customization.
|
|
306
|
+
Supported keys:
|
|
307
|
+
|
|
308
|
+
[Visual Params]
|
|
309
|
+
- 'caption': str -> Text to write on the image.
|
|
310
|
+
- 'font_scale': float -> Size of text (default: 1.0).
|
|
311
|
+
- 'font_thickness': int -> Thickness of text (default: 2).
|
|
312
|
+
- 'font_color': tuple -> RGB tuple (default: (255, 255, 255)).
|
|
313
|
+
- 'text_pos': tuple -> (x, y) bottom-left position (default: (20, 40)).
|
|
314
|
+
|
|
315
|
+
[Flow Params (Only used if array has 2 channels)]
|
|
316
|
+
- 'max_flow': float -> Normalization max magnitude.
|
|
317
|
+
- 'convention': str -> 'middlebury' or 'kitti'.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
np.ndarray: RGB image (H_grid, W_grid, 3) ready for Matplotlib/TensorBoard.
|
|
321
|
+
"""
|
|
322
|
+
import cv2
|
|
323
|
+
|
|
324
|
+
H, W = arrays[0].shape[:2]
|
|
325
|
+
N = len(arrays)
|
|
326
|
+
|
|
327
|
+
if kwargs_list is None:
|
|
328
|
+
kwargs_list = [{}] * N
|
|
329
|
+
|
|
330
|
+
# Canvas Size calculation
|
|
331
|
+
max_u = max(p[0] for p in positions)
|
|
332
|
+
max_v = max(p[1] for p in positions)
|
|
333
|
+
|
|
334
|
+
# Create internal BGR canvas for OpenCV operations
|
|
335
|
+
frame_bgr = np.zeros(((max_v + 1) * H, (max_u + 1) * W, 3), dtype=np.uint8)
|
|
336
|
+
|
|
337
|
+
for arr, (u, v), kw in zip(arrays, positions, kwargs_list):
|
|
338
|
+
# --- 1. Process Content (Strict Types -> BGR) ---
|
|
339
|
+
if arr.shape[2] == 2:
|
|
340
|
+
# TYPE: FLOW (Float) -> RGB -> BGR
|
|
341
|
+
flow_args = {k: v for k, v in kw.items() if k in ["max_flow", "convention"]}
|
|
342
|
+
rgb_flow = flow_to_color(arr, **flow_args)
|
|
343
|
+
img_chunk_bgr = cv2.cvtColor(rgb_flow, cv2.COLOR_RGB2BGR)
|
|
344
|
+
|
|
345
|
+
elif arr.shape[2] == 3:
|
|
346
|
+
# TYPE: IMAGE (RGB uint8) -> BGR
|
|
347
|
+
img_chunk_bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
|
348
|
+
|
|
349
|
+
else:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Array has {arr.shape[2]} channels. Expected 2 (Flow) or 3 (RGB Image)."
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# --- 2. Add Text (OpenCV draws on BGR) ---
|
|
355
|
+
caption = kw.get("caption", None)
|
|
356
|
+
if caption:
|
|
357
|
+
# User specifies color in RGB (e.g., Red=(255,0,0))
|
|
358
|
+
# We flip to BGR because we are drawing on a BGR image
|
|
359
|
+
font_color_rgb = kw.get("font_color", (255, 255, 255))
|
|
360
|
+
font_color_bgr = font_color_rgb[::-1]
|
|
361
|
+
|
|
362
|
+
cv2.putText(
|
|
363
|
+
img_chunk_bgr,
|
|
364
|
+
str(caption),
|
|
365
|
+
kw.get("text_pos", (20, 40)),
|
|
366
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
367
|
+
fontScale=kw.get("font_scale", 1.0),
|
|
368
|
+
color=font_color_bgr,
|
|
369
|
+
thickness=kw.get("font_thickness", 2),
|
|
370
|
+
lineType=cv2.LINE_AA,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# --- 3. Place in Grid ---
|
|
374
|
+
y, x = v * H, u * W
|
|
375
|
+
frame_bgr[y : y + H, x : x + W] = img_chunk_bgr
|
|
376
|
+
|
|
377
|
+
# Convert final result back to RGB for return
|
|
378
|
+
return cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: torch-nvidia-of-sdk
|
|
3
|
+
Version: 5.0.0
|
|
4
|
+
Summary: PyTorch bindings for NVIDIA Optical Flow SDK, providing hardware-accelerated optical flow computation with PyTorch end-to-end integration in Nvidia and Python.
|
|
5
|
+
Author-Email: Juan Montesinos <jfmontgar@gmail.com>
|
|
6
|
+
Requires-Python: >=3.13
|
|
7
|
+
Requires-Dist: imageio[ffmpeg]>=2.37.2
|
|
8
|
+
Requires-Dist: matplotlib>=3.10.8
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: torch>=2.10.0
|
|
11
|
+
Requires-Dist: tqdm>=4.67.2
|
|
12
|
+
Provides-Extra: full
|
|
13
|
+
Requires-Dist: opencv-contrib-python-headless>=4.13.0.90; extra == "full"
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# Torch Optical Flow
|
|
17
|
+
|
|
18
|
+
PyTorch bindings for NVIDIA Optical Flow SDK, providing hardware-accelerated optical flow computation with PyTorch
|
|
19
|
+
end-to-end integration in Nvidia and Python.
|
|
20
|
+
|
|
21
|
+
Please read more about the NVIDIA Optical Flow SDK here: [https://developer.nvidia.com/optical-flow-sdk](https://developer.nvidia.com/optical-flow-sdk)
|
|
22
|
+
|
|
23
|
+
# What's this repo about?
|
|
24
|
+
- Hardware-accelerated optical flow using a special processor in Nvidia GPUs. No gradients are computed, this is for inference only.
|
|
25
|
+
- Frame interpolation and ROI, or other additional content in the SDK is not supported.
|
|
26
|
+
- Configurable speed (slow, medium, fast) vs and grid size (1, 2, 4)
|
|
27
|
+
- Support for various ABGR8 format, namely, RGB images.
|
|
28
|
+
- End-to-end GPU processing with PyTorch.
|
|
29
|
+
- Biderectional optical flow computation (forward and backward) in a single call. This is supported by the SDK, but not exposed in the wrappers they provide.
|
|
30
|
+
|
|
31
|
+
The package comes with basic functionality for optical flow:
|
|
32
|
+
- .flo reader and writer
|
|
33
|
+
- Optical Flow common metrics
|
|
34
|
+
- Visualization utilities
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Requirements
|
|
38
|
+
### System Requirements
|
|
39
|
+
- NVIDIA GPU with Optical Flow SDK support (Turing, Ampere or Ada)
|
|
40
|
+
- Tested on linux (Ubuntu), the SDK is compatible with windows too. Read optical_flow_sdk>Read_Me.pdf for windows instructions.
|
|
41
|
+
|
|
42
|
+
### Software Requirements
|
|
43
|
+
- CUDA toolkit >=10.2
|
|
44
|
+
- Linux drivers "nvidia-smi" >=528.85
|
|
45
|
+
- GCC >= 5.1
|
|
46
|
+
- CMake >= 3.14
|
|
47
|
+
- When you pip install torch, it comes with its own CUDA binaries. Get the same or higher CUDA toolkit version as your PyTorch installation.
|
|
48
|
+
|
|
49
|
+
# Installation
|
|
50
|
+
## pip / uv
|
|
51
|
+
|
|
52
|
+
## Build from Source
|
|
53
|
+
This repository uses uv.
|
|
54
|
+
A oneshot comand to build, install and test the package would be:
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
rm -rf build _skbuild .venv && CC=gcc CXX=g++ uv sync --extra full --reinstall-package torch-nvidia-of-sdk && uv run examples/minimal_example.py
|
|
58
|
+
```
|
|
59
|
+
`--reinstall-package` forces `uv` to re-compile the package. Clearing caches is not really needed but I'm paranoid.
|
|
60
|
+
`--extra full` is analogous to pip extras `pip install torch-nvidia-of-sdk[full]`. It just adds headless opencv for visualization
|
|
61
|
+
## Compiling your own wheel
|
|
62
|
+
`CC=gcc CXX=g++ uv build --wheel --package torch-nvidia-of-sdk` will build a wheel in `dist/` that you can install with pip.
|
|
63
|
+
|
|
64
|
+
# Quick Start
|
|
65
|
+
|
|
66
|
+
Try the minimal example to get started quickly:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# Run the minimal example (uses sample frames from assets/)
|
|
70
|
+
uv run examples/minimal_example.py
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
This will:
|
|
74
|
+
1. Load two sample frames from the `assets/` directory
|
|
75
|
+
2. Compute optical flow using NVOF
|
|
76
|
+
3. Generate visualizations and save results to `output/`
|
|
77
|
+
|
|
78
|
+
See [`examples/README.md`](examples/README.md) for more examples and tutorials.
|
|
79
|
+
|
|
80
|
+
# Basic Usage
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import torch
|
|
84
|
+
import numpy as np
|
|
85
|
+
from of import TorchNVOpticalFlow
|
|
86
|
+
from of.io import read_flo, write_flo
|
|
87
|
+
from of.visualization import flow_to_color
|
|
88
|
+
|
|
89
|
+
# Load your images (RGB format, uint8)
|
|
90
|
+
img1 = torch.from_numpy(np.array(...)).cuda() # Shape: (H, W, 3)
|
|
91
|
+
img2 = torch.from_numpy(np.array(...)).cuda()
|
|
92
|
+
|
|
93
|
+
# Initialize optical flow engine
|
|
94
|
+
flow_engine = TorchNVOpticalFlow(
|
|
95
|
+
width=img1.shape[1],
|
|
96
|
+
height=img1.shape[0],
|
|
97
|
+
gpu_id=0,
|
|
98
|
+
preset="medium", # "slow", "medium", or "fast"
|
|
99
|
+
grid_size=1, # 1, 2, or 4
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Compute optical flow
|
|
103
|
+
flow = flow_engine.compute_flow(img1, img2, upsample=True)
|
|
104
|
+
|
|
105
|
+
# Flow is a (H, W, 2) tensor where flow[..., 0] is x-displacement, flow[..., 1] is y-displacement
|
|
106
|
+
print(f"Flow shape: {flow.shape}")
|
|
107
|
+
|
|
108
|
+
# Visualize flow as RGB image
|
|
109
|
+
flow_rgb = flow_to_color(flow.cpu().numpy())
|
|
110
|
+
|
|
111
|
+
# Save flow to .flo file
|
|
112
|
+
write_flo("output_flow.flo", flow)
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
# API Reference
|
|
116
|
+
|
|
117
|
+
## Core Class: `TorchNVOpticalFlow`
|
|
118
|
+
|
|
119
|
+
### Constructor
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
TorchNVOpticalFlow(
|
|
123
|
+
width: int,
|
|
124
|
+
height: int,
|
|
125
|
+
gpu_id: int = 0,
|
|
126
|
+
preset: str = "medium",
|
|
127
|
+
grid_size: int = 1,
|
|
128
|
+
bidirectional: bool = False
|
|
129
|
+
)
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
**Parameters:**
|
|
133
|
+
- `width`: Width of input images in pixels
|
|
134
|
+
- `height`: Height of input images in pixels
|
|
135
|
+
- `gpu_id`: CUDA device ID (default: 0)
|
|
136
|
+
- `preset`: Speed/quality preset. Options:
|
|
137
|
+
- `"slow"`: Highest quality, slowest
|
|
138
|
+
- `"medium"`: Balanced (recommended)
|
|
139
|
+
- `"fast"`: Fastest, lower quality
|
|
140
|
+
- `grid_size`: Output grid size. Options: 1, 2, or 4
|
|
141
|
+
- 1: Full resolution output (default)
|
|
142
|
+
- 2/4: Downsampled output (faster, use with `upsample=True` to restore resolution)
|
|
143
|
+
- `bidirectional`: Enable bidirectional flow computation (forward and backward)
|
|
144
|
+
|
|
145
|
+
### Methods
|
|
146
|
+
|
|
147
|
+
#### `compute_flow(input, reference, upsample=True)`
|
|
148
|
+
|
|
149
|
+
Compute forward optical flow between two frames.
|
|
150
|
+
|
|
151
|
+
**Parameters:**
|
|
152
|
+
- `input`: First frame as CUDA tensor of shape `(H, W, 4)`, dtype `uint8`, RGBA format
|
|
153
|
+
- `reference`: Second frame as CUDA tensor of shape `(H, W, 4)`, dtype `uint8`, RGBA format
|
|
154
|
+
- `upsample`: If True and grid_size > 1, upsample flow to full resolution (default: True)
|
|
155
|
+
|
|
156
|
+
**Returns:**
|
|
157
|
+
- `torch.Tensor`: Optical flow of shape `(H, W, 2)`, dtype `float32`
|
|
158
|
+
- `flow[..., 0]`: Horizontal displacement (x)
|
|
159
|
+
- `flow[..., 1]`: Vertical displacement (y)
|
|
160
|
+
|
|
161
|
+
**Example:**
|
|
162
|
+
```python
|
|
163
|
+
flow = flow_engine.compute_flow(img1_rgba, img2_rgba, upsample=True)
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
#### `compute_flow_bidirectional(input, reference, upsample=True)`
|
|
167
|
+
|
|
168
|
+
Compute both forward and backward optical flow.
|
|
169
|
+
|
|
170
|
+
**Parameters:**
|
|
171
|
+
- `input`: First frame as CUDA tensor of shape `(H, W, 4)`, dtype `uint8`, RGBA format
|
|
172
|
+
- `reference`: Second frame as CUDA tensor of shape `(H, W, 4)`, dtype `uint8`, RGBA format
|
|
173
|
+
- `upsample`: If True and grid_size > 1, upsample flows to full resolution (default: True)
|
|
174
|
+
|
|
175
|
+
**Returns:**
|
|
176
|
+
- `Tuple[torch.Tensor, torch.Tensor]`: Forward and backward flows, each of shape `(H, W, 2)`
|
|
177
|
+
|
|
178
|
+
**Example:**
|
|
179
|
+
```python
|
|
180
|
+
forward_flow, backward_flow = flow_engine.compute_flow_bidirectional(
|
|
181
|
+
img1_rgba, img2_rgba, upsample=True
|
|
182
|
+
)
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
#### `output_shape()`
|
|
186
|
+
|
|
187
|
+
Get the output shape for the current configuration.
|
|
188
|
+
|
|
189
|
+
**Returns:**
|
|
190
|
+
- `List[int]`: Output shape as `[height, width, 2]`
|
|
191
|
+
|
|
192
|
+
---
|
|
193
|
+
|
|
194
|
+
## I/O Utilities (`of.io`)
|
|
195
|
+
|
|
196
|
+
### `read_flo(filepath)`
|
|
197
|
+
|
|
198
|
+
Read optical flow from `.flo` file (Middlebury format).
|
|
199
|
+
|
|
200
|
+
**Parameters:**
|
|
201
|
+
- `filepath`: Path to `.flo` file (str or Path)
|
|
202
|
+
|
|
203
|
+
**Returns:**
|
|
204
|
+
- `np.ndarray`: Flow array of shape `(H, W, 2)`, dtype `float32`
|
|
205
|
+
|
|
206
|
+
### `write_flo(filepath, flow)`
|
|
207
|
+
|
|
208
|
+
Write optical flow to `.flo` file (Middlebury format).
|
|
209
|
+
|
|
210
|
+
**Parameters:**
|
|
211
|
+
- `filepath`: Output file path (str or Path)
|
|
212
|
+
- `flow`: Flow array of shape `(H, W, 2)` (numpy array or torch tensor)
|
|
213
|
+
|
|
214
|
+
# Examples
|
|
215
|
+
|
|
216
|
+
This repository includes several examples in the `examples/` directory:
|
|
217
|
+
|
|
218
|
+
See [`examples/README.md`](examples/README.md) for detailed documentation and usage instructions.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
of/__init__.py,sha256=4UKSViAFdZWuiGfxIbactVbQkQeA7hnoGZd2_PpdveM,105
|
|
2
|
+
of/datasets.py,sha256=Yf2l0vNEkzkEVo0Yn3dkV7RSPoTxos_WOIPlOHap5Nc,7629
|
|
3
|
+
of/io.py,sha256=AXMlM2zh6zAt5l4hbwTsHq3lpXXGA_8kX3mtTWZR3hs,2075
|
|
4
|
+
of/metrics.py,sha256=djeP2XkYt5FcGkQsdKhtynhSbBOLpemm9iz6I3Td3mg,5497
|
|
5
|
+
of/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
of/visualization.py,sha256=Wc-d7-QvTzvKIBSGC0eE62qEI9R95E1MbispRM_2uvI,11621
|
|
7
|
+
of/methods/__init__.py,sha256=wabPZD0BntWyXUcrAvbjY0hfMg1ArtHtD-bMTyvmgS8,77
|
|
8
|
+
of/methods/nvidia_sdk.py,sha256=0nZgsHCZrJXiFsPn6RhVdhrPWoHKvrQr2Pru0uLyRW4,2363
|
|
9
|
+
of/methods/nvof_torch.cpython-313-x86_64-linux-gnu.so,sha256=54urcZT1reYkrnfYxHK2nROZOa98otjsHk2A0R0Xuqw,313808
|
|
10
|
+
torch_nvidia_of_sdk.libs/libcuda-df918354.so.580.105.08,sha256=b022zSCMvDgYt5zAe0-ijrqSxG2JFnJMfFzmm063juc,96295392
|
|
11
|
+
torch_nvidia_of_sdk.libs/libcudart-381c0faa.so.12.9.37,sha256=aSNYYYNLW27faxYhvRRkFbGQDgOiSvsmnkYOVnP7oyc,753656
|
|
12
|
+
torch_nvidia_of_sdk-5.0.0.dist-info/METADATA,sha256=I6OgY1-CAzrcjVdF1tRAVzxrhmVfeeo77f2qbTPxVPE,7104
|
|
13
|
+
torch_nvidia_of_sdk-5.0.0.dist-info/WHEEL,sha256=JVXwG5HcbWr_fVpHgfrVgL0hYKvnp2fH_zuuklvZsD4,115
|
|
14
|
+
torch_nvidia_of_sdk-5.0.0.dist-info/RECORD,,
|
|
15
|
+
torch_nvidia_of_sdk-5.0.0.dist-info/sboms/auditwheel.cdx.json,sha256=3Gc-fQ9TGDEcEAu3DXWD50ff7jSXkLzkbsTPjFwfImY,1956
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"bomFormat": "CycloneDX", "specVersion": "1.4", "version": 1, "metadata": {"component": {"type": "library", "bom-ref": "pkg:pypi/torch_nvidia_of_sdk@5.0.0?file_name=torch_nvidia_of_sdk-5.0.0-py3-none-manylinux_2_34_x86_64.whl", "name": "torch_nvidia_of_sdk", "version": "5.0.0", "purl": "pkg:pypi/torch_nvidia_of_sdk@5.0.0?file_name=torch_nvidia_of_sdk-5.0.0-py3-none-manylinux_2_34_x86_64.whl"}, "tools": [{"name": "auditwheel", "version": "6.6.0"}]}, "components": [{"type": "library", "bom-ref": "pkg:pypi/torch_nvidia_of_sdk@5.0.0?file_name=torch_nvidia_of_sdk-5.0.0-py3-none-manylinux_2_34_x86_64.whl", "name": "torch_nvidia_of_sdk", "version": "5.0.0", "purl": "pkg:pypi/torch_nvidia_of_sdk@5.0.0?file_name=torch_nvidia_of_sdk-5.0.0-py3-none-manylinux_2_34_x86_64.whl"}, {"type": "library", "bom-ref": "pkg:deb/ubuntu/libnvidia-compute-580@580.105.08-0ubuntu1#4bd77591991c90dfe7081d8439e3bbfc934181610ce8d8a0ff77e83466d8866b", "name": "libnvidia-compute-580", "version": "580.105.08-0ubuntu1", "purl": "pkg:deb/ubuntu/libnvidia-compute-580@580.105.08-0ubuntu1"}, {"type": "library", "bom-ref": "pkg:deb/ubuntu/cuda-cudart-12-9@12.9.37-1#52366d11334f0bdf1ec18cb5b3389421d9afd26e46a0ea03841c7a6067e9ea15", "name": "cuda-cudart-12-9", "version": "12.9.37-1", "purl": "pkg:deb/ubuntu/cuda-cudart-12-9@12.9.37-1"}], "dependencies": [{"ref": "pkg:pypi/torch_nvidia_of_sdk@5.0.0?file_name=torch_nvidia_of_sdk-5.0.0-py3-none-manylinux_2_34_x86_64.whl", "dependsOn": ["pkg:deb/ubuntu/libnvidia-compute-580@580.105.08-0ubuntu1#4bd77591991c90dfe7081d8439e3bbfc934181610ce8d8a0ff77e83466d8866b", "pkg:deb/ubuntu/cuda-cudart-12-9@12.9.37-1#52366d11334f0bdf1ec18cb5b3389421d9afd26e46a0ea03841c7a6067e9ea15"]}, {"ref": "pkg:deb/ubuntu/libnvidia-compute-580@580.105.08-0ubuntu1#4bd77591991c90dfe7081d8439e3bbfc934181610ce8d8a0ff77e83466d8866b"}, {"ref": "pkg:deb/ubuntu/cuda-cudart-12-9@12.9.37-1#52366d11334f0bdf1ec18cb5b3389421d9afd26e46a0ea03841c7a6067e9ea15"}]}
|
|
Binary file
|
|
Binary file
|