opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +19 -19
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +12 -4
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +141 -38
- opentau/utils/transformers_patch.py +251 -20
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/METADATA +37 -17
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -21
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +1 -0
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau/scripts/visualize_dataset_html.py +0 -507
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""gRPC client for robot policy inference.
|
|
16
|
+
|
|
17
|
+
This client runs on the robot and sends observations to a remote gRPC server
|
|
18
|
+
for ML inference. It subscribes to /joint_states for robot state, creates
|
|
19
|
+
fake images for inference, and publishes motor commands to
|
|
20
|
+
/motor_command_controller/motor_commands.
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
python src/opentau/scripts/grpc/client.py \
|
|
24
|
+
--server_address 192.168.1.100:50051 \
|
|
25
|
+
--prompt "pick up the red block"
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import argparse
|
|
29
|
+
import io
|
|
30
|
+
import logging
|
|
31
|
+
import time
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from typing import Optional
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
import rclpy
|
|
37
|
+
|
|
38
|
+
# Import ROS 2 message types
|
|
39
|
+
from interfaces.msg import MotorCommands, RawMotorCommand
|
|
40
|
+
from PIL import Image
|
|
41
|
+
from rclpy.node import Node
|
|
42
|
+
from sensor_msgs.msg import JointState
|
|
43
|
+
|
|
44
|
+
import grpc
|
|
45
|
+
from opentau.scripts.grpc import robot_inference_pb2, robot_inference_pb2_grpc
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
# Topics
|
|
50
|
+
MOTOR_COMMANDS_TOPIC = "/motor_command_controller/motor_commands"
|
|
51
|
+
JOINT_STATES_TOPIC = "/joint_states"
|
|
52
|
+
|
|
53
|
+
# Joint configuration (example robot).
|
|
54
|
+
JOINT_NAMES: list[str] = [
|
|
55
|
+
"base_yaw_joint",
|
|
56
|
+
"shoulder_pitch_joint",
|
|
57
|
+
"shoulder_roll_joint",
|
|
58
|
+
"elbow_flex_joint",
|
|
59
|
+
"wrist_roll_joint",
|
|
60
|
+
"wrist_yaw_joint",
|
|
61
|
+
"hip_pitch_joint",
|
|
62
|
+
"hip_roll_joint",
|
|
63
|
+
"knee_joint",
|
|
64
|
+
"ankle_pitch_joint",
|
|
65
|
+
"ankle_roll_joint",
|
|
66
|
+
"gripper_finger_joint",
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class ClientConfig:
|
|
72
|
+
"""Configuration for the gRPC client."""
|
|
73
|
+
|
|
74
|
+
server_address: str = "localhost:50051"
|
|
75
|
+
timeout_seconds: float = 5.0
|
|
76
|
+
max_retries: int = 3
|
|
77
|
+
retry_delay_seconds: float = 0.1
|
|
78
|
+
image_encoding: str = "jpeg" # "jpeg", "png", or "raw"
|
|
79
|
+
jpeg_quality: int = 85
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PolicyClient:
|
|
83
|
+
"""gRPC client for communicating with the policy inference server."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, config: ClientConfig):
|
|
86
|
+
"""Initialize the client.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
config: Client configuration.
|
|
90
|
+
"""
|
|
91
|
+
self.config = config
|
|
92
|
+
self._channel: Optional[grpc.Channel] = None
|
|
93
|
+
self._stub: Optional[robot_inference_pb2_grpc.RobotPolicyServiceStub] = None
|
|
94
|
+
self._connected = False
|
|
95
|
+
self._request_counter = 0
|
|
96
|
+
|
|
97
|
+
def connect(self) -> bool:
|
|
98
|
+
"""Establish connection to the server.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if connection was successful, False otherwise.
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
self._channel = grpc.insecure_channel(
|
|
105
|
+
self.config.server_address,
|
|
106
|
+
options=[
|
|
107
|
+
("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB
|
|
108
|
+
("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB
|
|
109
|
+
("grpc.keepalive_time_ms", 10000),
|
|
110
|
+
("grpc.keepalive_timeout_ms", 5000),
|
|
111
|
+
],
|
|
112
|
+
)
|
|
113
|
+
self._stub = robot_inference_pb2_grpc.RobotPolicyServiceStub(self._channel)
|
|
114
|
+
|
|
115
|
+
# Test connection with health check
|
|
116
|
+
response = self._stub.HealthCheck(
|
|
117
|
+
robot_inference_pb2.HealthCheckRequest(),
|
|
118
|
+
timeout=self.config.timeout_seconds,
|
|
119
|
+
)
|
|
120
|
+
self._connected = response.healthy
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Connected to server: {self.config.server_address}, "
|
|
123
|
+
f"model: {response.model_name}, device: {response.device}"
|
|
124
|
+
)
|
|
125
|
+
return self._connected
|
|
126
|
+
|
|
127
|
+
except grpc.RpcError as e:
|
|
128
|
+
logger.error(f"Failed to connect to server: {e}")
|
|
129
|
+
self._connected = False
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
def disconnect(self):
|
|
133
|
+
"""Close the connection to the server."""
|
|
134
|
+
if self._channel:
|
|
135
|
+
self._channel.close()
|
|
136
|
+
self._channel = None
|
|
137
|
+
self._stub = None
|
|
138
|
+
self._connected = False
|
|
139
|
+
logger.info("Disconnected from server")
|
|
140
|
+
|
|
141
|
+
def is_connected(self) -> bool:
|
|
142
|
+
"""Check if the client is connected.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
True if connected, False otherwise.
|
|
146
|
+
"""
|
|
147
|
+
return self._connected and self._channel is not None
|
|
148
|
+
|
|
149
|
+
def _encode_image(self, image: np.ndarray) -> robot_inference_pb2.CameraImage:
|
|
150
|
+
"""Encode an image for transmission.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
image: Image array of shape (H, W, C) with values in [0, 255] or [0, 1].
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
CameraImage protobuf message.
|
|
157
|
+
"""
|
|
158
|
+
# Normalize image to [0, 255] uint8
|
|
159
|
+
if image.dtype == np.float32 or image.dtype == np.float64:
|
|
160
|
+
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
|
|
161
|
+
elif image.dtype != np.uint8:
|
|
162
|
+
image = image.astype(np.uint8)
|
|
163
|
+
|
|
164
|
+
camera_image = robot_inference_pb2.CameraImage()
|
|
165
|
+
|
|
166
|
+
if self.config.image_encoding == "jpeg":
|
|
167
|
+
pil_image = Image.fromarray(image)
|
|
168
|
+
buffer = io.BytesIO()
|
|
169
|
+
pil_image.save(buffer, format="JPEG", quality=self.config.jpeg_quality)
|
|
170
|
+
camera_image.image_data = buffer.getvalue()
|
|
171
|
+
camera_image.encoding = "jpeg"
|
|
172
|
+
elif self.config.image_encoding == "png":
|
|
173
|
+
pil_image = Image.fromarray(image)
|
|
174
|
+
buffer = io.BytesIO()
|
|
175
|
+
pil_image.save(buffer, format="PNG")
|
|
176
|
+
camera_image.image_data = buffer.getvalue()
|
|
177
|
+
camera_image.encoding = "png"
|
|
178
|
+
else: # raw
|
|
179
|
+
camera_image.image_data = image.astype(np.float32).tobytes()
|
|
180
|
+
camera_image.encoding = "raw"
|
|
181
|
+
|
|
182
|
+
return camera_image
|
|
183
|
+
|
|
184
|
+
def get_action_chunk(
|
|
185
|
+
self,
|
|
186
|
+
images: list[np.ndarray],
|
|
187
|
+
state: np.ndarray,
|
|
188
|
+
prompt: str,
|
|
189
|
+
) -> tuple[np.ndarray, float]:
|
|
190
|
+
"""Get action chunk from the policy server.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
images: List of image arrays (H, W, C) for each camera.
|
|
194
|
+
state: Robot state vector.
|
|
195
|
+
prompt: Language instruction.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Tuple of (action chunk array, inference time in ms).
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
RuntimeError: If not connected or inference fails.
|
|
202
|
+
"""
|
|
203
|
+
if not self.is_connected():
|
|
204
|
+
raise RuntimeError("Client is not connected to server")
|
|
205
|
+
|
|
206
|
+
self._request_counter += 1
|
|
207
|
+
request = robot_inference_pb2.ObservationRequest()
|
|
208
|
+
request.request_id = f"req_{self._request_counter}"
|
|
209
|
+
request.timestamp_ns = time.time_ns()
|
|
210
|
+
request.prompt = prompt
|
|
211
|
+
|
|
212
|
+
# Add images
|
|
213
|
+
for image in images:
|
|
214
|
+
camera_image = self._encode_image(image)
|
|
215
|
+
request.images.append(camera_image)
|
|
216
|
+
|
|
217
|
+
# Add state
|
|
218
|
+
request.robot_state.state.extend(state.flatten().tolist())
|
|
219
|
+
|
|
220
|
+
# Make request with retries
|
|
221
|
+
last_error = None
|
|
222
|
+
for attempt in range(self.config.max_retries):
|
|
223
|
+
try:
|
|
224
|
+
response = self._stub.GetActionChunk(request, timeout=self.config.timeout_seconds)
|
|
225
|
+
|
|
226
|
+
# Convert repeated ActionVector messages into a 2D numpy array:
|
|
227
|
+
# shape = (chunk_length, action_dim)
|
|
228
|
+
if not response.action_chunk:
|
|
229
|
+
action_chunk = np.zeros((0,), dtype=np.float32)
|
|
230
|
+
else:
|
|
231
|
+
action_vectors = [
|
|
232
|
+
np.asarray(action_vector.values, dtype=np.float32)
|
|
233
|
+
for action_vector in response.action_chunk
|
|
234
|
+
]
|
|
235
|
+
# Stack into (T, D) array where T is chunk length.
|
|
236
|
+
action_chunk = np.stack(action_vectors, axis=0)
|
|
237
|
+
|
|
238
|
+
return action_chunk, response.inference_time_ms
|
|
239
|
+
|
|
240
|
+
except grpc.RpcError as e:
|
|
241
|
+
last_error = e
|
|
242
|
+
logger.warning(f"Request failed (attempt {attempt + 1}): {e}")
|
|
243
|
+
if attempt < self.config.max_retries - 1:
|
|
244
|
+
time.sleep(self.config.retry_delay_seconds)
|
|
245
|
+
|
|
246
|
+
raise RuntimeError(f"Failed after {self.config.max_retries} retries: {last_error}")
|
|
247
|
+
|
|
248
|
+
def health_check(self) -> dict:
|
|
249
|
+
"""Check server health.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dictionary with health status information.
|
|
253
|
+
"""
|
|
254
|
+
if not self._stub:
|
|
255
|
+
return {"healthy": False, "status": "Not connected"}
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
response = self._stub.HealthCheck(
|
|
259
|
+
robot_inference_pb2.HealthCheckRequest(),
|
|
260
|
+
timeout=self.config.timeout_seconds,
|
|
261
|
+
)
|
|
262
|
+
return {
|
|
263
|
+
"healthy": response.healthy,
|
|
264
|
+
"status": response.status,
|
|
265
|
+
"model_name": response.model_name,
|
|
266
|
+
"device": response.device,
|
|
267
|
+
"gpu_memory_used_gb": response.gpu_memory_used_gb,
|
|
268
|
+
"gpu_memory_total_gb": response.gpu_memory_total_gb,
|
|
269
|
+
}
|
|
270
|
+
except grpc.RpcError as e:
|
|
271
|
+
return {"healthy": False, "status": str(e)}
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# =============================================================================
|
|
275
|
+
# ROS 2 Integration
|
|
276
|
+
# =============================================================================
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@dataclass
|
|
280
|
+
class ROS2Config:
|
|
281
|
+
"""Configuration for the ROS 2 policy client node."""
|
|
282
|
+
|
|
283
|
+
# gRPC settings
|
|
284
|
+
server_address: str = "localhost:50051"
|
|
285
|
+
timeout_seconds: float = 30.0 # Allow longer timeout for ML inference warmup
|
|
286
|
+
|
|
287
|
+
# Topic names
|
|
288
|
+
state_topic: str = JOINT_STATES_TOPIC
|
|
289
|
+
motor_commands_topic: str = MOTOR_COMMANDS_TOPIC
|
|
290
|
+
|
|
291
|
+
# Control settings
|
|
292
|
+
control_frequency_hz: float = 10.0
|
|
293
|
+
prompt: str = ""
|
|
294
|
+
|
|
295
|
+
# Fake image settings
|
|
296
|
+
num_cameras: int = 2
|
|
297
|
+
image_height: int = 224
|
|
298
|
+
image_width: int = 224
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ROS2PolicyClient(Node):
|
|
302
|
+
"""ROS 2 node that interfaces with the gRPC policy server.
|
|
303
|
+
|
|
304
|
+
This node subscribes to joint states, creates fake images for inference,
|
|
305
|
+
sends them to the gRPC server, and publishes the resulting actions as
|
|
306
|
+
motor commands. The prompt is provided via command-line argument.
|
|
307
|
+
|
|
308
|
+
Example usage:
|
|
309
|
+
```python
|
|
310
|
+
import rclpy
|
|
311
|
+
from motor_command_controller.grpc.client import ROS2PolicyClient, ROS2Config
|
|
312
|
+
|
|
313
|
+
rclpy.init()
|
|
314
|
+
config = ROS2Config(
|
|
315
|
+
server_address="192.168.1.100:50051",
|
|
316
|
+
control_frequency_hz=30.0,
|
|
317
|
+
prompt="pick up the red block",
|
|
318
|
+
)
|
|
319
|
+
node = ROS2PolicyClient(config)
|
|
320
|
+
rclpy.spin(node)
|
|
321
|
+
```
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
def __init__(self, config: ROS2Config):
|
|
325
|
+
"""Initialize the ROS 2 policy client node.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
config: ROS 2 configuration.
|
|
329
|
+
"""
|
|
330
|
+
super().__init__("policy_client")
|
|
331
|
+
self.config = config
|
|
332
|
+
# Fixed joint ordering for this example robot.
|
|
333
|
+
self.joint_names: list[str] = JOINT_NAMES
|
|
334
|
+
|
|
335
|
+
# Initialize gRPC client
|
|
336
|
+
client_config = ClientConfig(
|
|
337
|
+
server_address=config.server_address,
|
|
338
|
+
timeout_seconds=config.timeout_seconds,
|
|
339
|
+
)
|
|
340
|
+
self.policy_client = PolicyClient(client_config)
|
|
341
|
+
|
|
342
|
+
# State storage
|
|
343
|
+
self._latest_positions: Optional[list[float]] = None
|
|
344
|
+
self._latest_velocities: Optional[list[float]] = None
|
|
345
|
+
self._prompt: str = config.prompt
|
|
346
|
+
|
|
347
|
+
# Create subscriber for joint states
|
|
348
|
+
self._state_sub = self.create_subscription(
|
|
349
|
+
JointState,
|
|
350
|
+
config.state_topic,
|
|
351
|
+
self._state_callback,
|
|
352
|
+
10,
|
|
353
|
+
)
|
|
354
|
+
self.get_logger().info(f"Subscribed to {config.state_topic}")
|
|
355
|
+
self.get_logger().info(f"Using prompt: {self._prompt}")
|
|
356
|
+
|
|
357
|
+
# Create publisher for motor commands
|
|
358
|
+
self._motor_commands_pub = self.create_publisher(
|
|
359
|
+
MotorCommands,
|
|
360
|
+
config.motor_commands_topic,
|
|
361
|
+
10,
|
|
362
|
+
)
|
|
363
|
+
self.get_logger().info(f"Publishing to {config.motor_commands_topic}")
|
|
364
|
+
|
|
365
|
+
# Create control timer
|
|
366
|
+
self._control_timer = self.create_timer(
|
|
367
|
+
# 1.0 / config.control_frequency_hz,
|
|
368
|
+
1,
|
|
369
|
+
self._control_callback,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Connect to server
|
|
373
|
+
self.get_logger().info(f"Connecting to gRPC server at {config.server_address}")
|
|
374
|
+
if not self.policy_client.connect():
|
|
375
|
+
self.get_logger().error("Failed to connect to gRPC server")
|
|
376
|
+
else:
|
|
377
|
+
self.get_logger().info("Connected to gRPC server")
|
|
378
|
+
|
|
379
|
+
def _create_fake_images(self) -> list[np.ndarray]:
|
|
380
|
+
"""Create fake images for inference.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
List of fake RGB images.
|
|
384
|
+
"""
|
|
385
|
+
images = []
|
|
386
|
+
for _ in range(self.config.num_cameras):
|
|
387
|
+
# Create a random RGB image
|
|
388
|
+
image = np.random.randint(
|
|
389
|
+
0,
|
|
390
|
+
255,
|
|
391
|
+
(self.config.image_height, self.config.image_width, 3),
|
|
392
|
+
dtype=np.uint8,
|
|
393
|
+
)
|
|
394
|
+
images.append(image)
|
|
395
|
+
return images
|
|
396
|
+
|
|
397
|
+
def _state_callback(self, msg: JointState):
|
|
398
|
+
"""Handle incoming joint state messages.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
msg: JointState message.
|
|
402
|
+
"""
|
|
403
|
+
# For this example script, we simply take the first N joints from the
|
|
404
|
+
# incoming message, where N is the number of fake joints.
|
|
405
|
+
num_joints = len(self.joint_names)
|
|
406
|
+
if len(msg.position) < num_joints or len(msg.velocity) < num_joints:
|
|
407
|
+
self.get_logger().warning(
|
|
408
|
+
f"Received joint_states with fewer than {num_joints} joints; waiting for full state.",
|
|
409
|
+
throttle_duration_sec=5.0,
|
|
410
|
+
)
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
self._latest_positions = list(msg.position[:num_joints])
|
|
414
|
+
self._latest_velocities = list(msg.velocity[:num_joints])
|
|
415
|
+
|
|
416
|
+
def _control_callback(self):
|
|
417
|
+
"""Main control loop callback."""
|
|
418
|
+
if not self.policy_client.is_connected():
|
|
419
|
+
self.get_logger().warning(
|
|
420
|
+
"Not connected to gRPC server",
|
|
421
|
+
throttle_duration_sec=5.0,
|
|
422
|
+
)
|
|
423
|
+
return
|
|
424
|
+
|
|
425
|
+
# Check if we have joint state data
|
|
426
|
+
if self._latest_positions is None or self._latest_velocities is None:
|
|
427
|
+
self.get_logger().warning(
|
|
428
|
+
"Waiting for joint_states...",
|
|
429
|
+
throttle_duration_sec=5.0,
|
|
430
|
+
)
|
|
431
|
+
return
|
|
432
|
+
|
|
433
|
+
# Check if we have a prompt
|
|
434
|
+
if not self._prompt:
|
|
435
|
+
self.get_logger().warning(
|
|
436
|
+
"No prompt provided. Use --prompt argument.",
|
|
437
|
+
throttle_duration_sec=5.0,
|
|
438
|
+
)
|
|
439
|
+
return
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
# Create state vector (positions + velocities)
|
|
443
|
+
state = np.array(
|
|
444
|
+
self._latest_positions + self._latest_velocities,
|
|
445
|
+
dtype=np.float32,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Create fake images
|
|
449
|
+
images = self._create_fake_images()
|
|
450
|
+
|
|
451
|
+
self.get_logger().info("Sending request to server")
|
|
452
|
+
self.get_logger().info(f"Images list length: {len(images)}")
|
|
453
|
+
if images:
|
|
454
|
+
self.get_logger().info(f"First image shape: {images[0].shape}")
|
|
455
|
+
else:
|
|
456
|
+
self.get_logger().info("Images list is empty")
|
|
457
|
+
self.get_logger().info(f"State shape: {state.shape}")
|
|
458
|
+
|
|
459
|
+
# Get action chunk from server
|
|
460
|
+
action_chunk, inference_time_ms = self.policy_client.get_action_chunk(
|
|
461
|
+
images=images,
|
|
462
|
+
state=state,
|
|
463
|
+
prompt=self._prompt,
|
|
464
|
+
)
|
|
465
|
+
self.get_logger().info("Received action chunk from server")
|
|
466
|
+
self.get_logger().info(f"Action chunk shape: {action_chunk.shape}")
|
|
467
|
+
self.get_logger().info(f"Inference time: {inference_time_ms:.1f} ms")
|
|
468
|
+
|
|
469
|
+
# Publish motor commands
|
|
470
|
+
self._publish_motor_commands(action_chunk)
|
|
471
|
+
|
|
472
|
+
self.get_logger().debug(
|
|
473
|
+
f"Action chunk: {action_chunk[:3]}..., inference time: {inference_time_ms:.1f}ms"
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
except Exception as e:
|
|
477
|
+
self.get_logger().error(f"Failed to get action chunk: {e}")
|
|
478
|
+
|
|
479
|
+
def _publish_motor_commands(self, action_chunk: np.ndarray):
|
|
480
|
+
"""Publish motor commands from action chunk.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
action_chunk: Action chunk from the policy server.
|
|
484
|
+
Expected to be either:
|
|
485
|
+
- A 2D array of shape (chunk_length, num_joints) where each row
|
|
486
|
+
is a full action vector for all joints, or
|
|
487
|
+
- A 1D array of shape (num_joints,) for a single action vector.
|
|
488
|
+
"""
|
|
489
|
+
msg = MotorCommands()
|
|
490
|
+
msg.header.stamp = self.get_clock().now().to_msg()
|
|
491
|
+
msg.joint_names = self.joint_names
|
|
492
|
+
|
|
493
|
+
# Select the action vector to apply.
|
|
494
|
+
# If we have a chunk of actions, use the most recent one.
|
|
495
|
+
if action_chunk.ndim == 2:
|
|
496
|
+
if action_chunk.shape[0] == 0:
|
|
497
|
+
self.get_logger().error("Received empty action chunk (no timesteps)")
|
|
498
|
+
return
|
|
499
|
+
action = action_chunk[-1]
|
|
500
|
+
elif action_chunk.ndim == 1:
|
|
501
|
+
action = action_chunk
|
|
502
|
+
else:
|
|
503
|
+
self.get_logger().error(
|
|
504
|
+
f"Unexpected action_chunk shape: {action_chunk.shape} (ndim={action_chunk.ndim})"
|
|
505
|
+
)
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
# Create motor commands for each joint
|
|
509
|
+
# Assumes action contains target positions for each joint
|
|
510
|
+
num_joints = len(self.joint_names)
|
|
511
|
+
if len(action) < num_joints:
|
|
512
|
+
self.get_logger().error(f"Action vector too small: {len(action)} < {num_joints}")
|
|
513
|
+
return
|
|
514
|
+
|
|
515
|
+
msg.commands = [RawMotorCommand(q=float(action[i])) for i, joint_name in enumerate(self.joint_names)]
|
|
516
|
+
|
|
517
|
+
self._motor_commands_pub.publish(msg)
|
|
518
|
+
|
|
519
|
+
def publish_damping_command(self):
|
|
520
|
+
"""Publish a damping command to safely stop the robot."""
|
|
521
|
+
if self._latest_positions is None:
|
|
522
|
+
return
|
|
523
|
+
|
|
524
|
+
msg = MotorCommands()
|
|
525
|
+
msg.header.stamp = self.get_clock().now().to_msg()
|
|
526
|
+
msg.joint_names = self.joint_names
|
|
527
|
+
msg.commands = [
|
|
528
|
+
RawMotorCommand(q=self._latest_positions[i]) for i, joint_name in enumerate(self.joint_names)
|
|
529
|
+
]
|
|
530
|
+
self._motor_commands_pub.publish(msg)
|
|
531
|
+
|
|
532
|
+
def destroy_node(self):
|
|
533
|
+
"""Clean up resources when node is destroyed."""
|
|
534
|
+
# Send damping command before shutting down
|
|
535
|
+
self.publish_damping_command()
|
|
536
|
+
self.policy_client.disconnect()
|
|
537
|
+
super().destroy_node()
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def main():
|
|
541
|
+
"""Main entry point for the ROS 2 gRPC policy client."""
|
|
542
|
+
parser = argparse.ArgumentParser(description="gRPC Robot Policy Client (ROS 2)")
|
|
543
|
+
parser.add_argument(
|
|
544
|
+
"--server_address",
|
|
545
|
+
type=str,
|
|
546
|
+
default="localhost:50051",
|
|
547
|
+
help="Server address (host:port)",
|
|
548
|
+
)
|
|
549
|
+
parser.add_argument(
|
|
550
|
+
"--control_frequency",
|
|
551
|
+
type=float,
|
|
552
|
+
default=30.0,
|
|
553
|
+
help="Control loop frequency in Hz",
|
|
554
|
+
)
|
|
555
|
+
parser.add_argument(
|
|
556
|
+
"--prompt",
|
|
557
|
+
type=str,
|
|
558
|
+
default="",
|
|
559
|
+
help="Language prompt for the policy (required)",
|
|
560
|
+
)
|
|
561
|
+
parser.add_argument(
|
|
562
|
+
"--num_cameras",
|
|
563
|
+
type=int,
|
|
564
|
+
default=2,
|
|
565
|
+
help="Number of fake cameras to simulate",
|
|
566
|
+
)
|
|
567
|
+
parser.add_argument(
|
|
568
|
+
"--timeout",
|
|
569
|
+
type=float,
|
|
570
|
+
default=30.0,
|
|
571
|
+
help="gRPC timeout in seconds (increase for slow inference)",
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
args = parser.parse_args()
|
|
575
|
+
|
|
576
|
+
logging.basicConfig(level=logging.INFO)
|
|
577
|
+
|
|
578
|
+
# Initialize ROS 2
|
|
579
|
+
rclpy.init()
|
|
580
|
+
|
|
581
|
+
config = ROS2Config(
|
|
582
|
+
server_address=args.server_address,
|
|
583
|
+
timeout_seconds=args.timeout,
|
|
584
|
+
control_frequency_hz=args.control_frequency,
|
|
585
|
+
prompt=args.prompt,
|
|
586
|
+
num_cameras=args.num_cameras,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
node = ROS2PolicyClient(config)
|
|
590
|
+
|
|
591
|
+
try:
|
|
592
|
+
rclpy.spin(node)
|
|
593
|
+
except KeyboardInterrupt:
|
|
594
|
+
pass
|
|
595
|
+
finally:
|
|
596
|
+
node.destroy_node()
|
|
597
|
+
rclpy.shutdown()
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
if __name__ == "__main__":
|
|
601
|
+
main()
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# # -*- coding: utf-8 -*-
|
|
16
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
17
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
18
|
+
# source: robot_inference.proto
|
|
19
|
+
# Protobuf Python Version: 6.31.1
|
|
20
|
+
"""Generated protocol buffer code."""
|
|
21
|
+
|
|
22
|
+
from google.protobuf import descriptor as _descriptor
|
|
23
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
24
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
25
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
26
|
+
from google.protobuf.internal import builder as _builder
|
|
27
|
+
|
|
28
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
29
|
+
_runtime_version.Domain.PUBLIC, 6, 31, 1, "", "robot_inference.proto"
|
|
30
|
+
)
|
|
31
|
+
# @@protoc_insertion_point(imports)
|
|
32
|
+
|
|
33
|
+
_sym_db = _symbol_database.Default()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
|
37
|
+
b'\n\x15robot_inference.proto\x12\x0frobot_inference"3\n\x0b\x43\x61meraImage\x12\x12\n\nimage_data\x18\x01 \x01(\x0c\x12\x10\n\x08\x65ncoding\x18\x02 \x01(\t"\x1b\n\nRobotState\x12\r\n\x05state\x18\x01 \x03(\x02"\xae\x01\n\x12ObservationRequest\x12,\n\x06images\x18\x01 \x03(\x0b\x32\x1c.robot_inference.CameraImage\x12\x30\n\x0brobot_state\x18\x02 \x01(\x0b\x32\x1b.robot_inference.RobotState\x12\x0e\n\x06prompt\x18\x03 \x01(\t\x12\x14\n\x0ctimestamp_ns\x18\x04 \x01(\x03\x12\x12\n\nrequest_id\x18\x05 \x01(\t"\x1e\n\x0c\x41\x63tionVector\x12\x0e\n\x06values\x18\x01 \x03(\x02"\x8f\x01\n\x13\x41\x63tionChunkResponse\x12\x33\n\x0c\x61\x63tion_chunk\x18\x01 \x03(\x0b\x32\x1d.robot_inference.ActionVector\x12\x14\n\x0ctimestamp_ns\x18\x02 \x01(\x03\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\x19\n\x11inference_time_ms\x18\x04 \x01(\x02"\x14\n\x12HealthCheckRequest"\x93\x01\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12\x1a\n\x12gpu_memory_used_gb\x18\x05 \x01(\x02\x12\x1b\n\x13gpu_memory_total_gb\x18\x06 \x01(\x02\x32\xb0\x02\n\x12RobotPolicyService\x12[\n\x0eGetActionChunk\x12#.robot_inference.ObservationRequest\x1a$.robot_inference.ActionChunkResponse\x12\x63\n\x12StreamActionChunks\x12#.robot_inference.ObservationRequest\x1a$.robot_inference.ActionChunkResponse(\x01\x30\x01\x12X\n\x0bHealthCheck\x12#.robot_inference.HealthCheckRequest\x1a$.robot_inference.HealthCheckResponseb\x06proto3'
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_globals = globals()
|
|
41
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
42
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "robot_inference_pb2", _globals)
|
|
43
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
44
|
+
DESCRIPTOR._loaded_options = None
|
|
45
|
+
_globals["_CAMERAIMAGE"]._serialized_start = 42
|
|
46
|
+
_globals["_CAMERAIMAGE"]._serialized_end = 93
|
|
47
|
+
_globals["_ROBOTSTATE"]._serialized_start = 95
|
|
48
|
+
_globals["_ROBOTSTATE"]._serialized_end = 122
|
|
49
|
+
_globals["_OBSERVATIONREQUEST"]._serialized_start = 125
|
|
50
|
+
_globals["_OBSERVATIONREQUEST"]._serialized_end = 299
|
|
51
|
+
_globals["_ACTIONVECTOR"]._serialized_start = 301
|
|
52
|
+
_globals["_ACTIONVECTOR"]._serialized_end = 331
|
|
53
|
+
_globals["_ACTIONCHUNKRESPONSE"]._serialized_start = 334
|
|
54
|
+
_globals["_ACTIONCHUNKRESPONSE"]._serialized_end = 477
|
|
55
|
+
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 479
|
|
56
|
+
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 499
|
|
57
|
+
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 502
|
|
58
|
+
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 649
|
|
59
|
+
_globals["_ROBOTPOLICYSERVICE"]._serialized_start = 652
|
|
60
|
+
_globals["_ROBOTPOLICYSERVICE"]._serialized_end = 956
|
|
61
|
+
# @@protoc_insertion_point(module_scope)
|