vlalab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,322 @@
1
+ """
2
+ VLA-Lab Dataset Viewer
3
+
4
+ Browse and analyze training/evaluation datasets in Zarr format.
5
+ """
6
+
7
+ import streamlit as st
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.gridspec import GridSpec
11
+ from pathlib import Path
12
+ from typing import Optional, List
13
+
14
+ # Setup matplotlib fonts
15
+ try:
16
+ from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
17
+ setup_matplotlib_fonts(verbose=False)
18
+ except Exception:
19
+ pass
20
+
21
+
22
+ class ZarrDatasetViewer:
23
+ """Viewer for Zarr datasets (Diffusion Policy format)."""
24
+
25
+ def __init__(self, zarr_path: str):
26
+ self.zarr_path = zarr_path
27
+ self.valid = False
28
+ self._load_data()
29
+
30
+ def _load_data(self):
31
+ """Load Zarr dataset."""
32
+ try:
33
+ import zarr
34
+ self.root = zarr.open(self.zarr_path, mode='r')
35
+ self.data = self.root['data']
36
+ self.meta = self.root['meta']
37
+ self.episode_ends = self.meta['episode_ends'][:]
38
+ self.valid = True
39
+ except ImportError:
40
+ st.error("请安装 zarr: pip install zarr")
41
+ self.valid = False
42
+ except Exception as e:
43
+ st.error(f"无法加载 Zarr 文件: {e}")
44
+ self.valid = False
45
+ return
46
+
47
+ # Analyze fields
48
+ self.image_keys = []
49
+ self.lowdim_keys = []
50
+ self.action_key = 'action'
51
+
52
+ for key in self.data.keys():
53
+ arr = self.data[key]
54
+ if key == 'action':
55
+ self.action_key = key
56
+ elif len(arr.shape) == 4: # (T, H, W, C)
57
+ self.image_keys.append(key)
58
+ else:
59
+ self.lowdim_keys.append(key)
60
+
61
+ def get_episode_slice(self, episode_idx: int) -> slice:
62
+ """Get slice for an episode."""
63
+ start_idx = 0 if episode_idx == 0 else self.episode_ends[episode_idx - 1]
64
+ end_idx = self.episode_ends[episode_idx]
65
+ return slice(int(start_idx), int(end_idx))
66
+
67
+ def get_episode_data(self, episode_idx: int) -> dict:
68
+ """Get all data for an episode."""
69
+ s = self.get_episode_slice(episode_idx)
70
+ data = {}
71
+ for key in self.data.keys():
72
+ data[key] = self.data[key][s]
73
+ return data
74
+
75
+ def plot_images_grid(self, episode_idx: int, step_interval: int = 5, max_frames: int = 20):
76
+ """Plot image grid for an episode."""
77
+ if not self.image_keys:
78
+ st.warning("数据集中没有图像数据")
79
+ return
80
+
81
+ episode_data = self.get_episode_data(episode_idx)
82
+ image_key = self.image_keys[0]
83
+ images = episode_data[image_key]
84
+
85
+ total_frames = len(images)
86
+ frame_indices = list(range(0, total_frames, step_interval))[:max_frames]
87
+ n_frames = len(frame_indices)
88
+
89
+ n_cols = 5
90
+ n_rows = (n_frames + n_cols - 1) // n_cols
91
+
92
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
93
+ st.write(f"**Episode {episode_idx} - 图像概览 ({image_key})**")
94
+
95
+ axes = np.atleast_2d(axes)
96
+
97
+ for i, frame_idx in enumerate(frame_indices):
98
+ row, col = i // n_cols, i % n_cols
99
+ ax = axes[row, col]
100
+ image = images[frame_idx]
101
+ if image.shape[0] in [1, 3] and len(image.shape) == 3:
102
+ image = np.transpose(image, (1, 2, 0))
103
+ ax.imshow(image)
104
+ ax.set_title(f'Step {frame_idx}', fontsize=8)
105
+ ax.axis('off')
106
+
107
+ for i in range(n_frames, n_rows * n_cols):
108
+ row, col = i // n_cols, i % n_cols
109
+ axes[row, col].axis('off')
110
+
111
+ plt.tight_layout()
112
+ st.pyplot(fig)
113
+ plt.close(fig)
114
+
115
+ def plot_actions_summary(self, episode_idx: int):
116
+ """Plot action summary for an episode."""
117
+ episode_data = self.get_episode_data(episode_idx)
118
+ actions = episode_data['action']
119
+ T = len(actions)
120
+ action_dim = actions.shape[1]
121
+
122
+ fig = plt.figure(figsize=(12, 8))
123
+ gs = GridSpec(2, 2, figure=fig)
124
+
125
+ st.write(f"**Episode {episode_idx} - 动作全局分析**")
126
+
127
+ # Position time series
128
+ ax1 = fig.add_subplot(gs[0, :])
129
+ for i, label in enumerate(['x', 'y', 'z']):
130
+ if i < action_dim:
131
+ ax1.plot(actions[:, i], label=label, alpha=0.8)
132
+ ax1.set_title('位置变化 (Position)')
133
+ ax1.legend()
134
+ ax1.grid(True, alpha=0.3)
135
+
136
+ # 3D trajectory
137
+ ax2 = fig.add_subplot(gs[1, 0], projection='3d')
138
+ if action_dim >= 3:
139
+ ax2.plot(actions[:, 0], actions[:, 1], actions[:, 2], 'b-', alpha=0.6)
140
+ ax2.scatter(actions[0, 0], actions[0, 1], actions[0, 2], c='g', label='Start')
141
+ ax2.scatter(actions[-1, 0], actions[-1, 1], actions[-1, 2], c='r', label='End')
142
+ ax2.set_title('3D 轨迹')
143
+ ax2.set_xlabel('X')
144
+ ax2.set_ylabel('Y')
145
+ ax2.set_zlabel('Z')
146
+
147
+ # Gripper
148
+ ax3 = fig.add_subplot(gs[1, 1])
149
+ if action_dim >= 8:
150
+ ax3.plot(actions[:, 7], 'k-', linewidth=2)
151
+ ax3.set_title('夹爪 (Gripper)')
152
+ ax3.grid(True, alpha=0.3)
153
+ else:
154
+ ax3.text(0.5, 0.5, "无夹爪数据", ha='center')
155
+
156
+ plt.tight_layout()
157
+ st.pyplot(fig)
158
+ plt.close(fig)
159
+
160
+ def plot_interactive_step(self, episode_idx: int, step_idx: int):
161
+ """Plot interactive step view."""
162
+ episode_data = self.get_episode_data(episode_idx)
163
+ actions = episode_data['action']
164
+
165
+ # Get image
166
+ image = None
167
+ if self.image_keys:
168
+ image_key = self.image_keys[0]
169
+ raw_img = episode_data[image_key][step_idx]
170
+ if raw_img.shape[0] in [1, 3] and len(raw_img.shape) == 3:
171
+ image = np.transpose(raw_img, (1, 2, 0))
172
+ else:
173
+ image = raw_img
174
+
175
+ # Get action
176
+ current_action = actions[step_idx]
177
+ action_dim = actions.shape[1]
178
+
179
+ # Layout
180
+ c1, c2 = st.columns([1.5, 1])
181
+
182
+ with c1:
183
+ st.markdown(f"#### 📸 相机视角 (Step {step_idx})")
184
+ if image is not None:
185
+ st.image(image, use_container_width=True)
186
+ else:
187
+ st.warning("无图像数据")
188
+
189
+ with c2:
190
+ st.markdown("#### 🤖 机器人状态")
191
+ st.info(f"Step: **{step_idx}** / {len(actions)-1}")
192
+
193
+ mc1, mc2, mc3 = st.columns(3)
194
+ mc1.metric("X", f"{current_action[0]:.3f}")
195
+ mc2.metric("Y", f"{current_action[1]:.3f}")
196
+ mc3.metric("Z", f"{current_action[2]:.3f}")
197
+
198
+ if action_dim >= 7:
199
+ st.markdown("**姿态 (Quaternion):**")
200
+ st.code(f"[{current_action[3]:.3f}, {current_action[4]:.3f}, {current_action[5]:.3f}, {current_action[6]:.3f}]")
201
+
202
+ if action_dim >= 8:
203
+ g_val = current_action[7]
204
+ g_state = "OPEN" if g_val > 0.5 else "CLOSED"
205
+ st.metric("夹爪 Gripper", f"{g_val:.3f}", delta=g_state)
206
+
207
+ st.divider()
208
+
209
+ # Trajectory view
210
+ st.markdown("#### 📍 轨迹同步视图")
211
+
212
+ fig = plt.figure(figsize=(14, 5))
213
+ gs = GridSpec(1, 2, figure=fig)
214
+
215
+ # 3D trajectory
216
+ ax1 = fig.add_subplot(gs[0, 0], projection='3d')
217
+ ax1.plot(actions[:,0], actions[:,1], actions[:,2], 'b-', alpha=0.2, linewidth=1, label='Path')
218
+ ax1.scatter(actions[0,0], actions[0,1], actions[0,2], c='g', s=20, alpha=0.5)
219
+ ax1.scatter(actions[-1,0], actions[-1,1], actions[-1,2], c='gray', s=20, alpha=0.5)
220
+ ax1.scatter(current_action[0], current_action[1], current_action[2],
221
+ c='r', s=150, edgecolors='k', label='Current', zorder=100)
222
+ ax1.set_title("3D 空间位置 (红点=当前)", fontsize=10)
223
+ ax1.set_xlabel('X')
224
+ ax1.set_ylabel('Y')
225
+ ax1.set_zlabel('Z')
226
+
227
+ # Time series
228
+ ax2 = fig.add_subplot(gs[0, 1])
229
+ steps = np.arange(len(actions))
230
+ for i, label in enumerate(['x', 'y', 'z']):
231
+ ax2.plot(steps, actions[:, i], label=label, alpha=0.6)
232
+ ax2.axvline(x=step_idx, color='r', linestyle='--', linewidth=2, label='Current Step')
233
+ ax2.set_title("XYZ 随时间变化 (红线=当前)", fontsize=10)
234
+ ax2.legend(loc='upper right', fontsize=8)
235
+ ax2.grid(True, alpha=0.3)
236
+
237
+ plt.tight_layout()
238
+ st.pyplot(fig)
239
+ plt.close(fig)
240
+
241
+ def plot_workspace_3d(self, sample_ratio: float = 0.1):
242
+ """Plot 3D workspace distribution."""
243
+ if not self.valid:
244
+ return
245
+
246
+ st.write("正在采样并生成 3D 工作空间...")
247
+ actions = self.data['action'][:]
248
+ n_samples = int(len(actions) * sample_ratio)
249
+ indices = np.random.choice(len(actions), n_samples, replace=False)
250
+ sampled = actions[indices]
251
+
252
+ fig = plt.figure(figsize=(10, 6))
253
+ ax = fig.add_subplot(111, projection='3d')
254
+ img = ax.scatter(sampled[:, 0], sampled[:, 1], sampled[:, 2],
255
+ c=np.arange(n_samples), cmap='viridis', s=1, alpha=0.3)
256
+ ax.set_xlabel('X')
257
+ ax.set_ylabel('Y')
258
+ ax.set_zlabel('Z')
259
+ ax.set_title(f'Global Workspace ({n_samples} points)')
260
+ plt.colorbar(img, ax=ax, label='Time step order')
261
+ st.pyplot(fig)
262
+ plt.close(fig)
263
+
264
+
265
+ def render():
266
+ """Render the dataset viewer page."""
267
+ st.title("📊 训练数据可视化")
268
+
269
+ # Path input
270
+ default_path = st.sidebar.text_input(
271
+ "Zarr 数据集路径",
272
+ value="/data0/vla-data/processed/Diffusion_Policy/data/001_assembly_chocolate/assembly_chocolate_300.zarr"
273
+ )
274
+
275
+ # Initialize session state
276
+ if 'zarr_viz' not in st.session_state:
277
+ st.session_state.zarr_viz = None
278
+
279
+ # Load button
280
+ if st.sidebar.button("加载/重载数据集", type="primary"):
281
+ if Path(default_path).exists():
282
+ st.session_state.zarr_viz = ZarrDatasetViewer(default_path)
283
+ st.success(f"已加载: {Path(default_path).name}")
284
+ else:
285
+ st.error("路径不存在!")
286
+
287
+ viz = st.session_state.zarr_viz
288
+
289
+ if viz and viz.valid:
290
+ st.sidebar.markdown("---")
291
+ st.sidebar.info(f"Episodes: {len(viz.episode_ends)}\nTotal Steps: {viz.episode_ends[-1]}")
292
+
293
+ # Episode selection
294
+ selected_ep = st.sidebar.selectbox("选择 Episode", range(len(viz.episode_ends)))
295
+
296
+ # Tabs
297
+ tab1, tab2, tab3 = st.tabs(["🔍 详细交互", "📊 全局概览", "🧊 空间分布"])
298
+
299
+ with tab1:
300
+ st.markdown(f"### Episode {selected_ep} - 逐帧分析")
301
+ ep_data = viz.get_episode_data(selected_ep)
302
+ max_step = len(ep_data['action']) - 1
303
+ step_idx = st.slider("⏱️ 时间轴", 0, max_step, 0)
304
+ viz.plot_interactive_step(selected_ep, step_idx)
305
+
306
+ with tab2:
307
+ col1, col2 = st.columns([1, 4])
308
+ with col1:
309
+ st.caption("设置")
310
+ interval = st.slider("采样间隔", 1, 20, 5)
311
+ max_frames = st.slider("最大帧数", 5, 50, 20)
312
+ with col2:
313
+ viz.plot_images_grid(selected_ep, step_interval=interval, max_frames=max_frames)
314
+ st.divider()
315
+ viz.plot_actions_summary(selected_ep)
316
+
317
+ with tab3:
318
+ ratio = st.slider("采样比例", 0.01, 1.0, 0.1, 0.01)
319
+ if st.button("生成 3D 分布图"):
320
+ viz.plot_workspace_3d(ratio)
321
+ else:
322
+ st.info('👈 请先在左侧输入路径并点击"加载数据集"')
@@ -0,0 +1,360 @@
1
+ """
2
+ VLA-Lab Inference Run Viewer
3
+
4
+ Step-by-step replay of inference sessions with multi-camera support.
5
+ """
6
+
7
+ import streamlit as st
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from pathlib import Path
11
+ import json
12
+ import base64
13
+ import cv2
14
+ from typing import Optional, Dict, Any
15
+
16
+ import vlalab
17
+
18
+ # Setup matplotlib fonts
19
+ try:
20
+ from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
21
+ setup_matplotlib_fonts(verbose=False)
22
+ except Exception:
23
+ pass
24
+
25
+
26
+ class InferenceRunViewer:
27
+ """Viewer for VLA-Lab inference runs."""
28
+
29
+ def __init__(self, run_path: str):
30
+ self.run_path = Path(run_path)
31
+ self.valid = False
32
+ self.steps = []
33
+ self.meta = {}
34
+ self._load_data()
35
+
36
+ def _load_data(self):
37
+ """Load run data."""
38
+ try:
39
+ meta_path = self.run_path / "meta.json"
40
+ steps_path = self.run_path / "steps.jsonl"
41
+
42
+ if meta_path.exists():
43
+ with open(meta_path, "r") as f:
44
+ self.meta = json.load(f)
45
+
46
+ if steps_path.exists():
47
+ self.steps = []
48
+ with open(steps_path, "r") as f:
49
+ for line in f:
50
+ if line.strip():
51
+ self.steps.append(json.loads(line))
52
+
53
+ self.valid = True
54
+ except Exception as e:
55
+ st.error(f"加载失败: {e}")
56
+ self.valid = False
57
+
58
+ def _get_latency_ms(self, timing_dict: Dict, key_base: str) -> float:
59
+ """Get latency value in ms."""
60
+ new_key = f"{key_base}_ms"
61
+ if new_key in timing_dict and timing_dict[new_key] is not None:
62
+ return timing_dict[new_key]
63
+ if key_base in timing_dict and timing_dict[key_base] is not None:
64
+ return timing_dict[key_base] * 1000
65
+ return 0.0
66
+
67
+ def load_image_from_ref(self, image_ref: Dict) -> Optional[np.ndarray]:
68
+ """Load image from image reference."""
69
+ if not image_ref:
70
+ return None
71
+
72
+ image_path = self.run_path / image_ref.get("path", "")
73
+ if not image_path.exists():
74
+ return None
75
+
76
+ img = cv2.imread(str(image_path))
77
+ if img is None:
78
+ return None
79
+
80
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81
+
82
+ def get_step_image(self, step_idx: int) -> Optional[np.ndarray]:
83
+ """Get image for a step."""
84
+ if step_idx >= len(self.steps):
85
+ return None
86
+
87
+ step = self.steps[step_idx]
88
+ obs = step.get("obs", {})
89
+ images = obs.get("images", [])
90
+ if images:
91
+ return self.load_image_from_ref(images[0])
92
+ return None
93
+
94
+ def get_step_images(self, step_idx: int) -> Dict[str, np.ndarray]:
95
+ """Get all camera images for a step as {camera_name: image_rgb}."""
96
+ if step_idx >= len(self.steps):
97
+ return {}
98
+ step = self.steps[step_idx]
99
+ obs = step.get("obs", {})
100
+ images = obs.get("images", [])
101
+ out: Dict[str, np.ndarray] = {}
102
+ for ref in images or []:
103
+ if not isinstance(ref, dict):
104
+ continue
105
+ cam = ref.get("camera_name", "default")
106
+ img = self.load_image_from_ref(ref)
107
+ if img is not None:
108
+ out[str(cam)] = img
109
+ return out
110
+
111
+ def get_step_state(self, step_idx: int) -> np.ndarray:
112
+ """Get state for a step."""
113
+ if step_idx >= len(self.steps):
114
+ return np.array([])
115
+
116
+ step = self.steps[step_idx]
117
+ obs = step.get("obs", {})
118
+ state = obs.get("state", [])
119
+ return np.array(state) if state else np.array([])
120
+
121
+ def get_step_action(self, step_idx: int) -> np.ndarray:
122
+ """Get action for a step."""
123
+ if step_idx >= len(self.steps):
124
+ return np.array([])
125
+
126
+ step = self.steps[step_idx]
127
+ action_data = step.get("action", {})
128
+ values = action_data.get("values", [])
129
+ return np.array(values) if values else np.array([])
130
+
131
+ def get_all_states(self) -> np.ndarray:
132
+ """Get all states as array."""
133
+ states = [self.get_step_state(i) for i in range(len(self.steps))]
134
+ valid_states = [s for s in states if len(s) > 0]
135
+ return np.array(valid_states) if valid_states else np.array([])
136
+
137
+ def plot_replay_frame(self, step_idx: int):
138
+ """Plot replay frame for a step."""
139
+ if not self.valid or step_idx >= len(self.steps):
140
+ return
141
+
142
+ step = self.steps[step_idx]
143
+ current_state = self.get_step_state(step_idx)
144
+ pred_action = self.get_step_action(step_idx)
145
+ imgs = self.get_step_images(step_idx)
146
+ timing = step.get("timing", {})
147
+
148
+ # Layout
149
+ c1, c2 = st.columns([1, 1.5])
150
+
151
+ with c1:
152
+ st.markdown("#### 👁️ 模型视觉观测")
153
+ if imgs:
154
+ # Display in a grid (up to 3 columns) to support multi-camera runs
155
+ cam_names = list(imgs.keys())
156
+ n_cols = min(3, max(1, len(cam_names)))
157
+ cols = st.columns(n_cols)
158
+ for i, cam in enumerate(cam_names):
159
+ with cols[i % n_cols]:
160
+ img = imgs[cam]
161
+ st.image(
162
+ img,
163
+ caption=f"{cam} | Step {step_idx} | {img.shape}",
164
+ use_container_width=True,
165
+ )
166
+ else:
167
+ st.warning("无图像数据")
168
+
169
+ # Timing metrics
170
+ if timing:
171
+ t_transport = self._get_latency_ms(timing, "transport_latency")
172
+ t_infer = self._get_latency_ms(timing, "inference_latency")
173
+ total = self._get_latency_ms(timing, "total_latency")
174
+
175
+ st.markdown("#### ⏱️ 时延诊断")
176
+ col_t1, col_t2, col_t3 = st.columns(3)
177
+ col_t1.metric("传输延迟", f"{t_transport:.0f} ms")
178
+ col_t2.metric("推理耗时", f"{t_infer:.0f} ms")
179
+ col_t3.metric("总回路", f"{total:.0f} ms")
180
+
181
+ if t_transport > 100:
182
+ st.error(f"⚠️ 传输延迟过高 ({t_transport:.0f}ms)! 检查网络或 SSH 隧道")
183
+
184
+ with c2:
185
+ st.markdown("#### 🗺️ 3D 动作规划")
186
+
187
+ if len(current_state) >= 3:
188
+ fig = plt.figure(figsize=(8, 6))
189
+ ax = fig.add_subplot(111, projection='3d')
190
+
191
+ # Plot history
192
+ all_states = self.get_all_states()
193
+ if len(all_states) > 0 and all_states.shape[1] >= 3:
194
+ start = max(0, step_idx - 50)
195
+ hist = all_states[start:step_idx+1]
196
+ if len(hist) > 1:
197
+ ax.plot(hist[:,0], hist[:,1], hist[:,2], 'k-', alpha=0.3, label='History')
198
+
199
+ # Current position
200
+ ax.scatter(current_state[0], current_state[1], current_state[2],
201
+ c='b', s=100, label='Current')
202
+
203
+ # Predicted trajectory
204
+ if len(pred_action) > 0 and pred_action.ndim == 2 and pred_action.shape[1] >= 3:
205
+ ax.plot(pred_action[:,0], pred_action[:,1], pred_action[:,2],
206
+ 'r--', linewidth=2, label='Pred')
207
+ ax.scatter(pred_action[-1,0], pred_action[-1,1], pred_action[-1,2],
208
+ c='r', marker='x', s=100)
209
+
210
+ ax.set_xlabel('X')
211
+ ax.set_ylabel('Y')
212
+ ax.set_zlabel('Z')
213
+ ax.legend()
214
+
215
+ # Set axis limits
216
+ if len(all_states) > 0 and all_states.shape[1] >= 3:
217
+ margin = 0.1
218
+ ax.set_xlim(all_states[:,0].min()-margin, all_states[:,0].max()+margin)
219
+ ax.set_ylim(all_states[:,1].min()-margin, all_states[:,1].max()+margin)
220
+ ax.set_zlim(all_states[:,2].min()-margin, all_states[:,2].max()+margin)
221
+
222
+ st.pyplot(fig)
223
+ plt.close(fig)
224
+ else:
225
+ st.warning("状态维度不足,无法绘制3D轨迹")
226
+
227
+ def plot_latency_analysis(self):
228
+ """Plot latency analysis chart."""
229
+ if not self.steps:
230
+ st.warning("当前日志不包含步骤数据")
231
+ return
232
+
233
+ steps_range = range(len(self.steps))
234
+ trans_lats = []
235
+ infer_lats = []
236
+ total_lats = []
237
+
238
+ for step in self.steps:
239
+ timing = step.get("timing", {})
240
+ trans_lats.append(self._get_latency_ms(timing, "transport_latency"))
241
+ infer_lats.append(self._get_latency_ms(timing, "inference_latency"))
242
+ total_lats.append(self._get_latency_ms(timing, "total_latency"))
243
+
244
+ if not any(total_lats):
245
+ st.warning("当前日志不包含详细时延数据")
246
+ return
247
+
248
+ fig, ax = plt.subplots(figsize=(12, 5))
249
+ ax.plot(steps_range, total_lats, color='gray', alpha=0.3, label='Total Loop')
250
+ ax.plot(steps_range, trans_lats, color='orange', label='Transport (Network)')
251
+ ax.plot(steps_range, infer_lats, color='blue', label='Inference (GPU)')
252
+
253
+ ax.set_title("时延组成分析 (ms)")
254
+ ax.set_xlabel("Step")
255
+ ax.set_ylabel("Latency (ms)")
256
+ ax.legend()
257
+ ax.grid(True, alpha=0.3)
258
+ ax.axhline(100, color='r', linestyle='--', alpha=0.5)
259
+ ax.text(0, 105, '100ms Alert', color='r', fontsize=8)
260
+
261
+ st.pyplot(fig)
262
+ plt.close(fig)
263
+
264
+ # Statistics
265
+ col1, col2, col3, col4 = st.columns(4)
266
+ valid_total = [t for t in total_lats if t > 0]
267
+ valid_trans = [t for t in trans_lats if t > 0]
268
+ valid_infer = [t for t in infer_lats if t > 0]
269
+
270
+ if valid_total:
271
+ col1.metric("平均总延迟", f"{np.mean(valid_total):.1f} ms")
272
+ if valid_trans:
273
+ col2.metric("平均传输延迟", f"{np.mean(valid_trans):.1f} ms")
274
+ if valid_infer:
275
+ col3.metric("平均推理延迟", f"{np.mean(valid_infer):.1f} ms")
276
+ if valid_total:
277
+ col4.metric("最大总延迟", f"{np.max(valid_total):.1f} ms")
278
+
279
+
280
+ def render():
281
+ """Render the inference viewer page."""
282
+ st.title("🔬 推理运行回放")
283
+
284
+ # Sidebar: show current runs directory
285
+ runs_dir = vlalab.get_runs_dir()
286
+ st.sidebar.markdown("### 日志目录")
287
+ st.sidebar.code(str(runs_dir))
288
+
289
+ # List projects
290
+ projects = vlalab.list_projects()
291
+
292
+ if not projects:
293
+ st.info(f"未找到任何项目。请先使用 `vlalab.init()` 创建运行记录。\n\n日志目录: `{runs_dir}`")
294
+ st.markdown("""
295
+ **提示**: 设置 `$VLALAB_DIR` 环境变量可更改日志存储位置。
296
+
297
+ ```bash
298
+ export VLALAB_DIR=/path/to/your/logs
299
+ ```
300
+ """)
301
+ return
302
+
303
+ # Project filter
304
+ selected_project = st.sidebar.selectbox(
305
+ "选择项目",
306
+ ["全部"] + projects,
307
+ )
308
+
309
+ # List runs
310
+ if selected_project == "全部":
311
+ run_paths = vlalab.list_runs()
312
+ else:
313
+ run_paths = vlalab.list_runs(project=selected_project)
314
+
315
+ if not run_paths:
316
+ st.info("该项目下没有运行记录。")
317
+ return
318
+
319
+ # Select run
320
+ selected_path = st.sidebar.selectbox(
321
+ "选择运行",
322
+ run_paths,
323
+ format_func=lambda p: f"{p.name} ({p.parent.name})"
324
+ )
325
+
326
+ if selected_path is None:
327
+ return
328
+
329
+ # Load viewer
330
+ if 'viewer' not in st.session_state or st.session_state.get('last_run') != str(selected_path):
331
+ st.session_state.viewer = InferenceRunViewer(str(selected_path))
332
+ st.session_state.last_run = str(selected_path)
333
+
334
+ viewer = st.session_state.viewer
335
+
336
+ if not viewer.valid:
337
+ return
338
+
339
+ # Show metadata
340
+ st.sidebar.markdown("---")
341
+ st.sidebar.markdown("### 运行信息")
342
+ st.sidebar.info(f"步数: {len(viewer.steps)}")
343
+ if viewer.meta:
344
+ model = viewer.meta.get("model_name", "unknown")
345
+ if isinstance(model, str) and len(model) > 30:
346
+ model = "..." + model[-30:]
347
+ st.sidebar.info(f"模型: {model}")
348
+
349
+ # Tabs
350
+ tab1, tab2 = st.tabs(["📺 逐帧回放", "📈 性能分析"])
351
+
352
+ with tab1:
353
+ if viewer.steps:
354
+ step_idx = st.slider("Step", 0, len(viewer.steps)-1, 0)
355
+ viewer.plot_replay_frame(step_idx)
356
+ else:
357
+ st.warning("无步骤数据")
358
+
359
+ with tab2:
360
+ viewer.plot_latency_analysis()