vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,374 @@
1
+ """
2
+ VLA-Lab Open-Loop Evaluation Viewer
3
+
4
+ Visualize and compare predicted vs ground-truth actions from open-loop evaluation.
5
+ """
6
+
7
+ import streamlit as st
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from pathlib import Path
11
+ import json
12
+ from typing import Dict, List, Any, Optional
13
+
14
+ # Setup matplotlib fonts
15
+ try:
16
+ from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
17
+ setup_matplotlib_fonts(verbose=False)
18
+ except Exception:
19
+ pass
20
+
21
+
22
+ def load_eval_results(results_path: Path) -> Dict[str, Any]:
23
+ """Load evaluation results from JSON file."""
24
+ with open(results_path, "r") as f:
25
+ return json.load(f)
26
+
27
+
28
+ def load_trajectory_arrays(results_dir: Path, traj_id: int) -> Dict[str, np.ndarray]:
29
+ """
30
+ Load trajectory arrays (GT and pred actions) if saved as .npy files.
31
+
32
+ This is a placeholder - actual implementation depends on how arrays are saved.
33
+ """
34
+ arrays = {}
35
+
36
+ gt_path = results_dir / f"traj_{traj_id}_gt.npy"
37
+ pred_path = results_dir / f"traj_{traj_id}_pred.npy"
38
+
39
+ if gt_path.exists():
40
+ arrays["gt_actions"] = np.load(gt_path)
41
+ if pred_path.exists():
42
+ arrays["pred_actions"] = np.load(pred_path)
43
+
44
+ return arrays
45
+
46
+
47
+ def plot_action_comparison(
48
+ gt_actions: np.ndarray,
49
+ pred_actions: np.ndarray,
50
+ dim_idx: int,
51
+ action_horizon: int = 16,
52
+ dim_label: str = "",
53
+ ) -> plt.Figure:
54
+ """Plot GT vs predicted actions for a single dimension."""
55
+ fig, ax = plt.subplots(figsize=(12, 4))
56
+
57
+ steps = np.arange(len(gt_actions))
58
+
59
+ ax.plot(steps, gt_actions[:, dim_idx], label="Ground Truth", alpha=0.8, linewidth=2)
60
+ ax.plot(steps, pred_actions[:, dim_idx], label="Predicted", alpha=0.8, linewidth=2, linestyle="--")
61
+
62
+ # Mark inference points
63
+ for j in range(0, len(gt_actions), action_horizon):
64
+ ax.axvline(x=j, color='gray', linestyle=':', alpha=0.3)
65
+
66
+ # Error band
67
+ error = np.abs(gt_actions[:, dim_idx] - pred_actions[:, dim_idx])
68
+ ax.fill_between(steps,
69
+ gt_actions[:, dim_idx] - error,
70
+ gt_actions[:, dim_idx] + error,
71
+ alpha=0.2, color='red', label='Error')
72
+
73
+ ax.set_xlabel("Step")
74
+ ax.set_ylabel("Value")
75
+ ax.set_title(f"Action Dimension: {dim_label or dim_idx}")
76
+ ax.legend()
77
+ ax.grid(True, alpha=0.3)
78
+
79
+ return fig
80
+
81
+
82
+ def plot_all_dimensions(
83
+ gt_actions: np.ndarray,
84
+ pred_actions: np.ndarray,
85
+ action_horizon: int = 16,
86
+ action_labels: Optional[List[str]] = None,
87
+ ) -> plt.Figure:
88
+ """Plot all action dimensions in a grid."""
89
+ num_dims = gt_actions.shape[1]
90
+ n_cols = min(3, num_dims)
91
+ n_rows = (num_dims + n_cols - 1) // n_cols
92
+
93
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows))
94
+ axes = np.atleast_2d(axes)
95
+
96
+ steps = np.arange(len(gt_actions))
97
+
98
+ for i in range(num_dims):
99
+ row, col = i // n_cols, i % n_cols
100
+ ax = axes[row, col]
101
+
102
+ ax.plot(steps, gt_actions[:, i], label="GT", alpha=0.8)
103
+ ax.plot(steps, pred_actions[:, i], label="Pred", alpha=0.8, linestyle="--")
104
+
105
+ label = action_labels[i] if action_labels and i < len(action_labels) else f"Dim {i}"
106
+ ax.set_title(label, fontsize=10)
107
+ ax.grid(True, alpha=0.3)
108
+
109
+ if i == 0:
110
+ ax.legend(fontsize=8)
111
+
112
+ # Hide empty subplots
113
+ for i in range(num_dims, n_rows * n_cols):
114
+ row, col = i // n_cols, i % n_cols
115
+ axes[row, col].axis('off')
116
+
117
+ plt.tight_layout()
118
+ return fig
119
+
120
+
121
+ def plot_3d_trajectory(
122
+ gt_actions: np.ndarray,
123
+ pred_actions: np.ndarray,
124
+ xyz_dims: List[int] = [0, 1, 2],
125
+ ) -> plt.Figure:
126
+ """Plot 3D trajectory comparison (if actions contain position)."""
127
+ fig = plt.figure(figsize=(10, 8))
128
+ ax = fig.add_subplot(111, projection='3d')
129
+
130
+ x, y, z = xyz_dims
131
+
132
+ # GT trajectory
133
+ ax.plot(gt_actions[:, x], gt_actions[:, y], gt_actions[:, z],
134
+ 'b-', label='Ground Truth', alpha=0.8, linewidth=2)
135
+ ax.scatter(gt_actions[0, x], gt_actions[0, y], gt_actions[0, z],
136
+ c='green', s=100, label='Start')
137
+ ax.scatter(gt_actions[-1, x], gt_actions[-1, y], gt_actions[-1, z],
138
+ c='red', s=100, label='End')
139
+
140
+ # Pred trajectory
141
+ ax.plot(pred_actions[:, x], pred_actions[:, y], pred_actions[:, z],
142
+ 'r--', label='Predicted', alpha=0.6, linewidth=2)
143
+
144
+ ax.set_xlabel('X')
145
+ ax.set_ylabel('Y')
146
+ ax.set_zlabel('Z')
147
+ ax.legend()
148
+ ax.set_title('3D Trajectory Comparison')
149
+
150
+ return fig
151
+
152
+
153
+ def plot_error_histogram(
154
+ gt_actions: np.ndarray,
155
+ pred_actions: np.ndarray,
156
+ ) -> plt.Figure:
157
+ """Plot error distribution histogram."""
158
+ errors = (gt_actions - pred_actions).flatten()
159
+
160
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
161
+
162
+ # Error histogram
163
+ axes[0].hist(errors, bins=50, alpha=0.7, edgecolor='black')
164
+ axes[0].axvline(x=0, color='r', linestyle='--', label='Zero')
165
+ axes[0].set_xlabel('Error')
166
+ axes[0].set_ylabel('Frequency')
167
+ axes[0].set_title('Error Distribution')
168
+ axes[0].legend()
169
+
170
+ # Absolute error histogram
171
+ abs_errors = np.abs(errors)
172
+ axes[1].hist(abs_errors, bins=50, alpha=0.7, edgecolor='black', color='orange')
173
+ axes[1].axvline(x=np.mean(abs_errors), color='r', linestyle='--',
174
+ label=f'Mean: {np.mean(abs_errors):.4f}')
175
+ axes[1].set_xlabel('Absolute Error')
176
+ axes[1].set_ylabel('Frequency')
177
+ axes[1].set_title('Absolute Error Distribution')
178
+ axes[1].legend()
179
+
180
+ plt.tight_layout()
181
+ return fig
182
+
183
+
184
+ def render():
185
+ """Render the evaluation viewer page."""
186
+ st.title("📊 Open-Loop 评估结果")
187
+
188
+ st.markdown("""
189
+ 可视化 VLA 模型的 Open-Loop 评估结果,比较预测动作与真实动作。
190
+
191
+ **使用方法:**
192
+ 1. 上传评估结果 JSON 文件
193
+ 2. 或指定包含评估图片的目录
194
+ """)
195
+
196
+ # Sidebar options
197
+ st.sidebar.markdown("### 数据来源")
198
+
199
+ source = st.sidebar.radio(
200
+ "选择数据来源",
201
+ ["上传 JSON", "浏览目录", "演示数据"],
202
+ )
203
+
204
+ results = None
205
+ gt_actions = None
206
+ pred_actions = None
207
+
208
+ if source == "上传 JSON":
209
+ uploaded_file = st.sidebar.file_uploader(
210
+ "上传评估结果 JSON",
211
+ type=["json"],
212
+ )
213
+
214
+ if uploaded_file:
215
+ results = json.load(uploaded_file)
216
+ st.success(f"已加载评估结果: {len(results.get('results', []))} 条轨迹")
217
+
218
+ elif source == "浏览目录":
219
+ results_dir = st.sidebar.text_input(
220
+ "评估结果目录",
221
+ value="",
222
+ placeholder="/path/to/eval_results/",
223
+ )
224
+
225
+ if results_dir and Path(results_dir).exists():
226
+ # Find JSON files
227
+ json_files = list(Path(results_dir).glob("*.json"))
228
+ if json_files:
229
+ selected_json = st.sidebar.selectbox(
230
+ "选择结果文件",
231
+ json_files,
232
+ format_func=lambda p: p.name,
233
+ )
234
+ if selected_json:
235
+ results = load_eval_results(selected_json)
236
+
237
+ # Find plot images
238
+ png_files = sorted(Path(results_dir).glob("*.png"))
239
+ if png_files:
240
+ st.markdown("### 已生成的评估图")
241
+ for png_file in png_files:
242
+ st.image(str(png_file), caption=png_file.name)
243
+
244
+ elif source == "演示数据":
245
+ st.info("使用随机生成的演示数据")
246
+
247
+ # Generate demo data
248
+ np.random.seed(42)
249
+ num_steps = 200
250
+ action_dim = 8
251
+
252
+ # Simulate GT actions (smooth trajectory)
253
+ t = np.linspace(0, 4 * np.pi, num_steps)
254
+ gt_actions = np.zeros((num_steps, action_dim))
255
+ gt_actions[:, 0] = 0.5 * np.sin(t) + 0.1 * np.random.randn(num_steps)
256
+ gt_actions[:, 1] = 0.3 * np.cos(t) + 0.1 * np.random.randn(num_steps)
257
+ gt_actions[:, 2] = 0.2 * np.sin(0.5 * t) + 0.1 * np.random.randn(num_steps)
258
+ for i in range(3, action_dim):
259
+ gt_actions[:, i] = 0.1 * np.sin(t * (i - 2)) + 0.05 * np.random.randn(num_steps)
260
+
261
+ # Simulate pred actions (GT + noise + slight bias)
262
+ pred_actions = gt_actions + 0.05 * np.random.randn(num_steps, action_dim)
263
+ pred_actions += 0.02 * np.ones_like(pred_actions) # slight bias
264
+
265
+ # Calculate metrics
266
+ mse = float(np.mean((gt_actions - pred_actions) ** 2))
267
+ mae = float(np.mean(np.abs(gt_actions - pred_actions)))
268
+
269
+ results = {
270
+ "results": [{"trajectory_id": 0, "mse": mse, "mae": mae, "num_steps": num_steps}],
271
+ "avg_mse": mse,
272
+ "avg_mae": mae,
273
+ "num_trajectories": 1,
274
+ }
275
+
276
+ # Display results
277
+ if results:
278
+ st.markdown("---")
279
+
280
+ # Summary metrics
281
+ st.markdown("### 📈 评估指标")
282
+ col1, col2, col3 = st.columns(3)
283
+
284
+ col1.metric("平均 MSE", f"{results.get('avg_mse', 0):.6f}")
285
+ col2.metric("平均 MAE", f"{results.get('avg_mae', 0):.6f}")
286
+ col3.metric("评估轨迹数", results.get('num_trajectories', 0))
287
+
288
+ # Per-trajectory results
289
+ if results.get("results"):
290
+ st.markdown("### 📋 轨迹详情")
291
+
292
+ traj_data = []
293
+ for r in results["results"]:
294
+ traj_data.append({
295
+ "轨迹 ID": r["trajectory_id"],
296
+ "MSE": f"{r['mse']:.6f}",
297
+ "MAE": f"{r['mae']:.6f}",
298
+ "步数": r["num_steps"],
299
+ })
300
+
301
+ st.dataframe(traj_data, use_container_width=True)
302
+
303
+ # Visualization tabs
304
+ if gt_actions is not None and pred_actions is not None:
305
+ st.markdown("---")
306
+ st.markdown("### 🎨 可视化")
307
+
308
+ tab1, tab2, tab3, tab4 = st.tabs([
309
+ "📊 时序对比",
310
+ "🗺️ 3D 轨迹",
311
+ "📈 误差分布",
312
+ "🔍 逐维度分析",
313
+ ])
314
+
315
+ with tab1:
316
+ action_horizon = st.slider("Action Horizon", 4, 32, 16)
317
+ fig = plot_all_dimensions(gt_actions, pred_actions, action_horizon)
318
+ st.pyplot(fig)
319
+ plt.close(fig)
320
+
321
+ with tab2:
322
+ if gt_actions.shape[1] >= 3:
323
+ st.markdown("假设前三个维度为 XYZ 位置")
324
+ fig = plot_3d_trajectory(gt_actions, pred_actions)
325
+ st.pyplot(fig)
326
+ plt.close(fig)
327
+ else:
328
+ st.warning("动作维度不足,无法绘制 3D 轨迹")
329
+
330
+ with tab3:
331
+ fig = plot_error_histogram(gt_actions, pred_actions)
332
+ st.pyplot(fig)
333
+ plt.close(fig)
334
+
335
+ with tab4:
336
+ dim_idx = st.selectbox(
337
+ "选择维度",
338
+ range(gt_actions.shape[1]),
339
+ format_func=lambda x: f"维度 {x}",
340
+ )
341
+ fig = plot_action_comparison(gt_actions, pred_actions, dim_idx)
342
+ st.pyplot(fig)
343
+ plt.close(fig)
344
+
345
+ # Usage instructions
346
+ with st.expander("💡 如何生成评估结果"):
347
+ st.markdown("""
348
+ ```python
349
+ from vlalab.eval import OpenLoopEvaluator
350
+ from vlalab.eval.adapters import GR00TAdapter
351
+
352
+ # 1. 创建 Policy 适配器
353
+ adapter = GR00TAdapter(policy)
354
+
355
+ # 2. 创建评估器
356
+ evaluator = OpenLoopEvaluator(
357
+ policy=adapter,
358
+ dataset_path="/path/to/dataset.zarr",
359
+ )
360
+
361
+ # 3. 运行评估
362
+ results = evaluator.evaluate(
363
+ traj_ids=[0, 1, 2],
364
+ max_steps=200,
365
+ save_plots_dir="eval_outputs/",
366
+ )
367
+
368
+ # 4. 保存结果
369
+ evaluator.evaluate_and_save(
370
+ "eval_outputs/results.json",
371
+ traj_ids=[0, 1, 2],
372
+ )
373
+ ```
374
+ """)
vlalab/cli.py CHANGED
@@ -44,7 +44,7 @@ def view(port: int, run_dir: str):
44
44
  package_dir = Path(vlalab.__file__).parent
