vlalab 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlalab/__init__.py +82 -0
- vlalab/adapters/__init__.py +10 -0
- vlalab/adapters/converter.py +146 -0
- vlalab/adapters/dp_adapter.py +181 -0
- vlalab/adapters/groot_adapter.py +148 -0
- vlalab/apps/__init__.py +1 -0
- vlalab/apps/streamlit/__init__.py +1 -0
- vlalab/apps/streamlit/app.py +103 -0
- vlalab/apps/streamlit/pages/__init__.py +1 -0
- vlalab/apps/streamlit/pages/dataset_viewer.py +322 -0
- vlalab/apps/streamlit/pages/inference_viewer.py +360 -0
- vlalab/apps/streamlit/pages/latency_viewer.py +256 -0
- vlalab/cli.py +137 -0
- vlalab/core.py +672 -0
- vlalab/logging/__init__.py +10 -0
- vlalab/logging/jsonl_writer.py +114 -0
- vlalab/logging/run_loader.py +216 -0
- vlalab/logging/run_logger.py +343 -0
- vlalab/schema/__init__.py +17 -0
- vlalab/schema/run.py +162 -0
- vlalab/schema/step.py +177 -0
- vlalab/viz/__init__.py +9 -0
- vlalab/viz/mpl_fonts.py +161 -0
- vlalab-0.1.0.dist-info/METADATA +443 -0
- vlalab-0.1.0.dist-info/RECORD +29 -0
- vlalab-0.1.0.dist-info/WHEEL +5 -0
- vlalab-0.1.0.dist-info/entry_points.txt +2 -0
- vlalab-0.1.0.dist-info/licenses/LICENSE +21 -0
- vlalab-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Dataset Viewer
|
|
3
|
+
|
|
4
|
+
Browse and analyze training/evaluation datasets in Zarr format.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import streamlit as st
|
|
8
|
+
import numpy as np
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from matplotlib.gridspec import GridSpec
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional, List
|
|
13
|
+
|
|
14
|
+
# Setup matplotlib fonts
|
|
15
|
+
try:
|
|
16
|
+
from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
|
|
17
|
+
setup_matplotlib_fonts(verbose=False)
|
|
18
|
+
except Exception:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ZarrDatasetViewer:
|
|
23
|
+
"""Viewer for Zarr datasets (Diffusion Policy format)."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, zarr_path: str):
|
|
26
|
+
self.zarr_path = zarr_path
|
|
27
|
+
self.valid = False
|
|
28
|
+
self._load_data()
|
|
29
|
+
|
|
30
|
+
def _load_data(self):
|
|
31
|
+
"""Load Zarr dataset."""
|
|
32
|
+
try:
|
|
33
|
+
import zarr
|
|
34
|
+
self.root = zarr.open(self.zarr_path, mode='r')
|
|
35
|
+
self.data = self.root['data']
|
|
36
|
+
self.meta = self.root['meta']
|
|
37
|
+
self.episode_ends = self.meta['episode_ends'][:]
|
|
38
|
+
self.valid = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
st.error("请安装 zarr: pip install zarr")
|
|
41
|
+
self.valid = False
|
|
42
|
+
except Exception as e:
|
|
43
|
+
st.error(f"无法加载 Zarr 文件: {e}")
|
|
44
|
+
self.valid = False
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
# Analyze fields
|
|
48
|
+
self.image_keys = []
|
|
49
|
+
self.lowdim_keys = []
|
|
50
|
+
self.action_key = 'action'
|
|
51
|
+
|
|
52
|
+
for key in self.data.keys():
|
|
53
|
+
arr = self.data[key]
|
|
54
|
+
if key == 'action':
|
|
55
|
+
self.action_key = key
|
|
56
|
+
elif len(arr.shape) == 4: # (T, H, W, C)
|
|
57
|
+
self.image_keys.append(key)
|
|
58
|
+
else:
|
|
59
|
+
self.lowdim_keys.append(key)
|
|
60
|
+
|
|
61
|
+
def get_episode_slice(self, episode_idx: int) -> slice:
|
|
62
|
+
"""Get slice for an episode."""
|
|
63
|
+
start_idx = 0 if episode_idx == 0 else self.episode_ends[episode_idx - 1]
|
|
64
|
+
end_idx = self.episode_ends[episode_idx]
|
|
65
|
+
return slice(int(start_idx), int(end_idx))
|
|
66
|
+
|
|
67
|
+
def get_episode_data(self, episode_idx: int) -> dict:
|
|
68
|
+
"""Get all data for an episode."""
|
|
69
|
+
s = self.get_episode_slice(episode_idx)
|
|
70
|
+
data = {}
|
|
71
|
+
for key in self.data.keys():
|
|
72
|
+
data[key] = self.data[key][s]
|
|
73
|
+
return data
|
|
74
|
+
|
|
75
|
+
def plot_images_grid(self, episode_idx: int, step_interval: int = 5, max_frames: int = 20):
|
|
76
|
+
"""Plot image grid for an episode."""
|
|
77
|
+
if not self.image_keys:
|
|
78
|
+
st.warning("数据集中没有图像数据")
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
episode_data = self.get_episode_data(episode_idx)
|
|
82
|
+
image_key = self.image_keys[0]
|
|
83
|
+
images = episode_data[image_key]
|
|
84
|
+
|
|
85
|
+
total_frames = len(images)
|
|
86
|
+
frame_indices = list(range(0, total_frames, step_interval))[:max_frames]
|
|
87
|
+
n_frames = len(frame_indices)
|
|
88
|
+
|
|
89
|
+
n_cols = 5
|
|
90
|
+
n_rows = (n_frames + n_cols - 1) // n_cols
|
|
91
|
+
|
|
92
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
|
|
93
|
+
st.write(f"**Episode {episode_idx} - 图像概览 ({image_key})**")
|
|
94
|
+
|
|
95
|
+
axes = np.atleast_2d(axes)
|
|
96
|
+
|
|
97
|
+
for i, frame_idx in enumerate(frame_indices):
|
|
98
|
+
row, col = i // n_cols, i % n_cols
|
|
99
|
+
ax = axes[row, col]
|
|
100
|
+
image = images[frame_idx]
|
|
101
|
+
if image.shape[0] in [1, 3] and len(image.shape) == 3:
|
|
102
|
+
image = np.transpose(image, (1, 2, 0))
|
|
103
|
+
ax.imshow(image)
|
|
104
|
+
ax.set_title(f'Step {frame_idx}', fontsize=8)
|
|
105
|
+
ax.axis('off')
|
|
106
|
+
|
|
107
|
+
for i in range(n_frames, n_rows * n_cols):
|
|
108
|
+
row, col = i // n_cols, i % n_cols
|
|
109
|
+
axes[row, col].axis('off')
|
|
110
|
+
|
|
111
|
+
plt.tight_layout()
|
|
112
|
+
st.pyplot(fig)
|
|
113
|
+
plt.close(fig)
|
|
114
|
+
|
|
115
|
+
def plot_actions_summary(self, episode_idx: int):
|
|
116
|
+
"""Plot action summary for an episode."""
|
|
117
|
+
episode_data = self.get_episode_data(episode_idx)
|
|
118
|
+
actions = episode_data['action']
|
|
119
|
+
T = len(actions)
|
|
120
|
+
action_dim = actions.shape[1]
|
|
121
|
+
|
|
122
|
+
fig = plt.figure(figsize=(12, 8))
|
|
123
|
+
gs = GridSpec(2, 2, figure=fig)
|
|
124
|
+
|
|
125
|
+
st.write(f"**Episode {episode_idx} - 动作全局分析**")
|
|
126
|
+
|
|
127
|
+
# Position time series
|
|
128
|
+
ax1 = fig.add_subplot(gs[0, :])
|
|
129
|
+
for i, label in enumerate(['x', 'y', 'z']):
|
|
130
|
+
if i < action_dim:
|
|
131
|
+
ax1.plot(actions[:, i], label=label, alpha=0.8)
|
|
132
|
+
ax1.set_title('位置变化 (Position)')
|
|
133
|
+
ax1.legend()
|
|
134
|
+
ax1.grid(True, alpha=0.3)
|
|
135
|
+
|
|
136
|
+
# 3D trajectory
|
|
137
|
+
ax2 = fig.add_subplot(gs[1, 0], projection='3d')
|
|
138
|
+
if action_dim >= 3:
|
|
139
|
+
ax2.plot(actions[:, 0], actions[:, 1], actions[:, 2], 'b-', alpha=0.6)
|
|
140
|
+
ax2.scatter(actions[0, 0], actions[0, 1], actions[0, 2], c='g', label='Start')
|
|
141
|
+
ax2.scatter(actions[-1, 0], actions[-1, 1], actions[-1, 2], c='r', label='End')
|
|
142
|
+
ax2.set_title('3D 轨迹')
|
|
143
|
+
ax2.set_xlabel('X')
|
|
144
|
+
ax2.set_ylabel('Y')
|
|
145
|
+
ax2.set_zlabel('Z')
|
|
146
|
+
|
|
147
|
+
# Gripper
|
|
148
|
+
ax3 = fig.add_subplot(gs[1, 1])
|
|
149
|
+
if action_dim >= 8:
|
|
150
|
+
ax3.plot(actions[:, 7], 'k-', linewidth=2)
|
|
151
|
+
ax3.set_title('夹爪 (Gripper)')
|
|
152
|
+
ax3.grid(True, alpha=0.3)
|
|
153
|
+
else:
|
|
154
|
+
ax3.text(0.5, 0.5, "无夹爪数据", ha='center')
|
|
155
|
+
|
|
156
|
+
plt.tight_layout()
|
|
157
|
+
st.pyplot(fig)
|
|
158
|
+
plt.close(fig)
|
|
159
|
+
|
|
160
|
+
def plot_interactive_step(self, episode_idx: int, step_idx: int):
|
|
161
|
+
"""Plot interactive step view."""
|
|
162
|
+
episode_data = self.get_episode_data(episode_idx)
|
|
163
|
+
actions = episode_data['action']
|
|
164
|
+
|
|
165
|
+
# Get image
|
|
166
|
+
image = None
|
|
167
|
+
if self.image_keys:
|
|
168
|
+
image_key = self.image_keys[0]
|
|
169
|
+
raw_img = episode_data[image_key][step_idx]
|
|
170
|
+
if raw_img.shape[0] in [1, 3] and len(raw_img.shape) == 3:
|
|
171
|
+
image = np.transpose(raw_img, (1, 2, 0))
|
|
172
|
+
else:
|
|
173
|
+
image = raw_img
|
|
174
|
+
|
|
175
|
+
# Get action
|
|
176
|
+
current_action = actions[step_idx]
|
|
177
|
+
action_dim = actions.shape[1]
|
|
178
|
+
|
|
179
|
+
# Layout
|
|
180
|
+
c1, c2 = st.columns([1.5, 1])
|
|
181
|
+
|
|
182
|
+
with c1:
|
|
183
|
+
st.markdown(f"#### 📸 相机视角 (Step {step_idx})")
|
|
184
|
+
if image is not None:
|
|
185
|
+
st.image(image, use_container_width=True)
|
|
186
|
+
else:
|
|
187
|
+
st.warning("无图像数据")
|
|
188
|
+
|
|
189
|
+
with c2:
|
|
190
|
+
st.markdown("#### 🤖 机器人状态")
|
|
191
|
+
st.info(f"Step: **{step_idx}** / {len(actions)-1}")
|
|
192
|
+
|
|
193
|
+
mc1, mc2, mc3 = st.columns(3)
|
|
194
|
+
mc1.metric("X", f"{current_action[0]:.3f}")
|
|
195
|
+
mc2.metric("Y", f"{current_action[1]:.3f}")
|
|
196
|
+
mc3.metric("Z", f"{current_action[2]:.3f}")
|
|
197
|
+
|
|
198
|
+
if action_dim >= 7:
|
|
199
|
+
st.markdown("**姿态 (Quaternion):**")
|
|
200
|
+
st.code(f"[{current_action[3]:.3f}, {current_action[4]:.3f}, {current_action[5]:.3f}, {current_action[6]:.3f}]")
|
|
201
|
+
|
|
202
|
+
if action_dim >= 8:
|
|
203
|
+
g_val = current_action[7]
|
|
204
|
+
g_state = "OPEN" if g_val > 0.5 else "CLOSED"
|
|
205
|
+
st.metric("夹爪 Gripper", f"{g_val:.3f}", delta=g_state)
|
|
206
|
+
|
|
207
|
+
st.divider()
|
|
208
|
+
|
|
209
|
+
# Trajectory view
|
|
210
|
+
st.markdown("#### 📍 轨迹同步视图")
|
|
211
|
+
|
|
212
|
+
fig = plt.figure(figsize=(14, 5))
|
|
213
|
+
gs = GridSpec(1, 2, figure=fig)
|
|
214
|
+
|
|
215
|
+
# 3D trajectory
|
|
216
|
+
ax1 = fig.add_subplot(gs[0, 0], projection='3d')
|
|
217
|
+
ax1.plot(actions[:,0], actions[:,1], actions[:,2], 'b-', alpha=0.2, linewidth=1, label='Path')
|
|
218
|
+
ax1.scatter(actions[0,0], actions[0,1], actions[0,2], c='g', s=20, alpha=0.5)
|
|
219
|
+
ax1.scatter(actions[-1,0], actions[-1,1], actions[-1,2], c='gray', s=20, alpha=0.5)
|
|
220
|
+
ax1.scatter(current_action[0], current_action[1], current_action[2],
|
|
221
|
+
c='r', s=150, edgecolors='k', label='Current', zorder=100)
|
|
222
|
+
ax1.set_title("3D 空间位置 (红点=当前)", fontsize=10)
|
|
223
|
+
ax1.set_xlabel('X')
|
|
224
|
+
ax1.set_ylabel('Y')
|
|
225
|
+
ax1.set_zlabel('Z')
|
|
226
|
+
|
|
227
|
+
# Time series
|
|
228
|
+
ax2 = fig.add_subplot(gs[0, 1])
|
|
229
|
+
steps = np.arange(len(actions))
|
|
230
|
+
for i, label in enumerate(['x', 'y', 'z']):
|
|
231
|
+
ax2.plot(steps, actions[:, i], label=label, alpha=0.6)
|
|
232
|
+
ax2.axvline(x=step_idx, color='r', linestyle='--', linewidth=2, label='Current Step')
|
|
233
|
+
ax2.set_title("XYZ 随时间变化 (红线=当前)", fontsize=10)
|
|
234
|
+
ax2.legend(loc='upper right', fontsize=8)
|
|
235
|
+
ax2.grid(True, alpha=0.3)
|
|
236
|
+
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
st.pyplot(fig)
|
|
239
|
+
plt.close(fig)
|
|
240
|
+
|
|
241
|
+
def plot_workspace_3d(self, sample_ratio: float = 0.1):
|
|
242
|
+
"""Plot 3D workspace distribution."""
|
|
243
|
+
if not self.valid:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
st.write("正在采样并生成 3D 工作空间...")
|
|
247
|
+
actions = self.data['action'][:]
|
|
248
|
+
n_samples = int(len(actions) * sample_ratio)
|
|
249
|
+
indices = np.random.choice(len(actions), n_samples, replace=False)
|
|
250
|
+
sampled = actions[indices]
|
|
251
|
+
|
|
252
|
+
fig = plt.figure(figsize=(10, 6))
|
|
253
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
254
|
+
img = ax.scatter(sampled[:, 0], sampled[:, 1], sampled[:, 2],
|
|
255
|
+
c=np.arange(n_samples), cmap='viridis', s=1, alpha=0.3)
|
|
256
|
+
ax.set_xlabel('X')
|
|
257
|
+
ax.set_ylabel('Y')
|
|
258
|
+
ax.set_zlabel('Z')
|
|
259
|
+
ax.set_title(f'Global Workspace ({n_samples} points)')
|
|
260
|
+
plt.colorbar(img, ax=ax, label='Time step order')
|
|
261
|
+
st.pyplot(fig)
|
|
262
|
+
plt.close(fig)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def render():
|
|
266
|
+
"""Render the dataset viewer page."""
|
|
267
|
+
st.title("📊 训练数据可视化")
|
|
268
|
+
|
|
269
|
+
# Path input
|
|
270
|
+
default_path = st.sidebar.text_input(
|
|
271
|
+
"Zarr 数据集路径",
|
|
272
|
+
value="/data0/vla-data/processed/Diffusion_Policy/data/001_assembly_chocolate/assembly_chocolate_300.zarr"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Initialize session state
|
|
276
|
+
if 'zarr_viz' not in st.session_state:
|
|
277
|
+
st.session_state.zarr_viz = None
|
|
278
|
+
|
|
279
|
+
# Load button
|
|
280
|
+
if st.sidebar.button("加载/重载数据集", type="primary"):
|
|
281
|
+
if Path(default_path).exists():
|
|
282
|
+
st.session_state.zarr_viz = ZarrDatasetViewer(default_path)
|
|
283
|
+
st.success(f"已加载: {Path(default_path).name}")
|
|
284
|
+
else:
|
|
285
|
+
st.error("路径不存在!")
|
|
286
|
+
|
|
287
|
+
viz = st.session_state.zarr_viz
|
|
288
|
+
|
|
289
|
+
if viz and viz.valid:
|
|
290
|
+
st.sidebar.markdown("---")
|
|
291
|
+
st.sidebar.info(f"Episodes: {len(viz.episode_ends)}\nTotal Steps: {viz.episode_ends[-1]}")
|
|
292
|
+
|
|
293
|
+
# Episode selection
|
|
294
|
+
selected_ep = st.sidebar.selectbox("选择 Episode", range(len(viz.episode_ends)))
|
|
295
|
+
|
|
296
|
+
# Tabs
|
|
297
|
+
tab1, tab2, tab3 = st.tabs(["🔍 详细交互", "📊 全局概览", "🧊 空间分布"])
|
|
298
|
+
|
|
299
|
+
with tab1:
|
|
300
|
+
st.markdown(f"### Episode {selected_ep} - 逐帧分析")
|
|
301
|
+
ep_data = viz.get_episode_data(selected_ep)
|
|
302
|
+
max_step = len(ep_data['action']) - 1
|
|
303
|
+
step_idx = st.slider("⏱️ 时间轴", 0, max_step, 0)
|
|
304
|
+
viz.plot_interactive_step(selected_ep, step_idx)
|
|
305
|
+
|
|
306
|
+
with tab2:
|
|
307
|
+
col1, col2 = st.columns([1, 4])
|
|
308
|
+
with col1:
|
|
309
|
+
st.caption("设置")
|
|
310
|
+
interval = st.slider("采样间隔", 1, 20, 5)
|
|
311
|
+
max_frames = st.slider("最大帧数", 5, 50, 20)
|
|
312
|
+
with col2:
|
|
313
|
+
viz.plot_images_grid(selected_ep, step_interval=interval, max_frames=max_frames)
|
|
314
|
+
st.divider()
|
|
315
|
+
viz.plot_actions_summary(selected_ep)
|
|
316
|
+
|
|
317
|
+
with tab3:
|
|
318
|
+
ratio = st.slider("采样比例", 0.01, 1.0, 0.1, 0.01)
|
|
319
|
+
if st.button("生成 3D 分布图"):
|
|
320
|
+
viz.plot_workspace_3d(ratio)
|
|
321
|
+
else:
|
|
322
|
+
st.info('👈 请先在左侧输入路径并点击"加载数据集"')
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Inference Run Viewer
|
|
3
|
+
|
|
4
|
+
Step-by-step replay of inference sessions with multi-camera support.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import streamlit as st
|
|
8
|
+
import numpy as np
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import json
|
|
12
|
+
import base64
|
|
13
|
+
import cv2
|
|
14
|
+
from typing import Optional, Dict, Any
|
|
15
|
+
|
|
16
|
+
import vlalab
|
|
17
|
+
|
|
18
|
+
# Setup matplotlib fonts
|
|
19
|
+
try:
|
|
20
|
+
from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
|
|
21
|
+
setup_matplotlib_fonts(verbose=False)
|
|
22
|
+
except Exception:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InferenceRunViewer:
|
|
27
|
+
"""Viewer for VLA-Lab inference runs."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, run_path: str):
|
|
30
|
+
self.run_path = Path(run_path)
|
|
31
|
+
self.valid = False
|
|
32
|
+
self.steps = []
|
|
33
|
+
self.meta = {}
|
|
34
|
+
self._load_data()
|
|
35
|
+
|
|
36
|
+
def _load_data(self):
|
|
37
|
+
"""Load run data."""
|
|
38
|
+
try:
|
|
39
|
+
meta_path = self.run_path / "meta.json"
|
|
40
|
+
steps_path = self.run_path / "steps.jsonl"
|
|
41
|
+
|
|
42
|
+
if meta_path.exists():
|
|
43
|
+
with open(meta_path, "r") as f:
|
|
44
|
+
self.meta = json.load(f)
|
|
45
|
+
|
|
46
|
+
if steps_path.exists():
|
|
47
|
+
self.steps = []
|
|
48
|
+
with open(steps_path, "r") as f:
|
|
49
|
+
for line in f:
|
|
50
|
+
if line.strip():
|
|
51
|
+
self.steps.append(json.loads(line))
|
|
52
|
+
|
|
53
|
+
self.valid = True
|
|
54
|
+
except Exception as e:
|
|
55
|
+
st.error(f"加载失败: {e}")
|
|
56
|
+
self.valid = False
|
|
57
|
+
|
|
58
|
+
def _get_latency_ms(self, timing_dict: Dict, key_base: str) -> float:
|
|
59
|
+
"""Get latency value in ms."""
|
|
60
|
+
new_key = f"{key_base}_ms"
|
|
61
|
+
if new_key in timing_dict and timing_dict[new_key] is not None:
|
|
62
|
+
return timing_dict[new_key]
|
|
63
|
+
if key_base in timing_dict and timing_dict[key_base] is not None:
|
|
64
|
+
return timing_dict[key_base] * 1000
|
|
65
|
+
return 0.0
|
|
66
|
+
|
|
67
|
+
def load_image_from_ref(self, image_ref: Dict) -> Optional[np.ndarray]:
|
|
68
|
+
"""Load image from image reference."""
|
|
69
|
+
if not image_ref:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
image_path = self.run_path / image_ref.get("path", "")
|
|
73
|
+
if not image_path.exists():
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
img = cv2.imread(str(image_path))
|
|
77
|
+
if img is None:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
81
|
+
|
|
82
|
+
def get_step_image(self, step_idx: int) -> Optional[np.ndarray]:
|
|
83
|
+
"""Get image for a step."""
|
|
84
|
+
if step_idx >= len(self.steps):
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
step = self.steps[step_idx]
|
|
88
|
+
obs = step.get("obs", {})
|
|
89
|
+
images = obs.get("images", [])
|
|
90
|
+
if images:
|
|
91
|
+
return self.load_image_from_ref(images[0])
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def get_step_images(self, step_idx: int) -> Dict[str, np.ndarray]:
|
|
95
|
+
"""Get all camera images for a step as {camera_name: image_rgb}."""
|
|
96
|
+
if step_idx >= len(self.steps):
|
|
97
|
+
return {}
|
|
98
|
+
step = self.steps[step_idx]
|
|
99
|
+
obs = step.get("obs", {})
|
|
100
|
+
images = obs.get("images", [])
|
|
101
|
+
out: Dict[str, np.ndarray] = {}
|
|
102
|
+
for ref in images or []:
|
|
103
|
+
if not isinstance(ref, dict):
|
|
104
|
+
continue
|
|
105
|
+
cam = ref.get("camera_name", "default")
|
|
106
|
+
img = self.load_image_from_ref(ref)
|
|
107
|
+
if img is not None:
|
|
108
|
+
out[str(cam)] = img
|
|
109
|
+
return out
|
|
110
|
+
|
|
111
|
+
def get_step_state(self, step_idx: int) -> np.ndarray:
|
|
112
|
+
"""Get state for a step."""
|
|
113
|
+
if step_idx >= len(self.steps):
|
|
114
|
+
return np.array([])
|
|
115
|
+
|
|
116
|
+
step = self.steps[step_idx]
|
|
117
|
+
obs = step.get("obs", {})
|
|
118
|
+
state = obs.get("state", [])
|
|
119
|
+
return np.array(state) if state else np.array([])
|
|
120
|
+
|
|
121
|
+
def get_step_action(self, step_idx: int) -> np.ndarray:
|
|
122
|
+
"""Get action for a step."""
|
|
123
|
+
if step_idx >= len(self.steps):
|
|
124
|
+
return np.array([])
|
|
125
|
+
|
|
126
|
+
step = self.steps[step_idx]
|
|
127
|
+
action_data = step.get("action", {})
|
|
128
|
+
values = action_data.get("values", [])
|
|
129
|
+
return np.array(values) if values else np.array([])
|
|
130
|
+
|
|
131
|
+
def get_all_states(self) -> np.ndarray:
|
|
132
|
+
"""Get all states as array."""
|
|
133
|
+
states = [self.get_step_state(i) for i in range(len(self.steps))]
|
|
134
|
+
valid_states = [s for s in states if len(s) > 0]
|
|
135
|
+
return np.array(valid_states) if valid_states else np.array([])
|
|
136
|
+
|
|
137
|
+
def plot_replay_frame(self, step_idx: int):
|
|
138
|
+
"""Plot replay frame for a step."""
|
|
139
|
+
if not self.valid or step_idx >= len(self.steps):
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
step = self.steps[step_idx]
|
|
143
|
+
current_state = self.get_step_state(step_idx)
|
|
144
|
+
pred_action = self.get_step_action(step_idx)
|
|
145
|
+
imgs = self.get_step_images(step_idx)
|
|
146
|
+
timing = step.get("timing", {})
|
|
147
|
+
|
|
148
|
+
# Layout
|
|
149
|
+
c1, c2 = st.columns([1, 1.5])
|
|
150
|
+
|
|
151
|
+
with c1:
|
|
152
|
+
st.markdown("#### 👁️ 模型视觉观测")
|
|
153
|
+
if imgs:
|
|
154
|
+
# Display in a grid (up to 3 columns) to support multi-camera runs
|
|
155
|
+
cam_names = list(imgs.keys())
|
|
156
|
+
n_cols = min(3, max(1, len(cam_names)))
|
|
157
|
+
cols = st.columns(n_cols)
|
|
158
|
+
for i, cam in enumerate(cam_names):
|
|
159
|
+
with cols[i % n_cols]:
|
|
160
|
+
img = imgs[cam]
|
|
161
|
+
st.image(
|
|
162
|
+
img,
|
|
163
|
+
caption=f"{cam} | Step {step_idx} | {img.shape}",
|
|
164
|
+
use_container_width=True,
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
st.warning("无图像数据")
|
|
168
|
+
|
|
169
|
+
# Timing metrics
|
|
170
|
+
if timing:
|
|
171
|
+
t_transport = self._get_latency_ms(timing, "transport_latency")
|
|
172
|
+
t_infer = self._get_latency_ms(timing, "inference_latency")
|
|
173
|
+
total = self._get_latency_ms(timing, "total_latency")
|
|
174
|
+
|
|
175
|
+
st.markdown("#### ⏱️ 时延诊断")
|
|
176
|
+
col_t1, col_t2, col_t3 = st.columns(3)
|
|
177
|
+
col_t1.metric("传输延迟", f"{t_transport:.0f} ms")
|
|
178
|
+
col_t2.metric("推理耗时", f"{t_infer:.0f} ms")
|
|
179
|
+
col_t3.metric("总回路", f"{total:.0f} ms")
|
|
180
|
+
|
|
181
|
+
if t_transport > 100:
|
|
182
|
+
st.error(f"⚠️ 传输延迟过高 ({t_transport:.0f}ms)! 检查网络或 SSH 隧道")
|
|
183
|
+
|
|
184
|
+
with c2:
|
|
185
|
+
st.markdown("#### 🗺️ 3D 动作规划")
|
|
186
|
+
|
|
187
|
+
if len(current_state) >= 3:
|
|
188
|
+
fig = plt.figure(figsize=(8, 6))
|
|
189
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
190
|
+
|
|
191
|
+
# Plot history
|
|
192
|
+
all_states = self.get_all_states()
|
|
193
|
+
if len(all_states) > 0 and all_states.shape[1] >= 3:
|
|
194
|
+
start = max(0, step_idx - 50)
|
|
195
|
+
hist = all_states[start:step_idx+1]
|
|
196
|
+
if len(hist) > 1:
|
|
197
|
+
ax.plot(hist[:,0], hist[:,1], hist[:,2], 'k-', alpha=0.3, label='History')
|
|
198
|
+
|
|
199
|
+
# Current position
|
|
200
|
+
ax.scatter(current_state[0], current_state[1], current_state[2],
|
|
201
|
+
c='b', s=100, label='Current')
|
|
202
|
+
|
|
203
|
+
# Predicted trajectory
|
|
204
|
+
if len(pred_action) > 0 and pred_action.ndim == 2 and pred_action.shape[1] >= 3:
|
|
205
|
+
ax.plot(pred_action[:,0], pred_action[:,1], pred_action[:,2],
|
|
206
|
+
'r--', linewidth=2, label='Pred')
|
|
207
|
+
ax.scatter(pred_action[-1,0], pred_action[-1,1], pred_action[-1,2],
|
|
208
|
+
c='r', marker='x', s=100)
|
|
209
|
+
|
|
210
|
+
ax.set_xlabel('X')
|
|
211
|
+
ax.set_ylabel('Y')
|
|
212
|
+
ax.set_zlabel('Z')
|
|
213
|
+
ax.legend()
|
|
214
|
+
|
|
215
|
+
# Set axis limits
|
|
216
|
+
if len(all_states) > 0 and all_states.shape[1] >= 3:
|
|
217
|
+
margin = 0.1
|
|
218
|
+
ax.set_xlim(all_states[:,0].min()-margin, all_states[:,0].max()+margin)
|
|
219
|
+
ax.set_ylim(all_states[:,1].min()-margin, all_states[:,1].max()+margin)
|
|
220
|
+
ax.set_zlim(all_states[:,2].min()-margin, all_states[:,2].max()+margin)
|
|
221
|
+
|
|
222
|
+
st.pyplot(fig)
|
|
223
|
+
plt.close(fig)
|
|
224
|
+
else:
|
|
225
|
+
st.warning("状态维度不足,无法绘制3D轨迹")
|
|
226
|
+
|
|
227
|
+
def plot_latency_analysis(self):
|
|
228
|
+
"""Plot latency analysis chart."""
|
|
229
|
+
if not self.steps:
|
|
230
|
+
st.warning("当前日志不包含步骤数据")
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
steps_range = range(len(self.steps))
|
|
234
|
+
trans_lats = []
|
|
235
|
+
infer_lats = []
|
|
236
|
+
total_lats = []
|
|
237
|
+
|
|
238
|
+
for step in self.steps:
|
|
239
|
+
timing = step.get("timing", {})
|
|
240
|
+
trans_lats.append(self._get_latency_ms(timing, "transport_latency"))
|
|
241
|
+
infer_lats.append(self._get_latency_ms(timing, "inference_latency"))
|
|
242
|
+
total_lats.append(self._get_latency_ms(timing, "total_latency"))
|
|
243
|
+
|
|
244
|
+
if not any(total_lats):
|
|
245
|
+
st.warning("当前日志不包含详细时延数据")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
fig, ax = plt.subplots(figsize=(12, 5))
|
|
249
|
+
ax.plot(steps_range, total_lats, color='gray', alpha=0.3, label='Total Loop')
|
|
250
|
+
ax.plot(steps_range, trans_lats, color='orange', label='Transport (Network)')
|
|
251
|
+
ax.plot(steps_range, infer_lats, color='blue', label='Inference (GPU)')
|
|
252
|
+
|
|
253
|
+
ax.set_title("时延组成分析 (ms)")
|
|
254
|
+
ax.set_xlabel("Step")
|
|
255
|
+
ax.set_ylabel("Latency (ms)")
|
|
256
|
+
ax.legend()
|
|
257
|
+
ax.grid(True, alpha=0.3)
|
|
258
|
+
ax.axhline(100, color='r', linestyle='--', alpha=0.5)
|
|
259
|
+
ax.text(0, 105, '100ms Alert', color='r', fontsize=8)
|
|
260
|
+
|
|
261
|
+
st.pyplot(fig)
|
|
262
|
+
plt.close(fig)
|
|
263
|
+
|
|
264
|
+
# Statistics
|
|
265
|
+
col1, col2, col3, col4 = st.columns(4)
|
|
266
|
+
valid_total = [t for t in total_lats if t > 0]
|
|
267
|
+
valid_trans = [t for t in trans_lats if t > 0]
|
|
268
|
+
valid_infer = [t for t in infer_lats if t > 0]
|
|
269
|
+
|
|
270
|
+
if valid_total:
|
|
271
|
+
col1.metric("平均总延迟", f"{np.mean(valid_total):.1f} ms")
|
|
272
|
+
if valid_trans:
|
|
273
|
+
col2.metric("平均传输延迟", f"{np.mean(valid_trans):.1f} ms")
|
|
274
|
+
if valid_infer:
|
|
275
|
+
col3.metric("平均推理延迟", f"{np.mean(valid_infer):.1f} ms")
|
|
276
|
+
if valid_total:
|
|
277
|
+
col4.metric("最大总延迟", f"{np.max(valid_total):.1f} ms")
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def render():
|
|
281
|
+
"""Render the inference viewer page."""
|
|
282
|
+
st.title("🔬 推理运行回放")
|
|
283
|
+
|
|
284
|
+
# Sidebar: show current runs directory
|
|
285
|
+
runs_dir = vlalab.get_runs_dir()
|
|
286
|
+
st.sidebar.markdown("### 日志目录")
|
|
287
|
+
st.sidebar.code(str(runs_dir))
|
|
288
|
+
|
|
289
|
+
# List projects
|
|
290
|
+
projects = vlalab.list_projects()
|
|
291
|
+
|
|
292
|
+
if not projects:
|
|
293
|
+
st.info(f"未找到任何项目。请先使用 `vlalab.init()` 创建运行记录。\n\n日志目录: `{runs_dir}`")
|
|
294
|
+
st.markdown("""
|
|
295
|
+
**提示**: 设置 `$VLALAB_DIR` 环境变量可更改日志存储位置。
|
|
296
|
+
|
|
297
|
+
```bash
|
|
298
|
+
export VLALAB_DIR=/path/to/your/logs
|
|
299
|
+
```
|
|
300
|
+
""")
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
# Project filter
|
|
304
|
+
selected_project = st.sidebar.selectbox(
|
|
305
|
+
"选择项目",
|
|
306
|
+
["全部"] + projects,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# List runs
|
|
310
|
+
if selected_project == "全部":
|
|
311
|
+
run_paths = vlalab.list_runs()
|
|
312
|
+
else:
|
|
313
|
+
run_paths = vlalab.list_runs(project=selected_project)
|
|
314
|
+
|
|
315
|
+
if not run_paths:
|
|
316
|
+
st.info("该项目下没有运行记录。")
|
|
317
|
+
return
|
|
318
|
+
|
|
319
|
+
# Select run
|
|
320
|
+
selected_path = st.sidebar.selectbox(
|
|
321
|
+
"选择运行",
|
|
322
|
+
run_paths,
|
|
323
|
+
format_func=lambda p: f"{p.name} ({p.parent.name})"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
if selected_path is None:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
# Load viewer
|
|
330
|
+
if 'viewer' not in st.session_state or st.session_state.get('last_run') != str(selected_path):
|
|
331
|
+
st.session_state.viewer = InferenceRunViewer(str(selected_path))
|
|
332
|
+
st.session_state.last_run = str(selected_path)
|
|
333
|
+
|
|
334
|
+
viewer = st.session_state.viewer
|
|
335
|
+
|
|
336
|
+
if not viewer.valid:
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
# Show metadata
|
|
340
|
+
st.sidebar.markdown("---")
|
|
341
|
+
st.sidebar.markdown("### 运行信息")
|
|
342
|
+
st.sidebar.info(f"步数: {len(viewer.steps)}")
|
|
343
|
+
if viewer.meta:
|
|
344
|
+
model = viewer.meta.get("model_name", "unknown")
|
|
345
|
+
if isinstance(model, str) and len(model) > 30:
|
|
346
|
+
model = "..." + model[-30:]
|
|
347
|
+
st.sidebar.info(f"模型: {model}")
|
|
348
|
+
|
|
349
|
+
# Tabs
|
|
350
|
+
tab1, tab2 = st.tabs(["📺 逐帧回放", "📈 性能分析"])
|
|
351
|
+
|
|
352
|
+
with tab1:
|
|
353
|
+
if viewer.steps:
|
|
354
|
+
step_idx = st.slider("Step", 0, len(viewer.steps)-1, 0)
|
|
355
|
+
viewer.plot_replay_frame(step_idx)
|
|
356
|
+
else:
|
|
357
|
+
st.warning("无步骤数据")
|
|
358
|
+
|
|
359
|
+
with tab2:
|
|
360
|
+
viewer.plot_latency_analysis()
|