45
45
  app_path = package_dir / "apps" / "streamlit" / "app.py"
46
46
 
47
- cmd = [sys.executable, "-m", "streamlit", "run", str(app_path), "--server.port", str(port)]
47
+ cmd = [sys.executable, "-m", "streamlit", "run", str(app_path), "--server.port", str(port), "--server.address", "0.0.0.0"]
48
48
 
49
49
  if run_dir:
50
50
  cmd.extend(["--", "--run-dir", run_dir])
@@ -0,0 +1,15 @@
1
+ """
2
+ VLA-Lab Evaluation Module
3
+
4
+ Provides open-loop evaluation tools for VLA models.
5
+ """
6
+
7
+ from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
8
+ from vlalab.eval.open_loop_eval import OpenLoopEvaluator, evaluate_trajectory
9
+
10
+ __all__ = [
11
+ "EvalPolicy",
12
+ "ModalityConfig",
13
+ "OpenLoopEvaluator",
14
+ "evaluate_trajectory",
15
+ ]
@@ -0,0 +1,14 @@
1
+ """
2
+ VLA-Lab Policy Adapters
3
+
4
+ Adapters that wrap specific VLA model implementations to conform to
5
+ the unified EvalPolicy interface.
6
+ """
7
+
8
+ from vlalab.eval.adapters.groot_adapter import GR00TAdapter
9
+ from vlalab.eval.adapters.dp_adapter import DiffusionPolicyAdapter
10
+
11
+ __all__ = [
12
+ "GR00TAdapter",
13
+ "DiffusionPolicyAdapter",
14
+ ]
@@ -0,0 +1,279 @@
1
+ """
2
+ Diffusion Policy Adapter
3
+
4
+ Wraps Diffusion Policy implementations to conform to the unified EvalPolicy interface.
5
+ This adapter handles the conversion between VLA-Lab's standardized observation
6
+ format and Diffusion Policy's specific input/output formats.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional, Callable
10
+ import numpy as np
11
+
12
+ from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
13
+
14
+
15
+ def parse_observation_dp(
16
+ obs: Dict[str, Any],
17
+ state_key: str = "state",
18
+ image_key: str = "image",
19
+ ) -> Dict[str, Any]:
20
+ """
21
+ Convert standardized observation to Diffusion Policy's expected format.
22
+
23
+ Diffusion Policy typically expects:
24
+ {
25
+ "state": np.ndarray (state_dim,) or (T, state_dim),
26
+ "image": np.ndarray (H, W, C) or (T, H, W, C),
27
+ }
28
+
29
+ Args:
30
+ obs: Standardized observation dict with:
31
+ - "state": Dict[str, np.ndarray] - state vectors
32
+ - "images": Dict[str, np.ndarray] - images (H, W, C)
33
+ state_key: Key to use for concatenated state
34
+ image_key: Key to use for primary image
35
+
36
+ Returns:
37
+ DP-formatted observation dict
38
+ """
39
+ dp_obs = {}
40
+
41
+ # Concatenate all state keys into single state vector
42
+ if "state" in obs:
43
+ state_parts = []
44
+ for key, arr in obs["state"].items():
45
+ arr = np.atleast_1d(arr).astype(np.float32)
46
+ state_parts.append(arr)
47
+ if state_parts:
48
+ dp_obs[state_key] = np.concatenate(state_parts, axis=-1)
49
+
50
+ # Use first image as primary image
51
+ if "images" in obs:
52
+ for cam_name, img in obs["images"].items():
53
+ dp_obs[image_key] = img
54
+ break # Use first image
55
+
56
+ return dp_obs
57
+
58
+
59
+ def parse_action_dp(
60
+ action: Any,
61
+ ) -> np.ndarray:
62
+ """
63
+ Convert Diffusion Policy action output to standardized array format.
64
+
65
+ Diffusion Policy outputs action chunks as:
66
+ - np.ndarray of shape (action_horizon, action_dim)
67
+ - or torch.Tensor of same shape
68
+
69
+ Args:
70
+ action: DP action output
71
+
72
+ Returns:
73
+ Action array of shape (action_horizon, action_dim)
74
+ """
75
+ # Handle torch tensors
76
+ if hasattr(action, "cpu"):
77
+ action = action.cpu().numpy()
78
+
79
+ action = np.asarray(action)
80
+
81
+ # Ensure 2D
82
+ if action.ndim == 1:
83
+ action = action[None, :]
84
+
85
+ return action
86
+
87
+
88
+ class DiffusionPolicyAdapter(EvalPolicy):
89
+ """
90
+ Adapter for Diffusion Policy.
91
+
92
+ This adapter is designed to work with various Diffusion Policy implementations.
93
+ You can either:
94
+ 1. Pass a policy object with a predict() or get_action() method
95
+ 2. Pass a callable inference function directly
96
+
97
+ Usage:
98
+ # Option 1: Wrap policy object
99
+ from dp_server import DPPolicy
100
+ policy = DPPolicy.load(checkpoint_path)
101
+ adapter = DiffusionPolicyAdapter(policy)
102
+
103
+ # Option 2: Wrap inference client
104
+ from dp_client import DPClient
105
+ client = DPClient(host="localhost", port=5000)
106
+ adapter = DiffusionPolicyAdapter(
107
+ client,
108
+ inference_fn=lambda p, obs: p.predict(obs["state"], obs["image"])
109
+ )
110
+
111
+ # Option 3: Wrap callable
112
+ adapter = DiffusionPolicyAdapter(
113
+ inference_fn=my_inference_function,
114
+ action_horizon=8,
115
+ action_dim=7,
116
+ )
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ policy: Any = None,
122
+ inference_fn: Optional[Callable] = None,
123
+ action_horizon: int = 8,
124
+ action_dim: int = 7,
125
+ state_keys: Optional[List[str]] = None,
126
+ image_keys: Optional[List[str]] = None,
127
+ ):
128
+ """
129
+ Initialize the Diffusion Policy adapter.
130
+
131
+ Args:
132
+ policy: DP policy instance (optional if inference_fn provided)
133
+ inference_fn: Custom inference function (policy, obs) -> action
134
+ action_horizon: Number of future actions predicted
135
+ action_dim: Dimension of action space
136
+ state_keys: List of state modality keys
137
+ image_keys: List of image modality keys
138
+ """
139
+ self.policy = policy
140
+ self._inference_fn = inference_fn
141
+
142
+ self._modality_config = ModalityConfig(
143
+ state_keys=state_keys or ["robot_state"],
144
+ action_keys=["action"],
145
+ image_keys=image_keys or ["front"],
146
+ language_keys=[], # DP typically doesn't use language
147
+ action_horizon=action_horizon,
148
+ action_dim=action_dim,
149
+ )
150
+
151
+ def _default_inference(self, obs: Dict[str, Any]) -> np.ndarray:
152
+ """Default inference logic when no custom inference_fn provided."""
153
+ dp_obs = parse_observation_dp(obs)
154
+
155
+ # Try different method names
156
+ if hasattr(self.policy, "predict"):
157
+ action = self.policy.predict(dp_obs)
158
+ elif hasattr(self.policy, "get_action"):
159
+ action = self.policy.get_action(dp_obs)
160
+ elif hasattr(self.policy, "__call__"):
161
+ action = self.policy(dp_obs)
162
+ else:
163
+ raise ValueError(
164
+ "Policy must have predict(), get_action(), or __call__() method. "
165
+ "Alternatively, provide inference_fn parameter."
166
+ )
167
+
168
+ return parse_action_dp(action)
169
+
170
+ def get_action(
171
+ self,
172
+ obs: Dict[str, Any],
173
+ task_description: Optional[str] = None,
174
+ ) -> np.ndarray:
175
+ """
176
+ Get action from Diffusion Policy.
177
+
178
+ Args:
179
+ obs: Standardized observation dict
180
+ task_description: Ignored (DP doesn't use language)
181
+
182
+ Returns:
183
+ Action array of shape (action_horizon, action_dim)
184
+ """
185
+ if self._inference_fn is not None:
186
+ if self.policy is not None:
187
+ action = self._inference_fn(self.policy, obs)
188
+ else:
189
+ action = self._inference_fn(obs)
190
+ return parse_action_dp(action)
191
+
192
+ return self._default_inference(obs)
193
+
194
+ def get_modality_config(self) -> ModalityConfig:
195
+ """Get modality configuration."""
196
+ return self._modality_config
197
+
198
+ def reset(self) -> None:
199
+ """Reset the policy."""
200
+ if self.policy is not None and hasattr(self.policy, "reset"):
201
+ self.policy.reset()
202
+
203
+
204
+ class DiffusionPolicyClientAdapter(DiffusionPolicyAdapter):
205
+ """
206
+ Adapter for Diffusion Policy inference client (e.g., ZMQ client).
207
+
208
+ This is a convenience class for wrapping remote DP inference servers.
209
+
210
+ Usage:
211
+ adapter = DiffusionPolicyClientAdapter(
212
+ host="localhost",
213
+ port=5000,
214
+ action_horizon=8,
215
+ action_dim=7,
216
+ )
217
+ action = adapter.get_action(obs)
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ host: str = "localhost",
223
+ port: int = 5000,
224
+ action_horizon: int = 8,
225
+ action_dim: int = 7,
226
+ state_keys: Optional[List[str]] = None,
227
+ image_keys: Optional[List[str]] = None,
228
+ ):
229
+ """
230
+ Initialize client adapter.
231
+
232
+ Args:
233
+ host: Server hostname
234
+ port: Server port
235
+ action_horizon: Number of future actions
236
+ action_dim: Action dimension
237
+ state_keys: State modality keys
238
+ image_keys: Image modality keys
239
+ """
240
+ super().__init__(
241
+ policy=None,
242
+ action_horizon=action_horizon,
243
+ action_dim=action_dim,
244
+ state_keys=state_keys,
245
+ image_keys=image_keys,
246
+ )
247
+
248
+ self.host = host
249
+ self.port = port
250
+ self._client = None
251
+
252
+ def _get_client(self):
253
+ """Lazy initialization of ZMQ client."""
254
+ if self._client is None:
255
+ try:
256
+ import zmq
257
+ context = zmq.Context()
258
+ self._client = context.socket(zmq.REQ)
259
+ self._client.connect(f"tcp://{self.host}:{self.port}")
260
+ except ImportError:
261
+ raise ImportError("pyzmq required for DiffusionPolicyClientAdapter")
262
+ return self._client
263
+
264
+ def get_action(
265
+ self,
266
+ obs: Dict[str, Any],
267
+ task_description: Optional[str] = None,
268
+ ) -> np.ndarray:
269
+ """
270
+ Get action from remote DP server.
271
+
272
+ Note: This is a placeholder. Actual implementation depends on
273
+ the specific DP server protocol used.
274
+ """
275
+ raise NotImplementedError(
276
+ "DiffusionPolicyClientAdapter.get_action() requires implementation "
277
+ "specific to your DP server protocol. Override this method or use "
278
+ "DiffusionPolicyAdapter with a custom inference_fn."
279
+ )